├── LICENSE ├── README.md ├── arguments.py ├── data_handler ├── __init__.py ├── celeba.py ├── cifar10.py ├── custom_loader.py ├── dataloader_factory.py ├── dataset_factory.py └── utkface.py ├── main.py ├── networks ├── __init__.py ├── cifar_net.py ├── mlp.py ├── model_factory.py ├── resnet.py └── shufflenet.py ├── trainer ├── __init__.py ├── adv_debiasing.py ├── kd_at.py ├── kd_fitnet.py ├── kd_hinton.py ├── kd_mfd.py ├── kd_nst.py ├── loss_utils.py ├── scratch_mmd.py ├── trainer_factory.py └── vanilla_train.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Donggyu Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fair-Feature-Distillation-for-Visual-Recognition 2 | Official implementation of paper 'Fair Feature Distillation for Visual Recognition' 3 | 4 | ## **Execution Details** 5 | 6 | ### Requirements 7 | 8 | - Python 3 9 | - GPU Titan XP / Pytorch 1.6 / CUDA 10.1 10 | 11 | #### 1) Download dataset 12 | 13 | - UTKFACE : [link](https://susanqq.github.io/UTKFace/) (We used Aligned&Cropped Faces from the site) 14 | - CelebA : link 15 | 16 | #### 2) Execution command 17 | You should first train a scratch model used as a teacher. 18 | ``` 19 | # Cifar10 20 | $ python3 ./main.py --method scratch --dataset cifar10 --model cifar_net --epochs 50 --img-size 32 --batch-size 128 --optimizer Adam --lr 0.001 --date 210525 21 | 22 | # UTKFACE 23 | $ python3 ./main.py --method scratch --dataset utkface --model resnet18 --epochs 50 --img-size 176 --batch-size 128 --optimizer Adam --lr 0.001 --date 210525 24 | 25 | # CelebA 26 | $ python3 ./main.py --method scratch --dataset celeba --model shufflenet --epochs 50 --img-size 176 --batch-size 128 --optimizer Adam --lr 0.001 --date 210525 27 | ``` 28 | 29 | Then, using the saved teacher model, you can train a student model via MFD algorithm. 30 | ``` 31 | # Cifar10 32 | $ python3 ./main.py --method kd_mfd --dataset cifar10 --model cifar_net --epochs 50 --labelwise --lambf 3 --lambh 0 --no-annealing --img-size 32 --batch-size 128 --optimizer Adam --lr 0.001 --teacher-path trained_models/210525/cifar10/scratch/cifar_net_seed0_epochs50_bs128_lr0.001.pt 33 | 34 | # UTKFACE 35 | $ python3 ./main.py --method kd_mfd --dataset utkface --model resnet18 --epochs 50 --labelwise --lambf 3 --lambh 0 --no-annealing --img-size 176 --batch-size 128 --optimizer Adam --lr 0.001 --teacher-path trained_models/210525/utkface/scratch/resnet18_seed0_epochs50_bs128_lr0.001.pt 36 | 37 | # CelebA 38 | $ python3 ./main.py --method kd_mfd --dataset celeba --model shufflenet --epochs 50 --labelwise --lambf 7 --lambh 0 --no-annealing --img-size 176 --batch-size 128 --optimizer Adam --lr 0.001 --teacher-path trained_models/210525/celeba/scratch/shufflenet_seed0_epochs50_bs128_lr0.001.pt 39 | ``` 40 | 41 | #### Notes 42 | 43 | The all datasets can be downloaded in link. 44 | 45 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | def get_args(): 5 | parser = argparse.ArgumentParser(description='Fairness') 6 | parser.add_argument('--log-dir', default='./results/', 7 | help='directory to save logs (default: ./results/)') 8 | parser.add_argument('--data-dir', default='./data/', 9 | help='data directory (default: ./data/)') 10 | parser.add_argument('--save-dir', default='./trained_models/', 11 | help='directory to save trained models (default: ./trained_models/)') 12 | parser.add_argument('--no-cuda', action='store_true', default=False, 13 | help='disables CUDA training') 14 | parser.add_argument('--device', default=0, type=int, help='cuda device number') 15 | parser.add_argument('--t-device', default=0, type=int, help='teacher cuda device number') 16 | 17 | 18 | parser.add_argument('--mode', default='train', choices=['train', 'eval']) 19 | parser.add_argument('--modelpath', default=None) 20 | parser.add_argument('--evalset', default='all', choices=['all', 'train', 'test']) 21 | 22 | parser.add_argument('--dataset', required=True, default='', choices=['utkface', 'celeba', 'cifar10']) 23 | parser.add_argument('--skew-ratio', default=0.8, type=float, help='skew ratio for cifar-10s') 24 | parser.add_argument('--img-size', default=224, type=int, help='img size for preprocessing') 25 | 26 | parser.add_argument('--lr', default=0.001, type=float, help='learning rate') 27 | parser.add_argument('--epochs', default=50, type=int, help='number of training epochs') 28 | parser.add_argument('--batch-size', default=128, type=int, help='mini batch size') 29 | parser.add_argument('--seed', default=0, type=int, help='seed for randomness') 30 | parser.add_argument('--date', default='20xxxxxx', type=str, help='experiment date') 31 | parser.add_argument('--method', default='scratch', type=str, required=True, 32 | choices=['scratch', 'kd_hinton', 'kd_fitnet', 'kd_at', 33 | 'kd_mfd', 'scratch_mmd', 'kd_nst', 'adv_debiasing']) 34 | 35 | parser.add_argument('--optimizer', default='Adam', type=str, required=False, 36 | choices=['SGD', 'SGD_momentum_decay', 'Adam'], 37 | help='(default=%(default)s)') 38 | 39 | parser.add_argument('--lambh', default=4, type=float, help='kd strength hyperparameter') 40 | parser.add_argument('--lambf', default=1, type=float, help='feature distill strength hyperparameter') 41 | parser.add_argument('--kd-temp', default=3, type=float, help='temperature for KD') 42 | 43 | parser.add_argument('--model', default='', required=True, choices=['resnet', 'shufflenet', 'mlp', 'cifar_net']) 44 | parser.add_argument('--parallel', default=False, action='store_true', help='data parallel') 45 | parser.add_argument('--teacher-type', default=None, choices=['resnet', 'shufflenet', 'cifar_net']) 46 | parser.add_argument('--teacher-path', default=None, help='teacher model path') 47 | 48 | parser.add_argument('--pretrained', default=False, action='store_true', help='load imagenet pretrained model') 49 | parser.add_argument('--num-workers', default=2, type=int, help='the number of thread used in dataloader') 50 | parser.add_argument('--term', default=20, type=int, help='the period for recording train acc') 51 | parser.add_argument('--target', default='Attractive', type=str, help='target attribute for celeba') 52 | 53 | parser.add_argument('--no-annealing', action='store_true', default=False, help='do not anneal lamb during training') 54 | parser.add_argument('--fitnet-simul', default=False, action='store_true', help='no hint-training') 55 | 56 | parser.add_argument('--eta', default=0.0003, type=float, help='adversary training learning rate') 57 | parser.add_argument('--adv-lambda', default=2.0, type=float, help='adversary loss strength') 58 | 59 | parser.add_argument('--sigma', default=1.0, type=float, help='sigma for rbf kernel') 60 | parser.add_argument('--kernel', default='rbf', type=str, choices=['rbf', 'poly'], help='kernel for mmd') 61 | parser.add_argument('--labelwise', default=False, action='store_true', help='labelwise loader') 62 | parser.add_argument('--jointfeature', default=False, action='store_true', help='mmd with both joint') 63 | parser.add_argument('--get-inter', default=False, action='store_true', 64 | help='get penultimate features for TSNE visualization') 65 | 66 | args = parser.parse_args() 67 | args.cuda = not args.no_cuda and torch.cuda.is_available() 68 | if args.mode == 'train' and (args.method.startswith('kd')): 69 | if args.teacher_path is None: 70 | raise Exception('A teacher model path is not specified.') 71 | 72 | if args.mode == 'eval' and args.model_path is None: 73 | raise Exception('Model path to load is not specified!') 74 | 75 | return args 76 | -------------------------------------------------------------------------------- /data_handler/__init__.py: -------------------------------------------------------------------------------- 1 | from data_handler.dataloader_factory import * -------------------------------------------------------------------------------- /data_handler/celeba.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from os.path import join 4 | from torchvision.datasets.vision import VisionDataset 5 | import PIL 6 | import pandas 7 | import numpy as np 8 | import zipfile 9 | from functools import partial 10 | from torchvision.datasets.utils import download_file_from_google_drive, check_integrity, verify_str_arg 11 | 12 | 13 | class CelebA(VisionDataset): 14 | base_folder = "celeba" 15 | # There currently does not appear to be a easy way to extract 7z in python (without introducing additional 16 | # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available 17 | # right now. 18 | file_list = [ 19 | # File ID MD5 Hash Filename 20 | ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"), 21 | # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"), 22 | # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"), 23 | ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"), 24 | ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"), 25 | ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"), 26 | ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"), 27 | # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"), 28 | ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"), 29 | ] 30 | 31 | def __init__(self, root, split="train", target_type="attr", transform=None, 32 | target_transform=None, download=False, target_attr='Attractive', labelwise=False): 33 | super(CelebA, self).__init__(root, transform=transform, 34 | target_transform=target_transform) 35 | self.split = split 36 | if isinstance(target_type, list): 37 | self.target_type = target_type 38 | else: 39 | self.target_type = [target_type] 40 | 41 | if not self.target_type and self.target_transform is not None: 42 | raise RuntimeError('target_transform is specified but target_type is empty') 43 | 44 | if download: 45 | self.download() 46 | 47 | if not self._check_integrity(): 48 | raise RuntimeError('Dataset not found or corrupted.' + 49 | ' You can use download=True to download it') 50 | # SELECT the features 51 | self.sensitive_attr = 'Male' 52 | self.target_attr = target_attr 53 | split_map = { 54 | "train": 0, 55 | "valid": 1, 56 | "test": 2, 57 | "all": None, 58 | } 59 | split = split_map[verify_str_arg(split.lower(), "split", 60 | ("train", "valid", "test", "all" ))] 61 | 62 | fn = partial(join, self.root, self.base_folder) 63 | splits = pandas.read_csv(fn("list_eval_partition.txt"), delim_whitespace=True, header=None, index_col=0) 64 | attr = pandas.read_csv(fn("list_attr_celeba.txt"), delim_whitespace=True, header=1) 65 | 66 | mask = slice(None) if split is None else (splits[1] == split) 67 | 68 | self.filename = splits[mask].index.values 69 | self.attr = torch.as_tensor(attr[mask].values) 70 | self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1} 71 | self.attr_names = list(attr.columns) 72 | print(self.attr_names) 73 | self.target_idx = self.attr_names.index(self.target_attr) 74 | self.sensi_idx = self.attr_names.index(self.sensitive_attr) 75 | self.feature_idx = [i for i in range(len(self.attr_names)) if i != self.target_idx and i!=self.sensi_idx] 76 | self.num_classes = 2 77 | self.num_groups =2 78 | print('num classes is {}'.format(self.num_classes)) 79 | self.num_data = self._data_count() 80 | if self.split == "test": 81 | self._balance_test_data() 82 | self.labelwise = labelwise 83 | if self.labelwise: 84 | self.idx_map = self._make_idx_map() 85 | 86 | def _make_idx_map(self): 87 | idx_map = [[] for i in range(self.num_groups * self.num_classes)] 88 | for j, i in enumerate(self.attr): 89 | y = self.attr[j, self.target_idx] 90 | s = self.attr[j, self.sensi_idx] 91 | pos = s*self.num_classes + y 92 | idx_map[pos].append(j) 93 | final_map = [] 94 | for l in idx_map: 95 | final_map.extend(l) 96 | return final_map 97 | 98 | def _check_integrity(self): 99 | for (_, md5, filename) in self.file_list: 100 | fpath = os.path.join(self.root, self.base_folder, filename) 101 | _, ext = os.path.splitext(filename) 102 | # Allow original archive to be deleted (zip and 7z) 103 | # Only need the extracted images 104 | if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5): 105 | return False 106 | 107 | # Should check a hash of the images 108 | return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba")) 109 | 110 | def download(self): 111 | if self._check_integrity(): 112 | print('Files already downloaded and verified') 113 | return 114 | 115 | for (file_id, md5, filename) in self.file_list: 116 | download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5) 117 | 118 | with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f: 119 | f.extractall(os.path.join(self.root, self.base_folder)) 120 | 121 | def __getitem__(self, index): 122 | if self.labelwise: 123 | index = self.idx_map[index] 124 | img_name = self.filename[index] 125 | X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", img_name)) 126 | 127 | target = self.attr[index, self.target_idx] 128 | sensitive = self.attr[index, self.sensi_idx] 129 | feature = self.attr[index, self.feature_idx] 130 | if self.transform is not None: 131 | X = self.transform(X) 132 | 133 | if self.target_transform is not None: 134 | target = self.target_transform(target) 135 | 136 | return X, feature, sensitive, target, (index, img_name) 137 | 138 | def __len__(self): 139 | return len(self.attr) 140 | 141 | def _data_count(self): 142 | data_count = np.zeros((self.num_groups, self.num_classes), dtype=int) 143 | print(' %s mode'%self.split) 144 | for index in range(len(self.attr)): 145 | target = self.attr[index, self.target_idx] 146 | sensitive = self.attr[index, self.sensi_idx] 147 | data_count[sensitive, target] += 1 148 | for i in range(self.num_groups): 149 | print('# of %d groups data : '%i, data_count[i, :]) 150 | return data_count 151 | 152 | def _balance_test_data(self): 153 | num_data_min = np.min(self.num_data) 154 | print('min : ', num_data_min) 155 | data_count = np.zeros((self.num_groups, self.num_classes), dtype=int) 156 | new_filename = [] 157 | new_attr = [] 158 | print(len(self.attr)) 159 | for index in range(len(self.attr)): 160 | target=self.attr[index, self.target_idx] 161 | sensitive = self.attr[index, self.sensi_idx] 162 | if data_count[sensitive, target] < num_data_min: 163 | new_filename.append(self.filename[index]) 164 | new_attr.append(self.attr[index]) 165 | data_count[sensitive, target] += 1 166 | 167 | for i in range(self.num_groups): 168 | print('# of balanced %d\'s groups data : '%i, data_count[i, :]) 169 | 170 | self.filename = new_filename 171 | self.attr = torch.stack(new_attr) 172 | -------------------------------------------------------------------------------- /data_handler/cifar10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | from PIL import Image 4 | import numpy as np 5 | import pickle 6 | 7 | from torchvision.datasets.vision import VisionDataset 8 | from torchvision.datasets.utils import check_integrity, download_and_extract_archive 9 | 10 | 11 | def rgb_to_grayscale(img): 12 | """Convert image to gray scale""" 13 | pil_gray_img = img.convert('L') 14 | np_gray_img = np.array(pil_gray_img, dtype=np.uint8) 15 | np_gray_img = np.dstack([np_gray_img, np_gray_img, np_gray_img]) 16 | 17 | return np_gray_img 18 | 19 | 20 | class CIFAR_10S(VisionDataset): 21 | def __init__(self, root, split='train', transform=None, target_transform=None, 22 | seed=0, skewed_ratio=0.8, labelwise=False): 23 | super(CIFAR_10S, self).__init__(root, transform=transform, target_transform=target_transform) 24 | 25 | self.split = split 26 | self.seed = seed 27 | 28 | self.num_classes = 10 29 | self.num_groups = 2 30 | 31 | imgs, labels, colors, data_count = self._make_skewed(split, seed, skewed_ratio, self.num_classes) 32 | 33 | self.dataset = {} 34 | self.dataset['image'] = np.array(imgs) 35 | self.dataset['label'] = np.array(labels) 36 | self.dataset['color'] = np.array(colors) 37 | 38 | self._get_label_list() 39 | self.labelwise = labelwise 40 | 41 | self.num_data = data_count 42 | 43 | if self.labelwise: 44 | self.idx_map = self._make_idx_map() 45 | 46 | def _make_idx_map(self): 47 | idx_map = [[] for i in range(self.num_groups * self.num_classes)] 48 | for j, i in enumerate(self.dataset['image']): 49 | y = self.dataset['label'][j] 50 | s = self.dataset['color'][j] 51 | pos = s * self.num_classes + y 52 | idx_map[int(pos)].append(j) 53 | final_map = [] 54 | for l in idx_map: 55 | final_map.extend(l) 56 | return final_map 57 | 58 | def _get_label_list(self): 59 | self.label_list = [] 60 | for i in range(self.num_classes): 61 | self.label_list.append(sum(self.dataset['label'] == i)) 62 | 63 | def _set_mapping(self): 64 | tmp = [[] for _ in range(self.num_classes)] 65 | for i in range(self.__len__()): 66 | tmp[int(self.dataset['label'][i])].append(i) 67 | self.map = [] 68 | for i in range(len(tmp)): 69 | self.map.extend(tmp[i]) 70 | 71 | def __len__(self): 72 | return len(self.dataset['image']) 73 | 74 | def __getitem__(self, index): 75 | if self.labelwise: 76 | index = self.idx_map[index] 77 | image = self.dataset['image'][index] 78 | label = self.dataset['label'][index] 79 | color = self.dataset['color'][index] 80 | 81 | if self.transform: 82 | image = self.transform(image) 83 | 84 | if self.target_transform: 85 | label = self.target_transform(label) 86 | 87 | return image, 0, np.float32(color), np.int64(label), (index, 0) 88 | 89 | def _make_skewed(self, split='train', seed=0, skewed_ratio=1., num_classes=10): 90 | 91 | train = False if split =='test' else True 92 | cifardata = CIFAR10('./data', train=train, shuffle=True, seed=seed, download=True) 93 | 94 | num_data = 50000 if split =='train' else 20000 95 | 96 | imgs = np.zeros((num_data, 32, 32, 3), dtype=np.uint8) 97 | labels = np.zeros(num_data) 98 | colors = np.zeros(num_data) 99 | data_count = np.zeros((2, 10), dtype=int) 100 | 101 | num_total_train_data = int((50000 // num_classes)) 102 | num_skewed_train_data = int((50000 * skewed_ratio) // num_classes) 103 | 104 | for i, data in enumerate(cifardata): 105 | img, target = data 106 | 107 | if split == 'test': 108 | imgs[i] = rgb_to_grayscale(img) 109 | imgs[i+10000] = np.array(img) 110 | labels[i] = target 111 | labels[i+10000] = target 112 | colors[i] = 0 113 | colors[i+10000] = 1 114 | data_count[0, target] += 1 115 | data_count[1, target] += 1 116 | else: 117 | if target < 5: 118 | if data_count[0, target] < (num_skewed_train_data): 119 | imgs[i] = rgb_to_grayscale(img) 120 | colors[i] = 0 121 | data_count[0, target] += 1 122 | else: 123 | imgs[i] = np.array(img) 124 | colors[i] = 1 125 | data_count[1, target] += 1 126 | labels[i] = target 127 | else: 128 | if data_count[0, target] < (num_total_train_data - num_skewed_train_data): 129 | imgs[i] = rgb_to_grayscale(img) 130 | colors[i] = 0 131 | data_count[0, target] += 1 132 | else: 133 | imgs[i] = np.array(img) 134 | colors[i] = 1 135 | data_count[1, target] += 1 136 | labels[i] = target 137 | 138 | print('<# of Skewed data>') 139 | print(data_count) 140 | 141 | return imgs, labels, colors, data_count 142 | 143 | 144 | class CIFAR10(VisionDataset): 145 | base_folder = 'cifar-10-batches-py' 146 | url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 147 | filename = "cifar-10-python.tar.gz" 148 | tgz_md5 = 'c58f30108f718f92721af3b95e74349a' 149 | train_list = [ 150 | ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], 151 | ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], 152 | ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], 153 | ['data_batch_4', '634d18415352ddfa80567beed471001a'], 154 | ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], 155 | ] 156 | 157 | test_list = [ 158 | ['test_batch', '40351d587109b95175f43aff81a1287e'], 159 | ] 160 | meta = { 161 | 'filename': 'batches.meta', 162 | 'key': 'label_names', 163 | 'md5': '5ff9c542aee3614f3951f8cda6e48888', 164 | } 165 | 166 | def __init__(self, root, train=True, transform=None, target_transform=None, 167 | download=False, shuffle=False, seed=0): 168 | 169 | super(CIFAR10, self).__init__(root, transform=transform, 170 | target_transform=target_transform) 171 | 172 | self.train = train # training set or test set 173 | 174 | if download: 175 | self.download() 176 | 177 | if not self._check_integrity(): 178 | raise RuntimeError('Dataset not found or corrupted.' + 179 | ' You can use download=True to download it') 180 | 181 | if self.train: 182 | downloaded_list = self.train_list 183 | else: 184 | downloaded_list = self.test_list 185 | 186 | self.data = [] 187 | self.targets = [] 188 | 189 | # now load the picked numpy arrays 190 | for file_name, checksum in downloaded_list: 191 | file_path = os.path.join(self.root, self.base_folder, file_name) 192 | with open(file_path, 'rb') as f: 193 | entry = pickle.load(f, encoding='latin1') 194 | self.data.append(entry['data']) 195 | if 'labels' in entry: 196 | self.targets.extend(entry['labels']) 197 | else: 198 | self.targets.extend(entry['fine_labels']) 199 | 200 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) 201 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 202 | 203 | if shuffle: 204 | np.random.seed(seed) 205 | idx = np.arange(len(self.data), dtype=np.int64) 206 | np.random.shuffle(idx) 207 | self.data = self.data[idx] 208 | self.targets = np.array(self.targets)[idx] 209 | 210 | self._load_meta() 211 | 212 | def _load_meta(self): 213 | path = os.path.join(self.root, self.base_folder, self.meta['filename']) 214 | if not check_integrity(path, self.meta['md5']): 215 | raise RuntimeError('Dataset metadata file not found or corrupted.' + 216 | ' You can use download=True to download it') 217 | with open(path, 'rb') as infile: 218 | data = pickle.load(infile, encoding='latin1') 219 | self.classes = data[self.meta['key']] 220 | self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} 221 | 222 | def __getitem__(self, index): 223 | """ 224 | Args: 225 | index (int): Index 226 | 227 | Returns: 228 | tuple: (image, target) where target is index of the target class. 229 | """ 230 | img, target = self.data[index], self.targets[index] 231 | 232 | # doing this so that it is consistent with all other datasets 233 | # to return a PIL Image 234 | img = Image.fromarray(img) 235 | 236 | if self.transform is not None: 237 | img = self.transform(img) 238 | 239 | if self.target_transform is not None: 240 | target = self.target_transform(target) 241 | 242 | return img, target 243 | 244 | def __len__(self): 245 | return len(self.data) 246 | 247 | def _check_integrity(self): 248 | root = self.root 249 | for fentry in (self.train_list + self.test_list): 250 | filename, md5 = fentry[0], fentry[1] 251 | fpath = os.path.join(root, self.base_folder, filename) 252 | if not check_integrity(fpath, md5): 253 | return False 254 | return True 255 | 256 | def download(self): 257 | if self._check_integrity(): 258 | print('Files already downloaded and verified') 259 | return 260 | download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) 261 | 262 | def extra_repr(self): 263 | return "Split: {}".format("Train" if self.train is True else "Test") 264 | 265 | -------------------------------------------------------------------------------- /data_handler/custom_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data.sampler import RandomSampler 3 | 4 | 5 | class Customsampler(RandomSampler): 6 | 7 | def __init__(self, data_source, replacement=False, num_samples=None, batch_size=None, generator=None): 8 | super(Customsampler, self).__init__(data_source=data_source, replacement=replacement, 9 | num_samples=num_samples, generator=generator) 10 | 11 | self.l = data_source.num_classes 12 | self.g = data_source.num_groups 13 | self.nbatch_size = batch_size // (self.l*self.g) 14 | self.num_data = data_source.num_data 15 | pos = np.unravel_index(np.argmax(self.num_data), self.num_data.shape) 16 | self.max_pos = pos[0] * self.g + pos[1] 17 | 18 | def __iter__(self): 19 | final_list = [] 20 | index_list = [] 21 | total_num = 0 22 | for i in range(self.l*self.g): 23 | tmp = np.arange(self.num_data[i//self.l, i%self.l]) + total_num 24 | np.random.shuffle(tmp) 25 | index_list.append(list(tmp)) 26 | if i != self.max_pos: 27 | while len(index_list[-1]) < np.max(self.num_data): 28 | tmp = np.arange(self.num_data[i//self.l, i%self.l]) + total_num 29 | np.random.shuffle(tmp) 30 | index_list[-1].extend(list(tmp)) 31 | total_num += self.num_data[i//self.l, i%self.l] 32 | 33 | for tmp in range(len(index_list[self.max_pos]) // self.nbatch_size): 34 | for list_ in index_list: 35 | final_list.extend(list_[tmp*self.nbatch_size:(tmp+1)*self.nbatch_size]) 36 | 37 | return iter(final_list) 38 | 39 | 40 | def gen(index_list, nbatch_size): 41 | idx = 0 42 | np.random.shuffle(index_list) 43 | while True: 44 | idx += nbatch_size 45 | if idx > len(index_list): 46 | print('lets go') 47 | raise StopIteration 48 | yield index_list[idx-nbatch_size:nbatch_size] 49 | 50 | -------------------------------------------------------------------------------- /data_handler/dataloader_factory.py: -------------------------------------------------------------------------------- 1 | from data_handler.dataset_factory import DatasetFactory 2 | 3 | import numpy as np 4 | from torchvision import transforms 5 | from torch.utils.data import DataLoader 6 | 7 | 8 | class DataloaderFactory: 9 | def __init__(self): 10 | pass 11 | 12 | @staticmethod 13 | def get_dataloader(name, img_size=224, batch_size=256, seed = 0, num_workers=4, 14 | target='Smiling', skew_ratio=1., labelwise=False): 15 | 16 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 17 | std=[0.229, 0.224, 0.225]) 18 | 19 | if name == 'celeba': 20 | transform_list = [transforms.RandomResizedCrop(img_size), 21 | transforms.RandomHorizontalFlip(), 22 | transforms.ToTensor(), 23 | normalize 24 | ] 25 | 26 | elif 'cifar10' in name: 27 | transform_list = [transforms.ToPILImage(), 28 | transforms.RandomHorizontalFlip(), 29 | transforms.ToTensor() 30 | ] 31 | else: 32 | transform_list = [transforms.Resize((256,256)), 33 | transforms.RandomCrop(img_size), 34 | transforms.RandomHorizontalFlip(), 35 | transforms.ToTensor(), 36 | normalize 37 | ] 38 | 39 | if 'cifar10' in name: 40 | test_transform_list = [transforms.ToTensor()] 41 | else: 42 | test_transform_list = [transforms.Resize((img_size,img_size)), 43 | transforms.ToTensor(), 44 | normalize] 45 | preprocessing = transforms.Compose(transform_list) 46 | test_preprocessing = transforms.Compose(test_transform_list) 47 | 48 | test_dataset = DatasetFactory.get_dataset(name, test_preprocessing, 'test', target=target, 49 | seed=seed, skew_ratio=skew_ratio) 50 | train_dataset = DatasetFactory.get_dataset(name, preprocessing, split='train', target=target, 51 | seed=seed, skew_ratio=skew_ratio, labelwise=labelwise) 52 | 53 | def _init_fn(worker_id): 54 | np.random.seed(int(seed)) 55 | 56 | num_classes = test_dataset.num_classes 57 | num_groups = test_dataset.num_groups 58 | 59 | if labelwise: 60 | from data_handler.custom_loader import Customsampler 61 | sampler = Customsampler(train_dataset, replacement=False, batch_size=batch_size) 62 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, 63 | num_workers=num_workers, worker_init_fn=_init_fn, pin_memory=True, drop_last=True) 64 | else: 65 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 66 | num_workers=num_workers, worker_init_fn=_init_fn, pin_memory=True, drop_last=True) 67 | 68 | test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, 69 | num_workers=num_workers, worker_init_fn=_init_fn, pin_memory=True) 70 | 71 | print('# of test data : {}'.format(len(test_dataset))) 72 | print('Dataset loaded.') 73 | 74 | return num_classes, num_groups, train_dataloader, test_dataloader 75 | 76 | -------------------------------------------------------------------------------- /data_handler/dataset_factory.py: -------------------------------------------------------------------------------- 1 | class DatasetFactory: 2 | def __init__(self): 3 | pass 4 | 5 | @staticmethod 6 | def get_dataset(name, transform=None, split='Train', target='Attractive', seed=0, skew_ratio=1., labelwise=False): 7 | 8 | if name == "utkface": 9 | from data_handler.utkface import UTKFaceDataset 10 | root = './data/UTKFace' 11 | return UTKFaceDataset(root=root, split=split, transform=transform, 12 | labelwise=labelwise) 13 | 14 | elif name == "celeba": 15 | from data_handler.celeba import CelebA 16 | root='./data/' 17 | return CelebA(root=root, split=split, transform=transform, target_attr=target, labelwise=labelwise) 18 | 19 | elif name == "cifar10": 20 | from data_handler.cifar10 import CIFAR_10S 21 | root = './data/cifar10' 22 | return CIFAR_10S(root=root, split=split, transform=transform, seed=seed, skewed_ratio=skew_ratio, 23 | labelwise=labelwise) 24 | -------------------------------------------------------------------------------- /data_handler/utkface.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | from torchvision.datasets.vision import VisionDataset 3 | from PIL import Image 4 | from utils import list_files 5 | from natsort import natsorted 6 | import random 7 | import numpy as np 8 | 9 | class UTKFaceDataset(VisionDataset): 10 | 11 | label = 'age' 12 | sensi = 'race' 13 | fea_map = { 14 | 'age' : 0, 15 | 'gender' : 1, 16 | 'race' : 2 17 | } 18 | num_map = { 19 | 'age' : 100, 20 | 'gender' : 2, 21 | 'race' : 4 22 | } 23 | 24 | def __init__(self, root, split='train', transform=None, target_transform=None, 25 | labelwise=False): 26 | 27 | super(UTKFaceDataset, self).__init__(root, transform=transform, 28 | target_transform=target_transform) 29 | 30 | self.split = split 31 | self.filename = list_files(root, '.jpg') 32 | self.filename = natsorted(self.filename) 33 | self._delete_incomplete_images() 34 | self._delete_others_n_age_filter() 35 | self.num_groups = self.num_map[self.sensi] 36 | self.num_classes = self.num_map[self.label] 37 | self.labelwise = labelwise 38 | 39 | random.seed(1) 40 | random.shuffle(self.filename) 41 | 42 | self._make_data() 43 | self.num_data = self._data_count() 44 | 45 | if self.labelwise: 46 | self.idx_map = self._make_idx_map() 47 | 48 | def __len__(self): 49 | return len(self.filename) 50 | 51 | def __getitem__(self, index): 52 | if self.labelwise: 53 | index = self.idx_map[index] 54 | img_name = self.filename[index] 55 | s, l = self._filename2SY(img_name) 56 | 57 | image_path = join(self.root, img_name) 58 | image = Image.open(image_path, mode='r').convert('RGB') 59 | 60 | if self.transform: 61 | image = self.transform(image) 62 | 63 | return image, 1, np.float32(s), np.int64(l), (index, img_name) 64 | 65 | def _make_idx_map(self): 66 | idx_map = [[] for i in range(self.num_groups * self.num_classes)] 67 | for j, i in enumerate(self.filename): 68 | s, y = self._filename2SY(i) 69 | pos = s*self.num_classes + y 70 | idx_map[pos].append(j) 71 | 72 | final_map = [] 73 | for l in idx_map: 74 | final_map.extend(l) 75 | return final_map 76 | 77 | def lg_filter(self, l, g): 78 | tmp = [] 79 | for i in self.filename: 80 | g_, l_ = self._filename2SY(i) 81 | if l == l_ and g == g_: 82 | tmp.append(i) 83 | return tmp 84 | 85 | def _delete_incomplete_images(self): 86 | self.filename = [image for image in self.filename if len(image.split('_')) == 4] 87 | 88 | def _delete_others_n_age_filter(self): 89 | 90 | self.filename = [image for image in self.filename 91 | if ((image.split('_')[self.fea_map['race']] != '4'))] 92 | ages = [self._transform_age(int(image.split('_')[self.fea_map['age']])) for image in self.filename] 93 | self.num_map['age'] = len(set(ages)) 94 | 95 | def _filename2SY(self, filename): 96 | tmp = filename.split('_') 97 | sensi = int(tmp[self.fea_map[self.sensi]]) 98 | label = int(tmp[self.fea_map[self.label]]) 99 | if self.sensi == 'age': 100 | sensi = self._transform_age(sensi) 101 | if self.label == 'age': 102 | label = self._transform_age(label) 103 | return int(sensi), int(label) 104 | 105 | def _transform_age(self, age): 106 | if age<20: 107 | label = 0 108 | elif age<40: 109 | label = 1 110 | else: 111 | label = 2 112 | return label 113 | 114 | def _make_data(self): 115 | import copy 116 | min_cnt = 100 117 | data_count = np.zeros((self.num_groups, self.num_classes), dtype=int) 118 | if self.split == 'train': 119 | tmp = copy.deepcopy(self.filename) 120 | else: 121 | tmp = [] 122 | 123 | for i in reversed(self.filename): 124 | s, l = self._filename2SY(i) 125 | data_count[s, l] += 1 126 | if data_count[s, l] <= min_cnt: 127 | if self.split =='train': 128 | tmp.remove(i) 129 | else: 130 | tmp.append(i) 131 | 132 | self.filename = tmp 133 | 134 | def _data_count(self): 135 | data_count = np.zeros((self.num_groups, self.num_classes), dtype=int) 136 | data_set = self.filename 137 | 138 | for img_name in data_set: 139 | s, l = self._filename2SY(img_name) 140 | data_count[s, l] += 1 141 | 142 | for i in range(self.num_groups): 143 | print('# of %d groyp data : '%i, data_count[i, :]) 144 | return data_count 145 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import numpy as np 5 | import networks 6 | import data_handler 7 | import trainer 8 | from utils import check_log_dir, make_log_name, set_seed 9 | 10 | from arguments import get_args 11 | import time 12 | import os 13 | 14 | args = get_args() 15 | 16 | 17 | def main(): 18 | 19 | torch.backends.cudnn.enabled = True 20 | 21 | seed = args.seed 22 | set_seed(seed) 23 | 24 | np.set_printoptions(precision=4) 25 | torch.set_printoptions(precision=4) 26 | 27 | log_name = make_log_name(args) 28 | dataset = args.dataset 29 | save_dir = os.path.join(args.save_dir, args.date, dataset, args.method) 30 | log_dir = os.path.join(args.log_dir, args.date, dataset, args.method) 31 | check_log_dir(save_dir) 32 | check_log_dir(log_dir) 33 | ########################## get dataloader ################################ 34 | 35 | tmp = data_handler.DataloaderFactory.get_dataloader(args.dataset, img_size=args.img_size, 36 | batch_size=args.batch_size, seed=args.seed, 37 | num_workers=args.num_workers, 38 | target=args.target, 39 | skew_ratio=args.skew_ratio, 40 | labelwise=args.labelwise 41 | ) 42 | num_classes, num_groups, train_loader, test_loader = tmp 43 | 44 | ########################## get model ################################## 45 | 46 | model = networks.ModelFactory.get_model(args.model, num_classes, args.img_size, pretrained=args.pretrained) 47 | 48 | if args.parallel: 49 | model = nn.DataParallel(model) 50 | 51 | model.cuda('cuda:{}'.format(args.device)) 52 | 53 | if args.modelpath is not None: 54 | model.load_state_dict(torch.load(args.modelpath)) 55 | 56 | teacher = None 57 | if (args.method.startswith('kd') or args.teacher_path is not None) and args.mode != 'eval': 58 | teacher = networks.ModelFactory.get_model(args.model, train_loader.dataset.num_classes, args.img_size) 59 | if args.parallel: 60 | teacher = nn.DataParallel(teacher) 61 | teacher.load_state_dict(torch.load(args.teacher_path)) 62 | teacher.cuda('cuda:{}'.format(args.t_device)) 63 | 64 | ########################## get trainer ################################## 65 | 66 | if args.optimizer == 'Adam': 67 | optimizer = optim.Adam(model.parameters(), lr=args.lr, ) 68 | elif 'SGD' in args.optimizer: 69 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9) 70 | 71 | trainer_ = trainer.TrainerFactory.get_trainer(args.method, model=model, args=args, 72 | optimizer=optimizer, teacher=teacher) 73 | 74 | ####################### start training or evaluating #################### 75 | 76 | if args.mode == 'train': 77 | start_t = time.time() 78 | trainer_.train(train_loader, test_loader, args.epochs) 79 | end_t = time.time() 80 | train_t = int((end_t - start_t)/60) # to minutes 81 | print('Training Time : {} hours {} minutes'.format(int(train_t/60), (train_t % 60))) 82 | trainer_.save_model(save_dir, log_name) 83 | 84 | else: 85 | print('Evaluation ----------------') 86 | model_to_load = args.modelpath 87 | trainer_.model.load_state_dict(torch.load(model_to_load)) 88 | print('Trained model loaded successfully') 89 | 90 | if args.evalset == 'all': 91 | trainer_.compute_confusion_matix('train', train_loader.dataset.num_classes, train_loader, log_dir, log_name) 92 | trainer_.compute_confusion_matix('test', test_loader.dataset.num_classes, test_loader, log_dir, log_name) 93 | 94 | elif args.evalset == 'train': 95 | trainer_.compute_confusion_matix('train', train_loader.dataset.num_classes, train_loader, log_dir, log_name) 96 | else: 97 | trainer_.compute_confusion_matix('test', test_loader.dataset.num_classes, test_loader, log_dir, log_name) 98 | 99 | print('Done!') 100 | 101 | 102 | if __name__ == '__main__': 103 | main() 104 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.model_factory import * -------------------------------------------------------------------------------- /networks/cifar_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class Net(nn.Module): 7 | def __init__(self, num_classes=10): 8 | super().__init__() 9 | 10 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) 11 | s = compute_conv_output_size(32, 3, padding=1) # 32 12 | self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 13 | s = compute_conv_output_size(s, 3, padding=1) # 32 14 | s = s // 2 # 16 15 | self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 16 | s = compute_conv_output_size(s, 3, padding=1) # 16 17 | self.conv4 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 18 | s = compute_conv_output_size(s, 3, padding=1) # 16 19 | s = s // 2 # 8 20 | self.conv5 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 21 | s = compute_conv_output_size(s, 3, padding=1) # 8 22 | self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 23 | s = compute_conv_output_size(s, 3, padding=1) # 8 24 | 25 | s = s // 2 # 4 26 | self.fc1 = nn.Linear(s * s * 128, 256) # 2048 27 | self.drop1 = nn.Dropout(0.25) 28 | self.drop2 = nn.Dropout(0.5) 29 | self.MaxPool = torch.nn.MaxPool2d(2) 30 | 31 | self.last = torch.nn.Linear(256, num_classes) 32 | self.relu = torch.nn.ReLU() 33 | 34 | def forward(self, x, get_inter=False, before_fc=False): 35 | act1 = self.relu(self.conv1(x)) 36 | act2 = self.relu(self.conv2(act1)) 37 | h = self.drop1(self.MaxPool(act2)) 38 | act3 = self.relu(self.conv3(h)) 39 | act4 = self.relu(self.conv4(act3)) 40 | h = self.drop1(self.MaxPool(act4)) 41 | act5 = self.relu(self.conv5(h)) 42 | act6 = self.relu(self.conv6(act5)) 43 | h = self.drop1(self.MaxPool(act6)) 44 | h = h.view(x.shape[0], -1) 45 | act7 = self.relu(self.fc1(h)) 46 | # h = self.drop2(act7) 47 | y=self.last(act7) 48 | 49 | if get_inter: 50 | if before_fc: 51 | return act6, y 52 | else: 53 | return act7, y 54 | else: 55 | return y 56 | 57 | 58 | def compute_conv_output_size(l_in, kernel_size, stride=1, padding=0, dilation=1): 59 | return int(np.floor((l_in + 2 * padding - dilation * (kernel_size - 1) - 1) / float(stride) + 1)) 60 | -------------------------------------------------------------------------------- /networks/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Function 5 | 6 | class MLP(nn.Module): 7 | def __init__(self, feature_size, hidden_dim, num_class=None, num_layer=2, adv=False, adv_lambda=1.): 8 | super(MLP, self).__init__() 9 | try: 10 | in_features = self.compute_input_size(feature_size) # if list 11 | except: 12 | in_features = feature_size # if int 13 | 14 | self.adv = adv 15 | if self.adv: 16 | self.adv_lambda = adv_lambda 17 | 18 | self.num_layer = num_layer 19 | 20 | fc = [] 21 | in_dim = in_features 22 | for i in range(num_layer-1): 23 | fc.append(nn.Linear(in_dim, hidden_dim)) 24 | fc.append(nn.ReLU()) 25 | in_dim = hidden_dim 26 | 27 | fc.append(in_dim, num_class) 28 | self.fc = nn.Sequential(*fc) 29 | 30 | def forward(self, feature): 31 | feature = torch.flatten(feature, 1) 32 | if self.adv: 33 | feature = ReverseLayerF.apply(feature, self.adv_lambda) 34 | 35 | out = self.fc(feature) 36 | 37 | return out 38 | 39 | def compute_input_size(self, feature_size): 40 | in_features = 1 41 | for size in feature_size: 42 | in_features = in_features * size 43 | 44 | return in_features 45 | 46 | 47 | class ReverseLayerF(Function): 48 | 49 | @staticmethod 50 | def forward(ctx, x, alpha): 51 | ctx.alpha = alpha 52 | 53 | return x.view_as(x) 54 | 55 | @staticmethod 56 | def backward(ctx, grad_output): 57 | return grad_output.neg() * ctx.alpha, None 58 | -------------------------------------------------------------------------------- /networks/model_factory.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from networks.resnet import resnet18 4 | from networks.shufflenet import shufflenet_v2_x1_0 5 | from networks.cifar_net import Net 6 | from networks.mlp import MLP 7 | 8 | 9 | class ModelFactory(): 10 | def __init__(self): 11 | pass 12 | 13 | @staticmethod 14 | def get_model(target_model, num_classes, img_size, pretrained=False): 15 | 16 | if target_model == 'mlp': 17 | return MLP(feature_size=img_size, hidden_dim=40, num_class=num_classes) 18 | 19 | elif target_model == 'resnet': 20 | if pretrained: 21 | model = resnet18(pretrained=True) 22 | model.fc = nn.Linear(in_features=512, out_features=num_classes, bias=True) 23 | else: 24 | model = resnet18(pretrained=False, num_classes=num_classes) 25 | return model 26 | 27 | elif target_model == 'cifar_net': 28 | return Net(num_classes=num_classes) 29 | 30 | elif target_model == 'shufflenet': 31 | if pretrained: 32 | model = shufflenet_v2_x1_0(pretrained=True) 33 | model.fc = nn.Linear(in_features=1024, out_features=num_classes, bias=True) 34 | else: 35 | model = shufflenet_v2_x1_0(pretrained=False, num_classes=num_classes) 36 | return model 37 | 38 | else: 39 | raise NotImplementedError 40 | 41 | -------------------------------------------------------------------------------- /networks/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.models.utils import load_state_dict_from_url 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50'] 7 | 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | } 14 | 15 | 16 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 17 | """3x3 convolution with padding""" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 19 | padding=dilation, groups=groups, bias=False, dilation=dilation) 20 | 21 | 22 | def conv1x1(in_planes, out_planes, stride=1): 23 | """1x1 convolution""" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 25 | 26 | 27 | class BasicBlock(nn.Module): 28 | expansion = 1 29 | __constants__ = ['downsample'] 30 | 31 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 32 | base_width=64, dilation=1, norm_layer=None): 33 | super(BasicBlock, self).__init__() 34 | if norm_layer is None: 35 | norm_layer = nn.BatchNorm2d 36 | if groups != 1 or base_width != 64: 37 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 38 | if dilation > 1: 39 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 40 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 41 | self.conv1 = conv3x3(inplanes, planes, stride) 42 | self.bn1 = norm_layer(planes) 43 | self.relu = nn.ReLU(inplace=True) 44 | self.conv2 = conv3x3(planes, planes) 45 | self.bn2 = norm_layer(planes) 46 | self.downsample = downsample 47 | self.stride = stride 48 | 49 | def forward(self, x): 50 | identity = x 51 | 52 | out = self.conv1(x) 53 | out = self.bn1(out) 54 | out = self.relu(out) 55 | 56 | out = self.conv2(out) 57 | out = self.bn2(out) 58 | 59 | if self.downsample is not None: 60 | identity = self.downsample(x) 61 | 62 | out += identity 63 | out = self.relu(out) 64 | 65 | return out 66 | 67 | 68 | class Bottleneck(nn.Module): 69 | expansion = 4 70 | __constants__ = ['downsample'] 71 | 72 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 73 | base_width=64, dilation=1, norm_layer=None): 74 | super(Bottleneck, self).__init__() 75 | if norm_layer is None: 76 | norm_layer = nn.BatchNorm2d 77 | width = int(planes * (base_width / 64.)) * groups 78 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 79 | self.conv1 = conv1x1(inplanes, width) 80 | self.bn1 = norm_layer(width) 81 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 82 | self.bn2 = norm_layer(width) 83 | self.conv3 = conv1x1(width, planes * self.expansion) 84 | self.bn3 = norm_layer(planes * self.expansion) 85 | self.relu = nn.ReLU(inplace=True) 86 | self.downsample = downsample 87 | self.stride = stride 88 | 89 | def forward(self, x): 90 | identity = x 91 | 92 | out = self.conv1(x) 93 | out = self.bn1(out) 94 | out = self.relu(out) 95 | 96 | out = self.conv2(out) 97 | out = self.bn2(out) 98 | out = self.relu(out) 99 | 100 | out = self.conv3(out) 101 | out = self.bn3(out) 102 | 103 | if self.downsample is not None: 104 | identity = self.downsample(x) 105 | 106 | out += identity 107 | out = self.relu(out) 108 | 109 | return out 110 | 111 | 112 | class ResNet(nn.Module): 113 | 114 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 115 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 116 | norm_layer=None): 117 | super(ResNet, self).__init__() 118 | if norm_layer is None: 119 | norm_layer = nn.BatchNorm2d 120 | self._norm_layer = norm_layer 121 | 122 | self.inplanes = 64 123 | self.dilation = 1 124 | if replace_stride_with_dilation is None: 125 | # each element in the tuple indicates if we should replace 126 | # the 2x2 stride with a dilated convolution instead 127 | replace_stride_with_dilation = [False, False, False] 128 | if len(replace_stride_with_dilation) != 3: 129 | raise ValueError("replace_stride_with_dilation should be None " 130 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 131 | self.groups = groups 132 | self.base_width = width_per_group 133 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 134 | bias=False) 135 | self.bn1 = norm_layer(self.inplanes) 136 | self.relu = nn.ReLU(inplace=True) 137 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 138 | self.layer1 = self._make_layer(block, 64, layers[0]) 139 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 140 | dilate=replace_stride_with_dilation[0]) 141 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 142 | dilate=replace_stride_with_dilation[1]) 143 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 144 | dilate=replace_stride_with_dilation[2]) 145 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 146 | self.fc = nn.Linear(512 * block.expansion, num_classes) 147 | 148 | 149 | for m in self.modules(): 150 | if isinstance(m, nn.Conv2d): 151 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 152 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 153 | nn.init.constant_(m.weight, 1) 154 | nn.init.constant_(m.bias, 0) 155 | 156 | # Zero-initialize the last BN in each residual branch, 157 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 158 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 159 | if zero_init_residual: 160 | for m in self.modules(): 161 | if isinstance(m, Bottleneck): 162 | nn.init.constant_(m.bn3.weight, 0) 163 | elif isinstance(m, BasicBlock): 164 | nn.init.constant_(m.bn2.weight, 0) 165 | 166 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 167 | norm_layer = self._norm_layer 168 | downsample = None 169 | previous_dilation = self.dilation 170 | if dilate: 171 | self.dilation *= stride 172 | stride = 1 173 | if stride != 1 or self.inplanes != planes * block.expansion: 174 | downsample = nn.Sequential( 175 | conv1x1(self.inplanes, planes * block.expansion, stride), 176 | norm_layer(planes * block.expansion), 177 | ) 178 | 179 | layers = [] 180 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 181 | self.base_width, previous_dilation, norm_layer)) 182 | self.inplanes = planes * block.expansion 183 | for _ in range(1, blocks): 184 | layers.append(block(self.inplanes, planes, groups=self.groups, 185 | base_width=self.base_width, dilation=self.dilation, 186 | norm_layer=norm_layer)) 187 | 188 | return nn.Sequential(*layers) 189 | 190 | def _forward_impl(self, x, get_inter=False): 191 | # See note [TorchScript super()] 192 | h = self.conv1(x) 193 | h = self.bn1(h) 194 | h = self.relu(h) 195 | h = self.maxpool(h) 196 | 197 | b1 = self.layer1(h) 198 | b2 = self.layer2(b1) 199 | b3 = self.layer3(b2) 200 | b4 = self.layer4(b3) 201 | 202 | h = self.avgpool(b4) 203 | h = torch.flatten(h, 1) 204 | h = self.fc(h) 205 | 206 | if get_inter: 207 | return b1, b2, b3, b4, h 208 | else: 209 | return h 210 | 211 | def forward(self, x, get_inter=False): 212 | return self._forward_impl(x, get_inter) 213 | 214 | 215 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 216 | model = ResNet(block, layers, **kwargs) 217 | if pretrained: 218 | state_dict = load_state_dict_from_url(model_urls[arch], 219 | progress=progress) 220 | model.load_state_dict(state_dict) 221 | return model 222 | 223 | 224 | def resnet18(pretrained=False, progress=True, **kwargs): 225 | r"""ResNet-18 model from 226 | `"Deep Residual Learning for Image Recognition" `_ 227 | Args: 228 | pretrained (bool): If True, returns a model pre-trained on ImageNet 229 | progress (bool): If True, displays a progress bar of the download to stderr 230 | """ 231 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 232 | **kwargs) 233 | 234 | 235 | def resnet34(pretrained=False, progress=True, **kwargs): 236 | r"""ResNet-34 model from 237 | `"Deep Residual Learning for Image Recognition" `_ 238 | Args: 239 | pretrained (bool): If True, returns a model pre-trained on ImageNet 240 | progress (bool): If True, displays a progress bar of the download to stderr 241 | """ 242 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 243 | **kwargs) 244 | 245 | 246 | def resnet50(pretrained=False, progress=True, **kwargs): 247 | r"""ResNet-50 model from 248 | `"Deep Residual Learning for Image Recognition" `_ 249 | Args: 250 | pretrained (bool): If True, returns a model pre-trained on ImageNet 251 | progress (bool): If True, displays a progress bar of the download to stderr 252 | """ 253 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 254 | **kwargs) 255 | -------------------------------------------------------------------------------- /networks/shufflenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.models.utils import load_state_dict_from_url 4 | 5 | 6 | __all__ = [ 7 | 'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 8 | 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0' 9 | ] 10 | 11 | model_urls = { 12 | 'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth', 13 | 'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth', 14 | 'shufflenetv2_x1.5': None, 15 | 'shufflenetv2_x2.0': None, 16 | } 17 | 18 | 19 | def channel_shuffle(x, groups): 20 | # type: (torch.Tensor, int) -> torch.Tensor 21 | batchsize, num_channels, height, width = x.data.size() 22 | channels_per_group = num_channels // groups 23 | 24 | # reshape 25 | x = x.view(batchsize, groups, 26 | channels_per_group, height, width) 27 | 28 | x = torch.transpose(x, 1, 2).contiguous() 29 | 30 | # flatten 31 | x = x.view(batchsize, -1, height, width) 32 | 33 | return x 34 | 35 | 36 | class InvertedResidual(nn.Module): 37 | def __init__(self, inp, oup, stride): 38 | super(InvertedResidual, self).__init__() 39 | 40 | if not (1 <= stride <= 3): 41 | raise ValueError('illegal stride value') 42 | self.stride = stride 43 | 44 | branch_features = oup // 2 45 | assert (self.stride != 1) or (inp == branch_features << 1) 46 | 47 | if self.stride > 1: 48 | self.branch1 = nn.Sequential( 49 | self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), 50 | nn.BatchNorm2d(inp), 51 | nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 52 | nn.BatchNorm2d(branch_features), 53 | nn.ReLU(inplace=True), 54 | ) 55 | else: 56 | self.branch1 = nn.Sequential() 57 | 58 | self.branch2 = nn.Sequential( 59 | nn.Conv2d(inp if (self.stride > 1) else branch_features, 60 | branch_features, kernel_size=1, stride=1, padding=0, bias=False), 61 | nn.BatchNorm2d(branch_features), 62 | nn.ReLU(inplace=True), 63 | self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), 64 | nn.BatchNorm2d(branch_features), 65 | nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 66 | nn.BatchNorm2d(branch_features), 67 | nn.ReLU(inplace=True), 68 | ) 69 | 70 | @staticmethod 71 | def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): 72 | return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) 73 | 74 | def forward(self, x): 75 | if self.stride == 1: 76 | x1, x2 = x.chunk(2, dim=1) 77 | out = torch.cat((x1, self.branch2(x2)), dim=1) 78 | else: 79 | out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) 80 | 81 | out = channel_shuffle(out, 2) 82 | 83 | return out 84 | 85 | 86 | class ShuffleNetV2(nn.Module): 87 | def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, inverted_residual=InvertedResidual): 88 | super(ShuffleNetV2, self).__init__() 89 | 90 | if len(stages_repeats) != 3: 91 | raise ValueError('expected stages_repeats as list of 3 positive ints') 92 | if len(stages_out_channels) != 5: 93 | raise ValueError('expected stages_out_channels as list of 5 positive ints') 94 | self._stage_out_channels = stages_out_channels 95 | 96 | input_channels = 3 97 | output_channels = self._stage_out_channels[0] 98 | self.conv1 = nn.Sequential( 99 | nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), 100 | nn.BatchNorm2d(output_channels), 101 | nn.ReLU(inplace=True), 102 | ) 103 | input_channels = output_channels 104 | 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | 107 | stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] 108 | for name, repeats, output_channels in zip( 109 | stage_names, stages_repeats, self._stage_out_channels[1:]): 110 | seq = [inverted_residual(input_channels, output_channels, 2)] 111 | for i in range(repeats - 1): 112 | seq.append(inverted_residual(output_channels, output_channels, 1)) 113 | setattr(self, name, nn.Sequential(*seq)) 114 | input_channels = output_channels 115 | 116 | output_channels = self._stage_out_channels[-1] 117 | self.conv5 = nn.Sequential( 118 | nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False), 119 | nn.BatchNorm2d(output_channels), 120 | nn.ReLU(inplace=True), 121 | ) 122 | 123 | self.fc = nn.Linear(output_channels, num_classes) 124 | 125 | def _forward_impl(self, x, get_inter=False): 126 | # See note [TorchScript super()] 127 | x = self.conv1(x) 128 | x = self.maxpool(x) 129 | x = self.stage2(x) 130 | x = self.stage3(x) 131 | x = self.stage4(x) 132 | h = self.conv5(x) 133 | h1 = h.mean([2, 3]) # globalpool 134 | out = self.fc(h1) 135 | if get_inter: 136 | return h, out 137 | else: 138 | return out 139 | 140 | def forward(self, x, get_inter=False): 141 | return self._forward_impl(x, get_inter) 142 | 143 | 144 | def _shufflenetv2(arch, pretrained, progress, *args, **kwargs): 145 | model = ShuffleNetV2(*args, **kwargs) 146 | 147 | if pretrained: 148 | model_url = model_urls[arch] 149 | if model_url is None: 150 | raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) 151 | else: 152 | state_dict = load_state_dict_from_url(model_url, progress=progress) 153 | model.load_state_dict(state_dict) 154 | 155 | return model 156 | 157 | 158 | def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs): 159 | """ 160 | Constructs a ShuffleNetV2 with 0.5x output channels, as described in 161 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 162 | `_. 163 | 164 | Args: 165 | pretrained (bool): If True, returns a model pre-trained on ImageNet 166 | progress (bool): If True, displays a progress bar of the download to stderr 167 | """ 168 | return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, 169 | [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) 170 | 171 | 172 | 173 | def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs): 174 | """ 175 | Constructs a ShuffleNetV2 with 1.0x output channels, as described in 176 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 177 | `_. 178 | 179 | Args: 180 | pretrained (bool): If True, returns a model pre-trained on ImageNet 181 | progress (bool): If True, displays a progress bar of the download to stderr 182 | """ 183 | return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, 184 | [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) 185 | 186 | 187 | 188 | def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs): 189 | """ 190 | Constructs a ShuffleNetV2 with 1.5x output channels, as described in 191 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 192 | `_. 193 | 194 | Args: 195 | pretrained (bool): If True, returns a model pre-trained on ImageNet 196 | progress (bool): If True, displays a progress bar of the download to stderr 197 | """ 198 | return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, 199 | [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) 200 | 201 | 202 | 203 | def shufflenet_v2_x2_0(pretrained=False, progress=True, **kwargs): 204 | """ 205 | Constructs a ShuffleNetV2 with 2.0x output channels, as described in 206 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 207 | `_. 208 | 209 | Args: 210 | pretrained (bool): If True, returns a model pre-trained on ImageNet 211 | progress (bool): If True, displays a progress bar of the download to stderr 212 | """ 213 | return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, 214 | [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) 215 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from trainer.trainer_factory import * -------------------------------------------------------------------------------- /trainer/adv_debiasing.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | 7 | import time 8 | 9 | from utils import get_accuracy 10 | from networks.mlp import MLP 11 | from trainer.loss_utils import compute_hinton_loss, compute_feature_loss 12 | import trainer 13 | from torch.optim.lr_scheduler import ReduceLROnPlateau 14 | 15 | 16 | class Trainer(trainer.GenericTrainer): 17 | def __init__(self, args, teacher, **kwargs): 18 | super().__init__(args=args, **kwargs) 19 | self.lambh = args.lambh 20 | self.lambf = args.lambf 21 | self.kd_temp = args.kd_temp 22 | self.teacher = teacher 23 | 24 | self.adv_lambda = args.adv_lambda 25 | self.adv_lr = args.eta 26 | self.no_annealing = args.no_annealing 27 | 28 | def train(self, train_loader, test_loader, epochs): 29 | model = self.model 30 | model.train() 31 | num_groups = train_loader.dataset.num_groups 32 | num_classes = train_loader.dataset.num_classes 33 | if self.teacher is not None: 34 | self.teacher.eval() 35 | self._init_adversary(num_groups, num_classes) 36 | sa_clf_list = self.sa_clf_list 37 | 38 | for epoch in range(epochs): 39 | self._train_epoch(epoch, train_loader, model, sa_clf_list, self.teacher) 40 | 41 | eval_start_time = time.time() 42 | eval_loss, eval_acc, eval_adv_loss, eval_adv_acc, eval_deopp, eval_adv_loss_list = \ 43 | self.evaluate(model, sa_clf_list, test_loader, self.criterion, self.adv_criterion) 44 | eval_end_time = time.time() 45 | print('[{}/{}] Method: {} ' 46 | 'Test Loss: {:.3f} Test Acc: {:.2f} Test Adv Acc: {:.2f} Test DEopp {:.2f} [{:.2f} s]'.format 47 | (epoch + 1, epochs, self.method, 48 | eval_loss, eval_acc, eval_adv_acc, eval_deopp, (eval_end_time - eval_start_time))) 49 | 50 | if self.scheduler != None: 51 | self.scheduler.step(eval_loss) 52 | if len(self.adv_scheduler_list) != 0: 53 | for c in range(num_classes): 54 | self.adv_scheduler_list[c].step(eval_adv_loss_list[c]) 55 | 56 | print('Training Finished!') 57 | 58 | def _train_epoch(self, epoch, train_loader, model, sa_clf, teacher=None): 59 | num_classes = train_loader.dataset.num_classes 60 | 61 | model.train() 62 | if teacher is not None: 63 | teacher.eval() 64 | 65 | running_acc = 0.0 66 | running_loss = 0.0 67 | batch_start_time = time.time() 68 | 69 | for i, data in enumerate(train_loader): 70 | # Get the inputs 71 | inputs, _, groups, targets, _ = data 72 | labels = targets 73 | groups = groups.long() 74 | 75 | if self.cuda: 76 | inputs = inputs.cuda(device=self.device) 77 | labels = labels.cuda(device=self.device) 78 | groups = groups.cuda(device=self.device) 79 | 80 | kd_loss = 0 81 | feature_loss = 0 82 | if teacher is not None: 83 | t_inputs = inputs.to(self.t_device) 84 | feature_loss, outputs, tea_logits, _, _ = compute_feature_loss(inputs, t_inputs, model, teacher, 85 | device=self.device) 86 | kd_loss = compute_hinton_loss(outputs, t_outputs=tea_logits, 87 | kd_temp=self.kd_temp, device=self.device) if self.lambh != 0 else 0 88 | else: 89 | outputs = model(inputs) 90 | 91 | adv_loss = 0 92 | for c in range(num_classes): 93 | if sum(labels == c) == 0: 94 | continue 95 | adv_inputs = outputs[labels == c].clone() 96 | adv_preds = sa_clf[c](adv_inputs) 97 | adv_loss += self.adv_criterion(adv_preds, groups[labels==c]) 98 | 99 | loss = self.criterion(outputs, labels) 100 | loss = loss + self.lambh * kd_loss 101 | loss = loss + self.lambf * feature_loss 102 | 103 | running_loss += loss.item() 104 | running_acc += get_accuracy(outputs, labels) 105 | 106 | self.optimizer.zero_grad() 107 | for c in range(num_classes): 108 | self.adv_optimizer_list[c].zero_grad() 109 | 110 | loss = loss + adv_loss 111 | loss.backward() 112 | 113 | self.optimizer.step() 114 | for c in range(num_classes): 115 | self.adv_optimizer_list[c].step() 116 | 117 | if i % self.term == self.term - 1: # print every self.term mini-batches 118 | avg_batch_time = time.time() - batch_start_time 119 | print_statement = '[{}/{}, {:5d}] Method: {} Train Loss: {:.3f} Train Acc: {:.2f} [{:.2f} s/batch]'\ 120 | .format(epoch + 1, self.epochs, i + 1, self.method, running_loss / self.term, running_acc / self.term, 121 | avg_batch_time / self.term) 122 | print(print_statement) 123 | 124 | running_loss = 0.0 125 | running_acc = 0.0 126 | batch_start_time = time.time() 127 | 128 | if not self.no_annealing and teacher is not None: 129 | self.lambh = self.lambh - 3/(self.epochs-1) 130 | 131 | def evaluate(self, model, adversary, loader, criterion, adv_criterion, device=None): 132 | model.eval() 133 | num_groups = loader.dataset.num_groups 134 | num_classes = loader.dataset.num_classes 135 | device = self.device if device is None else device 136 | eval_acc = 0 137 | eval_adv_acc = 0 138 | eval_loss = 0 139 | eval_adv_loss = 0 140 | eval_adv_loss_list = torch.zeros(num_classes) 141 | eval_eopp_list = torch.zeros(num_groups, num_classes).cuda(device) 142 | eval_data_count = torch.zeros(num_groups, num_classes).cuda(device) 143 | 144 | if 'Custom' in type(loader).__name__: 145 | loader = loader.generate() 146 | with torch.no_grad(): 147 | for j, eval_data in enumerate(loader): 148 | # Get the inputs 149 | inputs, _, groups, classes, _ = eval_data 150 | # 151 | labels = classes 152 | groups = groups.long() 153 | if self.cuda: 154 | inputs = inputs.cuda(device) 155 | labels = labels.cuda(device) 156 | groups = groups.cuda(device) 157 | 158 | outputs = model(inputs) 159 | 160 | loss = criterion(outputs, labels) 161 | eval_loss += loss.item() 162 | eval_acc += get_accuracy(outputs, labels) 163 | preds = torch.argmax(outputs, 1) 164 | acc = (preds == labels).float().squeeze() 165 | for g in range(num_groups): 166 | for l in range(num_classes): 167 | eval_eopp_list[g, l] += acc[(groups == g) * (labels == l)].sum() 168 | eval_data_count[g, l] += torch.sum((groups == g) * (labels == l)) 169 | 170 | for c in range(num_classes): 171 | if sum(labels==c)==0: 172 | continue 173 | adv_preds = adversary[c](outputs[labels==c]) 174 | adv_loss = adv_criterion(adv_preds, groups[labels==c]) 175 | eval_adv_loss += adv_loss.item() 176 | eval_adv_loss_list[c] += adv_loss.item() 177 | # print(c, adv_preds.shape) 178 | eval_adv_acc += get_accuracy(adv_preds, groups[labels==c]) 179 | 180 | eval_loss = eval_loss / (j+1) 181 | eval_acc = eval_acc / (j+1) 182 | eval_adv_loss = eval_adv_loss / ((j+1) * num_classes) 183 | eval_adv_loss_list = eval_adv_loss_list / (j+1) 184 | eval_adv_acc = eval_adv_acc / ((j+1) * num_classes) 185 | eval_eopp_list = eval_eopp_list / eval_data_count 186 | eval_max_eopp = torch.max(eval_eopp_list, dim=0)[0] - torch.min(eval_eopp_list, dim=0)[0] 187 | eval_max_eopp = torch.max(eval_max_eopp).item() 188 | model.train() 189 | return eval_loss, eval_acc, eval_adv_loss, eval_adv_acc, eval_max_eopp, eval_adv_loss_list 190 | 191 | def _init_adversary(self, num_groups, num_classes): 192 | self.model.eval() 193 | self.sa_clf_list = [] 194 | self.adv_optimizer_list = [] 195 | self.adv_scheduler_list = [] 196 | for _ in range(num_classes): 197 | sa_clf = MLP(feature_size=num_classes, hidden_dim=32, num_class=num_groups, num_layer=2, 198 | adv=True, adv_lambda=self.adv_lambda) 199 | if self.cuda: 200 | sa_clf.cuda(device=self.device) 201 | sa_clf.train() 202 | self.sa_clf_list.append(sa_clf) 203 | adv_optimizer = optim.Adam(sa_clf.parameters(), lr=self.adv_lr) 204 | self.adv_optimizer_list.append(adv_optimizer) 205 | self.adv_scheduler_list.append(ReduceLROnPlateau(adv_optimizer)) 206 | 207 | self.adv_criterion = nn.CrossEntropyLoss() 208 | -------------------------------------------------------------------------------- /trainer/kd_at.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import time 3 | from utils import get_accuracy 4 | from trainer.loss_utils import compute_at_loss 5 | import trainer 6 | 7 | 8 | class Trainer(trainer.GenericTrainer): 9 | def __init__(self, args, **kwargs): 10 | super().__init__(args=args, **kwargs) 11 | 12 | self.model_type = args.model 13 | self.for_cifar = True if args.model == 'cifar_net' else False 14 | self.lambf = args.lambf 15 | 16 | def train(self, train_loader, test_loader, epochs): 17 | 18 | self.model.train() 19 | self.teacher.eval() 20 | 21 | for epoch in range(self.epochs): 22 | # train during one epoch 23 | self._train_epoch(epoch, train_loader, self.model, self.teacher) 24 | 25 | eval_start_time = time.time() 26 | eval_loss, eval_acc, eval_deopp = self.evaluate(self.model, test_loader, self.criterion) 27 | eval_end_time = time.time() 28 | print('[{}/{}] Method: {} ' 29 | 'Test Loss: {:.3f} Test Acc: {:.2f} Test DEopp {:.2f} [{:.2f} s]'.format 30 | (epoch + 1, epochs, self.method, 31 | eval_loss, eval_acc, eval_deopp, (eval_end_time - eval_start_time))) 32 | 33 | if self.scheduler != None: 34 | self.scheduler.step(eval_loss) 35 | 36 | print('Training Finished!') 37 | 38 | def _train_epoch(self, epoch, train_loader, model, teacher): 39 | model.train() 40 | teacher.eval() 41 | 42 | running_acc = 0.0 43 | running_loss = 0.0 44 | batch_start_time = time.time() 45 | for i, data in enumerate(train_loader): 46 | # Get the inputs 47 | inputs, _, _, targets, _ = data 48 | labels = targets 49 | 50 | if self.cuda: 51 | inputs = inputs.cuda(self.device) 52 | labels = labels.cuda(self.device) 53 | t_inputs = inputs.to(self.t_device) 54 | 55 | attention_loss, stu_outputs, tea_outputs, _, _ = compute_at_loss(inputs, t_inputs, model, teacher, 56 | self.device, for_cifar=self.for_cifar) 57 | loss = self.criterion(stu_outputs, labels) 58 | loss = loss + self.lambf * attention_loss 59 | 60 | running_loss += loss.item() 61 | running_acc += get_accuracy(stu_outputs, labels) 62 | 63 | self.optimizer.zero_grad() 64 | loss.backward() 65 | self.optimizer.step() 66 | 67 | if i % self.term == self.term-1: # print every self.term mini-batches 68 | avg_batch_time = time.time() - batch_start_time 69 | print('[{}/{}, {:5d}] Method: {} Train Loss: {:.3f} Train Acc: {:.2f} ' 70 | '[{:.2f} s/batch]'.format 71 | (epoch + 1, self.epochs, i+1, self.method, running_loss / self.term, running_acc / self.term, 72 | avg_batch_time/self.term)) 73 | 74 | running_loss = 0.0 75 | running_acc = 0.0 76 | batch_start_time = time.time() 77 | 78 | -------------------------------------------------------------------------------- /trainer/kd_fitnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import time 4 | from utils import get_accuracy 5 | from trainer.kd_hinton import Trainer as hinton_Trainer 6 | from trainer.loss_utils import compute_feature_loss, compute_hinton_loss 7 | 8 | 9 | class Trainer(hinton_Trainer): 10 | def __init__(self, args, **kwargs): 11 | super().__init__(args=args, **kwargs) 12 | self.model_type = args.model 13 | 14 | self.fitnet_simul = args.fitnet_simul 15 | 16 | def train(self, train_loader, test_loader, epochs): 17 | 18 | if not self.fitnet_simul: 19 | for epoch in range(int(self.epochs/2)): 20 | 21 | self._train_epoch_hint(epoch, train_loader, self.model, self.teacher) 22 | 23 | print('Hint Training Finished!') 24 | self.save_model(self.save_dir, self.log_name + '_hint') 25 | 26 | for epoch in range(self.epochs): 27 | self._train_epoch(epoch, train_loader, self.model, self.teacher) 28 | eval_start_time = time.time() 29 | eval_loss, eval_acc, eval_deopp = self.evaluate(self.model, test_loader, self.criterion) 30 | eval_end_time = time.time() 31 | print('[{}/{}] Method: {} ' 32 | 'Test Loss: {:.3f} Test Acc: {:.2f} Test DEopp {:.2f} [{:.2f} s]'.format 33 | (epoch + 1, epochs, self.method, 34 | eval_loss, eval_acc, eval_deopp, (eval_end_time - eval_start_time))) 35 | 36 | if self.scheduler != None: 37 | self.scheduler.step(eval_loss) 38 | 39 | print('Training Finished!') 40 | 41 | def _train_epoch_hint(self, epoch, train_loader, model, teacher): 42 | model.train() 43 | teacher.eval() 44 | 45 | running_loss = 0.0 46 | avg_batch_time = 0.0 47 | 48 | for i, data in enumerate(train_loader): 49 | batch_start_time = time.time() 50 | # Get the inputs 51 | inputs, _, groups, labels, _ = data 52 | 53 | if self.cuda: 54 | inputs = inputs.cuda(self.device) 55 | t_inputs = inputs.to(self.t_device) 56 | 57 | fitnet_loss, _, _, _, _ = compute_feature_loss(inputs, t_inputs, model, teacher, device=self.device) 58 | running_loss += fitnet_loss.item() 59 | 60 | self.optimizer.zero_grad() 61 | fitnet_loss.backward() 62 | self.optimizer.step() 63 | 64 | batch_end_time = time.time() 65 | avg_batch_time += batch_end_time - batch_start_time 66 | 67 | if i % self.term == self.term-1: # print every self.term mini-batches 68 | train_loss = running_loss / self.term 69 | print('[{}/{}, {:5d}] Method: {} FitNet Hint Train Loss: {:.3f} [{:.2f} s/batch]'.format 70 | (epoch + 1, int(self.epochs/2), i + 1, self.method, train_loss, avg_batch_time / self.term)) 71 | 72 | running_loss = 0.0 73 | avg_batch_time = 0.0 74 | 75 | def _train_epoch(self, epoch, train_loader, model, teacher, distiller=None): 76 | model.train() 77 | teacher.eval() 78 | 79 | running_acc = 0.0 80 | running_loss = 0.0 81 | batch_start_time = time.time() 82 | 83 | for i, data in enumerate(train_loader): 84 | # Get the inputs 85 | inputs, _, _, targets, _ = data 86 | labels = targets 87 | 88 | if self.cuda: 89 | inputs = inputs.cuda(self.device) 90 | labels = labels.cuda(self.device) 91 | t_inputs = inputs.to(self.t_device) 92 | 93 | if self.fitnet_simul: 94 | feature_loss, stu_logits, tea_logits, _, _ = compute_feature_loss(inputs, t_inputs, model, teacher, 95 | device=self.device) 96 | kd_loss = compute_hinton_loss(stu_logits, t_outputs=tea_logits, 97 | kd_temp=self.kd_temp, device=self.device) 98 | else: 99 | stu_logits = model(inputs) 100 | kd_loss = compute_hinton_loss(stu_logits, t_inputs=t_inputs, teacher=teacher, 101 | kd_temp=self.kd_temp, device=self.device) if self.lambh != 0 else 0 102 | feature_loss = 0 103 | 104 | loss = self.criterion(stu_logits, labels) 105 | 106 | loss = loss + self.lambh * kd_loss 107 | loss = loss + feature_loss if self.fitnet_simul else loss 108 | 109 | running_loss += loss.item() 110 | running_acc += get_accuracy(stu_logits, labels) 111 | 112 | self.optimizer.zero_grad() 113 | loss.backward() 114 | self.optimizer.step() 115 | 116 | if i % self.term == self.term - 1: # print every self.term mini-batches 117 | avg_batch_time = time.time() - batch_start_time 118 | print('[{}/{}, {:5d}] Method: {} Train Loss: {:.3f} Train Acc: {:.2f} ' 119 | '[{:.2f} s/batch]'.format 120 | (epoch + 1, self.epochs, i + 1, self.method, running_loss / self.term, running_acc / self.term, 121 | avg_batch_time / self.term)) 122 | 123 | running_loss = 0.0 124 | running_acc = 0.0 125 | batch_start_time = time.time() 126 | 127 | if not self.no_annealing: 128 | self.lambh = self.lambh - 3/(self.epochs-1) 129 | 130 | -------------------------------------------------------------------------------- /trainer/kd_hinton.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import time 3 | from utils import get_accuracy 4 | from trainer.loss_utils import compute_hinton_loss 5 | import trainer 6 | 7 | 8 | class Trainer(trainer.GenericTrainer): 9 | def __init__(self, args, **kwargs): 10 | super().__init__(args=args, **kwargs) 11 | self.lambh = args.lambh 12 | self.kd_temp = args.kd_temp 13 | self.seed = args.seed 14 | self.no_annealing = args.no_annealing 15 | 16 | def train(self, train_loader, test_loader, epochs): 17 | 18 | self.model.train() 19 | self.teacher.eval() 20 | 21 | for epoch in range(self.epochs): 22 | self._train_epoch(epoch, train_loader, self.model, self.teacher) 23 | 24 | eval_start_time = time.time() 25 | eval_loss, eval_acc, eval_deopp = self.evaluate(self.model, test_loader, self.criterion) 26 | eval_end_time = time.time() 27 | print('[{}/{}] Method: {} ' 28 | 'Test Loss: {:.3f} Test Acc: {:.2f} Test DEopp {:.2f} [{:.2f} s]'.format 29 | (epoch + 1, epochs, self.method, 30 | eval_loss, eval_acc, eval_deopp, (eval_end_time - eval_start_time))) 31 | 32 | if self.scheduler != None: 33 | self.scheduler.step(eval_loss) 34 | 35 | print('Training Finished!') 36 | 37 | def _train_epoch(self, epoch, train_loader, model, teacher, distiller=None): 38 | 39 | model.train() 40 | teacher.eval() 41 | 42 | running_acc = 0.0 43 | running_loss = 0.0 44 | 45 | batch_start_time = time.time() 46 | for i, data in enumerate(train_loader): 47 | # Get the inputs 48 | inputs, _, groups, targets, _ = data 49 | labels = targets 50 | 51 | if self.cuda: 52 | inputs = inputs.cuda(self.device) 53 | labels = labels.cuda(self.device) 54 | t_inputs = inputs.to(self.t_device) 55 | 56 | outputs = model(inputs) 57 | t_outputs = teacher(t_inputs) 58 | kd_loss = compute_hinton_loss(outputs, t_outputs, kd_temp=self.kd_temp, device=self.device) 59 | 60 | loss = self.criterion(outputs, labels) + self.lambh * kd_loss 61 | 62 | running_loss += loss.item() 63 | running_acc += get_accuracy(outputs, labels) 64 | 65 | self.optimizer.zero_grad() 66 | loss.backward() 67 | self.optimizer.step() 68 | 69 | if i % self.term == self.term-1: # print every self.term mini-batches 70 | avg_batch_time = time.time() - batch_start_time 71 | 72 | print('[{}/{}, {:5d}] Method: {} Train Loss: {:.3f} Train Acc: {:.2f} ' 73 | '[{:.2f} s/batch]'.format 74 | (epoch + 1, self.epochs, i+1, self.method, running_loss / self.term, running_acc / self.term, 75 | avg_batch_time/self.term)) 76 | 77 | running_loss = 0.0 78 | running_acc = 0.0 79 | batch_start_time = time.time() 80 | 81 | if not self.no_annealing: 82 | self.lambh = self.lambh - 3/(self.epochs-1) 83 | -------------------------------------------------------------------------------- /trainer/kd_mfd.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | import time 7 | from utils import get_accuracy 8 | from trainer.kd_hinton import Trainer as hinton_Trainer 9 | from trainer.loss_utils import compute_hinton_loss 10 | 11 | 12 | class Trainer(hinton_Trainer): 13 | def __init__(self, args, **kwargs): 14 | super().__init__(args=args, **kwargs) 15 | self.lambh = args.lambh 16 | self.lambf = args.lambf 17 | self.sigma = args.sigma 18 | self.kernel = args.kernel 19 | self.jointfeature = args.jointfeature 20 | 21 | def train(self, train_loader, test_loader, epochs): 22 | 23 | num_classes = train_loader.dataset.num_classes 24 | num_groups = train_loader.dataset.num_groups 25 | 26 | distiller = MMDLoss(w_m=self.lambf, sigma=self.sigma, 27 | num_classes=num_classes, num_groups=num_groups, kernel=self.kernel) 28 | 29 | for epoch in range(self.epochs): 30 | self._train_epoch(epoch, train_loader, self.model, self.teacher, distiller=distiller) 31 | eval_start_time = time.time() 32 | eval_loss, eval_acc, eval_deopp = self.evaluate(self.model, test_loader, self.criterion) 33 | eval_end_time = time.time() 34 | print('[{}/{}] Method: {} ' 35 | 'Test Loss: {:.3f} Test Acc: {:.2f} Test DEopp {:.2f} [{:.2f} s]'.format 36 | (epoch + 1, epochs, self.method, 37 | eval_loss, eval_acc, eval_deopp, (eval_end_time - eval_start_time))) 38 | 39 | if self.scheduler != None: 40 | self.scheduler.step(eval_loss) 41 | 42 | print('Training Finished!') 43 | 44 | def _train_epoch(self, epoch, train_loader, model, teacher, distiller=None): 45 | model.train() 46 | teacher.eval() 47 | 48 | running_acc = 0.0 49 | running_loss = 0.0 50 | batch_start_time = time.time() 51 | 52 | for i, data in enumerate(train_loader): 53 | # Get the inputs 54 | inputs, _, groups, targets, _ = data 55 | labels = targets 56 | 57 | if self.cuda: 58 | inputs = inputs.cuda(self.device) 59 | labels = labels.cuda(self.device) 60 | groups = groups.long().cuda(self.device) 61 | t_inputs = inputs.to(self.t_device) 62 | 63 | outputs = model(inputs, get_inter=True) 64 | stu_logits = outputs[-1] 65 | 66 | t_outputs = teacher(t_inputs, get_inter=True) 67 | tea_logits = t_outputs[-1] 68 | 69 | kd_loss = compute_hinton_loss(stu_logits, t_outputs=tea_logits, 70 | kd_temp=self.kd_temp, device=self.device) if self.lambh != 0 else 0 71 | 72 | loss = self.criterion(stu_logits, labels) 73 | loss = loss + self.lambh * kd_loss 74 | 75 | 76 | f_s = outputs[-2] 77 | f_t = t_outputs[-2] 78 | mmd_loss = distiller.forward(f_s, f_t, groups=groups, labels=labels, jointfeature=self.jointfeature) 79 | 80 | loss = loss + mmd_loss 81 | running_loss += loss.item() 82 | running_acc += get_accuracy(stu_logits, labels) 83 | 84 | self.optimizer.zero_grad() 85 | loss.backward() 86 | self.optimizer.step() 87 | 88 | if i % self.term == self.term - 1: # print every self.term mini-batches 89 | avg_batch_time = time.time() - batch_start_time 90 | print('[{}/{}, {:5d}] Method: {} Train Loss: {:.3f} Train Acc: {:.2f} ' 91 | '[{:.2f} s/batch]'.format 92 | (epoch + 1, self.epochs, i + 1, self.method, running_loss / self.term, running_acc / self.term, 93 | avg_batch_time / self.term)) 94 | 95 | running_loss = 0.0 96 | running_acc = 0.0 97 | batch_start_time = time.time() 98 | 99 | if not self.no_annealing: 100 | self.lambh = self.lambh - 3 / (self.epochs - 1) 101 | 102 | 103 | class MMDLoss(nn.Module): 104 | def __init__(self, w_m, sigma, num_groups, num_classes, kernel): 105 | super(MMDLoss, self).__init__() 106 | self.w_m = w_m 107 | self.sigma = sigma 108 | self.num_groups = num_groups 109 | self.num_classes = num_classes 110 | self.kernel = kernel 111 | 112 | def forward(self, f_s, f_t, groups, labels, jointfeature=False): 113 | if self.kernel == 'poly': 114 | student = F.normalize(f_s.view(f_s.shape[0], -1), dim=1) 115 | teacher = F.normalize(f_t.view(f_t.shape[0], -1), dim=1).detach() 116 | else: 117 | student = f_s.view(f_s.shape[0], -1) 118 | teacher = f_t.view(f_t.shape[0], -1).detach() 119 | 120 | mmd_loss = 0 121 | 122 | if jointfeature: 123 | K_TS, sigma_avg = self.pdist(teacher, student, 124 | sigma_base=self.sigma, kernel=self.kernel) 125 | K_TT, _ = self.pdist(teacher, teacher, sigma_base=self.sigma, sigma_avg=sigma_avg, kernel=self.kernel) 126 | K_SS, _ = self.pdist(student, student, 127 | sigma_base=self.sigma, sigma_avg=sigma_avg, kernel=self.kernel) 128 | 129 | mmd_loss += K_TT.mean() + K_SS.mean() - 2 * K_TS.mean() 130 | 131 | else: 132 | with torch.no_grad(): 133 | _, sigma_avg = self.pdist(teacher, student, sigma_base=self.sigma, kernel=self.kernel) 134 | 135 | for c in range(self.num_classes): 136 | if len(teacher[labels==c]) == 0: 137 | continue 138 | for g in range(self.num_groups): 139 | if len(student[(labels==c) * (groups == g)]) == 0: 140 | continue 141 | K_TS, _ = self.pdist(teacher[labels == c], student[(labels == c) * (groups == g)], 142 | sigma_base=self.sigma, sigma_avg=sigma_avg, kernel=self.kernel) 143 | K_SS, _ = self.pdist(student[(labels == c) * (groups == g)], student[(labels == c) * (groups == g)], 144 | sigma_base=self.sigma, sigma_avg=sigma_avg, kernel=self.kernel) 145 | 146 | K_TT, _ = self.pdist(teacher[labels == c], teacher[labels == c], sigma_base=self.sigma, 147 | sigma_avg=sigma_avg, kernel=self.kernel) 148 | 149 | mmd_loss += K_TT.mean() + K_SS.mean() - 2 * K_TS.mean() 150 | 151 | loss = (1/2) * self.w_m * mmd_loss 152 | 153 | return loss 154 | 155 | @staticmethod 156 | def pdist(e1, e2, eps=1e-12, kernel='rbf', sigma_base=1.0, sigma_avg=None): 157 | if len(e1) == 0 or len(e2) == 0: 158 | res = torch.zeros(1) 159 | else: 160 | if kernel == 'rbf': 161 | e1_square = e1.pow(2).sum(dim=1) 162 | e2_square = e2.pow(2).sum(dim=1) 163 | prod = e1 @ e2.t() 164 | res = (e1_square.unsqueeze(1) + e2_square.unsqueeze(0) - 2 * prod).clamp(min=eps) 165 | res = res.clone() 166 | sigma_avg = res.mean().detach() if sigma_avg is None else sigma_avg 167 | res = torch.exp(-res / (2*(sigma_base)*sigma_avg)) 168 | elif kernel == 'poly': 169 | res = torch.matmul(e1, e2.t()).pow(2) 170 | 171 | return res, sigma_avg 172 | -------------------------------------------------------------------------------- /trainer/kd_nst.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | import time 6 | from utils import get_accuracy 7 | import trainer 8 | 9 | 10 | class Trainer(trainer.GenericTrainer): 11 | def __init__(self, args, **kwargs): 12 | super().__init__(args=args, **kwargs) 13 | self.model_type = args.model 14 | self.lambf = args.lambf 15 | self.kernel = args.kernel 16 | 17 | def train(self, train_loader, test_loader, epochs): 18 | 19 | distiller = NST(lamb=self.lambf) 20 | 21 | for epoch in range(self.epochs): 22 | 23 | self._train_epoch(epoch, train_loader, self.model, self.teacher, distiller) 24 | 25 | eval_start_time = time.time() 26 | eval_loss, eval_acc, eval_deopp = self.evaluate(self.model, test_loader, self.criterion) 27 | eval_end_time = time.time() 28 | print('[{}/{}] Method: {} ' 29 | 'Test Loss: {:.3f} Test Acc: {:.2f} Test DEopp {:.2f} [{:.2f} s]'.format 30 | (epoch + 1, epochs, self.method, 31 | eval_loss, eval_acc, eval_deopp, (eval_end_time - eval_start_time))) 32 | 33 | if self.scheduler != None: 34 | self.scheduler.step(eval_loss) 35 | 36 | print('Training Finished!') 37 | 38 | def _train_epoch(self, epoch, train_loader, model, teacher, distiller): 39 | model.train() 40 | teacher.eval() 41 | 42 | running_acc = 0.0 43 | running_loss = 0.0 44 | batch_start_time = time.time() 45 | 46 | for i, data in enumerate(train_loader): 47 | # Get the inputs 48 | inputs, _, groups, targets, _ = data 49 | labels = targets 50 | 51 | if self.cuda: 52 | inputs = inputs.cuda(self.device) 53 | labels = labels.cuda(self.device) 54 | t_inputs = inputs.to(self.t_device) 55 | 56 | outputs = model(inputs, get_inter=True) 57 | stu_logits = outputs[-1] 58 | f_s = outputs[-2] 59 | 60 | t_outputs = teacher(t_inputs, get_inter=True) 61 | f_t = t_outputs[-2] 62 | 63 | loss = self.criterion(stu_logits, labels) 64 | mmd_loss = distiller.forward(f_s, f_t) 65 | 66 | loss = loss + mmd_loss 67 | 68 | running_loss += loss.item() 69 | running_acc += get_accuracy(stu_logits, labels) 70 | 71 | # zero the parameter gradients + backward + optimize 72 | self.optimizer.zero_grad() 73 | loss.backward() 74 | self.optimizer.step() 75 | 76 | if i % self.term == self.term - 1: # print every self.term mini-batches 77 | avg_batch_time = time.time() - batch_start_time 78 | print('[{}/{}, {:5d}] Method: {} Train Loss: {:.3f} Train Acc: {:.2f} ' 79 | '[{:.2f} s/batch]'.format 80 | (epoch + 1, self.epochs, i + 1, self.method, running_loss / self.term, running_acc / self.term, 81 | avg_batch_time / self.term)) 82 | 83 | running_loss = 0.0 84 | running_acc = 0.0 85 | batch_start_time = time.time() 86 | 87 | 88 | class NST(nn.Module): 89 | def __init__(self, lamb): 90 | super(NST, self).__init__() 91 | self.lamb = lamb 92 | 93 | def forward(self, fm_s, fm_t): 94 | fm_s = fm_s.view(fm_s.size(0), fm_s.size(1), -1) 95 | fm_s = F.normalize(fm_s, dim=2) 96 | 97 | fm_t = fm_t.view(fm_t.size(0), fm_t.size(1), -1) 98 | fm_t = F.normalize(fm_t, dim=2) 99 | 100 | loss = self.poly_kernel(fm_t, fm_t).mean() + self.poly_kernel(fm_s, fm_s).mean() \ 101 | - 2 * self.poly_kernel(fm_t, fm_s).mean() 102 | 103 | loss = self.lamb * loss 104 | return loss 105 | 106 | def poly_kernel(self, fm1, fm2): 107 | fm1 = fm1.unsqueeze(1) 108 | fm2 = fm2.unsqueeze(2) 109 | out = (fm1 * fm2).sum(-1).pow(2) 110 | 111 | return out 112 | -------------------------------------------------------------------------------- /trainer/loss_utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | def mse(inputs, targets): 6 | return (inputs - targets).pow(2).mean() 7 | 8 | 9 | def compute_feature_loss(inputs, t_inputs, student, teacher, device=0): 10 | stu_outputs = student(inputs, get_inter=True) 11 | f_s = stu_outputs[-2] 12 | f_s = f_s.view(f_s.shape[0], -1) 13 | stu_logits = stu_outputs[-1] 14 | 15 | tea_outputs = teacher(t_inputs, get_inter=True) 16 | f_t = tea_outputs[-2].to(device) 17 | f_t = f_t.view(f_t.shape[0], -1).detach() 18 | 19 | tea_logits = tea_outputs[-1] 20 | 21 | fitnet_loss = (1 / 2) * (mse(f_s, f_t)) 22 | 23 | return fitnet_loss, stu_logits, tea_logits, f_s, f_t 24 | 25 | 26 | def compute_hinton_loss(outputs, t_outputs=None, teacher=None, t_inputs=None, kd_temp=3, device=0): 27 | if t_outputs is None: 28 | if (t_inputs is not None and teacher is not None): 29 | t_outputs = teacher(t_inputs) 30 | else: 31 | Exception('Nothing is given to compute hinton loss') 32 | 33 | soft_label = F.softmax(t_outputs / kd_temp, dim=1).to(device).detach() 34 | kd_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(outputs / kd_temp, dim=1), 35 | soft_label) * (kd_temp * kd_temp) 36 | 37 | return kd_loss 38 | 39 | 40 | def compute_at_loss(inputs, t_inputs, student, teacher, device=0, for_cifar=False): 41 | stu_outputs = student(inputs, get_inter=True) if not for_cifar else student(inputs, get_inter=True, before_fc=True) 42 | stu_logits = stu_outputs[-1] 43 | f_s = stu_outputs[-2] 44 | 45 | tea_outputs = teacher(t_inputs, get_inter=True) if not for_cifar else teacher(inputs, get_inter=True, before_fc=True) 46 | tea_logits = tea_outputs[-1].to(device) 47 | f_t = tea_outputs[-2].to(device) 48 | attention_loss = (1 / 2) * (at_loss(f_s, f_t)) 49 | return attention_loss, stu_logits, tea_logits, f_s, f_t 50 | 51 | 52 | def at(x): 53 | return F.normalize(x.pow(2).mean(1).view(x.size(0), -1)) 54 | 55 | 56 | def at_loss(x, y): 57 | return (at(x) - at(y)).pow(2).mean() 58 | -------------------------------------------------------------------------------- /trainer/scratch_mmd.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import time 4 | import numpy as np 5 | from utils import get_accuracy 6 | import trainer 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class Trainer(trainer.GenericTrainer): 13 | def __init__(self, args, **kwargs): 14 | super().__init__(args=args, **kwargs) 15 | 16 | self.lambf = args.lambf 17 | self.sigma = args.sigma 18 | self.kernel = args.kernel 19 | 20 | def train(self, train_loader, test_loader, epochs): 21 | model = self.model 22 | model.train() 23 | 24 | num_classes = train_loader.dataset.num_classes 25 | num_groups = train_loader.dataset.num_groups 26 | 27 | distiller = MMDLoss(w_m=self.lambf, sigma=self.sigma, 28 | num_classes=num_classes, num_groups=num_groups, kernel=self.kernel) 29 | 30 | for epoch in range(epochs): 31 | self._train_epoch(epoch, train_loader, model, distiller=distiller) 32 | 33 | eval_start_time = time.time() 34 | eval_loss, eval_acc, eval_deopp = self.evaluate(model, test_loader, self.criterion) 35 | eval_end_time = time.time() 36 | print('[{}/{}] Method: {} ' 37 | 'Test Loss: {:.3f} Test Acc: {:.2f} Test DEopp {:.2f} [{:.2f} s]'.format 38 | (epoch + 1, epochs, self.method, 39 | eval_loss, eval_acc, eval_deopp, (eval_end_time - eval_start_time))) 40 | 41 | if self.scheduler != None and 'Multi' not in type(self.scheduler).__name__: 42 | self.scheduler.step(eval_loss) 43 | else: 44 | self.scheduler.step() 45 | 46 | print('Training Finished!') 47 | 48 | def _train_epoch(self, epoch, train_loader, model, distiller): 49 | 50 | model.train() 51 | 52 | running_acc = 0.0 53 | running_loss = 0.0 54 | batch_start_time = time.time() 55 | for i, data in enumerate(train_loader): 56 | # Get the inputs 57 | inputs, _, groups, targets, _ = data 58 | 59 | labels = targets 60 | 61 | if self.cuda: 62 | inputs = inputs.cuda(device=self.device) 63 | labels = labels.cuda(device=self.device) 64 | groups = groups.long().cuda(device=self.device) 65 | 66 | outputs = model(inputs, get_inter=True) 67 | f_s = outputs[-2] 68 | loss = self.criterion(outputs[-1], labels) 69 | 70 | mmd_loss = distiller.forward(f_s, groups=groups, labels=labels) 71 | loss = loss + mmd_loss 72 | 73 | running_loss += loss.item() 74 | running_acc += get_accuracy(outputs[-1], labels) 75 | 76 | self.optimizer.zero_grad() 77 | loss.backward() 78 | self.optimizer.step() 79 | 80 | if i % self.term == self.term - 1: # print every self.term mini-batches 81 | avg_batch_time = time.time() - batch_start_time 82 | print('[{}/{}, {:5d}] Method: {} Train Loss: {:.3f} Train Acc: {:.2f} ' 83 | '[{:.2f} s/batch]'.format 84 | (epoch + 1, self.epochs, i + 1, self.method, running_loss / self.term, running_acc / self.term, 85 | avg_batch_time / self.term)) 86 | 87 | running_loss = 0.0 88 | running_acc = 0.0 89 | batch_start_time = time.time() 90 | 91 | 92 | class MMDLoss(nn.Module): 93 | def __init__(self, w_m, sigma, num_groups, num_classes, kernel): 94 | super(MMDLoss, self).__init__() 95 | self.w_m = w_m 96 | self.sigma = sigma 97 | self.num_groups = num_groups 98 | self.num_classes = num_classes 99 | self.kernel = kernel 100 | 101 | def forward(self, f_s, groups, labels): 102 | if self.kernel == 'poly': 103 | student = F.normalize(f_s.view(f_s.shape[0], -1), dim=1) 104 | else: 105 | student = f_s.view(f_s.shape[0], -1) 106 | 107 | mmd_loss = 0 108 | 109 | for c in range(self.num_classes): 110 | 111 | target_joint = student[labels == c].clone().detach() 112 | 113 | for g in range(self.num_groups): 114 | if len(student[(labels == c) * (groups == g)]) == 0: 115 | continue 116 | 117 | K_SSg, sigma_avg = self.pdist(target_joint, student[(labels == c) * (groups == g)], 118 | sigma_base=self.sigma, kernel=self.kernel) 119 | 120 | K_SgSg, _ = self.pdist(student[(labels==c) * (groups==g)], student[(labels==c) * (groups==g)], 121 | sigma_base=self.sigma, sigma_avg=sigma_avg, kernel=self.kernel) 122 | 123 | K_SS, _ = self.pdist(target_joint, target_joint, 124 | sigma_base=self.sigma, sigma_avg=sigma_avg, kernel=self.kernel) 125 | 126 | mmd_loss += torch.clamp(K_SS.mean() + K_SgSg.mean() - 2 * K_SSg.mean(), 0.0, np.inf).mean() 127 | 128 | loss = self.w_m * mmd_loss / (2*self.num_groups) 129 | 130 | return loss 131 | 132 | @staticmethod 133 | def pdist(e1, e2, eps=1e-12, kernel='rbf', sigma_base=1.0, sigma_avg=None): 134 | if len(e1) == 0 or len(e2) == 0: 135 | res = torch.zeros(1) 136 | else: 137 | if kernel == 'rbf': 138 | e1_square = e1.pow(2).sum(dim=1) 139 | e2_square = e2.pow(2).sum(dim=1) 140 | prod = e1 @ e2.t() 141 | res = (e1_square.unsqueeze(1) + e2_square.unsqueeze(0) - 2 * prod).clamp(min=eps) 142 | res = res.clone() 143 | 144 | sigma_avg = res.mean().detach() if sigma_avg is None else sigma_avg 145 | res = torch.exp(-res / (2*(sigma_base**2)*sigma_avg)) 146 | elif kernel == 'poly': 147 | res = torch.matmul(e1, e2.t()).pow(2) 148 | return res, sigma_avg 149 | -------------------------------------------------------------------------------- /trainer/trainer_factory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import torch.nn as nn 5 | from torch.optim.lr_scheduler import ReduceLROnPlateau, MultiStepLR 6 | from sklearn.metrics import confusion_matrix 7 | from utils import make_log_name 8 | 9 | 10 | class TrainerFactory: 11 | def __init__(self): 12 | pass 13 | 14 | @staticmethod 15 | def get_trainer(method, **kwargs): 16 | if method == 'scratch': 17 | import trainer.vanilla_train as trainer 18 | elif method == 'kd_hinton': 19 | import trainer.kd_hinton as trainer 20 | elif method == 'kd_fitnet': 21 | import trainer.kd_fitnet as trainer 22 | elif method == 'kd_at': 23 | import trainer.kd_at as trainer 24 | elif method == 'kd_nst': 25 | import trainer.kd_nst as trainer 26 | elif method == 'kd_mfd': 27 | import trainer.kd_mfd as trainer 28 | elif method == 'scratch_mmd': 29 | import trainer.scratch_mmd as trainer 30 | elif method == 'adv_debiasing': 31 | import trainer.adv_debiasing as trainer 32 | else: 33 | raise Exception('Not allowed method') 34 | return trainer.Trainer(**kwargs) 35 | 36 | 37 | class GenericTrainer: 38 | ''' 39 | Base class for trainer; to implement a new training routine, inherit from this. 40 | ''' 41 | def __init__(self, model, args, optimizer, teacher=None): 42 | self.get_inter = args.get_inter 43 | 44 | self.cuda = args.cuda 45 | self.device = args.device 46 | self.t_device = args.t_device 47 | self.term = args.term 48 | self.lr = args.lr 49 | self.parallel = args.parallel 50 | self.epochs = args.epochs 51 | self.method = args.method 52 | self.model = model 53 | self.teacher = teacher 54 | self.optimizer = optimizer 55 | self.optim_type = args.optimizer 56 | self.img_size = args.img_size if not 'cifar10' in args.dataset else 32 57 | self.criterion=torch.nn.CrossEntropyLoss() 58 | self.scheduler = None 59 | 60 | self.log_name = make_log_name(args) 61 | self.log_dir = os.path.join(args.log_dir, args.date, args.dataset, args.method) 62 | self.save_dir = os.path.join(args.save_dir, args.date, args.dataset, args.method) 63 | 64 | if self.optim_type == 'Adam' and self.optimizer is not None: 65 | self.scheduler = ReduceLROnPlateau(self.optimizer) 66 | else: 67 | self.scheduler = MultiStepLR(self.optimizer, [30, 60, 90], gamma=0.1) 68 | 69 | def evaluate(self, model, loader, criterion, device=None, groupwise=False): 70 | model.eval() 71 | num_groups = loader.dataset.num_groups 72 | num_classes = loader.dataset.num_classes 73 | device = self.device if device is None else device 74 | 75 | eval_acc = 0 if not groupwise else torch.zeros(num_groups, num_classes).cuda(device) 76 | eval_loss = 0 if not groupwise else torch.zeros(num_groups, num_classes).cuda(device) 77 | eval_eopp_list = torch.zeros(num_groups, num_classes).cuda(device) 78 | eval_data_count = torch.zeros(num_groups, num_classes).cuda(device) 79 | 80 | if 'Custom' in type(loader).__name__: 81 | loader = loader.generate() 82 | with torch.no_grad(): 83 | for j, eval_data in enumerate(loader): 84 | # Get the inputs 85 | inputs, _, groups, classes, _ = eval_data 86 | # 87 | labels = classes 88 | if self.cuda: 89 | inputs = inputs.cuda(device) 90 | labels = labels.cuda(device) 91 | groups = groups.cuda(device) 92 | 93 | outputs = model(inputs) 94 | 95 | if groupwise: 96 | if self.cuda: 97 | groups = groups.cuda(device) 98 | loss = nn.CrossEntropyLoss(reduction='none')(outputs, labels) 99 | preds = torch.argmax(outputs, 1) 100 | acc = (preds == labels).float().squeeze() 101 | for g in range(num_groups): 102 | for l in range(num_classes): 103 | eval_loss[g, l] += loss[(groups == g) * (labels == l)].sum() 104 | eval_acc[g, l] += acc[(groups == g) * (labels == l)].sum() 105 | eval_data_count[g, l] += torch.sum((groups == g) * (labels == l)) 106 | 107 | else: 108 | loss = criterion(outputs, labels) 109 | eval_loss += loss.item() * len(labels) 110 | preds = torch.argmax(outputs, 1) 111 | acc = (preds == labels).float().squeeze() 112 | eval_acc += acc.sum() 113 | 114 | for g in range(num_groups): 115 | for l in range(num_classes): 116 | eval_eopp_list[g, l] += acc[(groups == g) * (labels == l)].sum() 117 | eval_data_count[g, l] += torch.sum((groups == g) * (labels == l)) 118 | 119 | eval_loss = eval_loss / eval_data_count.sum() if not groupwise else eval_loss / eval_data_count 120 | eval_acc = eval_acc / eval_data_count.sum() if not groupwise else eval_acc / eval_data_count 121 | eval_eopp_list = eval_eopp_list / eval_data_count 122 | eval_max_eopp = torch.max(eval_eopp_list, dim=0)[0] - torch.min(eval_eopp_list, dim=0)[0] 123 | eval_max_eopp = torch.max(eval_max_eopp).item() 124 | model.train() 125 | return eval_loss, eval_acc, eval_max_eopp 126 | 127 | def save_model(self, save_dir, log_name="", model=None): 128 | model_to_save = self.model if model is None else model 129 | model_savepath = os.path.join(save_dir, log_name + '.pt') 130 | torch.save(model_to_save.state_dict(), model_savepath) 131 | 132 | print('Model saved to %s' % model_savepath) 133 | 134 | def compute_confusion_matix(self, dataset='test', num_classes=2, 135 | dataloader=None, log_dir="", log_name=""): 136 | from scipy.io import savemat 137 | from collections import defaultdict 138 | self.model.eval() 139 | confu_mat = defaultdict(lambda: np.zeros((num_classes, num_classes))) 140 | print('# of {} data : {}'.format(dataset, len(dataloader.dataset))) 141 | 142 | predict_mat = {} 143 | output_set = torch.tensor([]) 144 | group_set = torch.tensor([], dtype=torch.long) 145 | target_set = torch.tensor([], dtype=torch.long) 146 | intermediate_feature_set = torch.tensor([]) 147 | 148 | with torch.no_grad(): 149 | for i, data in enumerate(dataloader): 150 | # Get the inputs 151 | inputs, _, groups, targets, _ = data 152 | labels = targets 153 | groups = groups.long() 154 | 155 | if self.cuda: 156 | inputs = inputs.cuda(self.device) 157 | labels = labels.cuda(self.device) 158 | 159 | # forward 160 | 161 | outputs = self.model(inputs) 162 | if self.get_inter: 163 | intermediate_feature = self.model.forward(inputs, get_inter=True)[-2] 164 | 165 | group_set = torch.cat((group_set, groups)) 166 | target_set = torch.cat((target_set, targets)) 167 | output_set = torch.cat((output_set, outputs.cpu())) 168 | if self.get_inter: 169 | intermediate_feature_set = torch.cat((intermediate_feature_set, intermediate_feature.cpu())) 170 | 171 | pred = torch.argmax(outputs, 1) 172 | group_element = list(torch.unique(groups).numpy()) 173 | for i in group_element: 174 | mask = groups == i 175 | if len(labels[mask]) != 0: 176 | confu_mat[str(i)] += confusion_matrix( 177 | labels[mask].cpu().numpy(), pred[mask].cpu().numpy(), 178 | labels=[i for i in range(num_classes)]) 179 | 180 | predict_mat['group_set'] = group_set.numpy() 181 | predict_mat['target_set'] = target_set.numpy() 182 | predict_mat['output_set'] = output_set.numpy() 183 | if self.get_inter: 184 | predict_mat['intermediate_feature_set'] = intermediate_feature_set.numpy() 185 | 186 | savepath = os.path.join(log_dir, log_name + '_{}_confu'.format(dataset)) 187 | print('savepath', savepath) 188 | savemat(savepath, confu_mat, appendmat=True) 189 | 190 | savepath_pred = os.path.join(log_dir, log_name + '_{}_pred'.format(dataset)) 191 | savemat(savepath_pred, predict_mat, appendmat=True) 192 | 193 | print('Computed confusion matrix for {} dataset successfully!'.format(dataset)) 194 | return confu_mat 195 | -------------------------------------------------------------------------------- /trainer/vanilla_train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import time 4 | from utils import get_accuracy 5 | import trainer 6 | 7 | 8 | class Trainer(trainer.GenericTrainer): 9 | def __init__(self, args, **kwargs): 10 | super().__init__(args=args, **kwargs) 11 | 12 | def train(self, train_loader, test_loader, epochs): 13 | model = self.model 14 | model.train() 15 | 16 | for epoch in range(epochs): 17 | self._train_epoch(epoch, train_loader, model) 18 | 19 | eval_start_time = time.time() 20 | eval_loss, eval_acc, eval_deopp = self.evaluate(model, test_loader, self.criterion) 21 | eval_end_time = time.time() 22 | print('[{}/{}] Method: {} ' 23 | 'Test Loss: {:.3f} Test Acc: {:.2f} Test DEopp {:.2f} [{:.2f} s]'.format 24 | (epoch + 1, epochs, self.method, 25 | eval_loss, eval_acc, eval_deopp, (eval_end_time - eval_start_time))) 26 | 27 | if self.scheduler != None and 'Multi' not in type(self.scheduler).__name__: 28 | self.scheduler.step(eval_loss) 29 | else: 30 | self.scheduler.step() 31 | 32 | print('Training Finished!') 33 | 34 | def _train_epoch(self, epoch, train_loader, model): 35 | model.train() 36 | 37 | running_acc = 0.0 38 | running_loss = 0.0 39 | 40 | batch_start_time = time.time() 41 | for i, data in enumerate(train_loader): 42 | # Get the inputs 43 | inputs, _, groups, targets, _ = data 44 | 45 | labels = targets 46 | 47 | if self.cuda: 48 | inputs = inputs.cuda(device=self.device) 49 | labels = labels.cuda(device=self.device) 50 | outputs = model(inputs) 51 | loss = self.criterion(outputs, labels) 52 | 53 | running_loss += loss.item() 54 | running_acc += get_accuracy(outputs, labels) 55 | 56 | self.optimizer.zero_grad() 57 | loss.backward() 58 | self.optimizer.step() 59 | 60 | if i % self.term == self.term-1: # print every self.term mini-batches 61 | avg_batch_time = time.time()-batch_start_time 62 | print('[{}/{}, {:5d}] Method: {} Train Loss: {:.3f} Train Acc: {:.2f} ' 63 | '[{:.2f} s/batch]'.format 64 | (epoch + 1, self.epochs, i+1, self.method, running_loss / self.term, running_acc / self.term, 65 | avg_batch_time/self.term)) 66 | 67 | running_loss = 0.0 68 | running_acc = 0.0 69 | batch_start_time = time.time() 70 | 71 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import os 5 | 6 | 7 | def list_files(root, suffix, prefix=False): 8 | root = os.path.expanduser(root) 9 | files = list( 10 | filter( 11 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix), 12 | os.listdir(root) 13 | ) 14 | ) 15 | if prefix is True: 16 | files = [os.path.join(root, d) for d in files] 17 | return files 18 | 19 | 20 | def set_seed(seed): 21 | torch.manual_seed(seed) 22 | # torch.cuda.manual_seed(seed) 23 | np.random.seed(seed) 24 | random.seed(seed) 25 | torch.backends.cudnn.benchmark = False 26 | torch.backends.cudnn.deterministic = True 27 | 28 | 29 | def get_accuracy(outputs, labels, binary=False): 30 | #if multi-label classification 31 | if len(labels.size())>1: 32 | outputs = (outputs>0.0).float() 33 | correct = ((outputs==labels)).float().sum() 34 | total = torch.tensor(labels.shape[0] * labels.shape[1], dtype=torch.float) 35 | avg = correct / total 36 | return avg.item() 37 | if binary: 38 | predictions = (torch.sigmoid(outputs) >= 0.5).float() 39 | else: 40 | predictions = torch.argmax(outputs, 1) 41 | c = (predictions == labels).float().squeeze() 42 | accuracy = torch.mean(c) 43 | return accuracy.item() 44 | 45 | 46 | def check_log_dir(log_dir): 47 | try: 48 | if not os.path.isdir(log_dir): 49 | os.makedirs(log_dir) 50 | except OSError: 51 | print("Failed to create directory!!") 52 | 53 | 54 | def make_log_name(args): 55 | log_name = args.model 56 | 57 | if args.mode == 'eval': 58 | log_name = args.modelpath.split('/')[-1] 59 | # remove .pt from name 60 | log_name = log_name[:-3] 61 | 62 | else: 63 | if args.pretrained: 64 | log_name += '_pretrained' 65 | log_name += '_seed{}_epochs{}_bs{}_lr{}'.format(args.seed, args.epochs, args.batch_size, args.lr) 66 | 67 | if args.method == 'adv_debiasing': 68 | log_name += '_advlamb{}_eta{}'.format(args.adv_lambda, args.eta) 69 | 70 | elif args.method == 'scratch_mmd' or args.method == 'kd_mfd': 71 | log_name += '_{}'.format(args.kernel) 72 | log_name += '_sigma{}'.format(args.sigma) if args.kernel == 'rbf' else '' 73 | log_name += '_jointfeature' if args.jointfeature else '' 74 | log_name += '_lambf{}'.format(args.lambf) if args.method == 'scratch_mmd' else '' 75 | 76 | if args.labelwise: 77 | log_name += '_labelwise' 78 | 79 | if args.teacher_path is not None: 80 | log_name += '_temp{}'.format(args.kd_temp) 81 | log_name += '_lambh{}_lambf{}'.format(args.lambh, args.lambf) 82 | 83 | if args.no_annealing: 84 | log_name += '_fixedlamb' 85 | if args.dataset == 'celeba' and args.target != 'Attractive': 86 | log_name += '_{}'.format(args.target) 87 | 88 | return log_name 89 | 90 | --------------------------------------------------------------------------------