├── dataset ├── data │ ├── STL_client_num=10_alpha=0.02.png │ ├── STL_client_num=10_alpha=0.05.png │ ├── FMNIST_client_num=10_alpha=0.02.png │ ├── FMNIST_client_num=10_alpha=0.05.png │ ├── STL32_client_num=10_alpha=0.02.png │ ├── STL32_client_num=10_alpha=0.05.png │ ├── CIFAR10_client_num=10_alpha=0.02.png │ ├── CIFAR10_client_num=10_alpha=0.05.png │ ├── OCTMNIST_client_num=10_alpha=0.02.png │ ├── OCTMNIST_client_num=10_alpha=0.05.png │ ├── PathMNIST_client_num=10_alpha=0.02.png │ ├── PathMNIST_client_num=10_alpha=0.05.png │ ├── ImageNette_client_num=10_alpha=0.02.png │ ├── ImageNette_client_num=10_alpha=0.05.png │ ├── OrganSMNIST_client_num=10_alpha=0.02.png │ ├── OrganSMNIST_client_num=10_alpha=0.05.png │ ├── OrganCMNIST224_client_num=10_alpha=0.02.png │ ├── OrganCMNIST224_client_num=10_alpha=0.05.png │ ├── RetinaMNIST224_client_num=10_alpha=0.02.png │ ├── RetinaMNIST224_client_num=10_alpha=0.05.png │ ├── PneumoniaMNIST224_client_num=10_alpha=0.02.png │ ├── PneumoniaMNIST224_client_num=10_alpha=0.05.png │ ├── dataset.py │ └── dataset_partition.py └── split_file │ ├── RetinaMNIST224_client_num=10_alpha=0.02.json │ └── RetinaMNIST224_client_num=10_alpha=0.05.json ├── run.sh ├── LICENSE ├── README.md ├── config.py ├── main.py ├── requirements.yaml └── src ├── models.py ├── client.py ├── server.py └── utils.py /dataset/data/STL_client_num=10_alpha=0.02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Youth-49/FedVCK_2024/HEAD/dataset/data/STL_client_num=10_alpha=0.02.png -------------------------------------------------------------------------------- /dataset/data/STL_client_num=10_alpha=0.05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Youth-49/FedVCK_2024/HEAD/dataset/data/STL_client_num=10_alpha=0.05.png -------------------------------------------------------------------------------- /dataset/data/FMNIST_client_num=10_alpha=0.02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Youth-49/FedVCK_2024/HEAD/dataset/data/FMNIST_client_num=10_alpha=0.02.png -------------------------------------------------------------------------------- /dataset/data/FMNIST_client_num=10_alpha=0.05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Youth-49/FedVCK_2024/HEAD/dataset/data/FMNIST_client_num=10_alpha=0.05.png -------------------------------------------------------------------------------- /dataset/data/STL32_client_num=10_alpha=0.02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Youth-49/FedVCK_2024/HEAD/dataset/data/STL32_client_num=10_alpha=0.02.png -------------------------------------------------------------------------------- /dataset/data/STL32_client_num=10_alpha=0.05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Youth-49/FedVCK_2024/HEAD/dataset/data/STL32_client_num=10_alpha=0.05.png -------------------------------------------------------------------------------- /dataset/data/CIFAR10_client_num=10_alpha=0.02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Youth-49/FedVCK_2024/HEAD/dataset/data/CIFAR10_client_num=10_alpha=0.02.png -------------------------------------------------------------------------------- /dataset/data/CIFAR10_client_num=10_alpha=0.05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Youth-49/FedVCK_2024/HEAD/dataset/data/CIFAR10_client_num=10_alpha=0.05.png -------------------------------------------------------------------------------- /dataset/data/OCTMNIST_client_num=10_alpha=0.02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Youth-49/FedVCK_2024/HEAD/dataset/data/OCTMNIST_client_num=10_alpha=0.02.png -------------------------------------------------------------------------------- /dataset/data/OCTMNIST_client_num=10_alpha=0.05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Youth-49/FedVCK_2024/HEAD/dataset/data/OCTMNIST_client_num=10_alpha=0.05.png -------------------------------------------------------------------------------- /dataset/data/PathMNIST_client_num=10_alpha=0.02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Youth-49/FedVCK_2024/HEAD/dataset/data/PathMNIST_client_num=10_alpha=0.02.png -------------------------------------------------------------------------------- /dataset/data/PathMNIST_client_num=10_alpha=0.05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Youth-49/FedVCK_2024/HEAD/dataset/data/PathMNIST_client_num=10_alpha=0.05.png -------------------------------------------------------------------------------- /dataset/data/ImageNette_client_num=10_alpha=0.02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Youth-49/FedVCK_2024/HEAD/dataset/data/ImageNette_client_num=10_alpha=0.02.png -------------------------------------------------------------------------------- /dataset/data/ImageNette_client_num=10_alpha=0.05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Youth-49/FedVCK_2024/HEAD/dataset/data/ImageNette_client_num=10_alpha=0.05.png -------------------------------------------------------------------------------- /dataset/data/OrganSMNIST_client_num=10_alpha=0.02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Youth-49/FedVCK_2024/HEAD/dataset/data/OrganSMNIST_client_num=10_alpha=0.02.png -------------------------------------------------------------------------------- /dataset/data/OrganSMNIST_client_num=10_alpha=0.05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Youth-49/FedVCK_2024/HEAD/dataset/data/OrganSMNIST_client_num=10_alpha=0.05.png -------------------------------------------------------------------------------- /dataset/data/OrganCMNIST224_client_num=10_alpha=0.02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Youth-49/FedVCK_2024/HEAD/dataset/data/OrganCMNIST224_client_num=10_alpha=0.02.png -------------------------------------------------------------------------------- /dataset/data/OrganCMNIST224_client_num=10_alpha=0.05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Youth-49/FedVCK_2024/HEAD/dataset/data/OrganCMNIST224_client_num=10_alpha=0.05.png -------------------------------------------------------------------------------- /dataset/data/RetinaMNIST224_client_num=10_alpha=0.02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Youth-49/FedVCK_2024/HEAD/dataset/data/RetinaMNIST224_client_num=10_alpha=0.02.png -------------------------------------------------------------------------------- /dataset/data/RetinaMNIST224_client_num=10_alpha=0.05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Youth-49/FedVCK_2024/HEAD/dataset/data/RetinaMNIST224_client_num=10_alpha=0.05.png -------------------------------------------------------------------------------- /dataset/data/PneumoniaMNIST224_client_num=10_alpha=0.02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Youth-49/FedVCK_2024/HEAD/dataset/data/PneumoniaMNIST224_client_num=10_alpha=0.02.png -------------------------------------------------------------------------------- /dataset/data/PneumoniaMNIST224_client_num=10_alpha=0.05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Youth-49/FedVCK_2024/HEAD/dataset/data/PneumoniaMNIST224_client_num=10_alpha=0.05.png -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | 2 | python main.py --dataset PathMNIST --model ConvNetBN --lr_server 0.001 --weight_decay_server 1e-6 --compression_ratio 0.01 --dc_iterations 5000 --image_lr 0.2 --init random_noise --clip_norm 10 --weighted_mmd --b 0 --contrastive_way supcon_asym_syn --con_beta 0.05 --device cuda:0 --topk 5 --alpha 0.05 --tag 2-2-1 3 | python main.py --dataset OrganSMNIST --model ConvNetBN --compression_ratio 0.05 --lr_server 0.001 --dc_iterations 10000 --image_lr 0.1 --init random_noise --weighted_sample --contrastive_way supcon_asym_syn --con_beta 0.1 --topk 5 --con_temp 0.1 --device cuda:0 --alpha 0.05 --tag 6-2-1 4 | python main.py --model ConvNetBN --compression_ratio 0.02 --lr_server 0.005 --weight_decay_server 1e-6 --dc_iterations 10000 --image_lr 1.0 --init random_noise --weighted_mmd --b 0.3 --contrastive_way supcon_asym_syn --con_beta 0.1 --topk 5 --con_temp 0.1 --device cuda:0 --alpha 0.05 --tag 1-2-1 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Guochen Yan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FedVCK@AAAI 2025 2 | 3 | FedVCK: Non-IID Robust and Communication-Efficient Federated Learning via Valuable Condensed Knowledge for Medical Image Analysis, Accepted by AAAI 2025 4 | 5 | 6 | 7 | ## Abstract 8 | 9 | Federated learning has become a promising solution for collaboration among medical institutions. However, data owned by each institution would be highly heterogeneous and the distribution is always non-independent and identical distribution (non-IID), resulting in client drift and unsatisfactory performance. Despite existing federated learning methods attempting to solve the non-IID problems, they still show marginal advantages but rely on frequent communication which would incur high costs and privacy concerns. In this paper, we propose a novel federated learning method: **Fed**erated learning via **V**aluable **C**ondensed **K**nowledge (FedVCK). We enhance the quality of condensed knowledge and select the most necessary knowledge guided by models, to tackle the non-IID problem within limited communication budgets effectively. Specifically, on the client side, we condense the knowledge of each client into a small dataset and further enhance the condensation procedure with latent distribution constraints, facilitating the effective capture of high-quality knowledge. During each round, we specifically target and condense knowledge that has not been assimilated by the current model, thereby preventing unnecessary repetition of homogeneous knowledge and minimizing the frequency of communications required. On the server side, we propose relational supervised contrastive learning to provide more supervision signals to aid the global model updating. Comprehensive experiments across various medical tasks show that FedVCK can outperform state-of-the-art methods, demonstrating that it's non-IID robust and communication-efficient. 10 | 11 | 12 | 13 | ----- 14 | 15 | 16 | 17 | Software Environment: see `requirements.yaml` 18 | 19 | Hardware Platform: Ubuntu with Geforce RTX 3090 GPU 20 | 21 | Supported datasets: {Path, OCT, OrganS, OranC, Pneumuonia)MNIST, CIFAR10, STL10, ImageNette. 22 | 23 | 24 | 25 | To run the code: 26 | 27 | ```bash 28 | bash run.sh 29 | ``` 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | 5 | # parser.add_argument("--debug", type=bool, default=False) 6 | parser.add_argument("--seed", type=int, default=19260817) 7 | parser.add_argument("--device", type=str, default="cuda:0") 8 | 9 | parser.add_argument("--dataset_root", type=str, default="./dataset/torchvision") 10 | parser.add_argument("--split_file", type=str, default="") 11 | parser.add_argument("--dataset", type=str, default='CIFAR10') 12 | parser.add_argument("--client_num", type=int, default=10) 13 | parser.add_argument("--partition", type=str, default='dirichlet') 14 | parser.add_argument("--alpha", type=float, default=0.5) 15 | parser.add_argument('--num_classes_per_client', type=int, default=2, help="label_split") 16 | 17 | parser.add_argument("--model", type=str, default="ConvNet") 18 | parser.add_argument("--communication_rounds", type=int, default=10) 19 | parser.add_argument("--join_ratio", type=float, default=1.0) 20 | parser.add_argument("--lr_server", type=float, default=0.01) 21 | parser.add_argument("--momentum_server", type=float, default=0.9) 22 | parser.add_argument("--weight_decay_server", type=float, default=0) 23 | 24 | parser.add_argument("--batch_size", type=int, default=256) 25 | parser.add_argument("--model_epochs", type=int, default=1000) 26 | parser.add_argument("--local_ep", type=int, default=40) 27 | parser.add_argument("--ipc", type=int, default=10) 28 | parser.add_argument("--compression_ratio", type=float, default=0.) 29 | parser.add_argument("--dc_iterations", type=int, default=1000) 30 | parser.add_argument("--dc_batch_size", type=int, default=256) 31 | parser.add_argument("--image_lr", type=float, default=1.0) 32 | parser.add_argument("--image_momentum", type=float, default=0.5) 33 | parser.add_argument("--image_weight_decay", type=float, default=0) 34 | parser.add_argument("--init", type=str, default='real') 35 | parser.add_argument("--clip_norm", type=float, default=30) 36 | parser.add_argument("--weighted_matching", action='store_true', default=False) 37 | parser.add_argument("--weighted_sample", action='store_true', default=False) 38 | parser.add_argument("--weighted_mmd", action='store_true', default=False) 39 | parser.add_argument("--contrastive_way", type=str, default='supcon_asym_syn', choices=['supcon_asym', 'supcon_asym_syn', 'supcon_relation']) 40 | parser.add_argument("--con_beta", type=float, default=0.) 41 | parser.add_argument("--con_temp", type=float, default=1.0) 42 | parser.add_argument("--topk", type=int, default=3) 43 | parser.add_argument("--lr_head", type=float, default=0.01) 44 | parser.add_argument("--momentum_head", type=float, default=0.9) 45 | parser.add_argument("--weight_decay_head", type=float, default=0) 46 | parser.add_argument("--gamma", type=float, default=1.0) 47 | parser.add_argument("--lamda", type=float, default=0.5) 48 | parser.add_argument("--b", type=float, default=0.7) 49 | parser.add_argument("--kernel", type=str, default='linear') 50 | 51 | parser.add_argument("--lr", type=float, default=0.01) 52 | parser.add_argument("--momentum", type=float, default=0.5) 53 | parser.add_argument("--weight_decay", type=float, default=0) 54 | 55 | parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy') 56 | parser.add_argument("--preserve_all", action='store_true', default=False) 57 | 58 | parser.add_argument("--eval_gap", type=int, default=1) 59 | 60 | parser.add_argument("--tag", type=str, default='0') 61 | parser.add_argument("--save_root_path", type=str, default='../results/') -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import os 4 | import sys 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import Subset 9 | 10 | from src.client import Client 11 | from src.server import Server 12 | from config import parser 13 | from dataset.data.dataset import get_dataset, PerLabelDatasetNonIID 14 | from src.utils import setup_seed, get_model, ParamDiffAug 15 | import logging 16 | 17 | def get_n_params(model): 18 | pp = 0 19 | for p in list(model.parameters()): 20 | nn = 1 21 | for s in list(p.size()): 22 | nn = nn * s 23 | pp += nn 24 | return pp 25 | 26 | def main(): 27 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1" 28 | args = parser.parse_args() 29 | args.dsa_param = ParamDiffAug() 30 | args.dsa = False if args.dsa_strategy == 'None' else True 31 | 32 | 33 | if args.partition == 'dirichlet': 34 | split_file = f'/{args.dataset}_client_num={args.client_num}_alpha={args.alpha}.json' 35 | args.split_file = os.path.join(os.path.dirname(__file__), "dataset/split_file"+split_file) 36 | if args.compression_ratio > 0.: 37 | model_identification = f'{args.dataset}_alpha{args.alpha}_{args.client_num}clients/{args.model}_{100*args.compression_ratio}%_{args.dc_iterations}dc_{args.model_epochs}epochs_{args.tag}' 38 | else: 39 | model_identification = f'{args.dataset}_alpha{args.alpha}_{args.client_num}clients/{args.model}_{args.ipc}ipc_{args.dc_iterations}dc_{args.model_epochs}epochs_{args.tag}' 40 | # raise Exception('Compression ratio should > 0') 41 | elif args.partition == 'label': 42 | split_file = f'/{args.dataset}_client_num={args.client_num}_label={args.num_classes_per_client}.json' 43 | args.split_file = os.path.join(os.path.dirname(__file__), "dataset/split_file"+split_file) 44 | if args.compression_ratio > 0.: 45 | model_identification = f'{args.dataset}_label{args.num_classes_per_client}_{args.client_num}clients/{args.model}_{100*args.compression_ratio}%_{args.dc_iterations}dc_{args.model_epochs}epochs_{args.tag}' 46 | else: 47 | raise Exception('Compression ratio should > 0') 48 | elif args.partition == 'pathological': 49 | split_file = f'/{args.dataset}_client_num={args.client_num}_pathological={args.num_classes_per_client}.json' 50 | args.split_file = os.path.join(os.path.dirname(__file__), "dataset/split_file"+split_file) 51 | if args.compression_ratio > 0.: 52 | model_identification = f'{args.dataset}_pathological{args.num_classes_per_client}_{args.client_num}clients/{args.model}_{100*args.compression_ratio}%_{args.dc_iterations}dc_{args.model_epochs}epochs_{args.tag}' 53 | else: 54 | raise Exception('Compression ratio should > 0') 55 | 56 | args.save_root_path = os.path.join(os.path.dirname(__file__), 'results/') 57 | args.save_root_path = os.path.join(args.save_root_path, model_identification) 58 | os.makedirs(args.save_root_path, exist_ok=True) 59 | log_format = '%(message)s' 60 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format) 61 | log_file = 'log.txt' 62 | log_path = os.path.join(args.save_root_path, log_file) 63 | print(log_path) 64 | if os.path.exists(log_path): 65 | raise Exception('log file already exists!') 66 | fh = logging.FileHandler(log_path, mode='w') 67 | fh.setFormatter(logging.Formatter(log_format)) 68 | logging.getLogger().addHandler(fh) 69 | 70 | setup_seed(args.seed) 71 | device = torch.device(args.device) 72 | torch.cuda.set_device(device) 73 | 74 | # get dataset and init models 75 | dataset_info, train_set, test_set, test_loader = get_dataset(args.dataset, args.dataset_root, args.batch_size) 76 | print("load data: done") 77 | with open(args.split_file, 'r') as file: 78 | file_data = json.load(file) 79 | client_indices, client_classes = file_data['client_idx'], file_data['client_classes'] 80 | 81 | if args.dataset in ['CIFAR10', 'FMNIST',]: 82 | labels = np.array(train_set.targets, dtype='int64') 83 | elif args.dataset in ['PathMNIST', 'OCTMNIST', 'OrganSMNIST', 'OrganCMNIST', 'ImageNette', 'OrganCMNIST224', 'PneumoniaMNIST224', 'RetinaMNIST224', 'STL', 'STL32']: 84 | labels = train_set.labels 85 | net_cls_counts = {} 86 | dict_users = {i: idcs for i, idcs in enumerate(client_indices)} 87 | for net_i, dataidx in dict_users.items(): 88 | unq, unq_cnt = np.unique(labels[dataidx], return_counts=True) 89 | tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))} 90 | net_cls_counts[net_i] = tmp 91 | 92 | logging.info(f'Data statistics: {net_cls_counts}') 93 | logging.info(f'client classes: {client_classes}') 94 | 95 | train_sets = [Subset(train_set, indices) for indices in client_indices] 96 | 97 | global_model = get_model(args.model, dataset_info) 98 | logging.info(global_model) 99 | logging.info(get_n_params(global_model)) 100 | logging.info(args.__dict__) 101 | 102 | # init server and clients 103 | client_list = [Client( 104 | cid=i, 105 | train_set=PerLabelDatasetNonIID( 106 | train_sets[i], 107 | client_classes[i], 108 | dataset_info['channel'], 109 | device, 110 | ), 111 | classes=client_classes[i], 112 | dataset_info=dataset_info, 113 | ipc=args.ipc, 114 | compression_ratio=args.compression_ratio, 115 | dc_iterations=args.dc_iterations, 116 | real_batch_size=args.dc_batch_size, 117 | image_lr=args.image_lr, 118 | image_momentum=args.image_momentum, 119 | image_weight_decay=args.image_weight_decay, 120 | lr=args.lr, 121 | momentum=args.momentum, 122 | weight_decay=args.weight_decay, 123 | local_ep=args.local_ep, 124 | dsa=args.dsa, 125 | dsa_strategy=args.dsa_strategy, 126 | init = args.init, 127 | clip_norm = args.clip_norm, 128 | gamma = args.gamma, 129 | lamda = args.lamda, 130 | b = args.b, 131 | con_temp = args.con_temp, 132 | kernel = args.kernel, 133 | save_root_path=args.save_root_path, 134 | device=device, 135 | ) for i in range(args.client_num)] 136 | 137 | server = Server( 138 | train_set = PerLabelDatasetNonIID( 139 | train_set, 140 | range(0,dataset_info['num_classes']), 141 | dataset_info['channel'], 142 | device, 143 | ), 144 | ipc = args.ipc, 145 | dataset_info=dataset_info, 146 | global_model_name=args.model, 147 | global_model=global_model, 148 | clients=client_list, 149 | communication_rounds=args.communication_rounds, 150 | join_ratio=args.join_ratio, 151 | batch_size=args.batch_size, 152 | model_epochs=args.model_epochs, 153 | lr_server=args.lr_server, 154 | momentum_server=args.momentum_server, 155 | weight_decay_server=args.weight_decay_server, 156 | lr_head=args.lr_head, 157 | momentum_head=args.momentum_head, 158 | weight_decay_head=args.weight_decay_head, 159 | weighted_matching = args.weighted_matching, 160 | weighted_sample = args.weighted_sample, 161 | weighted_mmd = args.weighted_mmd, 162 | contrastive_way = args.contrastive_way, 163 | con_beta = args.con_beta, 164 | con_temp = args.con_temp, 165 | topk = args.topk, 166 | dsa = args.dsa, 167 | dsa_strategy = args.dsa_strategy, 168 | preserve_all = args.preserve_all, 169 | eval_gap=args.eval_gap, 170 | test_set=test_set, 171 | test_loader=test_loader, 172 | device=device, 173 | model_identification=model_identification, 174 | save_root_path=args.save_root_path 175 | ) 176 | print('Server and Clients have been created.') 177 | 178 | # fit the model 179 | server.fit() 180 | 181 | if __name__ == "__main__": 182 | main() -------------------------------------------------------------------------------- /requirements.yaml: -------------------------------------------------------------------------------- 1 | name: dcfl 2 | channels: 3 | - pytorch 4 | - fastai 5 | - nvidia 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=5.1=1_gnu 10 | - annotated-types=0.6.0=py39h06a4308_0 11 | - blas=1.0=mkl 12 | - blessed=1.20.0=py39h06a4308_0 13 | - bottleneck=1.3.7=py39ha9d4c09_0 14 | - brotli=1.0.9=h5eee18b_8 15 | - brotli-bin=1.0.9=h5eee18b_8 16 | - brotli-python=1.0.9=py39h6a678d5_8 17 | - bzip2=1.0.8=h5eee18b_6 18 | - ca-certificates=2024.3.11=h06a4308_0 19 | - catalogue=2.0.10=py39h06a4308_0 20 | - certifi=2024.2.2=py39h06a4308_0 21 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 22 | - click=8.1.7=py39h06a4308_0 23 | - cloudpathlib=0.16.0=py39h06a4308_1 24 | - colorama=0.4.6=py39h06a4308_0 25 | - confection=0.1.4=py39h2f386ee_0 26 | - contourpy=1.2.0=py39hdb19cb5_0 27 | - cuda-cudart=11.8.89=0 28 | - cuda-cupti=11.8.87=0 29 | - cuda-libraries=11.8.0=0 30 | - cuda-nvrtc=11.8.89=0 31 | - cuda-nvtx=11.8.86=0 32 | - cuda-runtime=11.8.0=0 33 | - cuda-version=12.4=hbda6634_3 34 | - cycler=0.11.0=pyhd3eb1b0_0 35 | - cymem=2.0.6=py39h295c915_0 36 | - cyrus-sasl=2.1.28=h52b45da_1 37 | - cython-blis=0.7.9=py39h7deecbd_0 38 | - dbus=1.13.18=hb2f20db_0 39 | - expat=2.6.2=h6a678d5_0 40 | - fastai=2.7.15=py_0 41 | - fastcore=1.5.37=py_0 42 | - fastdownload=0.0.7=py_0 43 | - fastprogress=1.0.3=py_0 44 | - ffmpeg=4.2.2=h20bf706_0 45 | - filelock=3.13.1=py39h06a4308_0 46 | - fontconfig=2.14.1=h4c34cd2_2 47 | - fonttools=4.51.0=py39h5eee18b_0 48 | - freetype=2.12.1=h4a9f257_0 49 | - glib=2.78.4=h6a678d5_0 50 | - glib-tools=2.78.4=h6a678d5_0 51 | - gmp=6.2.1=h295c915_3 52 | - gmpy2=2.1.2=py39heeb90bb_0 53 | - gnutls=3.6.15=he1e5248_0 54 | - gpustat=1.1.1=py39h06a4308_0 55 | - gst-plugins-base=1.14.1=h6a678d5_1 56 | - gstreamer=1.14.1=h5eee18b_1 57 | - icu=73.1=h6a678d5_0 58 | - idna=3.7=py39h06a4308_0 59 | - importlib_resources=6.1.1=py39h06a4308_1 60 | - intel-openmp=2023.1.0=hdb19cb5_46306 61 | - jinja2=3.1.3=py39h06a4308_0 62 | - joblib=1.4.0=py39h06a4308_0 63 | - jpeg=9e=h5eee18b_1 64 | - kiwisolver=1.4.4=py39h6a678d5_0 65 | - krb5=1.20.1=h143b758_1 66 | - lame=3.100=h7b6447c_0 67 | - langcodes=3.3.0=pyhd3eb1b0_0 68 | - lcms2=2.12=h3be6417_0 69 | - ld_impl_linux-64=2.38=h1181459_1 70 | - lerc=3.0=h295c915_0 71 | - libbrotlicommon=1.0.9=h5eee18b_8 72 | - libbrotlidec=1.0.9=h5eee18b_8 73 | - libbrotlienc=1.0.9=h5eee18b_8 74 | - libclang13=14.0.6=default_he11475f_1 75 | - libcublas=11.11.3.6=0 76 | - libcufft=10.9.0.58=0 77 | - libcufile=1.9.1.3=h99ab3db_1 78 | - libcups=2.4.2=h2d74bed_1 79 | - libcurand=10.3.5.147=h99ab3db_1 80 | - libcusolver=11.4.1.48=0 81 | - libcusparse=11.7.5.86=0 82 | - libdeflate=1.17=h5eee18b_1 83 | - libedit=3.1.20230828=h5eee18b_0 84 | - libffi=3.4.4=h6a678d5_1 85 | - libgcc-ng=11.2.0=h1234567_1 86 | - libgfortran-ng=11.2.0=h00389a5_1 87 | - libgfortran5=11.2.0=h1234567_1 88 | - libglib=2.78.4=hdc74915_0 89 | - libgomp=11.2.0=h1234567_1 90 | - libiconv=1.16=h5eee18b_3 91 | - libidn2=2.3.4=h5eee18b_0 92 | - libjpeg-turbo=2.0.0=h9bf148f_0 93 | - libllvm14=14.0.6=hdb19cb5_3 94 | - libnpp=11.8.0.86=0 95 | - libnvjpeg=11.9.0.86=0 96 | - libopus=1.3.1=h7b6447c_0 97 | - libpng=1.6.39=h5eee18b_0 98 | - libpq=12.17=hdbd6064_0 99 | - libstdcxx-ng=11.2.0=h1234567_1 100 | - libtasn1=4.19.0=h5eee18b_0 101 | - libtiff=4.5.1=h6a678d5_0 102 | - libunistring=0.9.10=h27cfd23_0 103 | - libuuid=1.41.5=h5eee18b_0 104 | - libvpx=1.7.0=h439df22_0 105 | - libwebp-base=1.3.2=h5eee18b_0 106 | - libxcb=1.15=h7f8727e_0 107 | - libxkbcommon=1.0.1=h5eee18b_1 108 | - libxml2=2.10.4=hfdd30dd_2 109 | - llvm-openmp=14.0.6=h9e868ea_0 110 | - lz4-c=1.9.4=h6a678d5_1 111 | - markupsafe=2.1.3=py39h5eee18b_0 112 | - matplotlib=3.8.0=py39h06a4308_0 113 | - matplotlib-base=3.8.0=py39h1128e8f_0 114 | - mkl=2023.1.0=h213fc3f_46344 115 | - mkl-service=2.4.0=py39h5eee18b_1 116 | - mkl_fft=1.3.8=py39h5eee18b_0 117 | - mkl_random=1.2.4=py39hdb19cb5_0 118 | - mpc=1.1.0=h10f8cd9_1 119 | - mpfr=4.0.2=hb69a4c5_1 120 | - mpmath=1.3.0=py39h06a4308_0 121 | - murmurhash=1.0.7=py39h295c915_0 122 | - mysql=5.7.24=h721c034_2 123 | - ncurses=6.4=h6a678d5_0 124 | - nettle=3.7.3=hbbd107a_1 125 | - networkx=3.1=py39h06a4308_0 126 | - numexpr=2.8.7=py39h85018f9_0 127 | - numpy=1.26.2=py39h5f9d8c6_0 128 | - numpy-base=1.26.2=py39hb5e798b_0 129 | - nvidia-ml-py=12.535.133=py39h06a4308_0 130 | - openh264=2.1.1=h4ff587b_0 131 | - openjpeg=2.4.0=h3ad879b_0 132 | - openssl=3.0.13=h7f8727e_1 133 | - packaging=23.2=py39h06a4308_0 134 | - pandas=2.2.1=py39h6a678d5_0 135 | - pcre2=10.42=hebb0a14_1 136 | - pillow=10.3.0=py39h5eee18b_0 137 | - pip=24.0=py39h06a4308_0 138 | - ply=3.11=py39h06a4308_0 139 | - preshed=3.0.6=py39h295c915_0 140 | - pydantic=2.5.3=py39h06a4308_0 141 | - pydantic-core=2.14.6=py39hb02cf49_0 142 | - pyparsing=3.0.9=py39h06a4308_0 143 | - pyqt=5.15.10=py39h6a678d5_0 144 | - pyqt5-sip=12.13.0=py39h5eee18b_0 145 | - pysocks=1.7.1=py39h06a4308_0 146 | - python=3.9.18=h955ad1f_0 147 | - python-dateutil=2.9.0post0=py39h06a4308_0 148 | - python-tzdata=2023.3=pyhd3eb1b0_0 149 | - pytorch=2.1.2=py3.9_cuda11.8_cudnn8.7.0_0 150 | - pytorch-cuda=11.8=h7e8668a_5 151 | - pytorch-mutex=1.0=cuda 152 | - pytz=2024.1=py39h06a4308_0 153 | - pyyaml=6.0.1=py39h5eee18b_0 154 | - qt-main=5.15.2=h53bd1ea_10 155 | - readline=8.2=h5eee18b_0 156 | - requests=2.31.0=py39h06a4308_1 157 | - scikit-learn=1.2.2=py39h6a678d5_1 158 | - scipy=1.11.4=py39h5f9d8c6_0 159 | - seaborn=0.12.2=py39h06a4308_0 160 | - setuptools=69.5.1=py39h06a4308_0 161 | - shellingham=1.5.0=py39h06a4308_0 162 | - sip=6.7.12=py39h6a678d5_0 163 | - six=1.16.0=pyhd3eb1b0_1 164 | - smart_open=5.2.1=py39h06a4308_0 165 | - spacy=3.7.2=py39h3c18c91_0 166 | - spacy-legacy=3.0.12=py39h06a4308_0 167 | - spacy-loggers=1.0.4=py39h06a4308_0 168 | - sqlite=3.45.3=h5eee18b_0 169 | - srsly=2.4.8=py39h6a678d5_1 170 | - sympy=1.12=py39h06a4308_0 171 | - tbb=2021.8.0=hdb19cb5_0 172 | - thinc=8.2.2=py39h3c18c91_0 173 | - threadpoolctl=2.2.0=pyh0d69192_0 174 | - tk=8.6.14=h39e8969_0 175 | - tomli=2.0.1=py39h06a4308_0 176 | - torchaudio=2.1.2=py39_cu118 177 | - torchtriton=2.1.0=py39 178 | - torchvision=0.16.2=py39_cu118 179 | - tornado=6.3.3=py39h5eee18b_0 180 | - tqdm=4.65.0=py39hb070fc8_0 181 | - typer=0.9.0=py39h06a4308_0 182 | - typing-extensions=4.11.0=py39h06a4308_0 183 | - typing_extensions=4.11.0=py39h06a4308_0 184 | - unicodedata2=15.1.0=py39h5eee18b_0 185 | - urllib3=2.2.1=py39h06a4308_0 186 | - wasabi=0.9.1=py39h06a4308_0 187 | - wcwidth=0.2.5=pyhd3eb1b0_0 188 | - weasel=0.3.4=py39h06a4308_0 189 | - wheel=0.43.0=py39h06a4308_0 190 | - x264=1!157.20191217=h7b6447c_0 191 | - xz=5.4.6=h5eee18b_1 192 | - yaml=0.2.5=h7b6447c_0 193 | - zipp=3.17.0=py39h06a4308_0 194 | - zlib=1.2.13=h5eee18b_1 195 | - zstd=1.5.5=hc292b87_2 196 | - pip: 197 | - absl-py==1.4.0 198 | - appdirs==1.4.4 199 | - argcomplete==3.2.2 200 | - array-record==0.5.0 201 | - astunparse==1.6.3 202 | - brokenaxes==0.6.1 203 | - cachetools==5.3.2 204 | - dm-tree==0.1.8 205 | - docker-pycreds==0.4.0 206 | - einops==0.7.0 207 | - etils==1.5.2 208 | - fast-pytorch-kmeans==0.2.0.1 209 | - fire==0.5.0 210 | - flatbuffers==24.3.7 211 | - fsspec==2023.12.2 212 | - gast==0.5.4 213 | - gitdb==4.0.11 214 | - gitpython==3.1.42 215 | - google-pasta==0.2.0 216 | - googleapis-common-protos==1.63.0 217 | - grpcio==1.62.1 218 | - h5py==3.10.0 219 | - imageio==2.34.0 220 | - importlib-metadata==7.1.0 221 | - jax==0.4.25 222 | - jaxlib==0.4.25+cuda11.cudnn86 223 | - keras==3.1.1 224 | - kornia==0.7.2 225 | - kornia-rs==0.1.1 226 | - lazy-loader==0.3 227 | - libclang==18.1.1 228 | - markdown==3.6 229 | - markdown-it-py==3.0.0 230 | - mdurl==0.1.2 231 | - medmnist==3.0.1 232 | - ml-dtypes==0.3.2 233 | - namex==0.0.7 234 | - nvidia-cublas-cu12==12.3.4.1 235 | - nvidia-cuda-cupti-cu12==12.3.101 236 | - nvidia-cuda-nvcc-cu12==12.3.107 237 | - nvidia-cuda-nvrtc-cu12==12.3.107 238 | - nvidia-cuda-runtime-cu12==12.3.101 239 | - nvidia-cudnn-cu12==8.9.7.29 240 | - nvidia-cufft-cu12==11.0.12.1 241 | - nvidia-curand-cu12==10.3.4.107 242 | - nvidia-cusolver-cu12==11.5.4.101 243 | - nvidia-cusparse-cu12==12.2.0.103 244 | - nvidia-nccl-cu12==2.19.3 245 | - nvidia-nvjitlink-cu12==12.3.101 246 | - objax==1.8.0 247 | - opencv-python==4.8.1.78 248 | - opt-einsum==3.3.0 249 | - optree==0.10.0 250 | - parameterized==0.9.0 251 | - pipx==1.4.3 252 | - platformdirs==4.2.0 253 | - promise==2.3 254 | - protobuf==3.20.3 255 | - psutil==5.9.8 256 | - pygments==2.17.2 257 | - pynvml==11.5.0 258 | - rich==13.7.1 259 | - scikit-image==0.22.0 260 | - sentry-sdk==1.43.0 261 | - setproctitle==1.3.3 262 | - smmap==5.0.1 263 | - tensorboard==2.16.2 264 | - tensorboard-data-server==0.7.2 265 | - tensorflow==2.16.1 266 | - tensorflow-datasets==4.9.3 267 | - tensorflow-io-gcs-filesystem==0.36.0 268 | - tensorflow-metadata==1.14.0 269 | - termcolor==2.4.0 270 | - tifffile==2024.2.12 271 | - toml==0.10.2 272 | - tzdata==2024.1 273 | - userpath==1.9.1 274 | - wandb==0.16.4 275 | - werkzeug==3.0.1 276 | - wrapt==1.16.0 277 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | class ConvNet(nn.Module): 5 | def __init__(self, channel, num_classes, net_width, net_depth, net_act, net_norm, net_pooling, im_size = (32,32)): 6 | super(ConvNet, self).__init__() 7 | if net_act == 'sigmoid': 8 | self.net_act = nn.Sigmoid() 9 | elif net_act == 'relu': 10 | self.net_act = nn.ReLU(inplace=True) 11 | elif net_act == 'leakyrelu': 12 | self.net_act = nn.LeakyReLU(negative_slope=0.01) 13 | else: 14 | exit('unknown activation function: %s'%net_act) 15 | 16 | if net_pooling == 'maxpooling': 17 | self.net_pooling = nn.MaxPool2d(kernel_size=2, stride=2) 18 | elif net_pooling == 'avgpooling': 19 | self.net_pooling = nn.AvgPool2d(kernel_size=2, stride=2) 20 | elif net_pooling == 'none': 21 | self.net_pooling = None 22 | else: 23 | exit('unknown net_pooling: %s'%net_pooling) 24 | 25 | self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_pooling, im_size) 26 | num_feat = shape_feat[0]*shape_feat[1]*shape_feat[2] 27 | self.classifier = nn.Linear(num_feat, num_classes) 28 | self.modified = False 29 | 30 | def forward(self, x, train=False, mode='dummy', normalize='dummy'): 31 | if self.training: 32 | self.modified = True 33 | out = self.features(x) 34 | # inter_out = out.view(out.size(0), -1) 35 | inter_out = out.reshape(out.size(0), -1) 36 | out = self.classifier(inter_out) 37 | if train: 38 | return inter_out, out 39 | else: 40 | return out 41 | 42 | def embed(self, x): 43 | out = self.features(x) 44 | # out = out.view(out.size(0), -1) 45 | out = out.reshape(out.size(0), -1) 46 | return out 47 | 48 | def _get_normlayer(self, net_norm, shape_feat): 49 | # shape_feat = (c*h*w) 50 | if net_norm == 'batchnorm': 51 | norm = nn.BatchNorm2d(shape_feat[0], affine=True) 52 | elif net_norm == 'layernorm': 53 | norm = nn.LayerNorm(shape_feat, elementwise_affine=True) 54 | elif net_norm == 'instancenorm': 55 | norm = nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True) 56 | elif net_norm == 'groupnorm': 57 | norm = nn.GroupNorm(4, shape_feat[0], affine=True) 58 | elif net_norm == 'none': 59 | norm = None 60 | else: 61 | norm = None 62 | exit('unknown net_norm: %s'%net_norm) 63 | return norm 64 | 65 | def _make_layers(self, channel, net_width, net_depth, net_norm, net_pooling, im_size): 66 | layers = [] 67 | in_channels = channel 68 | shape_feat = [in_channels, im_size[0], im_size[1]] 69 | if im_size[0] == 28: 70 | shape_feat = [in_channels, 32, 32] 71 | # if im_size[0] == 28: 72 | # im_size = (32, 32) 73 | # shape_feat = [in_channels, im_size[0], im_size[1]] 74 | for d in range(net_depth): 75 | # add compatibility for image 3*28*28 in medmnist dataset 76 | layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if (channel == 1 or (channel == 3 and im_size[0] == 28)) and d == 0 else 1)] 77 | # layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)] 78 | shape_feat[0] = net_width 79 | if net_norm != 'none': 80 | layers += [self._get_normlayer(net_norm, shape_feat)] 81 | layers += [self.net_act] 82 | in_channels = net_width 83 | if net_pooling != 'none': 84 | layers += [self.net_pooling] 85 | shape_feat[1] //= 2 86 | shape_feat[2] //= 2 87 | 88 | return nn.Sequential(*layers), shape_feat 89 | 90 | class Projector(nn.Module): 91 | def __init__(self, input_dim, output_dim, hidden_dim=128, num_hidden=1, bn='batchnorm', activation='relu'): 92 | super(Projector, self).__init__() 93 | self.layers = [nn.Linear(input_dim, hidden_dim)] 94 | if bn == 'batchnorm': 95 | self.layers.append(nn.BatchNorm1d(hidden_dim)) 96 | if activation == 'relu': 97 | self.layers.append(nn.ReLU(inplace=True)) 98 | for _ in range(num_hidden-1): 99 | self.layers.append(nn.Linear(hidden_dim, hidden_dim)) 100 | if bn == 'batchnorm': 101 | self.layers.append(nn.BatchNorm1d(hidden_dim)) 102 | if activation == 'relu': 103 | self.layers.append(nn.ReLU(inplace=True)) 104 | self.layers.append(nn.Linear(hidden_dim, output_dim)) 105 | self.layers = nn.Sequential(*self.layers) 106 | 107 | def forward(self, x): 108 | x = self.layers(x) 109 | return x 110 | 111 | 112 | class BasicBlock(nn.Module): 113 | expansion = 1 114 | 115 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): 116 | super(BasicBlock, self).__init__() 117 | self.norm = norm 118 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 119 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 120 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 121 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 122 | 123 | ## 124 | # self.activation = nn.ReLU(inplace=True) 125 | ## 126 | self.shortcut = nn.Sequential() 127 | if stride != 1 or in_planes != self.expansion*planes: 128 | self.shortcut = nn.Sequential( 129 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 130 | nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes) 131 | ) 132 | 133 | def forward(self, x): 134 | out = F.relu(self.bn1(self.conv1(x))) 135 | out = self.bn2(self.conv2(out)) 136 | out += self.shortcut(x) 137 | out = F.relu(out) 138 | # out = self.activation(out) 139 | return out 140 | 141 | 142 | class Bottleneck(nn.Module): 143 | expansion = 4 144 | 145 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): 146 | super(Bottleneck, self).__init__() 147 | self.norm = norm 148 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 149 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 150 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 151 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 152 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 153 | self.bn3 = nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes) 154 | 155 | self.shortcut = nn.Sequential() 156 | if stride != 1 or in_planes != self.expansion*planes: 157 | self.shortcut = nn.Sequential( 158 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 159 | nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes) 160 | ) 161 | 162 | def forward(self, x): 163 | out = F.relu(self.bn1(self.conv1(x))) 164 | out = F.relu(self.bn2(self.conv2(out))) 165 | out = self.bn3(self.conv3(out)) 166 | out += self.shortcut(x) 167 | out = F.relu(out) 168 | return out 169 | 170 | 171 | class ResNet(nn.Module): 172 | def __init__(self, block, num_blocks, channel=3, num_classes=10, norm='instancenorm'): 173 | super(ResNet, self).__init__() 174 | self.in_planes = 64 175 | self.norm = norm 176 | 177 | self.features = [] 178 | # self.conv1 = nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1, bias=False) 179 | self.conv1 = nn.Conv2d(channel, 64, kernel_size=7, stride=2, padding=3, bias=False) 180 | self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(64) 181 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 182 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 183 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 184 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 185 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 186 | self.classifier = nn.Linear(512*block.expansion, num_classes) 187 | self.modified = False 188 | ## 189 | # self.activation = nn.ReLU(inplace=True) 190 | # self.features.append(self.conv1) 191 | # self.features.append(self.bn1) 192 | # self.features.append(self.activation) 193 | # self.features = nn.Sequential(*self.features) 194 | # self.features = nn.Sequential(*(list(self.features.children()) + list(self.layer1.children()))) 195 | # self.features = nn.Sequential(*(list(self.features.children()) + list(self.layer2.children()))) 196 | # self.features = nn.Sequential(*(list(self.features.children()) + list(self.layer3.children()))) 197 | # self.features = nn.Sequential(*(list(self.features.children()) + list(self.layer4.children()))) 198 | ## 199 | 200 | def _make_layer(self, block, planes, num_blocks, stride): 201 | strides = [stride] + [1]*(num_blocks-1) 202 | layers = [] 203 | for stride in strides: 204 | layers.append(block(self.in_planes, planes, stride, self.norm)) 205 | self.in_planes = planes * block.expansion 206 | return nn.Sequential(*layers) 207 | 208 | def forward(self, x, train=False): 209 | if self.training: self.modified = True 210 | out = F.relu(self.bn1(self.conv1(x))) 211 | out = self.maxpool(out) # 212 | out = self.layer1(out) 213 | out = self.layer2(out) 214 | out = self.layer3(out) 215 | out = self.layer4(out) 216 | out = F.avg_pool2d(out, 4) 217 | inter_out = out.view(out.size(0), -1) 218 | out = self.classifier(inter_out) 219 | if train: 220 | return inter_out, out 221 | else: 222 | return out 223 | 224 | def embed(self, x): 225 | out = F.relu(self.bn1(self.conv1(x))) 226 | out = self.maxpool(out) # 227 | out = self.layer1(out) 228 | out = self.layer2(out) 229 | out = self.layer3(out) 230 | out = self.layer4(out) 231 | out = F.avg_pool2d(out, 4) 232 | # print(out.shape) 233 | out = out.view(out.size(0), -1) 234 | return out 235 | 236 | 237 | def ResNet18BN(channel, num_classes): 238 | return ResNet(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes, norm='batchnorm') 239 | 240 | def ResNet18(channel, num_classes): 241 | return ResNet(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes) 242 | 243 | def ResNet34(channel, num_classes): 244 | return ResNet(BasicBlock, [3,4,6,3], channel=channel, num_classes=num_classes) 245 | 246 | def ResNet50(channel, num_classes): 247 | return ResNet(Bottleneck, [3,4,6,3], channel=channel, num_classes=num_classes) 248 | 249 | def ResNet101(channel, num_classes): 250 | return ResNet(Bottleneck, [3,4,23,3], channel=channel, num_classes=num_classes) 251 | 252 | def ResNet152(channel, num_classes): 253 | return ResNet(Bottleneck, [3,8,36,3], channel=channel, num_classes=num_classes) 254 | -------------------------------------------------------------------------------- /dataset/data/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from torchvision import datasets, transforms 5 | import copy 6 | import random 7 | import math 8 | import logging 9 | from torchvision.utils import save_image 10 | import os 11 | from medmnist import PathMNIST, OCTMNIST, OrganSMNIST, OrganCMNIST, PneumoniaMNIST, RetinaMNIST 12 | 13 | 14 | def get_dataset(dataset, dataset_root, batch_size): 15 | if dataset == 'MNIST': 16 | channel = 1 17 | im_size = (28, 28) 18 | num_classes = 10 19 | mean = [0.1307] 20 | std = [0.3081] 21 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 22 | trainset = datasets.MNIST(dataset_root, train=True, download=True, transform=transform) # no augmentation 23 | testset = datasets.MNIST(dataset_root, train=False, download=True, transform=transform) 24 | class_names = [str(c) for c in range(num_classes)] 25 | elif dataset == 'CIFAR10': 26 | channel = 3 27 | im_size = (32, 32) 28 | num_classes = 10 29 | mean = [0.4914, 0.4822, 0.4465] 30 | std = [0.2023, 0.1994, 0.2010] 31 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 32 | trainset = datasets.CIFAR10(dataset_root, train=True, download=True, transform=transform) # no augmentation 33 | testset = datasets.CIFAR10(dataset_root, train=False, download=True, transform=transform) 34 | class_names = trainset.classes 35 | elif dataset == 'STL': 36 | channel = 3 37 | im_size = (96, 96) 38 | num_classes = 10 39 | mean = [0.5, 0.5, 0.5] 40 | std = [0.5, 0.5, 0.5] 41 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 42 | trainset = datasets.STL10(dataset_root, split='train', download=True, transform=transform) # no augmentation 43 | testset = datasets.STL10(dataset_root, split='test', download=True, transform=transform) 44 | class_names = None 45 | elif dataset == 'STL32': 46 | channel = 3 47 | im_size = (32, 32) 48 | num_classes = 10 49 | mean = [0.5, 0.5, 0.5] 50 | std = [0.5, 0.5, 0.5] 51 | transform = transforms.Compose([transforms.Resize([32, 32]), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 52 | trainset = datasets.STL10(dataset_root, split='train', download=True, transform=transform) # no augmentation 53 | testset = datasets.STL10(dataset_root, split='test', download=True, transform=transform) 54 | class_names = None 55 | elif dataset == 'PathMNIST': 56 | channel = 3 57 | im_size = (28, 28) 58 | num_classes = 9 59 | mean = [0.5, 0.5, 0.5] 60 | std = [0.5, 0.5, 0.5] 61 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 62 | trainset = PathMNIST(split="train", download=True, root=dataset_root, transform=transform) # no transformation 63 | trainset.labels = np.array(np.squeeze(trainset.labels).tolist(), dtype='int64') 64 | testset = PathMNIST(split="test", download=True, root=dataset_root, transform=transform) 65 | testset.labels = np.array(np.squeeze(testset.labels).tolist(), dtype='int64') 66 | class_names = trainset.info['label'].values() 67 | elif dataset == 'OrganSMNIST': 68 | channel = 1 69 | im_size = (28, 28) 70 | num_classes = 11 71 | mean = [0.5,] 72 | std = [0.5,] 73 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 74 | trainset = OrganSMNIST(split="train", download=True, root=dataset_root, transform=transform) # no transformation 75 | trainset.labels = np.array(np.squeeze(trainset.labels).tolist(), dtype='int64') 76 | testset = OrganSMNIST(split="test", download=True, root=dataset_root, transform=transform) 77 | testset.labels = np.array(np.squeeze(testset.labels).tolist(), dtype='int64') 78 | class_names = trainset.info['label'].values() 79 | elif dataset == 'OCTMNIST': 80 | channel = 1 81 | im_size = (28, 28) 82 | num_classes = 4 83 | mean = [0.5,] 84 | std = [0.5,] 85 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 86 | trainset = OCTMNIST(split="train", download=True, root=dataset_root, transform=transform) # no transformation 87 | trainset.labels = np.array(np.squeeze(trainset.labels).tolist(), dtype='int64') 88 | testset = OCTMNIST(split="test", download=True, root=dataset_root, transform=transform) 89 | testset.labels = np.array(np.squeeze(testset.labels).tolist(), dtype='int64') 90 | class_names = trainset.info['label'].values() 91 | elif dataset == 'ImageNette': 92 | from fastai.vision.all import untar_data, URLs 93 | channel = 3 94 | num_classes = 10 95 | im_size = (64, 64) 96 | mean = [0.5, 0.5, 0.5] 97 | std = [0.5, 0.5, 0.5] 98 | path = untar_data(URLs.IMAGENETTE) 99 | print(path) 100 | transform = transforms.Compose([transforms.Resize([64, 64]), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 101 | trainset = datasets.ImageFolder(root=f'{path}/train', transform=transform) # cancel augment 102 | testset = datasets.ImageFolder(root=f'{path}/val', transform=transform) 103 | trainset.labels = np.array(np.squeeze(trainset.targets).tolist(), dtype='int64') 104 | testset.labels = np.array(np.squeeze(testset.targets).tolist(), dtype='int64') 105 | class_names = range(10) 106 | elif dataset == 'OrganCMNIST224': 107 | channel = 1 108 | num_classes = 11 109 | im_size = (224, 224) 110 | mean = [0.5,] 111 | std = [0.5,] 112 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 113 | trainset = OrganCMNIST(split="train", download=True, root=dataset_root, transform=transform, size=224) # no transformation 114 | trainset.labels = np.array(np.squeeze(trainset.labels).tolist(), dtype='int64') 115 | testset = OrganCMNIST(split="test", download=True, root=dataset_root, transform=transform, size=224) 116 | testset.labels = np.array(np.squeeze(testset.labels).tolist(), dtype='int64') 117 | class_names = trainset.info['label'].values() 118 | elif dataset == 'PneumoniaMNIST224': 119 | channel = 1 120 | num_classes = 2 121 | im_size = (224, 224) 122 | mean = [0.5,] 123 | std = [0.5,] 124 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 125 | trainset = PneumoniaMNIST(split="train", download=True, root=dataset_root, transform=transform, size=224) # no transformation 126 | trainset.labels = np.array(np.squeeze(trainset.labels).tolist(), dtype='int64') 127 | testset = PneumoniaMNIST(split="test", download=True, root=dataset_root, transform=transform, size=224) # no transformation 128 | testset.labels = np.array(np.squeeze(testset.labels).tolist(), dtype='int64') 129 | class_names = trainset.info['label'].values() 130 | elif dataset == 'RetinaMNIST224': 131 | channel = 3 132 | num_classes = 5 133 | im_size = (224, 224) 134 | mean = [0.5, 0.5, 0.5] 135 | std = [0.5, 0.5, 0.5] 136 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 137 | trainset = RetinaMNIST(split="train", download=True, root=dataset_root, transform=transform, size=224) # no transformation 138 | trainset.labels = np.array(np.squeeze(trainset.labels).tolist(), dtype='int64') 139 | testset = RetinaMNIST(split="test", download=True, root=dataset_root, transform=transform, size=224) # no transformation 140 | testset.labels = np.array(np.squeeze(testset.labels).tolist(), dtype='int64') 141 | class_names = trainset.info['label'].values() 142 | else: 143 | exit(f'unknown dataset: {dataset}') 144 | 145 | dataset_info = { 146 | 'name': dataset, 147 | 'channel': channel, 148 | 'im_size': im_size, 149 | 'num_classes': num_classes, 150 | 'classes_names': class_names, 151 | 'mean': mean, 152 | 'std': std, 153 | } 154 | 155 | testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, 156 | num_workers=0) # pin memory 157 | 158 | return dataset_info, trainset, testset, testloader 159 | 160 | class PerLabelDatasetNonIID(): 161 | def __init__(self, dst_train, classes, channel, device): # images: n x c x h x w tensor 162 | self.images_all = [] 163 | self.labels_all = [] 164 | self.indices_class = {c: [] for c in classes} 165 | self.device = device 166 | 167 | self.images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))] 168 | self.labels_all = [dst_train[i][1] for i in range(len(dst_train))] 169 | for i, lab in enumerate(self.labels_all): 170 | if lab not in classes: 171 | continue 172 | self.indices_class[lab].append(i) 173 | if len(self.images_all) > 0: 174 | self.images_all = torch.cat(self.images_all, dim=0).to(device) 175 | self.labels_all = torch.tensor(self.labels_all, dtype=torch.long, device=device) 176 | self.loss_all = None 177 | self.sorted_indices_class = {c: [] for c in classes} 178 | self.sample_prob = {c: [] for c in classes} 179 | self.sample_indices = {c: [] for c in classes} 180 | 181 | def __len__(self): 182 | return self.images_all.shape[0] 183 | 184 | def get_random_images(self, n): # get n random images 185 | idx_shuffle = np.random.permutation(range(self.images_all.shape[0]))[:n] 186 | return self.images_all[idx_shuffle] 187 | 188 | def get_images(self, c, n, avg=False): # get n random images from class c 189 | if not avg: 190 | if len(self.indices_class[c]) >= n: 191 | idx_shuffle = np.random.permutation(self.indices_class[c])[:n] 192 | else: 193 | # sampled_idx = np.random.choice(self.indices_class[c], n - len(self.indices_class[c]), replace=True) 194 | # idx_shuffle = np.concatenate((self.indices_class[c], sampled_idx), axis=None) 195 | idx_shuffle = self.indices_class[c] 196 | return self.images_all[idx_shuffle] 197 | else: 198 | sampled_imgs = [] 199 | for _ in range(n): 200 | if len(self.indices_class[c]) >= 5: 201 | idx = np.random.choice(self.indices_class[c], 5, replace=False) 202 | else: 203 | idx = np.random.choice(self.indices_class[c], 5, replace=True) 204 | sampled_imgs.append(torch.mean(self.images_all[idx], dim=0, keepdim=True)) 205 | sampled_imgs = torch.cat(sampled_imgs, dim=0).to(self.device) 206 | return sampled_imgs 207 | 208 | def get_all_images(self, c): 209 | all_images = self.images_all[self.indices_class[c]] 210 | return all_images 211 | 212 | def sort_image_by_model(self, model, thres=0.5, rounds=None, cid=None, save_root_path=None): 213 | loss_function = torch.nn.CrossEntropyLoss(reduction='none') 214 | for i, c in enumerate(self.indices_class.keys()): 215 | all_images_c = self.get_all_images(c) 216 | all_labels_c = torch.ones(all_images_c.shape[0])*c 217 | all_labels_c = all_labels_c.long().to(self.device) 218 | model.eval() 219 | with torch.no_grad(): 220 | all_pred_c = model(all_images_c) 221 | loss = loss_function(all_pred_c, all_labels_c) 222 | sorted_loss, sorted_indices = torch.sort(loss, descending=True, dim=0) 223 | 224 | thres = int(math.ceil(len(sorted_indices) * thres)) 225 | # logging.info(f"{sorted_loss[:thres]}") 226 | self.sorted_indices_class[c] = [self.indices_class[c][idx] for idx in sorted_indices[:thres]] 227 | # save_image(self.images_all[self.sorted_indices_class[c]].data.clone(), os.path.join(save_root_path, f'hard_imgs{rounds}_{cid}_{c}.png'), normalize=True, scale_each=True, nrow=10) 228 | del all_images_c, all_labels_c, all_pred_c 229 | torch.cuda.empty_cache() 230 | 231 | def cal_loss(self, model, prev_model, lamda=0.5, gamma=1.0, b=0.7, rounds=None, cid=None, save_root_path=None): 232 | loss_function = torch.nn.CrossEntropyLoss(reduction='none') 233 | model.eval() 234 | prev_model.eval() 235 | with torch.no_grad(): 236 | if self.images_all.shape[0] > 500: 237 | all_preds = [] 238 | all_preds_prev = [] 239 | batch_size = 500 240 | total_num = self.images_all.shape[0] 241 | for idx in range(0, total_num, batch_size): 242 | batch_st = idx 243 | if batch_st + batch_size >= total_num: 244 | batch_ed = total_num 245 | else: 246 | batch_ed = idx+batch_size 247 | all_preds.append(model(self.images_all[batch_st: batch_ed])) 248 | all_preds_prev.append(prev_model(self.images_all[batch_st: batch_ed])) 249 | all_preds = torch.cat(all_preds, dim=0) 250 | all_preds_prev = torch.cat(all_preds_prev, dim=0) 251 | all_preds = (1-lamda) * all_preds + lamda * all_preds_prev 252 | print(all_preds.shape) 253 | else: 254 | all_preds = model(self.images_all) 255 | all_preds_prev = prev_model(self.images_all) 256 | all_preds = (1-lamda) * all_preds + lamda * all_preds_prev 257 | self.loss_all = loss_function(all_preds, self.labels_all).type(torch.float64) 258 | # self.loss_all = 1.0/(1.0+torch.exp(-gamma * (self.loss_all-0.7))).cpu() 259 | self.loss_all = 1.0/(1.0+torch.exp(-gamma * (self.loss_all-b))).cpu() 260 | # logging.info(f"{self.loss_all.cpu().tolist()}") 261 | del all_preds 262 | torch.cuda.empty_cache() 263 | 264 | def norm_loss(self): 265 | for i, c in enumerate(self.indices_class.keys()): 266 | self.sample_prob[c] = F.softmax(self.loss_all[self.indices_class[c]], dim=0) 267 | hist, _ = np.histogram(self.sample_prob[c], bins=10) 268 | logging.info(f"class {c} have {len(self.indices_class[c])} samples, histogram: {hist}") 269 | 270 | def pre_sample(self, it, bs): 271 | for i, c in enumerate(self.indices_class.keys()): 272 | self.sample_prob[c] = F.softmax(self.loss_all[self.indices_class[c]], dim=0) 273 | self.sample_indices[c] = np.random.choice(self.indices_class[c], size=it*bs, replace=True, p=self.sample_prob[c]) 274 | hist, bin_edges = np.histogram(self.sample_prob[c], bins=10) 275 | logging.info(f"class {c} have {len(self.indices_class[c])} samples, histogram: {hist}, bin edged: {bin_edges}") 276 | 277 | def weighted_sample(self, c, it, bs): 278 | return self.images_all[self.sample_indices[c][it:it+bs]] 279 | 280 | def get_images_loss(self, c, n): 281 | if len(self.indices_class[c]) >= n: 282 | idx_shuffle = np.random.permutation(self.indices_class[c])[:n] 283 | else: 284 | idx_shuffle = self.indices_class[c] 285 | return self.images_all[idx_shuffle], self.loss_all[idx_shuffle] 286 | 287 | def get_sorted_images(self, c, n): 288 | if len(self.sorted_indices_class[c]) >= n: 289 | idx_shuffle = np.random.permutation(self.sorted_indices_class[c])[:n] 290 | else: 291 | idx_shuffle = self.sorted_indices_class[c] 292 | return self.images_all[idx_shuffle] -------------------------------------------------------------------------------- /dataset/data/dataset_partition.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import numpy as np 6 | from torchvision import datasets, transforms 7 | import matplotlib.pyplot as plt 8 | from medmnist import PathMNIST, OCTMNIST, OrganSMNIST, OrganCMNIST, RetinaMNIST, PneumoniaMNIST 9 | from fastai.vision.all import untar_data, URLs 10 | import torch 11 | 12 | 13 | def plot_client_data_distribution(num_classes, num_users, dict_users, labels, save_path): 14 | for client_id in dict_users.keys(): 15 | print(len(dict_users[client_id])) 16 | 17 | plt.figure(figsize=(12, 8)) 18 | label_distribution = [[] for _ in range(num_classes)] 19 | for client_id, client_data in dict_users.items(): 20 | for idx in client_data: 21 | label_distribution[labels[idx]].append(client_id) 22 | 23 | plt.hist(label_distribution, stacked=True, 24 | bins=np.arange(-0.5, num_users + 1.5, 1), 25 | label=range(num_classes), rwidth=0.5) 26 | plt.xticks(np.arange(num_users), ["Client %d" % 27 | c_id for c_id in range(num_users)]) 28 | plt.xlabel("Client ID") 29 | plt.ylabel("Number of samples") 30 | plt.legend(loc="upper right") 31 | plt.title("Label Distribution on Different Clients") 32 | plt.savefig(save_path) 33 | 34 | def partition(args): 35 | np.random.seed(args.seed) 36 | 37 | # prepare datasets for then partition latter 38 | if args.dataset == 'MNIST': 39 | num_classes = 10 40 | mean = [0.1307] 41 | std = [0.3081] 42 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 43 | dataset = datasets.MNIST(args.dataset_root, train=True, download=True, transform=transform) # no augmentation 44 | class_names = [str(c) for c in range(num_classes)] 45 | elif args.dataset == 'CIFAR10': 46 | num_classes = 10 47 | mean = [0.4914, 0.4822, 0.4465] 48 | std = [0.2023, 0.1994, 0.2010] 49 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 50 | dataset = datasets.CIFAR10(args.dataset_root, train=True, download=True, transform=transform) # no augmentation 51 | class_names = dataset.classes 52 | elif args.dataset == 'FMNIST': 53 | num_classes = 10 54 | mean = [0.2861] 55 | std = [0.3530] 56 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 57 | dataset = datasets.FashionMNIST(args.dataset_root, train=True, download=True, transform=transform) # no augmentation 58 | class_names = dataset.classes 59 | elif args.dataset == 'STL': 60 | num_classes = 10 61 | mean = [0.5, 0.5, 0.5] 62 | std = [0.5, 0.5, 0.5] 63 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 64 | dataset = datasets.STL10(args.dataset_root, split='train', download=True, transform=transform) # no augmentation 65 | class_names = dataset.classes 66 | elif args.dataset == 'STL32': 67 | num_classes = 10 68 | mean = [0.5, 0.5, 0.5] 69 | std = [0.5, 0.5, 0.5] 70 | transform = transforms.Compose([transforms.Resize([32, 32]), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 71 | dataset = datasets.STL10(args.dataset_root, split='train', download=True, transform=transform) # no augmentation 72 | class_names = dataset.classes 73 | elif args.dataset == 'PathMNIST': 74 | num_classes = 9 75 | mean = [0.5, 0.5, 0.5] 76 | std = [0.5, 0.5, 0.5] 77 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 78 | dataset = PathMNIST(split="train", download=True, root=args.dataset_root, transform=transform) # no transformation 79 | dataset.labels = np.array(np.squeeze(dataset.labels).tolist(), dtype='int64') 80 | class_names = dataset.info['label'].values() 81 | for lbl in dataset.info['label'].keys(): 82 | print(f'class {lbl} have {(dataset.labels == int(lbl)).sum()} samples') 83 | elif args.dataset == 'OrganSMNIST': 84 | num_classes = 11 85 | mean = [0.5, 0.5, 0.5] 86 | std = [0.5, 0.5, 0.5] 87 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 88 | dataset = OrganSMNIST(split="train", download=True, root=args.dataset_root, transform=transform) # no transformation 89 | dataset.labels = np.array(np.squeeze(dataset.labels).tolist(), dtype='int64') 90 | class_names = dataset.info['label'].values() 91 | for lbl in dataset.info['label'].keys(): 92 | print(f'class {lbl} have {(dataset.labels == int(lbl)).sum()} samples') 93 | elif args.dataset == 'OCTMNIST': 94 | num_channels = 1 95 | num_classes = 4 96 | mean = [0.5,] 97 | std = [0.5,] 98 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 99 | dataset = OCTMNIST(split="train", download=True, root=args.dataset_root, transform=transform) # no transformation 100 | dataset.labels = np.array(np.squeeze(dataset.labels).tolist(), dtype='int64') 101 | for lbl in dataset.info['label'].keys(): 102 | print(f'class {lbl} have {(dataset.labels == int(lbl)).sum()} samples') 103 | elif args.dataset == 'ImageNette': 104 | num_channels = 3 105 | num_classes = 10 106 | mean = [0.5, 0.5, 0.5] 107 | std = [0.5, 0.5, 0.5] 108 | path = untar_data(URLs.IMAGENETTE) 109 | print(path) 110 | transform = transforms.Compose([transforms.Resize([64, 64]), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 111 | dataset = datasets.ImageFolder(root=f'{path}/train', transform=transform) # cancel augment 112 | dataset.labels = np.array(np.squeeze(dataset.targets).tolist(), dtype='int64') 113 | for lbl in range(10): 114 | print(f'class {lbl} have {(dataset.labels == int(lbl)).sum()} samples') 115 | elif args.dataset == 'OrganCMNIST224': 116 | num_classes = 11 117 | mean = [0.5,] 118 | std = [0.5,] 119 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 120 | dataset = OrganCMNIST(split="train", download=True, root=args.dataset_root, transform=transform, size=224) # no transformation 121 | dataset.labels = np.array(np.squeeze(dataset.labels).tolist(), dtype='int64') 122 | class_names = dataset.info['label'].values() 123 | for lbl in dataset.info['label'].keys(): 124 | print(f'class {lbl} have {(dataset.labels == int(lbl)).sum()} samples') 125 | elif args.dataset == 'PneumoniaMNIST224': 126 | num_classes = 2 127 | mean = [0.5,] 128 | std = [0.5,] 129 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 130 | dataset = PneumoniaMNIST(split="train", download=True, root=args.dataset_root, transform=transform, size=224) # no transformation 131 | dataset.labels = np.array(np.squeeze(dataset.labels).tolist(), dtype='int64') 132 | class_names = dataset.info['label'].values() 133 | for lbl in dataset.info['label'].keys(): 134 | print(f'class {lbl} have {(dataset.labels == int(lbl)).sum()} samples') 135 | elif args.dataset == 'RetinaMNIST224': 136 | num_classes = 5 137 | mean = [0.5, 0.5, 0.5] 138 | std = [0.5, 0.5, 0.5] 139 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) 140 | dataset = RetinaMNIST(split="train", download=True, root=args.dataset_root, transform=transform, size=224) # no transformation 141 | dataset.labels = np.array(np.squeeze(dataset.labels).tolist(), dtype='int64') 142 | class_names = dataset.info['label'].values() 143 | for lbl in dataset.info['label'].keys(): 144 | print(f'class {lbl} have {(dataset.labels == int(lbl)).sum()} samples') 145 | else: 146 | exit(f'unknown dataset: f{args.dataset}') 147 | 148 | if args.dataset in ['CIFAR10', 'CIFAR100', 'FMNIST', 'CIFAR100C']: 149 | labels = np.array(dataset.targets, dtype='int64') 150 | elif args.dataset in ['PathMNIST', 'OrganAMNIST', 'OCTMNIST','OrganSMNIST', 'ImageNette', 'OrganCMNIST224', 'PneumoniaMNIST224', 'RetinaMNIST224', 'STL', 'STL32']: 151 | labels = dataset.labels 152 | 153 | dict_users = {} 154 | dict_classes = {} 155 | 156 | def dirichlet_split(): 157 | min_size = -1 158 | min_require_size = 0 159 | K = num_classes 160 | if args.dataset in ['CIFAR10', 'FMNIST']: 161 | labels = np.array(dataset.targets, dtype='int64') 162 | elif args.dataset in ['PathMNIST', 'OCTMNIST', 'OrganSMNIST', 'OrganCMNIST', 'ImageNette', 'OrganCMNIST224', 'PneumoniaMNIST224', 'RetinaMNIST224', 'STL', 'STL32']: 163 | labels = dataset.labels 164 | N = labels.shape[0] 165 | while min_size < min_require_size: 166 | idx_batch = [[] for _ in range(args.client_num)] 167 | for k in range(K): 168 | idx_k = np.where(labels == k)[0] 169 | np.random.shuffle(idx_k) 170 | proportions = np.random.dirichlet(np.repeat(args.alpha, args.client_num)) 171 | # print(proportions) 172 | proportions = np.array([p * (len(idx_j) < N / args.client_num) for p, idx_j in zip(proportions, idx_batch)]) 173 | proportions = proportions / proportions.sum() 174 | proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1] 175 | idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))] 176 | min_size = min([len(idx_j) for idx_j in idx_batch]) 177 | print(min_size) 178 | 179 | for j in range(args.client_num): 180 | np.random.shuffle(idx_batch[j]) 181 | dict_users[j] = idx_batch[j] 182 | 183 | return dict_users 184 | 185 | def label_split(): 186 | if args.dataset in ['CIFAR10', 'FMNIST']: 187 | labels = np.array(dataset.targets, dtype='int64') 188 | elif args.dataset in ['PathMNIST', 'OCTMNIST', 'OrganSMNIST', 'OrganCMNIST', 'ImageNette', 'OrganCMNIST224', 'PneumoniaMNIST224', 'RetinaMNIST224', 'STL', 'STL32']: 189 | labels = dataset.labels 190 | times = [0 for i in range(num_classes)] 191 | contain = [] 192 | for i in range(args.client_num): 193 | current = [i % num_classes] 194 | times[i % num_classes] += 1 195 | j = 1 196 | while (j < args.num_classes_per_client): 197 | ind = np.random.randint(0, num_classes-1) 198 | if (ind not in current): 199 | j = j+1 200 | current.append(ind) 201 | times[ind] += 1 202 | contain.append(current) 203 | 204 | dict_users = {i: np.ndarray(0,dtype=np.int64) for i in range(args.client_num)} 205 | for i in range(num_classes): 206 | idx_k = np.where(labels == i)[0] 207 | np.random.shuffle(idx_k) 208 | split = np.array_split(idx_k, times[i]) 209 | ids = 0 210 | for j in range(args.client_num): 211 | if i in contain[j]: 212 | dict_users[j] = np.append(dict_users[j], split[ids]) 213 | ids+=1 214 | 215 | for client_id in dict_users.keys(): 216 | dict_users[client_id] = dict_users[client_id].tolist() 217 | return dict_users 218 | 219 | 220 | def pathological_split(): 221 | if args.dataset in ['CIFAR10', 'FMNIST']: 222 | labels = np.array(dataset.targets, dtype='int64') 223 | elif args.dataset in ['PathMNIST', 'OCTMNIST', 'OrganSMNIST', 'OrganCMNIST', 'ImageNette', 'OrganCMNIST224', 'PneumoniaMNIST224', 'RetinaMNIST224', 'STL', 'STL32']: 224 | labels = dataset.labels 225 | num_samples = labels.shape[0] 226 | num_shards = args.num_classes_per_client * args.client_num 227 | assert num_samples % num_shards == 0 228 | num_imgs_per_shard = int(num_samples / num_shards) 229 | print(f"total sample: {num_samples}, num_shards: {num_shards}, num_imgs_per_shard: {num_imgs_per_shard}") 230 | 231 | idx_shard = [i for i in range(num_shards)] 232 | dict_users = {i: np.array([], dtype='int64') for i in range(args.client_num)} 233 | idxs = np.arange(num_samples) 234 | 235 | # sort labels 236 | idxs_labels = np.vstack((idxs, labels)) 237 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] 238 | idxs = idxs_labels[0,:] 239 | 240 | # divide and assign 241 | for i in range(args.client_num): 242 | rand_set = set(np.random.choice(idx_shard, args.num_classes_per_client, replace=False)) 243 | idx_shard = list(set(idx_shard) - rand_set) 244 | for rand in rand_set: 245 | dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs_per_shard:(rand+1)*num_imgs_per_shard]), axis=0) 246 | 247 | for client_id in dict_users.keys(): 248 | dict_users[client_id] = dict_users[client_id].tolist() 249 | return dict_users 250 | 251 | 252 | if args.method == 'dirichlet': 253 | dict_users = dirichlet_split() 254 | elif args.method == 'label': 255 | dict_users = label_split() 256 | elif args.method == 'pathological': 257 | dict_users = pathological_split() 258 | 259 | net_cls_counts = {} 260 | 261 | for net_i, dataidx in dict_users.items(): 262 | dict_classes[net_i] = [] 263 | unq, unq_cnt = np.unique(labels[dataidx], return_counts=True) 264 | tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))} 265 | net_cls_counts[net_i] = tmp 266 | for c, cnt in tmp.items(): 267 | if cnt >= 10: 268 | dict_classes[net_i].append(int(c)) 269 | 270 | print('Data statistics: %s' % str(net_cls_counts)) 271 | 272 | save_path = os.path.join(os.path.dirname(__file__), '../', 'split_file') 273 | if args.method == 'dirichlet': 274 | file_name = f'{args.dataset}_client_num={args.client_num}_alpha={args.alpha}.json' 275 | plot_client_data_distribution(num_classes, args.client_num, dict_users, labels, save_path=f'{args.dataset}_client_num={args.client_num}_alpha={args.alpha}.png') 276 | elif args.method == 'label': 277 | file_name = f'{args.dataset}_client_num={args.client_num}_label={args.num_classes_per_client}.json' 278 | plot_client_data_distribution(num_classes, args.client_num, dict_users, labels, save_path=f'{args.dataset}_client_num={args.client_num}_label={args.num_classes_per_client}.png') 279 | elif args.method == 'pathological': 280 | file_name = f'{args.dataset}_client_num={args.client_num}_pathological={args.num_classes_per_client}.json' 281 | plot_client_data_distribution(num_classes, args.client_num, dict_users, labels, save_path=f'{args.dataset}_client_num={args.client_num}_pathological={args.num_classes_per_client}.png') 282 | 283 | os.makedirs(save_path, exist_ok=True) 284 | with open(os.path.join(save_path, file_name), 'w') as json_file: 285 | json.dump({ 286 | "client_idx": [dict_users[i] for i in range(args.client_num)], 287 | "client_classes": [dict_classes[i] for i in range(args.client_num)], 288 | }, json_file, indent=4) 289 | 290 | if __name__ == "__main__": 291 | partition_parser = argparse.ArgumentParser() 292 | 293 | partition_parser.add_argument("--dataset", type=str, default='CIFAR10') 294 | partition_parser.add_argument("--method", type=str, default='dirichlet') 295 | partition_parser.add_argument("--client_num", type=int, default=10) 296 | partition_parser.add_argument("--alpha", type=float, default=0.2) 297 | partition_parser.add_argument("--num_classes_per_client", type=int, default=2) 298 | partition_parser.add_argument("--dataset_root", type=str, default='../torchvision') 299 | partition_parser.add_argument("--seed", type=int, default=42) 300 | args = partition_parser.parse_args() 301 | partition(args) -------------------------------------------------------------------------------- /src/client.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import gc 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import numpy as np 9 | # from sklearn.cluster import KMeans 10 | import time 11 | import logging 12 | import math 13 | from torchvision.utils import save_image 14 | from torch.utils.data import DataLoader, TensorDataset 15 | 16 | from dataset.data.dataset import PerLabelDatasetNonIID 17 | from src.utils import sample_random_model, random_pertube, DiffAugment, ParamDiffAug, get_model, MMDLoss, M3DLoss 18 | 19 | def get_gpu_mem_info(gpu_id=0): 20 | import pynvml 21 | pynvml.nvmlInit() 22 | gpu_id = int(str(gpu_id)[-1]) 23 | if gpu_id < 0 or gpu_id >= pynvml.nvmlDeviceGetCount(): 24 | logging.info(f'gpu_id {gpu_id} does not exsit!'.format(gpu_id)) 25 | return 0, 0, 0 26 | 27 | handler = pynvml.nvmlDeviceGetHandleByIndex(gpu_id) 28 | meminfo = pynvml.nvmlDeviceGetMemoryInfo(handler) 29 | total = round(meminfo.total / 1024 / 1024, 2) 30 | used = round(meminfo.used / 1024 / 1024, 2) 31 | free = round(meminfo.free / 1024 / 1024, 2) 32 | logging.info(f"total {total}MB, used {used}MB, free {free}MB") 33 | return total, used, free 34 | 35 | class Client: 36 | def __init__( 37 | self, 38 | cid: int, 39 | # --- dataset information --- 40 | train_set: PerLabelDatasetNonIID, 41 | classes: list[int], 42 | dataset_info: dict, 43 | # --- data condensation params --- 44 | ipc: int, 45 | compression_ratio: float, 46 | dc_iterations: int, 47 | real_batch_size: int, 48 | image_lr: float, 49 | image_momentum: float, 50 | image_weight_decay: float, 51 | lr: float, 52 | momentum: float, 53 | weight_decay: float, 54 | local_ep: int, 55 | dsa: bool, 56 | dsa_strategy: str, 57 | init: str, 58 | clip_norm: float, 59 | gamma: float, 60 | lamda: float, 61 | b: float, 62 | con_temp: float, 63 | kernel: str, 64 | save_root_path: str, 65 | device: torch.device, 66 | ): 67 | self.cid = cid 68 | 69 | self.train_set = train_set 70 | self.classes = classes 71 | self.dataset_info = dataset_info 72 | 73 | self.ipc = ipc 74 | self.compression_ratio = compression_ratio 75 | self.dc_iterations = dc_iterations 76 | self.real_batch_size = real_batch_size 77 | self.image_lr = image_lr 78 | self.image_momentum = image_momentum 79 | self.image_weight_decay = image_weight_decay 80 | self.lr = lr 81 | self.momentum = momentum 82 | self.weight_decay = weight_decay 83 | self.round = -1 84 | self.local_ep = local_ep 85 | self.dsa = dsa 86 | self.dsa_strategy = dsa_strategy 87 | self.model_name = None 88 | self.global_model = None 89 | self.prev_global_model = None 90 | self.dsa_param = ParamDiffAug() 91 | self.init = init 92 | self.clip_norm = clip_norm 93 | self.gamma = gamma 94 | self.lamda = lamda 95 | self.b = b 96 | self.con_temp = con_temp 97 | self.kernel = kernel 98 | self.save_root_path = save_root_path 99 | self.device = device 100 | 101 | if len(self.classes) > 0: 102 | if self.compression_ratio > 0.: 103 | self.ipc_dict = {c: max(5, int(math.ceil(len(self.train_set.indices_class[c])*self.compression_ratio))) for c in self.classes} 104 | else: 105 | self.ipc_dict = {c: self.ipc for c in self.classes} 106 | num_synthetic_images = sum(self.ipc_dict.values()) 107 | self.accumulate_num_syn_imgs = [0,] 108 | for i, c in enumerate(self.classes): 109 | self.accumulate_num_syn_imgs.append(self.accumulate_num_syn_imgs[-1] + self.ipc_dict[c]) 110 | 111 | self.synthetic_images = torch.randn( 112 | size=( 113 | num_synthetic_images, 114 | dataset_info['channel'], 115 | dataset_info['im_size'][0], 116 | dataset_info['im_size'][1], 117 | ), 118 | dtype=torch.float, 119 | requires_grad=True, 120 | device=self.device, 121 | ) 122 | self.synthetic_labels = torch.cat([torch.ones(self.ipc_dict[c]) * c for c in self.classes]).long().to(self.device) 123 | 124 | 125 | def train_weighted_sample(self): 126 | self.round += 1 127 | # initialize S_k and initialize optimizer 128 | self.initialization() 129 | logging.info("synthesize from random noise") 130 | optimizer_image = torch.optim.SGD([self.synthetic_images,], lr=self.image_lr, momentum=self.image_momentum, weight_decay=self.image_weight_decay) 131 | optimizer_image.zero_grad() 132 | logging.info(f"client {self.cid} have real samples {[len(self.train_set.indices_class[c]) for c in self.classes]}") 133 | logging.info(f"client {self.cid} will condense {self.ipc_dict} samples for each class it owns") 134 | 135 | if self.round == 0: 136 | self.global_model = get_model(self.model_name, self.dataset_info).to(self.device) 137 | prototypes = self.get_feature_prototype() 138 | logit_prototypes = self.get_logit_prototype() 139 | 140 | logging.info(f"loss weighted matching the samples") 141 | self.train_set.cal_loss(copy.deepcopy(self.global_model), copy.deepcopy(self.prev_global_model), lamda=self.lamda, gamma=self.gamma, b=self.b, rounds=self.round, cid=self.cid, save_root_path=self.save_root_path) 142 | self.train_set.pre_sample(it=self.dc_iterations+1, bs=self.real_batch_size) 143 | 144 | total_loss = 0. 145 | self.global_model.train() 146 | for param in list(self.global_model.parameters()): 147 | param.requires_grad = False 148 | 149 | for dc_iteration in range(self.dc_iterations+1): 150 | loss = torch.tensor(0.0).to(self.device) 151 | images_real_all = [] 152 | images_syn_all = [] 153 | num_real_image = [0, ] 154 | for i, c in enumerate(self.classes): 155 | real_image = self.train_set.images_all[self.train_set.sample_indices[c][dc_iteration:dc_iteration+self.real_batch_size]] 156 | # real_image = self.train_set.weighted_sample(c, dc_iteration, self.real_batch_size) 157 | num_real_image.append(num_real_image[-1] + real_image.shape[0]) 158 | synthetic_image = self.synthetic_images[self.accumulate_num_syn_imgs[i] : self.accumulate_num_syn_imgs[i+1]].reshape( 159 | (self.ipc_dict[c], self.dataset_info['channel'], self.dataset_info['im_size'][0], self.dataset_info['im_size'][1])) 160 | 161 | if self.dsa: 162 | seed = int(time.time() * 1000) % 100000 163 | real_image = DiffAugment(real_image, self.dsa_strategy, seed=seed, param=self.dsa_param) 164 | synthetic_image = DiffAugment(synthetic_image, self.dsa_strategy, seed=seed, param=self.dsa_param) 165 | 166 | images_real_all.append(real_image) 167 | images_syn_all.append(synthetic_image) 168 | 169 | images_real_all = torch.cat(images_real_all, dim=0) 170 | images_syn_all = torch.cat(images_syn_all, dim=0) 171 | self.global_model.train() 172 | real_feature = self.global_model.embed(images_real_all).detach() 173 | self.global_model.eval() 174 | synthetic_feature = self.global_model.embed(images_syn_all) 175 | 176 | for i, c in enumerate(self.classes): 177 | mean_real_feature = torch.mean(real_feature[num_real_image[i] : num_real_image[i+1]], dim=0) 178 | mean_synthetic_feature = torch.mean(synthetic_feature[self.accumulate_num_syn_imgs[i] : self.accumulate_num_syn_imgs[i+1]], dim=0) 179 | loss += torch.sum((mean_real_feature - mean_synthetic_feature)**2) 180 | 181 | total_loss += loss.item() 182 | optimizer_image.zero_grad() 183 | loss.backward() 184 | total_norm = nn.utils.clip_grad_norm_([self.synthetic_images,], max_norm=self.clip_norm) 185 | optimizer_image.step() 186 | 187 | if dc_iteration % 200 == 0 or dc_iteration == self.dc_iterations: 188 | logging.info(f'client {self.cid}, data condensation {dc_iteration}, total loss = {loss.item()}, avg loss = {loss.item() / len(self.classes)}') 189 | 190 | # return S_k 191 | synthetic_labels = torch.cat([torch.ones(self.ipc_dict[c]) * c for c in self.classes]) 192 | # torch.save({'data': self.synthetic_images.detach().cpu(), 'label': synthetic_labels.detach().cpu()}, os.path.join(self.save_root_path, f"round{self.round}_client{self.cid}.pt")) 193 | return copy.deepcopy(self.synthetic_images.detach()), copy.deepcopy(synthetic_labels), total_loss/(len(self.classes)*self.dc_iterations), self.ipc_dict, self.accumulate_num_syn_imgs, prototypes, logit_prototypes 194 | 195 | def train_weighted_MMD(self): 196 | self.round += 1 197 | 198 | # initialize S_k and initialize optimizer 199 | self.initialization() 200 | optimizer_image = torch.optim.SGD([self.synthetic_images,], lr=self.image_lr, momentum=self.image_momentum, weight_decay=self.image_weight_decay) 201 | optimizer_image.zero_grad() 202 | loss_fn = nn.CrossEntropyLoss() 203 | logging.info(f"client {self.cid} have real samples {[len(self.train_set.indices_class[c]) for c in self.classes]}") 204 | logging.info(f"client {self.cid} will condense {self.ipc_dict} samples for each class it owns") 205 | 206 | if self.round == 0: 207 | self.global_model = get_model(self.model_name, self.dataset_info).to(self.device) 208 | prototypes = self.get_feature_prototype() 209 | logit_prototypes = self.get_logit_prototype() 210 | 211 | logging.info(f"loss weighted matching the samples") 212 | self.train_set.cal_loss(copy.deepcopy(self.global_model), copy.deepcopy(self.prev_global_model), lamda=self.lamda, gamma=self.gamma, b=self.b, rounds=self.round, cid=self.cid, save_root_path=self.save_root_path) 213 | self.train_set.pre_sample(it=self.dc_iterations+1, bs=self.real_batch_size) 214 | 215 | total_loss = 0. 216 | mmd_criterion = M3DLoss(kernel_type=self.kernel, device=self.device) 217 | self.global_model.train() 218 | for param in list(self.global_model.parameters()): 219 | param.requires_grad = False 220 | 221 | for dc_iteration in range(self.dc_iterations+1): 222 | loss = torch.tensor(0.0).to(self.device) 223 | images_real_all = [] 224 | images_syn_all = [] 225 | num_real_image = [0, ] 226 | for i, c in enumerate(self.classes): 227 | real_image = self.train_set.images_all[self.train_set.sample_indices[c][dc_iteration:dc_iteration+self.real_batch_size]] 228 | # real_image = self.train_set.weighted_sample(c, dc_iteration, self.real_batch_size) 229 | num_real_image.append(num_real_image[-1] + real_image.shape[0]) 230 | synthetic_image = self.synthetic_images[self.accumulate_num_syn_imgs[i] : self.accumulate_num_syn_imgs[i+1]].reshape( 231 | (self.ipc_dict[c], self.dataset_info['channel'], self.dataset_info['im_size'][0], self.dataset_info['im_size'][1])) 232 | 233 | if self.dsa: 234 | seed = int(time.time() * 1000) % 100000 235 | real_image = DiffAugment(real_image, self.dsa_strategy, seed=seed, param=self.dsa_param) 236 | synthetic_image = DiffAugment(synthetic_image, self.dsa_strategy, seed=seed, param=self.dsa_param) 237 | 238 | images_real_all.append(real_image) 239 | images_syn_all.append(synthetic_image) 240 | 241 | images_real_all = torch.cat(images_real_all, dim=0) 242 | images_syn_all = torch.cat(images_syn_all, dim=0) 243 | self.global_model.train() 244 | real_feature = self.global_model.embed(images_real_all).detach() 245 | self.global_model.eval() 246 | synthetic_feature = self.global_model.embed(images_syn_all) 247 | 248 | for i, c in enumerate(self.classes): 249 | loss += mmd_criterion(real_feature[num_real_image[i] : num_real_image[i+1]], synthetic_feature[self.accumulate_num_syn_imgs[i] : self.accumulate_num_syn_imgs[i+1]]) 250 | 251 | total_loss += loss.item() 252 | optimizer_image.zero_grad() 253 | loss.backward() 254 | total_norm = nn.utils.clip_grad_norm_([self.synthetic_images,], max_norm=self.clip_norm) 255 | optimizer_image.step() 256 | 257 | if dc_iteration % 200 == 0 or dc_iteration == self.dc_iterations: 258 | logging.info(f'client {self.cid}, data condensation {dc_iteration}, total loss = {loss.item()}, avg loss = {loss.item() / len(self.classes)}') 259 | 260 | # return S_k 261 | synthetic_labels = torch.cat([torch.ones(self.ipc_dict[c]) * c for c in self.classes]) 262 | # torch.save({'data': self.synthetic_images.detach().cpu(), 'label': synthetic_labels.detach().cpu()}, os.path.join(self.save_root_path, f"round{self.round}_client{self.cid}.pt")) 263 | return copy.deepcopy(self.synthetic_images.detach()), copy.deepcopy(synthetic_labels), total_loss/(len(self.classes)*self.dc_iterations), self.ipc_dict, self.accumulate_num_syn_imgs, prototypes, logit_prototypes 264 | 265 | def get_feature_prototype(self): 266 | logging.info(f"get_feature_prototype") 267 | prototypes = {c: None for c in self.classes} 268 | self.global_model.eval() 269 | for param in list(self.global_model.parameters()): 270 | param.requires_grad = False 271 | for c in self.classes: 272 | tot_num_c = len(self.train_set.indices_class[c]) 273 | if tot_num_c > 500: 274 | real_feature_c = [] 275 | batch_size = 0 276 | for it in range(0, tot_num_c, 500): 277 | if it + 500 >= tot_num_c: 278 | real_feature_c_batch = self.global_model.embed(self.train_set.images_all[self.train_set.indices_class[c][it: tot_num_c]]).detach() 279 | real_feature_c.append(torch.sum(real_feature_c_batch, dim=0)) 280 | else: 281 | real_feature_c_batch = self.global_model.embed(self.train_set.images_all[self.train_set.indices_class[c][it: it+500]]).detach() 282 | real_feature_c.append(torch.sum(real_feature_c_batch, dim=0)) 283 | real_feature_c = torch.vstack(real_feature_c) 284 | real_feature_c = torch.sum(real_feature_c, dim=0) / tot_num_c 285 | prototypes[c] = (real_feature_c, tot_num_c) 286 | del real_feature_c 287 | else: 288 | real_images_c = self.train_set.get_all_images(c) 289 | real_feature_c = self.global_model.embed(real_images_c) 290 | prototypes[c] = (torch.mean(real_feature_c, dim=0), tot_num_c) 291 | del real_feature_c, real_images_c 292 | torch.cuda.empty_cache() 293 | 294 | return prototypes 295 | 296 | def get_logit_prototype(self): 297 | logging.info(f"get_logit_prototype") 298 | prototypes = {c: None for c in self.classes} 299 | self.global_model.eval() 300 | for param in list(self.global_model.parameters()): 301 | param.requires_grad = False 302 | for c in self.classes: 303 | tot_num_c = len(self.train_set.indices_class[c]) 304 | if tot_num_c > 500: 305 | real_logit_c = [] 306 | real_score_c = [] 307 | for it in range(0, tot_num_c, 500): 308 | if it + 500 >= tot_num_c: 309 | real_logit_c_batch = self.global_model(self.train_set.images_all[self.train_set.indices_class[c][it: tot_num_c]]).detach() 310 | real_logit_c_batch_sm = F.softmax(real_logit_c_batch, dim=1) 311 | real_score_c.append(torch.log((real_logit_c_batch_sm[:, c]+1e-5) / (1 - real_logit_c_batch_sm[:, c]+1e-5))) 312 | real_logit_c.append(torch.sum(real_logit_c_batch, dim=0)) 313 | else: 314 | real_logit_c_batch = self.global_model(self.train_set.images_all[self.train_set.indices_class[c][it: it+500]]).detach() 315 | real_logit_c_batch_sm = F.softmax(real_logit_c_batch, dim=1) 316 | real_score_c.append(torch.log((real_logit_c_batch_sm[:, c]+1e-5) / (1 - real_logit_c_batch_sm[:, c]+1e-5))) 317 | real_logit_c.append(torch.sum(real_logit_c_batch, dim=0)) 318 | real_logit_c = torch.vstack(real_logit_c) 319 | real_score_c = torch.cat(real_score_c) 320 | real_logit_c = torch.sum(real_logit_c, dim=0) / tot_num_c 321 | prototypes[c] = (real_logit_c, tot_num_c) 322 | del real_logit_c 323 | else: 324 | real_images_c = self.train_set.get_all_images(c) 325 | real_logit_c = self.global_model(real_images_c) 326 | real_logit_c_sm = F.softmax(real_logit_c, dim=1) 327 | real_score_c = torch.log((real_logit_c_sm[:, c]+1e-5) / (1 - real_logit_c_sm[:, c]+1e-5)) 328 | prototypes[c] = (torch.mean(real_logit_c, dim=0), tot_num_c) 329 | del real_logit_c, real_images_c 330 | torch.cuda.empty_cache() 331 | 332 | return prototypes 333 | 334 | def recieve_model(self, model_name, global_model=None): 335 | self.model_name = model_name 336 | if global_model is not None: 337 | if self.round == -1: 338 | self.prev_global_model = copy.deepcopy(global_model) 339 | else: 340 | self.prev_global_model = copy.deepcopy(self.global_model) 341 | self.global_model = copy.deepcopy(global_model) 342 | self.global_model.eval() 343 | 344 | def initialization(self): 345 | if self.init == 'real': 346 | logging.info("initialized by real images") 347 | for i, c in enumerate(self.classes): 348 | self.synthetic_images.data[self.accumulate_num_syn_imgs[i] : self.accumulate_num_syn_imgs[i+1]] = self.train_set.get_images(c, self.ipc_dict[c], avg=False).detach().data 349 | elif self.init == 'real_avg': 350 | logging.info("initialized by average real images") 351 | for i, c in enumerate(self.classes): 352 | self.synthetic_images.data[self.accumulate_num_syn_imgs[i] : self.accumulate_num_syn_imgs[i+1]] = self.train_set.get_images(c, self.ipc_dict[c], avg=True).detach().data 353 | elif self.init == 'random_noise': 354 | logging.info("initialized by random noise") 355 | pass 356 | -------------------------------------------------------------------------------- /src/server.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import random 4 | import time 5 | import gc 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data import DataLoader, TensorDataset 10 | from tqdm import tqdm 11 | from sklearn.manifold import TSNE 12 | import numpy as np 13 | from torchvision.utils import save_image 14 | 15 | from src.client import Client 16 | from src.utils import DiffAugment, ParamDiffAug, MMDLoss, ContrastiveLoss, SupervisedContrastiveLoss 17 | from .models import Projector 18 | 19 | import matplotlib.pyplot as plt 20 | import torch.nn.functional as F 21 | import json 22 | import logging 23 | 24 | def get_gpu_mem_info(gpu_id=0): 25 | import pynvml 26 | pynvml.nvmlInit() 27 | gpu_id = int(str(gpu_id)[-1]) 28 | if gpu_id < 0 or gpu_id >= pynvml.nvmlDeviceGetCount(): 29 | logging.info(f'gpu_id {gpu_id} does not exsit!'.format(gpu_id)) 30 | return 0, 0, 0 31 | 32 | handler = pynvml.nvmlDeviceGetHandleByIndex(gpu_id) 33 | meminfo = pynvml.nvmlDeviceGetMemoryInfo(handler) 34 | total = round(meminfo.total / 1024 / 1024, 2) 35 | used = round(meminfo.used / 1024 / 1024, 2) 36 | free = round(meminfo.free / 1024 / 1024, 2) 37 | logging.info(f"total {total}MB, used {used}MB, free {free}MB") 38 | return total, used, free 39 | 40 | def get_embedding(model, data_input, device, batch_size=1024, detach=True): 41 | embedding_list = [] 42 | total_num = data_input.shape[0] 43 | for idx in range(0, total_num, batch_size): 44 | batch_st = idx 45 | if batch_st + batch_size >= total_num: 46 | batch_ed = total_num 47 | else: 48 | batch_ed = idx+batch_size 49 | if detach: 50 | embedding_list.append(model.embed(data_input[batch_st: batch_ed]).detach()) 51 | else: 52 | embedding_list.append(model.embed(data_input[batch_st: batch_ed])) 53 | 54 | print(f"get embedding > 5000 with batchsize={batch_size}...") 55 | get_gpu_mem_info(device) 56 | embedding_list = torch.cat(embedding_list, dim=0) 57 | return embedding_list 58 | 59 | class Server: 60 | def __init__( 61 | self, 62 | train_set, 63 | ipc, 64 | dataset_info, 65 | global_model_name: str, 66 | global_model: nn.Module, 67 | clients: list[Client], 68 | # --- model training params --- 69 | communication_rounds: int, 70 | join_ratio: float, 71 | batch_size: int, 72 | model_epochs: int, 73 | lr_server: float, 74 | momentum_server: float, 75 | weight_decay_server: float, 76 | lr_head: float, 77 | momentum_head: float, 78 | weight_decay_head: float, 79 | weighted_matching: bool, 80 | weighted_sample: bool, 81 | weighted_mmd: bool, 82 | contrastive_way: str, 83 | con_beta: float, 84 | con_temp: float, 85 | topk: int, 86 | dsa: bool, 87 | dsa_strategy: str, 88 | preserve_all: bool, 89 | # --- test and evaluation information --- 90 | eval_gap: int, 91 | test_set: object, 92 | test_loader: DataLoader, 93 | device: torch.device, 94 | # --- save model and synthetic images --- 95 | model_identification: str, 96 | save_root_path: str 97 | ): 98 | self.train_set = train_set 99 | self.ipc = ipc 100 | self.dataset_info = dataset_info 101 | self.global_model_name = global_model_name 102 | self.global_model = global_model.to(device) 103 | self.clients = clients 104 | 105 | self.communication_rounds = communication_rounds 106 | self.join_ratio = join_ratio 107 | self.batch_size = batch_size 108 | self.model_epochs = model_epochs 109 | self.lr_server = lr_server 110 | self.momentum_server = momentum_server 111 | self.weight_decay_server = weight_decay_server 112 | self.lr_head = lr_head 113 | self.momentum_head = momentum_head 114 | self.weight_decay_head = weight_decay_head 115 | self.weighted_matching = weighted_matching 116 | self.weighted_sample = weighted_sample 117 | self.weighted_mmd = weighted_mmd 118 | self.contrastive_way = contrastive_way 119 | self.con_beta = con_beta 120 | self.con_temp = con_temp 121 | self.topk = topk 122 | self.dsa = dsa 123 | self.dsa_strategy = dsa_strategy 124 | self.dsa_param = ParamDiffAug() 125 | self.preserve_all = preserve_all 126 | 127 | self.eval_gap = eval_gap 128 | self.test_set = test_set 129 | self.test_loader = test_loader 130 | self.device = device 131 | 132 | self.model_identification = model_identification 133 | self.save_root_path = save_root_path 134 | 135 | def fit(self): 136 | evaluate_acc = 0 137 | round_list = [] 138 | evaluate_acc_list = [] 139 | img_syn_loss = {idx: [] for idx in range(len(self.clients))} 140 | 141 | all_synthetic_data = [] 142 | all_synthetic_label = [] 143 | all_syn_imgs_c = {c: [] for c in range(0, self.dataset_info['num_classes'])} 144 | mmd_gap = {c: [] for c in range(0, self.dataset_info['num_classes'])} 145 | accumlate_mmd = {c: [] for c in range(0, self.dataset_info['num_classes'])} 146 | prev_syn_proto = None 147 | 148 | for rounds in range(self.communication_rounds): 149 | logging.info(f' ====== round {rounds} ======') 150 | start_time = time.time() 151 | logging.info('---------- client training ----------') 152 | 153 | selected_clients = self.select_clients() 154 | selected_clients_id = [selected_client.cid for selected_client in selected_clients] 155 | logging.info(f'selected clients: {selected_clients_id}') 156 | 157 | server_prototypes = {c: 0 for c in range(0, self.dataset_info['num_classes'])} 158 | server_proto_tensor = [] 159 | server_logit_prototypes = {c: 0 for c in range(0, self.dataset_info['num_classes'])} 160 | server_logit_proto_tensor = [] 161 | 162 | num_samples = {c: 0 for c in range(0, self.dataset_info['num_classes'])} 163 | syn_imgs_all = {c: [] for c in range(0, self.dataset_info['num_classes'])} 164 | syn_imgs_num_cur = {c: 0 for c in range(0, self.dataset_info['num_classes'])} 165 | idx_client = {c: {client.cid: [] for client in selected_clients} for c in range(0,self.dataset_info['num_classes'])} 166 | for client in selected_clients: 167 | print(f"Round {rounds}, client {client.cid} start training...") 168 | get_gpu_mem_info(self.device) 169 | client.recieve_model(self.global_model_name, self.global_model) 170 | # if len(client.classes) == 0: 171 | # logging.info(f"skip client {client.cid}") 172 | # continue 173 | condense_st_time = time.time() 174 | if self.weighted_sample: 175 | imgs, labels, syn_loss, ipc_dict, accmulate_num_syn_imgs, prototypes, logit_prototypes = client.train_weighted_sample() 176 | elif self.weighted_mmd: 177 | imgs, labels, syn_loss, ipc_dict, accmulate_num_syn_imgs, prototypes, logit_prototypes = client.train_weighted_MMD() 178 | condense_ed_time = time.time() 179 | logging.info(f"Round {rounds}, client {client.cid} condense time: {condense_ed_time - condense_st_time}") 180 | 181 | img_syn_loss[client.cid].append(syn_loss) 182 | for i, c in enumerate(client.classes): 183 | synthetic_image_c = imgs[accmulate_num_syn_imgs[i] : accmulate_num_syn_imgs[i+1]].reshape( 184 | (ipc_dict[c], self.dataset_info['channel'], self.dataset_info['im_size'][0], self.dataset_info['im_size'][1])) 185 | syn_imgs_all[c].append(synthetic_image_c) 186 | idx_client[c][client.cid] = range(syn_imgs_num_cur[c], syn_imgs_num_cur[c] + ipc_dict[c]) 187 | syn_imgs_num_cur[c] += ipc_dict[c] 188 | 189 | for i, c in enumerate(client.classes): 190 | logging.info(f"client {client.cid}, class {c} have {prototypes[c][1]} samples") 191 | server_prototypes[c] += prototypes[c][0] * prototypes[c][1] 192 | num_samples[c] += prototypes[c][1] 193 | server_logit_prototypes[c] += logit_prototypes[c][0] * logit_prototypes[c][1] 194 | 195 | print(f"Round {rounds}, client {client.cid} finish training...") 196 | get_gpu_mem_info(self.device) 197 | 198 | logging.info(f"server receives {syn_imgs_num_cur} condensed samples for each class") 199 | 200 | for c in range(self.dataset_info['num_classes']): 201 | server_prototypes[c] /= num_samples[c] 202 | server_logit_prototypes[c] /= num_samples[c] 203 | server_proto_tensor.append(server_prototypes[c]) 204 | server_logit_proto_tensor.append(server_logit_prototypes[c]) 205 | 206 | server_proto_tensor = torch.vstack(server_proto_tensor).to(self.device).detach() 207 | server_proto_tensor = F.normalize(server_proto_tensor, dim=1) # 是不是应该先norm再平均 208 | server_logit_proto_tensor = torch.vstack(server_logit_proto_tensor).to(self.device).detach() 209 | logging.info(f"logit_proto before softmax: {server_logit_proto_tensor}") 210 | _, relation_class = self.get_mask(server_logit_proto_tensor, k = self.topk) 211 | if rounds > 0: 212 | for c in range(self.dataset_info['num_classes']): 213 | if c not in relation_class[c]: 214 | logging.info(f"class {c} not in relation_class, manually added") 215 | relation_class[c][-1] = c 216 | 217 | logging.info(f"shape of prototypes in tensor: {server_proto_tensor.shape}") 218 | logging.info(f"shape of logit prototypes in tensor: {server_logit_proto_tensor.shape}") 219 | logging.info(f"relation tensor: {relation_class}") 220 | 221 | for c in range(0, self.dataset_info['num_classes']): 222 | syn_imgs_all[c] = torch.vstack(syn_imgs_all[c]) 223 | 224 | synthetic_data = [] 225 | synthetic_label = [] 226 | for c in range(0, self.dataset_info['num_classes']): 227 | all_syn_imgs_c[c].append(syn_imgs_all[c]) 228 | synthetic_data.append(syn_imgs_all[c]) 229 | synthetic_label.append(torch.ones(syn_imgs_all[c].shape[0])*c) 230 | 231 | synthetic_data = torch.vstack(synthetic_data) 232 | synthetic_label = torch.cat(synthetic_label, dim=0) 233 | 234 | logging.info('---------- update global model ----------') 235 | # update model parameters by SGD 236 | all_synthetic_data.append(synthetic_data) 237 | all_synthetic_label.append(synthetic_label) 238 | logging.info(len(synthetic_data)) 239 | 240 | preserve_thres = max(10, self.communication_rounds // 2) 241 | logging.info(f"preserve threshold: {preserve_thres}") 242 | if (not self.preserve_all) and (len(all_synthetic_data) > preserve_thres): 243 | all_synthetic_data = all_synthetic_data[-preserve_thres: ] 244 | all_synthetic_label = all_synthetic_label[-preserve_thres: ] 245 | 246 | logging.info(len(all_synthetic_data)) 247 | all_synthetic_data_eval = torch.cat(all_synthetic_data, dim=0).cpu() 248 | all_synthetic_label_eval = torch.cat(all_synthetic_label, dim=0).cpu() 249 | synthetic_dataset = TensorDataset(all_synthetic_data_eval, all_synthetic_label_eval) 250 | logging.info(f"Round {rounds}: # synthetic sample: {len(synthetic_dataset)}") 251 | synthetic_dataloader = DataLoader(synthetic_dataset, self.batch_size, shuffle=True, num_workers=2) 252 | 253 | self.global_model.train() 254 | model_optimizer = torch.optim.SGD( 255 | self.global_model.parameters(), 256 | lr=self.lr_server, 257 | weight_decay=self.weight_decay_server, 258 | momentum=self.momentum_server, 259 | ) 260 | model_optimizer.zero_grad() 261 | lr_schedule = torch.optim.lr_scheduler.StepLR(model_optimizer, step_size=(self.model_epochs//2), gamma=0.1) 262 | loss_function = torch.nn.CrossEntropyLoss() 263 | z_dim = server_proto_tensor.shape[1] 264 | relation_sup_con_loss = SupervisedContrastiveLoss(num_classes=self.dataset_info['num_classes'], device=self.device, temperature=self.con_temp, z_dim=z_dim, relation_class=relation_class) 265 | # con_loss = ContrastiveLoss(z_dim, device=self.device, temperature=self.con_temp) 266 | mlp_head_optimizer = torch.optim.Adam( 267 | relation_sup_con_loss.head.parameters(), 268 | lr=self.lr_head, 269 | weight_decay=self.weight_decay_head, 270 | # momentum=self.momentum_head, 271 | ) 272 | mlp_head_optimizer.zero_grad() 273 | head_lr_schedule = torch.optim.lr_scheduler.StepLR(mlp_head_optimizer, step_size=(self.model_epochs//2), gamma=0.1) 274 | 275 | print(f"Round {rounds}, global model start training...") 276 | get_gpu_mem_info(self.device) 277 | 278 | # evaluate ahead training 279 | acc, test_loss = self.evaluate() 280 | logging.info(f'round {rounds} evaluation: test acc is {acc:.4f}, test loss = {test_loss:.6f}') 281 | self.global_model.train() 282 | for param in list(self.global_model.parameters()): 283 | param.requires_grad = True 284 | for epoch in range(self.model_epochs+1): 285 | total_loss = 0 286 | total_con_loss = 0 287 | total_sample = 0 288 | for x, target in synthetic_dataloader: 289 | n_sample = target.shape[0] 290 | x, target = x.to(self.device), target.to(self.device) 291 | if self.con_beta > 0.: 292 | features, _ = self.global_model(x, train=True) 293 | if self.dsa: 294 | x = DiffAugment(x, self.dsa_strategy, param=self.dsa_param) 295 | 296 | target = target.long() 297 | _, pred = self.global_model(x, train=True) 298 | loss = loss_function(pred, target) 299 | total_loss += loss.item() * n_sample 300 | 301 | if self.con_beta > 0. and rounds > 0 and x.shape[0] > 1: 302 | if self.contrastive_way == 'supcon_asym_syn': 303 | assert prev_syn_proto is not None 304 | positive_proto = prev_syn_proto[target, :] 305 | loss_con = relation_sup_con_loss(features, target, positive_proto, asymmetric=True) 306 | total_con_loss += loss_con.item() * n_sample 307 | loss += self.con_beta * loss_con 308 | 309 | model_optimizer.zero_grad() 310 | loss.backward() 311 | model_optimizer.step() 312 | total_sample += n_sample 313 | if self.con_beta > 0. and self.contrastive_way in ['supcon_asym', 'supcon_asym_syn'] and rounds > 0: 314 | mlp_head_optimizer.step() 315 | 316 | total_loss /= total_sample 317 | total_con_loss /= total_sample 318 | # if self.dataset_info['name'] not in ['OCTMNIST']: 319 | lr_schedule.step() 320 | if self.con_beta > 0. and self.contrastive_way in ['supcon_asym', 'supcon_asym_syn'] and rounds > 0: 321 | head_lr_schedule.step() 322 | if epoch == (self.model_epochs // 2): 323 | logging.info(f"At epoch {epoch}, decay the con_beta with 0.1 factor") 324 | self.con_beta *= 0.1 325 | 326 | if epoch%100 == 0 or epoch == self.model_epochs: 327 | acc, test_loss = self.evaluate() 328 | self.global_model.train() 329 | logging.info(f"epoch {epoch}, train loss avg now = {total_loss:.6f}, train contrast loss now = {total_con_loss:.6f}, test acc now = {acc:.4f}, test loss now = {test_loss:.6f}") 330 | 331 | round_time = time.time() - start_time 332 | logging.info(f'epoch avg loss = {total_loss / self.model_epochs}, total time = {round_time}') 333 | 334 | print(f"Round {rounds}, global model finish training...") 335 | get_gpu_mem_info(self.device) 336 | 337 | logging.info(f"Round {rounds} finish, update the prev_syn_proto") 338 | prev_syn_proto = torch.zeros_like(server_proto_tensor).to(self.device) 339 | self.global_model.eval() 340 | with torch.no_grad(): 341 | for c in range(0, self.dataset_info['num_classes']): 342 | all_syn_cat = torch.cat(all_syn_imgs_c[c], dim=0) 343 | logging.info(f"{all_syn_cat.shape}") 344 | if all_syn_cat.shape[0] > 128: 345 | for it in range(0, all_syn_cat.shape[0], 128): 346 | if it + 128 >= all_syn_cat.shape[0]: 347 | prev_syn_proto[c, :] += torch.sum(self.global_model.embed(all_syn_cat[it: ]).detach(), dim=0) 348 | else: 349 | prev_syn_proto[c, :] += torch.sum(self.global_model.embed(all_syn_cat[it: it+128]).detach(), dim=0) 350 | prev_syn_proto[c, :] /= all_syn_cat.shape[0] 351 | else: 352 | prev_syn_proto[c, :] = torch.mean(self.global_model.embed(all_syn_cat).detach(), dim=0) 353 | prev_syn_proto = F.normalize(prev_syn_proto, dim=1).detach() 354 | logging.info(f"shape of prev_syn_proto: {prev_syn_proto.shape}") 355 | 356 | if rounds % self.eval_gap == 0: 357 | acc, test_loss = self.evaluate() 358 | logging.info(f'round {rounds} evaluation: test acc is {acc:.4f}, test loss = {test_loss:.6f}') 359 | evaluate_acc = acc 360 | round_list.append(rounds) 361 | evaluate_acc_list.append(evaluate_acc) 362 | 363 | # self.save_model(path=save_root_path, rounds=rounds, include_image=False) 364 | # torch.save(self.global_model.state_dict(), os.path.join(self.save_root_path, f"model_{rounds}.pt")) 365 | 366 | logging.info(evaluate_acc_list) 367 | logging.info(img_syn_loss) 368 | logging.info(mmd_gap) 369 | logging.info(accumlate_mmd) 370 | 371 | def get_mask(self, matrix, k=3, largest=True): 372 | min_val, min_idx = torch.topk(matrix, k=k, dim=-1, largest=largest) 373 | mask = torch.zeros_like(matrix) 374 | rows = torch.arange(min_idx.size(0)).unsqueeze(1) 375 | mask[rows, min_idx] = 1 376 | mask = mask.bool() 377 | return mask, min_idx 378 | 379 | def select_clients(self): 380 | return ( 381 | self.clients if self.join_ratio == 1.0 382 | else random.sample(self.clients, int(round(len(self.clients) * self.join_ratio))) 383 | ) 384 | 385 | def evaluate(self): 386 | prediction_matrix = {c: {c: 0 for c in range(self.dataset_info['num_classes'])} for c in range(self.dataset_info['num_classes'])} 387 | self.global_model.eval() 388 | with torch.no_grad(): 389 | correct, total, test_loss = 0, 0, 0. 390 | for x, target in self.test_loader: 391 | x, target = x.to(self.device), target.to(self.device, dtype=torch.int64) 392 | pred = self.global_model(x) 393 | test_loss += F.cross_entropy(pred, target, reduction='sum').item() 394 | _, pred_label = torch.max(pred.data, 1) 395 | total += x.data.size()[0] 396 | correct += (pred_label == target.data).sum().item() 397 | for i in range(target.shape[0]): 398 | prediction_matrix[target[i].item()][pred_label[i].item()] += 1 399 | 400 | logging.info(f"{prediction_matrix}") 401 | return correct / float(total), test_loss / float(total) 402 | 403 | def evaluate_model(self, model): 404 | model.eval() 405 | with torch.no_grad(): 406 | correct, total, test_loss = 0, 0, 0. 407 | for x, target in self.test_loader: 408 | x, target = x.to(self.device), target.to(self.device, dtype=torch.int64) 409 | pred = model(x) 410 | test_loss += F.cross_entropy(pred, target, reduction='sum').item() 411 | _, pred_label = torch.max(pred.data, 1) 412 | total += x.data.size()[0] 413 | correct += (pred_label == target.data).sum().item() 414 | 415 | return correct / float(total), test_loss / float(total) 416 | 417 | def make_checkpoint(self, rounds): 418 | checkpoint = { 419 | 'current_round': rounds, 420 | 'model': self.global_model.state_dict() 421 | } 422 | return checkpoint 423 | 424 | def save_model(self, path, rounds, include_image): 425 | # torch.save(self.make_checkpoint(rounds), os.path.join(path, f'model_{rounds}.pt')) 426 | torch.save(self.make_checkpoint(rounds), os.path.join(path, 'model.pt')) 427 | if include_image: 428 | raise NotImplemented('not implement yet') -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | from torch.optim.optimizer import Optimizer, required 7 | from torch.distributions.multivariate_normal import MultivariateNormal 8 | from src.models import ResNet18, ConvNet, ResNet18BN 9 | from .models import Projector 10 | import torch.nn.functional as F 11 | 12 | import torch 13 | import torch.nn as nn 14 | import logging 15 | 16 | def get_gpu_mem_info(gpu_id=0): 17 | import pynvml 18 | pynvml.nvmlInit() 19 | gpu_id = int(str(gpu_id)[-1]) 20 | if gpu_id < 0 or gpu_id >= pynvml.nvmlDeviceGetCount(): 21 | print(f'gpu_id {gpu_id} does not exsit!'.format(gpu_id)) 22 | return 0, 0, 0 23 | 24 | handler = pynvml.nvmlDeviceGetHandleByIndex(gpu_id) 25 | meminfo = pynvml.nvmlDeviceGetMemoryInfo(handler) 26 | total = round(meminfo.total / 1024 / 1024, 2) 27 | used = round(meminfo.used / 1024 / 1024, 2) 28 | free = round(meminfo.free / 1024 / 1024, 2) 29 | print(f"total {total}MB, used {used}MB, free {free}MB") 30 | return total, used, free 31 | 32 | class SupervisedContrastiveLoss(torch.nn.Module): 33 | def __init__(self, num_classes, device, temperature=0.07, z_dim=10, relation_class=None): 34 | super(SupervisedContrastiveLoss, self).__init__() 35 | self.device = device 36 | self.head = Projector(input_dim=z_dim, output_dim=z_dim).to(self.device) 37 | self.head.train() 38 | self.temperature = temperature 39 | self.relation_class = relation_class 40 | self.num_classes = num_classes 41 | 42 | def forward(self, x, y, proto=None, asymmetric=False): 43 | if asymmetric: 44 | x = self.head(x) 45 | x = F.normalize(x, dim=1) 46 | if proto is not None: 47 | sim_matrix = torch.exp(torch.matmul(x, proto.t()) / self.temperature) 48 | else: 49 | sim_matrix = torch.exp(torch.matmul(x, x.t()) / self.temperature) 50 | # generate the mask for positive and negative pairs 51 | mask = torch.eq(y.unsqueeze(0), y.unsqueeze(1)).float() 52 | relation_mask = torch.zeros_like(mask) 53 | for i in range(self.num_classes): 54 | class_mask = (y == i).unsqueeze(1).float() 55 | for cls in self.relation_class[i]: 56 | relation_mask += class_mask * (y == cls).unsqueeze(0).float() 57 | loss = -torch.log((mask * sim_matrix).sum(1) / (relation_mask * sim_matrix).sum(1)).mean() 58 | return loss 59 | 60 | class ContrastiveLoss(nn.Module): 61 | def __init__(self, z_dim, device, temperature=1.0): 62 | super(ContrastiveLoss, self).__init__() 63 | self.device = device 64 | self.head = Projector(input_dim=z_dim, output_dim=z_dim).to(self.device) 65 | self.head.train() 66 | self.temperature = temperature 67 | 68 | def forward(self, x, proto, y, asymmetric=False): 69 | if asymmetric: 70 | x = self.head(x) 71 | x = F.normalize(x, dim=1) 72 | sim_matrix = torch.matmul(x, proto.t()) / self.temperature 73 | mask = torch.eq(y.unsqueeze(0), y.unsqueeze(1)).float() 74 | mask = mask / mask.sum(dim=1, keepdim=True) 75 | loss = -(torch.log_softmax(sim_matrix, dim=1) * mask).sum(dim=1).mean() 76 | return loss 77 | 78 | class RBF(nn.Module): 79 | def __init__(self, device='cpu', n_kernels=5, mul_factor=2.0, bandwidth=None): 80 | super().__init__() 81 | self.device = device 82 | self.bandwidth_multipliers = mul_factor ** (torch.arange(n_kernels) - n_kernels // 2) 83 | self.bandwidth_multipliers = self.bandwidth_multipliers.to(device) 84 | self.bandwidth = bandwidth 85 | 86 | def get_bandwidth(self, L2_distances): 87 | if self.bandwidth is None: 88 | n_samples = L2_distances.shape[0] 89 | return L2_distances.data.sum() / (n_samples ** 2 - n_samples) 90 | 91 | return self.bandwidth 92 | 93 | def forward(self, X): 94 | L2_distances = torch.cdist(X, X) ** 2 95 | return torch.exp(-L2_distances[None, ...] / (self.get_bandwidth(L2_distances) * self.bandwidth_multipliers)[:, None, None]).sum(dim=0) 96 | 97 | class PoliKernel(nn.Module): 98 | def __init__(self, constant_term=1, degree=2): 99 | super().__init__() 100 | self.constant_term = constant_term 101 | self.degree = degree 102 | 103 | def forward(self, X): 104 | K = (torch.matmul(X, X.t()) + self.constant_term) ** self.degree 105 | return K 106 | 107 | class LinearKernel(nn.Module): 108 | def __init__(self): 109 | super().__init__() 110 | 111 | def forward(self, X): 112 | K = torch.matmul(X, X.t()) 113 | return K 114 | 115 | class LaplaceKernel(nn.Module): 116 | def __init__(self): 117 | super().__init__() 118 | self.gammas = torch.FloatTensor([0.1, 1, 5]).cuda() 119 | 120 | def forward(self, X): 121 | L2_distances = torch.cdist(X, X) ** 2 122 | return torch.exp(-L2_distances[None, ...] * (self.gammas)[:, None, None]).sum(dim=0) 123 | 124 | class M3DLoss(nn.Module): 125 | def __init__(self, kernel_type, device): 126 | super().__init__() 127 | self.device = device 128 | if kernel_type == 'gaussian': 129 | self.kernel = RBF(device = self.device) 130 | elif kernel_type == 'linear': 131 | self.kernel = LinearKernel() 132 | elif kernel_type == 'polinominal': 133 | self.kernel = PoliKernel() 134 | elif kernel_type == 'laplace': 135 | self.kernel = LaplaceKernel() 136 | 137 | def forward(self, X, Y): 138 | K = self.kernel(torch.vstack([X, Y])) 139 | X_size = X.shape[0] 140 | XX = K[:X_size, :X_size].mean() 141 | XY = K[:X_size, X_size:].mean() 142 | YY = K[X_size:, X_size:].mean() 143 | return XX - 2 * XY + YY 144 | 145 | class MMDLoss(nn.Module): 146 | ''' 147 | https://github.com/jindongwang/transferlearning/blob/master/code/distance/mmd_pytorch.py 148 | 计算源域数据和目标域数据的MMD距离 149 | Params: 150 | source: 源域数据(n * len(x)) 151 | target: 目标域数据(m * len(y)) 152 | kernel_mul: 153 | kernel_num: 取不同高斯核的数量 154 | fix_sigma: 不同高斯核的sigma值 155 | Return: 156 | loss: MMD loss 157 | ''' 158 | def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5, fix_sigma=None, **kwargs): 159 | super(MMDLoss, self).__init__() 160 | self.kernel_num = kernel_num 161 | self.kernel_mul = kernel_mul 162 | self.fix_sigma = None 163 | self.kernel_type = kernel_type 164 | 165 | def guassian_kernel(self, source, target, kernel_mul, kernel_num, fix_sigma): 166 | n_samples = int(source.size()[0]) + int(target.size()[0]) 167 | total = torch.cat([source, target], dim=0) 168 | total0 = total.unsqueeze(0).expand( 169 | int(total.size(0)), int(total.size(0)), int(total.size(1))) 170 | total1 = total.unsqueeze(1).expand( 171 | int(total.size(0)), int(total.size(0)), int(total.size(1))) 172 | 173 | L2_distance = ((total0-total1)**2).sum(2) 174 | if fix_sigma: 175 | bandwidth = fix_sigma 176 | else: 177 | bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples) 178 | bandwidth /= kernel_mul ** (kernel_num // 2) 179 | bandwidth_list = [bandwidth * (kernel_mul**i) 180 | for i in range(kernel_num)] 181 | kernel_val = [torch.exp(-L2_distance / bandwidth_temp) 182 | for bandwidth_temp in bandwidth_list] 183 | return sum(kernel_val) 184 | 185 | def linear_mmd2(self, f_of_X, f_of_Y): 186 | loss = 0.0 187 | delta = f_of_X.float().mean(0) - f_of_Y.float().mean(0) 188 | loss = delta.dot(delta.T) 189 | return loss 190 | 191 | def forward(self, source, target): 192 | if self.kernel_type == 'linear': 193 | return self.linear_mmd2(source, target) 194 | elif self.kernel_type == 'rbf': 195 | batch_size = int(source.size()[0]) 196 | kernels = self.guassian_kernel( 197 | source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma) 198 | XX = torch.mean(kernels[:batch_size, :batch_size]) 199 | YY = torch.mean(kernels[batch_size:, batch_size:]) 200 | XY = torch.mean(kernels[:batch_size, batch_size:]) 201 | YX = torch.mean(kernels[batch_size:, :batch_size]) 202 | loss = torch.mean(XX + YY - XY - YX) 203 | return loss 204 | 205 | 206 | def get_model(model_name, dataset_info): 207 | if model_name == "ConvNet": 208 | model = ConvNet( 209 | channel=dataset_info['channel'], 210 | num_classes=dataset_info['num_classes'], 211 | net_width=128, 212 | net_depth=3, 213 | net_act='relu', 214 | net_norm='instancenorm', 215 | net_pooling='avgpooling', 216 | im_size=dataset_info['im_size'] 217 | ) 218 | elif model_name == "ConvNetBN": 219 | model = ConvNet( 220 | channel=dataset_info['channel'], 221 | num_classes=dataset_info['num_classes'], 222 | net_width=128, 223 | net_depth=3, 224 | net_act='relu', 225 | net_norm='batchnorm', 226 | net_pooling='avgpooling', 227 | im_size=dataset_info['im_size'] 228 | ) 229 | elif model_name == "ResNet": 230 | model = ResNet18( 231 | channel=dataset_info['channel'], 232 | num_classes=dataset_info['num_classes'] 233 | ) 234 | elif model_name == 'ResNet18BN': 235 | model = ResNet18BN( 236 | channel=dataset_info['channel'], 237 | num_classes=dataset_info['num_classes'] 238 | ) 239 | else: 240 | raise NotImplementedError("only support ConvNet and ResNet") 241 | 242 | return model 243 | 244 | def setup_seed(seed): 245 | torch.manual_seed(seed) 246 | torch.cuda.manual_seed(seed) 247 | torch.cuda.manual_seed_all(seed) 248 | np.random.seed(seed) 249 | random.seed(seed) 250 | torch.backends.cudnn.deterministic = True 251 | torch.backends.cudnn.benchmark = False 252 | torch.backends.cudnn.enabled = False 253 | 254 | def sample_random_model(model, rho): 255 | new_model = copy.deepcopy(model) 256 | parameters = new_model.parameters() 257 | 258 | mean = parameters.view(-1) 259 | multivariate_normal = MultivariateNormal(mean, torch.eye(mean.shape[0])) 260 | distance = rho + 1 261 | while distance > rho: 262 | sample = multivariate_normal.sample() 263 | distance = torch.sqrt(torch.sum((mean - sample)**2)) 264 | 265 | new_parameters = sample.view(parameters.shape) 266 | for old_param, new_param in zip(parameters, new_parameters): 267 | with torch.no_grad(): 268 | old_param.fill_(new_param) 269 | 270 | return new_model 271 | 272 | def random_pertube(model, rho): 273 | new_model = copy.deepcopy(model) 274 | for p in new_model.parameters(): 275 | gauss = torch.normal(mean=torch.zeros_like(p), std=1) 276 | if p.grad is None: 277 | p.grad = gauss 278 | else: 279 | p.grad.data.copy_(gauss.data) 280 | 281 | norm = torch.norm(torch.stack([p.grad.norm(p=2) for p in new_model.parameters() if p.grad is not None]), p=2) 282 | 283 | with torch.no_grad(): 284 | scale = rho / (norm + 1e-12) 285 | scale = torch.clamp(scale, max=1.0) 286 | for p in new_model.parameters(): 287 | if p.grad is not None: 288 | e_w = 1.0 * p.grad * scale.to(p) 289 | p.add_(e_w) 290 | 291 | new_model.zero_grad() 292 | return new_model 293 | 294 | 295 | def augment(images, dc_aug_param, device): 296 | # This can be sped up in the future. 297 | if dc_aug_param != None and dc_aug_param['strategy'] != 'none': 298 | scale = dc_aug_param['scale'] 299 | crop = dc_aug_param['crop'] 300 | rotate = dc_aug_param['rotate'] 301 | noise = dc_aug_param['noise'] 302 | strategy = dc_aug_param['strategy'] 303 | 304 | shape = images.shape 305 | mean = [] 306 | for c in range(shape[1]): 307 | mean.append(float(torch.mean(images[:,c]))) 308 | 309 | def cropfun(i): 310 | im_ = torch.zeros(shape[1],shape[2]+crop*2,shape[3]+crop*2, dtype=torch.float, device=device) 311 | for c in range(shape[1]): 312 | im_[c] = mean[c] 313 | im_[:, crop:crop+shape[2], crop:crop+shape[3]] = images[i] 314 | r, c = np.random.permutation(crop*2)[0], np.random.permutation(crop*2)[0] 315 | images[i] = im_[:, r:r+shape[2], c:c+shape[3]] 316 | 317 | def scalefun(i): 318 | h = int((np.random.uniform(1 - scale, 1 + scale)) * shape[2]) 319 | w = int((np.random.uniform(1 - scale, 1 + scale)) * shape[2]) 320 | tmp = F.interpolate(images[i:i + 1], [h, w], )[0] 321 | mhw = max(h, w, shape[2], shape[3]) 322 | im_ = torch.zeros(shape[1], mhw, mhw, dtype=torch.float, device=device) 323 | r = int((mhw - h) / 2) 324 | c = int((mhw - w) / 2) 325 | im_[:, r:r + h, c:c + w] = tmp 326 | r = int((mhw - shape[2]) / 2) 327 | c = int((mhw - shape[3]) / 2) 328 | images[i] = im_[:, r:r + shape[2], c:c + shape[3]] 329 | 330 | def rotatefun(i): 331 | im_ = scipyrotate(images[i].cpu().data.numpy(), angle=np.random.randint(-rotate, rotate), axes=(-2, -1), cval=np.mean(mean)) 332 | r = int((im_.shape[-2] - shape[-2]) / 2) 333 | c = int((im_.shape[-1] - shape[-1]) / 2) 334 | images[i] = torch.tensor(im_[:, r:r + shape[-2], c:c + shape[-1]], dtype=torch.float, device=device) 335 | 336 | def noisefun(i): 337 | images[i] = images[i] + noise * torch.randn(shape[1:], dtype=torch.float, device=device) 338 | 339 | 340 | augs = strategy.split('_') 341 | 342 | for i in range(shape[0]): 343 | choice = np.random.permutation(augs)[0] # randomly implement one augmentation 344 | if choice == 'crop': 345 | cropfun(i) 346 | elif choice == 'scale': 347 | scalefun(i) 348 | elif choice == 'rotate': 349 | rotatefun(i) 350 | elif choice == 'noise': 351 | noisefun(i) 352 | 353 | return images 354 | 355 | 356 | def get_daparam(dataset, model, model_eval): 357 | # We find that augmentation doesn't always benefit the performance. 358 | # So we do augmentation for some of the settings. 359 | 360 | dc_aug_param = dict() 361 | dc_aug_param['crop'] = 4 362 | dc_aug_param['scale'] = 0.2 363 | dc_aug_param['rotate'] = 45 364 | dc_aug_param['noise'] = 0.001 365 | dc_aug_param['strategy'] = 'none' 366 | 367 | if dataset == 'MNIST': 368 | dc_aug_param['strategy'] = 'crop_scale_rotate' 369 | 370 | if model_eval in ['ConvNetBN', 'ConvNet']: # Data augmentation makes model training with Batch Norm layer easier. 371 | dc_aug_param['strategy'] = 'crop_noise' 372 | 373 | return dc_aug_param 374 | 375 | 376 | class ParamDiffAug(): 377 | def __init__(self): 378 | self.aug_mode = 'S' #'multiple or single' 379 | self.prob_flip = 0.5 380 | self.ratio_scale = 1.2 381 | self.ratio_rotate = 15.0 382 | self.ratio_crop_pad = 0.125 383 | self.ratio_cutout = 0.5 # the size would be 0.5x0.5 384 | self.ratio_noise = 0.05 385 | self.brightness = 1.0 386 | self.saturation = 2.0 387 | self.contrast = 0.5 388 | 389 | 390 | def set_seed_DiffAug(param): 391 | if param.latestseed == -1: 392 | return 393 | else: 394 | torch.random.manual_seed(param.latestseed) 395 | param.latestseed += 1 396 | 397 | 398 | def DiffAugment(x, strategy='', seed = -1, param = None): 399 | if seed == -1: 400 | param.Siamese = False 401 | else: 402 | param.Siamese = True 403 | 404 | param.latestseed = seed 405 | 406 | if strategy == 'None' or strategy == 'none': 407 | return x 408 | 409 | if strategy: 410 | if param.aug_mode == 'M': # original 411 | for p in strategy.split('_'): 412 | for f in AUGMENT_FNS[p]: 413 | x = f(x, param) 414 | elif param.aug_mode == 'S': 415 | pbties = strategy.split('_') 416 | set_seed_DiffAug(param) 417 | p = pbties[torch.randint(0, len(pbties), size=(1,)).item()] 418 | for f in AUGMENT_FNS[p]: 419 | x = f(x, param) 420 | else: 421 | exit('unknown augmentation mode: %s'%param.aug_mode) 422 | x = x.contiguous() 423 | return x 424 | 425 | 426 | # We implement the following differentiable augmentation strategies based on the code provided in https://github.com/mit-han-lab/data-efficient-gans. 427 | def rand_scale(x, param): 428 | # x>1, max scale 429 | # sx, sy: (0, +oo), 1: orignial size, 0.5: enlarge 2 times 430 | ratio = param.ratio_scale 431 | set_seed_DiffAug(param) 432 | sx = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio 433 | set_seed_DiffAug(param) 434 | sy = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio 435 | theta = [[[sx[i], 0, 0], 436 | [0, sy[i], 0],] for i in range(x.shape[0])] 437 | theta = torch.tensor(theta, dtype=torch.float) 438 | if param.Siamese: # Siamese augmentation: 439 | theta[:] = theta[0].clone() 440 | grid = F.affine_grid(theta, x.shape).to(x.device) 441 | x = F.grid_sample(x, grid) 442 | return x 443 | 444 | 445 | def rand_rotate(x, param): # [-180, 180], 90: anticlockwise 90 degree 446 | ratio = param.ratio_rotate 447 | set_seed_DiffAug(param) 448 | theta = (torch.rand(x.shape[0]) - 0.5) * 2 * ratio / 180 * float(np.pi) 449 | theta = [[[torch.cos(theta[i]), torch.sin(-theta[i]), 0], 450 | [torch.sin(theta[i]), torch.cos(theta[i]), 0],] for i in range(x.shape[0])] 451 | theta = torch.tensor(theta, dtype=torch.float) 452 | if param.Siamese: # Siamese augmentation: 453 | theta[:] = theta[0].clone() 454 | grid = F.affine_grid(theta, x.shape).to(x.device) 455 | x = F.grid_sample(x, grid) 456 | return x 457 | 458 | 459 | def rand_flip(x, param): 460 | prob = param.prob_flip 461 | set_seed_DiffAug(param) 462 | randf = torch.rand(x.size(0), 1, 1, 1, device=x.device) 463 | if param.Siamese: # Siamese augmentation: 464 | randf[:] = randf[0].clone() 465 | return torch.where(randf < prob, x.flip(3), x) 466 | 467 | 468 | def rand_brightness(x, param): 469 | ratio = param.brightness 470 | set_seed_DiffAug(param) 471 | randb = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) 472 | if param.Siamese: # Siamese augmentation: 473 | randb[:] = randb[0].clone() 474 | x = x + (randb - 0.5)*ratio 475 | return x 476 | 477 | 478 | def rand_saturation(x, param): 479 | ratio = param.saturation 480 | x_mean = x.mean(dim=1, keepdim=True) 481 | set_seed_DiffAug(param) 482 | rands = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) 483 | if param.Siamese: # Siamese augmentation: 484 | rands[:] = rands[0].clone() 485 | x = (x - x_mean) * (rands * ratio) + x_mean 486 | return x 487 | 488 | 489 | def rand_contrast(x, param): 490 | ratio = param.contrast 491 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 492 | set_seed_DiffAug(param) 493 | randc = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) 494 | if param.Siamese: # Siamese augmentation: 495 | randc[:] = randc[0].clone() 496 | x = (x - x_mean) * (randc + ratio) + x_mean 497 | return x 498 | 499 | 500 | def rand_crop(x, param): 501 | # The image is padded on its surrounding and then cropped. 502 | ratio = param.ratio_crop_pad 503 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 504 | set_seed_DiffAug(param) 505 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) 506 | set_seed_DiffAug(param) 507 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) 508 | if param.Siamese: # Siamese augmentation: 509 | translation_x[:] = translation_x[0].clone() 510 | translation_y[:] = translation_y[0].clone() 511 | grid_batch, grid_x, grid_y = torch.meshgrid( 512 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 513 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 514 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 515 | ) 516 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 517 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 518 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 519 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) 520 | return x 521 | 522 | 523 | def rand_cutout(x, param): 524 | ratio = param.ratio_cutout 525 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 526 | set_seed_DiffAug(param) 527 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) 528 | set_seed_DiffAug(param) 529 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) 530 | if param.Siamese: # Siamese augmentation: 531 | offset_x[:] = offset_x[0].clone() 532 | offset_y[:] = offset_y[0].clone() 533 | grid_batch, grid_x, grid_y = torch.meshgrid( 534 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 535 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 536 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 537 | ) 538 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) 539 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) 540 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 541 | mask[grid_batch, grid_x, grid_y] = 0 542 | x = x * mask.unsqueeze(1) 543 | return x 544 | 545 | 546 | AUGMENT_FNS = { 547 | 'color': [rand_brightness, rand_saturation, rand_contrast], 548 | 'crop': [rand_crop], 549 | 'cutout': [rand_cutout], 550 | 'flip': [rand_flip], 551 | 'scale': [rand_scale], 552 | 'rotate': [rand_rotate], 553 | } -------------------------------------------------------------------------------- /dataset/split_file/RetinaMNIST224_client_num=10_alpha=0.02.json: -------------------------------------------------------------------------------- 1 | { 2 | "client_idx": [ 3 | [ 4 | 530, 5 | 635, 6 | 357, 7 | 145, 8 | 709, 9 | 10, 10 | 398, 11 | 818, 12 | 170, 13 | 703, 14 | 120, 15 | 786, 16 | 127, 17 | 57, 18 | 1036, 19 | 291, 20 | 785, 21 | 741, 22 | 433, 23 | 1049, 24 | 847, 25 | 1004, 26 | 354, 27 | 243, 28 | 53, 29 | 459, 30 | 186, 31 | 392, 32 | 15, 33 | 3, 34 | 777, 35 | 594, 36 | 434, 37 | 963, 38 | 844, 39 | 861, 40 | 890, 41 | 190 42 | ], 43 | [ 44 | 277, 45 | 386, 46 | 1021, 47 | 743, 48 | 582, 49 | 230, 50 | 465, 51 | 1008, 52 | 130, 53 | 146, 54 | 224, 55 | 1068, 56 | 268, 57 | 85, 58 | 163, 59 | 672, 60 | 337, 61 | 202 62 | ], 63 | [ 64 | 125, 65 | 922, 66 | 913, 67 | 184, 68 | 90, 69 | 588, 70 | 701, 71 | 316, 72 | 437, 73 | 301, 74 | 653, 75 | 346, 76 | 20, 77 | 1032, 78 | 227, 79 | 399, 80 | 46, 81 | 1079, 82 | 869, 83 | 258, 84 | 946, 85 | 340, 86 | 526, 87 | 794, 88 | 918 89 | ], 90 | [ 91 | 757, 92 | 397, 93 | 56, 94 | 803, 95 | 925, 96 | 237, 97 | 103, 98 | 744, 99 | 564, 100 | 59, 101 | 659, 102 | 934, 103 | 280, 104 | 1028, 105 | 446, 106 | 121, 107 | 625, 108 | 87, 109 | 114, 110 | 867, 111 | 77, 112 | 390, 113 | 423, 114 | 677, 115 | 61, 116 | 259, 117 | 78, 118 | 762, 119 | 658, 120 | 997, 121 | 362, 122 | 872, 123 | 240, 124 | 642, 125 | 64, 126 | 35, 127 | 568, 128 | 147, 129 | 31, 130 | 515, 131 | 315, 132 | 630, 133 | 816, 134 | 308, 135 | 457, 136 | 961, 137 | 640, 138 | 950, 139 | 622, 140 | 931, 141 | 466, 142 | 838, 143 | 1016, 144 | 807, 145 | 84, 146 | 541, 147 | 155, 148 | 247 149 | ], 150 | [ 151 | 866, 152 | 662, 153 | 91, 154 | 726, 155 | 82, 156 | 998, 157 | 1072, 158 | 408, 159 | 80, 160 | 282, 161 | 812, 162 | 21, 163 | 384, 164 | 527, 165 | 639, 166 | 272, 167 | 813, 168 | 205, 169 | 597, 170 | 571, 171 | 45, 172 | 244, 173 | 754, 174 | 560, 175 | 245, 176 | 302, 177 | 376, 178 | 561, 179 | 162, 180 | 554, 181 | 575, 182 | 853, 183 | 326, 184 | 945, 185 | 532, 186 | 822, 187 | 1050, 188 | 705, 189 | 674, 190 | 584, 191 | 474, 192 | 112, 193 | 645, 194 | 669, 195 | 576, 196 | 281, 197 | 1060, 198 | 445, 199 | 996, 200 | 740, 201 | 254, 202 | 430, 203 | 902, 204 | 356, 205 | 965, 206 | 678, 207 | 32, 208 | 13, 209 | 738, 210 | 759, 211 | 1077, 212 | 118, 213 | 841, 214 | 485, 215 | 556, 216 | 210, 217 | 158, 218 | 623, 219 | 60, 220 | 614, 221 | 74, 222 | 478, 223 | 196, 224 | 938, 225 | 359, 226 | 140, 227 | 297, 228 | 276, 229 | 1053, 230 | 1065, 231 | 286, 232 | 241, 233 | 1078, 234 | 94, 235 | 949, 236 | 937, 237 | 293, 238 | 55, 239 | 769, 240 | 368, 241 | 707, 242 | 393, 243 | 893, 244 | 1029, 245 | 443, 246 | 9, 247 | 187, 248 | 494, 249 | 784, 250 | 739, 251 | 100, 252 | 1040, 253 | 195, 254 | 908, 255 | 878, 256 | 706, 257 | 968, 258 | 481, 259 | 52, 260 | 1048, 261 | 1070, 262 | 475, 263 | 763, 264 | 856, 265 | 577, 266 | 440, 267 | 613, 268 | 1035, 269 | 497, 270 | 463, 271 | 675, 272 | 128, 273 | 27, 274 | 591, 275 | 251, 276 | 425, 277 | 99, 278 | 898, 279 | 667, 280 | 676, 281 | 298, 282 | 929, 283 | 536, 284 | 543, 285 | 660, 286 | 104, 287 | 820, 288 | 1056, 289 | 419, 290 | 310, 291 | 115, 292 | 520, 293 | 656, 294 | 414, 295 | 159, 296 | 101, 297 | 113, 298 | 126, 299 | 995, 300 | 871, 301 | 263, 302 | 42, 303 | 647, 304 | 587, 305 | 717, 306 | 801, 307 | 318, 308 | 129, 309 | 460, 310 | 1074, 311 | 501, 312 | 508, 313 | 502, 314 | 490, 315 | 529, 316 | 882, 317 | 654, 318 | 365, 319 | 304, 320 | 722, 321 | 98, 322 | 47, 323 | 36, 324 | 312, 325 | 292, 326 | 694, 327 | 924, 328 | 874, 329 | 1007, 330 | 283, 331 | 58, 332 | 1057, 333 | 604, 334 | 850, 335 | 67, 336 | 143, 337 | 519, 338 | 348, 339 | 110, 340 | 391, 341 | 436 342 | ], 343 | [ 344 | 34, 345 | 83, 346 | 661, 347 | 109, 348 | 620, 349 | 1015, 350 | 471, 351 | 670, 352 | 137, 353 | 417, 354 | 727, 355 | 933, 356 | 458, 357 | 713, 358 | 967, 359 | 742, 360 | 358, 361 | 54, 362 | 200, 363 | 349, 364 | 966, 365 | 714, 366 | 439, 367 | 787, 368 | 504, 369 | 894, 370 | 332, 371 | 418, 372 | 17, 373 | 355, 374 | 1003, 375 | 374, 376 | 795, 377 | 546, 378 | 852, 379 | 956, 380 | 796, 381 | 498, 382 | 686, 383 | 102, 384 | 624, 385 | 1045, 386 | 208, 387 | 565, 388 | 775, 389 | 1054, 390 | 404, 391 | 51, 392 | 779, 393 | 752, 394 | 523, 395 | 524, 396 | 395, 397 | 958, 398 | 30, 399 | 798, 400 | 695, 401 | 119, 402 | 178, 403 | 603, 404 | 987, 405 | 715, 406 | 1059, 407 | 168, 408 | 95, 409 | 943, 410 | 730, 411 | 232, 412 | 631, 413 | 855, 414 | 765, 415 | 644, 416 | 778, 417 | 203, 418 | 680, 419 | 935, 420 | 271, 421 | 284, 422 | 849, 423 | 24, 424 | 201, 425 | 454, 426 | 223, 427 | 846, 428 | 724, 429 | 1020, 430 | 107, 431 | 692, 432 | 761, 433 | 260, 434 | 960, 435 | 438, 436 | 780, 437 | 444, 438 | 864, 439 | 451, 440 | 236, 441 | 904, 442 | 789, 443 | 747, 444 | 409, 445 | 959, 446 | 710, 447 | 936, 448 | 1058, 449 | 562, 450 | 214, 451 | 790, 452 | 977, 453 | 167, 454 | 11, 455 | 132, 456 | 947, 457 | 71, 458 | 828, 459 | 242, 460 | 1073, 461 | 421, 462 | 40, 463 | 772, 464 | 192 465 | ], 466 | [ 467 | 638, 468 | 388, 469 | 942, 470 | 197, 471 | 863, 472 | 736, 473 | 410, 474 | 257, 475 | 189, 476 | 1062, 477 | 540, 478 | 228, 479 | 319, 480 | 704, 481 | 1047, 482 | 371, 483 | 583, 484 | 317, 485 | 926, 486 | 811, 487 | 978, 488 | 885, 489 | 650, 490 | 174, 491 | 627, 492 | 684, 493 | 370, 494 | 749, 495 | 823, 496 | 274, 497 | 461, 498 | 22, 499 | 580, 500 | 830, 501 | 737, 502 | 865, 503 | 767, 504 | 941, 505 | 601, 506 | 599, 507 | 66, 508 | 188, 509 | 138, 510 | 1055, 511 | 834, 512 | 671, 513 | 300, 514 | 331, 515 | 116, 516 | 206, 517 | 209, 518 | 428, 519 | 586, 520 | 333, 521 | 679, 522 | 514, 523 | 992, 524 | 427, 525 | 688, 526 | 1017, 527 | 1039, 528 | 26, 529 | 877, 530 | 868, 531 | 464, 532 | 89, 533 | 122, 534 | 932, 535 | 229, 536 | 442, 537 | 666, 538 | 626, 539 | 1018, 540 | 148, 541 | 550, 542 | 1019, 543 | 97, 544 | 719, 545 | 29, 546 | 360, 547 | 344, 548 | 792, 549 | 858, 550 | 221, 551 | 48, 552 | 305, 553 | 234, 554 | 951, 555 | 364, 556 | 216, 557 | 810, 558 | 345, 559 | 339, 560 | 1010, 561 | 982, 562 | 469, 563 | 336, 564 | 1002, 565 | 262, 566 | 1037, 567 | 839, 568 | 728, 569 | 766, 570 | 862, 571 | 1064, 572 | 897, 573 | 509, 574 | 955, 575 | 194, 576 | 1046, 577 | 589, 578 | 817, 579 | 422, 580 | 1043, 581 | 153, 582 | 330, 583 | 350, 584 | 1005, 585 | 1000, 586 | 289, 587 | 307, 588 | 618, 589 | 716, 590 | 4, 591 | 848, 592 | 164, 593 | 615, 594 | 415, 595 | 512, 596 | 712, 597 | 75, 598 | 8, 599 | 1042, 600 | 403, 601 | 191, 602 | 488, 603 | 988, 604 | 593, 605 | 881, 606 | 411, 607 | 598, 608 | 689, 609 | 788, 610 | 843, 611 | 637, 612 | 991, 613 | 735, 614 | 921, 615 | 781, 616 | 729, 617 | 770, 618 | 486, 619 | 429, 620 | 513, 621 | 939, 622 | 821, 623 | 151, 624 | 199, 625 | 875, 626 | 495, 627 | 447, 628 | 33, 629 | 558, 630 | 135, 631 | 999, 632 | 182, 633 | 69, 634 | 37, 635 | 976 636 | ], 637 | [ 638 | 373, 639 | 746, 640 | 377, 641 | 269, 642 | 366, 643 | 572, 644 | 1, 645 | 628, 646 | 923, 647 | 343, 648 | 238, 649 | 906, 650 | 774, 651 | 369, 652 | 690, 653 | 673, 654 | 845, 655 | 341, 656 | 944, 657 | 306, 658 | 590, 659 | 334, 660 | 682, 661 | 708, 662 | 1044, 663 | 28, 664 | 859, 665 | 1030, 666 | 957, 667 | 165, 668 | 804, 669 | 914, 670 | 699, 671 | 909, 672 | 711, 673 | 381, 674 | 160, 675 | 154, 676 | 648, 677 | 506, 678 | 363, 679 | 969, 680 | 606, 681 | 619, 682 | 222, 683 | 1051, 684 | 275, 685 | 964, 686 | 538, 687 | 782, 688 | 111, 689 | 309, 690 | 321, 691 | 687, 692 | 81, 693 | 570, 694 | 518, 695 | 753, 696 | 915, 697 | 383, 698 | 467, 699 | 910, 700 | 824, 701 | 611, 702 | 322, 703 | 180, 704 | 23, 705 | 889, 706 | 351, 707 | 63, 708 | 919, 709 | 252, 710 | 424, 711 | 984, 712 | 296, 713 | 552, 714 | 750, 715 | 401, 716 | 723, 717 | 832, 718 | 226, 719 | 141, 720 | 940, 721 | 1033, 722 | 235, 723 | 265, 724 | 62, 725 | 920, 726 | 905, 727 | 917, 728 | 493, 729 | 566, 730 | 632, 731 | 511, 732 | 1024, 733 | 751, 734 | 883, 735 | 607, 736 | 479, 737 | 539, 738 | 651, 739 | 702, 740 | 870, 741 | 492, 742 | 448, 743 | 574, 744 | 1052, 745 | 412, 746 | 172, 747 | 484, 748 | 149, 749 | 385, 750 | 643, 751 | 152, 752 | 86, 753 | 181, 754 | 755, 755 | 974, 756 | 117, 757 | 1069, 758 | 347, 759 | 557, 760 | 857, 761 | 617, 762 | 545, 763 | 578, 764 | 157, 765 | 903, 766 | 19, 767 | 592, 768 | 394, 769 | 860, 770 | 953, 771 | 1067, 772 | 329, 773 | 544, 774 | 522, 775 | 239, 776 | 314, 777 | 161, 778 | 854, 779 | 517, 780 | 891, 781 | 696, 782 | 139, 783 | 1038, 784 | 718, 785 | 449, 786 | 88, 787 | 649, 788 | 733, 789 | 831, 790 | 610, 791 | 815, 792 | 68, 793 | 456, 794 | 215, 795 | 826, 796 | 768, 797 | 873, 798 | 92, 799 | 14, 800 | 900, 801 | 175, 802 | 496, 803 | 962, 804 | 324, 805 | 895, 806 | 500, 807 | 758, 808 | 426, 809 | 299, 810 | 204, 811 | 970, 812 | 930, 813 | 420, 814 | 5, 815 | 12, 816 | 528, 817 | 569, 818 | 207, 819 | 171, 820 | 681, 821 | 1023, 822 | 972, 823 | 551, 824 | 952, 825 | 231, 826 | 1001, 827 | 793, 828 | 776, 829 | 218, 830 | 825, 831 | 96, 832 | 1034, 833 | 323, 834 | 549, 835 | 954, 836 | 169, 837 | 842, 838 | 16, 839 | 267, 840 | 452, 841 | 989, 842 | 655, 843 | 683, 844 | 896, 845 | 791, 846 | 685, 847 | 665, 848 | 342, 849 | 563, 850 | 213, 851 | 720, 852 | 389, 853 | 596, 854 | 507, 855 | 413, 856 | 616, 857 | 248, 858 | 612, 859 | 25, 860 | 573, 861 | 79, 862 | 320, 863 | 983, 864 | 646, 865 | 41, 866 | 585, 867 | 50, 868 | 725, 869 | 489, 870 | 179, 871 | 535, 872 | 1066, 873 | 975, 874 | 986, 875 | 748, 876 | 279, 877 | 106, 878 | 76, 879 | 879, 880 | 691, 881 | 985, 882 | 827, 883 | 1026, 884 | 407, 885 | 948, 886 | 450, 887 | 911, 888 | 800, 889 | 605, 890 | 39, 891 | 375, 892 | 432, 893 | 912, 894 | 270, 895 | 0, 896 | 693, 897 | 534, 898 | 327, 899 | 455, 900 | 142, 901 | 166, 902 | 220, 903 | 1076, 904 | 211, 905 | 441, 906 | 608, 907 | 303, 908 | 290, 909 | 567, 910 | 837, 911 | 7, 912 | 328, 913 | 219, 914 | 380, 915 | 198, 916 | 1011, 917 | 809, 918 | 668, 919 | 483, 920 | 387, 921 | 311, 922 | 295, 923 | 808, 924 | 468, 925 | 802, 926 | 806, 927 | 212, 928 | 249, 929 | 73, 930 | 731, 931 | 105, 932 | 1022, 933 | 246, 934 | 510, 935 | 253, 936 | 185, 937 | 416, 938 | 981, 939 | 287, 940 | 1025, 941 | 124, 942 | 353, 943 | 993, 944 | 1027, 945 | 435, 946 | 773, 947 | 851, 948 | 261, 949 | 833, 950 | 70, 951 | 503, 952 | 559, 953 | 473, 954 | 732, 955 | 472, 956 | 548, 957 | 480, 958 | 176, 959 | 65, 960 | 256, 961 | 487, 962 | 136, 963 | 928, 964 | 980, 965 | 595, 966 | 1075, 967 | 1063, 968 | 193, 969 | 783, 970 | 1061, 971 | 177, 972 | 335, 973 | 352, 974 | 829, 975 | 313, 976 | 378, 977 | 233, 978 | 771, 979 | 372, 980 | 641, 981 | 892, 982 | 537, 983 | 134, 984 | 664, 985 | 760, 986 | 183, 987 | 131, 988 | 1031, 989 | 382, 990 | 994, 991 | 462, 992 | 453, 993 | 721, 994 | 431, 995 | 379, 996 | 516, 997 | 1013, 998 | 266, 999 | 1014, 1000 | 18, 1001 | 173, 1002 | 927, 1003 | 657, 1004 | 609, 1005 | 805, 1006 | 499, 1007 | 2, 1008 | 525, 1009 | 1071, 1010 | 916, 1011 | 836, 1012 | 884, 1013 | 361, 1014 | 907, 1015 | 979, 1016 | 797, 1017 | 971, 1018 | 150, 1019 | 888, 1020 | 338, 1021 | 636, 1022 | 325, 1023 | 123, 1024 | 49, 1025 | 531, 1026 | 700, 1027 | 1041, 1028 | 899, 1029 | 43, 1030 | 476, 1031 | 579, 1032 | 547, 1033 | 294, 1034 | 600, 1035 | 886, 1036 | 400, 1037 | 217, 1038 | 887, 1039 | 405, 1040 | 482, 1041 | 880, 1042 | 697, 1043 | 819, 1044 | 156, 1045 | 396, 1046 | 814, 1047 | 72, 1048 | 555, 1049 | 1006, 1050 | 285, 1051 | 367, 1052 | 93, 1053 | 652, 1054 | 273, 1055 | 144, 1056 | 901, 1057 | 44, 1058 | 1012, 1059 | 491 1060 | ], 1061 | [ 1062 | 756, 1063 | 505, 1064 | 581, 1065 | 133, 1066 | 533, 1067 | 470, 1068 | 108, 1069 | 406, 1070 | 876, 1071 | 698, 1072 | 402, 1073 | 840, 1074 | 477, 1075 | 38, 1076 | 278, 1077 | 629, 1078 | 521, 1079 | 288 1080 | ], 1081 | [ 1082 | 835, 1083 | 799, 1084 | 1009, 1085 | 745, 1086 | 264, 1087 | 973, 1088 | 255, 1089 | 634, 1090 | 734, 1091 | 663, 1092 | 621, 1093 | 990, 1094 | 542, 1095 | 553, 1096 | 602, 1097 | 250, 1098 | 764, 1099 | 6, 1100 | 225, 1101 | 633 1102 | ] 1103 | ], 1104 | "client_classes": [ 1105 | [ 1106 | 4 1107 | ], 1108 | [ 1109 | 2 1110 | ], 1111 | [ 1112 | 3 1113 | ], 1114 | [ 1115 | 0 1116 | ], 1117 | [ 1118 | 2 1119 | ], 1120 | [ 1121 | 1 1122 | ], 1123 | [ 1124 | 3 1125 | ], 1126 | [ 1127 | 0 1128 | ], 1129 | [ 1130 | 4 1131 | ], 1132 | [ 1133 | 4 1134 | ] 1135 | ] 1136 | } -------------------------------------------------------------------------------- /dataset/split_file/RetinaMNIST224_client_num=10_alpha=0.05.json: -------------------------------------------------------------------------------- 1 | { 2 | "client_idx": [ 3 | [ 4 | 764, 5 | 477, 6 | 621, 7 | 243, 8 | 278, 9 | 542, 10 | 973, 11 | 861, 12 | 777, 13 | 756, 14 | 3, 15 | 1009, 16 | 145, 17 | 120, 18 | 629, 19 | 264, 20 | 38, 21 | 398, 22 | 127, 23 | 709, 24 | 15, 25 | 1036, 26 | 890, 27 | 392, 28 | 505, 29 | 470, 30 | 602, 31 | 745, 32 | 288, 33 | 703, 34 | 10, 35 | 847, 36 | 876, 37 | 741, 38 | 357, 39 | 698, 40 | 57, 41 | 190 42 | ], 43 | [ 44 | 735, 45 | 437, 46 | 301, 47 | 173, 48 | 837, 49 | 347, 50 | 448, 51 | 955, 52 | 338, 53 | 615, 54 | 688, 55 | 123, 56 | 1005, 57 | 732, 58 | 234, 59 | 97, 60 | 593, 61 | 330, 62 | 689, 63 | 431, 64 | 1017, 65 | 29, 66 | 4, 67 | 1043, 68 | 191, 69 | 650, 70 | 728, 71 | 599, 72 | 188, 73 | 618, 74 | 306, 75 | 394, 76 | 1042, 77 | 991, 78 | 579, 79 | 719, 80 | 410, 81 | 1000, 82 | 344, 83 | 1055, 84 | 309, 85 | 289, 86 | 427, 87 | 993, 88 | 317, 89 | 1064, 90 | 92, 91 | 583, 92 | 215, 93 | 464, 94 | 69, 95 | 939, 96 | 83, 97 | 388, 98 | 839, 99 | 875, 100 | 189, 101 | 228, 102 | 742, 103 | 954, 104 | 371, 105 | 346, 106 | 679, 107 | 858, 108 | 951, 109 | 601, 110 | 793, 111 | 804, 112 | 626, 113 | 1047, 114 | 885, 115 | 865, 116 | 66, 117 | 781, 118 | 336, 119 | 1039, 120 | 319, 121 | 229, 122 | 648, 123 | 26, 124 | 447, 125 | 311, 126 | 429, 127 | 922, 128 | 588, 129 | 684, 130 | 151, 131 | 580, 132 | 862, 133 | 258, 134 | 227, 135 | 863, 136 | 881, 137 | 197, 138 | 854, 139 | 442, 140 | 193, 141 | 792, 142 | 868, 143 | 897, 144 | 221, 145 | 253, 146 | 486, 147 | 823, 148 | 999, 149 | 831, 150 | 153, 151 | 46, 152 | 729, 153 | 75, 154 | 403, 155 | 209, 156 | 199, 157 | 540, 158 | 836, 159 | 350, 160 | 671, 161 | 7, 162 | 449, 163 | 422, 164 | 921, 165 | 495, 166 | 37, 167 | 737, 168 | 360, 169 | 788, 170 | 843, 171 | 934, 172 | 28, 173 | 316, 174 | 659, 175 | 194, 176 | 1032, 177 | 830, 178 | 90, 179 | 1018, 180 | 415, 181 | 736, 182 | 514, 183 | 54, 184 | 406, 185 | 333, 186 | 834, 187 | 89, 188 | 165, 189 | 262, 190 | 1079, 191 | 8, 192 | 488, 193 | 574, 194 | 1024, 195 | 1026, 196 | 489, 197 | 869, 198 | 327, 199 | 56, 200 | 627, 201 | 370, 202 | 657, 203 | 666, 204 | 181, 205 | 164, 206 | 926, 207 | 682, 208 | 817, 209 | 941, 210 | 1063, 211 | 976, 212 | 767, 213 | 138, 214 | 848, 215 | 122, 216 | 550, 217 | 596, 218 | 461, 219 | 913, 220 | 749, 221 | 450, 222 | 62, 223 | 988, 224 | 331, 225 | 345, 226 | 932, 227 | 632, 228 | 766, 229 | 504, 230 | 86, 231 | 378, 232 | 340, 233 | 216, 234 | 566, 235 | 1049, 236 | 135, 237 | 978, 238 | 712, 239 | 358, 240 | 364, 241 | 600, 242 | 1038, 243 | 322, 244 | 638, 245 | 117, 246 | 155, 247 | 809, 248 | 810, 249 | 701 250 | ], 251 | [ 252 | 133, 253 | 835, 254 | 291, 255 | 721, 256 | 961, 257 | 108, 258 | 785, 259 | 799, 260 | 250, 261 | 402, 262 | 844, 263 | 840, 264 | 963, 265 | 186, 266 | 296, 267 | 498, 268 | 380, 269 | 977, 270 | 225, 271 | 818, 272 | 6, 273 | 1004, 274 | 786, 275 | 521, 276 | 53, 277 | 990, 278 | 663, 279 | 768, 280 | 581, 281 | 683, 282 | 533, 283 | 720, 284 | 170, 285 | 354, 286 | 974, 287 | 594, 288 | 633, 289 | 631, 290 | 255, 291 | 634, 292 | 695 293 | ], 294 | [ 295 | 383, 296 | 351, 297 | 270, 298 | 517, 299 | 642, 300 | 369, 301 | 423, 302 | 696, 303 | 612, 304 | 226, 305 | 687, 306 | 303, 307 | 1071, 308 | 14, 309 | 655, 310 | 177, 311 | 500, 312 | 111, 313 | 872, 314 | 980, 315 | 400, 316 | 549, 317 | 387, 318 | 914, 319 | 420, 320 | 1052, 321 | 916, 322 | 308, 323 | 744, 324 | 1, 325 | 903, 326 | 944, 327 | 5, 328 | 59, 329 | 851, 330 | 235, 331 | 870, 332 | 610, 333 | 261, 334 | 239, 335 | 473, 336 | 295, 337 | 972, 338 | 131, 339 | 1027, 340 | 506, 341 | 538, 342 | 219, 343 | 931, 344 | 152, 345 | 664, 346 | 748, 347 | 212, 348 | 791, 349 | 510, 350 | 389, 351 | 285, 352 | 43, 353 | 880, 354 | 909, 355 | 557, 356 | 609, 357 | 1061, 358 | 342, 359 | 487, 360 | 64, 361 | 157, 362 | 213, 363 | 70, 364 | 516, 365 | 681, 366 | 329, 367 | 552, 368 | 149, 369 | 962, 370 | 528, 371 | 827, 372 | 325, 373 | 567, 374 | 412, 375 | 757, 376 | 702, 377 | 482, 378 | 857, 379 | 294, 380 | 455, 381 | 341, 382 | 940, 383 | 279, 384 | 280, 385 | 873, 386 | 511, 387 | 541, 388 | 381, 389 | 171, 390 | 240, 391 | 435, 392 | 19, 393 | 25, 394 | 622, 395 | 537, 396 | 334, 397 | 725, 398 | 829, 399 | 805, 400 | 453, 401 | 114, 402 | 953, 403 | 611, 404 | 237, 405 | 211, 406 | 592, 407 | 838, 408 | 323, 409 | 507, 410 | 2, 411 | 539, 412 | 986, 413 | 753, 414 | 979, 415 | 78, 416 | 619, 417 | 751, 418 | 63, 419 | 204, 420 | 466, 421 | 1030, 422 | 459, 423 | 535, 424 | 172, 425 | 139, 426 | 207, 427 | 265, 428 | 366, 429 | 96, 430 | 407, 431 | 314, 432 | 39, 433 | 372, 434 | 983, 435 | 652, 436 | 617, 437 | 917, 438 | 1069, 439 | 668, 440 | 313, 441 | 147, 442 | 31, 443 | 249, 444 | 891, 445 | 23, 446 | 760, 447 | 1013, 448 | 361, 449 | 673, 450 | 452, 451 | 630, 452 | 564, 453 | 77, 454 | 432, 455 | 832, 456 | 997, 457 | 887, 458 | 665, 459 | 222, 460 | 927, 461 | 981, 462 | 491, 463 | 985, 464 | 900, 465 | 776, 466 | 658, 467 | 957, 468 | 545, 469 | 467, 470 | 930, 471 | 103, 472 | 750, 473 | 44, 474 | 1011, 475 | 734, 476 | 1023, 477 | 915, 478 | 693, 479 | 802, 480 | 651, 481 | 867, 482 | 824, 483 | 989, 484 | 559, 485 | 1025, 486 | 231, 487 | 807, 488 | 269, 489 | 61, 490 | 625, 491 | 150, 492 | 826, 493 | 969, 494 | 905, 495 | 771, 496 | 925, 497 | 320 498 | ], 499 | [ 500 | 960, 501 | 236, 502 | 167, 503 | 201, 504 | 796, 505 | 451, 506 | 292, 507 | 34, 508 | 168, 509 | 761, 510 | 1015, 511 | 1045, 512 | 251, 513 | 775, 514 | 1003, 515 | 787, 516 | 271, 517 | 242, 518 | 894, 519 | 421, 520 | 710, 521 | 444, 522 | 935, 523 | 714, 524 | 772, 525 | 192, 526 | 943, 527 | 119, 528 | 730, 529 | 109, 530 | 562, 531 | 132, 532 | 223, 533 | 418, 534 | 752, 535 | 51, 536 | 680, 537 | 780, 538 | 523, 539 | 417, 540 | 603, 541 | 727, 542 | 71, 543 | 284, 544 | 933, 545 | 620, 546 | 355, 547 | 214, 548 | 107, 549 | 438, 550 | 404, 551 | 458, 552 | 1058, 553 | 524, 554 | 692, 555 | 178, 556 | 789, 557 | 1073, 558 | 1054, 559 | 661, 560 | 849, 561 | 686, 562 | 956, 563 | 553, 564 | 40, 565 | 332, 566 | 644, 567 | 795, 568 | 30, 569 | 536, 570 | 374, 571 | 95, 572 | 779, 573 | 471, 574 | 765, 575 | 967, 576 | 137, 577 | 958, 578 | 260, 579 | 546, 580 | 713, 581 | 24, 582 | 852, 583 | 395, 584 | 790, 585 | 798, 586 | 208, 587 | 565, 588 | 828, 589 | 200, 590 | 409, 591 | 454, 592 | 778, 593 | 987, 594 | 959, 595 | 1059, 596 | 102, 597 | 813, 598 | 966, 599 | 11, 600 | 715, 601 | 349, 602 | 947, 603 | 747, 604 | 232, 605 | 904, 606 | 1020, 607 | 846, 608 | 17, 609 | 855, 610 | 864, 611 | 439, 612 | 624 613 | ], 614 | [ 615 | 196, 616 | 1056, 617 | 717, 618 | 995, 619 | 463, 620 | 478, 621 | 866, 622 | 263, 623 | 60, 624 | 527, 625 | 55, 626 | 393, 627 | 27, 628 | 662, 629 | 743, 630 | 571, 631 | 694, 632 | 440, 633 | 582, 634 | 1070, 635 | 326, 636 | 91, 637 | 604, 638 | 672, 639 | 705, 640 | 113, 641 | 298, 642 | 654, 643 | 639, 644 | 244, 645 | 669, 646 | 759, 647 | 1008, 648 | 1053, 649 | 707, 650 | 938, 651 | 245, 652 | 532, 653 | 112, 654 | 419, 655 | 501, 656 | 98, 657 | 304, 658 | 856, 659 | 195, 660 | 898, 661 | 276, 662 | 202, 663 | 365, 664 | 726, 665 | 965, 666 | 445, 667 | 187, 668 | 128, 669 | 556, 670 | 902, 671 | 575, 672 | 801 673 | ], 674 | [ 675 | 84, 676 | 716, 677 | 1046, 678 | 125, 679 | 411, 680 | 339, 681 | 946, 682 | 918, 683 | 428, 684 | 144, 685 | 526, 686 | 456, 687 | 1062, 688 | 816, 689 | 877, 690 | 307, 691 | 746, 692 | 48, 693 | 950, 694 | 305, 695 | 982, 696 | 992, 697 | 413, 698 | 257, 699 | 206, 700 | 724, 701 | 910, 702 | 586, 703 | 518, 704 | 800, 705 | 396, 706 | 148, 707 | 184, 708 | 513, 709 | 399, 710 | 606, 711 | 20, 712 | 352, 713 | 483, 714 | 635, 715 | 274, 716 | 1019, 717 | 1022, 718 | 1002, 719 | 589, 720 | 87, 721 | 994, 722 | 670, 723 | 33, 724 | 704, 725 | 558, 726 | 116, 727 | 179, 728 | 203, 729 | 174, 730 | 770, 731 | 711, 732 | 154, 733 | 68, 734 | 182, 735 | 300, 736 | 821, 737 | 530, 738 | 472, 739 | 886, 740 | 273, 741 | 901, 742 | 794, 743 | 879, 744 | 1010, 745 | 22, 746 | 105, 747 | 1067, 748 | 774, 749 | 469, 750 | 628, 751 | 811, 752 | 1076, 753 | 942, 754 | 156, 755 | 637, 756 | 217, 757 | 493, 758 | 362, 759 | 889, 760 | 1037, 761 | 718, 762 | 512, 763 | 18, 764 | 72, 765 | 88, 766 | 653 767 | ], 768 | [ 769 | 812, 770 | 224, 771 | 126, 772 | 241, 773 | 348, 774 | 302, 775 | 129, 776 | 42, 777 | 508, 778 | 769, 779 | 376, 780 | 36, 781 | 591, 782 | 509, 783 | 968, 784 | 1029, 785 | 67, 786 | 99, 787 | 45, 788 | 882, 789 | 465, 790 | 32, 791 | 822, 792 | 554, 793 | 318, 794 | 118, 795 | 210, 796 | 1040, 797 | 763, 798 | 475, 799 | 784, 800 | 850, 801 | 163, 802 | 356, 803 | 660, 804 | 1057, 805 | 115, 806 | 443, 807 | 878, 808 | 597, 809 | 359, 810 | 1078, 811 | 1048, 812 | 871, 813 | 485, 814 | 614, 815 | 310, 816 | 460, 817 | 425, 818 | 13, 819 | 853, 820 | 494, 821 | 1072, 822 | 408, 823 | 436, 824 | 268, 825 | 841, 826 | 293, 827 | 577, 828 | 576, 829 | 497, 830 | 1077, 831 | 277, 832 | 230, 833 | 391, 834 | 520, 835 | 82, 836 | 996, 837 | 297, 838 | 613, 839 | 1035, 840 | 908, 841 | 158, 842 | 52, 843 | 85, 844 | 820, 845 | 502 846 | ], 847 | [ 848 | 373, 849 | 134, 850 | 1016, 851 | 782, 852 | 607, 853 | 363, 854 | 842, 855 | 825, 856 | 1041, 857 | 971, 858 | 321, 859 | 531, 860 | 515, 861 | 646, 862 | 1051, 863 | 382, 864 | 183, 865 | 643, 866 | 476, 867 | 259, 868 | 248, 869 | 568, 870 | 815, 871 | 136, 872 | 595, 873 | 16, 874 | 608, 875 | 895, 876 | 1066, 877 | 569, 878 | 479, 879 | 920, 880 | 1075, 881 | 65, 882 | 238, 883 | 12, 884 | 335, 885 | 911, 886 | 1034, 887 | 731, 888 | 685, 889 | 81, 890 | 50, 891 | 275, 892 | 121, 893 | 563, 894 | 700, 895 | 252, 896 | 952, 897 | 697, 898 | 0, 899 | 397, 900 | 970, 901 | 1001, 902 | 551, 903 | 773, 904 | 462, 905 | 416, 906 | 570, 907 | 1031, 908 | 884, 909 | 548, 910 | 964, 911 | 585, 912 | 457, 913 | 79, 914 | 572, 915 | 547, 916 | 343, 917 | 578, 918 | 49, 919 | 1006, 920 | 808, 921 | 324, 922 | 35, 923 | 896, 924 | 246, 925 | 446, 926 | 928, 927 | 783, 928 | 401, 929 | 892, 930 | 1033, 931 | 912, 932 | 975, 933 | 484, 934 | 267, 935 | 73, 936 | 496, 937 | 256, 938 | 160, 939 | 185, 940 | 677, 941 | 124, 942 | 860, 943 | 424, 944 | 433, 945 | 522, 946 | 41, 947 | 176, 948 | 906, 949 | 218, 950 | 441, 951 | 733, 952 | 590, 953 | 480, 954 | 723, 955 | 797, 956 | 375, 957 | 555, 958 | 141, 959 | 492, 960 | 198, 961 | 1028, 962 | 353, 963 | 266, 964 | 699, 965 | 93, 966 | 690, 967 | 819, 968 | 175, 969 | 640, 970 | 106, 971 | 233, 972 | 426, 973 | 755, 974 | 1014, 975 | 636, 976 | 76, 977 | 616, 978 | 169, 979 | 923, 980 | 379, 981 | 1044, 982 | 573, 983 | 641, 984 | 534, 985 | 142, 986 | 385, 987 | 405, 988 | 907, 989 | 762, 990 | 287, 991 | 290, 992 | 468, 993 | 859, 994 | 328, 995 | 315, 996 | 833, 997 | 161, 998 | 919, 999 | 247, 1000 | 691, 1001 | 806, 1002 | 220, 1003 | 649, 1004 | 299, 1005 | 948, 1006 | 605, 1007 | 845, 1008 | 499, 1009 | 503, 1010 | 525, 1011 | 888, 1012 | 708, 1013 | 758, 1014 | 390, 1015 | 544, 1016 | 899, 1017 | 1012, 1018 | 377, 1019 | 984, 1020 | 166, 1021 | 367, 1022 | 803, 1023 | 883, 1024 | 180, 1025 | 814 1026 | ], 1027 | [ 1028 | 645, 1029 | 254, 1030 | 490, 1031 | 1065, 1032 | 587, 1033 | 474, 1034 | 282, 1035 | 94, 1036 | 676, 1037 | 998, 1038 | 706, 1039 | 598, 1040 | 937, 1041 | 312, 1042 | 110, 1043 | 945, 1044 | 893, 1045 | 519, 1046 | 205, 1047 | 74, 1048 | 146, 1049 | 949, 1050 | 159, 1051 | 283, 1052 | 47, 1053 | 678, 1054 | 674, 1055 | 667, 1056 | 100, 1057 | 58, 1058 | 754, 1059 | 368, 1060 | 561, 1061 | 80, 1062 | 272, 1063 | 21, 1064 | 140, 1065 | 481, 1066 | 584, 1067 | 143, 1068 | 1050, 1069 | 384, 1070 | 529, 1071 | 874, 1072 | 130, 1073 | 647, 1074 | 386, 1075 | 337, 1076 | 929, 1077 | 722, 1078 | 543, 1079 | 738, 1080 | 430, 1081 | 1021, 1082 | 162, 1083 | 281, 1084 | 1060, 1085 | 286, 1086 | 936, 1087 | 675, 1088 | 1068, 1089 | 623, 1090 | 740, 1091 | 9, 1092 | 414, 1093 | 560, 1094 | 1074, 1095 | 101, 1096 | 739, 1097 | 104, 1098 | 924, 1099 | 1007, 1100 | 656, 1101 | 434 1102 | ] 1103 | ], 1104 | "client_classes": [ 1105 | [ 1106 | 4 1107 | ], 1108 | [ 1109 | 0, 1110 | 3 1111 | ], 1112 | [ 1113 | 4 1114 | ], 1115 | [ 1116 | 0 1117 | ], 1118 | [ 1119 | 1 1120 | ], 1121 | [ 1122 | 2 1123 | ], 1124 | [ 1125 | 0, 1126 | 3 1127 | ], 1128 | [ 1129 | 2 1130 | ], 1131 | [ 1132 | 0 1133 | ], 1134 | [ 1135 | 2 1136 | ] 1137 | ] 1138 | } --------------------------------------------------------------------------------