├── requirements.txt ├── LICENSE ├── tools ├── preprocess_diabetic_dataset.py ├── preprocess_rsna_dataset.py └── preprocess_isic_diabetic_dataset.py ├── models ├── __init__.py ├── resnet_cifar.py ├── wresnet.py └── vgg.py ├── dataset_isic2019.py ├── README.md ├── utils.py ├── train_classifier.py └── main_fedisca.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | medmnist==2.1.0 4 | tqdm 5 | pandas 6 | scikit-learn 7 | pydicom -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Myeongkyun Kang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tools/preprocess_diabetic_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from shutil import copy2 3 | 4 | 5 | def read_label(csv_path): 6 | data_list = [] 7 | with open(csv_path) as f: 8 | f.readline() # remove header 9 | while True: 10 | line = f.readline() 11 | if len(line) == 0: 12 | break 13 | line_split = line.strip().split(',') 14 | data_list.append((f'{line_split[0]}.jpg', int(line_split[1]))) 15 | return data_list 16 | 17 | 18 | if __name__ == '__main__': 19 | result_dir = './dataset/diabetic2015' 20 | target_size = (224, 224) 21 | 22 | data_dir = './dataset/Resized 2015 & 2019 Blindness Detection Images' 23 | train_data_dir = os.path.join(data_dir, 'resized train 15_224') 24 | test_data_dir = os.path.join(data_dir, 'resized test 15_224') 25 | 26 | label_dir = os.path.join(data_dir, 'labels') 27 | train_label_path = os.path.join(label_dir, 'trainLabels15.csv') 28 | test_label_path = os.path.join(label_dir, 'testLabels15.csv') 29 | 30 | result_train_dir = os.path.join(result_dir, 'train') 31 | result_test_dir = os.path.join(result_dir, 'test') 32 | 33 | train_list = read_label(train_label_path) 34 | test_list = read_label(test_label_path) 35 | 36 | for filename, label in train_list: 37 | img_path = os.path.join(train_data_dir, filename) 38 | save_path = os.path.join(result_train_dir, str(label), filename) 39 | os.makedirs(os.path.join(result_train_dir, str(label)), exist_ok=True) 40 | 41 | if os.path.isfile(img_path): 42 | copy2(img_path, save_path) 43 | else: 44 | print('skip:', img_path) 45 | 46 | for filename, label in test_list: 47 | img_path = os.path.join(test_data_dir, filename) 48 | save_path = os.path.join(result_test_dir, str(label), filename) 49 | os.makedirs(os.path.join(result_test_dir, str(label)), exist_ok=True) 50 | 51 | if os.path.isfile(img_path): 52 | copy2(img_path, save_path) 53 | else: 54 | print('skip:', img_path) 55 | -------------------------------------------------------------------------------- /tools/preprocess_rsna_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import pydicom as dcm 5 | from PIL import Image 6 | from tqdm import tqdm 7 | 8 | 9 | def read_label(csv_path): 10 | data_list = [] 11 | with open(csv_path) as f: 12 | f.readline() # remove header 13 | while True: 14 | line = f.readline() 15 | if len(line) == 0: 16 | break 17 | line_split = line.strip().split(',') 18 | filename = f'{line_split[0]}.dcm' 19 | label = 0 if line_split[1] == 'Normal' else 1 20 | data_list.append((filename, label)) 21 | return data_list 22 | 23 | 24 | def save(label_list, dcm_dir, save_dir): 25 | for filename, label in tqdm(label_list): 26 | dcm_path = os.path.join(dcm_dir, filename) 27 | 28 | # read dcm 29 | ds = dcm.dcmread(dcm_path) 30 | img = Image.fromarray(ds.pixel_array, 'L').convert('RGB') 31 | img = img.resize(target_size, resample=Image.BILINEAR) 32 | 33 | img.save(os.path.join(save_dir, str(label), filename.replace('.dcm', '.png'))) 34 | 35 | 36 | if __name__ == '__main__': 37 | target_size = (224, 224) 38 | 39 | result_dir = './dataset/rsna' 40 | result_train_dir = os.path.join(result_dir, 'train') 41 | result_test_dir = os.path.join(result_dir, 'test') 42 | os.makedirs(os.path.join(result_train_dir, '0')) 43 | os.makedirs(os.path.join(result_train_dir, '1')) 44 | os.makedirs(os.path.join(result_test_dir, '0')) 45 | os.makedirs(os.path.join(result_test_dir, '1')) 46 | 47 | data_dir = './dataset/RSNA_Pneumonia' 48 | dcm_dir = os.path.join(data_dir, 'stage_2_train_images') 49 | csv_path = os.path.join(data_dir, 'stage_2_detailed_class_info.csv') 50 | 51 | data_label_list = read_label(csv_path) 52 | random.Random(1).shuffle(data_label_list) 53 | 54 | test_num = int(len(data_label_list) * 0.1) 55 | train_label_list = data_label_list[test_num:] 56 | test_label_list = data_label_list[:test_num] 57 | 58 | save(train_label_list, dcm_dir, result_train_dir) 59 | save(test_label_list, dcm_dir, result_test_dir) 60 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.hub import load_state_dict_from_url 2 | from torchvision import models 3 | 4 | from .resnet_cifar import ResNet18, ResNet34 5 | from .vgg import vgg8_bn, vgg11_bn 6 | from .wresnet import wrn_16_2 7 | 8 | 9 | def get_model_heter(index, num_classes, in_channels=3): 10 | if index == 0: 11 | return ResNet18(in_channels=in_channels, num_classes=num_classes) 12 | elif index == 1: 13 | return ResNet34(in_channels=in_channels, num_classes=num_classes) 14 | elif index == 2: 15 | return wrn_16_2(in_channels=in_channels, num_classes=num_classes) 16 | elif index == 3: 17 | return vgg8_bn(in_channels=in_channels, num_classes=num_classes) 18 | elif index == 4: 19 | return vgg11_bn(in_channels=in_channels, num_classes=num_classes) 20 | 21 | print('WARNING Invalid Model Index:', index) 22 | return ResNet18(in_channels=in_channels, num_classes=num_classes) 23 | 24 | 25 | def get_model_heter_224(index, num_classes, in_channels=3): 26 | if index in [0, 1, 2]: 27 | net = models.resnet18(num_classes=num_classes) 28 | state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet18-f37072fd.pth', progress=True) 29 | del state_dict['fc.weight'] 30 | del state_dict['fc.bias'] 31 | net.load_state_dict(state_dict, strict=False) 32 | return net 33 | elif index == 3: 34 | net = models.resnet34(num_classes=num_classes) 35 | state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet34-b627a593.pth', progress=True) 36 | del state_dict['fc.weight'] 37 | del state_dict['fc.bias'] 38 | net.load_state_dict(state_dict, strict=False) 39 | return net 40 | elif index == 4: 41 | net = models.vgg11_bn(num_classes=num_classes) 42 | state_dict = load_state_dict_from_url('https://download.pytorch.org/models/vgg11_bn-6002323d.pth', progress=True) 43 | del state_dict['classifier.6.weight'] 44 | del state_dict['classifier.6.bias'] 45 | net.load_state_dict(state_dict, strict=False) 46 | return net 47 | 48 | print('WARNING Invalid Model Index:', index) 49 | net = models.resnet18(num_classes=num_classes) 50 | state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet18-f37072fd.pth', progress=True) 51 | del state_dict['fc.weight'] 52 | del state_dict['fc.bias'] 53 | net.load_state_dict(state_dict, strict=False) 54 | return net 55 | -------------------------------------------------------------------------------- /dataset_isic2019.py: -------------------------------------------------------------------------------- 1 | # https://github.com/owkin/FLamby/blob/main/flamby/datasets/fed_isic2019/README.md 2 | # resize_images.py with center crop 3 | 4 | # https://github.com/owkin/FLamby/blob/main/flamby/datasets/fed_isic2019/dataset.py 5 | 6 | import os 7 | import random 8 | 9 | import pandas as pd 10 | import torch 11 | from PIL import Image 12 | from torch.utils.data import Dataset 13 | 14 | 15 | class FedIsic2019(torch.utils.data.Dataset): 16 | 17 | def __init__(self, center=0, split='test', transform=None, data_path=None, val_rate=0.1): 18 | assert center in list(range(6)) + [-1] 19 | assert split in ["train", "test", "val"] 20 | 21 | self.center = center 22 | self.split = split 23 | self.data_path = data_path 24 | self.transform = transform 25 | 26 | self.train_test_split_path = os.path.join(self.data_path, 'train_test_split') 27 | self.image_dir = os.path.join(self.data_path, 'ISIC_2019_Training_Input_preprocessed') 28 | 29 | # Read train_test_split 30 | df = pd.read_csv(self.train_test_split_path) 31 | 32 | if self.center == -1: 33 | if self.split in ['train', 'val']: 34 | key = 'train' 35 | df2 = df.query("fold == '" + key + "' ").reset_index(drop=True) 36 | else: 37 | key = 'test' 38 | df2 = df.query("fold == '" + key + "' ").reset_index(drop=True) 39 | else: 40 | if self.split == 'train': 41 | key = f'train_{self.center}' 42 | df2 = df.query("fold2 == '" + key + "' ").reset_index(drop=True) 43 | elif self.split == 'val': 44 | key = f'train_{self.center}' 45 | df2 = df.query("fold2 == '" + key + "' ").reset_index(drop=True) 46 | else: 47 | key = 'test' 48 | df2 = df.query("fold == '" + key + "' ").reset_index(drop=True) 49 | 50 | images, targets = df2.image.tolist(), df2.target.tolist() # always same order 51 | samples = [(os.path.join(self.image_dir, image_name + ".jpg"), target) for image_name, target in zip(images, targets)] 52 | 53 | # shuffle with fixed seed 54 | random.Random(1).shuffle(samples) 55 | if self.center == -1: 56 | self.samples = samples # val handles with user_groups 57 | else: 58 | if self.split == 'train': 59 | self.samples = samples[int(len(samples) * val_rate):] 60 | elif self.split == 'val': 61 | self.samples = samples[:int(len(samples) * val_rate)] 62 | elif self.split == 'test': 63 | self.samples = samples 64 | else: 65 | raise ValueError('') 66 | 67 | self.targets = [s[1] for s in self.samples] 68 | 69 | def __len__(self): 70 | return len(self.samples) 71 | 72 | def __getitem__(self, idx): 73 | image_path, target = self.samples[idx] 74 | image = Image.open(image_path).convert('RGB') 75 | image = self.transform(image) 76 | return image, target 77 | -------------------------------------------------------------------------------- /tools/preprocess_isic_diabetic_dataset.py: -------------------------------------------------------------------------------- 1 | # https://github.com/owkin/FLamby/blob/main/flamby/datasets/fed_isic2019/dataset_creation_scripts/resize_images.py 2 | 3 | from __future__ import division 4 | 5 | import os 6 | 7 | import numpy 8 | import numpy as np 9 | from PIL import Image 10 | from joblib import Parallel, delayed 11 | from tqdm import tqdm 12 | 13 | 14 | def color_constancy(img, power=6, gamma=None): 15 | import cv2 16 | 17 | """ 18 | Preprocessing step to make sure that the images appear with similar brightness 19 | and contrast. 20 | See this [link}(https://en.wikipedia.org/wiki/Color_constancy) for an explanation. 21 | Thank you to [Aman Arora](https://github.com/amaarora) for this 22 | [implementation](https://github.com/amaarora/melonama) 23 | Parameters 24 | ---------- 25 | img: 3D numpy array, the original image 26 | power: int, degree of norm 27 | gamma: float, value of gamma correction 28 | """ 29 | img_dtype = img.dtype 30 | 31 | if gamma is not None: 32 | img = img.astype("uint8") 33 | look_up_table = numpy.ones((256, 1), dtype="uint8") * 0 34 | for i in range(256): 35 | look_up_table[i][0] = 255 * pow(i / 255, 1 / gamma) 36 | img = cv2.LUT(img, look_up_table) 37 | 38 | img = img.astype("float32") 39 | img_power = numpy.power(img, power) 40 | rgb_vec = numpy.power(numpy.mean(img_power, (0, 1)), 1 / power) 41 | rgb_norm = numpy.sqrt(numpy.sum(numpy.power(rgb_vec, 2.0))) 42 | rgb_vec = rgb_vec / rgb_norm 43 | rgb_vec = 1 / (rgb_vec * numpy.sqrt(3)) 44 | img = numpy.multiply(img, rgb_vec) 45 | 46 | return img.astype(img_dtype) 47 | 48 | 49 | def resize_and_maintain(path, in_path, output_path, sz: tuple, cc): 50 | """Preprocessing of images 51 | Mantains aspect ratio fo input image. Possibility to add color constancy. 52 | Thank you to [Aman Arora](https://github.com/amaarora) for this 53 | [implementation](https://github.com/amaarora/melonama) 54 | Parameters 55 | ---------- 56 | path : path to input image 57 | output_path : path to output image 58 | sz : tuple, shorter edge of resized image is sz[0] 59 | cc : color constancy is added if True 60 | """ 61 | try: 62 | fn = os.path.basename(path) 63 | img = Image.open(path) 64 | size = sz[0] 65 | old_size = img.size 66 | ratio = float(size) / min(old_size) 67 | new_size = tuple([int(x * ratio) for x in old_size]) 68 | img = img.resize(new_size, resample=Image.BILINEAR) 69 | if cc: 70 | img = color_constancy(np.array(img)) 71 | img = Image.fromarray(img) 72 | save_path = path.replace(in_path, output_path) 73 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 74 | img.save(save_path) 75 | except: 76 | print('Error:', path) 77 | 78 | 79 | def process(input_folder, output_folder, sz, cc): 80 | images = [] 81 | for root, dirs, files in os.walk(input_folder): 82 | for name in files: 83 | if name.endswith('.jpg'): 84 | images.append(os.path.join(root, name)) 85 | 86 | os.makedirs(output_folder, exist_ok=True) 87 | Parallel(n_jobs=48)( 88 | delayed(resize_and_maintain)(i, input_folder, output_folder, (sz, sz), cc) 89 | for i in tqdm(images) 90 | ) 91 | 92 | 93 | if __name__ == "__main__": 94 | input_folder = './dataset/fed_isic2019/ISIC_2019_Training_Input' 95 | output_folder = './dataset/fed_isic2019/ISIC_2019_Training_Input_preprocessed' 96 | cc = True # only for isic2019 97 | sz = 224 98 | process(input_folder, output_folder, sz, cc) 99 | 100 | input_folder = './dataset/Resized 2015 & 2019 Blindness Detection Images/resized train 15' 101 | output_folder = './dataset/Resized 2015 & 2019 Blindness Detection Images/resized train 15_224' 102 | cc = False 103 | sz = 224 104 | process(input_folder, output_folder, sz, cc) 105 | 106 | input_folder = './dataset/Resized 2015 & 2019 Blindness Detection Images/resized test 15' 107 | output_folder = './dataset/Resized 2015 & 2019 Blindness Detection Images/resized test 15_224' 108 | cc = False 109 | sz = 224 110 | process(input_folder, output_folder, sz, cc) 111 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FedISCA 2 | MICCAI2023. "One-shot Federated Learning on Medical Data using Knowledge Distillation with Image Synthesis and Client Model Adaptation" 3 | 4 | # Train Classifiers 5 | 6 | # bloodmnist, dermamnist, octmnist, pathmnist, tissuemnist 7 | CUDA_VISIBLE_DEVICES=0 python train_classifier.py \ 8 | --root ./dataset \ 9 | --output_dir ./pretrained_models \ 10 | --dataset bloodmnist --partition iid 11 | 12 | # rsna 13 | CUDA_VISIBLE_DEVICES=0 python train_classifier.py \ 14 | --root ./dataset \ 15 | --output_dir ./pretrained_models \ 16 | --dataset rsna --aug --pretrained --partition iid 17 | 18 | # diabetic2015 19 | CUDA_VISIBLE_DEVICES=0 python train_classifier.py \ 20 | --root ./dataset \ 21 | --output_dir ./pretrained_models \ 22 | --dataset diabetic2015 --aug --pretrained --partition iid 23 | 24 | # isic2019_merge 25 | CUDA_VISIBLE_DEVICES=0 python train_classifier.py \ 26 | --root ./dataset \ 27 | --output_dir ./pretrained_models \ 28 | --dataset isic2019_merge --aug --pretrained --partition iid 29 | 30 | # isic2019 31 | CUDA_VISIBLE_DEVICES=0 python train_classifier.py \ 32 | --root ./dataset \ 33 | --output_dir ./pretrained_models \ 34 | --dataset isic2019 --aug --pretrained 35 | 36 | 37 | # Run FedISCA 38 | 39 | ``` 40 | python 41 | import os 42 | GPU='0' 43 | dataset_tag_list = ['iid_5_0.6'] 44 | dataset_list = ['bloodmnist'] # 'bloodmnist', 'dermamnist', 'octmnist', 'pathmnist', 'tissuemnist' 45 | for dataset_tag in dataset_tag_list: 46 | for dataset in dataset_list: 47 | os.system(f"CUDA_VISIBLE_DEVICES={GPU} python main_fedisca.py \ 48 | --dataset {dataset} \ 49 | --root ./dataset \ 50 | --teacher_weights=./pretrained_models/{dataset}_{dataset_tag} \ 51 | --exp_descr=./results/oneshot_{dataset_tag}/{dataset}") 52 | 53 | python 54 | import os 55 | GPU='0' 56 | dataset_list = ['rsna'] # 'rsna', 'diabetic2015' 57 | dataset_tag_list = ['iid_5_0.6'] 58 | for dataset_tag in dataset_tag_list: 59 | for dataset in dataset_list: 60 | os.system(f"CUDA_VISIBLE_DEVICES={GPU} python main_fedisca.py \ 61 | --dataset {dataset} \ 62 | --root ./dataset \ 63 | --teacher_weights=./pretrained_models/{dataset}_{dataset_tag} --pretrained \ 64 | --exp_descr=./results/oneshot_{dataset_tag}/{dataset}_pretrained --bs 50 --iters_mi 1000") 65 | 66 | python 67 | import os 68 | GPU='0' 69 | dataset_list = ['isic2019'] 70 | dataset_tag_list = ['merge_iid_5_0.6'] # 'merge_iid_5_0.6', 'dirichlet_6_0.6' 71 | for dataset_tag in dataset_tag_list: 72 | for dataset in dataset_list: 73 | os.system(f"CUDA_VISIBLE_DEVICES={GPU} python main_fedisca.py \ 74 | --dataset {dataset} \ 75 | --root ./dataset \ 76 | --teacher_weights=./pretrained_models/{dataset}_{dataset_tag} --pretrained \ 77 | --exp_descr=./results/oneshot_{dataset_tag}/{dataset}_pretrained --bs 50 --iters_mi 1000") 78 | ``` 79 | 80 | 81 | # Datasets 82 | 83 | # small-scale datasets 84 | https://medmnist.com/ 85 | # download bloodmnist, dermamnist, octmnist, pathmnist, tissuemnist npz files and move to the ./datasets/medmnist directory 86 | 87 | # large-scale datasets 88 | https://www.kaggle.com/c/rsna-pneumonia-detection-challenge 89 | https://www.kaggle.com/datasets/benjaminwarner/resized-2015-2019-blindness-detection-images 90 | https://challenge.isic-archive.com/landing/2019/ 91 | https://github.com/owkin/FLamby/blob/main/flamby/datasets/fed_isic2019/README.md 92 | # check tools for preprocessing 93 | 94 | 95 | # Environments 96 | 97 | pip install -U pip 98 | pip install -r requirements.txt 99 | 100 | 101 | # Citation 102 | If you find this repository useful in your research, please cite: 103 | ``` 104 | @inproceedings{kang2023one, 105 | title={One-Shot Federated Learning on Medical Data Using Knowledge Distillation with Image Synthesis and Client Model Adaptation}, 106 | author={Kang, Myeongkyun and Chikontwe, Philip and Kim, Soopil and Jin, Kyong Hwan and Adeli, Ehsan and Pohl, Kilian M and Park, Sang Hyun}, 107 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 108 | pages={521--531}, 109 | year={2023}, 110 | organization={Springer} 111 | } 112 | ``` 113 | 114 | Thanks to works below for their implementations which were useful for this work. 115 | [DeepInversion](https://github.com/NVlabs/DeepInversion) 116 | -------------------------------------------------------------------------------- /models/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | # 2019.07.24-Changed output of forward function 2 | # Huawei Technologies Co., Ltd. 3 | # taken from https://github.com/huawei-noah/Data-Efficient-Model-Compression/blob/master/DAFL/resnet.py 4 | # for comparison with DAFL 5 | 6 | 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class BasicBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self, in_planes, planes, stride=1): 15 | super(BasicBlock, self).__init__() 16 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 17 | self.bn1 = nn.BatchNorm2d(planes) 18 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 19 | self.bn2 = nn.BatchNorm2d(planes) 20 | 21 | self.shortcut = nn.Sequential() 22 | if stride != 1 or in_planes != self.expansion * planes: 23 | self.shortcut = nn.Sequential(nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion * planes)) 24 | 25 | def forward(self, x): 26 | out = F.relu(self.bn1(self.conv1(x))) 27 | out = self.bn2(self.conv2(out)) 28 | out += self.shortcut(x) 29 | out = F.relu(out) 30 | return out 31 | 32 | 33 | class Bottleneck(nn.Module): 34 | expansion = 4 35 | 36 | def __init__(self, in_planes, planes, stride=1): 37 | super(Bottleneck, self).__init__() 38 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 39 | self.bn1 = nn.BatchNorm2d(planes) 40 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 41 | self.bn2 = nn.BatchNorm2d(planes) 42 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 43 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 44 | 45 | self.shortcut = nn.Sequential() 46 | if stride != 1 or in_planes != self.expansion * planes: 47 | self.shortcut = nn.Sequential(nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion * planes)) 48 | 49 | def forward(self, x): 50 | out = F.relu(self.bn1(self.conv1(x))) 51 | out = F.relu(self.bn2(self.conv2(out))) 52 | out = self.bn3(self.conv3(out)) 53 | out += self.shortcut(x) 54 | out = F.relu(out) 55 | return out 56 | 57 | 58 | class ResNet(nn.Module): 59 | def __init__(self, block, num_blocks, num_classes=10, in_channels=3): 60 | super(ResNet, self).__init__() 61 | self.in_planes = 64 62 | 63 | self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False) 64 | self.bn1 = nn.BatchNorm2d(64) 65 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 66 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 67 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 68 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 69 | self.linear = nn.Linear(512 * block.expansion, num_classes) 70 | 71 | def _make_layer(self, block, planes, num_blocks, stride): 72 | strides = [stride] + [1] * (num_blocks - 1) 73 | layers = [] 74 | for stride in strides: 75 | layers.append(block(self.in_planes, planes, stride)) 76 | self.in_planes = planes * block.expansion 77 | return nn.Sequential(*layers) 78 | 79 | def forward(self, x, out_feature=False): 80 | x = self.conv1(x) 81 | 82 | x = self.bn1(x) 83 | out = F.relu(x) 84 | 85 | out = self.layer1(out) 86 | out = self.layer2(out) 87 | out = self.layer3(out) 88 | out = self.layer4(out) 89 | out = F.avg_pool2d(out, 4) 90 | feature = out.view(out.size(0), -1) 91 | out = self.linear(feature) 92 | if out_feature == False: 93 | return out 94 | else: 95 | return out, feature 96 | 97 | 98 | def ResNet18(in_channels=3, num_classes=10): 99 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes, in_channels=in_channels) 100 | 101 | 102 | def ResNet34(in_channels=3, num_classes=10): 103 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes, in_channels=in_channels) 104 | 105 | 106 | def ResNet50(in_channels=3, num_classes=10): 107 | return ResNet(Bottleneck, [3, 4, 6, 3], num_classes, in_channels=in_channels) 108 | 109 | 110 | def ResNet101(in_channels=3, num_classes=10): 111 | return ResNet(Bottleneck, [3, 4, 23, 3], num_classes, in_channels=in_channels) 112 | 113 | 114 | def ResNet152(in_channels=3, num_classes=10): 115 | return ResNet(Bottleneck, [3, 8, 36, 3], num_classes, in_channels=in_channels) 116 | -------------------------------------------------------------------------------- /models/wresnet.py: -------------------------------------------------------------------------------- 1 | '''https://github.com/polo5/ZeroShotKnowledgeTransfer/blob/master/models/wresnet.py 2 | ''' 3 | 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | """ 11 | Original Author: Wei Yang 12 | """ 13 | 14 | 15 | class BasicBlock(nn.Module): 16 | def __init__(self, in_planes, out_planes, stride, dropout_rate=0.0): 17 | super(BasicBlock, self).__init__() 18 | self.bn1 = nn.BatchNorm2d(in_planes) 19 | self.relu1 = nn.ReLU(inplace=True) 20 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(out_planes) 23 | self.relu2 = nn.ReLU(inplace=True) 24 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 25 | padding=1, bias=False) 26 | self.dropout = nn.Dropout(dropout_rate) 27 | self.equalInOut = (in_planes == out_planes) 28 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 29 | padding=0, bias=False) or None 30 | 31 | def forward(self, x): 32 | if not self.equalInOut: 33 | x = self.relu1(self.bn1(x)) 34 | else: 35 | out = self.relu1(self.bn1(x)) 36 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 37 | out = self.dropout(out) 38 | out = self.conv2(out) 39 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 40 | 41 | 42 | class NetworkBlock(nn.Module): 43 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropout_rate=0.0): 44 | super(NetworkBlock, self).__init__() 45 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropout_rate) 46 | 47 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropout_rate): 48 | layers = [] 49 | for i in range(nb_layers): 50 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropout_rate)) 51 | return nn.Sequential(*layers) 52 | 53 | def forward(self, x): 54 | return self.layer(x) 55 | 56 | 57 | class WideResNet(nn.Module): 58 | def __init__(self, depth, num_classes, widen_factor=1, dropout_rate=0.0, in_channels=3): 59 | super(WideResNet, self).__init__() 60 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 61 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4' 62 | n = (depth - 4) // 6 63 | block = BasicBlock 64 | # 1st conv before any network block 65 | self.conv1 = nn.Conv2d(in_channels, nChannels[0], kernel_size=3, stride=1, 66 | padding=1, bias=False) 67 | # 1st block 68 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropout_rate) 69 | # 2nd block 70 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropout_rate) 71 | # 3rd block 72 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropout_rate) 73 | # global average pooling and classifier 74 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 75 | self.relu = nn.ReLU(inplace=True) 76 | self.fc = nn.Linear(nChannels[3], num_classes) 77 | self.nChannels = nChannels[3] 78 | 79 | for m in self.modules(): 80 | if isinstance(m, nn.Conv2d): 81 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 82 | m.weight.data.normal_(0, math.sqrt(2. / n)) 83 | elif isinstance(m, nn.BatchNorm2d): 84 | m.weight.data.fill_(1) 85 | m.bias.data.zero_() 86 | elif isinstance(m, nn.Linear): 87 | m.bias.data.zero_() 88 | 89 | def forward(self, x, return_features=False): 90 | out = self.conv1(x) 91 | out = self.block1(out) 92 | out = self.block2(out) 93 | out = self.block3(out) 94 | out = self.relu(self.bn1(out)) 95 | out = F.adaptive_avg_pool2d(out, (1, 1)) 96 | features = out.view(-1, self.nChannels) 97 | out = self.fc(features) 98 | 99 | if return_features: 100 | return out, features 101 | else: 102 | return out 103 | 104 | 105 | def wrn_16_1(num_classes, in_channels=3, dropout_rate=0): 106 | return WideResNet(depth=16, num_classes=num_classes, widen_factor=1, dropout_rate=dropout_rate, in_channels=in_channels) 107 | 108 | 109 | def wrn_16_2(num_classes, in_channels=3, dropout_rate=0): 110 | return WideResNet(depth=16, num_classes=num_classes, widen_factor=2, dropout_rate=dropout_rate, in_channels=in_channels) 111 | 112 | 113 | def wrn_40_1(num_classes, in_channels=3, dropout_rate=0): 114 | return WideResNet(depth=40, num_classes=num_classes, widen_factor=1, dropout_rate=dropout_rate, in_channels=in_channels) 115 | 116 | 117 | def wrn_40_2(num_classes, in_channels=3, dropout_rate=0): 118 | return WideResNet(depth=40, num_classes=num_classes, widen_factor=2, dropout_rate=dropout_rate, in_channels=in_channels) 119 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | """https://github.com/HobbitLong/RepDistiller/blob/master/models/vgg.py 2 | """ 3 | import math 4 | 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class VGG(nn.Module): 10 | 11 | def __init__(self, cfg, batch_norm=False, num_classes=1000, in_channels=3): 12 | super(VGG, self).__init__() 13 | self.block0 = self._make_layers(cfg[0], batch_norm, in_channels) 14 | self.block1 = self._make_layers(cfg[1], batch_norm, cfg[0][-1]) 15 | self.block2 = self._make_layers(cfg[2], batch_norm, cfg[1][-1]) 16 | self.block3 = self._make_layers(cfg[3], batch_norm, cfg[2][-1]) 17 | self.block4 = self._make_layers(cfg[4], batch_norm, cfg[3][-1]) 18 | 19 | self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2) 20 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 21 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 22 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 23 | self.pool4 = nn.AdaptiveAvgPool2d((1, 1)) 24 | # self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 25 | 26 | self.classifier = nn.Linear(512, num_classes) 27 | self._initialize_weights() 28 | 29 | def get_feat_modules(self): 30 | feat_m = nn.ModuleList([]) 31 | feat_m.append(self.block0) 32 | feat_m.append(self.pool0) 33 | feat_m.append(self.block1) 34 | feat_m.append(self.pool1) 35 | feat_m.append(self.block2) 36 | feat_m.append(self.pool2) 37 | feat_m.append(self.block3) 38 | feat_m.append(self.pool3) 39 | feat_m.append(self.block4) 40 | feat_m.append(self.pool4) 41 | return feat_m 42 | 43 | def get_bn_before_relu(self): 44 | bn1 = self.block1[-1] 45 | bn2 = self.block2[-1] 46 | bn3 = self.block3[-1] 47 | bn4 = self.block4[-1] 48 | return [bn1, bn2, bn3, bn4] 49 | 50 | def forward(self, x, return_features=False): 51 | h = x.shape[2] 52 | x = F.relu(self.block0(x)) 53 | x = self.pool0(x) 54 | x = self.block1(x) 55 | x = F.relu(x) 56 | x = self.pool1(x) 57 | x = self.block2(x) 58 | x = F.relu(x) 59 | x = self.pool2(x) 60 | x = self.block3(x) 61 | x = F.relu(x) 62 | if h == 64: 63 | x = self.pool3(x) 64 | x = self.block4(x) 65 | x = F.relu(x) 66 | x = self.pool4(x) 67 | features = x.view(x.size(0), -1) 68 | x = self.classifier(features) 69 | if return_features: 70 | return x, features 71 | else: 72 | return x 73 | 74 | @staticmethod 75 | def _make_layers(cfg, batch_norm=False, in_channels=3): 76 | layers = [] 77 | for v in cfg: 78 | if v == 'M': 79 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 80 | else: 81 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 82 | if batch_norm: 83 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 84 | else: 85 | layers += [conv2d, nn.ReLU(inplace=True)] 86 | in_channels = v 87 | layers = layers[:-1] 88 | return nn.Sequential(*layers) 89 | 90 | def _initialize_weights(self): 91 | for m in self.modules(): 92 | if isinstance(m, nn.Conv2d): 93 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 94 | m.weight.data.normal_(0, math.sqrt(2. / n)) 95 | if m.bias is not None: 96 | m.bias.data.zero_() 97 | elif isinstance(m, nn.BatchNorm2d): 98 | m.weight.data.fill_(1) 99 | m.bias.data.zero_() 100 | elif isinstance(m, nn.Linear): 101 | n = m.weight.size(1) 102 | m.weight.data.normal_(0, 0.01) 103 | m.bias.data.zero_() 104 | 105 | 106 | cfg = { 107 | 'A': [[64], [128], [256, 256], [512, 512], [512, 512]], 108 | 'B': [[64, 64], [128, 128], [256, 256], [512, 512], [512, 512]], 109 | 'D': [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]], 110 | 'E': [[64, 64], [128, 128], [256, 256, 256, 256], [512, 512, 512, 512], [512, 512, 512, 512]], 111 | 'S': [[64], [128], [256], [512], [512]], 112 | } 113 | 114 | 115 | def vgg8(**kwargs): 116 | """VGG 8-layer model (configuration "S") 117 | Args: 118 | pretrained (bool): If True, returns a model pre-trained on ImageNet 119 | """ 120 | model = VGG(cfg['S'], **kwargs) 121 | return model 122 | 123 | 124 | def vgg8_bn(**kwargs): 125 | """VGG 8-layer model (configuration "S") 126 | Args: 127 | pretrained (bool): If True, returns a model pre-trained on ImageNet 128 | """ 129 | model = VGG(cfg['S'], batch_norm=True, **kwargs) 130 | return model 131 | 132 | 133 | def vgg11(**kwargs): 134 | """VGG 11-layer model (configuration "A") 135 | Args: 136 | pretrained (bool): If True, returns a model pre-trained on ImageNet 137 | """ 138 | model = VGG(cfg['A'], **kwargs) 139 | return model 140 | 141 | 142 | def vgg11_bn(**kwargs): 143 | """VGG 11-layer model (configuration "A") with batch normalization""" 144 | model = VGG(cfg['A'], batch_norm=True, **kwargs) 145 | return model 146 | 147 | 148 | def vgg13(**kwargs): 149 | """VGG 13-layer model (configuration "B") 150 | Args: 151 | pretrained (bool): If True, returns a model pre-trained on ImageNet 152 | """ 153 | model = VGG(cfg['B'], **kwargs) 154 | return model 155 | 156 | 157 | def vgg13_bn(**kwargs): 158 | """VGG 13-layer model (configuration "B") with batch normalization""" 159 | model = VGG(cfg['B'], batch_norm=True, **kwargs) 160 | return model 161 | 162 | 163 | def vgg16(**kwargs): 164 | """VGG 16-layer model (configuration "D") 165 | Args: 166 | pretrained (bool): If True, returns a model pre-trained on ImageNet 167 | """ 168 | model = VGG(cfg['D'], **kwargs) 169 | return model 170 | 171 | 172 | def vgg16_bn(**kwargs): 173 | """VGG 16-layer model (configuration "D") with batch normalization""" 174 | model = VGG(cfg['D'], batch_norm=True, **kwargs) 175 | return model 176 | 177 | 178 | def vgg19(**kwargs): 179 | """VGG 19-layer model (configuration "E") 180 | Args: 181 | pretrained (bool): If True, returns a model pre-trained on ImageNet 182 | """ 183 | model = VGG(cfg['E'], **kwargs) 184 | return model 185 | 186 | 187 | def vgg19_bn(**kwargs): 188 | """VGG 19-layer model (configuration 'E') with batch normalization""" 189 | model = VGG(cfg['E'], batch_norm=True, **kwargs) 190 | return model 191 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.data 6 | from PIL import Image 7 | from sklearn import metrics 8 | 9 | 10 | class Ensemble(torch.nn.Module): 11 | def __init__(self, model_list): 12 | super(Ensemble, self).__init__() 13 | self.models = nn.ModuleList(model_list) 14 | 15 | def forward(self, x): 16 | logits_total = 0 17 | for i in range(len(self.models)): 18 | logits = self.models[i](x) 19 | logits_total += logits 20 | logits_e = logits_total / len(self.models) 21 | 22 | return logits_e 23 | 24 | 25 | class DeepInversionFeatureHook(): 26 | ''' 27 | Implementation of the forward hook to track feature statistics and compute a loss on them. 28 | Will compute mean and variance, and will use l2 as a loss 29 | ''' 30 | 31 | def __init__(self, module): 32 | self.hook = module.register_forward_hook(self.hook_fn) 33 | 34 | def hook_fn(self, module, input, output): 35 | # hook co compute deepinversion's feature distribution regularization 36 | nch = input[0].shape[1] 37 | 38 | mean = input[0].mean([0, 2, 3]) 39 | var = input[0].permute(1, 0, 2, 3).contiguous().view([nch, -1]).var(1, unbiased=False) 40 | 41 | # forcing mean and variance to match between two distributions 42 | # other ways might work better, e.g. KL divergence 43 | r_feature = torch.norm(module.running_var.data.type(var.type()) - var, 2) + torch.norm(module.running_mean.data.type(var.type()) - mean, 2) 44 | 45 | self.r_feature = r_feature # must have no output 46 | 47 | def close(self): 48 | self.hook.remove() 49 | 50 | 51 | def test(net, testloader, criterion, device): 52 | net.eval() 53 | test_loss = 0 54 | correct = 0 55 | total = 0 56 | gt_list, pred_list = [], [] 57 | img_size = 28 58 | 59 | with torch.no_grad(): 60 | for batch_idx, (inputs, targets) in enumerate(testloader): 61 | inputs, targets = inputs.to(device), targets.flatten().long().cuda() if len(targets.shape) == 2 else targets.long().cuda() 62 | outputs = net(inputs) 63 | loss = criterion(outputs, targets) 64 | 65 | test_loss += loss.item() 66 | _, predicted = outputs.max(1) 67 | total += targets.size(0) 68 | correct += predicted.eq(targets).sum().item() 69 | gt_list.extend(targets.tolist()) 70 | pred_list.extend(predicted.tolist()) 71 | img_size = inputs.shape[-1] 72 | 73 | acc = correct / total 74 | b_acc = metrics.balanced_accuracy_score(gt_list, pred_list) 75 | print('Loss: %.3f | Acc: %.3f%% (%d/%d), B. Acc: %.3f%%' % (test_loss / (batch_idx + 1), 100. * acc, correct, total, b_acc)) 76 | 77 | # for isic2019, diabetic2015 78 | if img_size > 128: 79 | acc = b_acc 80 | 81 | return acc 82 | 83 | 84 | class KLDiv(nn.Module): 85 | def __init__(self, T=1.0, reduction='batchmean'): 86 | super().__init__() 87 | self.T = T 88 | self.reduction = reduction 89 | 90 | def forward(self, logits, targets): 91 | return kldiv(logits, targets, T=self.T, reduction=self.reduction) 92 | 93 | 94 | def kldiv(logits, targets, T=1.0, reduction='batchmean'): 95 | q = F.log_softmax(logits / T, dim=1) 96 | p = F.softmax(targets / T, dim=1) 97 | return F.kl_div(q, p, reduction=reduction) * (T * T) 98 | 99 | 100 | def adjust_learning_rate(optimizer, epoch, lr_init=0.1, lr_step1=80, lr_step2=120): 101 | """For resnet, the lr starts from 0.1, and is divided by 10 at 80 and 120 epochs""" 102 | if epoch < lr_step1: 103 | lr = lr_init 104 | elif epoch < lr_step2: 105 | lr = lr_init * 0.1 106 | else: 107 | lr = lr_init * 0.1 * 0.1 108 | for param_group in optimizer.param_groups: 109 | param_group['lr'] = lr 110 | 111 | 112 | def get_cls_num_list(traindata_cls_counts, num_label): 113 | cls_num_list = [] 114 | for key, val in traindata_cls_counts.items(): 115 | temp = [0] * num_label 116 | for key_1, val_1 in val.items(): 117 | temp[key_1] = val_1 118 | cls_num_list.append(temp) 119 | 120 | return cls_num_list 121 | 122 | 123 | def record_net_data_stats(y_train, net_dataidx_map): 124 | net_cls_counts = {} 125 | 126 | for net_i, dataidx in net_dataidx_map.items(): 127 | unq, unq_cnt = np.unique(y_train[dataidx], return_counts=True) 128 | tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))} 129 | net_cls_counts[net_i] = tmp 130 | 131 | return net_cls_counts 132 | 133 | 134 | def partition_data(y_train, num_label, partition, beta=0.4, num_users=5, debug=True): 135 | data_size = y_train.shape[0] 136 | 137 | if partition == "iid": 138 | idxs = np.random.permutation(data_size) 139 | batch_idxs = np.array_split(idxs, num_users) 140 | net_dataidx_map = {i: batch_idxs[i] for i in range(num_users)} 141 | 142 | elif partition == "dirichlet": 143 | min_size = 0 144 | min_require_size = 10 145 | net_dataidx_map = {} 146 | 147 | while min_size < min_require_size: 148 | idx_batch = [[] for _ in range(num_users)] 149 | for k in range(num_label): 150 | idx_k = np.where(y_train == k)[0] 151 | np.random.shuffle(idx_k) # shuffle the label 152 | proportions = np.random.dirichlet(np.repeat(beta, num_users)) 153 | proportions = np.array([p * (len(idx_j) < data_size / num_users) for p, idx_j in zip(proportions, idx_batch)]) # 0 or x 154 | proportions = proportions / proportions.sum() 155 | proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1] 156 | idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))] 157 | min_size = min([len(idx_j) for idx_j in idx_batch]) 158 | 159 | for j in range(num_users): 160 | np.random.shuffle(idx_batch[j]) 161 | net_dataidx_map[j] = idx_batch[j] 162 | 163 | if debug: 164 | train_data_cls_counts = record_net_data_stats(y_train, net_dataidx_map) 165 | print('Data statistics: %s' % str(train_data_cls_counts)) 166 | 167 | train_cls_num_list = get_cls_num_list(train_data_cls_counts, num_label) 168 | print('Data number: %s' % str(train_cls_num_list)) 169 | 170 | return net_dataidx_map 171 | 172 | 173 | class DatasetSplit(torch.utils.data.Dataset): 174 | 175 | def __init__(self, dataset, idxs): 176 | self.dataset = dataset 177 | self.idxs = [int(i) for i in idxs] 178 | 179 | def __len__(self): 180 | return len(self.idxs) 181 | 182 | def __getitem__(self, item): 183 | image, label = self.dataset[self.idxs[item]] 184 | return image, label 185 | 186 | 187 | class ImageDataset(torch.utils.data.Dataset): 188 | def __init__(self, images, targets, transform=None): 189 | self.images = images 190 | self.targets = targets 191 | self.transform = transform 192 | 193 | def __getitem__(self, idx): 194 | img = self.images[idx] 195 | img = Image.fromarray(img) 196 | target = self.targets[idx] 197 | if self.transform: 198 | img = self.transform(img) 199 | return img, target 200 | 201 | def __len__(self): 202 | return len(self.images) 203 | -------------------------------------------------------------------------------- /train_classifier.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | 5 | import medmnist 6 | import numpy as np 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import torch.nn.parallel 10 | import torch.optim 11 | import torch.utils.data 12 | import torchvision.transforms as transforms 13 | from sklearn import metrics 14 | from sklearn.utils import class_weight 15 | from torch.autograd import Variable 16 | from torch.utils.data import DataLoader 17 | from torchvision.datasets import ImageFolder 18 | 19 | from dataset_isic2019 import FedIsic2019 20 | from models import get_model_heter, get_model_heter_224 21 | from models.resnet_cifar import ResNet18 22 | from utils import DatasetSplit, ImageDataset, adjust_learning_rate, partition_data 23 | 24 | 25 | def main(args): 26 | # set seed 27 | if args.seed is not None: 28 | random.seed(args.seed) 29 | np.random.seed(args.seed) 30 | torch.manual_seed(args.seed) 31 | torch.cuda.manual_seed(args.seed) 32 | torch.cuda.manual_seed_all(args.seed) 33 | cudnn.deterministic = True 34 | cudnn.benchmark = False 35 | 36 | if args.dataset in medmnist.INFO: 37 | info = medmnist.INFO[args.dataset] 38 | 39 | n_channels = info['n_channels'] 40 | n_classes = len(info['label']) 41 | epochs = 100 42 | lr_step1 = 50 43 | lr_step2 = 75 44 | lr_init = 0.001 45 | 46 | DataClass = getattr(medmnist, info['python_class']) 47 | 48 | # check valid dataset 49 | if 'multi-class' != info['task']: 50 | raise ValueError("Invalid Task") 51 | 52 | # preprocessing 53 | aug_list = [] 54 | if args.aug: 55 | aug_list.append(transforms.RandomCrop(28, padding=4)) 56 | aug_list.append(transforms.RandomHorizontalFlip()) 57 | preprocess_list = [transforms.ToTensor(), transforms.Normalize(mean=[.5], std=[.5])] 58 | 59 | transform_train = transforms.Compose(aug_list + preprocess_list) 60 | transform_test = transforms.Compose(preprocess_list) 61 | 62 | # load the data 63 | med_train_data = DataClass(split='train', root=os.path.join(args.root, 'medmnist')) 64 | med_val_data = DataClass(split='val', root=os.path.join(args.root, 'medmnist')) 65 | 66 | data_images_merge = np.concatenate([med_train_data.imgs, med_val_data.imgs]) 67 | data_targets_merge = np.concatenate([med_train_data.labels, med_val_data.labels]) 68 | 69 | data_train = ImageDataset(data_images_merge, data_targets_merge, transform=transform_train) 70 | data_val = ImageDataset(data_images_merge, data_targets_merge, transform=transform_test) 71 | 72 | y_train = np.array(data_train.targets) 73 | 74 | elif args.dataset == 'isic2019': 75 | aug_list = [] 76 | if args.aug: 77 | aug_list.append(transforms.RandomAffine(50, shear=0.1)) 78 | aug_list.append(transforms.RandomResizedCrop(224)) 79 | aug_list.append(transforms.RandomHorizontalFlip()) 80 | aug_list.append(transforms.ColorJitter(brightness=0.15, contrast=0.1)) 81 | preprocess_list = [transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[.5], std=[.5])] 82 | 83 | transform_train = transforms.Compose(aug_list + preprocess_list) 84 | transform_test = transforms.Compose(preprocess_list) 85 | 86 | n_channels = 3 87 | n_classes = 8 88 | epochs = 100 89 | lr_step1 = 50 90 | lr_step2 = 75 91 | lr_init = 0.001 92 | args.num_users = 6 93 | print('reset num_users to', args.num_users) 94 | 95 | elif args.dataset == 'isic2019_merge': 96 | aug_list = [] 97 | if args.aug: 98 | aug_list.append(transforms.RandomAffine(50, shear=0.1)) 99 | aug_list.append(transforms.RandomResizedCrop(224)) 100 | aug_list.append(transforms.RandomHorizontalFlip()) 101 | aug_list.append(transforms.ColorJitter(brightness=0.15, contrast=0.1)) 102 | preprocess_list = [transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[.5], std=[.5])] 103 | 104 | transform_train = transforms.Compose(aug_list + preprocess_list) 105 | transform_test = transforms.Compose(preprocess_list) 106 | 107 | # set dataset 108 | data_train = FedIsic2019(center=-1, split='train', transform=transform_train, data_path=os.path.join(args.root, 'fed_isic2019'), val_rate=args.val_rate) 109 | data_val = FedIsic2019(center=-1, split='val', transform=transform_test, data_path=os.path.join(args.root, 'fed_isic2019'), val_rate=args.val_rate) 110 | 111 | n_channels = 3 112 | n_classes = 8 113 | epochs = 100 114 | lr_step1 = 50 115 | lr_step2 = 75 116 | lr_init = 0.001 117 | 118 | y_train = np.array(data_train.targets) 119 | 120 | elif args.dataset == 'diabetic2015': 121 | aug_list = [] 122 | if args.aug: 123 | aug_list.append(transforms.RandomAffine(50)) 124 | aug_list.append(transforms.RandomResizedCrop(224)) 125 | aug_list.append(transforms.RandomHorizontalFlip()) 126 | aug_list.append(transforms.ColorJitter(brightness=0.15, contrast=0.1)) 127 | preprocess_list = [transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[.5], std=[.5])] 128 | 129 | transform_train = transforms.Compose(aug_list + preprocess_list) 130 | transform_test = transforms.Compose(preprocess_list) 131 | 132 | data_train = ImageFolder(os.path.join(args.root, 'diabetic2015', 'train'), transform=transform_train) # sorted 133 | data_val = ImageFolder(os.path.join(args.root, 'diabetic2015', 'train'), transform=transform_test) # sorted 134 | 135 | n_channels = 3 136 | n_classes = 5 137 | epochs = 100 138 | lr_step1 = 50 139 | lr_step2 = 75 140 | lr_init = 0.001 141 | 142 | y_train = np.array(data_train.targets) 143 | 144 | elif args.dataset == 'rsna': 145 | aug_list = [] 146 | if args.aug: 147 | aug_list.append(transforms.RandomResizedCrop(224)) 148 | aug_list.append(transforms.RandomHorizontalFlip()) 149 | aug_list.append(transforms.ColorJitter(brightness=0.15, contrast=0.1)) 150 | preprocess_list = [transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[.5], std=[.5])] 151 | 152 | transform_train = transforms.Compose(aug_list + preprocess_list) 153 | transform_test = transforms.Compose(preprocess_list) 154 | 155 | data_train = ImageFolder(os.path.join(args.root, 'rsna', 'train'), transform=transform_train) # sorted 156 | data_val = ImageFolder(os.path.join(args.root, 'rsna', 'train'), transform=transform_test) # sorted 157 | 158 | n_channels = 3 159 | n_classes = 2 160 | epochs = 100 161 | lr_step1 = 50 162 | lr_step2 = 75 163 | lr_init = 0.001 164 | 165 | y_train = np.array(data_train.targets) 166 | 167 | else: 168 | raise ValueError(f'Invalid Dataset: {args.dataset}') 169 | 170 | # ======================================== 171 | # make partition 172 | if args.dataset != 'isic2019': 173 | user_groups = partition_data(y_train, n_classes, partition=args.partition, beta=args.beta, num_users=args.num_users) 174 | 175 | train_user_groups, val_user_groups = [], [] 176 | for user_group in user_groups.values(): 177 | train_user_groups.append(user_group[int(len(user_group) * args.val_rate):]) 178 | val_user_groups.append(user_group[:int(len(user_group) * args.val_rate)]) 179 | # ======================================== 180 | 181 | for user_idx in range(args.num_users): 182 | output_dir = os.path.join(args.output_dir, f'{args.dataset}_{args.partition}_{args.num_users}_{args.beta}', f'client_{user_idx}') 183 | 184 | # make output dir 185 | os.makedirs(output_dir, exist_ok=True) 186 | if os.path.isfile(os.path.join(output_dir, 'val.csv')): 187 | os.remove(os.path.join(output_dir, 'val.csv')) 188 | 189 | # define model 190 | if args.dataset == 'isic2019': 191 | net = get_model_heter_224(user_idx, num_classes=n_classes).cuda() if args.model_heter else get_model_heter_224(0, num_classes=n_classes).cuda() 192 | 193 | # set dataset 194 | data_train_loader = DataLoader( 195 | FedIsic2019(center=user_idx, split='train', transform=transform_train, data_path=os.path.join(args.root, 'fed_isic2019'), val_rate=args.val_rate), 196 | batch_size=args.bs, shuffle=True, num_workers=8) 197 | data_val_loader = DataLoader( 198 | FedIsic2019(center=user_idx, split='val', transform=transform_test, data_path=os.path.join(args.root, 'fed_isic2019'), val_rate=args.val_rate), 199 | batch_size=args.bs, shuffle=True, num_workers=8) 200 | 201 | # calculate class_weights 202 | targets = data_train_loader.dataset.targets 203 | class_weights = class_weight.compute_class_weight(class_weight='balanced', classes=list(range(0, n_classes)), y=np.array(targets + list(range(0, n_classes)))) # apply smoothing 204 | class_weights = torch.tensor(class_weights, dtype=torch.float) 205 | 206 | elif args.dataset in ['diabetic2015', 'isic2019_merge', 'rsna']: 207 | net = get_model_heter_224(user_idx, num_classes=n_classes).cuda() if args.model_heter else get_model_heter_224(0, num_classes=n_classes).cuda() 208 | 209 | # set dataset 210 | data_train_loader = DataLoader(DatasetSplit(data_train, train_user_groups[user_idx]), batch_size=args.bs, shuffle=True, num_workers=8) 211 | data_val_loader = DataLoader(DatasetSplit(data_val, val_user_groups[user_idx]), batch_size=args.bs, shuffle=True, num_workers=8) 212 | 213 | # calculate class_weights 214 | targets = data_train.targets 215 | class_weights = class_weight.compute_class_weight(class_weight='balanced', classes=list(range(0, n_classes)), y=np.array(targets + list(range(0, n_classes)))) # apply smoothing 216 | class_weights = torch.tensor(class_weights, dtype=torch.float) 217 | 218 | else: 219 | net = get_model_heter(user_idx, in_channels=n_channels, num_classes=n_classes).cuda() if args.model_heter else ResNet18(in_channels=n_channels, num_classes=n_classes).cuda() 220 | 221 | # set dataset 222 | data_train_loader = DataLoader(DatasetSplit(data_train, train_user_groups[user_idx]), batch_size=args.bs, shuffle=True, num_workers=8) 223 | data_val_loader = DataLoader(DatasetSplit(data_val, val_user_groups[user_idx]), batch_size=args.bs, shuffle=True, num_workers=8) 224 | class_weights = None 225 | 226 | # define loss and optim 227 | criterion = torch.nn.CrossEntropyLoss(weight=class_weights).cuda() 228 | optimizer = torch.optim.SGD(net.parameters(), lr=lr_init, momentum=0.9, weight_decay=5e-4) 229 | 230 | acc_best = 0 231 | for e in range(1, epochs + 1): 232 | # ======================================== 233 | # Train 234 | adjust_learning_rate(optimizer, e, lr_init=lr_init, lr_step1=lr_step1, lr_step2=lr_step2) 235 | 236 | net.train() 237 | loss_list = [] 238 | for i, (images, labels) in enumerate(data_train_loader): 239 | images, labels = Variable(images).cuda(), Variable(labels.flatten().long() if len(labels.shape) == 2 else labels.long()).cuda() 240 | optimizer.zero_grad() 241 | output = net(images) 242 | loss = criterion(output, labels) 243 | loss.backward() 244 | optimizer.step() 245 | loss_list.append(loss.data.item()) 246 | if i == 1: 247 | print('Train - Epoch %d, Batch: %d, Loss: %f' % (e, i, loss.data.item())) 248 | # ======================================== 249 | 250 | # ======================================== 251 | # Val 252 | net.eval() 253 | total_correct, num_samples = 0, 0 254 | avg_loss = 0.0 255 | gt_list, pred_list = [], [] 256 | with torch.no_grad(): 257 | for i, (images, labels) in enumerate(data_val_loader): 258 | images, labels = Variable(images).cuda(), Variable(labels.flatten().long() if len(labels.shape) == 2 else labels.long()).cuda() 259 | output = net(images) 260 | avg_loss += criterion(output, labels).sum() 261 | pred = output.data.max(1)[1] 262 | total_correct += pred.eq(labels.data.view_as(pred)).sum() 263 | num_samples += images.shape[0] 264 | gt_list.extend(labels.tolist()) 265 | pred_list.extend(pred.tolist()) 266 | 267 | avg_loss /= num_samples 268 | acc = float(total_correct) / num_samples 269 | b_acc = metrics.balanced_accuracy_score(gt_list, pred_list) 270 | print('Val Avg. Loss: %f, Accuracy: %f, Balanced Accuracy: %f' % (avg_loss.data.item(), acc, b_acc)) 271 | 272 | # use balanced accuracy instead of accuracy 273 | if args.dataset in ['isic2019', 'diabetic2015']: 274 | acc = b_acc 275 | 276 | # write log 277 | with open(os.path.join(output_dir, 'val.csv'), 'at') as wf: 278 | wf.write('{},{:.4f}\n'.format(e, acc)) 279 | # ======================================== 280 | 281 | if acc_best < acc: 282 | acc_best = acc 283 | 284 | # save best 285 | torch.save(net, os.path.join(output_dir, 'best.pth')) 286 | 287 | # save last 288 | torch.save(net, os.path.join(output_dir, 'last.pth')) 289 | 290 | 291 | if __name__ == '__main__': 292 | parser = argparse.ArgumentParser(description='train-teacher-network') 293 | 294 | # Basic model parameters. 295 | parser.add_argument('--dataset', default='bloodmnist', type=str) 296 | parser.add_argument('--output_dir', default='./pretrained_models', type=str) 297 | parser.add_argument('--root', default='./dataset', type=str) 298 | parser.add_argument('--seed', default=1, type=int) 299 | parser.add_argument('--bs', default=128, type=int) 300 | parser.add_argument('--aug', action='store_true') 301 | parser.add_argument('--partition', default='dirichlet', type=str) 302 | parser.add_argument('--beta', default=0.6, type=float) 303 | parser.add_argument('--num_users', default=5, type=int) 304 | parser.add_argument('--val_rate', default=0.1, type=float) 305 | parser.add_argument('--model_heter', action='store_true') 306 | parser.add_argument('--pretrained', action='store_true') 307 | 308 | args = parser.parse_args() 309 | 310 | if args.model_heter: 311 | assert args.num_users == 5 312 | 313 | if args.dataset == 'diabetic2015': 314 | assert args.partition == 'iid' 315 | 316 | main(args) 317 | -------------------------------------------------------------------------------- /main_fedisca.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import copy 4 | import gc 5 | import os 6 | import random 7 | 8 | import medmnist 9 | import numpy as np 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | import torchvision.transforms as transforms 15 | import torchvision.utils as vutils 16 | from torch.hub import load_state_dict_from_url 17 | from torch.utils.data import Dataset 18 | from torchvision.models.resnet import resnet18 19 | 20 | from models import get_model_heter, get_model_heter_224 21 | from models.resnet_cifar import ResNet18 22 | from utils import KLDiv, test, adjust_learning_rate, DeepInversionFeatureHook, Ensemble 23 | 24 | 25 | def main(args): 26 | if args.seed is not None: 27 | random.seed(args.seed) 28 | np.random.seed(args.seed) 29 | torch.manual_seed(args.seed) 30 | torch.cuda.manual_seed(args.seed) 31 | torch.cuda.manual_seed_all(args.seed) 32 | cudnn.deterministic = True 33 | cudnn.benchmark = False 34 | 35 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 36 | 37 | # make exp directory 38 | img_exp_descr = os.path.join(args.exp_descr, 'img') 39 | best_img_exp_descr = os.path.join(img_exp_descr, 'best') 40 | os.makedirs(best_img_exp_descr, exist_ok=True) 41 | if os.path.isfile(os.path.join(args.exp_descr, 'test.csv')): 42 | os.remove(os.path.join(args.exp_descr, 'test.csv')) 43 | 44 | if args.dataset in medmnist.INFO: 45 | info = medmnist.INFO[args.dataset] 46 | DataClass = getattr(medmnist, info['python_class']) 47 | 48 | # preprocessing 49 | transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[.5], std=[.5])]) 50 | 51 | # load the data 52 | data_test = DataClass(split='test', transform=transform_test, root=os.path.join(args.root, 'medmnist')) 53 | data_test_loader = torch.utils.data.DataLoader(dataset=data_test, batch_size=args.bs, shuffle=False, num_workers=8) 54 | 55 | input_size = 28 56 | n_channels = info['n_channels'] 57 | n_classes = len(info['label']) 58 | 59 | epochs = 100 60 | lr_step1 = 50 61 | lr_step2 = 75 62 | lr_init = 0.001 63 | 64 | elif args.dataset == 'isic2019': 65 | from dataset_isic2019 import FedIsic2019 66 | 67 | # preprocessing 68 | transform_test = transforms.Compose([transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[.5], std=[.5])]) 69 | 70 | data_test_loader = torch.utils.data.DataLoader( 71 | FedIsic2019(split='test', data_path=os.path.join(args.root, 'fed_isic2019'), transform=transform_test), 72 | batch_size=args.bs, shuffle=True, num_workers=8) 73 | 74 | input_size = 224 75 | n_channels = 3 76 | n_classes = 8 77 | 78 | epochs = 100 79 | lr_step1 = 50 80 | lr_step2 = 75 81 | lr_init = 0.001 82 | 83 | elif args.dataset == 'diabetic2015': 84 | from torchvision.datasets import ImageFolder 85 | 86 | # preprocessing 87 | transform_test = transforms.Compose([transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[.5], std=[.5])]) 88 | 89 | data_test_loader = torch.utils.data.DataLoader( 90 | ImageFolder(os.path.join(args.root, 'diabetic2015', 'test'), transform=transform_test), 91 | batch_size=args.bs, shuffle=True, num_workers=8) 92 | 93 | input_size = 224 94 | n_channels = 3 95 | n_classes = 5 96 | 97 | epochs = 100 98 | lr_step1 = 50 99 | lr_step2 = 75 100 | lr_init = 0.001 101 | 102 | elif args.dataset == 'rsna': 103 | from torchvision.datasets import ImageFolder 104 | 105 | # preprocessing 106 | transform_test = transforms.Compose([transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[.5], std=[.5])]) 107 | 108 | data_test_loader = torch.utils.data.DataLoader( 109 | ImageFolder(os.path.join(args.root, 'rsna', 'test'), transform=transform_test), 110 | batch_size=args.bs, shuffle=True, num_workers=8) 111 | 112 | input_size = 224 113 | n_channels = 3 114 | n_classes = 2 115 | 116 | epochs = 100 117 | lr_step1 = 50 118 | lr_step2 = 75 119 | lr_init = 0.001 120 | 121 | else: 122 | raise ValueError(f'Invalid Dataset: {args.dataset}') 123 | 124 | if os.path.isfile(args.teacher_weights): 125 | # define networks 126 | net_teacher = resnet18(num_classes=n_classes) if args.dataset in ['isic2019', 'diabetic2015', 'rsna'] else ResNet18(in_channels=n_channels, num_classes=n_classes) 127 | net_teacher = net_teacher.to(device) 128 | 129 | # load checkpoint 130 | checkpoint = torch.load(args.teacher_weights) 131 | net_teacher.load_state_dict(checkpoint.state_dict()) 132 | net_teacher.eval() 133 | elif os.path.isdir(args.teacher_weights): 134 | model_list = [] 135 | for client_dir in sorted(os.listdir(args.teacher_weights)): 136 | weight_path = os.path.join(args.teacher_weights, client_dir, 'best.pth') 137 | 138 | if not os.path.isfile(weight_path): 139 | continue 140 | 141 | # define networks 142 | if '_heter' in args.teacher_weights: # TODO: 143 | print('load heterogeneous models') 144 | _get_model_heter = get_model_heter_224 if args.dataset in ['isic2019', 'diabetic2015', 'rsna'] else get_model_heter 145 | _net_teacher = _get_model_heter(int(client_dir.split('_')[-1]), in_channels=n_channels, num_classes=n_classes) 146 | else: 147 | _net_teacher = resnet18(num_classes=n_classes) if args.dataset in ['isic2019', 'diabetic2015', 'rsna'] else ResNet18(in_channels=n_channels, num_classes=n_classes) 148 | _net_teacher = _net_teacher.to(device) 149 | 150 | # load checkpoint 151 | checkpoint = torch.load(weight_path) 152 | _net_teacher.load_state_dict(checkpoint.state_dict()) 153 | _net_teacher.eval() 154 | 155 | model_list.append(_net_teacher) 156 | 157 | if len(model_list) == 0: 158 | raise ValueError('Invalid weights:', args.teacher_weights) 159 | 160 | # ensemble models 161 | net_teacher = Ensemble(model_list) 162 | else: 163 | raise ValueError('Invalid weights:', args.teacher_weights) 164 | 165 | # copy teacher model 166 | net_teacher_noiseadapt = copy.deepcopy(net_teacher) 167 | net_teacher_noiseadapt = net_teacher_noiseadapt.to(device) 168 | net_teacher_noiseadapt.train() 169 | 170 | criterion = nn.CrossEntropyLoss() 171 | 172 | # Checking teacher accuracy 173 | print('==> Teacher validation') 174 | acc_teacher = test(net_teacher, data_test_loader, criterion, device) 175 | with open(os.path.join(args.exp_descr, 'test_teacher.csv'), 'wt') as wf: 176 | wf.write('{:.4f}\n'.format(acc_teacher)) 177 | 178 | print("Starting model inversion") 179 | 180 | # placeholder for inputs 181 | inputs = torch.randn((args.bs, n_channels, input_size, input_size), requires_grad=True, device='cuda', dtype=torch.float) 182 | 183 | # target outputs to generate 184 | targets = torch.LongTensor(list(range(0, n_classes)) * (args.bs // n_classes) + list(range(0, args.bs % n_classes))).to('cuda') 185 | 186 | # define optimizer and loss 187 | optimizer_di = optim.Adam([inputs], lr=args.di_lr) 188 | 189 | # Create hooks for feature statistics catching 190 | loss_r_feature_layers = [] 191 | for module in net_teacher.modules(): 192 | if isinstance(module, nn.BatchNorm2d): 193 | loss_r_feature_layers.append(DeepInversionFeatureHook(module)) 194 | 195 | # for classifier 196 | net_cls = resnet18(num_classes=n_classes) if args.dataset in ['isic2019', 'diabetic2015', 'rsna'] else ResNet18(in_channels=n_channels, num_classes=n_classes) 197 | net_cls = net_cls.to(device) 198 | if args.pretrained: 199 | state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet18-f37072fd.pth', progress=True) 200 | del state_dict['fc.weight'] 201 | del state_dict['fc.bias'] 202 | net_cls.load_state_dict(state_dict, strict=False) 203 | 204 | criterion_cls = KLDiv(T=args.T) 205 | optimizer_cls = torch.optim.SGD(net_cls.parameters(), lr=lr_init, momentum=0.9, weight_decay=5e-4) 206 | 207 | acc_best = 0 208 | for e in range(1, epochs + 1): 209 | # initialize gaussian inputs 210 | inputs.data = torch.randn((args.bs, n_channels, input_size, input_size), requires_grad=True, device='cuda') 211 | 212 | # ============================== 213 | # get_images 214 | best_cost = 1e6 215 | n_classes = targets.max().item() + 1 216 | 217 | optimizer_di.state = collections.defaultdict(dict) # Reset state of optimizer 218 | 219 | image_list = [] 220 | 221 | # empty cache 222 | torch.cuda.empty_cache() 223 | gc.collect() 224 | 225 | # setting up the range for jitter 226 | if inputs.shape[-1] > 128: 227 | lim_0, lim_1 = 30, 30 228 | else: 229 | lim_0, lim_1 = 2, 2 230 | 231 | for mi_idx in range(args.iters_mi): 232 | # apply random jitter offsets 233 | off1 = random.randint(-lim_0, lim_0) 234 | off2 = random.randint(-lim_1, lim_1) 235 | inputs_jit = torch.roll(inputs, shifts=(off1, off2), dims=(2, 3)) 236 | 237 | # forward with jit images 238 | optimizer_di.zero_grad() 239 | net_teacher.zero_grad() 240 | outputs = net_teacher(inputs_jit) 241 | loss = criterion(outputs, targets) 242 | loss_target = loss.item() 243 | 244 | # apply total variation regularization 245 | diff1 = inputs_jit[:, :, :, :-1] - inputs_jit[:, :, :, 1:] 246 | diff2 = inputs_jit[:, :, :-1, :] - inputs_jit[:, :, 1:, :] 247 | diff3 = inputs_jit[:, :, 1:, :-1] - inputs_jit[:, :, :-1, 1:] 248 | diff4 = inputs_jit[:, :, :-1, :-1] - inputs_jit[:, :, 1:, 1:] 249 | loss_var = torch.norm(diff1) + torch.norm(diff2) + torch.norm(diff3) + torch.norm(diff4) 250 | loss = loss + args.di_var_scale * loss_var 251 | 252 | # R_feature loss 253 | loss_distr = sum([mod.r_feature for mod in loss_r_feature_layers]) 254 | loss = loss + args.r_feature_weight * loss_distr # best for noise before BN 255 | 256 | # l2 loss 257 | loss = loss + args.di_l2_scale * torch.norm(inputs_jit, 2) 258 | 259 | if mi_idx % args.log_freq == 0: 260 | print(f"It {mi_idx}\t Losses: total: {loss.item():3.3f},\ttarget: {loss_target:3.3f} \tR_feature_loss unscaled:\t {loss_distr.item():3.3f}") 261 | vutils.save_image(inputs.data.clone(), '{}/output_{}_{}.png'.format(img_exp_descr, e, mi_idx), normalize=True, scale_each=True, nrow=n_classes) 262 | 263 | if best_cost > loss.item(): 264 | best_cost = loss.item() 265 | best_inputs = inputs.data 266 | 267 | # backward pass 268 | loss.backward() 269 | optimizer_di.step() 270 | 271 | # append inputs 272 | image_list.append(inputs.detach().cpu().data) 273 | 274 | # save last 275 | print(f"It {args.iters_mi}\t Losses: total: {loss.item():3.3f},\ttarget: {loss_target:3.3f} \tR_feature_loss unscaled:\t {loss_distr.item():3.3f}") 276 | vutils.save_image(inputs.data.clone(), '{}/output_{}_{}.png'.format(img_exp_descr, e, args.iters_mi), normalize=True, scale_each=True, nrow=n_classes) 277 | 278 | # ============================== 279 | # evaluation 280 | outputs = net_teacher(best_inputs) 281 | _, predicted_teach = outputs.max(1) 282 | 283 | print('Teacher correct out of {}: {}, loss at {}'.format(args.bs, predicted_teach.eq(targets).sum().item(), criterion(outputs, targets).item())) 284 | 285 | vutils.save_image(best_inputs.clone(), '{}/output_{}.png'.format(best_img_exp_descr, e), normalize=True, scale_each=True, nrow=n_classes) 286 | 287 | # ============================== 288 | # train classifier 289 | print('==> Train classifier') 290 | adjust_learning_rate(optimizer_cls, e, lr_init=lr_init, lr_step1=lr_step1, lr_step2=lr_step2) 291 | 292 | # set train 293 | net_cls.train() 294 | net_teacher_noiseadapt.train() 295 | 296 | # update mean and std (from real to noise) 297 | for cls_i in range(len(image_list) - 1, -1, -1): 298 | with torch.no_grad(): 299 | net_teacher_noiseadapt(image_list[cls_i].to(device)) 300 | 301 | for cls_i in range(len(image_list)): 302 | cls_inputs = image_list[cls_i].to(device) 303 | 304 | optimizer_cls.zero_grad() 305 | 306 | # calculate alpha (0 -> 1) 307 | alpha = cls_i / len(image_list) 308 | 309 | # noise KD 310 | with torch.no_grad(): 311 | # update mean and std (from real to noise) 312 | outputs_noise = net_teacher_noiseadapt(cls_inputs) 313 | 314 | # real KD 315 | with torch.no_grad(): 316 | outputs = net_teacher(cls_inputs) 317 | outputs_cls = net_cls(cls_inputs) 318 | loss_cls_real = criterion_cls(outputs_cls, outputs.detach()) 319 | loss_cls_noise = criterion_cls(outputs_cls, outputs_noise.detach()) 320 | 321 | # emerge losses 322 | loss_cls = alpha * loss_cls_real + (1.0 - alpha) * loss_cls_noise 323 | 324 | loss_cls.backward() 325 | optimizer_cls.step() 326 | # ============================== 327 | 328 | # test classifier 329 | acc = test(net_cls, data_test_loader, criterion, device) 330 | 331 | # write log 332 | with open(os.path.join(args.exp_descr, 'test.csv'), 'at') as wf: 333 | wf.write('{},{:.4f}\n'.format(e, acc)) 334 | 335 | # save best 336 | if acc_best < acc: 337 | acc_best = acc 338 | torch.save(net_cls, os.path.join(args.exp_descr, 'best.pth')) 339 | torch.save(net_teacher_noiseadapt, os.path.join(args.exp_descr, 'best_teacher_noiseadapt.pth')) 340 | 341 | # save last 342 | torch.save(net_cls, os.path.join(args.exp_descr, 'last.pth')) 343 | torch.save(net_teacher_noiseadapt, os.path.join(args.exp_descr, 'last_teacher_noiseadapt.pth')) 344 | 345 | 346 | if __name__ == "__main__": 347 | parser = argparse.ArgumentParser(description='') 348 | parser.add_argument('--bs', default=256, type=int, help='batch size') 349 | parser.add_argument('--iters_mi', default=500, type=int, help='number of iterations for model inversion') 350 | parser.add_argument('--di_lr', default=0.05, type=float, help='lr for deep inversion') 351 | parser.add_argument('--di_var_scale', default=2.5e-5, type=float, help='TV L2 regularization coefficient') 352 | parser.add_argument('--di_l2_scale', default=0.0, type=float, help='L2 regularization coefficient') 353 | parser.add_argument('--r_feature_weight', default=10, type=float, help='weight for BN regularization statistic') 354 | parser.add_argument('--exp_descr', default="result", type=str, help='name to be added to experiment name') 355 | parser.add_argument('--teacher_weights', default="./pretrained_models/bloodmnist_iid_5_0.6", type=str, help='path to load weights of the model') 356 | parser.add_argument('--dataset', default='bloodmnist', type=str) 357 | parser.add_argument('--root', default='./dataset', type=str) 358 | parser.add_argument('--seed', default=1, type=int) 359 | parser.add_argument('--T', default=20, type=float) 360 | parser.add_argument('--log_freq', default=200, type=int) 361 | parser.add_argument('--pretrained', action='store_true') 362 | 363 | args = parser.parse_args() 364 | 365 | main(args) 366 | --------------------------------------------------------------------------------