├── Figures
├── bgMLP.png
└── overview_latest.png
├── iid-dictusers
├── ICH_1037_55000.npy
└── ChestXray14_1037_85000.npy
├── model
├── build_model.py
├── efficientnet.py
└── all_models.py
├── preprocess
├── split_train_test.py
├── count.py
├── count_pwise_disease.py
├── count_mean_dev.py
├── label_rectify.py
└── ICH_process.py
├── utils
├── feature_visual.py
├── valloss_cal.py
├── sampling.py
├── utils.py
├── multilabel_metrixs.py
├── FedAvg.py
├── options.py
├── FedNoRo.py
├── evaluations.py
├── FixMatch.py
└── local_training.py
├── requirements.txt
├── README.md
├── dataset
├── all_dataset.py
└── dataset.py
└── main.py
/Figures/bgMLP.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/szbonaldo/FedMLP/HEAD/Figures/bgMLP.png
--------------------------------------------------------------------------------
/Figures/overview_latest.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/szbonaldo/FedMLP/HEAD/Figures/overview_latest.png
--------------------------------------------------------------------------------
/iid-dictusers/ICH_1037_55000.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/szbonaldo/FedMLP/HEAD/iid-dictusers/ICH_1037_55000.npy
--------------------------------------------------------------------------------
/iid-dictusers/ChestXray14_1037_85000.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/szbonaldo/FedMLP/HEAD/iid-dictusers/ChestXray14_1037_85000.npy
--------------------------------------------------------------------------------
/model/build_model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .all_models import get_model, modify_last_layer
3 |
4 |
5 | def build_model(args):
6 | # choose different Neural network model for different args
7 | model = get_model(args.model, args.pretrained)
8 | model, _ = modify_last_layer(args.model, model, args.n_classes)
9 |
10 | return model
--------------------------------------------------------------------------------
/preprocess/split_train_test.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import pandas as pd
4 | import torch
5 | from sklearn.model_selection import train_test_split
6 | # ChestXray14
7 | seed = 2023
8 | torch.manual_seed(seed)
9 | torch.cuda.manual_seed(seed)
10 | torch.cuda.manual_seed_all(seed)
11 | np.random.seed(seed)
12 | random.seed(seed)
13 | data = pd.read_csv(r'E:\ICH_stage2\ICH_stage2\data_png185k_512.csv')
14 |
15 | train_ratio = 0.7
16 | test_ratio = 0.3
17 |
18 | train_data, test_data = train_test_split(data, test_size=(1 - train_ratio))
19 |
20 | print(f"训练集大小: {len(train_data)}")
21 | print(f"测试集大小: {len(test_data)}")
22 |
23 | train_data.to_csv('train_dataset.csv', index=False)
24 | test_data.to_csv('test_dataset.csv', index=False)
25 |
--------------------------------------------------------------------------------
/preprocess/count.py:
--------------------------------------------------------------------------------
1 | import csv
2 | # ChestXray14
3 | import pandas as pd
4 | import os
5 |
6 |
7 | def file_name(file_dir):
8 | file_list = []
9 | for root, dirs, files in os.walk(file_dir):
10 | file_list.append(files)
11 | return file_list
12 |
13 | image_4_path = r'E:\szb\Hust\2022~2023\lab\博一\images_004.tar\images_004\images'
14 | image_5_path = 'E:/szb/Hust/2022~2023/lab/博一/images_005.tar/images_005/images'
15 | file_5_list = file_name(image_5_path)[0]
16 | file_4_list = file_name(image_4_path)[0]
17 | csv_path = '../Data_Entry_2017_v2020.csv'
18 |
19 | with open(csv_path, 'r') as csv_file:
20 | csv_reader = csv.reader(csv_file)
21 | with open('filtered_data_4.csv', 'w', newline='') as new_csv_file:
22 | csv_writer = csv.writer(new_csv_file)
23 | for row in csv_reader:
24 | if row[0] in file_4_list:
25 | csv_writer.writerow(row)
26 |
--------------------------------------------------------------------------------
/preprocess/count_pwise_disease.py:
--------------------------------------------------------------------------------
1 | import csv
2 | # ChestXray14
3 | import numpy as np
4 | import pandas as pd
5 |
6 | csv_path = '../onehot-label.csv'
7 | count = 1
8 | pre_patient = ''
9 | total_disease = [0]*14
10 | new_disease = [0]*14
11 | patients = 0
12 | with open(csv_path, 'r') as csv_file:
13 | csv_reader = csv.reader(csv_file)
14 | for row in csv_reader:
15 | if count == 1:
16 | count += 1
17 | print(count)
18 | else:
19 | count += 1
20 | print(count)
21 | if row[0].split('_')[0] != pre_patient:
22 | patients += 1
23 | pre_patient = row[0].split('_')[0]
24 | total_disease = [i + j for i, j in zip(total_disease, new_disease)]
25 | new_disease = list(map(int, row[1:]))
26 | else:
27 | new_disease = [int(new_disease[i]) or int(row[i+1]) for i in range(len(new_disease))]
28 | total_disease = [i + j for i, j in zip(total_disease, new_disease)]
29 | print(np.array(total_disease)/patients)
30 |
31 |
--------------------------------------------------------------------------------
/preprocess/count_mean_dev.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.transforms import transforms
3 | # ChestXray14
4 | from dataset.all_dataset import ChestXray14
5 |
6 |
7 | def getStat(train_data):
8 | '''
9 | Compute mean and variance for training data
10 | :param train_data
11 | :return: (mean, std)
12 | '''
13 | print('Compute mean and variance for training data.')
14 | print(len(train_data))
15 | train_loader = torch.utils.data.DataLoader(
16 | train_data, batch_size=1, shuffle=False, num_workers=0,
17 | pin_memory=True)
18 | mean = torch.zeros(3)
19 | std = torch.zeros(3)
20 | for X, _ in train_loader:
21 | for d in range(3):
22 | mean[d] += X[:, d, :, :].mean()
23 | std[d] += X[:, d, :, :].std()
24 | mean.div_(len(train_data))
25 | std.div_(len(train_data))
26 | return list(mean.numpy()), list(std.numpy())
27 |
28 |
29 | if __name__ == '__main__':
30 | train_dataset = ChestXray14('/home/szb/multilabel', "train", transforms.Compose([
31 | transforms.Resize((224, 224)),
32 | transforms.ToTensor(),
33 | ]))
34 | print(getStat(train_dataset))
35 |
--------------------------------------------------------------------------------
/preprocess/label_rectify.py:
--------------------------------------------------------------------------------
1 | import csv
2 | # ChestXray14
3 | import numpy as np
4 | import pandas as pd
5 |
6 | # df = pd.read_csv('Data_Entry_2017_v2020.csv')
7 | # print(df)
8 | csv_path = './Data_Entry_2017_v2020.csv'
9 | count = 1
10 | title = ['Image Index', 'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Effusion', 'Emphysema', 'Fibrosis', 'Hernia', 'Infiltration', 'Mass', 'Nodule', 'Pleural_Thickening', 'Pneumonia', 'Pneumothorax']
11 | with open(csv_path, 'r') as csv_file:
12 | csv_reader = csv.reader(csv_file)
13 |
14 | with open('../onehot-label-PA.csv', 'w', newline='') as new_csv_file:
15 | csv_writer = csv.writer(new_csv_file)
16 | for row in csv_reader:
17 | if count == 1:
18 | csv_writer.writerow(title)
19 | count += 1
20 | print(count)
21 | else:
22 | count += 1
23 | print(count)
24 | if row[6] == 'PA':
25 | label_row = [row[0]] + [0]*14
26 | label = row[1] # str
27 | if label == 'No Finding':
28 | csv_writer.writerow(label_row)
29 | else:
30 | label_list = label.split('|')
31 | for i in label_list:
32 | label_row[title.index(i)] = 1
33 | csv_writer.writerow(label_row)
34 | else:
35 | pass
36 |
37 |
--------------------------------------------------------------------------------
/utils/feature_visual.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import torch
4 | from numpy import where
5 | from sklearn.manifold import TSNE
6 |
7 |
8 | color_map = ['r', 'y', 'k', 'g', 'b', 'm', 'c', 'peru']
9 | # color_map = ['r', 'g']
10 |
11 |
12 | def plot_embedding_2D(data, label, title, rnd):
13 | x_min, x_max = np.min(data, 0), np.max(data, 0)
14 | data = (data - x_min) / (x_max - x_min)
15 | fig = plt.figure()
16 | for i in range(len(np.unique(label))):
17 | list = data[label == i]
18 | plt.scatter(list[:, 0], list[:, 1], marker='o', s=1, color=color_map[i], label='class:{}'.format(i))
19 | plt.legend()
20 | plt.xticks([])
21 | plt.yticks([])
22 | plt.title(title)
23 | plt.savefig('proto_fig/' + f'rnd:{rnd}' + title + '.png')
24 | # plt.show()
25 | plt.clf()
26 | return fig
27 |
28 |
29 | def tnse_Visual(data, label, rnd, title):
30 | n_samples, n_features = data.shape
31 |
32 | print('Begining......')
33 |
34 | tsne_2D = TSNE(n_components=2, init='pca', random_state=0, perplexity=5)
35 | result_2D = tsne_2D.fit_transform(data)
36 |
37 | print('Finished......')
38 | fig1 = plot_embedding_2D(result_2D, label, title, rnd) # 将二维数据用plt绘制出来
39 |
40 |
41 | if __name__ == '__main__':
42 | label = torch.randint(high=2, low=0, size=(1000, ))
43 | print(label)
44 | data = torch.rand(1000, 512)
45 | print(data.shape)
46 | tnse_Visual(data, label, 1, '666')
47 |
48 |
--------------------------------------------------------------------------------
/preprocess/ICH_process.py:
--------------------------------------------------------------------------------
1 | import csv
2 | from collections import Counter
3 |
4 | import pandas as pd
5 | import numpy as np
6 |
7 | import os
8 | from tqdm import tqdm
9 | data = pd.read_csv(r"E:\ICH_stage2\ICH_stage2\stage_2_train.csv")
10 | data = np.array(data)
11 | patient_num = len(data) // 6
12 | ID_all = []
13 | label_all = []
14 | label_add = [0]*5
15 | for i in range(patient_num):
16 | ID = data[6*i][0].split('_epidural')[0]
17 | label_add = [data[6*i][1], data[6*i+1][1], data[6*i+2][1], data[6*i+3][1], data[6*i+4][1]]
18 | ID_all.append(ID)
19 | label_all.append(label_add)
20 |
21 | ID_have = []
22 | label_have = []
23 | for i in tqdm(range(patient_num)):
24 | name = ID_all[i] + ".png"
25 | path = os.path.join(r"E:\ICH_stage2\ICH_stage2\png185k_512", name)
26 | if os.path.exists(path):
27 | ID_have.append(name)
28 | label_have.append(label_all[i])
29 |
30 |
31 | csv_path = r'E:\ICH_stage2\ICH_stage2\data_png185k_512.csv'
32 | count = 1
33 | title = ['Image Index', 'epidural', 'intraparenchymal', 'intraventricular', 'subarachnoid', 'subdural']
34 | with open(csv_path, 'w', newline='') as new_csv_file:
35 | csv_writer = csv.writer(new_csv_file)
36 | for i in tqdm(range(len(ID_have)+1)):
37 | if count == 1:
38 | csv_writer.writerow(title)
39 | count += 1
40 | else:
41 | count += 1
42 | csv_writer.writerow([ID_have[i-1]] + label_have[i-1])
43 | label_have_total = np.sum(label_have, axis=0)
44 | label_have_class = np.sum(label_have, axis=1)
45 | print(label_have_total) # [2761 32564 23766 32122 42496]
46 | print(Counter(label_have_class)) # Counter({0: 87948, 1: 67969, 2: 22587, 3: 5642, 4: 885, 5: 20})
47 |
--------------------------------------------------------------------------------
/utils/valloss_cal.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import numpy as np
4 | import torch
5 | from torch import nn
6 | from torch.utils.data import DataLoader, SubsetRandomSampler
7 |
8 |
9 | def get_num_of_each_class(args, num_samples_to_split, test_dataset):
10 | class_sum = np.array([0.] * args.n_classes)
11 | for idx in range(num_samples_to_split):
12 | class_sum += test_dataset.targets[idx]
13 | return class_sum.tolist()
14 |
15 | def valloss(net, test_dataset, args):
16 | net.eval()
17 | test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size*4, shuffle=False, num_workers=4)
18 | split_ratio = 0.1
19 |
20 | num_samples = len(test_loader.dataset)
21 | num_samples_to_split = int(num_samples * split_ratio)
22 |
23 | random_sampler = SubsetRandomSampler(range(num_samples_to_split))
24 |
25 | val_data_loader = DataLoader(
26 | dataset=test_loader.dataset,
27 | batch_size=args.batch_size*4,
28 | sampler=random_sampler
29 | )
30 | class_num_list = get_num_of_each_class(args, num_samples_to_split, test_dataset)
31 | loss_w = [num_samples_to_split / i for i in class_num_list]
32 | print(class_num_list, num_samples_to_split, loss_w)
33 | bce_criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(loss_w).cuda()) # include sigmoid
34 | batch_loss = []
35 | with torch.no_grad():
36 | for samples in val_data_loader:
37 | images, labels = samples["image"].to(args.device), samples["target"].to(args.device)
38 | _, outputs = net(images)
39 | val_loss = bce_criterion(outputs, labels)
40 | batch_loss.append(val_loss.item())
41 | val_loss_mean = np.array(batch_loss).mean()
42 | logging.info(val_loss_mean)
43 | return val_loss_mean
44 |
45 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.2.0
2 | astunparse==1.6.3
3 | attrs==22.1.0
4 | bleach==6.1.0
5 | bytecode==0.15.1
6 | cachetools==5.2.0
7 | certifi==2024.2.2
8 | charset-normalizer==2.1.1
9 | conda-pack==0.6.0
10 | contextlib2==21.6.0
11 | contourpy==1.0.5
12 | cupy @ file:///opt/conda/conda-bld/cupy_1610065936122/work
13 | cxxfilt==0.3.0
14 | cycler==0.11.0
15 | efficientnet-pytorch==0.7.1
16 | einops==0.5.0
17 | fastrlock==0.8.1
18 | filelock==3.8.0
19 | flatbuffers==22.12.6
20 | fonttools==4.37.4
21 | gast==0.4.0
22 | google-auth==2.12.0
23 | google-auth-oauthlib==0.4.6
24 | google-pasta==0.2.0
25 | grpcio==1.49.1
26 | h5py==3.7.0
27 | hausdorff==0.2.6
28 | huggingface-hub==0.10.0
29 | idna==3.4
30 | imageio==2.22.1
31 | imbalanced-learn==0.11.0
32 | imblearn==0.0
33 | importlib-metadata==5.0.0
34 | iniconfig==1.1.1
35 | joblib==1.2.0
36 | kaggle==1.6.14
37 | keras==2.11.0
38 | kiwisolver==1.4.4
39 | libclang==14.0.6
40 | llvmlite==0.39.1
41 | Markdown==3.4.1
42 | MarkupSafe==2.1.1
43 | matplotlib==3.6.1
44 | MedPy==0.4.0
45 | mkl-fft==1.3.1
46 | mkl-random @ file:///tmp/build/80754af9/mkl_random_1626186066731/work
47 | mkl-service==2.4.0
48 | ml-collections==0.1.1
49 | monai==1.1.0
50 | munch==4.0.0
51 | networkx==2.8.7
52 | nibabel==4.0.2
53 | numba==0.56.2
54 | numpy @ file:///croot/numpy_and_numpy_base_1672336185480/work
55 | oauthlib==3.2.1
56 | opencv-python==4.4.0.46
57 | opt-einsum==3.3.0
58 | packaging==21.3
59 | pandas==1.5.0
60 | Pillow==9.2.0
61 | pluggy==1.0.0
62 | pretrainedmodels==0.7.4
63 | protobuf==3.19.6
64 | py==1.11.0
65 | pyasn1==0.4.8
66 | pyasn1-modules==0.2.8
67 | pyparsing==3.0.9
68 | pytest==7.1.3
69 | python-dateutil==2.8.2
70 | python-slugify==8.0.4
71 | pytz==2022.4
72 | PyWavelets==1.4.1
73 | PyYAML==6.0
74 | requests==2.28.1
75 | requests-oauthlib==1.3.1
76 | rsa==4.9
77 | scikit-image==0.19.3
78 | scikit-learn==1.1.3
79 | scipy==1.9.2
80 | seaborn==0.12.0
81 | setproctitle==1.3.2
82 | SimpleITK==2.2.0
83 | six @ file:///tmp/build/80754af9/six_1644875935023/work
84 | sklearn==0.0.post1
85 | swin-window-process==0.0.0
86 | tb-nightly==2.12.0a20221114
87 | tensorboard-data-server==0.6.1
88 | tensorboard-plugin-wit==1.8.1
89 | tensorboardX==2.5.1
90 | tensorflow==2.11.0
91 | tensorflow-estimator==2.11.0
92 | tensorflow-io-gcs-filesystem==0.28.0
93 | termcolor==1.1.0
94 | text-unidecode==1.3
95 | thop==0.1.1.post2209072238
96 | threadpoolctl==3.1.0
97 | tifffile==2022.10.10
98 | timm==0.4.12
99 | tomli==2.0.1
100 | torch==1.12.1+cu116
101 | torchaudio==0.12.1+cu116
102 | torchvision==0.13.1+cu116
103 | tqdm==4.64.1
104 | typing_extensions==4.4.0
105 | urllib3==1.22
106 | visualizerX==0.0.1
107 | vit-pytorch==0.35.8
108 | webencodings==0.5.1
109 | Werkzeug==2.2.2
110 | wrapt==1.14.1
111 | xarray==2022.11.0
112 | yacs==0.1.8
113 | zipp==3.9.0
114 |
--------------------------------------------------------------------------------
/utils/sampling.py:
--------------------------------------------------------------------------------
1 | # This file is borrowed from https://github.com/Xu-Jingyi/FedCorr/blob/main/util/sampling.py
2 |
3 | import numpy as np
4 |
5 |
6 | def iid_sampling(n_train, num_users, seed):
7 | np.random.seed(seed)
8 | num_items = int(n_train / num_users)
9 | dict_users, all_idxs = {}, [i for i in range(n_train)] # initial user and index for whole dataset
10 | for i in range(num_users):
11 | dict_users[i] = set(
12 | np.random.choice(all_idxs, num_items, replace=False)) # 'replace=False' make sure that there is no repeat
13 | all_idxs = list(set(all_idxs) - dict_users[i])
14 |
15 | for key in dict_users.keys():
16 | dict_users[key] = list(dict_users[key])
17 | return dict_users
18 |
19 |
20 | def non_iid_dirichlet_sampling(y_train, num_classes, p, num_users, seed, alpha_dirichlet):
21 | np.random.seed(seed)
22 | Phi = np.random.binomial(1, p, size=(num_users, num_classes)) # indicate the classes chosen by each client
23 | n_classes_per_client = np.sum(Phi, axis=1)
24 | while np.min(n_classes_per_client) == 0:
25 | invalid_idx = np.where(n_classes_per_client == 0)[0]
26 | Phi[invalid_idx] = np.random.binomial(1, p, size=(len(invalid_idx), num_classes))
27 | n_classes_per_client = np.sum(Phi, axis=1)
28 | Psi = [list(np.where(Phi[:, j] == 1)[0]) for j in range(num_classes)] # indicate the clients that choose each class
29 | num_clients_per_class = np.array([len(x) for x in Psi])
30 | dict_users = {}
31 | for class_i in range(num_classes+1):
32 | # all_idxs = np.where(y_train == class_i)[0]
33 | n_classes_per_sample = np.sum(y_train, axis=1)
34 | all_idxs = np.where(n_classes_per_sample == class_i)[0]
35 | # else_idxs = np.where(y_train[class_i*25907:(class_i+1)*25907, class_i] != 1)[0] + class_i*25907
36 | p_dirichlet = np.random.dirichlet([alpha_dirichlet] * num_clients_per_class[0])
37 | assignment = np.random.choice(Psi[0], size=len(all_idxs), p=p_dirichlet.tolist())
38 | # assignment_else = np.random.choice(Psi[class_i], size=len(else_idxs), p=p_dirichlet.tolist())
39 |
40 | for client_k in Psi[0]:
41 | if client_k in dict_users:
42 | # dict_users[client_k] = set(dict_users[client_k] | set(all_idxs[(assignment == client_k)]) | set(else_idxs[(assignment_else == client_k)]))
43 | dict_users[client_k] = set(dict_users[client_k] | set(all_idxs[(assignment == client_k)]))
44 | else:
45 | # dict_users[client_k] = set(np.concatenate((all_idxs[(assignment == client_k)], else_idxs[(assignment_else == client_k)])))
46 | dict_users[client_k] = set(all_idxs[(assignment == client_k)])
47 | for key in dict_users.keys():
48 | dict_users[key] = list(dict_users[key])
49 | return dict_users
50 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import random
4 | import shutil
5 | import sys
6 | import heapq
7 | import numpy as np
8 | import torch
9 | from tensorboardX import SummaryWriter
10 |
11 |
12 | def set_seed(seed):
13 | torch.manual_seed(seed)
14 | torch.cuda.manual_seed(seed)
15 | torch.cuda.manual_seed_all(seed)
16 | np.random.seed(seed)
17 | random.seed(seed)
18 |
19 |
20 | # def max_m_indices(arr, m):
21 | # max_m_values = heapq.nlargest(m, arr)
22 | # max_m_indices = [arr.index(value) for value in max_m_values]
23 | # return max_m_indices
24 | def max_m_indices(lst, n):
25 | elements_with_indices = list(enumerate(lst))
26 | sorted_elements = sorted(elements_with_indices, key=lambda x: x[1], reverse=True)
27 | top_n_elements = sorted_elements[:n]
28 | return [index for index, value in top_n_elements]
29 |
30 |
31 | def min_n_indices(lst, n):
32 | elements_with_indices = list(enumerate(lst))
33 | sorted_elements = sorted(elements_with_indices, key=lambda x: x[1])
34 | bottom_n_elements = sorted_elements[:n]
35 | return [index for index, value in bottom_n_elements]
36 | # def min_n_indices(arr, n):
37 | # min_n_values = heapq.nsmallest(n, arr)
38 | # min_n_indices = [arr.index(value) for value in min_n_values]
39 | # return min_n_indices
40 |
41 |
42 | def set_output_files(args):
43 | outputs_dir = 'outputs_' + str(args.dataset) + '_' + str(
44 | args.alpha_dirichlet) + '_' + str(args.n_clients) + '_' + str(args.model) + '_' + str(args.n_classes-args.annotation_num) + '5000classmiss7_1037_iid_FedMLP_stage1glo'
45 | # outputs_dir = 'outputs_' + str(args.dataset) + '_' + str(args.model) + '_dataaug_' + str(
46 | # args.n_classes - args.annotation_num) + 'classmiss_' + 'loss_distribution' # demo
47 | if not os.path.exists(outputs_dir):
48 | os.mkdir(outputs_dir)
49 | exp_dir = os.path.join(outputs_dir, args.exp + '_' +
50 | str(args.batch_size) + '_' + str(args.base_lr) + '_' +
51 | str(args.rounds_warmup) + '_' +
52 | str(args.rounds_corr) + '_' +
53 | str(args.rounds_finetune) + '_' + str(args.local_ep))
54 | if not os.path.exists(exp_dir):
55 | os.mkdir(exp_dir)
56 | models_dir = os.path.join(exp_dir, 'models')
57 | if not os.path.exists(models_dir):
58 | os.mkdir(models_dir)
59 | logs_dir = os.path.join(exp_dir, 'logs')
60 | if not os.path.exists(logs_dir):
61 | os.mkdir(logs_dir)
62 | tensorboard_dir = os.path.join(exp_dir, 'tensorboard')
63 | # if not os.path.exists(tensorboard_dir):
64 | # os.mkdir(tensorboard_dir)
65 | code_dir = os.path.join(exp_dir, 'code')
66 | if os.path.exists(code_dir):
67 | shutil.rmtree(code_dir)
68 | os.mkdir(code_dir)
69 | # shutil.make_archives(code_dir, 'zip', base_dir='/home/szb/multilabel/')
70 |
71 | logging.basicConfig(filename=logs_dir+'/logs.txt', level=logging.INFO,
72 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
73 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
74 | logging.info(str(args))
75 | writer1 = SummaryWriter(tensorboard_dir + 'writer1')
76 | return writer1, models_dir
77 |
--------------------------------------------------------------------------------
/utils/multilabel_metrixs.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import numpy as np
4 |
5 | def Hamming_Loss(y_true, y_pred, classid=None):
6 | temp = 0
7 | for i in range(y_true.shape[0]):
8 | temp += np.size(y_true[i] == y_pred[i]) - np.count_nonzero(y_true[i] == y_pred[i])
9 | return temp / (y_true.shape[0] * y_true.shape[1])
10 |
11 |
12 | # def Recall(y_true, y_pred): # sample-wise
13 | # temp = 0
14 | # for i in range(y_true.shape[0]):
15 | # if sum(y_true[i]) == 0:
16 | # continue
17 | # temp += sum(np.logical_and(y_true[i], y_pred[i])) / sum(y_true[i])
18 | # return temp / y_true.shape[0]
19 |
20 |
21 | def Recall(y_true, y_pred, classid=None): # class-wise
22 | temp = 0
23 | if classid is not None:
24 | return sum(np.logical_and(y_true.T[classid], y_pred.T[classid])) / sum(y_true.T[classid])
25 | else:
26 | for i in range(y_true.T.shape[0]):
27 | temp += sum(np.logical_and(y_true.T[i], y_pred.T[i])) / sum(y_true.T[i])
28 | # print('class: {}, recall: '.format(i), sum(np.logical_and(y_true.T[i], y_pred.T[i])) / sum(y_true.T[i]))
29 | return temp / y_true.T.shape[0]
30 |
31 |
32 | def BACC(y_true, y_pred, classid=None):
33 | temp = 0
34 | if classid is not None:
35 | recall1 = sum(np.logical_and(y_true.T[classid], y_pred.T[classid])) / sum(y_true.T[classid])
36 | recall0 = sum(~np.logical_or(y_true.T[classid], y_pred.T[classid])) / (y_true.T[classid].size - np.count_nonzero(y_true.T[classid]))
37 | bacc = (recall0 + recall1) / 2
38 | return bacc
39 | else:
40 | for i in range(y_true.T.shape[0]):
41 | recall1 = sum(np.logical_and(y_true.T[i], y_pred.T[i])) / sum(y_true.T[i])
42 | recall0 = sum(~np.logical_or(y_true.T[i], y_pred.T[i])) / (y_true.T[i].size - np.count_nonzero(y_true.T[i]))
43 | bacc = (recall0 + recall1) / 2
44 | logging.info('BACC:class%d : %f' % (i, bacc))
45 | temp += bacc
46 | return temp / y_true.T.shape[0]
47 |
48 |
49 | def Precision(y_true, y_pred, classid=None):
50 | temp = 0
51 | if classid is not None:
52 | return sum(np.logical_and(y_true.T[classid], y_pred.T[classid])) / sum(y_pred.T[classid])
53 | else:
54 | for i in range(y_true.T.shape[0]):
55 | if sum(y_pred.T[i]) == 0:
56 | continue
57 | logging.info('P:class%d : %f' % (i, sum(np.logical_and(y_true.T[i], y_pred.T[i])) / sum(y_pred.T[i])))
58 | temp += sum(np.logical_and(y_true.T[i], y_pred.T[i])) / sum(y_pred.T[i])
59 | # print('class: {}, precision: '.format(i), sum(np.logical_and(y_true.T[i], y_pred.T[i])) / sum(y_pred.T[i]))
60 | return temp / y_true.T.shape[0]
61 |
62 |
63 | def F1Measure(y_true, y_pred, classid=None):
64 | temp = 0
65 | if classid is not None:
66 | return (2*sum(np.logical_and(y_true.T[classid], y_pred.T[classid]))) / (sum(y_true.T[classid])+sum(y_pred.T[classid]))
67 | else:
68 | for i in range(y_true.T.shape[0]):
69 | logging.info('f1:class%d : %f' % (i, (2*sum(np.logical_and(y_true.T[i], y_pred.T[i]))) / (sum(y_true.T[i])+sum(y_pred.T[i]))))
70 | temp += (2*sum(np.logical_and(y_true.T[i], y_pred.T[i]))) / (sum(y_true.T[i])+sum(y_pred.T[i]))
71 | return temp / y_true.T.shape[0]
72 |
73 |
74 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FedMLP
2 | This is the official implementation for the paper: "FedMLP: Federated Multi-Label Medical Image Classiffcation under Task Heterogeneity", which is accepted at MICCAI'24 (Early Accept, top 11% in total 2869 submissions).
3 |
4 |
5 |
6 |
7 | ## Introduction
8 | Cross-silo federated learning (FL) enables decentralized organizations to collaboratively train models while preserving data privacy and has made signiffcant progress in medical image classiffcation. One common assumption is task homogeneity where each client has access to all classes during training. However, in clinical practice, given a multi-label classiffcation task, constrained by the level of medical knowledge and the prevalence of diseases, each institution may diagnose only partial categories, resulting in task heterogeneity. How to pursue effective multi-label medical image classiffcation under task heterogeneity is under-explored. In this paper, we first formulate such a realistic label missing setting in the multi-label FL domain and propose a two-stage method FedMLP to combat class missing from two aspects: pseudo label tagging and global knowledge learning. The former utilizes a warmed-up model to generate class prototypes and select samples with high confidence to supplement missing labels, while the latter uses a global model as a teacher for consistency regularization to prevent forgetting missing class knowledge. Experiments on two publicly-available medical datasets validate the superiority of FedMLP against the state-of-the-art both federated semi-supervised and noisy label learning approaches under task
9 | heterogeneity.
10 |
11 |
12 |
13 |
14 |
15 | ## Related Work
16 | - RSCFed [[paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Liang_RSCFed_Random_Sampling_Consensus_Federated_Semi-Supervised_Learning_CVPR_2022_paper.pdf)] [[code](https://github.com/XMed-Lab/RSCFed)]
17 | - FedNoRo [[paper](https://arxiv.org/pdf/2305.05230)] [[code](https://github.com/wnn2000/FedNoRo)]
18 | - CBAFed [[paper](https://openaccess.thecvf.com/content/CVPR2023/papers/Li_Class_Balanced_Adaptive_Pseudo_Labeling_for_Federated_Semi-Supervised_Learning_CVPR_2023_paper.pdf)] [[code](https://github.com/minglllli/CBAFed)]
19 |
20 |
21 | ## Dataset
22 | Please download the ICH dataset from [kaggle](https://www.kaggle.com/c/rsna-intracranial-hemorrhage-detection) and preprocess it follow this [notebook](https://www.kaggle.com/guiferviz/prepare-dataset-resizing-and-saving-as-png). Please download the ChestXray14 dataset from this [link](https://nihcc.app.box.com/v/ChestXray-NIHCC).
23 |
24 |
25 | ## Requirements
26 | We recommend using conda to setup the environment, See the `requirements.txt` for environment configuration.
27 |
28 |
29 | ## Citation
30 | If this repository is useful for your research, please consider citing:
31 | ```shell
32 | @inproceedings{sun2024fedmlp,
33 | title={FedMLP: Federated Multi-label Medical Image Classification Under Task Heterogeneity},
34 | author={Sun, Zhaobin and Wu, Nannan and Shi, Junjie and Yu, Li and Cheng, Kwang-Ting and Yan, Zengqiang},
35 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
36 | pages={394--404},
37 | year={2024},
38 | organization={Springer}
39 | }
40 | ```
41 |
42 |
43 | ## Contact
44 | For any questions, please contact 'zbsun@hust.edu.cn'.
45 |
--------------------------------------------------------------------------------
/dataset/all_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from PIL import Image
4 | import pandas as pd
5 |
6 | import torch
7 | from torch.utils.data import Dataset
8 |
9 |
10 | class ChestXray14(Dataset):
11 | def __init__(self, datapath, mode, transform=None):
12 | self.datapath = datapath
13 | self.mode = mode
14 | self.transform = transform
15 |
16 | assert self.mode in ["train", "test"]
17 | csv_file = os.path.join("/home/szb/multilabel/", self.mode + "_dataset_8class.csv")
18 | self.file = pd.read_csv(csv_file)
19 |
20 | self.image_list = self.file["Image Index"].values
21 | self.targets = self.file.iloc[0:, 1:].values.astype(np.float32)
22 |
23 | def __getitem__(self, index: int):
24 | image_id, target = self.image_list[index], self.targets[index]
25 | image = self.read_image(image_id)
26 |
27 | if self.transform is not None:
28 | if isinstance(self.transform, tuple):
29 | image1 = self.transform[0](image)
30 | image2 = self.transform[1](image)
31 | return {"image_aug_1": image1,
32 | "image_aug_2": image2,
33 | "target": target,
34 | "index": index,
35 | "image_id": image_id}
36 | else:
37 | image = self.transform(image)
38 | return {"image": image,
39 | "target": target,
40 | "index": index,
41 | "image_id": image_id}
42 |
43 | def __len__(self):
44 | return len(self.targets)
45 |
46 | def read_image(self, image_id):
47 | image_path = os.path.join("/home/szb/ChestXray14/images/image/", image_id)
48 | image = Image.open(image_path).convert("RGB")
49 | return image
50 |
51 |
52 | class ICH(Dataset):
53 | def __init__(self, datapath, mode, transform=None):
54 | self.datapath = datapath
55 | self.mode = mode
56 | self.transform = transform
57 |
58 | assert self.mode in ["train", "test"]
59 | csv_file = os.path.join("/home/szb/ICH_stage2/ICH_stage2/", self.mode + "_dataset_ICH.csv")
60 | # csv_file = os.path.join("/home/szb/ICH_stage2/ICH_stage2/", self.mode + '_demo.csv') # demo exp(5000 samples)
61 | self.file = pd.read_csv(csv_file)
62 |
63 | self.image_list = self.file["Image Index"].values
64 | self.targets = self.file.iloc[0:, 1:].values.astype(np.float32)
65 |
66 | def __getitem__(self, index: int):
67 | image_id, target = self.image_list[index], self.targets[index]
68 | image = self.read_image(image_id)
69 | if self.transform is not None:
70 | if isinstance(self.transform, tuple):
71 | image1 = self.transform[0](image)
72 | image2 = self.transform[1](image)
73 | return {"image_aug_1": image1,
74 | "image_aug_2": image2,
75 | "target": target,
76 | "index": index,
77 | "image_id": image_id}
78 | else:
79 | image = self.transform(image)
80 | return {"image": image,
81 | "target": target,
82 | "index": index,
83 | "image_id": image_id}
84 |
85 | def __len__(self):
86 | return len(self.targets)
87 |
88 | def read_image(self, image_id):
89 | image_path = os.path.join("/home/szb/ICH_stage2/ICH_stage2/png185k_512/", image_id)
90 | image = Image.open(image_path).convert("RGB")
91 | return image
92 |
--------------------------------------------------------------------------------
/model/efficientnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division, absolute_import
2 | from efficientnet_pytorch import EfficientNet
3 |
4 | pretrained_settings = {
5 | 'efficientnet': {
6 | 'imagenet': {
7 | 'input_space': 'RGB',
8 | 'input_size': [3, 224, 224],
9 | 'input_range': [0, 1],
10 | 'mean': [0.485, 0.456, 0.406],
11 | 'std': [0.229, 0.224, 0.225],
12 | 'num_classes': 1000
13 | }
14 | },
15 | }
16 |
17 |
18 | def initialize_pretrained_model(model, num_classes, settings):
19 | assert num_classes == settings['num_classes'],'num_classes should be {}, but is {}'.format(
20 | settings['num_classes'], num_classes)
21 | model.input_space = settings['input_space']
22 | model.input_size = settings['input_size']
23 | model.input_range = settings['input_range']
24 | model.mean = settings['mean']
25 | model.std = settings['std']
26 |
27 |
28 | def efficientnet_b0(num_classes=1000, pretrained='imagenet'):
29 | model = EfficientNet.from_pretrained('efficientnet-b0', advprop=False)
30 | if pretrained is not None:
31 | settings = pretrained_settings['efficientnet'][pretrained]
32 | initialize_pretrained_model(model, num_classes, settings)
33 | return model
34 |
35 |
36 | def efficientnet_b1(num_classes=1000, pretrained='imagenet'):
37 | model = EfficientNet.from_pretrained('efficientnet-b1', advprop=False)
38 | if pretrained is not None:
39 | settings = pretrained_settings['efficientnet'][pretrained]
40 | initialize_pretrained_model(model, num_classes, settings)
41 | return model
42 |
43 |
44 | def efficientnet_b2(num_classes=1000, pretrained='imagenet'):
45 | model = EfficientNet.from_pretrained('efficientnet-b2', advprop=False)
46 | if pretrained is not None:
47 | settings = pretrained_settings['efficientnet'][pretrained]
48 | initialize_pretrained_model(model, num_classes, settings)
49 | return model
50 |
51 |
52 | def efficientnet_b3(num_classes=1000, pretrained='imagenet'):
53 | model = EfficientNet.from_pretrained('efficientnet-b3', advprop=False)
54 | if pretrained is not None:
55 | settings = pretrained_settings['efficientnet'][pretrained]
56 | initialize_pretrained_model(model, num_classes, settings)
57 | return model
58 |
59 |
60 | def efficientnet_b4(num_classes=1000, pretrained='imagenet'):
61 | model = EfficientNet.from_pretrained('efficientnet-b4', advprop=False)
62 | if pretrained is not None:
63 | settings = pretrained_settings['efficientnet'][pretrained]
64 | initialize_pretrained_model(model, num_classes, settings)
65 | return model
66 |
67 |
68 | def efficientnet_b5(num_classes=1000, pretrained='imagenet'):
69 | model = EfficientNet.from_pretrained('efficientnet-b5', advprop=False)
70 | if pretrained is not None:
71 | settings = pretrained_settings['efficientnet'][pretrained]
72 | initialize_pretrained_model(model, num_classes, settings)
73 | return model
74 |
75 |
76 | def efficientnet_b6(num_classes=1000, pretrained='imagenet'):
77 | model = EfficientNet.from_pretrained('efficientnet-b6', advprop=False)
78 | if pretrained is not None:
79 | settings = pretrained_settings['efficientnet'][pretrained]
80 | initialize_pretrained_model(model, num_classes, settings)
81 | return model
82 |
83 |
84 | def efficientnet_b7(num_classes=1000, pretrained='imagenet'):
85 | model = EfficientNet.from_pretrained('efficientnet-b7', advprop=False)
86 | if pretrained is not None:
87 | settings = pretrained_settings['efficientnet'][pretrained]
88 | initialize_pretrained_model(model, num_classes, settings)
89 | return
--------------------------------------------------------------------------------
/utils/FedAvg.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 |
4 | import torch
5 | import numpy as np
6 |
7 | def FedAvg(w, dict_len):
8 | w_avg = copy.deepcopy(w[0])
9 | for k in w_avg.keys():
10 | w_avg[k] = w_avg[k] * dict_len[0]
11 | for i in range(1, len(w)):
12 | w_avg[k] += w[i][k] * dict_len[i]
13 | w_avg[k] = w_avg[k] / sum(dict_len)
14 | return w_avg
15 |
16 | def Fed_w(w, weight):
17 | w_avg = copy.deepcopy(w[0])
18 | for k in w_avg.keys():
19 | w_avg[k] = w_avg[k] * weight[0]
20 | for i in range(1, len(w)):
21 | w_avg[k] += w[i][k] * weight[i]
22 | w_avg[k] = w_avg[k] / sum(weight)
23 | return w_avg
24 |
25 | def RSCFed(DMA, w_locals, K, dict_len, M):
26 | w_sub = []
27 | for group in DMA:
28 | w_select = []
29 | N_total = 0
30 | for id in group:
31 | w_select.append(w_locals[id])
32 | N_total += dict_len[id]
33 | w_avg = Fed_w(w_select, [1]*K)
34 | w = []
35 | for id in group:
36 | a = dict_len[id] / N_total
37 | b = math.exp((-0.01)*(model_dist(w_locals[id], w_avg)/dict_len[id]))
38 | w.append(a*b)
39 | w_sub.append(Fed_w(w_select, w))
40 | w_glob = Fed_w(w_sub, [1]*M)
41 | return w_glob
42 |
43 | def model_dist(w_1, w_2):
44 | assert w_1.keys() == w_2.keys(), "Error: cannot compute distance between dict with different keys"
45 | dist_total = torch.zeros(1).float()
46 | for key in w_1:
47 | dist = torch.norm(w_1[key].cpu() - w_2[key].cpu())
48 | dist_total += dist.cpu()
49 | return dist_total.cpu().item()
50 |
51 | def FedAvg_tao(t, weight, class_active_client_list = None):
52 | if class_active_client_list is None:
53 | t_avg = np.array([0.]*len(t[0]))
54 | for i, tao in enumerate(t):
55 | t_avg += tao * float(weight[i])
56 | t_avg = t_avg / float(sum(weight))
57 | return t_avg
58 | else:
59 | t_avg = np.array([0.] * len(t[0]))
60 | for cls, cls_active_clients in enumerate(class_active_client_list):
61 | weight_sum = 0.
62 | for i, tao in enumerate(t):
63 | if i in cls_active_clients:
64 | t_avg[cls] += tao[cls] * float(weight[i])
65 | weight_sum += float(weight[i])
66 | if len(cls_active_clients) == 0:
67 | t_avg[cls] = 1.
68 | else:
69 | t_avg[cls] = t_avg[cls] / weight_sum
70 | return t_avg
71 |
72 | def FedAvg_proto(Prototypes, weight, class_active_client_list):
73 | Prototype_avg = torch.zeros((len(Prototypes[0]), len(Prototypes[0][0])))
74 | # Prototype_avg = np.array([torch.zeros_like(Prototypes[0][0])] * len(Prototypes[0]))
75 | for cls, cls_active_clients in enumerate(class_active_client_list):
76 | Prototype_class_0_avg = torch.zeros_like(Prototypes[0][0])
77 | Prototype_class_1_avg = torch.zeros_like(Prototypes[0][0])
78 | for client_id in cls_active_clients:
79 | Prototype_class_0_avg = Prototypes[client_id][2*cls] * weight[client_id] + Prototype_class_0_avg
80 | Prototype_class_1_avg = Prototypes[client_id][2*cls+1] * weight[client_id] + Prototype_class_1_avg
81 | # print(cls)
82 | # print(cls_active_clients)
83 | # print(Prototype_class_0_avg)
84 | # print(Prototype_class_1_avg)
85 | Prototype_class_0_avg = Prototype_class_0_avg / np.sum(np.array(weight)[cls_active_clients])
86 | Prototype_class_1_avg = Prototype_class_1_avg / np.sum(np.array(weight)[cls_active_clients])
87 | # print(Prototype_class_0_avg)
88 | # print(Prototype_class_1_avg)
89 | Prototype_avg[2*cls] = Prototype_class_0_avg
90 | Prototype_avg[2*cls+1] = Prototype_class_1_avg
91 | # print(Prototype_avg)
92 | # input()
93 | return Prototype_avg
94 |
95 | def FedAvg_rela(Prototypes, weight, class_active_client_list):
96 | Prototype_avg = torch.zeros((len(Prototypes[0]), len(Prototypes[0][0])))
97 | for cls, cls_active_clients in enumerate(class_active_client_list):
98 | Prototype_class_avg = torch.zeros_like(Prototypes[0][0])
99 | for client_id in cls_active_clients:
100 | Prototype_class_avg = Prototypes[client_id][cls] * weight[client_id] + Prototype_class_avg
101 | Prototype_class_avg = Prototype_class_avg / np.sum(np.array(weight)[cls_active_clients])
102 | Prototype_avg[cls] = Prototype_class_avg
103 | return Prototype_avg
--------------------------------------------------------------------------------
/utils/options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def args_parser():
5 | parser = argparse.ArgumentParser()
6 |
7 | # system setting
8 | parser.add_argument('--deterministic', type=int, default=1,
9 | help='whether use deterministic training')
10 | parser.add_argument('--seed', type=int, default=1037, help='random seed2023,1037')
11 | parser.add_argument('--gpu', type=str, default='2', help='GPU to use')
12 |
13 | # basic setting
14 | parser.add_argument('--exp', type=str,
15 | default='FedMLP', help='experiment name')
16 | parser.add_argument('--dataset', type=str,
17 | default='ChestXray14', help='dataset name')
18 | parser.add_argument('--model', type=str,
19 | default='Resnet18', help='model name')
20 | parser.add_argument('--batch_size', type=int,
21 | default=32, help='batch_size per gpu')
22 | parser.add_argument('--feature_dim', type=int,
23 | default=512, help='feature_dim of ResNet18')
24 | parser.add_argument('--base_lr', type=float, default=3e-5,
25 | help='base learning rate,ICH=3e-5,ChestXray14=3e-6')
26 | parser.add_argument('--pretrained', type=int, default=1)
27 | parser.add_argument('--train', type=int, default=1)
28 |
29 | # PSL setting
30 | parser.add_argument('--annotation_num', type=int,
31 | default='1', help='The number of categories annotated by each client.')
32 |
33 | # for FL
34 | parser.add_argument('--n_clients', type=int, default=8,
35 | help='number of users')
36 | parser.add_argument('--n_classes', type=int, default=8,
37 | help='number of classes')
38 | parser.add_argument('--iid', type=int, default=1, help="i.i.d. or non-i.i.d.")
39 | parser.add_argument('--alpha_dirichlet', type=float,
40 | default=0.5, help='parameter for non-iid')
41 | parser.add_argument('--local_ep', type=int, default=1, help='local epoch')
42 | parser.add_argument('--rounds_warmup', type=int, default=500, help='rounds')
43 | parser.add_argument('--rounds_corr', type=int, default=200, help='rounds')
44 | parser.add_argument('--rounds_distillation', type=int, default=200, help='rounds')
45 | parser.add_argument('--rounds_finetune', type=int, default=50, help='rounds')
46 | parser.add_argument('--rounds_FedMLP_stage1', type=int, default=50, help='rounds')
47 | parser.add_argument('--U', type=float, default=0.7, help='tao_upper_bound')
48 | parser.add_argument('--L', type=float, default=0.3, help='tao_lower_bound')
49 | parser.add_argument('--tao_min', type=float, default=0.1, help='tao_min')
50 | parser.add_argument('--runs', type=int, default=1, help='training seed')
51 |
52 | # RoFL
53 | parser.add_argument('--forget_rate', type=float, default=0.2, help='forget_rate')
54 | parser.add_argument('--num_gradual', type=int, default=10, help='T_k')
55 | parser.add_argument('--T_pl', type=int, help='T_pl: When to start using global guided pseudo labeling', default=100)
56 | parser.add_argument('--lambda_cen', type=float, help='lambda_cen', default=1.0)
57 | parser.add_argument('--lambda_e', type=float, help='lambda_e', default=0.8)
58 |
59 | # FedMLP_abu
60 | parser.add_argument('--difficulty_estimate', type=int, default=1, help='tao=1 or cal')
61 | parser.add_argument('--miss_client_difficulty', type=int, default=1, help='consider or not(tao agg method)')
62 | parser.add_argument('--mixup', type=int, default=1, help='y/n')
63 | parser.add_argument('--clean_threshold', type=float, default=0.005, help='clean_threshold')
64 | parser.add_argument('--noise_threshold', type=float, default=0.01, help='noise_threshold')
65 |
66 | # FedLSR
67 | parser.add_argument('--t_w', type=int, default=40, help='clean_threshold')
68 | # FedIRM
69 | parser.add_argument('--rounds_FedIRM_sup', type=int, default=20, help='rounds')
70 | parser.add_argument('--consistency', type=float, default=1, help='consistency')
71 | parser.add_argument('--consistency_rampup', type=float, default=30, help='consistency_rampup')
72 | parser.add_argument('--ema_decay', type=float, default=0.99, help='ema_decay')
73 | # FedNoRo
74 | parser.add_argument('--rounds_FedNoRo_warmup', type=int, default=500, help='rounds')
75 | parser.add_argument('--begin', type=int, default=10, help='ramp up begin')
76 | parser.add_argument('--end', type=int, default=499, help='ramp up end')
77 | parser.add_argument('--a', type=float, default=0.8, help='a')
78 | #CBAFed
79 | parser.add_argument('--rounds_CBAFed_warmup', type=int, default=50, help='rounds')
80 | args = parser.parse_args()
81 | return args
82 |
--------------------------------------------------------------------------------
/utils/FedNoRo.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class LogitAdjust_Multilabel(nn.Module):
10 | def __init__(self, cls_num_list, num, tau=1, weight=None):
11 | super(LogitAdjust_Multilabel, self).__init__()
12 | cls_num_list = torch.cuda.FloatTensor(cls_num_list)
13 | self.cls_p_list = cls_num_list / num
14 | self.weight = weight
15 |
16 | def forward(self, x, target):
17 | x_m = x.clone()
18 | # for i in range(len(self.cls_p_list)): #abu1
19 | # x_m[:, i] = (x_m[:, i]*self.cls_p_list[i])/(x_m[:, i]*self.cls_p_list[i] + (1-x_m[:, i])*(1-self.cls_p_list[i]))
20 | # nan_mask = torch.isnan(x_m)
21 | # x_m[nan_mask] = 0.
22 | return F.binary_cross_entropy(x_m, target, weight=self.weight, reduction='none')
23 |
24 |
25 | class LA_KD(nn.Module):
26 | def __init__(self, cls_num_list, num, active_class_list_client, negative_class_list_client, tau=1, weight=None):
27 | super(LA_KD, self).__init__()
28 | cls_num_list = torch.cuda.FloatTensor(cls_num_list)
29 | self.active_class_list_client = active_class_list_client
30 | self.negative_class_list_client = negative_class_list_client
31 | self.cls_p_list = cls_num_list / num
32 | self.weight = weight
33 | self.bce = LogitAdjust_Multilabel(cls_num_list, num)
34 |
35 | def forward(self, x, target, soft_target, w_kd):
36 | bceloss = self.bce(x, target)[:, self.active_class_list_client].sum()/(len(x) * len(self.active_class_list_client))
37 | kl = F.mse_loss(x, soft_target, reduction='none')[:, self.negative_class_list_client].sum()/(len(x) * len(self.negative_class_list_client))
38 | return w_kd * kl + (1 - w_kd) * bceloss
39 |
40 |
41 | def get_output(loader, net, args, sigmoid=False, criterion=None):
42 | net.eval()
43 | torch.backends.cudnn.deterministic = True
44 | torch.backends.cudnn.benchmark = False
45 | with torch.no_grad():
46 | for i, samples in enumerate(loader):
47 | images = samples["image"].to(args.device)
48 | labels = samples["target"].to(args.device)
49 | if sigmoid == True:
50 | _, outputs = net(images)
51 | outputs = torch.sigmoid(outputs)
52 | else:
53 | _, outputs = net(images)
54 | if criterion is not None:
55 | loss = criterion(outputs, labels)
56 | if i == 0:
57 | output_whole = np.array(outputs.cpu())
58 | if criterion is not None:
59 | loss_whole = np.array(loss.cpu())
60 | else:
61 | output_whole = np.concatenate(
62 | (output_whole, outputs.cpu()), axis=0)
63 | if criterion is not None:
64 | loss_whole = np.concatenate(
65 | (loss_whole, loss.cpu()), axis=0)
66 | if criterion is not None:
67 | return output_whole, loss_whole
68 | else:
69 | return output_whole
70 |
71 |
72 | def sigmoid_rampup(current, begin, end):
73 | """Exponential rampup from https://arxiv.org/abs/1610.02242"""
74 | current = np.clip(current, begin, end)
75 | phase = 1.0 - (current-begin) / (end-begin)
76 | return float(np.exp(-5.0 * phase * phase))
77 |
78 |
79 | def get_current_consistency_weight(rnd, begin, end):
80 | # Consistency ramp-up from https://arxiv.org/abs/1610.02242
81 | return sigmoid_rampup(rnd, begin, end)
82 |
83 |
84 | def DaAgg(w, dict_len, clean_clients, noisy_clients):
85 | client_weight = np.array(dict_len)
86 | client_weight = client_weight / client_weight.sum()
87 | distance = np.zeros(len(dict_len))
88 | for n_idx in noisy_clients:
89 | dis = []
90 | for c_idx in clean_clients:
91 | dis.append(model_dist(w[n_idx], w[c_idx]))
92 | distance[n_idx] = min(dis)
93 | distance = distance / distance.max()
94 | client_weight = client_weight * np.exp(-distance)
95 | client_weight = client_weight / client_weight.sum()
96 | # print(client_weight)
97 |
98 | w_avg = copy.deepcopy(w[0])
99 | for k in w_avg.keys():
100 | w_avg[k] = w_avg[k] * client_weight[0]
101 | for i in range(1, len(w)):
102 | w_avg[k] += w[i][k] * client_weight[i]
103 | return w_avg
104 |
105 |
106 | def model_dist(w_1, w_2):
107 | assert w_1.keys() == w_2.keys(), "Error: cannot compute distance between dict with different keys"
108 | dist_total = torch.zeros(1).float()
109 | for key in w_1.keys():
110 | if "int" in str(w_1[key].dtype):
111 | continue
112 | dist = torch.norm(w_1[key] - w_2[key])
113 | dist_total += dist.cpu()
114 |
115 | return dist_total.cpu().item()
--------------------------------------------------------------------------------
/utils/evaluations.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import numpy as np
4 |
5 | import torch
6 | import torch.optim
7 | from matplotlib import pyplot as plt
8 | from torch.utils.data import DataLoader
9 | from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score, roc_auc_score, confusion_matrix, recall_score, roc_curve, auc
10 | from sklearn.metrics import average_precision_score
11 |
12 | from utils.multilabel_metrixs import Recall, Hamming_Loss, F1Measure, Precision, BACC
13 |
14 |
15 | def globaltest(net, test_dataset, args):
16 | auroc = 0
17 | net.eval()
18 | test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size*4, shuffle=False, num_workers=4)
19 | all_preds = np.array([])
20 | all_probs = []
21 | all_labels = np.array(test_dataset.targets)
22 | with torch.no_grad():
23 | for samples in test_loader:
24 | images = samples["image"].to(args.device)
25 | _, outputs = net(images)
26 | probs = torch.sigmoid(outputs) # soft predict
27 | accuracy_th = 0.5
28 | preds = probs > accuracy_th # hard predict
29 | all_probs.append(probs.detach().cpu())
30 | if all_preds.ndim == 1:
31 | all_preds = preds.detach().cpu().numpy()
32 | else:
33 | all_preds = np.concatenate([all_preds, preds.detach().cpu().numpy()], axis=0)
34 |
35 | all_probs = torch.cat(all_probs).numpy()
36 | assert all_probs.shape[0] == len(test_dataset)
37 | assert all_probs.shape[1] == args.n_classes
38 | logging.info(np.sum(all_preds, axis=0))
39 | logging.info(np.sum(all_labels, axis=0))
40 |
41 | APs = []
42 |
43 | for label_index in range(all_labels.shape[1]):
44 | true_labels = all_labels[:, label_index]
45 | predicted_scores = all_probs[:, label_index]
46 | ap = average_precision_score(true_labels, predicted_scores)
47 | APs.append(ap)
48 |
49 | mAP = torch.tensor(APs).mean()
50 |
51 | bacc = BACC(all_labels, all_preds)
52 | R = Recall(all_labels, all_preds)
53 | hamming_loss = Hamming_Loss(all_labels, all_preds)
54 | F1 = F1Measure(all_labels, all_preds)
55 | P = Precision(all_labels, all_preds)
56 | colors = ['#FF0000', '#00FF00', '#0000FF', '#FFFF00', '#00FFFF', '#FF00FF', '#000000', '#808080', '#C0C0C0', '#800000',
57 | '#008000', '#000080', '#808000', '#008080']
58 | # print(all_probs)
59 | for i in range(len(all_labels.T)):
60 | fpr, tpr, th = roc_curve(all_labels.T[i], all_probs.T[i], pos_label=1)
61 | # ROCprint(fpr, tpr, i, colors[i])
62 | auroc += auc(fpr, tpr)
63 | # print('class: {}, auc: '.format(i), auc(fpr, tpr))
64 | # plt.show()
65 | auroc /= len(all_labels.T)
66 |
67 | return {"mAP": mAP,
68 | "BACC": bacc,
69 | "R": R,
70 | "F1": F1,
71 | "auc": auroc,
72 | "P": P,
73 | "hamming_loss": hamming_loss}
74 |
75 |
76 | def ROCprint(fpr, tpr, name, colorname):
77 | plt.plot(fpr, tpr, lw=1, label='{} (AUC={:.3f})'.format(name, auc(fpr, tpr)), color=colorname)
78 | plt.plot([0, 1], [0, 1], '--', lw=1, color='grey')
79 | plt.axis('square')
80 | plt.xlim([0, 1])
81 | plt.ylim([0, 1])
82 | plt.xlabel('False Positive Rate', fontsize=20)
83 | plt.ylabel('True Positive Rate', fontsize=20)
84 | plt.title('ROC Curve', fontsize=25)
85 | plt.legend(loc='lower right', fontsize=20)
86 | plt.savefig('multi_models_roc.png')
87 |
88 |
89 | def classtest(net, test_dataset, args, classid):
90 | auroc = 0
91 | net.eval()
92 | test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size*4, shuffle=False, num_workers=4)
93 | all_preds = np.array([])
94 | all_probs = []
95 | all_labels = np.array(test_dataset.targets)
96 | with torch.no_grad():
97 | for samples in test_loader:
98 | images = samples["image"].to(args.device)
99 | _, outputs = net(images)
100 | probs = torch.sigmoid(outputs) # soft predict
101 | accuracy_th = 0.5
102 | preds = probs > accuracy_th # hard predict
103 | all_probs.append(probs.detach().cpu())
104 | if all_preds.ndim == 1:
105 | all_preds = preds.detach().cpu().numpy()
106 | else:
107 | all_preds = np.concatenate([all_preds, preds.detach().cpu().numpy()], axis=0)
108 |
109 | all_probs = torch.cat(all_probs).numpy()
110 | assert all_probs.shape[0] == len(test_dataset)
111 | assert all_probs.shape[1] == args.n_classes
112 | logging.info(np.sum(all_preds, axis=0))
113 | logging.info(np.sum(all_labels, axis=0))
114 |
115 | bacc = BACC(all_labels, all_preds, classid)
116 | R = Recall(all_labels, all_preds, classid)
117 | F1 = F1Measure(all_labels, all_preds, classid)
118 | P = Precision(all_labels, all_preds, classid)
119 | colors = ['#FF0000', '#00FF00', '#0000FF', '#FFFF00', '#00FFFF', '#FF00FF', '#000000', '#808080', '#C0C0C0', '#800000',
120 | '#008000', '#000080', '#808000', '#008080']
121 | # print(all_probs)
122 | # for i in range(len(all_labels.T)):
123 | # fpr, tpr, th = roc_curve(all_labels.T[i], all_probs.T[i], pos_label=1)
124 | # # ROCprint(fpr, tpr, i, colors[i])
125 | # auroc += auc(fpr, tpr)
126 | # # print('class: {}, auc: '.format(i), auc(fpr, tpr))
127 | # # plt.show()
128 | # auroc /= len(all_labels.T)
129 |
130 | return {"BACC": bacc,
131 | "R": R,
132 | "F1": F1,
133 | "P": P}
134 |
--------------------------------------------------------------------------------
/model/all_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import copy
5 | from torchvision import models
6 | import pretrainedmodels
7 | from .efficientnet import efficientnet_b0
8 | from .efficientnet import efficientnet_b1
9 | from .efficientnet import efficientnet_b2
10 | from .efficientnet import efficientnet_b3
11 | from .efficientnet import efficientnet_b4
12 | from .efficientnet import efficientnet_b5
13 | from .efficientnet import efficientnet_b6
14 | from .efficientnet import efficientnet_b7
15 |
16 |
17 | class FCNorm(nn.Module):
18 | def __init__(self, in_features, out_features):
19 | super(FCNorm, self).__init__()
20 | self.weight = nn.Parameter(torch.Tensor(in_features, out_features))
21 | self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
22 | self.s = 30
23 |
24 | def forward(self, x):
25 | out = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0))
26 | return self.s * out
27 |
28 |
29 | def get_model(model_name, pretrained=False):
30 | """Returns a CNN model
31 | Args:
32 | model_name: model name
33 | pretrained: True or False
34 | Returns:
35 | model: the desired model
36 | Raises:
37 | ValueError: If model name is not recognized.
38 | """
39 | if pretrained == False:
40 | pt = None
41 | else:
42 | pt = 'imagenet'
43 | print(f"model pretrained: {pretrained}, mode: {pt}", )
44 |
45 | if model_name == 'Vgg11':
46 | return models.vgg11(pretrained=pretrained)
47 | elif model_name == 'Vgg13':
48 | return models.vgg13(pretrained=pretrained)
49 | elif model_name == 'Vgg16':
50 | return models.vgg16(pretrained=pretrained)
51 | elif model_name == 'Vgg19':
52 | return models.vgg19(pretrained=pretrained)
53 | elif model_name == 'Resnet18':
54 | return models.resnet18(pretrained=pretrained)
55 | elif model_name == 'Resnet34':
56 | return models.resnet34(pretrained=pretrained)
57 | elif model_name == 'Resnet50':
58 | return models.resnet50(pretrained=pretrained)
59 | elif model_name == 'Resnet101':
60 | return models.resnet101(pretrained=pretrained)
61 | elif model_name == 'Resnet152':
62 | return models.resnet152(pretrained=pretrained)
63 | elif model_name == 'Dense121':
64 | return models.densenet121(pretrained=pretrained)
65 | elif model_name == 'Dense169':
66 | return models.densenet169(pretrained=pretrained)
67 | elif model_name == 'Dense201':
68 | return models.densenet201(pretrained=pretrained)
69 | elif model_name == 'Dense161':
70 | return models.densenet161(pretrained=pretrained)
71 | elif model_name == 'SENet50':
72 | return pretrainedmodels.__dict__['se_resnet50'](num_classes=1000, pretrained=pt)
73 | elif model_name == 'SENet101':
74 | return pretrainedmodels.__dict__['se_resnet101'](num_classes=1000, pretrained=pt)
75 | elif model_name == 'SENet152':
76 | return pretrainedmodels.__dict__['se_resnet152'](num_classes=1000, pretrained=pt)
77 | elif model_name == 'SENet154':
78 | return pretrainedmodels.__dict__['senet154'](num_classes=1000, pretrained=pt)
79 | elif model_name == 'Efficient_b0':
80 | return efficientnet_b0(num_classes=1000, pretrained=pt)
81 | elif model_name == 'Efficient_b1':
82 | return efficientnet_b1(num_classes=1000, pretrained=pt)
83 | elif model_name == 'Efficient_b2':
84 | return efficientnet_b2(num_classes=1000, pretrained=pt)
85 | elif model_name == 'Efficient_b3':
86 | return efficientnet_b3(num_classes=1000, pretrained=pt)
87 | elif model_name == 'Efficient_b4':
88 | return efficientnet_b4(num_classes=1000, pretrained=pt)
89 | elif model_name == 'Efficient_b5':
90 | return efficientnet_b5(num_classes=1000, pretrained=pt)
91 | elif model_name == 'Efficient_b6':
92 | return efficientnet_b6(num_classes=1000, pretrained=pt)
93 | elif model_name == 'Efficient_b7':
94 | return efficientnet_b7(num_classes=1000, pretrained=pt)
95 | else:
96 | raise ValueError('Name of model unknown %s' % model_name)
97 |
98 |
99 | def modify_last_layer(model_name, model, num_classes, normed=False, bias=True):
100 | """modify the last layer of the CNN model to fit the num_classes
101 | Args:
102 | model_name: model name
103 | model: CNN model
104 | num_classes: class number
105 | Returns:
106 | model: the desired model
107 | """
108 |
109 | if 'Vgg' in model_name:
110 | num_ftrs = model.classifier._modules['6'].in_features
111 | model.classifier._modules['6'] = classifier(num_ftrs, num_classes, normed, bias)
112 | last_layer = model.classifier._modules['6']
113 | elif 'Dense' in model_name:
114 | num_ftrs = model.classifier.in_features
115 | model.classifier = classifier(num_ftrs, num_classes, normed, bias)
116 | last_layer = model.classifier
117 | elif 'Resnet' in model_name:
118 | num_ftrs = model.fc.in_features
119 | model.fc = classifier(num_ftrs, num_classes, normed, bias)
120 | last_layer = model.fc
121 | elif 'Efficient' in model_name:
122 | num_ftrs = model._fc.in_features
123 | model._fc = classifier(num_ftrs, num_classes, normed, bias)
124 | last_layer = model._fc
125 | else:
126 | num_ftrs = model.last_linear.in_features
127 | model.last_linear = classifier(num_ftrs, num_classes, normed, bias)
128 | last_layer = model.last_linear
129 | # print(model)
130 | return model, last_layer
131 |
132 |
133 | def classifier(num_features, num_classes, normed, bias):
134 | if normed:
135 | last_linear = FCNorm(num_features, num_classes)
136 | else:
137 | last_linear = nn.Linear(num_features, num_classes, bias=bias)
138 | return last_linear
139 |
140 |
141 | def get_feature_length(model_name, model):
142 | """get the feature length of the last feature layer
143 | Args:
144 | model_name: model name
145 | model: CNN model
146 | Returns:
147 | num_ftrs: the feature length of the last feature layer
148 | """
149 | if 'Vgg' in model_name:
150 | num_ftrs = model.classifier._modules['6'].in_features
151 | elif 'Dense' in model_name:
152 | num_ftrs = model.classifier.in_features
153 | elif 'Resnet' in model_name:
154 | num_ftrs = model.fc.in_features
155 | elif 'Efficient' in model_name:
156 | num_ftrs = model._fc.in_features
157 | elif 'RegNet' in model_name:
158 | num_ftrs = model.head.fc.in_features
159 | else:
160 | num_ftrs = model.last_linear.in_features
161 |
162 | return num_ftrs
--------------------------------------------------------------------------------
/utils/FixMatch.py:
--------------------------------------------------------------------------------
1 | # code in this file is adpated from
2 | # https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py
3 | # https://github.com/google-research/fixmatch/blob/master/third_party/auto_augment/augmentations.py
4 | # https://github.com/google-research/fixmatch/blob/master/libml/ctaugment.py
5 | import logging
6 | import random
7 |
8 | import numpy as np
9 | import PIL
10 | import PIL.ImageOps
11 | import PIL.ImageEnhance
12 | import PIL.ImageDraw
13 | from PIL import Image
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 | PARAMETER_MAX = 10
18 |
19 |
20 | def AutoContrast(img, **kwarg):
21 | return PIL.ImageOps.autocontrast(img)
22 |
23 |
24 | def Brightness(img, v, max_v, bias=0):
25 | v = _float_parameter(v, max_v) + bias
26 | return PIL.ImageEnhance.Brightness(img).enhance(v)
27 |
28 |
29 | def Color(img, v, max_v, bias=0):
30 | v = _float_parameter(v, max_v) + bias
31 | return PIL.ImageEnhance.Color(img).enhance(v)
32 |
33 |
34 | def Contrast(img, v, max_v, bias=0):
35 | v = _float_parameter(v, max_v) + bias
36 | return PIL.ImageEnhance.Contrast(img).enhance(v)
37 |
38 |
39 | def Cutout(img, v, max_v, bias=0):
40 | if v == 0:
41 | return img
42 | v = _float_parameter(v, max_v) + bias
43 | v = int(v * min(img.size))
44 | return CutoutAbs(img, v)
45 |
46 |
47 | def CutoutAbs(img, v, **kwarg):
48 | w, h = img.size
49 | x0 = np.random.uniform(0, w)
50 | y0 = np.random.uniform(0, h)
51 | x0 = int(max(0, x0 - v / 2.))
52 | y0 = int(max(0, y0 - v / 2.))
53 | x1 = int(min(w, x0 + v))
54 | y1 = int(min(h, y0 + v))
55 | xy = (x0, y0, x1, y1)
56 | # gray
57 | color = (127, 127, 127)
58 | img = img.copy()
59 | PIL.ImageDraw.Draw(img).rectangle(xy, color)
60 | return img
61 |
62 |
63 | def Equalize(img, **kwarg):
64 | return PIL.ImageOps.equalize(img)
65 |
66 |
67 | def Identity(img, **kwarg):
68 | return img
69 |
70 |
71 | def Invert(img, **kwarg):
72 | return PIL.ImageOps.invert(img)
73 |
74 |
75 | def Posterize(img, v, max_v, bias=0):
76 | v = _int_parameter(v, max_v) + bias
77 | return PIL.ImageOps.posterize(img, v)
78 |
79 |
80 | def Rotate(img, v, max_v, bias=0):
81 | v = _int_parameter(v, max_v) + bias
82 | if random.random() < 0.5:
83 | v = -v
84 | return img.rotate(v)
85 |
86 |
87 | def Sharpness(img, v, max_v, bias=0):
88 | v = _float_parameter(v, max_v) + bias
89 | return PIL.ImageEnhance.Sharpness(img).enhance(v)
90 |
91 |
92 | def ShearX(img, v, max_v, bias=0):
93 | v = _float_parameter(v, max_v) + bias
94 | if random.random() < 0.5:
95 | v = -v
96 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
97 |
98 |
99 | def ShearY(img, v, max_v, bias=0):
100 | v = _float_parameter(v, max_v) + bias
101 | if random.random() < 0.5:
102 | v = -v
103 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
104 |
105 |
106 | def Solarize(img, v, max_v, bias=0):
107 | v = _int_parameter(v, max_v) + bias
108 | return PIL.ImageOps.solarize(img, 256 - v)
109 |
110 |
111 | def SolarizeAdd(img, v, max_v, bias=0, threshold=128):
112 | v = _int_parameter(v, max_v) + bias
113 | if random.random() < 0.5:
114 | v = -v
115 | img_np = np.array(img).astype(np.int)
116 | img_np = img_np + v
117 | img_np = np.clip(img_np, 0, 255)
118 | img_np = img_np.astype(np.uint8)
119 | img = Image.fromarray(img_np)
120 | return PIL.ImageOps.solarize(img, threshold)
121 |
122 |
123 | def TranslateX(img, v, max_v, bias=0):
124 | v = _float_parameter(v, max_v) + bias
125 | if random.random() < 0.5:
126 | v = -v
127 | v = int(v * img.size[0])
128 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
129 |
130 |
131 | def TranslateY(img, v, max_v, bias=0):
132 | v = _float_parameter(v, max_v) + bias
133 | if random.random() < 0.5:
134 | v = -v
135 | v = int(v * img.size[1])
136 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
137 |
138 |
139 | def _float_parameter(v, max_v):
140 | return float(v) * max_v / PARAMETER_MAX
141 |
142 |
143 | def _int_parameter(v, max_v):
144 | return int(v * max_v / PARAMETER_MAX)
145 |
146 |
147 | def fixmatch_augment_pool():
148 | # FixMatch paper
149 | augs = [(AutoContrast, None, None),
150 | (Brightness, 0.9, 0.05),
151 | (Color, 0.9, 0.05),
152 | (Contrast, 0.9, 0.05),
153 | (Equalize, None, None),
154 | (Identity, None, None),
155 | (Posterize, 4, 4),
156 | (Rotate, 30, 0),
157 | (Sharpness, 0.9, 0.05),
158 | (ShearX, 0.3, 0),
159 | (ShearY, 0.3, 0),
160 | (Solarize, 256, 0),
161 | (TranslateX, 0.3, 0),
162 | (TranslateY, 0.3, 0)]
163 | return augs
164 |
165 |
166 | def my_augment_pool():
167 | # Test
168 | augs = [(AutoContrast, None, None),
169 | (Brightness, 1.8, 0.1),
170 | (Color, 1.8, 0.1),
171 | (Contrast, 1.8, 0.1),
172 | (Cutout, 0.2, 0),
173 | (Equalize, None, None),
174 | (Invert, None, None),
175 | (Posterize, 4, 4),
176 | (Rotate, 30, 0),
177 | (Sharpness, 1.8, 0.1),
178 | (ShearX, 0.3, 0),
179 | (ShearY, 0.3, 0),
180 | (Solarize, 256, 0),
181 | (SolarizeAdd, 110, 0),
182 | (TranslateX, 0.45, 0),
183 | (TranslateY, 0.45, 0)]
184 | return augs
185 |
186 |
187 | class RandAugmentPC(object):
188 | def __init__(self, n, m):
189 | assert n >= 1
190 | assert 1 <= m <= 10
191 | self.n = n
192 | self.m = m
193 | self.augment_pool = my_augment_pool()
194 |
195 | def __call__(self, img):
196 | ops = random.choices(self.augment_pool, k=self.n)
197 | for op, max_v, bias in ops:
198 | prob = np.random.uniform(0.2, 0.8)
199 | if random.random() + prob >= 1:
200 | img = op(img, v=self.m, max_v=max_v, bias=bias)
201 | img = CutoutAbs(img, int(32*0.5))
202 | return img
203 |
204 |
205 | class RandAugmentMC(object):
206 | def __init__(self, n, m):
207 | assert n >= 1
208 | assert 1 <= m <= 10
209 | self.n = n
210 | self.m = m
211 | self.augment_pool = fixmatch_augment_pool()
212 |
213 | def __call__(self, img):
214 | ops = random.choices(self.augment_pool, k=self.n)
215 | for op, max_v, bias in ops:
216 | v = np.random.randint(1, self.m)
217 | if random.random() < 0.5:
218 | img = op(img, v=v, max_v=max_v, bias=bias)
219 | img = CutoutAbs(img, int(32*0.5))
220 | return img
--------------------------------------------------------------------------------
/dataset/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | from torchvision.transforms import transforms
5 |
6 | from dataset.all_dataset import ChestXray14, ICH
7 | from utils.FixMatch import RandAugmentMC
8 | from utils.sampling import iid_sampling, non_iid_dirichlet_sampling
9 |
10 |
11 | def get_dataset(args):
12 | if args.dataset == "ChestXray14":
13 | root = "/home/szb/multilabel/onehot-label-PA.csv"
14 | args.n_classes = 8
15 | args.n_clients = 8
16 | args.num_users = args.n_clients
17 | args.input_channel = 3
18 |
19 | # normalize = transforms.Normalize([0.498, 0.498, 0.498],
20 | # [0.228, 0.228, 0.228])
21 | normalize = transforms.Normalize([0.485, 0.456, 0.406],
22 | [0.229, 0.224, 0.225])
23 | if args.exp == 'FedAVG' or args.exp == 'RoFL' or args.exp == 'FedNoRo' or args.exp == 'CBAFed':
24 | train_transform = transforms.Compose([
25 | transforms.Resize((224, 224)),
26 | transforms.RandomAffine(degrees=10, translate=(0.02, 0.02)),
27 | transforms.RandomHorizontalFlip(),
28 | transforms.ToTensor(),
29 | normalize,
30 | ])
31 | test_transform = transforms.Compose([
32 | transforms.Resize((224, 224)),
33 | transforms.ToTensor(),
34 | normalize,
35 | ])
36 | train_dataset = ChestXray14(root, "train", train_transform)
37 | test_dataset = ChestXray14(root, "test", test_transform)
38 |
39 | elif args.exp == 'RSCFed' or args.exp == 'FedPN' or args.exp == 'FedLSR' or args.exp == 'FedIRM':
40 | train_transform1 = transforms.Compose([
41 | transforms.Resize((224, 224)),
42 | transforms.RandomAffine(degrees=10, translate=(0.02, 0.02)),
43 | transforms.RandomHorizontalFlip(),
44 | transforms.ToTensor(),
45 | normalize,
46 | ])
47 | train_transform2 = transforms.Compose([
48 | transforms.Resize((224, 224)),
49 | transforms.RandomAffine(degrees=10, translate=(0.02, 0.02)),
50 | transforms.RandomHorizontalFlip(),
51 | transforms.ToTensor(),
52 | normalize,
53 | ])
54 | test_transform = transforms.Compose([
55 | transforms.Resize((224, 224)),
56 | transforms.ToTensor(),
57 | normalize,
58 | ])
59 | train_dataset = ChestXray14(root, "train", (train_transform1, train_transform2))
60 | test_dataset = ChestXray14(root, "test", test_transform)
61 |
62 | elif args.exp == 'FedAVG+FixMatch':
63 | train_weak_transform = transforms.Compose([
64 | transforms.Resize((224, 224)),
65 | transforms.RandomAffine(degrees=10, translate=(0.02, 0.02)),
66 | transforms.RandomHorizontalFlip(),
67 | transforms.ToTensor(),
68 | normalize,
69 | ])
70 | train_strong_transform = transforms.Compose([
71 | transforms.Resize((224, 224)),
72 | transforms.RandomAffine(degrees=10, translate=(0.02, 0.02)),
73 | transforms.RandomHorizontalFlip(),
74 | RandAugmentMC(n=2, m=10),
75 | transforms.ToTensor(),
76 | normalize,
77 | ])
78 | test_transform = transforms.Compose([
79 | transforms.Resize((224, 224)),
80 | transforms.ToTensor(),
81 | normalize,
82 | ])
83 | train_dataset = ChestXray14(root, "train", (train_weak_transform, train_strong_transform))
84 | test_dataset = ChestXray14(root, "test", test_transform)
85 |
86 | elif args.dataset == "ICH":
87 | root = "/home/szb/ICH_stage2/ICH_stage2/data_png185k_512.csv"
88 | args.n_classes = 5
89 | args.n_clients = 5
90 | args.num_users = args.n_clients
91 | args.input_channel = 3
92 |
93 | # normalize = transforms.Normalize([0.498, 0.498, 0.498],
94 | # [0.228, 0.228, 0.228])
95 | normalize = transforms.Normalize([0.485, 0.456, 0.406],
96 | [0.229, 0.224, 0.225])
97 | if args.exp == 'FedAVG' or args.exp == 'RoFL' or args.exp == 'FedNoRo' or args.exp == 'CBAFed':
98 | train_transform = transforms.Compose([
99 | transforms.Resize((224, 224)),
100 | transforms.RandomAffine(degrees=10, translate=(0.02, 0.02)),
101 | transforms.RandomHorizontalFlip(),
102 | transforms.ToTensor(),
103 | normalize,
104 | ])
105 | test_transform = transforms.Compose([
106 | transforms.Resize((224, 224)),
107 | transforms.ToTensor(),
108 | normalize,
109 | ])
110 | train_dataset = ICH(root, "train", train_transform)
111 | test_dataset = ICH(root, "test", test_transform)
112 |
113 | elif args.exp == 'RSCFed' or args.exp == 'FedPN' or args.exp == 'FedLSR' or args.exp == 'FedIRM':
114 | train_transform1 = transforms.Compose([
115 | transforms.Resize((224, 224)),
116 | transforms.RandomAffine(degrees=10, translate=(0.02, 0.02)),
117 | transforms.RandomHorizontalFlip(),
118 | transforms.ToTensor(),
119 | normalize,
120 | ])
121 | train_transform2 = transforms.Compose([
122 | transforms.Resize((224, 224)),
123 | transforms.RandomAffine(degrees=10, translate=(0.02, 0.02)),
124 | transforms.RandomHorizontalFlip(),
125 | transforms.ToTensor(),
126 | normalize,
127 | ])
128 | test_transform = transforms.Compose([
129 | transforms.Resize((224, 224)),
130 | transforms.ToTensor(),
131 | normalize,
132 | ])
133 | train_dataset = ICH(root, "train", (train_transform1, train_transform2))
134 | test_dataset = ICH(root, "test", test_transform)
135 |
136 | elif args.exp == 'FedAVG+FixMatch':
137 | train_weak_transform = transforms.Compose([
138 | transforms.Resize((224, 224)),
139 | transforms.RandomAffine(degrees=10, translate=(0.02, 0.02)),
140 | transforms.RandomHorizontalFlip(),
141 | transforms.ToTensor(),
142 | normalize,
143 | ])
144 | train_strong_transform = transforms.Compose([
145 | transforms.Resize((224, 224)),
146 | transforms.RandomAffine(degrees=10, translate=(0.02, 0.02)),
147 | transforms.RandomHorizontalFlip(),
148 | RandAugmentMC(n=2, m=10),
149 | transforms.ToTensor(),
150 | normalize,
151 | ])
152 | test_transform = transforms.Compose([
153 | transforms.Resize((224, 224)),
154 | transforms.ToTensor(),
155 | normalize,
156 | ])
157 | train_dataset = ICH(root, "train", (train_weak_transform, train_strong_transform))
158 | test_dataset = ICH(root, "test", test_transform)
159 |
160 | else:
161 | exit("Error: unrecognized dataset")
162 |
163 | n_train = len(train_dataset)
164 | y_train = np.array(train_dataset.targets)
165 | assert n_train == len(y_train)
166 | print(n_train)
167 |
168 | # Load or Generate 'dict_users'
169 | if args.iid == 0: # non-iid
170 | if os.path.exists(f"non-iid-dictusers/{str(args.dataset)+'_'+str(args.seed)+'_'+str(args.n_clients)+'_'+str(args.alpha_dirichlet)}.npy"):
171 | dict_users = np.load(f"non-iid-dictusers/{str(args.dataset)+'_'+str(args.seed)+'_'+str(args.n_clients)+'_'+str(args.alpha_dirichlet)}.npy", allow_pickle=True).item()
172 | else:
173 | dict_users = non_iid_dirichlet_sampling(y_train, args.n_classes, 1.0, args.n_clients, seed=args.seed, alpha_dirichlet=args.alpha_dirichlet)
174 | np.save(f"non-iid-dictusers/{str(args.dataset)+'_'+str(args.seed)+'_'+str(args.n_clients)+'_'+str(args.alpha_dirichlet)}.npy", dict_users, allow_pickle=True)
175 | else:
176 | if os.path.exists(f"iid-dictusers/{str(args.dataset)+'_'+str(args.seed)+'_'+str(args.n_clients) + '5000'}.npy"):
177 | dict_users = np.load(f"iid-dictusers/{str(args.dataset)+'_'+str(args.seed)+'_'+str(args.n_clients) + '5000'}.npy", allow_pickle=True).item()
178 | else:
179 | dict_users = iid_sampling(n_train, args.n_clients, args.seed)
180 | np.save(f"iid-dictusers/{str(args.dataset)+'_'+str(args.seed)+'_'+str(args.n_clients) + '5000'}.npy", dict_users, allow_pickle=True)
181 | return train_dataset, test_dataset, dict_users
182 |
183 |
184 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import random
4 | from collections import Counter
5 | from copy import deepcopy
6 |
7 | import numpy as np
8 | import torch
9 | from numpy import where
10 | from sklearn.mixture import GaussianMixture
11 | from torch import nn
12 | from torch.backends import cudnn
13 | from torch.utils.data import DataLoader
14 | from tqdm import tqdm
15 |
16 | from dataset.dataset import get_dataset
17 | from model.build_model import build_model
18 | from utils.FedAvg import FedAvg, RSCFed, FedAvg_tao, FedAvg_proto, FedAvg_rela
19 | from utils.FedNoRo import get_output, get_current_consistency_weight, DaAgg
20 | from utils.evaluations import globaltest, classtest
21 | from utils.feature_visual import tnse_Visual
22 | from utils.local_training import LocalUpdate
23 | from utils.options import args_parser
24 | from utils.utils import set_seed, set_output_files
25 | from utils.valloss_cal import valloss
26 |
27 | np.set_printoptions(threshold=np.inf)
28 |
29 | if __name__ == '__main__':
30 | args = args_parser()
31 | args.num_users = args.n_clients
32 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
33 | args.device = "cuda" if torch.cuda.is_available() else "cpu"
34 |
35 | # ------------------------------ deterministic or not ------------------------------
36 | if args.deterministic:
37 | cudnn.benchmark = False
38 | cudnn.deterministic = True
39 | set_seed(args.seed)
40 |
41 | # ------------------------------ output files ------------------------------
42 | writer1, models_dir = set_output_files(args)
43 |
44 | # ------------------------------ dataset ------------------------------
45 | dataset_train, dataset_test, dict_users = get_dataset(args)
46 | # for key in dict_users.keys():
47 | # print(len(dict_users[key]))
48 | # dict_users[key] = random.sample(dict_users[key], 2500)
49 | # np.save(f"iid-dictusers/{str(args.dataset) + '_' + str(args.seed) + '_' + str(args.n_clients) + '5000'}.npy",
50 | # dict_users, allow_pickle=True)
51 |
52 |
53 | logging.info(
54 | f"train: {np.sum(dataset_train.targets, axis=0)}, total: {len(dataset_train.targets)}")
55 | logging.info(
56 | f"test: {np.sum(dataset_test.targets, axis=0)}, total: {len(dataset_test.targets)}")
57 |
58 | row_idx_1, column_idx_1 = where(dataset_train.targets == 1)
59 | class_pos_idx_1 = []
60 | for i in range(args.n_classes):
61 | class_pos_idx_1.append(row_idx_1[where(column_idx_1 == i)[0]])
62 |
63 | p_pos_1 = 0.
64 | class_neg_idx_1 = []
65 | for i in range(args.n_classes):
66 | class_neg_idx_1.append(np.random.choice(class_pos_idx_1[i], int((1-p_pos_1)*len(class_pos_idx_1[i])), replace=False)) # list(array, ...)
67 | # --------------------- Partially Labelling ---------------------------
68 | if args.train:
69 | # ------------------------------ local settings ------------------------------
70 | user_id = list(range(args.n_clients))
71 | dict_len = [len(dict_users[idx]) for idx in user_id]
72 | trainer_locals_1 = []
73 | netglob = build_model(args)
74 | for i in user_id:
75 | trainer_locals_1.append(LocalUpdate(
76 | args, i, deepcopy(dataset_train), dict_users[i], class_pos_idx_1, class_neg_idx_1, dataset_test=dataset_test, active_class_list=[i], student=deepcopy(netglob).to(args.device),
77 | teacher_neg=deepcopy(netglob).to(args.device), teacher_act=deepcopy(netglob).to(args.device))) # student initial is global
78 | # trainer_locals_1.append(LocalUpdate(
79 | # args, i, deepcopy(dataset_train), dict_users[i], class_pos_idx_1, class_neg_idx_1, dataset_test=dataset_test, student=deepcopy(netglob).to(args.device),
80 | # teacher_neg=deepcopy(netglob).to(args.device),
81 | # teacher_act=deepcopy(netglob).to(args.device))) # student initial is global
82 |
83 |
84 | # ------------------------------ begin training ------------------------------
85 | for run in range(args.runs):
86 | set_seed(int(run))
87 | netclass = build_model(args)
88 | logging.info(f"\n===============================> beging, run: {run} <===============================\n")
89 | w_class_fl = []
90 | active_class_list = [] # [, , , , ]
91 | negetive_class_list = [] # [, , , , ]
92 | class_active_client_list = []
93 | class_negative_client_list = []
94 |
95 | # FeMLP
96 | tao = [0] * args.n_classes
97 | Prototype = []
98 | # RoFL: Initialize f_G
99 | f_G = torch.randn(2*args.n_classes, args.feature_dim, device=args.device) # [[cls0_0],[cls0_1],[cls1_0]...]
100 | forget_rate_schedule = []
101 | forget_rate = args.forget_rate
102 | exponent = 1
103 | forget_rate_schedule = np.ones(args.rounds_warmup) * forget_rate
104 | forget_rate_schedule[:args.num_gradual] = np.linspace(0, forget_rate ** exponent, args.num_gradual)
105 | # ------------------------------ stage1:warm-up ------------------------------
106 | for rnd in range(0, args.rounds_warmup):
107 | # if rnd == 49:
108 | # netglob.load_state_dict(torch.load(
109 | # "/home/szb/multilabel/chest 5 miss model/model_warmup_globdistill_0.3_0.1_49.pth"))
110 | # negetive_class_list = [[1,2,3,4], [0,2,3,4], [0,1,3,4], [0,1,2,4], [0,1,2,3]]
111 | # active_class_list = [[0], [1], [2], [3], [4]]
112 | # class_active_client_list = [[0, 1, 2, 3, 4], [0, 1, 2], [0, 2, 3, 4], [0, 1, 3, 4], [1, 2, 3, 4]]
113 | # class_negative_client_list = [[], [3, 4], [1], [2], [0]]
114 | if args.exp == 'RSCFed':
115 | M = 10
116 | K = 6
117 | DMA = []
118 | for i in range(M):
119 | random_numbers = random.sample(range(args.n_clients), K)
120 | DMA.append(random_numbers)
121 | print('DMA: ', DMA)
122 | # if args.exp == 'RoFL':
123 | # if len(forget_rate_schedule) > 0:
124 | # args.forget_rate = forget_rate_schedule[rnd]
125 | # logging.info('remember_rate: %f' % (1-args.forget_rate))
126 | # if args.exp == 'FedNoRo' and rnd >= args.rounds_FedNoRo_warmup:
127 | if args.exp == 'FedNoRo':
128 | weight_kd = get_current_consistency_weight(rnd, args.begin, args.end) * args.a
129 | logging.info("\n------------------------------> training, run: %d, round: %d <------------------------------" % (run, rnd))
130 | w_locals, loss_locals = [], []
131 | class_num_lists, data_nums = [], []
132 | taos, Prototypes = [], []
133 | # RoFL
134 | f_locals = []
135 | for i in tqdm(user_id): # training over the subset
136 | local = trainer_locals_1[i]
137 | if args.exp == 'FedAVG':
138 | w_local, loss_local, loss_false_negetive, loss_true_negetive, negetive_class_list_client, active_class_list_client = local.train(rnd,
139 | net=deepcopy(netglob).to(args.device), writer1=writer1)
140 | if args.exp == 'FedNoRo':
141 | if rnd < args.rounds_FedNoRo_warmup:
142 | w_local, loss_local, loss_false_negetive, loss_true_negetive, negetive_class_list_client, active_class_list_client = local.train_FedNoRo(
143 | i, rnd,
144 | net=deepcopy(netglob).to(args.device), writer1=writer1, weight_kd=weight_kd)
145 | # else:
146 | # w_local, loss_local, loss_false_negetive, loss_true_negetive, negetive_class_list_client, active_class_list_client = local.train_FedNoRo(
147 | # i, rnd,
148 | # net=deepcopy(netglob).to(args.device), writer1=writer1, weight_kd = weight_kd, clean_clients=clean_clients, noisy_clients = noisy_clients)
149 | if args.exp == 'CBAFed':
150 | if rnd < args.rounds_CBAFed_warmup:
151 | w_local, loss_local, loss_false_negetive, loss_true_negetive, negetive_class_list_client, active_class_list_client, class_num_list, data_num = local.train_CBAFed(
152 | rnd, net=deepcopy(netglob).to(args.device))
153 | else:
154 | w_local, loss_local, loss_false_negetive, loss_true_negetive, negetive_class_list_client, active_class_list_client, class_num_list, data_num = local.train_CBAFed(
155 | rnd, net=deepcopy(netglob).to(args.device), pt=pt, tao=tao)
156 | class_num_lists.append(deepcopy(class_num_list))
157 | data_nums.append(deepcopy(data_num))
158 | if args.exp == 'FedAVG+FixMatch':
159 | w_local, loss_local, loss_false_negetive, loss_true_negetive, negetive_class_list_client, active_class_list_client = local.train_FixMatch(rnd,
160 | net=deepcopy(netglob).to(args.device))
161 | if args.exp == 'FedLSR':
162 | w_local, loss_local, loss_false_negetive, loss_true_negetive, negetive_class_list_client, active_class_list_client = local.train_FedLSR(rnd,
163 | net=deepcopy(netglob).to(args.device))
164 | if args.exp == 'RSCFed':
165 | w_local, loss_local, loss_false_negetive, loss_true_negetive, negetive_class_list_client, active_class_list_client = local.train_RSCFed(rnd,
166 | net=deepcopy(netglob).to(args.device))
167 | # if args.exp == 'RoFL':
168 | # w_local, loss_local, f_k = local.train_RoFL(deepcopy(netglob).to(args.device), deepcopy(f_G).to(args.device), rnd)
169 | if args.exp == 'FedIRM':
170 | if rnd < args.rounds_FedIRM_sup - 1:
171 | w_local, loss_local, loss_false_negetive, loss_true_negetive, negetive_class_list_client, active_class_list_client = local.train_FedIRM(
172 | rnd, Prototype, writer1, negetive_class_list=None, active_class_list_client_i=None,
173 | net=deepcopy(netglob).to(args.device))
174 | else:
175 | w_local, loss_local, loss_false_negetive, loss_true_negetive, negetive_class_list_client, active_class_list_client, Prototype_local = local.train_FedIRM(
176 | rnd, Prototype, writer1, negetive_class_list[i], active_class_list[i],
177 | net=deepcopy(netglob).to(args.device))
178 | if args.exp == 'FeMLP':
179 | if rnd < args.rounds_FeMLP_stage1-1:
180 | w_local, loss_local, loss_false_negetive, loss_true_negetive, negetive_class_list_client, active_class_list_client = local.train_FedMLP(
181 | rnd, tao, Prototype, writer1, negetive_class_list=None, active_class_list_client_i=None, net=deepcopy(netglob).to(args.device))
182 | else:
183 | w_local, loss_local, loss_false_negetive, loss_true_negetive, negetive_class_list_client, active_class_list_client, tao_local, Prototype_local = local.train_FeMLP(
184 | rnd, tao, Prototype, writer1, negetive_class_list[i], active_class_list[i], net=deepcopy(netglob).to(args.device))
185 | if rnd == 0 and args.exp != 'RoFL':
186 | active_class_list.append(active_class_list_client)
187 | negetive_class_list.append(negetive_class_list_client)
188 | # store every updated model
189 | if args.exp == 'FedMLP' and rnd >= args.rounds_FedMLP_stage1-1:
190 | taos.append(deepcopy(tao_local))
191 | Prototypes.append(deepcopy(Prototype_local))
192 | if args.exp == 'FedIRM' and rnd >= args.rounds_FedIRM_sup-1:
193 | Prototypes.append(deepcopy(Prototype_local.detach().cpu()))
194 | if args.exp == 'RoFL':
195 | f_locals.append(f_k)
196 | w_locals.append(deepcopy(w_local))
197 | loss_locals.append(deepcopy(loss_local))
198 | writer1.add_scalar(f'train_run{run}/warm-up-loss/client{i}', loss_local, rnd)
199 | # aggregation
200 | if rnd == 0 and args.exp != 'RoFL':
201 | for i in range(len(active_class_list)):
202 | class_active_client_list.append([])
203 | class_negative_client_list.append([])
204 | for j in range(len(active_class_list)):
205 | if i in active_class_list[j]:
206 | class_active_client_list[i].append(j)
207 | if i in negetive_class_list[j]:
208 | class_negative_client_list[i].append(j)
209 | logging.info(class_active_client_list)
210 | logging.info(class_negative_client_list)
211 | assert i == user_id[-1]
212 | assert len(w_locals) == len(dict_len) == args.n_clients
213 | if args.exp == 'RSCFed':
214 | w_glob_fl = RSCFed(DMA, w_locals, K, dict_len, M)
215 | netglob.load_state_dict(deepcopy(w_glob_fl))
216 | elif args.exp == 'FedMLP':
217 | if rnd < args.rounds_FedMLP_stage1 - 1:
218 | w_glob_fl = FedAvg(w_locals, dict_len)
219 | netglob.load_state_dict(deepcopy(w_glob_fl))
220 | else:
221 | w_glob_fl = FedAvg(w_locals, dict_len)
222 | netglob.load_state_dict(deepcopy(w_glob_fl))
223 | tao = FedAvg_tao(taos, dict_len, class_negative_client_list)
224 | print('avg_tao: ', tao)
225 | # if args.miss_client_difficulty == 1:
226 | # tao = FedAvg_tao(taos, dict_len)
227 | # else:
228 | # tao = FedAvg_tao(taos, dict_len, class_active_client_list)
229 | # print('avg_tao: ', tao)
230 | if rnd == args.rounds_FedMLP_stage1 - 1:
231 | Prototype = FedAvg_proto(Prototypes, dict_len, class_active_client_list)
232 | else:
233 | lam = 1.0
234 | Prototype = (1-lam)*Prototype + lam*FedAvg_proto(Prototypes, dict_len, class_active_client_list)
235 | print('rnd: ', rnd, 'ok')
236 | if (rnd + 1) % 10 == 0 and run == 0:
237 | torch.save(netglob.state_dict(), models_dir + f'/model_warmup_globdistill_0.3_0.1_{rnd}.pth')
238 | elif args.exp == 'FedIRM':
239 | if rnd < args.rounds_FedIRM_sup - 1:
240 | w_glob_fl = FedAvg(w_locals, dict_len)
241 | netglob.load_state_dict(deepcopy(w_glob_fl))
242 | else:
243 | w_glob_fl = FedAvg(w_locals, dict_len)
244 | netglob.load_state_dict(deepcopy(w_glob_fl))
245 | if rnd == args.rounds_FedIRM_sup - 1:
246 | print(Prototypes)
247 | Prototype = FedAvg_rela(Prototypes, dict_len, class_active_client_list)
248 | print(Prototype)
249 | else:
250 | lam = 1.0
251 | Prototype = (1-lam)*Prototype + lam*FedAvg_rela(Prototypes, dict_len, class_active_client_list)
252 | print('rnd: ', rnd, 'ok')
253 | # elif args.exp == 'RoFL':
254 | # w_glob_fl = FedAvg(w_locals, dict_len)
255 | # netglob.load_state_dict(deepcopy(w_glob_fl))
256 | # sim = torch.nn.CosineSimilarity(dim=1)
257 | # tmp = 0
258 | # w_sum = 0
259 | # for i in f_locals:
260 | # sim_weight = sim(f_G, i).reshape(2*args.n_classes, 1)
261 | # w_sum += sim_weight
262 | # tmp += sim_weight * i
263 | # # print(sim_weight)
264 | # # print(i)
265 | # for i in range(len(w_sum)):
266 | # if w_sum[i, 0] == 0:
267 | # w_sum[i, 0] = 1
268 | # f_G = torch.div(tmp, w_sum)
269 | elif args.exp == 'FedNoRo':
270 | if rnd < args.rounds_FedNoRo_warmup:
271 | w_glob_fl = FedAvg(w_locals, dict_len)
272 | netglob.load_state_dict(deepcopy(w_glob_fl))
273 | elif args.exp == 'CBAFed':
274 | if rnd < args.rounds_CBAFed_warmup:
275 | if rnd % 5 != 0:
276 | w_glob_fl = FedAvg(w_locals, dict_len)
277 | netglob.load_state_dict(deepcopy(w_glob_fl))
278 | else:
279 | if rnd == 0:
280 | w_glob_fl = FedAvg(w_locals, dict_len)
281 | netglob.load_state_dict(deepcopy(w_glob_fl))
282 | w_glob_res = deepcopy(w_glob_fl)
283 | else:
284 | w_glob_fl = FedAvg(w_locals, dict_len)
285 | for k in w_glob_fl.keys():
286 | w_glob_fl[k] = 0.2*w_glob_fl[k] + 0.8*w_glob_res[k]
287 | netglob.load_state_dict(deepcopy(w_glob_fl))
288 | w_glob_res = deepcopy(w_glob_fl)
289 | if rnd >= args.rounds_CBAFed_warmup - 1:
290 | c_num = torch.zeros(args.n_classes)
291 | d_num = 0
292 | for s in user_id:
293 | c_num += class_num_lists[s]
294 | d_num += data_nums[s]
295 | pt = c_num / d_num
296 | avg_pt = pt.sum() / len(pt)
297 | std_pt = torch.sqrt((1/(len(pt)-1))*(((pt-avg_pt)**2).sum()))
298 | tao = pt + 0.45 - std_pt
299 | tao = torch.where(tao > 0.95, 0.95, tao)
300 | tao = torch.where(tao < 0.55, 0.55, tao)
301 | if rnd >= args.rounds_CBAFed_warmup:
302 | wti = (torch.tensor(data_nums) / torch.tensor(data_nums).sum()).tolist()
303 | if (rnd-args.rounds_CBAFed_warmup) % 5 != 0:
304 | w_glob_fl = FedAvg(w_locals, wti)
305 | netglob.load_state_dict(deepcopy(w_glob_fl))
306 | else:
307 | if (rnd-args.rounds_CBAFed_warmup) == 0:
308 | w_glob_fl = FedAvg(w_locals, wti)
309 | netglob.load_state_dict(deepcopy(w_glob_fl))
310 | w_glob_res = deepcopy(w_glob_fl)
311 | else:
312 | w_glob_fl = FedAvg(w_locals, wti)
313 | for k in w_glob_fl.keys():
314 | w_glob_fl[k] = 0.5*w_glob_fl[k] + 0.5*w_glob_res[k]
315 | netglob.load_state_dict(deepcopy(w_glob_fl))
316 | w_glob_res = deepcopy(w_glob_fl)
317 | else:
318 | w_glob_fl = FedAvg(w_locals, dict_len)
319 | netglob.load_state_dict(deepcopy(w_glob_fl))
320 |
321 | # validate
322 | if rnd % 10 == 9:
323 | logging.info(
324 | "\n------------------------------> testing, run: %d, round: %d <------------------------------" % (
325 | run, rnd))
326 | result = globaltest(deepcopy(netglob).to(args.device), test_dataset=dataset_test, args=args)
327 | mAP, BACC, R, F1, auroc, P, hamming_loss = result["mAP"], result["BACC"], result["R"], result[
328 | "F1"], result["auc"], result["P"], result["hamming_loss"]
329 | logging.info(
330 | "-----> mAP: %.2f, BACC: %.2f, R: %.2f, F1: %.2f, auc: %.2f, P: %.2f, hamming_loss: %.2f" % (
331 | mAP, BACC * 100, R * 100, F1 * 100, auroc * 100, P * 100, hamming_loss))
332 | writer1.add_scalar(f'test_run{run}/mAP', mAP, rnd)
333 | writer1.add_scalar(f'test_run{run}/BACC', BACC, rnd)
334 | writer1.add_scalar(f'test_run{run}/R', R, rnd)
335 | writer1.add_scalar(f'test_run{run}/F1', F1, rnd)
336 | writer1.add_scalar(f'test_run{run}/auc', auroc, rnd)
337 | writer1.add_scalar(f'test_run{run}/P', P, rnd)
338 | writer1.add_scalar(f'test_run{run}/hamming_loss', hamming_loss, rnd)
339 | logging.info('\n')
340 | if rnd == 49:
341 | torch.save(netglob.state_dict(), models_dir + f'/model_{rnd}.pth')
342 |
343 | if rnd % 10 == 9:
344 | logging.info('test')
345 | logging.info("\n------------------------------> testing, run: %d, round: %d <------------------------------" % (run, rnd))
346 | result = globaltest(deepcopy(netglob).to(args.device), test_dataset=dataset_test, args=args)
347 | mAP, BACC, R, F1, auroc, P, hamming_loss = result["mAP"], result["BACC"], result["R"], result["F1"], result["auc"], result["P"], result["hamming_loss"]
348 | logging.info("-----> mAP: %.2f, BACC: %.2f, R: %.2f, F1: %.2f, auc: %.2f, P: %.2f, hamming_loss: %.2f" % (mAP, BACC*100, R*100, F1*100, auroc*100, P*100, hamming_loss))
349 | logging.info(np.array(loss_locals))
350 | writer1.add_scalar(f'corr-test_run{run}/mAP', mAP, rnd)
351 | writer1.add_scalar(f'corr-test_run{run}/BACC', BACC, rnd)
352 | writer1.add_scalar(f'corr-test_run{run}/R', R, rnd)
353 | writer1.add_scalar(f'corr-test_run{run}/F1', F1, rnd)
354 | writer1.add_scalar(f'corr-test_run{run}/auc', auroc, rnd)
355 | writer1.add_scalar(f'corr-test_run{run}/P', P, rnd)
356 | writer1.add_scalar(f'corr-test_run{run}/hamming_loss', hamming_loss, rnd)
357 | logging.info('\n')
358 |
359 | # save model
360 | if (rnd + 1) == args.rounds_corr:
361 | torch.save(netglob.state_dict(), models_dir + f'/corr_model_{run}_{rnd}.pth')
362 | if rnd < 50 and (rnd + 1) % 5 == 0 and run == 0:
363 | torch.save(netglob.state_dict(), models_dir + f'/corr_model_{run}_{rnd}.pth')
364 |
365 | else: # test
366 | netglob = build_model(args)
367 | netglob.load_state_dict(torch.load("/home/szb/multilabel/model_warmup/model_warmup_49.pth"))
368 | result = classtest(deepcopy(netglob).to(args.device), test_dataset=dataset_test, args=args, classid=1)
369 | BACC, R, F1, P = result["BACC"], result["R"], result["F1"], result["P"]
370 | logging.info(
371 | "-----> BACC: %.2f, R: %.2f, F1: %.2f, P: %.2f" % (BACC * 100, R * 100, F1 * 100, P * 100))
372 | logging.info('\n')
373 | result = classtest(deepcopy(netglob).to(args.device), test_dataset=dataset_test, args=args, classid=4)
374 | BACC, R, F1, P = result["BACC"], result["R"], result["F1"], result["P"]
375 | logging.info(
376 | "-----> BACC: %.2f, R: %.2f, F1: %.2f, P: %.2f" % (BACC * 100, R * 100, F1 * 100, P * 100))
377 | logging.info('\n')
378 |
379 | torch.cuda.empty_cache()
380 |
--------------------------------------------------------------------------------
/utils/local_training.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | import math
4 | import random
5 | import time
6 | from copy import deepcopy
7 |
8 | import numpy as np
9 | import pandas as pd
10 | import torch
11 | from matplotlib import pyplot as plt
12 | from sklearn.mixture import GaussianMixture
13 | from torch import nn
14 | from torch.cuda.amp import autocast
15 | from torch.utils.data import DataLoader, Dataset
16 | import torch.nn.functional as F
17 | from itertools import chain
18 | import seaborn as sns
19 |
20 | from utils.evaluations import globaltest, classtest
21 | from utils.feature_visual import tnse_Visual
22 | from utils.FedNoRo import LogitAdjust_Multilabel, LA_KD
23 | from utils.utils import max_m_indices, min_n_indices
24 |
25 |
26 | class LocalUpdate(object):
27 | def __init__(self, args, client_id, dataset, idxs, class_pos_idx, class_neg_idx, active_class_list=None, student=None, teacher_neg=None, teacher_act=None, dataset_test=None):
28 | self.teacher_neg = teacher_neg
29 | self.teacher_act = teacher_act
30 | self.dataset_test = dataset_test
31 | self.ema_model = teacher_neg
32 | self.student = student
33 | self.args = args
34 | self.client_id = client_id
35 | self.idxs = idxs
36 | self.dataset = dataset
37 | self.active_class_list = active_class_list
38 | self.local_dataset = DatasetSplit(dataset, idxs, client_id, args, class_neg_idx, active_class_list)
39 | self.class_num_list = self.local_dataset.get_num_of_each_class(args)
40 | self.loss_w = [len(self.local_dataset) / i for i in self.class_num_list]
41 | self.loss_w_unknown = [1] * len(self.class_num_list)
42 | self.loss_w_unknown[client_id] = len(self.local_dataset) / self.class_num_list[client_id]
43 | self.loss_balanced = [1] * len(self.class_num_list)
44 | logging.info(self.loss_w)
45 | logging.info(
46 | f"---> Client{client_id}, each class num: {self.class_num_list}, total num: {len(self.local_dataset)}")
47 | self.ldr_train = DataLoader(
48 | self.local_dataset, batch_size=self.args.batch_size, shuffle=True, num_workers=4)
49 | self.epoch = 0
50 | self.iter_num = 0
51 | self.lr = self.args.base_lr
52 | self.class_pos_idx = class_pos_idx
53 | self.class_neg_idx = class_neg_idx
54 | self.flag = True
55 | self.confuse_matrix = torch.zeros((8, 8)).cuda()
56 |
57 | def find_rows(self, tensor, up, down):
58 | condition = torch.all(torch.logical_or(tensor > up, tensor < down), axis=1)
59 | row_indices = torch.where(condition)[0]
60 | return row_indices
61 |
62 | def update_ema_variables(self, model, ema_model, alpha, global_step):
63 | alpha = min(1 - 1 / (global_step + 1), alpha)
64 | for ema_param, param in zip(ema_model.parameters(), model.parameters()):
65 | ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)
66 |
67 | def torch_tile(self, tensor, dim, n):
68 | if dim == 0:
69 | return tensor.unsqueeze(0).transpose(0, 1).repeat(1, n, 1).view(-1, tensor.shape[1])
70 | else:
71 | return tensor.unsqueeze(0).transpose(0, 1).repeat(1, 1, n).view(tensor.shape[0], -1)
72 |
73 | def get_confuse_matrix(self, logits, labels):
74 | source_prob = []
75 | for i in range(8):
76 | mask = self.torch_tile(torch.unsqueeze(labels[:, i], -1), 1, 8)
77 | logits_mask_out = logits * mask
78 | logits_avg = torch.sum(logits_mask_out, dim=0) / (torch.sum(labels[:, i]) + 1e-8)
79 | prob = torch.sigmoid(logits_avg / 2.0)
80 | source_prob.append(prob)
81 | return torch.stack(source_prob)
82 |
83 | def sigmoid_rampup(self, current, rampup_length):
84 | """Exponential rampup from https://arxiv.org/abs/1610.02242"""
85 | if rampup_length == 0:
86 | return 1.0
87 | else:
88 | current = np.clip(current, 0.0, rampup_length)
89 | phase = 1.0 - current / rampup_length
90 | return float(np.exp(-5.0 * phase * phase))
91 | def get_current_consistency_weight(self, epoch):
92 | return self.args.consistency * self.sigmoid_rampup(epoch, self.args.consistency_rampup)
93 |
94 | def sigmoid_mse_loss(self, input_logits, target_logits):
95 | """Takes softmax on both sides and returns MSE loss
96 |
97 | Note:
98 | - Returns the sum over all examples. Divide by the batch size afterwards
99 | if you want the mean.
100 | - Sends gradients to inputs but not the targets.
101 | """
102 | assert input_logits.size() == target_logits.size()
103 | input_softmax = torch.sigmoid(input_logits)
104 | target_softmax = torch.sigmoid(target_logits)
105 |
106 | mse_loss = (input_softmax - target_softmax) ** 2
107 | return mse_loss
108 |
109 | def kd_loss(self, source_matrix, target_matrix):
110 | Q = source_matrix
111 | P = target_matrix
112 | loss = (F.kl_div(Q.log(), P, None, None, 'batchmean') + F.kl_div(P.log(), Q, None, None, 'batchmean')) / 2.0
113 | return loss
114 |
115 | def train_FedNoRo(self, id, rnd, net, writer1, weight_kd = None, clean_clients=None, noisy_clients=None):
116 | assert len(self.ldr_train.dataset) == len(self.idxs)
117 | print(f"Client ID: {self.client_id}, Num: {len(self.ldr_train.dataset)}")
118 | if rnd < self.args.rounds_FedNoRo_warmup:
119 | student_net = copy.deepcopy(net).cuda()
120 | teacher_net = copy.deepcopy(net).cuda()
121 | student_net.train()
122 | teacher_net.eval()
123 | # set the optimizer
124 | self.optimizer = torch.optim.Adam(
125 | student_net.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4)
126 |
127 | # train and update
128 | epoch_loss = []
129 | for epoch in range(self.args.local_ep):
130 | batch_loss = []
131 | for k, (samples, item, active_class_list) in enumerate(self.ldr_train):
132 | if k == 0:
133 | active_class_list_client = []
134 | negetive_class_list_client = []
135 | for i in range(self.args.annotation_num):
136 | active_class_list_client.append(active_class_list[i][0].item())
137 | for i in range(self.args.n_classes):
138 | if i not in active_class_list_client:
139 | negetive_class_list_client.append(i)
140 | self.class_num_list[i] = 0
141 | criterion = LA_KD(cls_num_list=self.class_num_list, num=len(self.ldr_train.dataset), active_class_list_client=active_class_list_client, negative_class_list_client=negetive_class_list_client)
142 | images, labels = samples["image"].to(self.args.device), samples["target"].to(self.args.device)
143 |
144 | _, logits = student_net(images)
145 | # print(logits)
146 |
147 | with torch.no_grad():
148 | _, teacher_output = teacher_net(images)
149 | soft_label = torch.sigmoid(teacher_output / 0.8)
150 | # print(teacher_output)
151 | logits_sig = torch.sigmoid(logits).cuda()
152 | loss = criterion(logits_sig, labels, soft_label, weight_kd)
153 |
154 | self.optimizer.zero_grad()
155 | loss.backward()
156 | self.optimizer.step()
157 |
158 | batch_loss.append(loss.item())
159 | self.iter_num += 1
160 | self.epoch = self.epoch + 1
161 | epoch_loss.append(np.array(batch_loss).mean())
162 | else:
163 | if id in clean_clients:
164 | net.train()
165 | self.optimizer = torch.optim.Adam(
166 | net.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4)
167 | epoch_loss = []
168 | ce_criterion = LogitAdjust_Multilabel(cls_num_list=self.class_num_list, num=len(self.ldr_train.dataset))
169 | for epoch in range(self.args.local_ep):
170 | batch_loss = []
171 | for k, (samples, item, active_class_list) in enumerate(self.ldr_train):
172 | if k == 0:
173 | active_class_list_client = []
174 | negetive_class_list_client = []
175 | for i in range(self.args.annotation_num):
176 | active_class_list_client.append(active_class_list[i][0].item())
177 | for i in range(self.args.n_classes):
178 | if i not in active_class_list_client:
179 | negetive_class_list_client.append(i)
180 | images, labels = samples["image"].to(self.args.device), samples["target"].to(self.args.device)
181 | _, logits = net(images)
182 | logits_sig = torch.sigmoid(logits).cuda()
183 | loss = ce_criterion(logits_sig, labels)
184 | self.optimizer.zero_grad()
185 | loss.backward()
186 | self.optimizer.step()
187 | batch_loss.append(loss.item())
188 | self.iter_num += 1
189 | self.epoch = self.epoch + 1
190 | epoch_loss.append(np.array(batch_loss).mean())
191 | elif id in noisy_clients:
192 | student_net = copy.deepcopy(net).cuda()
193 | teacher_net = copy.deepcopy(net).cuda()
194 | student_net.train()
195 | teacher_net.eval()
196 | # set the optimizer
197 | self.optimizer = torch.optim.Adam(
198 | student_net.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4)
199 |
200 | # train and update
201 | epoch_loss = []
202 | criterion = LA_KD(cls_num_list=self.class_num_list, num=len(self.ldr_train.dataset))
203 |
204 | for epoch in range(self.args.local_ep):
205 | batch_loss = []
206 | for k, (samples, item, active_class_list) in enumerate(self.ldr_train):
207 | if k == 0:
208 | active_class_list_client = []
209 | negetive_class_list_client = []
210 | for i in range(self.args.annotation_num):
211 | active_class_list_client.append(active_class_list[i][0].item())
212 | for i in range(self.args.n_classes):
213 | if i not in active_class_list_client:
214 | negetive_class_list_client.append(i)
215 | images, labels = samples["image"].to(self.args.device), samples["target"].to(self.args.device)
216 |
217 | _, logits = student_net(images)
218 | with torch.no_grad():
219 | _, teacher_output = teacher_net(images)
220 | soft_label = torch.sigmoid(teacher_output / 0.8)
221 | logits_sig = torch.sigmoid(logits).cuda()
222 | loss = criterion(logits_sig, labels, soft_label, weight_kd)
223 |
224 | self.optimizer.zero_grad()
225 | loss.backward()
226 | self.optimizer.step()
227 |
228 | batch_loss.append(loss.item())
229 | self.iter_num += 1
230 | self.epoch = self.epoch + 1
231 | epoch_loss.append(np.array(batch_loss).mean())
232 | student_net.cpu()
233 | self.optimizer.zero_grad()
234 | return student_net.state_dict(), np.array(epoch_loss).mean(), _, _, negetive_class_list_client, active_class_list_client
235 |
236 | def train_CBAFed(self, rnd, net, pt=None, tao=None):
237 | if rnd < self.args.rounds_CBAFed_warmup: # stage1
238 | class_num_list = torch.zeros(self.args.n_classes)
239 | data_num = 0
240 | net.train()
241 | # set the optimizer
242 | self.optimizer = torch.optim.Adam(
243 | net.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4)
244 | # train and update
245 | epoch_loss = []
246 | print(self.loss_w)
247 | bce_criterion_sup = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(self.loss_w).cuda(),
248 | reduction='none') # include sigmoid
249 | for epoch in range(self.args.local_ep):
250 | print('local_epoch:', epoch)
251 | batch_loss = []
252 | for j, (samples, item, active_class_list) in enumerate(self.ldr_train):
253 | if j == 0:
254 | active_class_list_client = []
255 | negetive_class_list_client = []
256 | for i in range(self.args.annotation_num):
257 | active_class_list_client.append(active_class_list[i][0].item())
258 | for i in range(self.args.n_classes):
259 | if i not in active_class_list_client:
260 | negetive_class_list_client.append(i)
261 | images, labels = samples["image"].to(self.args.device), samples["target"].to(self.args.device)
262 | # class_num_list = class_num_list + torch.sum(labels, dim=0)
263 | data_num = data_num + len(labels)
264 | _, logits = net(images)
265 | loss_sup = bce_criterion_sup(logits, labels) # tensor(32, 5)
266 | loss_sup = loss_sup[:, active_class_list_client].sum() / (
267 | self.args.batch_size * self.args.annotation_num) # supervised_loss
268 | loss = loss_sup
269 | self.optimizer.zero_grad()
270 | loss.backward()
271 | self.optimizer.step()
272 | batch_loss.append(loss.item())
273 | self.iter_num += 1
274 | for id in active_class_list_client:
275 | class_num_list[id] = data_num
276 | self.epoch = self.epoch + 1
277 | epoch_loss.append(np.array(batch_loss).mean())
278 | return net.state_dict(), np.array(
279 | epoch_loss).mean(), _, _, negetive_class_list_client, active_class_list_client, class_num_list, data_num
280 | else:
281 | class_num_list = torch.zeros(self.args.n_classes)
282 | data_num = 0
283 | net.train()
284 | self.optimizer = torch.optim.Adam(
285 | net.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4)
286 | # train and update
287 | epoch_loss = []
288 |
289 | for epoch in range(self.args.local_ep):
290 | print('local_epoch:', epoch)
291 | batch_loss = []
292 | for j, (samples, item, active_class_list) in enumerate(self.ldr_train):
293 | idx_neg = []
294 | if j == 0:
295 | active_class_list_client = []
296 | negetive_class_list_client = []
297 | for i in range(self.args.annotation_num):
298 | active_class_list_client.append(active_class_list[i][0].item())
299 | for i in range(self.args.n_classes):
300 | if i not in active_class_list_client:
301 | negetive_class_list_client.append(i)
302 | images, labels = samples["image"].to(self.args.device), samples["target"].to(self.args.device)
303 | _, logits = net(images)
304 | prob = torch.sigmoid(logits)
305 | # print(negetive_class_list_client)
306 | # print(prob)
307 | for i in negetive_class_list_client:
308 | noise_num = len(torch.where(prob[:, i] > tao[i])[0])
309 | clean_num = len(torch.where(prob[:, i] < (1-tao[i]))[0])
310 | labels[:, i] = torch.where(prob[:, i] > tao[i], 1, labels[:, i])
311 | pseudo_idx = torch.where((prob[:, i] > tao[i]) | (prob[:, i] < (1-tao[i])))[0]
312 | idx_neg.append(pseudo_idx)
313 | class_num_list[i] = class_num_list[i] + len(pseudo_idx)
314 | # print(len(pseudo_idx))
315 | data_num = data_num + len(pseudo_idx)
316 | if noise_num == 0:
317 | self.loss_w[i] = 1
318 | else:
319 | self.loss_w[i] = (noise_num+clean_num) / noise_num
320 | print(self.loss_w)
321 | for i in active_class_list_client:
322 | class_num_list[i] = class_num_list[i] + len(labels)
323 | data_num = data_num + len(labels)*self.args.annotation_num
324 | bce_criterion_sup = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(self.loss_w).cuda(),
325 | reduction='none') # include sigmoid
326 | loss_sup = bce_criterion_sup(logits, labels) # tensor(32, 5)
327 | loss_sup_act = loss_sup[:, active_class_list_client].sum() / (
328 | self.args.batch_size * self.args.annotation_num) # supervised_loss
329 | # loss_sup_act = loss_sup.sum() / (self.args.batch_size * self.args.n_classes) # supervised_loss
330 | loss = loss_sup_act
331 | for k, i in enumerate(negetive_class_list_client):
332 | if len(idx_neg[k]) != 0:
333 | loss += loss_sup[idx_neg[k], i].sum() / len(idx_neg[k]) # supervised_loss
334 | self.optimizer.zero_grad()
335 | loss.backward()
336 | self.optimizer.step()
337 | batch_loss.append(loss.item())
338 | self.iter_num += 1
339 | self.epoch = self.epoch + 1
340 | epoch_loss.append(np.array(batch_loss).mean())
341 | return net.state_dict(), np.array(
342 | epoch_loss).mean(), _, _, negetive_class_list_client, active_class_list_client, class_num_list, data_num
343 |
344 | def train_FedIRM(self, rnd, target_matrix, writer1, negetive_class_list, active_class_list_client_i, net): # MICCAI2021
345 | if rnd < self.args.rounds_FedIRM_sup:
346 | net.train()
347 | # set the optimizer
348 | self.optimizer = torch.optim.Adam(
349 | net.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4)
350 | # train and update
351 | epoch_loss = []
352 | self.confuse_matrix = torch.zeros((8, 8)).cuda()
353 | print(self.loss_w)
354 | bce_criterion_sup = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(self.loss_w).cuda(),
355 | reduction='none') # include sigmoid
356 | for epoch in range(self.args.local_ep):
357 | print('local_epoch:', epoch)
358 | batch_loss = []
359 | for j, (samples, item, active_class_list) in enumerate(self.ldr_train):
360 | if j == 0:
361 | active_class_list_client = []
362 | negetive_class_list_client = []
363 | for i in range(self.args.annotation_num):
364 | active_class_list_client.append(active_class_list[i][0].item())
365 | for i in range(self.args.n_classes):
366 | if i not in active_class_list_client:
367 | negetive_class_list_client.append(i)
368 | images1, images2, labels = samples["image_aug_1"].to(self.args.device), samples["image_aug_2"].to(
369 | self.args.device), samples["target"].to(self.args.device)
370 | _, logits1 = net(images1)
371 | fe2, logits2 = net(images2)
372 | if rnd == self.args.rounds_FedIRM_sup - 1: # first relation matrix
373 | self.confuse_matrix = self.confuse_matrix + self.get_confuse_matrix(logits1, labels)
374 | loss_sup = bce_criterion_sup(logits1, labels) + bce_criterion_sup(logits2, labels) # tensor(32, 5)
375 | loss_sup = loss_sup[:, active_class_list_client].sum() / (
376 | self.args.batch_size * self.args.annotation_num) # supervised_loss
377 | self.optimizer.zero_grad()
378 | loss_sup.backward()
379 | self.optimizer.step()
380 | batch_loss.append(loss_sup.item())
381 | self.epoch = self.epoch + 1
382 | epoch_loss.append(np.array(batch_loss).mean())
383 | if rnd == self.args.rounds_FedIRM_sup - 1:
384 | with torch.no_grad():
385 | self.confuse_matrix = self.confuse_matrix / (1.0 * self.args.local_ep * (j + 1))
386 | return net.state_dict(), np.array(
387 | epoch_loss).mean(), _, _, negetive_class_list_client, active_class_list_client, self.confuse_matrix
388 | else:
389 | return net.state_dict(), np.array(
390 | epoch_loss).mean(), _, _, negetive_class_list_client, active_class_list_client
391 |
392 | else: # Inter-client Relation Matching
393 | if self.flag:
394 | self.ema_model.load_state_dict(net.state_dict())
395 | self.flag = False
396 | print('done')
397 | net.train()
398 | # set the optimizer
399 | self.optimizer = torch.optim.Adam(
400 | net.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4)
401 | # train and update
402 | epoch_loss = []
403 | self.confuse_matrix = torch.zeros((8, 8)).cuda()
404 | print(self.loss_w)
405 | bce_criterion_sup = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(self.loss_w).cuda(),
406 | reduction='none') # include sigmoid
407 | for epoch in range(self.args.local_ep):
408 | print('local_epoch:', epoch)
409 | batch_loss = []
410 | for j, (samples, item, active_class_list) in enumerate(self.ldr_train):
411 | if j == 0:
412 | active_class_list_client = []
413 | negetive_class_list_client = []
414 | for i in range(self.args.annotation_num):
415 | active_class_list_client.append(active_class_list[i][0].item())
416 | for i in range(self.args.n_classes):
417 | if i not in active_class_list_client:
418 | negetive_class_list_client.append(i)
419 | images1, images2, labels = samples["image_aug_1"].to(self.args.device), samples["image_aug_2"].to(
420 | self.args.device), samples["target"].to(self.args.device)
421 | _, outputs = net(images1)
422 | with torch.no_grad():
423 | _, ema_output = self.ema_model(images2)
424 | with torch.no_grad():
425 | preds = torch.sigmoid(outputs).cuda()
426 | uncertainty = -1.0 * (torch.sum(preds * torch.log(preds + 1e-6), dim=1) + torch.sum(
427 | (1 - preds) * torch.log(1 - preds + 1e-6), dim=1))
428 | uncertainty_mask = (uncertainty < 2.0)
429 | with torch.no_grad():
430 | activations = torch.sigmoid(outputs).cuda()
431 | confidence_mask = torch.zeros(len(uncertainty_mask), dtype=bool).cuda()
432 | confidence_mask[self.find_rows(activations, 0.7, 0.3)] = True
433 | mask = confidence_mask * uncertainty_mask
434 | if mask.sum().item() != 0:
435 | pseudo_labels = activations[mask] > 0.5
436 | source_matrix = self.get_confuse_matrix(outputs[mask], pseudo_labels)
437 | else:
438 | source_matrix = 0.5*torch.ones((8, 8)).cuda()
439 | target_matrix = target_matrix.cuda()
440 | print(source_matrix)
441 | print(target_matrix)
442 | consistency_weight = self.get_current_consistency_weight(rnd)
443 | consistency_dist = torch.sum(self.sigmoid_mse_loss(outputs, ema_output)) / self.args.batch_size
444 | consistency_loss = consistency_dist
445 | loss = consistency_weight * consistency_loss + consistency_weight * torch.sum(
446 | self.kd_loss(source_matrix, target_matrix))
447 | fe2, logits2 = net(images2)
448 | self.confuse_matrix = self.confuse_matrix + self.get_confuse_matrix(outputs, labels)
449 | loss_sup = bce_criterion_sup(outputs, labels) + bce_criterion_sup(logits2, labels) # tensor(32, 5)
450 | loss_sup = loss_sup[:, active_class_list_client].sum() / (
451 | self.args.batch_size * self.args.annotation_num) # supervised_loss
452 | loss = loss + loss_sup
453 | self.optimizer.zero_grad()
454 | loss.backward()
455 | self.optimizer.step()
456 | self.update_ema_variables(net, self.ema_model, self.args.ema_decay, self.iter_num)
457 | batch_loss.append(loss.item())
458 | self.iter_num = self.iter_num + 1
459 | self.epoch = self.epoch + 1
460 | epoch_loss.append(np.array(batch_loss).mean())
461 | with torch.no_grad():
462 | self.confuse_matrix = self.confuse_matrix / (1.0 * self.args.local_ep * (j + 1))
463 | return net.state_dict(), np.array(
464 | epoch_loss).mean(), _, _, negetive_class_list_client, active_class_list_client, self.confuse_matrix
465 |
466 | def train_RoFL(self, net, f_G, epoch):
467 | time1 = time.time()
468 | sim = torch.nn.CosineSimilarity(dim=1)
469 | pseudo_labels = torch.zeros(len(self.dataset), self.args.n_classes, dtype=torch.float64, device=self.args.device)
470 | optimizer = torch.optim.Adam(
471 | net.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4)
472 | epoch_loss = []
473 | net.eval()
474 | f_k = torch.zeros(2*self.args.n_classes, self.args.feature_dim, device=self.args.device)
475 | n_labels = torch.zeros(2*self.args.n_classes, 1, device=self.args.device)
476 |
477 | # obtain global-guided pseudo labels y_hat by y_hat_k = C_G(F_G(x_k))
478 | # initialization of global centroids
479 | # obtain naive average feature
480 | with torch.no_grad():
481 | for i, (samples, item, active_class_list) in enumerate(self.ldr_train):
482 | if i == 0:
483 | active_class_list_client = []
484 | negetive_class_list_client = []
485 | for i in range(self.args.annotation_num):
486 | active_class_list_client.append(active_class_list[i][0].item())
487 | for i in range(self.args.n_classes):
488 | if i not in active_class_list_client:
489 | negetive_class_list_client.append(i)
490 | images, labels = samples["image"].to(self.args.device), samples["target"].to(self.args.device)
491 | feature, logit = net(images)
492 | probs = torch.sigmoid(logit) # soft predict
493 | accuracy_th = 0.5
494 | preds = probs > accuracy_th # hard predict
495 | preds = preds.to(torch.float64)
496 | pseudo_labels[item] = preds
497 | if epoch == 0:
498 | for cls in range(self.args.n_classes):
499 | f_k[2*cls] += torch.sum(feature[torch.where(labels[:, cls] == 0)[0]], dim=0)
500 | f_k[2*cls+1] += torch.sum(feature[torch.where(labels[:, cls] == 1)[0]], dim=0)
501 | n_labels[2*cls] += len(torch.where(labels[:, cls] == 0)[0])
502 | n_labels[2*cls+1] += len(torch.where(labels[:, cls] == 1)[0])
503 |
504 | if epoch == 0:
505 | for i in range(len(n_labels)):
506 | if n_labels[i] == 0:
507 | n_labels[i] = 1
508 | f_k = torch.div(f_k, n_labels)
509 | else:
510 | f_k = f_G
511 | time2 = time.time()
512 | # print('local test time: ', time2-time1)
513 | net.train()
514 | for iter in range(self.args.local_ep):
515 | batch_loss = []
516 | for samples, item, active_class_list in self.ldr_train:
517 | time4 = time.time()
518 | net.zero_grad()
519 | images, labels = samples["image"].to(self.args.device), samples["target"].to(self.args.device)
520 | feature, logit = net(images)
521 | feature = feature.detach()
522 | f_k = f_k.to(self.args.device)
523 |
524 | small_loss_idxs, loss_w = self.get_small_loss_samples(logit, labels, self.args.forget_rate, negetive_class_list_client)
525 |
526 | y_k_tilde = torch.zeros(self.args.batch_size, self.args.n_classes, device=self.args.device)
527 | mask = torch.zeros(self.args.batch_size, device=self.args.device)
528 | for i in small_loss_idxs:
529 | for cls in range(self.args.n_classes):
530 | f_cls = f_k[[2*cls, 2*cls+1], :]
531 | y_k_tilde[i, cls] = torch.argmax(sim(f_cls, torch.reshape(feature[i], (1, self.args.feature_dim))))
532 | if torch.equal(y_k_tilde[i], labels[i]):
533 | mask[i] = 1
534 |
535 | # When to use pseudo-labels
536 | if epoch < self.args.T_pl:
537 | for i in small_loss_idxs:
538 | pseudo_labels[item[i]] = labels[i]
539 |
540 | # For loss calculating
541 | mask_resize = mask.unsqueeze(1).repeat([1, 5])
542 | new_labels = mask_resize[small_loss_idxs] * labels[small_loss_idxs] + (1 - mask_resize[small_loss_idxs]) * \
543 | pseudo_labels[item[small_loss_idxs]]
544 | new_labels = new_labels.type(torch.float).to(self.args.device)
545 | time5 = time.time()
546 | # print('batch train prepare time: ', time5-time4)
547 | loss = self.RFLloss(logit, labels, feature, f_k, mask, small_loss_idxs, new_labels, loss_w, epoch)
548 | # print('loss: ', loss)
549 | # weight update by minimizing loss: L_total = L_c + lambda_cen * L_cen + lambda_e * L_e
550 | loss.backward()
551 | optimizer.step()
552 |
553 | # obtain loss based average features f_k,j_hat from small loss dataset
554 | f_kj_hat = torch.zeros(2*self.args.n_classes, self.args.feature_dim, device=self.args.device)
555 | n = torch.zeros(2*self.args.n_classes, 1, device=self.args.device)
556 | for i in small_loss_idxs:
557 | for cls in range(self.args.n_classes):
558 | if labels[i, cls] == 0:
559 | f_kj_hat[2 * cls] += feature[i]
560 | n[2 * cls] += 1
561 | else:
562 | f_kj_hat[2 * cls + 1] += feature[i]
563 | n[2 * cls + 1] += 1
564 | for i in range(len(n)):
565 | if n[i] == 0:
566 | n[i] = 1
567 | f_kj_hat = torch.div(f_kj_hat, n)
568 |
569 | # update local centroid f_k
570 | one = torch.ones(2*self.args.n_classes, 1, device=self.args.device)
571 | f_k = (one - sim(f_k, f_kj_hat).reshape(2*self.args.n_classes, 1) ** 2) * f_k + (
572 | sim(f_k, f_kj_hat).reshape(2*self.args.n_classes, 1) ** 2) * f_kj_hat
573 |
574 | batch_loss.append(loss.item())
575 | time6 = time.time()
576 | # print('batch train time: ', time6 - time5)
577 | epoch_loss.append(sum(batch_loss) / len(batch_loss))
578 | time3 = time.time()
579 | # print('local train time: ', time3 - time2)
580 | return net.state_dict(), sum(epoch_loss) / len(epoch_loss), f_k
581 |
582 | def RFLloss(self, logit, labels, feature, f_k, mask, small_loss_idxs, new_labels, loss_w, epoch):
583 | mse = torch.nn.MSELoss(reduction='none')
584 | bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(loss_w).cuda()) # include sigmoid
585 | L_c = bce(logit[small_loss_idxs], new_labels)
586 | prob = torch.sigmoid(logit).cuda()
587 | for i in range(self.args.n_classes):
588 | if i == 0:
589 | L_cen = torch.sum(
590 | mask[small_loss_idxs] * torch.sum(mse(feature[small_loss_idxs], f_k[(2*i+labels[small_loss_idxs, i]).cpu().numpy()]), 1))/(len(small_loss_idxs)*self.args.feature_dim)
591 | else:
592 | L_cen += torch.sum(
593 | mask[small_loss_idxs] * torch.sum(mse(feature[small_loss_idxs], f_k[(2*i+labels[small_loss_idxs, i]).cpu().numpy()]), 1))/(len(small_loss_idxs)*self.args.feature_dim)
594 | L_cen = L_cen / self.args.n_classes
595 | for i in range(self.args.n_classes):
596 | clsi_prob = prob[:, i].unsqueeze(1)
597 | clsi_prob = torch.cat((clsi_prob, 1-clsi_prob), dim=1).cuda()
598 | if i == 0:
599 | L_e = -torch.mean(torch.sum(clsi_prob[small_loss_idxs] * torch.log(clsi_prob[small_loss_idxs]), dim=1))
600 | else:
601 | L_e += -torch.mean(torch.sum(clsi_prob[small_loss_idxs] * torch.log(clsi_prob[small_loss_idxs]), dim=1))
602 | L_e = L_e / self.args.n_classes
603 | lambda_e = self.args.lambda_e
604 | lambda_cen = self.args.lambda_cen
605 | if epoch < self.args.T_pl:
606 | lambda_cen = (self.args.lambda_cen * epoch) / self.args.T_pl
607 | # print('L_c: ', L_c.item(), 'L_cen: ', L_cen.item(), 'L_e: ', L_e.item())
608 | if math.isnan(L_c.item()) or math.isnan(L_cen.item()) or math.isnan(L_e.item()):
609 | print(logit)
610 | print(feature)
611 | print(loss_w)
612 | print('loss: ', L_c + (lambda_cen * L_cen) + (lambda_e * L_e))
613 | return L_c + (lambda_cen * L_cen) + (lambda_e * L_e)
614 |
615 | def get_small_loss_samples(self, y_pred, y_true, forget_rate, negetive_class_list_client):
616 | loss_w = self.loss_w
617 | for i in negetive_class_list_client:
618 | loss_w[i] = 5.
619 | loss_func = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(loss_w).cuda(), reduction='none') # include sigmoid
620 | loss = torch.sum(loss_func(y_pred, y_true), dim=1)
621 | ind_sorted = np.argsort(loss.data.cpu()).cuda()
622 | loss_sorted = loss[ind_sorted]
623 | remember_rate = 1 - forget_rate
624 | num_remember = int(remember_rate * len(loss_sorted))
625 | ind_update = ind_sorted[:num_remember]
626 | return ind_update, loss_w
627 |
628 | def train(self, rnd, net, writer1):
629 | # teacher_neg = deepcopy(net).to(self.args.device) # try
630 | assert len(self.ldr_train.dataset) == len(self.idxs)
631 | print(f"Client ID: {self.client_id}, Num: {len(self.ldr_train.dataset)}")
632 |
633 | net.train()
634 | # teacher_neg.eval() # try
635 | # set the optimizer
636 | self.optimizer = torch.optim.Adam(
637 | net.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4)
638 |
639 | # train and update
640 | epoch_loss = []
641 | print(self.loss_w)
642 | bce_criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(self.loss_w).cuda(), reduction='none') # include sigmoid
643 | active_class_list_client = []
644 | negetive_class_list_client = []
645 | # mse = nn.MSELoss() # try
646 | for epoch in range(self.args.local_ep):
647 | print('local_epoch:', epoch)
648 | batch_loss = []
649 | for k, (samples, item, active_class_list) in enumerate(self.ldr_train):
650 | images, labels = samples["image"].to(self.args.device), samples["target"].to(self.args.device)
651 | if k == 0:
652 | for i in range(self.args.annotation_num):
653 | active_class_list_client.append(active_class_list[i][0].item())
654 | for i in range(self.args.n_classes):
655 | if i not in active_class_list_client:
656 | negetive_class_list_client.append(i)
657 | _, logits = net(images)
658 |
659 | # logits1_sig = torch.sigmoid(logits).cuda() # try
660 | # with torch.no_grad():
661 | # _, logits2 = teacher_neg(images)
662 | # logits2_sig = torch.sigmoid(logits2).cuda()
663 |
664 | loss = bce_criterion(logits, labels) # tensor(32, 5)
665 | loss = loss.sum()/(self.args.batch_size * self.args.n_classes) # all_class_loss
666 | # loss = loss[:, active_class_list_client].sum()/(self.args.batch_size * self.args.annotation_num)
667 | # mask = torch.ones_like(loss) # active_class_loss
668 | # for i in negetive_class_list_client:
669 | # mask[:, i] = 0.
670 | # loss = (loss * mask).sum()/(self.args.batch_size * self.args.annotation_num)
671 | # loss += mse(logits1_sig[:, negetive_class_list_client], logits2_sig[:, negetive_class_list_client]) # try
672 |
673 | self.optimizer.zero_grad()
674 | loss.backward()
675 | self.optimizer.step()
676 |
677 | batch_loss.append(loss.item())
678 |
679 | self.iter_num += 1
680 | # if rnd % 5 == 0:
681 | # for j in range(len(negetive_class_list_client)):
682 | # plt.title(f'round:{rnd},epoch:{epoch},client:{self.client_id},miss class:{negetive_class_list_client[j]} loss distribution')
683 | # sns.kdeplot(loss_false_negetive[j], label='FN')
684 | # sns.kdeplot(loss_true_negetive[j], label='TN')
685 | # plt.legend()
686 | # plt.savefig(f'loss_fig/round:{rnd},epoch:{epoch},client:{self.client_id},miss class:{negetive_class_list_client[j]}_loss_distribution.png')
687 | # print('ok')
688 |
689 | # for u in range(5):
690 | # result = classtest(deepcopy(net).cuda(), test_dataset=self.dataset_test, args=self.args, classid=u)
691 | # BACC, R, F1, P= result["BACC"], result["R"], result["F1"], result["P"]
692 | # logging.info(
693 | # "-----> BACC: %.2f, R: %.2f, F1: %.2f, P: %.2f" % (BACC * 100, R * 100, F1 * 100, P * 100))
694 | # writer1.add_scalar(f'test_client_class{u}/BACC', BACC, epoch)
695 | # writer1.add_scalar(f'test_client_class{u}/R', R, epoch)
696 | # writer1.add_scalar(f'test_client_class{u}/F1', F1, epoch)
697 | # writer1.add_scalar(f'test_client_class{u}/P', P, epoch)
698 |
699 | self.epoch = self.epoch + 1
700 | epoch_loss.append(np.array(batch_loss).mean())
701 | net.cpu()
702 | self.optimizer.zero_grad()
703 | return net.state_dict(), np.array(epoch_loss).mean(), _, _, negetive_class_list_client, active_class_list_client
704 |
705 | def train_RSCFed(self, rnd, net):
706 | assert len(self.ldr_train.dataset) == len(self.idxs)
707 | print(f"Client ID: {self.client_id}, Num: {len(self.ldr_train.dataset)}")
708 | self.student = deepcopy(net).to(self.args.device)
709 | self.teacher_neg.eval()
710 | self.student.train()
711 |
712 | # set the optimizer
713 | self.optimizer = torch.optim.Adam(
714 | self.student.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4)
715 |
716 | # train and update
717 | epoch_loss = []
718 | print(self.loss_w)
719 | bce_criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(self.loss_w).cuda(), reduction='none') # include sigmoid
720 | mse_loss = nn.MSELoss()
721 |
722 | for epoch in range(self.args.local_ep):
723 | print('local_epoch:', epoch)
724 | batch_loss = []
725 | for samples, item, active_class_list in self.ldr_train:
726 | active_class_list_client = []
727 | negetive_class_list_client = []
728 | for i in range(self.args.annotation_num):
729 | active_class_list_client.append(active_class_list[i][0].item())
730 | for i in range(self.args.n_classes):
731 | if i not in active_class_list_client:
732 | negetive_class_list_client.append(i)
733 | images1, images2, labels = samples["image_aug_1"].to(self.args.device), samples["image_aug_2"].to(self.args.device), samples["target"].to(self.args.device)
734 | _, logits1_stu = self.student(images1)
735 | logits1_stu_sig = torch.sigmoid(logits1_stu).cuda()
736 |
737 | with torch.no_grad():
738 | _, logits2_tea = self.teacher_neg(images2)
739 | # _, logits2_tea = self.teacher_neg(images1)
740 | logits2_tea_sig = torch.sigmoid(logits2_tea).cuda()
741 | loss = bce_criterion(logits1_stu, labels) # tensor(32, 5)
742 | loss_sup = loss[:, active_class_list_client].sum()/(self.args.batch_size * self.args.annotation_num) # supervised_loss
743 | loss_unsup = mse_loss(logits1_stu_sig[:, negetive_class_list_client], logits2_tea_sig[:, negetive_class_list_client])
744 | loss = loss_sup + loss_unsup
745 | # loss = loss_sup
746 |
747 | self.optimizer.zero_grad()
748 | loss.backward()
749 | self.optimizer.step()
750 |
751 | # update teacher
752 | state_dict1 = self.teacher_neg.state_dict()
753 | state_dict2 = deepcopy(self.student).state_dict()
754 | weight1 = 1 - 0.001
755 | weight2 = 0.001
756 | weighted_state_dict = {}
757 | for name in state_dict1:
758 | weighted_state_dict[name] = weight1 * state_dict1[name] + weight2 * state_dict2[name]
759 | self.teacher_neg.load_state_dict(weighted_state_dict)
760 | self.teacher_neg = self.teacher_neg.to(self.args.device)
761 | batch_loss.append(loss.item())
762 | self.iter_num += 1
763 |
764 | self.epoch = self.epoch + 1
765 | epoch_loss.append(np.array(batch_loss).mean())
766 |
767 | net.cpu()
768 | self.optimizer.zero_grad()
769 | return self.student.state_dict(), np.array(epoch_loss).mean(), _, _, negetive_class_list_client, active_class_list_client
770 |
771 | def train_FixMatch(self, rnd, net):
772 | assert len(self.ldr_train.dataset) == len(self.idxs)
773 | print(f"Client ID: {self.client_id}, Num: {len(self.ldr_train.dataset)}")
774 | net.train()
775 |
776 | # set the optimizer
777 | self.optimizer = torch.optim.Adam(
778 | net.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4)
779 |
780 | # train and update
781 | epoch_loss = []
782 | print(self.loss_w)
783 | bce_criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(self.loss_w).cuda(), reduction='none') # include sigmoid
784 | bce_criterion_unsup = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(self.loss_w_unknown).cuda(), reduction='none') # include sigmoid
785 |
786 | for epoch in range(self.args.local_ep):
787 | print('local_epoch:', epoch)
788 | batch_loss = []
789 | for samples, item, active_class_list in self.ldr_train:
790 | active_class_list_client = []
791 | negetive_class_list_client = []
792 | for i in range(self.args.annotation_num):
793 | active_class_list_client.append(active_class_list[i][0].item())
794 | for i in range(self.args.n_classes):
795 | if i not in active_class_list_client:
796 | negetive_class_list_client.append(i)
797 | images1, images2, labels = samples["image_aug_1"].to(self.args.device), samples["image_aug_2"].to(self.args.device), samples["target"].to(self.args.device)
798 | _, logits_weak = net(images1)
799 | logits_weak_sig = torch.sigmoid(logits_weak).cuda()
800 | idx = set(range(self.args.batch_size))
801 | for c in negetive_class_list_client:
802 | idx = idx.intersection(set(torch.where(logits_weak_sig[:, c] > 0.8)[0].tolist()).union(set(torch.where(logits_weak_sig[:, c] < 0.2)[0].tolist())))
803 | idx = list(idx)
804 | logits_weak_sig = torch.where(logits_weak_sig > 0.5, 1.0, 0.0) # turn to hard label
805 | _, logits_strong = net(images2)
806 | loss = bce_criterion(logits_weak, labels) # tensor(32, 5)
807 | loss_sup = loss[:, active_class_list_client].sum()/(self.args.batch_size * self.args.annotation_num) # supervised_loss
808 | if len(idx) == 0 or len(negetive_class_list_client) == 0:
809 | loss = loss_sup
810 | else:
811 | loss = bce_criterion_unsup(logits_strong, logits_weak_sig)
812 | loss_unsup = loss[idx, :][:, negetive_class_list_client].sum()/(len(idx) * (self.args.n_classes - self.args.annotation_num)) # supervised_loss
813 | # print('loss_sup: ', loss_sup)
814 | # print('loss_unsup: ', loss_unsup)
815 | loss = loss_sup + 1. * loss_unsup
816 | self.optimizer.zero_grad()
817 | loss.backward()
818 | self.optimizer.step()
819 | batch_loss.append(loss.item())
820 | self.iter_num += 1
821 | self.epoch = self.epoch + 1
822 | epoch_loss.append(np.array(batch_loss).mean())
823 | net.cpu()
824 | self.optimizer.zero_grad()
825 | return net.state_dict(), np.array(epoch_loss).mean(), _, _, negetive_class_list_client, active_class_list_client
826 |
827 | def mixup_criterion(self, y_a, y_b, lam):
828 | return lambda criterion, pred: (lam * criterion(pred, y_a).T).T + ((1 - lam) * criterion(pred, y_b).T).T
829 |
830 | def test_loss(self, rnd, net, model_name):
831 | net.eval()
832 | bce_criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(self.loss_w).cuda(),
833 | reduction='none') # include sigmoid
834 | epoch_false_negetive_loss = []
835 | epoch_true_negetive_loss = []
836 | loss_false_negetive = []
837 | loss_true_negetive = []
838 | for i in range(self.args.n_classes - self.args.annotation_num):
839 | loss_false_negetive.append([])
840 | loss_true_negetive.append([])
841 |
842 | active_class_list_client = []
843 | feature = torch.tensor([]).cuda()
844 | flags = torch.zeros([self.args.n_classes - self.args.annotation_num, len(self.ldr_train.dataset)]).cuda() # 0:FN, 1:TN
845 | flag_class_one = torch.zeros([1, len(self.ldr_train.dataset)]).cuda()
846 | # print(flags.shape)
847 | # input()
848 | count = 0
849 | with torch.no_grad():
850 | for samples, item, active_class_list in self.ldr_train:
851 | images, labels = samples["image"].to(self.args.device), samples["target"].to(self.args.device)
852 | for i in range(len(item)):
853 | flag_class_one[0, count*len(item)+i] = int(labels[i, 1] == 0)
854 | for i in range(self.args.annotation_num):
855 | active_class_list_client.append(active_class_list[i][0].item())
856 | negetive_class_list_client = []
857 | feature_batch, logits = net(images)
858 | feature = torch.cat((feature, feature_batch), dim=0)
859 |
860 | loss = bce_criterion(logits, labels) # tensor(32, 5)
861 | class_miss_loss = []
862 | for i in range(self.args.n_classes):
863 | if i not in active_class_list_client:
864 | negetive_class_list_client.append(i)
865 | class_miss_loss.append(loss[:, i].clone().detach().cpu())
866 |
867 | for i in range(len(item)):
868 | for j in range(len(negetive_class_list_client)):
869 | if item[i].item() in self.class_neg_idx[negetive_class_list_client[j]]:
870 | loss_false_negetive[j].append(class_miss_loss[j][i].item())
871 | flags[j, i+count*len(item)] = 0
872 |
873 | elif item[i].item() in self.class_pos_idx[negetive_class_list_client[j]]:
874 | pass
875 | else:
876 | loss_true_negetive[j].append(class_miss_loss[j][i].item())
877 | flags[j, i+count*len(item)] = 1
878 | count += 1
879 |
880 | # for j in range(len(negetive_class_list_client)):
881 | # plt.title(
882 | # f'model:{model_name},round:{rnd},client:{self.client_id},miss class:{negetive_class_list_client[j]} test loss distribution')
883 | # sns.kdeplot(loss_false_negetive[j], label='FN')
884 | # sns.kdeplot(loss_true_negetive[j], label='TN')
885 | # plt.legend()
886 | # plt.savefig(
887 | # f'loss_fig/model:{model_name},round:{rnd},client:{self.client_id},miss class:{negetive_class_list_client[j]}_test_loss_distribution.png')
888 | # plt.clf()
889 |
890 | for j in range(len(negetive_class_list_client)):
891 | tnse_Visual(feature.cpu(), flags[j].cpu(), rnd, f'model{model_name} class{negetive_class_list_client[j]} p=1')
892 |
893 | tnse_Visual(feature.cpu(), flag_class_one[0].cpu(), rnd, f'model{model_name} class{1} p=1')
894 |
895 | for j in range(len(negetive_class_list_client)):
896 | epoch_false_negetive_loss.append(np.array(loss_false_negetive[j]).mean())
897 | epoch_true_negetive_loss.append(np.array(loss_true_negetive[j]).mean())
898 | net.cpu()
899 | return epoch_false_negetive_loss, epoch_true_negetive_loss
900 |
901 | def find_indices_in_a(self, a, b):
902 | return torch.where(a.unsqueeze(0) == b.unsqueeze(1))[1]
903 |
904 | def train_FedMLP(self, rnd, tao, Prototype, writer1, negetive_class_list, active_class_list_client_i, net): # my method
905 | # assert len(self.ldr_train.dataset) == len(self.idxs)
906 | # print(f"Client ID: {self.client_id}, Num: {len(self.ldr_train.dataset)}")
907 | if rnd < self.args.rounds_FedMLP_stage1: # stage1
908 | glob_model = deepcopy(net)
909 | net.train()
910 | glob_model.eval()
911 | # set the optimizer
912 | self.optimizer = torch.optim.Adam(
913 | net.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4)
914 | # train and update
915 | epoch_loss = []
916 | print(self.loss_w)
917 | bce_criterion_sup = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(self.loss_w).cuda(),
918 | reduction='none') # include sigmoid
919 | bce_criterion_unsup = nn.MSELoss()
920 | for epoch in range(self.args.local_ep):
921 | print('local_epoch:', epoch)
922 | batch_loss = []
923 | for j, (samples, item, active_class_list) in enumerate(self.ldr_train):
924 | if j == 0:
925 | active_class_list_client = []
926 | negetive_class_list_client = []
927 | for i in range(self.args.annotation_num):
928 | active_class_list_client.append(active_class_list[i][0].item())
929 | for i in range(self.args.n_classes):
930 | if i not in active_class_list_client:
931 | negetive_class_list_client.append(i)
932 | self.class_num_list[i] = 0 # try noro
933 | criterion = LogitAdjust_Multilabel(cls_num_list=self.class_num_list, num=len(self.idxs))
934 | mse_loss = nn.MSELoss(reduction='none')
935 | images1, images2, labels = samples["image_aug_1"].to(self.args.device), samples["image_aug_2"].to(
936 | self.args.device), samples["target"].to(self.args.device)
937 | fe1, logits1 = net(images1)
938 | logits1_sig = torch.sigmoid(logits1).cuda()
939 | _, logits2 = net(images2)
940 | logits2_sig = torch.sigmoid(logits2).cuda()
941 | # loss_sup1 = bce_criterion_sup(logits1, labels) # tensor(32, 5)
942 | # loss_sup2 = bce_criterion_sup(logits2, labels) # tensor(32, 5)
943 | with torch.no_grad():
944 | _, outputs_global = glob_model(images1)
945 | logits3 = torch.sigmoid(outputs_global).cuda()
946 | _, outputs_global = glob_model(images2)
947 | logits4 = torch.sigmoid(outputs_global).cuda()
948 | loss_dis1 = mse_loss(logits1_sig, logits3).cuda()
949 | loss_dis2 = mse_loss(logits2_sig, logits4).cuda()
950 | loss_dis = (loss_dis1 + loss_dis2) / 2.
951 | loss_sup1 = criterion(logits1_sig, labels) # tensor(32, 5)
952 | loss_sup2 = criterion(logits2_sig, labels) # tensor(32, 5)
953 | loss_sup = (loss_sup1 + loss_sup2) / 2.
954 | # loss_sup = loss_sup.sum() / (self.args.batch_size * self.args.n_classes) # supervised_loss
955 |
956 | loss_sup = loss_sup[:, active_class_list_client].sum() / (
957 | self.args.batch_size * self.args.annotation_num) # supervised_loss
958 | loss_dis = loss_dis[:, negetive_class_list_client].sum() / (
959 | self.args.batch_size * len(negetive_class_list_client)) # supervised_loss
960 |
961 | loss_unsup = bce_criterion_unsup(logits1_sig[:, negetive_class_list_client],
962 | logits2_sig[:, negetive_class_list_client])
963 | loss = loss_sup + 0.0*loss_unsup + loss_dis
964 | self.optimizer.zero_grad()
965 | loss.backward()
966 | self.optimizer.step()
967 | batch_loss.append(loss.item())
968 | self.iter_num += 1
969 | self.epoch = self.epoch + 1
970 | epoch_loss.append(np.array(batch_loss).mean())
971 | if rnd == self.args.rounds_FedMLP_stage1 - 1: # first tao and proto
972 | print('client: ', self.client_id, active_class_list_client)
973 | proto = torch.zeros((self.args.n_classes * 2, len(fe1[0]))) # [cls0proto0, cls0proto1, cls1proto0...]
974 | # proto = np.array([torch.zeros_like(fe1[0].cpu())] * self.args.n_classes * 2)
975 | num_proto = [0] * self.args.n_classes * 2
976 | t = np.array([0] * self.args.n_classes)
977 | test_loader = DataLoader(dataset=self.local_dataset, batch_size=self.args.batch_size * 4, shuffle=False,
978 | num_workers=8)
979 | net.eval()
980 | with torch.no_grad():
981 | for samples, _, _ in test_loader:
982 | images1, labels = samples["image_aug_1"].to(self.args.device), samples["target"].to(self.args.device)
983 | feature, outputs = net(images1)
984 | probs = torch.sigmoid(outputs) # soft predict
985 | for cls in active_class_list_client:
986 | idx0 = torch.where(labels[:, cls] == 0)[0].tolist()
987 | idx1 = torch.where(labels[:, cls] == 1)[0].tolist()
988 | num_proto[2*cls] += len(idx0)
989 | num_proto[2*cls+1] += len(idx1)
990 | proto[2*cls] = feature[idx0, :].sum(0).cpu() + proto[2*cls]
991 | proto[2*cls+1] = feature[idx1, :].sum(0).cpu() + proto[2*cls+1]
992 | # t[cls] += torch.sum(probs[idx0, cls] < self.args.L).item() + torch.sum(
993 | # probs[idx1, cls] > self.args.U).item()
994 | for cls in negetive_class_list:
995 | t[cls] += torch.sum(
996 | torch.logical_or(probs[:, cls] < self.args.L, probs[:, cls] > self.args.U)).item()
997 | for cls in active_class_list_client:
998 | proto[2*cls] = proto[2*cls] / num_proto[2*cls]
999 | proto[2*cls+1] = proto[2*cls+1] / num_proto[2*cls+1]
1000 | t = t / len(self.local_dataset)
1001 | print('local_t: ', t)
1002 | return net.state_dict(), np.array(epoch_loss).mean(), _, _, negetive_class_list_client, active_class_list_client, t, proto
1003 | else:
1004 | return net.state_dict(), np.array(epoch_loss).mean(), _, _, negetive_class_list_client, active_class_list_client
1005 |
1006 | else: # stage2
1007 | print(self.local_dataset.active_class_list)
1008 | # find train samples for each class
1009 | idx = [] # [[negcls1], [negcls2]]
1010 | idxss = []
1011 | feature = []
1012 | similarity = []
1013 | clean_idx = []
1014 | noise_idx = []
1015 | label = [] # [[negcls1], [negcls2]]
1016 | num_train = 0
1017 | glob_model = deepcopy(net)
1018 | net.eval()
1019 | class_idx = torch.tensor([])
1020 | l = torch.tensor([]).cuda()
1021 | f = torch.tensor([]).cuda()
1022 | t1 = time.time()
1023 | if rnd == self.args.rounds_FedMLP_stage1:
1024 | # print(len(self.ldr_train.dataset))
1025 | self.traindata_idx = [] # [[negcls1_clean_train_idx], [negcls1_noise_train_idx], [negcls2_clean_train_idx], [negcls2_noise_train_idx]] idx
1026 | for samples, item, active_class_list in self.ldr_train:
1027 | class_idx = torch.cat((class_idx, item), dim=0)
1028 | images1, labels = samples["image_aug_1"].to(self.args.device), samples["target"].to(self.args.device)
1029 | with torch.no_grad():
1030 | features, _ = net(images1)
1031 | f = torch.cat((f, features), dim=0)
1032 | l = torch.cat((l, labels), dim=0)
1033 | for i in range(len(negetive_class_list)):
1034 | feature.append(f) # miss n classes
1035 | idx.append(class_idx)
1036 | label.append(l)
1037 | else:
1038 | for samples, item, active_class_list in self.ldr_train:
1039 | class_idx = torch.cat((class_idx, item), dim=0)
1040 | images1, labels = samples["image_aug_1"].to(self.args.device), samples["target"].to(self.args.device)
1041 | with torch.no_grad():
1042 | features, _ = net(images1)
1043 | f = torch.cat((f, features), dim=0)
1044 | l = torch.cat((l, labels), dim=0)
1045 | for i in range(len(self.idxss)):
1046 | result_indices = self.find_indices_in_a(class_idx, torch.tensor(self.idxss[i]))
1047 | feature.append(f[result_indices])
1048 | idx.append(class_idx[result_indices])
1049 | label.append(l[result_indices])
1050 | t2 = time.time()
1051 | # print('feature_label_prepare_time: ', t2-t1)
1052 | for i, cls in enumerate(negetive_class_list):
1053 | # sim = []
1054 | proto_0 = Prototype[2*cls]
1055 | proto_1 = Prototype[2*cls+1]
1056 | model = CosineSimilarityFast().cuda()
1057 | sim = (model(feature[i], torch.unsqueeze(proto_0.cuda(), dim=0)) - model(feature[i], torch.unsqueeze(proto_1.cuda(), dim=0))).tolist()
1058 | similarity.append(sim)
1059 | t3 = time.time()
1060 | # print('sim_compute_time: ', t3 - t2)
1061 | for i in range(len(negetive_class_list)):
1062 | idx_0 = np.where(np.array(similarity[i]) >= 0)[0]
1063 | idx_1 = np.where(np.array(similarity[i]) < 0)[0]
1064 | clean_idx.append(idx_0.tolist())
1065 | noise_idx.append(idx_1.tolist())
1066 | if rnd == self.args.rounds_FedMLP_stage1:
1067 | for i, cls in enumerate(negetive_class_list):
1068 | print('cls', cls, 'tao: ', tao[cls])
1069 | num_clean_cls = int(1 * self.args.clean_threshold * len(clean_idx[i]))
1070 | num_noise_cls = int(1 * self.args.noise_threshold * len(noise_idx[i]))
1071 | # num_clean_cls = int(tao[cls] * len(clean_idx[i]))
1072 | # num_noise_cls = int(tao[cls] * len(noise_idx[i]))
1073 | num_train = num_train + num_noise_cls + num_clean_cls
1074 | max_m_indices_list = np.array(max_m_indices(similarity[i], num_clean_cls))
1075 | min_n_indices_list = np.array(min_n_indices(similarity[i], num_noise_cls))
1076 | if len(max_m_indices_list) == 0 and len(max_m_indices_list) == 0:
1077 | negcls_clean_train_idx = []
1078 | negcls_noise_train_idx = []
1079 | elif len(min_n_indices_list) == 0 and len(max_m_indices_list) != 0:
1080 | negcls_noise_train_idx = []
1081 | negcls_clean_train_idx = np.array(idx[i])[max_m_indices_list].tolist()
1082 | elif len(min_n_indices_list) != 0 and len(max_m_indices_list) == 0:
1083 | negcls_noise_train_idx = np.array(idx[i])[min_n_indices_list].tolist()
1084 | negcls_clean_train_idx = []
1085 | else:
1086 | negcls_clean_train_idx = np.array(idx[i])[max_m_indices_list].tolist()
1087 | negcls_noise_train_idx = np.array(idx[i])[min_n_indices_list].tolist()
1088 | self.traindata_idx.append(negcls_clean_train_idx)
1089 | self.traindata_idx.append(negcls_noise_train_idx)
1090 | else:
1091 | for i, cls in enumerate(negetive_class_list):
1092 | print('cls', cls, 'tao: ', tao[cls])
1093 | num_clean_cls = int(1 * self.args.clean_threshold * len(clean_idx[i]))
1094 | num_noise_cls = int(1 * self.args.noise_threshold * len(noise_idx[i]))
1095 | num_train = num_train + num_noise_cls + num_clean_cls
1096 | max_m_indices_list = np.array(max_m_indices(similarity[i], num_clean_cls))
1097 | min_n_indices_list = np.array(min_n_indices(similarity[i], num_noise_cls))
1098 |
1099 | if len(max_m_indices_list) == 0 and len(max_m_indices_list) == 0:
1100 | negcls_clean_train_idx = []
1101 | negcls_noise_train_idx = []
1102 | elif len(min_n_indices_list) == 0 and len(max_m_indices_list) != 0:
1103 | negcls_noise_train_idx = []
1104 | negcls_clean_train_idx = np.array(idx[i])[max_m_indices_list].tolist()
1105 | elif len(min_n_indices_list) != 0 and len(max_m_indices_list) == 0:
1106 | negcls_noise_train_idx = np.array(idx[i])[min_n_indices_list].tolist()
1107 | negcls_clean_train_idx = []
1108 | else:
1109 | negcls_clean_train_idx = np.array(idx[i])[max_m_indices_list].tolist()
1110 | negcls_noise_train_idx = np.array(idx[i])[min_n_indices_list].tolist()
1111 | self.traindata_idx[2*i].extend(negcls_clean_train_idx)
1112 | self.traindata_idx[2*i+1].extend(negcls_noise_train_idx)
1113 |
1114 | t4 = time.time()
1115 | # print('traindata_split_time: ', t4 - t3)
1116 |
1117 | for i, cls in enumerate(negetive_class_list):
1118 | print('class: ', cls, 'clean_train_samples: ', len(self.traindata_idx[2*i]))
1119 | print('class: ', cls, 'noise_train_samples: ', len(self.traindata_idx[2 * i+1]))
1120 | self.class_num_list[cls] = len(self.traindata_idx[2 * i+1])
1121 | # if rnd % 10 == 0:
1122 | # real_clean = 0
1123 | # real_noise = 0
1124 | # # print(self.dataset[self.traindata_idx[2*i]]['target'][cls])
1125 | # # print(type(self.dataset[self.traindata_idx[2*i]]['target'][cls]))
1126 | # # input()
1127 | # for j in self.traindata_idx[2*i]:
1128 | # if self.dataset[j]['target'][cls] == 0:
1129 | # real_clean += 1
1130 | # print('clean_acc: ', real_clean/len(self.traindata_idx[2*i]))
1131 | # for j in self.traindata_idx[2*i+1]:
1132 | # if self.dataset[j]['target'][cls] == 1:
1133 | # real_noise += 1
1134 | # if len(self.traindata_idx[2 * i+1]) == 0:
1135 | # print('noise_acc: ', 0)
1136 | # writer1.add_scalar(f'noise_acc/client{self.client_id}/class{cls}',
1137 | # 0, rnd)
1138 | # else:
1139 | # print('noise_acc: ', real_noise/len(self.traindata_idx[2*i+1]))
1140 | # writer1.add_scalar(f'noise_acc/client{self.client_id}/class{cls}',
1141 | # real_noise / len(self.traindata_idx[2 * i + 1]), rnd)
1142 | # writer1.add_scalar(f'clean_acc/client{self.client_id}/class{cls}', real_clean/len(self.traindata_idx[2*i]), rnd)
1143 | t5 = time.time()
1144 | # print('acc_compute_time: ', t5 - t4)
1145 | # train
1146 | net.train()
1147 | glob_model.eval()
1148 | # set the optimizer
1149 | self.optimizer = torch.optim.Adam(
1150 | net.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4)
1151 | # train and update
1152 | epoch_loss = []
1153 | loss_w = self.loss_w
1154 | for i, cls in enumerate(negetive_class_list):
1155 | if len(self.traindata_idx[2*i+1]) != 0:
1156 | loss_w[cls] = len(self.traindata_idx[2*i]) / len(self.traindata_idx[2*i+1])
1157 | else:
1158 | loss_w[cls] = 5.0
1159 | print(loss_w)
1160 | bce_criterion_sup = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(loss_w).cuda(),
1161 | reduction='none') # include sigmoid
1162 | mse_loss = nn.MSELoss(reduction='none')
1163 |
1164 | for epoch in range(self.args.local_ep):
1165 | print('local_epoch:', epoch)
1166 | batch_loss = []
1167 | dataset = DatasetSplit_pseudo(self.dataset, self.idxs, self.client_id, self.args,
1168 | active_class_list_client_i, negetive_class_list, self.traindata_idx)
1169 | dataloader = DataLoader(dataset, batch_size=self.args.batch_size, shuffle=True,
1170 | num_workers=8)
1171 | for samples, item, distill_cls in dataloader:
1172 | distill_cls = distill_cls.cuda()
1173 | sup_cls = (~distill_cls.bool()).float().cuda()
1174 | criterion = LogitAdjust_Multilabel(cls_num_list=self.class_num_list,
1175 | num=len(self.idxs))
1176 | images1, images2, labels = samples["image_aug_1"].to(self.args.device), samples["image_aug_2"].to(self.args.device), samples["target"].to(
1177 | self.args.device)
1178 | feature, outputs = net(images1)
1179 | logits1 = torch.sigmoid(outputs).cuda()
1180 | with torch.no_grad():
1181 | _, outputs_global = glob_model(images1)
1182 | logits2 = torch.sigmoid(outputs_global).cuda()
1183 | # loss_sup = bce_criterion_sup(outputs, labels).cuda()
1184 | loss_sup = criterion(logits1, labels).cuda()
1185 | loss_dis = mse_loss(logits1, logits2).cuda()
1186 |
1187 | # loss = ((loss_sup * sup_cls).sum() + (loss_dis * distill_cls).sum()) / (sup_cls.sum() + distill_cls.sum())
1188 | loss = (loss_sup * sup_cls).sum() / sup_cls.sum()
1189 | # print(loss)
1190 | self.optimizer.zero_grad()
1191 | loss.backward()
1192 | self.optimizer.step()
1193 | batch_loss.append(loss.item())
1194 | self.iter_num += 1
1195 | self.epoch = self.epoch + 1
1196 | epoch_loss.append(np.array(batch_loss).mean())
1197 | for i in range(len(self.traindata_idx) // 2):
1198 | idxs = self.traindata_idx[2 * i] + self.traindata_idx[2 * i + 1]
1199 | idxss.append(idxs)
1200 | t6 = time.time()
1201 | # print('local_train_time: ', t6 - t5)
1202 | self.idxss = idxss
1203 | for i in range(len(idxss)):
1204 | self.idxss[i] = list(set(self.idxs)-set(idxss[i])) # [[negcls1_else_idx], [negcls2_else_idx]]
1205 |
1206 | # proto = np.array(
1207 | # [torch.zeros_like(feature[0].cpu())] * self.args.n_classes * 2) # [cls0proto0, cls0proto1, cls1proto0...]
1208 | proto = torch.zeros((self.args.n_classes * 2, len(feature[0]))) # [cls0proto0, cls0proto1, cls1proto0...]
1209 | num_proto = [0] * self.args.n_classes * 2
1210 | t = np.array([0] * self.args.n_classes)
1211 | # test_idx = set() # partial_proto
1212 | # for i in range(len(self.traindata_idx) // 2):
1213 | # test_idx = test_idx.union(set(self.traindata_idx[2 * i]))
1214 | # test_idx = test_idx.union(set(self.traindata_idx[2 * i + 1]))
1215 | # test_dataset = DatasetSplit(self.dataset, list(test_idx), self.client_id, self.args, self.class_neg_idx,
1216 | # active_class_list_client_i)
1217 | # test_loader = DataLoader(dataset=test_dataset, batch_size=self.args.batch_size * 4, shuffle=False,
1218 | # num_workers=8)
1219 |
1220 | test_loader = DataLoader(dataset=self.local_dataset, batch_size=self.args.batch_size * 4, shuffle=False,
1221 | num_workers=8)
1222 | net.eval()
1223 | with torch.no_grad():
1224 | for samples, _, _ in test_loader:
1225 | images1, labels = samples["image_aug_1"].to(self.args.device), samples["target"].to(
1226 | self.args.device)
1227 | feature, outputs = net(images1)
1228 | probs = torch.sigmoid(outputs) # soft predict
1229 | for cls in self.local_dataset.active_class_list:
1230 | idx0 = torch.where(labels[:, cls] == 0)[0]
1231 | idx1 = torch.where(labels[:, cls] == 1)[0]
1232 | num_proto[2 * cls] += len(idx0)
1233 | num_proto[2 * cls + 1] += len(idx1)
1234 | proto[2 * cls] += feature[idx0, :].sum(0).cpu()
1235 | proto[2 * cls + 1] += feature[idx1, :].sum(0).cpu()
1236 | # t[cls] += torch.sum(probs[idx0, cls] < self.args.L).item() + torch.sum(
1237 | # probs[idx1, cls] > self.args.U).item()
1238 | for cls in negetive_class_list:
1239 | t[cls] += torch.sum(torch.logical_or(probs[:, cls] < self.args.L, probs[:, cls] > self.args.U)).item()
1240 | for cls in self.local_dataset.active_class_list:
1241 | if num_proto[2 * cls] == 0:
1242 | proto[2 * cls] = proto[2 * cls]
1243 | else:
1244 | proto[2 * cls] = proto[2 * cls] / num_proto[2 * cls]
1245 | if num_proto[2 * cls+1] == 0:
1246 | proto[2 * cls+1] = proto[2 * cls+1]
1247 | else:
1248 | proto[2 * cls+1] = proto[2 * cls+1] / num_proto[2 * cls+1]
1249 | t = t / len(self.local_dataset)
1250 | print('local_t: ', t)
1251 | net.cpu()
1252 | self.optimizer.zero_grad()
1253 | t7 = time.time()
1254 | print('local_test_proto_time: ', t7 - t6)
1255 | return net.state_dict(), np.array(
1256 | epoch_loss).mean(), _, _, negetive_class_list, self.local_dataset.active_class_list, t, proto
1257 |
1258 | def js(self, p_output, q_output):
1259 | """
1260 | :param predict: model prediction for original data
1261 | :param target: model prediction for mildly augmented data
1262 | :return: loss
1263 | """
1264 | KLDivLoss = nn.KLDivLoss(reduction='mean')
1265 | log_mean_output = ((p_output + q_output) / 2).log()
1266 | return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output)) / 2
1267 |
1268 | def anti_sigmoid(self, p):
1269 | return torch.log(p / (1 - p))
1270 | def train_FedLSR(self, rnd, net):
1271 | assert len(self.ldr_train.dataset) == len(self.idxs)
1272 | print(f"Client ID: {self.client_id}, Num: {len(self.ldr_train.dataset)}")
1273 | net.train()
1274 | # set the optimizer
1275 | self.optimizer = torch.optim.Adam(
1276 | net.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=5e-4)
1277 | # train and update
1278 | epoch_loss = []
1279 | print(self.loss_w)
1280 | bce_criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(self.loss_w).cuda()) # include sigmoid
1281 | for epoch in range(self.args.local_ep):
1282 | print('local_epoch:', epoch)
1283 | batch_loss = []
1284 | for samples, item, active_class_list in self.ldr_train:
1285 | active_class_list_client = []
1286 | negetive_class_list_client = []
1287 | for i in range(self.args.annotation_num):
1288 | active_class_list_client.append(active_class_list[i][0].item())
1289 | for i in range(self.args.n_classes):
1290 | if i not in active_class_list_client:
1291 | negetive_class_list_client.append(i)
1292 | images1, images2, labels = samples["image_aug_1"].to(self.args.device), samples["image_aug_2"].to(self.args.device), samples["target"].to(self.args.device)
1293 |
1294 | _, logits1_ori = net(images1)
1295 | _, logits2_ori = net(images2)
1296 | mix_1 = np.random.beta(1, 1) # mixing predict1 and predict2
1297 | mix_2 = 1 - mix_1
1298 | # to further conduct self distillation, *3 means the temperature T_d is 1/3
1299 | logits1 = torch.sigmoid(logits1_ori * 3).cuda()
1300 | logits2 = torch.sigmoid(logits2_ori * 3).cuda()
1301 |
1302 | # for training stability to conduct clamping to avoid exploding gradients, which is also used in Symmetric CE, ICCV 2019
1303 | logits1, logits2 = torch.clamp(logits1, min=1e-6, max=1.0), torch.clamp(logits2, min=1e-6, max=1.0)
1304 |
1305 | # to mix up the two predictions
1306 | p = torch.sigmoid(logits1_ori) * mix_1 + torch.sigmoid(logits2_ori) * mix_2
1307 | p = self.anti_sigmoid(p)
1308 | pred_mix = torch.sigmoid(p * 2).cuda()
1309 | betaa = 0.4
1310 | if (rnd < self.args.t_w):
1311 | betaa = 0.4 * rnd / self.args.t_w
1312 | loss = bce_criterion(pred_mix, labels) # to compute cross entropy loss
1313 | # print(loss)
1314 | loss += self.js(logits1, logits2) * betaa
1315 | # print(loss)
1316 |
1317 | self.optimizer.zero_grad()
1318 | loss.backward()
1319 | self.optimizer.step()
1320 | batch_loss.append(loss.item())
1321 | self.iter_num += 1
1322 | self.epoch = self.epoch + 1
1323 | epoch_loss.append(np.array(batch_loss).mean())
1324 | net.cpu()
1325 | self.optimizer.zero_grad()
1326 | return net.state_dict(), np.array(epoch_loss).mean(), _, _, negetive_class_list_client, active_class_list_client
1327 |
1328 | class DatasetSplit(Dataset):
1329 | def __init__(self, dataset, idxs, client_id, args, class_neg_idx, active_class_list=None, negative_class_list=None, corr_idx=None):
1330 | self.dataset = dataset
1331 | self.negative_class_list = negative_class_list
1332 | self.corr_idx = corr_idx
1333 | self.idxs = list(idxs)
1334 | self.client_id = client_id # choose active classes
1335 | self.annotation_num = args.annotation_num
1336 | class_list = list(range(args.n_classes))
1337 | if active_class_list is None:
1338 | self.active_class_list = random.sample(class_list, self.annotation_num)
1339 | else:
1340 | self.active_class_list = active_class_list
1341 | logging.info(f"Client ID: {self.client_id}, active_class_list: {self.active_class_list}")
1342 | self.class_neg_idx = class_neg_idx
1343 |
1344 | def __len__(self):
1345 | return len(self.idxs)
1346 |
1347 | def __getitem__(self, item):
1348 | sample = self.dataset[self.idxs[item]]
1349 | for i in range(len(sample['target'])):
1350 | if i not in self.active_class_list and self.idxs[item] in self.class_neg_idx[i]:
1351 | sample['target'][i] = 0
1352 | if self.corr_idx is not None:
1353 | for i, class_id in enumerate(self.negative_class_list):
1354 | if self.idxs[item] in self.corr_idx[i]:
1355 | sample['target'][class_id] = 1
1356 | return sample, self.idxs[item], self.active_class_list
1357 |
1358 | def get_num_of_each_class(self, args):
1359 | class_sum = np.array([0.] * args.n_classes)
1360 | for idx in self.idxs:
1361 | class_sum += self.dataset.targets[idx]
1362 | return class_sum.tolist()
1363 |
1364 |
1365 | class DatasetSplit_Mixup(Dataset):
1366 | def __init__(self, dataset, clean_idxs, noise_idxs, args, negative_class, negative_class_list, train_ratio):
1367 | self.dataset = dataset
1368 | self.negative_class = negative_class
1369 | self.clean_idxs = clean_idxs
1370 | self.noise_idxs = noise_idxs
1371 | self.annotation_num = args.annotation_num
1372 | self.negative_class_list = negative_class_list
1373 | self.train_ratio = train_ratio
1374 |
1375 | def __len__(self):
1376 | if self.train_ratio < 1:
1377 | return int(self.train_ratio * (len(self.clean_idxs) + len(self.noise_idxs)))
1378 | else:
1379 | return int(len(self.clean_idxs) + len(self.noise_idxs))
1380 |
1381 | def __getitem__(self, item):
1382 | if self.train_ratio < 1:
1383 | item = int(item / self.train_ratio)
1384 | if item < len(self.clean_idxs): # clean sample mixup
1385 | flag = 0
1386 | index = random.choice(self.clean_idxs)
1387 | sample1 = deepcopy(self.dataset[self.clean_idxs[item]])
1388 | sample2 = deepcopy(self.dataset[index])
1389 | for i in range(len(sample1['target'])):
1390 | if i in self.negative_class_list:
1391 | sample1['target'][i] = 0
1392 | sample2['target'][i] = 0
1393 | mixed_x, lam = self.mixup_data(sample1["image_aug_1"], sample2["image_aug_1"])
1394 | else: # noise sample mixup
1395 | flag = 1
1396 | index = random.choice(self.noise_idxs)
1397 | sample1 = self.dataset[self.noise_idxs[item - len(self.clean_idxs)]]
1398 | sample2 = self.dataset[index]
1399 | for i in range(len(sample1['target'])):
1400 | if i in self.negative_class_list:
1401 | sample1['target'][i] = 0
1402 | sample2['target'][i] = 0
1403 | mixed_x, lam = self.mixup_data(sample1["image_aug_1"], sample2["image_aug_1"])
1404 | sample1['target'][self.negative_class] = 1
1405 | sample2['target'][self.negative_class] = 1
1406 | return mixed_x, lam, flag, sample1, sample2
1407 |
1408 | def mixup_data(self, x1, x2, alpha=1.0):
1409 | '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda'''
1410 | if alpha > 0.:
1411 | lam = np.random.beta(alpha, alpha)
1412 | else:
1413 | lam = 1.
1414 | mixed_x = lam * x1 + (1 - lam) * x2
1415 | return mixed_x, lam
1416 |
1417 | class CosineSimilarityFast(torch.nn.Module):
1418 | def __init__(self):
1419 | super(CosineSimilarityFast, self).__init__()
1420 |
1421 | def forward(self, x1, x2):
1422 | x2 = x2.t()
1423 | # print(x1.shape)
1424 | # print(x2)
1425 | # print(x2.shape)
1426 | # input()
1427 | x = x1.mm(x2)
1428 |
1429 | x1_frobenius = x1.norm(dim=1).unsqueeze(0).t()
1430 | x2_frobenins = x2.norm(dim=0).unsqueeze(0)
1431 | x_frobenins = x1_frobenius.mm(x2_frobenins)
1432 |
1433 | final = x.mul(1/x_frobenins)
1434 | final = torch.squeeze(final, dim=1)
1435 | return final
1436 |
1437 | class DatasetSplit_pseudo(Dataset):
1438 | def __init__(self, dataset, idxs, client_id, args, active_class_list, negative_class_list, traindata_idx):
1439 | self.dataset = dataset
1440 | self.negative_class_list = negative_class_list
1441 | self.idxs = list(idxs)
1442 | self.client_id = client_id # choose active classes
1443 | self.annotation_num = args.annotation_num
1444 | self.active_class_list = active_class_list
1445 | self.traindata_idx = traindata_idx
1446 | self.args = args
1447 | self.idx_conf = []
1448 | for i in range(len(traindata_idx) // 2):
1449 | self.idx_conf += (traindata_idx[2 * i] + traindata_idx[2 * i + 1])
1450 | self.idx_nconf = list(set(self.idxs) - set(self.idx_conf))
1451 | logging.info(f"Client ID: {self.client_id}, active_class_list: {self.active_class_list}")
1452 |
1453 | def __len__(self):
1454 | return len(self.idxs)
1455 |
1456 | def __getitem__(self, item):
1457 | distill_cls = torch.zeros(self.args.n_classes)
1458 | sample = self.dataset[self.idxs[item]]
1459 | for i in range(len(sample['target'])):
1460 | if i not in self.active_class_list:
1461 | sample['target'][i] = 0
1462 | for i in range(len(self.traindata_idx)//2):
1463 | idx0 = self.traindata_idx[2*i]
1464 | idx1 = self.traindata_idx[2*i+1]
1465 | if self.idxs[item] in (idx0+idx1):
1466 | if self.idxs[item] in idx1:
1467 | sample['target'][self.negative_class_list[i]] = 1
1468 | else:
1469 | distill_cls[self.negative_class_list[i]] = 1
1470 | # if self.idxs[item] in self.idx_nconf: # mixup
1471 | # mixed_sample = copy.deepcopy(sample)
1472 | # index = random.choice(self.idx_nconf)
1473 | # sample2 = self.dataset[index]
1474 | # mixed_sample["image_aug_1"], lam = self.mixup_data(sample["image_aug_1"], sample2["image_aug_1"])
1475 | # mixed_sample['target'] = lam * sample['target'] + (1 - lam) * sample2['target']
1476 | # return mixed_sample, self.idxs[item], distill_cls
1477 | return sample, self.idxs[item], distill_cls
1478 |
1479 |
1480 | def get_num_of_each_class(self, args):
1481 | class_sum = np.array([0.] * args.n_classes)
1482 | for idx in self.idxs:
1483 | class_sum += self.dataset.targets[idx]
1484 | return class_sum.tolist()
1485 | def mixup_data(self, x1, x2, alpha=1.0):
1486 | '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda'''
1487 | if alpha > 0.:
1488 | lam = np.random.beta(alpha, alpha)
1489 | else:
1490 | lam = 1.
1491 | mixed_x = lam * x1 + (1 - lam) * x2
1492 | return mixed_x, lam
1493 |
--------------------------------------------------------------------------------