├── .gitignore ├── Attack ├── __pycache__ │ ├── datasets.cpython-310.pyc │ ├── engine.cpython-310.pyc │ ├── losses.cpython-310.pyc │ ├── models.cpython-310.pyc │ ├── samplers.cpython-310.pyc │ ├── softmax.cpython-310.pyc │ └── utils.cpython-310.pyc ├── datasets.py ├── engine.py ├── hubconf.py ├── log_data.py ├── losses.py ├── main.py ├── models.py ├── requirements.txt ├── resmlp_models.py ├── run_fgm.sh ├── run_pgd.sh ├── run_with_submitit.py ├── samplers.py ├── softmax.py ├── test_init.py ├── tox.ini └── utils.py ├── README.md ├── Reconstruction ├── datasets.py ├── engine.py ├── losses.py ├── main_train.py ├── models.py ├── requirements.txt ├── run.sh ├── samplers.py ├── softmax.py └── utils.py ├── Robust ├── __pycache__ │ ├── calibration_tools.cpython-310.pyc │ ├── datasets.cpython-310.pyc │ ├── engine.cpython-310.pyc │ ├── losses.cpython-310.pyc │ ├── models.cpython-310.pyc │ ├── samplers.cpython-310.pyc │ ├── softmax.cpython-310.pyc │ ├── utils.cpython-310.pyc │ └── utils_robust.cpython-310.pyc ├── calibration_tools.py ├── datasets.py ├── engine.py ├── eval_OOD.py ├── losses.py ├── main_train.py ├── models.py ├── requirements.txt ├── run.sh ├── samplers.py ├── softmax.py ├── utils.py └── utils_robust.py └── Scaled_Attention ├── datasets.py ├── engine.py ├── losses.py ├── main_train.py ├── models.py ├── requirements.txt ├── run.sh ├── samplers.py ├── softmax.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled source # 2 | ################### 3 | *.com 4 | *.class 5 | *.dll 6 | *.exe 7 | *.o 8 | *.so 9 | 10 | # Packages # 11 | ############ 12 | # it's better to unpack these files and commit the raw source 13 | # git has its own built in compression methods 14 | *.7z 15 | *.dmg 16 | *.gz 17 | *.iso 18 | *.jar 19 | *.rar 20 | *.tar 21 | *.zip 22 | 23 | # Logs and databases # 24 | ###################### 25 | *.log 26 | *.sql 27 | *.sqlite 28 | 29 | # OS generated files # 30 | ###################### 31 | .DS_Store 32 | .DS_Store? 33 | ._* 34 | .Spotlight-V100 35 | .Trashes 36 | ehthumbs.db 37 | Thumbs.db 38 | 39 | _pycache_ 40 | wandb/ -------------------------------------------------------------------------------- /Attack/__pycache__/datasets.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rachtsy/KPCA_code/017fa4e26f511d1c11c1025f8301552d0a6c1545/Attack/__pycache__/datasets.cpython-310.pyc -------------------------------------------------------------------------------- /Attack/__pycache__/engine.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rachtsy/KPCA_code/017fa4e26f511d1c11c1025f8301552d0a6c1545/Attack/__pycache__/engine.cpython-310.pyc -------------------------------------------------------------------------------- /Attack/__pycache__/losses.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rachtsy/KPCA_code/017fa4e26f511d1c11c1025f8301552d0a6c1545/Attack/__pycache__/losses.cpython-310.pyc -------------------------------------------------------------------------------- /Attack/__pycache__/models.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rachtsy/KPCA_code/017fa4e26f511d1c11c1025f8301552d0a6c1545/Attack/__pycache__/models.cpython-310.pyc -------------------------------------------------------------------------------- /Attack/__pycache__/samplers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rachtsy/KPCA_code/017fa4e26f511d1c11c1025f8301552d0a6c1545/Attack/__pycache__/samplers.cpython-310.pyc -------------------------------------------------------------------------------- /Attack/__pycache__/softmax.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rachtsy/KPCA_code/017fa4e26f511d1c11c1025f8301552d0a6c1545/Attack/__pycache__/softmax.cpython-310.pyc -------------------------------------------------------------------------------- /Attack/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rachtsy/KPCA_code/017fa4e26f511d1c11c1025f8301552d0a6c1545/Attack/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /Attack/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from torchvision import datasets, transforms 5 | from torchvision.datasets.folder import ImageFolder, default_loader 6 | 7 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 8 | from timm.data import create_transform 9 | import pdb 10 | 11 | 12 | class INatDataset(ImageFolder): 13 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 14 | category='name', loader=default_loader): 15 | self.transform = transform 16 | self.loader = loader 17 | self.target_transform = target_transform 18 | self.year = year 19 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 20 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 21 | with open(path_json) as json_file: 22 | data = json.load(json_file) 23 | 24 | with open(os.path.join(root, 'categories.json')) as json_file: 25 | data_catg = json.load(json_file) 26 | 27 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 28 | 29 | with open(path_json_for_targeter) as json_file: 30 | data_for_targeter = json.load(json_file) 31 | 32 | targeter = {} 33 | indexer = 0 34 | for elem in data_for_targeter['annotations']: 35 | king = [] 36 | king.append(data_catg[int(elem['category_id'])][category]) 37 | if king[0] not in targeter.keys(): 38 | targeter[king[0]] = indexer 39 | indexer += 1 40 | self.nb_classes = len(targeter) 41 | 42 | self.samples = [] 43 | for elem in data['images']: 44 | cut = elem['file_name'].split('/') 45 | target_current = int(cut[2]) 46 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 47 | 48 | categors = data_catg[target_current] 49 | target_current_true = targeter[categors[category]] 50 | self.samples.append((path_current, target_current_true)) 51 | 52 | # __getitem__ and __len__ inherited from ImageFolder 53 | 54 | 55 | def build_dataset(is_train, args): 56 | transform = build_transform(is_train, args) 57 | 58 | if args.data_set == 'CIFAR': 59 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) 60 | nb_classes = 100 61 | elif args.data_set == 'IMNET': 62 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 63 | dataset = datasets.ImageFolder(root, transform=transform) 64 | nb_classes = 1000 65 | elif args.data_set == 'INAT': 66 | dataset = INatDataset(args.data_path, train=is_train, year=2018, 67 | category=args.inat_category, transform=transform) 68 | nb_classes = dataset.nb_classes 69 | elif args.data_set == 'INAT19': 70 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 71 | category=args.inat_category, transform=transform) 72 | nb_classes = dataset.nb_classes 73 | 74 | return dataset, nb_classes 75 | 76 | 77 | def build_transform(is_train, args): 78 | resize_im = args.input_size > 32 79 | if is_train: 80 | # this should always dispatch to transforms_imagenet_train 81 | transform = create_transform( 82 | input_size=args.input_size, 83 | is_training=True, 84 | color_jitter=args.color_jitter, 85 | auto_augment=args.aa, 86 | interpolation=args.train_interpolation, 87 | re_prob=args.reprob, 88 | re_mode=args.remode, 89 | re_count=args.recount, 90 | ) 91 | if not resize_im: 92 | # replace RandomResizedCropAndInterpolation with 93 | # RandomCrop 94 | transform.transforms[0] = transforms.RandomCrop( 95 | args.input_size, padding=4) 96 | return transform 97 | 98 | t = [] 99 | if resize_im: 100 | size = int((256 / 224) * args.input_size) 101 | t.append( 102 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 103 | ) 104 | t.append(transforms.CenterCrop(args.input_size)) 105 | 106 | t.append(transforms.ToTensor()) 107 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 108 | return transforms.Compose(t) 109 | -------------------------------------------------------------------------------- /Attack/engine.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train and eval functions used in main.py 3 | """ 4 | import math 5 | import sys 6 | from typing import Iterable, Optional 7 | import pdb 8 | 9 | import torch 10 | import wandb 11 | 12 | from timm.data import Mixup 13 | from timm.utils import accuracy, ModelEma 14 | import numpy as np 15 | 16 | from losses import DistillationLoss 17 | import utils 18 | from cleverhans.torch.attacks.fast_gradient_method import fast_gradient_method 19 | from cleverhans.torch.attacks.projected_gradient_descent import projected_gradient_descent 20 | from cleverhans.torch.attacks.spsa import spsa 21 | from cleverhans.torch.attacks.spsa import spsa 22 | from cleverhans.torch.attacks.sparse_l1_descent import sparse_l1_descent 23 | from cleverhans.torch.attacks.noise import noise 24 | from cleverhans.torch.attacks.hop_skip_jump_attack import hop_skip_jump_attack 25 | from cleverhans.torch.attacks.carlini_wagner_l2 import carlini_wagner_l2 26 | 27 | 28 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 29 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 30 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 31 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 32 | set_training_mode=True): 33 | model.train(set_training_mode) 34 | metric_logger = utils.MetricLogger(delimiter=" ") 35 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 36 | header = 'Epoch: [{}]'.format(epoch) 37 | print_freq = 10 38 | 39 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 40 | samples = samples.to(device, non_blocking=True) 41 | targets = targets.to(device, non_blocking=True) 42 | 43 | if mixup_fn is not None: 44 | samples, targets = mixup_fn(samples, targets) 45 | 46 | with torch.cuda.amp.autocast(): 47 | # pdb.set_trace() 48 | outputs = model(samples) 49 | loss = criterion(samples, outputs, targets) 50 | 51 | loss_value = loss.item() 52 | 53 | if not math.isfinite(loss_value): 54 | print("Loss is {}, ".format(loss_value)) 55 | sys.exit(1) 56 | 57 | optimizer.zero_grad() 58 | 59 | # this attribute is added by timm on one optimizer (adahessian) 60 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 61 | loss_scaler(loss, optimizer, clip_grad=max_norm, 62 | parameters=model.parameters(), create_graph=is_second_order) 63 | 64 | torch.cuda.synchronize() 65 | if model_ema is not None: 66 | model_ema.update(model) 67 | 68 | metric_logger.update(loss=loss_value) 69 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 70 | # gather the stats from all processes 71 | metric_logger.synchronize_between_processes() 72 | print("Averaged stats:", metric_logger) 73 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 74 | 75 | 76 | # @torch.no_grad() 77 | def evaluate(data_loader, model, device, attack='none', eps=0.03): 78 | criterion = torch.nn.CrossEntropyLoss() 79 | 80 | metric_logger = utils.MetricLogger(delimiter=" ") 81 | header = 'Test:' 82 | 83 | # switch to evaluation mode 84 | model.eval() 85 | 86 | i=0 87 | for images, target in metric_logger.log_every(data_loader, 10, header): 88 | images = images.to(device, non_blocking=True) 89 | target = target.to(device, non_blocking=True) 90 | bs = images.shape[0] 91 | if attack != 'none': 92 | # bad_indices = np.random.choice(bs, bs, replace=False) 93 | if attack == 'fgm': 94 | images = fast_gradient_method(model, images, eps, np.inf) 95 | elif attack == 'pgd': 96 | images = projected_gradient_descent(model, images, eps, 0.15 * eps, 20, np.inf) 97 | elif attack == 'sld': 98 | images = sparse_l1_descent(model, images) 99 | elif attack == 'noise': 100 | images = noise(images) 101 | elif attack == 'cw': 102 | images = carlini_wagner_l2(model, images, 1000, confidence=eps) 103 | elif attack == 'spsa': 104 | images = spsa(model, images, eps, 10) 105 | print("here") 106 | elif attack == 'hsja': 107 | # can do targeted attack 108 | images = hop_skip_jump_attack(model, images, np.inf) 109 | # compute output 110 | with torch.cuda.amp.autocast(): 111 | with torch.no_grad(): 112 | output = model(images) 113 | loss = criterion(output, target) 114 | 115 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 116 | 117 | metric_logger.update(loss=loss.item()) 118 | metric_logger.meters['acc1'].update(acc1.item(), n=bs) 119 | metric_logger.meters['acc5'].update(acc5.item(), n=bs) 120 | for k, meter in metric_logger.meters.items(): 121 | wandb.log({f"val_{k}": meter.global_avg}) 122 | 123 | # gather the stats from all processes 124 | metric_logger.synchronize_between_processes() 125 | if attack != 'none': 126 | print(f'Evaluating attack method {attack} with perturbation budget {eps}:') 127 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 128 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 129 | 130 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 131 | -------------------------------------------------------------------------------- /Attack/hubconf.py: -------------------------------------------------------------------------------- 1 | from models import * 2 | 3 | dependencies = ["torch", "torchvision", "timm"] 4 | -------------------------------------------------------------------------------- /Attack/log_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import wandb 3 | 4 | # to transfer text file of attack results to wandb 5 | attack = "fgm" 6 | wandb.init(project="project_name") 7 | wandb.run.name = f"run_name_{attack}" 8 | for i in range(1,7): 9 | eps = str(i/255)[0:-2] 10 | f = open(f'/path/to/data/wandb_run_name_{attack}_{attack}_{str(i)}.txt') 11 | data = json.load(f) 12 | wandb.log({f"loss_{attack}":data['loss'], "eps":float(eps)}) 13 | wandb.log({f"acc1_{attack}":data['acc1'], "eps":float(eps)}) 14 | wandb.log({f"acc5_{attack}":data['acc5'], "eps":float(eps)}) -------------------------------------------------------------------------------- /Attack/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements the knowledge distillation loss 3 | """ 4 | import torch 5 | from torch.nn import functional as F 6 | 7 | 8 | class DistillationLoss(torch.nn.Module): 9 | """ 10 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 11 | taking a teacher model prediction and using it as additional supervision. 12 | """ 13 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 14 | distillation_type: str, alpha: float, tau: float): 15 | super().__init__() 16 | self.base_criterion = base_criterion 17 | self.teacher_model = teacher_model 18 | assert distillation_type in ['none', 'soft', 'hard'] 19 | self.distillation_type = distillation_type 20 | self.alpha = alpha 21 | self.tau = tau 22 | 23 | def forward(self, inputs, outputs, labels): 24 | """ 25 | Args: 26 | inputs: The original inputs that are feed to the teacher model 27 | outputs: the outputs of the model to be trained. It is expected to be 28 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 29 | in the first position and the distillation predictions as the second output 30 | labels: the labels for the base criterion 31 | """ 32 | outputs_kd = None 33 | if not isinstance(outputs, torch.Tensor): 34 | # assume that the model outputs a tuple of [outputs, outputs_kd] 35 | outputs, outputs_kd = outputs 36 | base_loss = self.base_criterion(outputs, labels) 37 | if self.distillation_type == 'none': 38 | return base_loss 39 | 40 | if outputs_kd is None: 41 | raise ValueError("When knowledge distillation is enabled, the model is " 42 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 43 | "class_token and the dist_token") 44 | # don't backprop throught the teacher 45 | with torch.no_grad(): 46 | teacher_outputs = self.teacher_model(inputs) 47 | 48 | if self.distillation_type == 'soft': 49 | T = self.tau 50 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 51 | # with slight modifications 52 | distillation_loss = F.kl_div( 53 | F.log_softmax(outputs_kd / T, dim=1), 54 | F.log_softmax(teacher_outputs / T, dim=1), 55 | reduction='sum', 56 | log_target=True 57 | ) * (T * T) / outputs_kd.numel() 58 | elif self.distillation_type == 'hard': 59 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) 60 | 61 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 62 | return loss 63 | -------------------------------------------------------------------------------- /Attack/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.nn as nn 5 | from functools import partial 6 | 7 | 8 | from timm.models.vision_transformer import _cfg 9 | from softmax import VisionTransformer 10 | from timm.models.registry import register_model 11 | from timm.models.layers import trunc_normal_ 12 | # from xcit import XCiT, HDPXCiT 13 | 14 | class DistilledVisionTransformer(VisionTransformer): 15 | def __init__(self, *args, **kwargs): 16 | super().__init__(*args, **kwargs) 17 | self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 18 | num_patches = self.patch_embed.num_patches 19 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) 20 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() 21 | 22 | trunc_normal_(self.dist_token, std=.02) 23 | trunc_normal_(self.pos_embed, std=.02) 24 | self.head_dist.apply(self._init_weights) 25 | 26 | def forward_features(self, x): 27 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 28 | # with slight modifications to add the dist_token 29 | B = x.shape[0] 30 | x = self.patch_embed(x) 31 | 32 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 33 | dist_token = self.dist_token.expand(B, -1, -1) 34 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 35 | 36 | x = x + self.pos_embed 37 | x = self.pos_drop(x) 38 | 39 | for blk in self.blocks: 40 | x = blk(x) 41 | 42 | 43 | x = self.norm(x) 44 | return x[:, 0], x[:, 1] 45 | 46 | def forward(self, x): 47 | x, x_dist = self.forward_features(x) 48 | x = self.head(x) 49 | x_dist = self.head_dist(x_dist) 50 | if self.training: 51 | return x, x_dist 52 | else: 53 | # during inference, return the average of both classifier predictions 54 | return (x + x_dist) / 2 55 | 56 | # register model with timms to be able to call it from "create_model" using its function name 57 | # but mainly edit the model from softmax.py 58 | @register_model 59 | def deit_tiny_patch16_224(pretrained=False, **kwargs): 60 | from softmax import VisionTransformer 61 | model = VisionTransformer( 62 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 63 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) # Tan's NOTE: in the original code, num_heads = 3 here 64 | model.default_cfg = _cfg() 65 | return model 66 | 67 | @register_model 68 | def deit_base_patch16_224(pretrained=False, **kwargs): 69 | from softmax import VisionTransformer 70 | model = VisionTransformer( 71 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 72 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) # Tan's NOTE: in the original code, num_heads = 3 here 73 | model.default_cfg = _cfg() 74 | return model 75 | 76 | 77 | -------------------------------------------------------------------------------- /Attack/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.11.0 2 | torchvision==0.12.0 3 | timm==0.6.7 -------------------------------------------------------------------------------- /Attack/resmlp_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | 5 | from timm.models.vision_transformer import Mlp, PatchEmbed , _cfg 6 | from timm.models.registry import register_model 7 | from timm.models.layers import trunc_normal_, DropPath 8 | 9 | 10 | __all__ = [ 11 | 'resMLP_12', 'resMLP_24', 'resMLP_36', 'resmlpB_24' 12 | ] 13 | 14 | class Affine(nn.Module): 15 | def __init__(self, dim): 16 | super().__init__() 17 | self.alpha = nn.Parameter(torch.ones(dim)) 18 | self.beta = nn.Parameter(torch.zeros(dim)) 19 | 20 | def forward(self, x): 21 | return self.alpha * x + self.beta 22 | 23 | class layers_scale_mlp_blocks(nn.Module): 24 | 25 | def __init__(self, dim, drop=0., drop_path=0., act_layer=nn.GELU,init_values=1e-4,num_patches = 196): 26 | super().__init__() 27 | self.norm1 = Affine(dim) 28 | self.attn = nn.Linear(num_patches, num_patches) 29 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 30 | self.norm2 = Affine(dim) 31 | self.mlp = Mlp(in_features=dim, hidden_features=int(4.0 * dim), act_layer=act_layer, drop=drop) 32 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 33 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 34 | 35 | def forward(self, x): 36 | x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x).transpose(1,2)).transpose(1,2)) 37 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) 38 | return x 39 | 40 | 41 | class resmlp_models(nn.Module): 42 | 43 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,drop_rate=0., 44 | Patch_layer=PatchEmbed,act_layer=nn.GELU, 45 | drop_path_rate=0.0,init_scale=1e-4): 46 | super().__init__() 47 | 48 | 49 | 50 | self.num_classes = num_classes 51 | self.num_features = self.embed_dim = embed_dim 52 | 53 | self.patch_embed = Patch_layer( 54 | img_size=img_size, patch_size=patch_size, in_chans=int(in_chans), embed_dim=embed_dim) 55 | num_patches = self.patch_embed.num_patches 56 | dpr = [drop_path_rate for i in range(depth)] 57 | 58 | self.blocks = nn.ModuleList([ 59 | layers_scale_mlp_blocks( 60 | dim=embed_dim,drop=drop_rate,drop_path=dpr[i], 61 | act_layer=act_layer,init_values=init_scale, 62 | num_patches=num_patches) 63 | for i in range(depth)]) 64 | 65 | 66 | self.norm = Affine(embed_dim) 67 | 68 | 69 | 70 | self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')] 71 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 72 | self.apply(self._init_weights) 73 | 74 | def _init_weights(self, m): 75 | if isinstance(m, nn.Linear): 76 | trunc_normal_(m.weight, std=0.02) 77 | if isinstance(m, nn.Linear) and m.bias is not None: 78 | nn.init.constant_(m.bias, 0) 79 | elif isinstance(m, nn.LayerNorm): 80 | nn.init.constant_(m.bias, 0) 81 | nn.init.constant_(m.weight, 1.0) 82 | 83 | 84 | 85 | def get_classifier(self): 86 | return self.head 87 | 88 | def reset_classifier(self, num_classes, global_pool=''): 89 | self.num_classes = num_classes 90 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 91 | 92 | def forward_features(self, x): 93 | B = x.shape[0] 94 | 95 | x = self.patch_embed(x) 96 | 97 | for i , blk in enumerate(self.blocks): 98 | x = blk(x) 99 | 100 | x = self.norm(x) 101 | x = x.mean(dim=1).reshape(B,1,-1) 102 | 103 | return x[:, 0] 104 | 105 | def forward(self, x): 106 | x = self.forward_features(x) 107 | x = self.head(x) 108 | return x 109 | 110 | @register_model 111 | def resmlp_12(pretrained=False,dist=False, **kwargs): 112 | model = resmlp_models( 113 | patch_size=16, embed_dim=384, depth=12, 114 | Patch_layer=PatchEmbed, 115 | init_scale=0.1,**kwargs) 116 | 117 | model.default_cfg = _cfg() 118 | if pretrained: 119 | if dist: 120 | url_path = "https://dl.fbaipublicfiles.com/deit/resmlp_12_dist.pth" 121 | else: 122 | url_path = "https://dl.fbaipublicfiles.com/deit/resmlp_12_no_dist.pth" 123 | checkpoint = torch.hub.load_state_dict_from_url( 124 | url=url_path, 125 | map_location="cpu", check_hash=True 126 | ) 127 | 128 | model.load_state_dict(checkpoint) 129 | return model 130 | 131 | @register_model 132 | def resmlp_24(pretrained=False,dist=False,dino=False, **kwargs): 133 | model = resmlp_models( 134 | patch_size=16, embed_dim=384, depth=24, 135 | Patch_layer=PatchEmbed, 136 | init_scale=1e-5,**kwargs) 137 | model.default_cfg = _cfg() 138 | if pretrained: 139 | if dist: 140 | url_path = "https://dl.fbaipublicfiles.com/deit/resmlp_24_dist.pth" 141 | elif dino: 142 | url_path = "https://dl.fbaipublicfiles.com/deit/resmlp_24_dino.pth" 143 | else: 144 | url_path = "https://dl.fbaipublicfiles.com/deit/resmlp_24_no_dist.pth" 145 | checkpoint = torch.hub.load_state_dict_from_url( 146 | url=url_path, 147 | map_location="cpu", check_hash=True 148 | ) 149 | 150 | model.load_state_dict(checkpoint) 151 | return model 152 | 153 | @register_model 154 | def resmlp_36(pretrained=False,dist=False, **kwargs): 155 | model = resmlp_models( 156 | patch_size=16, embed_dim=384, depth=36, 157 | Patch_layer=PatchEmbed, 158 | init_scale=1e-6,**kwargs) 159 | model.default_cfg = _cfg() 160 | if pretrained: 161 | if dist: 162 | url_path = "https://dl.fbaipublicfiles.com/deit/resmlp_36_dist.pth" 163 | else: 164 | url_path = "https://dl.fbaipublicfiles.com/deit/resmlp_36_no_dist.pth" 165 | checkpoint = torch.hub.load_state_dict_from_url( 166 | url=url_path, 167 | map_location="cpu", check_hash=True 168 | ) 169 | 170 | model.load_state_dict(checkpoint) 171 | return model 172 | 173 | @register_model 174 | def resmlpB_24(pretrained=False,dist=False, in_22k = False, **kwargs): 175 | model = resmlp_models( 176 | patch_size=8, embed_dim=768, depth=24, 177 | Patch_layer=PatchEmbed, 178 | init_scale=1e-6,**kwargs) 179 | model.default_cfg = _cfg() 180 | if pretrained: 181 | if dist: 182 | url_path = "https://dl.fbaipublicfiles.com/deit/resmlpB_24_dist.pth" 183 | elif in_22k: 184 | url_path = "https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth" 185 | else: 186 | url_path = "https://dl.fbaipublicfiles.com/deit/resmlpB_24_no_dist.pth" 187 | 188 | checkpoint = torch.hub.load_state_dict_from_url( 189 | url=url_path, 190 | map_location="cpu", check_hash=True 191 | ) 192 | 193 | model.load_state_dict(checkpoint) 194 | 195 | return model 196 | -------------------------------------------------------------------------------- /Attack/run_fgm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | count=0 4 | offset=1800 5 | 6 | for i in {1..6} 7 | do 8 | (( count++ )) 9 | port_num=`expr $count + $offset` 10 | eps=$(perl -e "print $i / 255") 11 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=$port_num --use_env main.py --model deit_tiny_patch16_224 --batch-size 48 --data-path /path/to/data/imagenet --output_dir /path/to/output/directory/ --project_name 'project_name' --job_name job_name --attack 'fgm' --eps $eps --finetune /path/to/trained/model/ --eval 1 --robust --num_iter 2 --layer -1 --lambd 4 12 | done -------------------------------------------------------------------------------- /Attack/run_pgd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | count=0 4 | offset=1800 5 | 6 | for i in {1..6} 7 | do 8 | (( count++ )) 9 | port_num=`expr $count + $offset` 10 | eps=$(perl -e "print $i / 255") 11 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=$port_num --use_env main.py --model deit_tiny_patch16_224 --batch-size 48 --data-path /path/to/data/imagenet --output_dir /path/to/output/directory/ --project_name 'project_name' --job_name job_name --attack 'pgd' --eps $eps --finetune /path/to/trained/model/ --eval 1 --robust --num_iter 2 --layer -1 --lambd 4 12 | done -------------------------------------------------------------------------------- /Attack/run_with_submitit.py: -------------------------------------------------------------------------------- 1 | """ 2 | A script to run multinode training with submitit. 3 | """ 4 | import argparse 5 | import os 6 | import uuid 7 | from pathlib import Path 8 | 9 | import main as classification 10 | import submitit 11 | 12 | 13 | def parse_args(): 14 | classification_parser = classification.get_args_parser() 15 | parser = argparse.ArgumentParser("Submitit for DeiT", parents=[classification_parser]) 16 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 17 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 18 | parser.add_argument("--timeout", default=2800, type=int, help="Duration of the job") 19 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 20 | 21 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 22 | parser.add_argument("--use_volta32", action='store_true', help="Big models? Use this") 23 | parser.add_argument('--comment', default="", type=str, 24 | help='Comment to pass to scheduler, e.g. priority message') 25 | return parser.parse_args() 26 | 27 | 28 | def get_shared_folder() -> Path: 29 | user = os.getenv("USER") 30 | if Path("/checkpoint/").is_dir(): 31 | p = Path(f"/checkpoint/{user}/experiments") 32 | p.mkdir(exist_ok=True) 33 | return p 34 | raise RuntimeError("No shared folder available") 35 | 36 | 37 | def get_init_file(): 38 | # Init file must not exist, but it's parent dir must exist. 39 | os.makedirs(str(get_shared_folder()), exist_ok=True) 40 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 41 | if init_file.exists(): 42 | os.remove(str(init_file)) 43 | return init_file 44 | 45 | 46 | class Trainer(object): 47 | def __init__(self, args): 48 | self.args = args 49 | 50 | def __call__(self): 51 | import main as classification 52 | 53 | self._setup_gpu_args() 54 | classification.main(self.args) 55 | 56 | def checkpoint(self): 57 | import os 58 | import submitit 59 | 60 | self.args.dist_url = get_init_file().as_uri() 61 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 62 | if os.path.exists(checkpoint_file): 63 | self.args.resume = checkpoint_file 64 | print("Requeuing ", self.args) 65 | empty_trainer = type(self)(self.args) 66 | return submitit.helpers.DelayedSubmission(empty_trainer) 67 | 68 | def _setup_gpu_args(self): 69 | import submitit 70 | from pathlib import Path 71 | 72 | job_env = submitit.JobEnvironment() 73 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 74 | self.args.gpu = job_env.local_rank 75 | self.args.rank = job_env.global_rank 76 | self.args.world_size = job_env.num_tasks 77 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 78 | 79 | 80 | def main(): 81 | args = parse_args() 82 | if args.job_dir == "": 83 | args.job_dir = get_shared_folder() / "%j" 84 | 85 | # Note that the folder will depend on the job_id, to easily track experiments 86 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 87 | 88 | num_gpus_per_node = args.ngpus 89 | nodes = args.nodes 90 | timeout_min = args.timeout 91 | 92 | partition = args.partition 93 | kwargs = {} 94 | if args.use_volta32: 95 | kwargs['slurm_constraint'] = 'volta32gb' 96 | if args.comment: 97 | kwargs['slurm_comment'] = args.comment 98 | 99 | executor.update_parameters( 100 | mem_gb=40 * num_gpus_per_node, 101 | gpus_per_node=num_gpus_per_node, 102 | tasks_per_node=num_gpus_per_node, # one task per GPU 103 | cpus_per_task=10, 104 | nodes=nodes, 105 | timeout_min=timeout_min, # max is 60 * 72 106 | # Below are cluster dependent parameters 107 | slurm_partition=partition, 108 | slurm_signal_delay_s=120, 109 | **kwargs 110 | ) 111 | 112 | executor.update_parameters(name="deit") 113 | 114 | args.dist_url = get_init_file().as_uri() 115 | args.output_dir = args.job_dir 116 | 117 | trainer = Trainer(args) 118 | job = executor.submit(trainer) 119 | 120 | print("Submitted job_id:", job.job_id) 121 | 122 | 123 | if __name__ == "__main__": 124 | main() 125 | -------------------------------------------------------------------------------- /Attack/samplers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import math 4 | 5 | 6 | class RASampler(torch.utils.data.Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset for distributed, 8 | with repeated augmentation. 9 | It ensures that different each augmented version of a sample will be visible to a 10 | different process (GPU) 11 | Heavily based on torch.utils.data.DistributedSampler 12 | """ 13 | 14 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 15 | if num_replicas is None: 16 | if not dist.is_available(): 17 | raise RuntimeError("Requires distributed package to be available") 18 | num_replicas = dist.get_world_size() 19 | if rank is None: 20 | if not dist.is_available(): 21 | raise RuntimeError("Requires distributed package to be available") 22 | rank = dist.get_rank() 23 | self.dataset = dataset 24 | self.num_replicas = num_replicas 25 | self.rank = rank 26 | self.epoch = 0 27 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 28 | self.total_size = self.num_samples * self.num_replicas 29 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 30 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 31 | self.shuffle = shuffle 32 | 33 | def __iter__(self): 34 | # deterministically shuffle based on epoch 35 | g = torch.Generator() 36 | g.manual_seed(self.epoch) 37 | if self.shuffle: 38 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 39 | else: 40 | indices = list(range(len(self.dataset))) 41 | 42 | # add extra samples to make it evenly divisible 43 | indices = [ele for ele in indices for i in range(3)] 44 | indices += indices[:(self.total_size - len(indices))] 45 | assert len(indices) == self.total_size 46 | 47 | # subsample 48 | indices = indices[self.rank:self.total_size:self.num_replicas] 49 | assert len(indices) == self.num_samples 50 | 51 | return iter(indices[:self.num_selected_samples]) 52 | 53 | def __len__(self): 54 | return self.num_selected_samples 55 | 56 | def set_epoch(self, epoch): 57 | self.epoch = epoch 58 | -------------------------------------------------------------------------------- /Attack/softmax.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | from functools import partial 4 | from collections import OrderedDict 5 | from copy import deepcopy 6 | from statistics import mean 7 | import numpy as np 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ 14 | from timm.models.vision_transformer import init_weights_vit_timm, init_weights_vit_jax, _load_weights 15 | from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv 16 | from utils import named_apply 17 | import copy 18 | import wandb 19 | 20 | 21 | 22 | class Attention(nn.Module): 23 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., 24 | robust=False, layerth=0, n=1, lambd=0, layer=0): 25 | super().__init__() 26 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 27 | self.num_heads = num_heads 28 | head_dim = dim // num_heads 29 | self.n = n 30 | self.lambd = lambd 31 | self.layer = layer 32 | # sqrt (D) 33 | self.scale = head_dim ** -0.5 34 | self.layerth = layerth 35 | 36 | self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias) 37 | 38 | self.attn_drop = nn.Dropout(attn_drop) 39 | 40 | self.proj = nn.Linear(dim, dim) 41 | self.proj_drop = nn.Dropout(proj_drop) 42 | self.robust = robust 43 | 44 | def forward(self, x): 45 | B, N, C = x.shape 46 | # q,k -> B -> heads -> n -> features 47 | qkv = self.qkv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 48 | k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 49 | 50 | if self.robust and self.layer < 0: 51 | l = torch.zeros((B,self.num_heads,N,C // self.num_heads)).to(torch.device("cuda"), non_blocking=True) 52 | y = torch.zeros((B,self.num_heads,N,C // self.num_heads)).to(torch.device("cuda"), non_blocking=True) 53 | 54 | mu=N*C/4/k.norm(p=1,dim=[-1,-2],keepdim=True) 55 | 56 | for i in range(0,self.n-1): 57 | s = k-l+y/mu 58 | s_less = s.le(-self.lambd*mu).int() 59 | s_more = s.ge(self.lambd*mu).int() 60 | s = (s-self.lambd*mu)*s_more + (s+self.lambd*mu)*s_less 61 | k2 = k-s-y/mu 62 | l = (k2 @ k2.transpose(-2, -1)) * self.scale 63 | l = l.softmax(dim=-1) 64 | l = l @ v 65 | y = y+mu*(k-l-s) 66 | 67 | s = k-l+y/mu 68 | s_less = s.le(-self.lambd*mu).int() 69 | s_more = s.ge(self.lambd*mu).int() 70 | s = (s-self.lambd*mu)*s_more + (s+self.lambd*mu)*s_less 71 | k2 = k-s-y/mu 72 | l = (k2 @ k2.transpose(-2, -1)) * self.scale 73 | l = l.softmax(dim=-1) 74 | l = self.attn_drop(l) 75 | x = l @ v 76 | y = y+mu*(k-x-s) 77 | 78 | elif self.robust and self.layerth==self.layer: 79 | l = torch.zeros((B,self.num_heads,N,C // self.num_heads)).to(torch.device("cuda"), non_blocking=True) 80 | y = torch.zeros((B,self.num_heads,N,C // self.num_heads)).to(torch.device("cuda"), non_blocking=True) 81 | 82 | mu=N*C/4/k.norm(p=1,dim=[-1,-2],keepdim=True) 83 | 84 | for i in range(0,self.n-1): 85 | s = k-l+y/mu 86 | s_less = s.le(-self.lambd*mu).int() 87 | s_more = s.ge(self.lambd*mu).int() 88 | s = (s-self.lambd*mu)*s_more + (s+self.lambd*mu)*s_less 89 | k2 = k-s-y/mu 90 | l = (k2 @ k2.transpose(-2, -1)) * self.scale 91 | l = l.softmax(dim=-1) 92 | l = l @ v 93 | y = y+mu*(k-l-s) 94 | 95 | s = k-l+y/mu 96 | s_less = s.le(-self.lambd*mu).int() 97 | s_more = s.ge(self.lambd*mu).int() 98 | s = (s-self.lambd*mu)*s_more + (s+self.lambd*mu)*s_less 99 | k2 = k-s-y/mu 100 | l = (k2 @ k2.transpose(-2, -1)) * self.scale 101 | l = l.softmax(dim=-1) 102 | l = self.attn_drop(l) 103 | x = l @ v 104 | y = y+mu*(k-x-s) 105 | 106 | else: 107 | attn = (k @ k.transpose(-2, -1)) * self.scale 108 | attn = attn.softmax(dim=-1) 109 | 110 | attn = self.attn_drop(attn) 111 | 112 | # @ is a matrix multiplication 113 | x = (attn @ v) 114 | 115 | # @ is a matrix multiplication 116 | x = x.transpose(1, 2).reshape(B,N,C) 117 | 118 | x = self.proj(x) 119 | x = self.proj_drop(x) 120 | 121 | ################ COSINE SIMILARITY MEASURE 122 | # n = x.shape[1] #x is in shape of (batchsize, length, dim) 123 | # sqaure norm across features 124 | # x_norm = torch.norm(x, 2, dim = -1, keepdim= True) 125 | # x_ = x/x_norm 126 | # x_cossim = torch.tril((x_ @ x_.transpose(-2, -1)), diagonal= -1).sum(dim = (-1, -2))/(n*(n - 1)/2) 127 | # x_cossim = x_cossim.mean() 128 | # python debugger breakpoint 129 | # import pdb;pdb.set_trace() 130 | ################ 131 | 132 | return x 133 | 134 | 135 | class Block(nn.Module): 136 | 137 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 138 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, layerth=None, 139 | robust=False, n=1, lambd=0, layer=0): 140 | super().__init__() 141 | self.norm1 = norm_layer(dim) 142 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, 143 | attn_drop=attn_drop, proj_drop=drop, robust=robust, 144 | layerth=layerth, n=n, lambd=lambd, layer=layer) 145 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 146 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 147 | self.norm2 = norm_layer(dim) 148 | mlp_hidden_dim = int(dim * mlp_ratio) 149 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 150 | self.layerth = layerth 151 | 152 | def forward(self, x): 153 | x = x + self.drop_path(self.attn(self.norm1(x))) 154 | x = x + self.drop_path(self.mlp(self.norm2(x))) 155 | return x 156 | 157 | 158 | class VisionTransformer(nn.Module): 159 | """ Vision Transformer 160 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 161 | - https://arxiv.org/abs/2010.11929 162 | Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` 163 | - https://arxiv.org/abs/2012.12877 164 | """ 165 | 166 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 167 | num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, 168 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, 169 | act_layer=None, weight_init='',pretrained_cfg=None,pretrained_cfg_overlay=None,robust=False,n=1,lambd=0,layer=0): 170 | """ 171 | Args: 172 | img_size (int, tuple): input image size 173 | patch_size (int, tuple): patch size 174 | in_chans (int): number of input channels 175 | num_classes (int): number of classes for classification head 176 | embed_dim (int): embedding dimension 177 | depth (int): depth of transformer 178 | num_heads (int): number of attention heads 179 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 180 | qkv_bias (bool): enable bias for qkv if True 181 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 182 | distilled (bool): model includes a distillation token and head as in DeiT models 183 | drop_rate (float): dropout rate 184 | attn_drop_rate (float): attention dropout rate 185 | drop_path_rate (float): stochastic depth rate 186 | embed_layer (nn.Module): patch embedding layer 187 | norm_layer: (nn.Module): normalization layer 188 | weight_init: (str): weight init scheme 189 | """ 190 | super().__init__() 191 | self.num_classes = num_classes 192 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 193 | self.num_tokens = 2 if distilled else 1 194 | self.lambd = lambd 195 | self.n = n 196 | self.layer = layer 197 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 198 | act_layer = act_layer or nn.GELU 199 | 200 | self.patch_embed = embed_layer( 201 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 202 | num_patches = self.patch_embed.num_patches 203 | # print(img_size,patch_size,in_chans,num_patches) 204 | 205 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 206 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None 207 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 208 | self.pos_drop = nn.Dropout(p=drop_rate) 209 | 210 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 211 | self.blocks = nn.Sequential(*[ 212 | Block( 213 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, 214 | attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, 215 | layerth = i, robust=robust, n=self.n, lambd=self.lambd, layer=self.layer) 216 | for i in range(depth)]) 217 | self.norm = norm_layer(embed_dim) 218 | 219 | # Representation layer 220 | if representation_size and not distilled: 221 | self.num_features = representation_size 222 | self.pre_logits = nn.Sequential(OrderedDict([ 223 | ('fc', nn.Linear(embed_dim, representation_size)), 224 | ('act', nn.Tanh()) 225 | ])) 226 | else: 227 | self.pre_logits = nn.Identity() 228 | 229 | # Classifier head(s) 230 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 231 | self.head_dist = None 232 | if distilled: 233 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 234 | 235 | self.init_weights(weight_init) 236 | 237 | def init_weights(self, mode=''): 238 | assert mode in ('jax', 'jax_nlhb', 'nlhb', '') 239 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. 240 | trunc_normal_(self.pos_embed, std=.02) 241 | if self.dist_token is not None: 242 | trunc_normal_(self.dist_token, std=.02) 243 | if mode.startswith('jax'): 244 | # leave cls token as zeros to match jax impl 245 | partial(init_weights_vit_jax(mode, head_bias), head_bias=head_bias, jax_impl=True) 246 | else: 247 | trunc_normal_(self.cls_token, std=.02) 248 | init_weights_vit_timm 249 | 250 | def _init_weights(self, m): 251 | # this fn left here for compat with downstream users 252 | init_weights(m) 253 | 254 | @torch.jit.ignore() 255 | def load_pretrained(self, checkpoint_path, prefix=''): 256 | _load_weights(self, checkpoint_path, prefix) 257 | 258 | @torch.jit.ignore 259 | def no_weight_decay(self): 260 | return {'pos_embed', 'cls_token', 'dist_token'} 261 | 262 | def get_classifier(self): 263 | if self.dist_token is None: 264 | return self.head 265 | else: 266 | return self.head, self.head_dist 267 | 268 | def reset_classifier(self, num_classes, global_pool=''): 269 | self.num_classes = num_classes 270 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 271 | if self.num_tokens == 2: 272 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 273 | 274 | def forward_features(self, x): 275 | x = self.patch_embed(x) 276 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks 277 | if self.dist_token is None: 278 | x = torch.cat((cls_token, x), dim=1) 279 | else: 280 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) 281 | # add the same pos_emb token to each sample? broadcasting... 282 | x = self.pos_drop(x + self.pos_embed) 283 | x = self.blocks(x) 284 | x = self.norm(x) 285 | 286 | if self.dist_token is None: 287 | return self.pre_logits(x[:, 0]) 288 | else: 289 | return x[:, 0], x[:, 1] 290 | 291 | def forward(self, x): 292 | x = self.forward_features(x) 293 | if self.head_dist is not None: 294 | x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple 295 | if self.training and not torch.jit.is_scripting(): 296 | # during inference, return the average of both classifier predictions 297 | return x, x_dist 298 | else: 299 | return (x + x_dist) / 2 300 | else: 301 | x = self.head(x) 302 | return x 303 | 304 | 305 | 306 | 307 | -------------------------------------------------------------------------------- /Attack/test_init.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pdb 3 | 4 | 5 | def initialization(num_samples, num_cluster): 6 | cluster_init_size = [num_samples // num_cluster for _ in range(num_cluster)] 7 | density = np.ones(num_samples) / num_samples 8 | all_indices = np.arange(num_samples) 9 | cluster_indices = [] 10 | for i in range(num_cluster): 11 | indices = np.random.choice(a=all_indices, size=cluster_init_size[i], replace=False, p=density) 12 | cluster_indices.append(np.sort(indices)) 13 | if i == num_cluster - 1: 14 | break 15 | density[indices] = 0 16 | nonzero = density > 0 17 | density[nonzero] = 1 / np.sum(nonzero) 18 | return cluster_indices 19 | 20 | indices = initialization(197, 5) 21 | print(indices) -------------------------------------------------------------------------------- /Attack/tox.ini: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ignore = F401,E402,F403,W503,W504 4 | -------------------------------------------------------------------------------- /Attack/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Misc functions, including distributed helpers. 3 | 4 | Mostly copy-paste from torchvision references. 5 | """ 6 | import io 7 | import os 8 | import time 9 | from collections import defaultdict, deque 10 | import datetime 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | from typing import Any, Callable, Dict, Iterator, Tuple, Type, Union 14 | 15 | import torch 16 | import torch.distributed as dist 17 | from torch import nn as nn 18 | 19 | 20 | data_loaders_names = { 21 | 'Brightness': 'brightness', 22 | 'Contrast': 'contrast', 23 | 'Defocus Blur': 'defocus_blur', 24 | 'Elastic Transform': 'elastic_transform', 25 | 'Fog': 'fog', 26 | 'Frost': 'frost', 27 | 'Gaussian Noise': 'gaussian_noise', 28 | 'Glass Blur': 'glass_blur', 29 | 'Impulse Noise': 'impulse_noise', 30 | 'JPEG Compression': 'jpeg_compression', 31 | 'Motion Blur': 'motion_blur', 32 | 'Pixelate': 'pixelate', 33 | 'Shot Noise': 'shot_noise', 34 | 'Snow': 'snow', 35 | 'Zoom Blur': 'zoom_blur' 36 | } 37 | 38 | 39 | def get_ce_alexnet(): 40 | """Returns Corruption Error values for AlexNet""" 41 | 42 | ce_alexnet = dict() 43 | ce_alexnet['Gaussian Noise'] = 0.886428 44 | ce_alexnet['Shot Noise'] = 0.894468 45 | ce_alexnet['Impulse Noise'] = 0.922640 46 | ce_alexnet['Defocus Blur'] = 0.819880 47 | ce_alexnet['Glass Blur'] = 0.826268 48 | ce_alexnet['Motion Blur'] = 0.785948 49 | ce_alexnet['Zoom Blur'] = 0.798360 50 | ce_alexnet['Snow'] = 0.866816 51 | ce_alexnet['Frost'] = 0.826572 52 | ce_alexnet['Fog'] = 0.819324 53 | ce_alexnet['Brightness'] = 0.564592 54 | ce_alexnet['Contrast'] = 0.853204 55 | ce_alexnet['Elastic Transform'] = 0.646056 56 | ce_alexnet['Pixelate'] = 0.717840 57 | ce_alexnet['JPEG Compression'] = 0.606500 58 | 59 | return ce_alexnet 60 | 61 | 62 | def get_mce_from_accuracy(accuracy, error_alexnet): 63 | """Computes mean Corruption Error from accuracy""" 64 | error = 100. - accuracy 65 | ce = error / (error_alexnet * 100.) 66 | 67 | return ce 68 | 69 | 70 | class SmoothedValue(object): 71 | """Track a series of values and provide access to smoothed values over a 72 | window or the global series average. 73 | """ 74 | 75 | def __init__(self, window_size=20, fmt=None): 76 | if fmt is None: 77 | fmt = "{median:.4f} ({global_avg:.4f})" 78 | self.deque = deque(maxlen=window_size) 79 | self.total = 0.0 80 | self.count = 0 81 | self.fmt = fmt 82 | 83 | def update(self, value, n=1): 84 | self.deque.append(value) 85 | self.count += n 86 | self.total += value * n 87 | 88 | def synchronize_between_processes(self): 89 | """ 90 | Warning: does not synchronize the deque! 91 | """ 92 | if not is_dist_avail_and_initialized(): 93 | return 94 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 95 | dist.barrier() 96 | dist.all_reduce(t) 97 | t = t.tolist() 98 | self.count = int(t[0]) 99 | self.total = t[1] 100 | 101 | @property 102 | def median(self): 103 | d = torch.tensor(list(self.deque)) 104 | return d.median().item() 105 | 106 | @property 107 | def avg(self): 108 | d = torch.tensor(list(self.deque), dtype=torch.float32) 109 | return d.mean().item() 110 | 111 | @property 112 | def global_avg(self): 113 | return self.total / self.count 114 | 115 | @property 116 | def max(self): 117 | return max(self.deque) 118 | 119 | @property 120 | def value(self): 121 | return self.deque[-1] 122 | 123 | def __str__(self): 124 | return self.fmt.format( 125 | median=self.median, 126 | avg=self.avg, 127 | global_avg=self.global_avg, 128 | max=self.max, 129 | value=self.value) 130 | 131 | 132 | class MetricLogger(object): 133 | def __init__(self, delimiter="\t"): 134 | self.meters = defaultdict(SmoothedValue) 135 | self.delimiter = delimiter 136 | 137 | def update(self, **kwargs): 138 | for k, v in kwargs.items(): 139 | if isinstance(v, torch.Tensor): 140 | v = v.item() 141 | assert isinstance(v, (float, int)) 142 | self.meters[k].update(v) 143 | 144 | def __getattr__(self, attr): 145 | if attr in self.meters: 146 | return self.meters[attr] 147 | if attr in self.__dict__: 148 | return self.__dict__[attr] 149 | raise AttributeError("'{}' object has no attribute '{}'".format( 150 | type(self).__name__, attr)) 151 | 152 | def __str__(self): 153 | loss_str = [] 154 | for name, meter in self.meters.items(): 155 | loss_str.append( 156 | "{}: {}".format(name, str(meter)) 157 | ) 158 | return self.delimiter.join(loss_str) 159 | 160 | def synchronize_between_processes(self): 161 | for meter in self.meters.values(): 162 | meter.synchronize_between_processes() 163 | 164 | def add_meter(self, name, meter): 165 | self.meters[name] = meter 166 | 167 | def log_every(self, iterable, print_freq, header=None): 168 | i = 0 169 | if not header: 170 | header = '' 171 | start_time = time.time() 172 | end = time.time() 173 | iter_time = SmoothedValue(fmt='{avg:.4f}') 174 | data_time = SmoothedValue(fmt='{avg:.4f}') 175 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 176 | log_msg = [ 177 | header, 178 | '[{0' + space_fmt + '}/{1}]', 179 | 'eta: {eta}', 180 | '{meters}', 181 | 'time: {time}', 182 | 'data: {data}' 183 | ] 184 | if torch.cuda.is_available(): 185 | log_msg.append('max mem: {memory:.0f}') 186 | log_msg = self.delimiter.join(log_msg) 187 | MB = 1024.0 * 1024.0 188 | for obj in iterable: 189 | data_time.update(time.time() - end) 190 | yield obj 191 | iter_time.update(time.time() - end) 192 | if i % print_freq == 0 or i == len(iterable) - 1: 193 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 194 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 195 | if torch.cuda.is_available(): 196 | print(log_msg.format( 197 | i, len(iterable), eta=eta_string, 198 | meters=str(self), 199 | time=str(iter_time), data=str(data_time), 200 | memory=torch.cuda.max_memory_allocated() / MB)) 201 | else: 202 | print(log_msg.format( 203 | i, len(iterable), eta=eta_string, 204 | meters=str(self), 205 | time=str(iter_time), data=str(data_time))) 206 | i += 1 207 | end = time.time() 208 | total_time = time.time() - start_time 209 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 210 | print('{} Total time: {} ({:.4f} s / it)'.format( 211 | header, total_time_str, total_time / len(iterable))) 212 | 213 | 214 | def _load_checkpoint_for_ema(model_ema, checkpoint): 215 | """ 216 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 217 | """ 218 | mem_file = io.BytesIO() 219 | torch.save(checkpoint, mem_file) 220 | mem_file.seek(0) 221 | model_ema._load_checkpoint(mem_file) 222 | 223 | 224 | def setup_for_distributed(is_master): 225 | """ 226 | This function disables printing when not in master process 227 | """ 228 | import builtins as __builtin__ 229 | builtin_print = __builtin__.print 230 | 231 | def print(*args, **kwargs): 232 | force = kwargs.pop('force', False) 233 | if is_master or force: 234 | builtin_print(*args, **kwargs) 235 | 236 | __builtin__.print = print 237 | 238 | 239 | def is_dist_avail_and_initialized(): 240 | if not dist.is_available(): 241 | return False 242 | if not dist.is_initialized(): 243 | return False 244 | return True 245 | 246 | 247 | def get_world_size(): 248 | if not is_dist_avail_and_initialized(): 249 | return 1 250 | return dist.get_world_size() 251 | 252 | 253 | def get_rank(): 254 | if not is_dist_avail_and_initialized(): 255 | return 0 256 | return dist.get_rank() 257 | 258 | 259 | def is_main_process(): 260 | return get_rank() == 0 261 | 262 | 263 | def save_on_master(*args, **kwargs): 264 | if is_main_process(): 265 | torch.save(*args, **kwargs) 266 | 267 | 268 | def init_distributed_mode(args): 269 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 270 | args.rank = int(os.environ["RANK"]) 271 | args.world_size = int(os.environ['WORLD_SIZE']) 272 | args.gpu = int(os.environ['LOCAL_RANK']) 273 | elif 'SLURM_PROCID' in os.environ: 274 | args.rank = int(os.environ['SLURM_PROCID']) 275 | args.gpu = args.rank % torch.cuda.device_count() 276 | else: 277 | print('Not using distributed mode') 278 | args.distributed = False 279 | return 280 | 281 | args.distributed = True 282 | 283 | torch.cuda.set_device(args.gpu) 284 | args.dist_backend = 'nccl' 285 | print('| distributed init (rank {}): {}'.format( 286 | args.rank, args.dist_url), flush=True) 287 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 288 | world_size=args.world_size, rank=args.rank) 289 | torch.distributed.barrier() 290 | setup_for_distributed(args.rank == 0) 291 | 292 | def named_apply( 293 | fn: Callable, 294 | module: nn.Module, name='', 295 | depth_first: bool = True, 296 | include_root: bool = False, 297 | ) -> nn.Module: 298 | if not depth_first and include_root: 299 | fn(module=module, name=name) 300 | for child_name, child_module in module.named_children(): 301 | child_name = '.'.join((name, child_name)) if name else child_name 302 | named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) 303 | if depth_first and include_root: 304 | fn(module=module, name=name) 305 | return module 306 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Code for RPC-SymViT and Scaled Attention in Our Paper 2 | Unveiling the Hidden Structure of Self-Attention via Kernel Principal Component Analysis 3 | 4 | https://arxiv.org/abs/2406.13762 5 | 6 | ## Requirements 7 | This toolkit requires PyTorch `torch`. 8 | 9 | The experiments for the paper were conducted with Python 3.10.12, timm 0.9.12 and PyTorch >= 1.4.0. 10 | 11 | The toolkit supports [Weights & Biases](https://docs.wandb.ai/) for monitoring jobs. If you use it, also install `wandb`. 12 | 13 | ## Instructions 14 | Please run each command line in the respective folders. A run.sh script is provided there as well. 15 | 16 | The hyper parameters that may be tuned are 17 | 1. --num_iter: the number of iterations of the PAP algorithm to run in a RPA-Attention layer 18 | 2. --lambd: the regularization parameter that controls the sparsity of the corruption matrix S 19 | 3. --layer: the layer to implement RPA-Attention, choose -1 for all layers 20 | 21 | ### Training 22 | 23 | RPC-SymViT 24 | ``` 25 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 1 --nproc_per_node=4 --use_env main_train.py --model deit_tiny_patch16_224 --batch-size 256 --data-path /path/to/data/imagenet --output_dir /path/to/checkpoints/ --robust --num_iter 4 --lambd 4 --layer 0 26 | ``` 27 | 28 | Scaled Attention *S* 29 | ``` 30 | CUDA_VISIBLE_DEVICES='1,2,3,0' python -m torch.distributed.launch --master_port 1 --nproc_per_node=4 --use_env main_train.py --model deit_tiny_patch16_224 --batch-size 256 --data-path /path/to/imagenet/ --output_dir /path/to/output/directory/ 31 | ``` 32 | 33 | Scaled Attention $\alpha$ $\times$ Asym \ 34 | Running this script without --s_scalar will default to training Scaled Attention *S* 35 | ``` 36 | CUDA_VISIBLE_DEVICES='1,2,3,0' python -m torch.distributed.launch --master_port 1 --nproc_per_node=4 --use_env main_train.py --model deit_tiny_patch16_224 --batch-size 256 --data-path /path/to/imagenet/ --output_dir /path/to/output/directory/ --s_scalar 37 | ``` 38 | 39 | ### Robustness Evaluation 40 | ``` 41 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 1 --nproc_per_node=4 --use_env eval_OOD.py --model deit_tiny_patch16_224 --data-path /path/to/data/imagenet/ --output_dir /path/to/checkpoints/ --robust --num_iter 4 --lambd 4 --layer 0 --resume /path/to/model/checkpoint/ 42 | ``` 43 | 44 | ### Attack Evaluation 45 | Run with --attack 'fgm' for FGSM attack and adjust --eps for severity of perturbation. 46 | ``` 47 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=$port_num --use_env main.py --model deit_tiny_patch16_224 --batch-size 48 --data-path /path/to/data/imagenet --output_dir /path/to/output/directory/ --project_name 'project_name' --job_name job_name --attack 'pgd' --eps 0.1 --finetune /path/to/trained/model/ --eval 1 --robust --num_iter 2 --layer -1 --lambd 4 48 | ``` 49 | 50 | ### Reconstruction Error Code (Sec 2.2.1) 51 | Run script run.py in the Reconstruction folder with wandb to reproduce the plot in Sec 2.2.1 of the paper. \ 52 | The loss is calculated for the same batch of images for each epoch. \ 53 | In the first 5-10 epochs, the loss logged may increase but should decrease afterwards. 54 | 55 | ``` 56 | CUDA_VISIBLE_DEVICES='4,5,6,7' python -m torch.distributed.launch --master_port 1 --nproc_per_node=4 --use_env main_train.py --model deit_tiny_patch16_224 --batch-size 256 --data-path /path/to/imagenet/ --output_dir /path/to/output/directory/ --lr 1e-4 --warmup-epochs 0 57 | ``` -------------------------------------------------------------------------------- /Reconstruction/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import os 4 | import json 5 | 6 | 7 | from torchvision import datasets, transforms 8 | from torchvision.datasets.folder import ImageFolder, default_loader 9 | 10 | 11 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 12 | from timm.data import create_transform 13 | 14 | # dont think this is in use anymore 15 | PATH_TO_IMAGENET_VAL = '/path/data/imagenet/val' 16 | 17 | # dont think this is in use anymore 18 | def create_symlinks_to_imagenet(imagenet_folder, folder_to_scan): 19 | if not os.path.exists(imagenet_folder): 20 | os.makedirs(imagenet_folder) 21 | folders_of_interest = os.listdir(folder_to_scan) 22 | path_prefix = PATH_TO_IMAGENET_VAL 23 | for folder in folders_of_interest: 24 | os.symlink(path_prefix + folder, imagenet_folder+folder, target_is_directory=True) 25 | 26 | 27 | 28 | 29 | class INatDataset(ImageFolder): 30 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 31 | category='name', loader=default_loader): 32 | self.transform = transform 33 | self.loader = loader 34 | self.target_transform = target_transform 35 | self.year = year 36 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 37 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 38 | with open(path_json) as json_file: 39 | data = json.load(json_file) 40 | 41 | 42 | with open(os.path.join(root, 'categories.json')) as json_file: 43 | data_catg = json.load(json_file) 44 | 45 | 46 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 47 | 48 | 49 | with open(path_json_for_targeter) as json_file: 50 | data_for_targeter = json.load(json_file) 51 | 52 | 53 | targeter = {} 54 | indexer = 0 55 | for elem in data_for_targeter['annotations']: 56 | king = [] 57 | king.append(data_catg[int(elem['category_id'])][category]) 58 | if king[0] not in targeter.keys(): 59 | targeter[king[0]] = indexer 60 | indexer += 1 61 | self.nb_classes = len(targeter) 62 | 63 | 64 | self.samples = [] 65 | for elem in data['images']: 66 | cut = elem['file_name'].split('/') 67 | target_current = int(cut[2]) 68 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 69 | 70 | 71 | categors = data_catg[target_current] 72 | target_current_true = targeter[categors[category]] 73 | self.samples.append((path_current, target_current_true)) 74 | 75 | 76 | # __getitem__ and __len__ inherited from ImageFolder 77 | 78 | 79 | 80 | # called from main twice, once for training, once for val 81 | def build_dataset(is_train, args): 82 | transform = build_transform(is_train, args) 83 | 84 | 85 | if args.data_set == 'CIFAR100': 86 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) 87 | nb_classes = 100 88 | if args.data_set == 'CIFAR10': 89 | dataset = datasets.CIFAR10(args.data_path, train=is_train, transform=transform) 90 | nb_classes = 10 91 | elif args.data_set == 'IMNET': 92 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 93 | dataset = datasets.ImageFolder(root, transform=transform) 94 | class_names = dataset.classes 95 | nb_classes = 1000 96 | elif args.data_set == 'INAT': 97 | dataset = INatDataset(args.data_path, train=is_train, year=2018, 98 | category=args.inat_category, transform=transform) 99 | nb_classes = dataset.nb_classes 100 | elif args.data_set == 'INAT19': 101 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 102 | category=args.inat_category, transform=transform) 103 | nb_classes = dataset.nb_classes 104 | 105 | 106 | 107 | return dataset, nb_classes 108 | 109 | 110 | 111 | 112 | def build_transform(is_train, args): 113 | resize_im = args.input_size > 32 114 | if is_train: 115 | # this should always dispatch to transforms_imagenet_train 116 | transform = create_transform( 117 | input_size=args.input_size, 118 | is_training=True, 119 | color_jitter=args.color_jitter, 120 | auto_augment=args.aa, 121 | interpolation=args.train_interpolation, 122 | re_prob=args.reprob, 123 | re_mode=args.remode, 124 | re_count=args.recount, 125 | ) 126 | if not resize_im: 127 | # replace RandomResizedCropAndInterpolation with 128 | # RandomCrop 129 | transform.transforms[0] = transforms.RandomCrop( 130 | args.input_size, padding=4) 131 | return transform 132 | 133 | 134 | t = [] 135 | if resize_im: 136 | size = int((256 / 224) * args.input_size) 137 | t.append( 138 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 139 | ) 140 | t.append(transforms.CenterCrop(args.input_size)) 141 | 142 | 143 | t.append(transforms.ToTensor()) 144 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 145 | return transforms.Compose(t) 146 | 147 | 148 | 149 | -------------------------------------------------------------------------------- /Reconstruction/engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Train and eval functions used in main.py 5 | """ 6 | import math 7 | import sys 8 | from typing import Iterable, Optional 9 | from fvcore.nn import FlopCountAnalysis 10 | import wandb 11 | import numpy as np 12 | 13 | 14 | import torch 15 | 16 | 17 | from timm.data import Mixup 18 | from timm.utils import accuracy, ModelEma 19 | 20 | from losses import DistillationLoss 21 | import utils 22 | import pdb 23 | 24 | def train_one_epoch(reconstruct_samples, model: torch.nn.Module, criterion: DistillationLoss, 25 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 26 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 27 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 28 | set_training_mode=True,wandb=False): 29 | # put our model in training mode... so that drop out and batch normalisation does not affect it 30 | model.train(set_training_mode) 31 | metric_logger = utils.MetricLogger(delimiter=" ") 32 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 33 | header = 'Epoch: [{}]'.format(epoch) 34 | print_freq = 10 35 | print(reconstruct_samples) 36 | 37 | for i, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 38 | if reconstruct_samples == None and i==2: 39 | print("assigning reconstruction samples idx") 40 | reconstruct_samples = samples 41 | samples = samples.to(device, non_blocking=True) 42 | targets = targets.to(device, non_blocking=True) 43 | # if i == 50: 44 | # break 45 | if mixup_fn is not None: 46 | samples, targets = mixup_fn(samples, targets) 47 | 48 | 49 | with torch.cuda.amp.autocast(): 50 | # flops = FlopCountAnalysis(model,samples) 51 | # print(flops.total()/1e9) 52 | # assert 1==2 53 | outputs = model(samples) 54 | loss = criterion(samples, outputs, targets) 55 | 56 | # break 57 | 58 | # loss is a tensor, averaged over the mini batch 59 | loss_value = loss.item() 60 | 61 | 62 | if not math.isfinite(loss_value): 63 | print("Loss is {}, stopping training".format(loss_value)) 64 | f = open("error.txt", "a") 65 | # writing in the file 66 | f.write("Loss is {}, stopping training".format(loss_value)) 67 | # closing the file 68 | f.close() 69 | sys.exit(1) 70 | 71 | 72 | optimizer.zero_grad() 73 | 74 | 75 | # this attribute is added by timm on one optimizer (adahessian) 76 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 77 | # provides optimisation step for model 78 | loss_scaler(loss, optimizer, clip_grad=max_norm, 79 | parameters=model.parameters(), create_graph=is_second_order) 80 | 81 | 82 | torch.cuda.synchronize() 83 | if model_ema is not None: 84 | model_ema.update(model) 85 | 86 | 87 | metric_logger.update(loss=loss_value) 88 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 89 | 90 | # gather the stats from all processes 91 | metric_logger.synchronize_between_processes() 92 | print("Averaged stats:", metric_logger) 93 | if wandb: 94 | for k, meter in metric_logger.meters.items(): 95 | wandb.log({k: meter.global_avg, 'epoch': epoch}) 96 | model.module.reconstruct = True 97 | reconstruct_samples = reconstruct_samples.to(device, non_blocking=True) 98 | outputs = model(reconstruct_samples) 99 | model.module.reconstruct = False 100 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, reconstruct_samples 101 | 102 | # evaluate on 1000 images in imagenet/val folder 103 | @torch.no_grad() 104 | def evaluate(reconstruct_samples, data_loader, model, device, attn_only=False, batch_limit=0,epoch=0,wandb=False): 105 | criterion = torch.nn.CrossEntropyLoss() 106 | 107 | metric_logger = utils.MetricLogger(delimiter=" ") 108 | header = 'Test:' 109 | 110 | # switch to evaluation mode 111 | model.eval() 112 | model.module.test() 113 | # i = 0 114 | if not isinstance(batch_limit, int) or batch_limit < 0: 115 | batch_limit = 0 116 | attn = [] 117 | pi = [] 118 | for i, (images, target) in enumerate(metric_logger.log_every(data_loader, 10, header)): 119 | if reconstruct_samples == None: 120 | print("assigning reconstruction samples idx") 121 | reconstruct_samples = images 122 | images = images.to(device, non_blocking=True) 123 | target = target.to(device, non_blocking=True) 124 | 125 | 126 | with torch.cuda.amp.autocast(): 127 | if attn_only: 128 | output, _aux = model(images) 129 | attn.append(_aux[0].detach().cpu().numpy()) 130 | pi.append(_aux[1].detach().cpu().numpy()) 131 | del _aux 132 | else: 133 | output = model(images) 134 | loss = criterion(output, target) 135 | 136 | # print(output.shape,target.shape) 137 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 138 | 139 | batch_size = images.shape[0] 140 | metric_logger.update(loss=loss.item()) 141 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 142 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 143 | r = {k: meter.global_avg for k, meter in metric_logger.meters.items()} 144 | # gather the stats from all processes 145 | metric_logger.synchronize_between_processes() 146 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 147 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 148 | if wandb: 149 | for k, meter in metric_logger.meters.items(): 150 | wandb.log({f'test_{k}': meter.global_avg , 'epoch':epoch}) 151 | 152 | model.module.reconstruct = True 153 | model.module.test_check = True 154 | reconstruct_samples = reconstruct_samples.to(device, non_blocking=True) 155 | outputs = model(reconstruct_samples) 156 | model.module.reconstruct = False 157 | model.module.test_check = False 158 | if attn_only: 159 | return r, (attn, pi) 160 | return r, reconstruct_samples 161 | 162 | 163 | -------------------------------------------------------------------------------- /Reconstruction/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Implements the knowledge distillation loss 5 | """ 6 | import torch 7 | from torch.nn import functional as F 8 | import timm 9 | 10 | class SoftTargetCrossEntropy(torch.nn.Module): 11 | 12 | def __init__(self): 13 | super(SoftTargetCrossEntropy, self).__init__() 14 | 15 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 16 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) 17 | return loss.mean() 18 | 19 | class DistillationLoss(torch.nn.Module): 20 | """ 21 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 22 | taking a teacher model prediction and using it as additional supervision. 23 | """ 24 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 25 | distillation_type: str, alpha: float, tau: float): 26 | super().__init__() 27 | self.base_criterion = base_criterion 28 | self.teacher_model = teacher_model 29 | assert distillation_type in ['none', 'soft', 'hard'] 30 | self.distillation_type = distillation_type 31 | self.alpha = alpha 32 | self.tau = tau 33 | 34 | 35 | def forward(self, inputs, outputs, labels): 36 | """ 37 | Args: 38 | inputs: The original inputs that are feed to the teacher model 39 | outputs: the outputs of the model to be trained. It is expected to be 40 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 41 | in the first position and the distillation predictions as the second output 42 | labels: the labels for the base criterion 43 | """ 44 | outputs_kd = None 45 | if not isinstance(outputs, torch.Tensor): 46 | # assume that the model outputs a tuple of [outputs, outputs_kd] 47 | outputs, outputs_kd = outputs 48 | base_loss = self.base_criterion(outputs, labels) 49 | if self.distillation_type == 'none': 50 | return base_loss 51 | 52 | 53 | if outputs_kd is None: 54 | raise ValueError("When knowledge distillation is enabled, the model is " 55 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 56 | "class_token and the dist_token") 57 | # don't backprop throught the teacher 58 | with torch.no_grad(): 59 | teacher_outputs = self.teacher_model(inputs) 60 | 61 | 62 | if self.distillation_type == 'soft': 63 | T = self.tau 64 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 65 | # with slight modifications 66 | distillation_loss = F.kl_div( 67 | F.log_softmax(outputs_kd / T, dim=1), 68 | F.log_softmax(teacher_outputs / T, dim=1), 69 | reduction='sum', 70 | log_target=True 71 | ) * (T * T) / outputs_kd.numel() 72 | elif self.distillation_type == 'hard': 73 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) 74 | 75 | 76 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 77 | return loss 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /Reconstruction/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.nn as nn 5 | from functools import partial 6 | 7 | 8 | from timm.models.vision_transformer import _cfg 9 | from softmax import VisionTransformer 10 | from timm.models.registry import register_model 11 | from timm.models.layers import trunc_normal_ 12 | # from xcit import XCiT, HDPXCiT 13 | 14 | class DistilledVisionTransformer(VisionTransformer): 15 | def __init__(self, *args, **kwargs): 16 | super().__init__(*args, **kwargs) 17 | self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 18 | num_patches = self.patch_embed.num_patches 19 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) 20 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() 21 | 22 | trunc_normal_(self.dist_token, std=.02) 23 | trunc_normal_(self.pos_embed, std=.02) 24 | self.head_dist.apply(self._init_weights) 25 | 26 | def forward_features(self, x): 27 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 28 | # with slight modifications to add the dist_token 29 | B = x.shape[0] 30 | x = self.patch_embed(x) 31 | 32 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 33 | dist_token = self.dist_token.expand(B, -1, -1) 34 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 35 | 36 | x = x + self.pos_embed 37 | x = self.pos_drop(x) 38 | 39 | for blk in self.blocks: 40 | x = blk(x) 41 | 42 | 43 | x = self.norm(x) 44 | return x[:, 0], x[:, 1] 45 | 46 | def forward(self, x): 47 | x, x_dist = self.forward_features(x) 48 | x = self.head(x) 49 | x_dist = self.head_dist(x_dist) 50 | if self.training: 51 | return x, x_dist 52 | else: 53 | # during inference, return the average of both classifier predictions 54 | return (x + x_dist) / 2 55 | 56 | # register model with timms to be able to call it from "create_model" using its function name 57 | # but mainly edit the model from softmax.py 58 | @register_model 59 | def deit_tiny_patch16_224(pretrained=False, **kwargs): 60 | from softmax import VisionTransformer 61 | model = VisionTransformer( 62 | patch_size=16, embed_dim=192, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 63 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) # Tan's NOTE: in the original code, num_heads = 3 here 64 | model.default_cfg = _cfg() 65 | return model 66 | 67 | 68 | -------------------------------------------------------------------------------- /Reconstruction/requirements.txt: -------------------------------------------------------------------------------- 1 | fvcore==0.1.5.post20221221 2 | numpy==1.25.2 3 | timm==0.9.7 4 | torch==2.0.1 5 | torchvision==0.15.2 6 | -------------------------------------------------------------------------------- /Reconstruction/run.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES='4,5,6,7' python -m torch.distributed.launch --master_port 1 --nproc_per_node=4 --use_env main_train.py \ 2 | --model deit_tiny_patch16_224 --batch-size 256 --data-path /path/to/imagenet/ --output_dir /path/to/output/directory/ \ 3 | --lr 1e-4 --warmup-epochs 0 4 | 5 | -------------------------------------------------------------------------------- /Reconstruction/samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.distributed as dist 5 | import math 6 | 7 | 8 | 9 | 10 | class RASampler(torch.utils.data.Sampler): 11 | """Sampler that restricts data loading to a subset of the dataset for distributed, 12 | with repeated augmentation. 13 | It ensures that different each augmented version of a sample will be visible to a 14 | different process (GPU) 15 | Heavily based on torch.utils.data.DistributedSampler 16 | """ 17 | 18 | # num_replicas = world size, rank = global rank 19 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 20 | if num_replicas is None: 21 | if not dist.is_available(): 22 | raise RuntimeError("Requires distributed package to be available") 23 | num_replicas = dist.get_world_size() 24 | if rank is None: 25 | if not dist.is_available(): 26 | raise RuntimeError("Requires distributed package to be available") 27 | rank = dist.get_rank() 28 | self.dataset = dataset 29 | self.num_replicas = num_replicas 30 | self.rank = rank 31 | self.epoch = 0 32 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 33 | self.total_size = self.num_samples * self.num_replicas 34 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 35 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 36 | self.shuffle = shuffle 37 | 38 | 39 | def __iter__(self): 40 | # deterministically shuffle based on epoch 41 | g = torch.Generator() 42 | g.manual_seed(self.epoch) 43 | if self.shuffle: 44 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 45 | else: 46 | indices = list(range(len(self.dataset))) 47 | 48 | 49 | # add extra samples to make it evenly divisible by 3 50 | indices = [ele for ele in indices for i in range(3)] 51 | indices += indices[:(self.total_size - len(indices))] 52 | assert len(indices) == self.total_size 53 | 54 | 55 | # subsample 56 | indices = indices[self.rank:self.total_size:self.num_replicas] 57 | assert len(indices) == self.num_samples 58 | 59 | 60 | return iter(indices[:self.num_selected_samples]) 61 | 62 | 63 | def __len__(self): 64 | return self.num_selected_samples 65 | 66 | 67 | def set_epoch(self, epoch): 68 | self.epoch = epoch 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /Reconstruction/softmax.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | from functools import partial 4 | from collections import OrderedDict 5 | from copy import deepcopy 6 | from statistics import mean 7 | import numpy as np 8 | import wandb 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ 15 | from timm.models.vision_transformer import init_weights_vit_timm, _load_weights, init_weights_vit_jax 16 | from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv 17 | from utils import named_apply 18 | import copy 19 | 20 | class Attention(nn.Module): 21 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., layerth=0): 22 | super().__init__() 23 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 24 | self.num_heads = num_heads 25 | self.layerth = layerth 26 | head_dim = dim // num_heads 27 | # sqrt (D) 28 | self.scale = head_dim ** -0.5 29 | self.reconstruct = False 30 | self.test = False 31 | self.train_error = 0 32 | self.test_error = 0 33 | 34 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 35 | 36 | self.attn_drop = nn.Dropout(attn_drop) 37 | 38 | self.proj = nn.Linear(dim, dim) 39 | self.proj_drop = nn.Dropout(proj_drop) 40 | 41 | def forward(self, x): 42 | B, N, C = x.shape 43 | # B, heads, N, features 44 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 45 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 46 | # normalise q,k 47 | q = (q - q.mean(dim=-1).unsqueeze(dim=-1).expand(B, self.num_heads, N, C // self.num_heads)) \ 48 | /q.std(dim=-1).unsqueeze(dim=-1).expand(B, self.num_heads, N, C // self.num_heads) 49 | k = (k - k.mean(dim=-1).unsqueeze(dim=-1).expand(B, self.num_heads, N, C // self.num_heads)) \ 50 | /k.std(dim=-1).unsqueeze(dim=-1).expand(B, self.num_heads, N, C // self.num_heads) 51 | 52 | qk = q @ k.transpose(-2, -1) 53 | qq = q @ q.transpose(-2, -1) 54 | attn = (qk) * self.scale 55 | attn = attn.softmax(dim=-1) 56 | 57 | attn = self.attn_drop(attn) 58 | 59 | # @ is a matrix multiplication 60 | x = (attn @ v) 61 | 62 | if self.reconstruct: 63 | num = torch.exp(torch.diagonal(qq,dim1=-2, dim2=-1)) 64 | den = torch.exp(2*qk[:,:,:,0]-torch.log(attn[:,:,:,0].pow(2))) 65 | # B, heads, N 66 | phiq_norm = num/den 67 | print(phiq_norm.shape) 68 | # B, heads, N, [C/heads] 69 | proj = x.pow(2).sum(dim=-1) 70 | print(proj.shape) 71 | error = phiq_norm - proj 72 | # print(error) 73 | error_ave = torch.log(error.mean()) 74 | if self.test: 75 | self.test_error=error_ave 76 | else: 77 | self.train_error=error_ave 78 | self.test = False 79 | self.reconstruct = False 80 | 81 | x = x.transpose(1, 2).reshape(B,N,C) 82 | 83 | x = self.proj(x) 84 | x = self.proj_drop(x) 85 | ################ COSINE SIMILARITY MEASURE 86 | # n = x.shape[1] #x is in shape of (batchsize, length, dim) 87 | # sqaure norm across features 88 | # x_norm = torch.norm(x, 2, dim = -1, keepdim= True) 89 | # x_ = x/x_norm 90 | # x_cossim = torch.tril((x_ @ x_.transpose(-2, -1)), diagonal= -1).sum(dim = (-1, -2))/(n*(n - 1)/2) 91 | # x_cossim = x_cossim.mean() 92 | # python debugger breakpoint 93 | # import pdb;pdb.set_trace() 94 | ################ 95 | 96 | return x 97 | 98 | 99 | class Block(nn.Module): 100 | 101 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 102 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, layerth = None): 103 | super().__init__() 104 | self.norm1 = norm_layer(dim) 105 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, 106 | attn_drop=attn_drop, proj_drop=drop,layerth=layerth) 107 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 108 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 109 | self.norm2 = norm_layer(dim) 110 | mlp_hidden_dim = int(dim * mlp_ratio) 111 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 112 | self.layerth = layerth 113 | 114 | def forward(self, x): 115 | 116 | x = x + self.drop_path(self.attn(self.norm1(x))) 117 | x = x + self.drop_path(self.mlp(self.norm2(x))) 118 | return x 119 | 120 | 121 | class VisionTransformer(nn.Module): 122 | """ Vision Transformer 123 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 124 | - https://arxiv.org/abs/2010.11929 125 | Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` 126 | - https://arxiv.org/abs/2012.12877 127 | """ 128 | 129 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 130 | num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, 131 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, 132 | act_layer=None, weight_init='',pretrained_cfg=None,pretrained_cfg_overlay=None,wandb=False): 133 | """ 134 | Args: 135 | img_size (int, tuple): input image size 136 | patch_size (int, tuple): patch size 137 | in_chans (int): number of input channels 138 | num_classes (int): number of classes for classification head 139 | embed_dim (int): embedding dimension 140 | depth (int): depth of transformer 141 | num_heads (int): number of attention heads 142 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 143 | qkv_bias (bool): enable bias for qkv if True 144 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 145 | distilled (bool): model includes a distillation token and head as in DeiT models 146 | drop_rate (float): dropout rate 147 | attn_drop_rate (float): attention dropout rate 148 | drop_path_rate (float): stochastic depth rate 149 | embed_layer (nn.Module): patch embedding layer 150 | norm_layer: (nn.Module): normalization layer 151 | weight_init: (str): weight init scheme 152 | """ 153 | super().__init__() 154 | self.wandb = wandb 155 | self.num_classes = num_classes 156 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 157 | self.num_tokens = 2 if distilled else 1 158 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 159 | act_layer = act_layer or nn.GELU 160 | self.reconstruct = False 161 | self.test_check = False 162 | self.depth = depth 163 | 164 | # how does embedding conv2d update its weights? 165 | self.patch_embed = embed_layer( 166 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 167 | num_patches = self.patch_embed.num_patches 168 | # print(img_size,patch_size,in_chans,num_patches) 169 | 170 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 171 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None 172 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 173 | self.pos_drop = nn.Dropout(p=drop_rate) 174 | 175 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 176 | self.blocks = nn.Sequential(*[ 177 | Block( 178 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, 179 | attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, 180 | layerth = i) 181 | for i in range(depth)]) 182 | self.norm = norm_layer(embed_dim) 183 | 184 | # Representation layer 185 | if representation_size and not distilled: 186 | self.num_features = representation_size 187 | self.pre_logits = nn.Sequential(OrderedDict([f 188 | ('fc', nn.Linear(embed_dim, representation_size)), 189 | ('act', nn.Tanh()) 190 | ])) 191 | else: 192 | self.pre_logits = nn.Identity() 193 | 194 | # Classifier head(s) 195 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 196 | self.head_dist = None 197 | if distilled: 198 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 199 | 200 | self.init_weights(weight_init) 201 | 202 | def init_weights(self, mode=''): 203 | assert mode in ('jax', 'jax_nlhb', 'nlhb', '') 204 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. 205 | trunc_normal_(self.pos_embed, std=.02) 206 | if self.dist_token is not None: 207 | trunc_normal_(self.dist_token, std=.02) 208 | if mode.startswith('jax'): 209 | # leave cls token as zeros to match jax impl 210 | partial(init_weights_vit_jax(mode, head_bias), head_bias=head_bias, jax_impl=True) 211 | else: 212 | trunc_normal_(self.cls_token, std=.02) 213 | init_weights_vit_timm 214 | 215 | def _init_weights(self, m): 216 | # this fn left here for compat with downstream users 217 | init_weights(m) 218 | 219 | @torch.jit.ignore() 220 | def load_pretrained(self, checkpoint_path, prefix=''): 221 | _load_weights(self, checkpoint_path, prefix) 222 | 223 | @torch.jit.ignore 224 | def no_weight_decay(self): 225 | return {'pos_embed', 'cls_token', 'dist_token'} 226 | 227 | def get_classifier(self): 228 | if self.dist_token is None: 229 | return self.head 230 | else: 231 | return self.head, self.head_dist 232 | 233 | def reset_classifier(self, num_classes, global_pool=''): 234 | self.num_classes = num_classes 235 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 236 | if self.num_tokens == 2: 237 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 238 | 239 | def forward_features(self, x): 240 | x = self.patch_embed(x) 241 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks 242 | if self.dist_token is None: 243 | x = torch.cat((cls_token, x), dim=1) 244 | else: 245 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) 246 | # add the same pos_emb token to each sample? broadcasting... 247 | x = self.pos_drop(x + self.pos_embed) 248 | if self.reconstruct: 249 | for i in range(0,self.depth): 250 | self.blocks[i].attn.reconstruct=True 251 | x = self.blocks(x) 252 | x = self.norm(x) 253 | 254 | if self.dist_token is None: 255 | return self.pre_logits(x[:, 0]) 256 | else: 257 | return x[:, 0], x[:, 1] 258 | 259 | def forward(self, x): 260 | x = self.forward_features(x) 261 | if self.head_dist is not None: 262 | x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple 263 | if self.training and not torch.jit.is_scripting(): 264 | # during inference, return the average of both classifier predictions 265 | return x, x_dist 266 | else: 267 | return (x + x_dist) / 2 268 | else: 269 | x = self.head(x) 270 | train_error = [] 271 | test_error = [] 272 | if self.reconstruct and self.test_check: 273 | for i in range(0,self.depth): 274 | test_error.append(self.blocks[i].attn.test_error.data.cpu()) 275 | print(test_error) 276 | if self.wandb: 277 | wandb.log({"test_error":np.average(test_error)}) 278 | elif self.reconstruct: 279 | for i in range(0,self.depth): 280 | train_error.append(self.blocks[i].attn.train_error.data.cpu()) 281 | print(train_error) 282 | if self.wandb: 283 | wandb.log({"train_error":np.average(train_error)}) 284 | return x 285 | 286 | def test(self): 287 | for i in range(0,self.depth): 288 | self.blocks[i].attn.test=True 289 | 290 | 291 | 292 | -------------------------------------------------------------------------------- /Reconstruction/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Misc functions, including distributed helpers. 5 | 6 | 7 | Mostly copy-paste from torchvision references. 8 | """ 9 | import io 10 | import os 11 | import time 12 | from collections import defaultdict, deque 13 | import datetime 14 | from typing import Any, Callable, Dict, Iterator, Tuple, Type, Union 15 | 16 | 17 | import torch 18 | import torch.distributed as dist 19 | from torch import nn as nn 20 | 21 | 22 | 23 | class SmoothedValue(object): 24 | """Track a series of values and provide access to smoothed values over a 25 | window or the global series average. 26 | """ 27 | 28 | 29 | def __init__(self, window_size=20, fmt=None): 30 | if fmt is None: 31 | fmt = "{median:.4f} ({global_avg:.4f})" 32 | self.deque = deque(maxlen=window_size) 33 | self.total = 0.0 34 | self.count = 0 35 | self.fmt = fmt 36 | 37 | 38 | def update(self, value, n=1): 39 | self.deque.append(value) 40 | self.count += n 41 | self.total += value * n 42 | 43 | 44 | def synchronize_between_processes(self): 45 | """ 46 | Warning: does not synchronize the deque! 47 | """ 48 | if not is_dist_avail_and_initialized(): 49 | return 50 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 51 | dist.barrier() 52 | dist.all_reduce(t) 53 | t = t.tolist() 54 | self.count = int(t[0]) 55 | self.total = t[1] 56 | 57 | 58 | @property 59 | def median(self): 60 | d = torch.tensor(list(self.deque)) 61 | return d.median().item() 62 | 63 | 64 | @property 65 | def avg(self): 66 | d = torch.tensor(list(self.deque), dtype=torch.float32) 67 | return d.mean().item() 68 | 69 | 70 | @property 71 | def global_avg(self): 72 | return self.total / self.count 73 | 74 | 75 | @property 76 | def max(self): 77 | return max(self.deque) 78 | 79 | 80 | @property 81 | def value(self): 82 | return self.deque[-1] 83 | 84 | 85 | def __str__(self): 86 | return self.fmt.format( 87 | median=self.median, 88 | avg=self.avg, 89 | global_avg=self.global_avg, 90 | max=self.max, 91 | value=self.value) 92 | 93 | 94 | 95 | 96 | class MetricLogger(object): 97 | def __init__(self, delimiter="\t"): 98 | self.meters = defaultdict(SmoothedValue) 99 | self.delimiter = delimiter 100 | 101 | 102 | def update(self, **kwargs): 103 | for k, v in kwargs.items(): 104 | if isinstance(v, torch.Tensor): 105 | v = v.item() 106 | assert isinstance(v, (float, int)) 107 | self.meters[k].update(v) 108 | 109 | 110 | def __getattr__(self, attr): 111 | if attr in self.meters: 112 | return self.meters[attr] 113 | if attr in self.__dict__: 114 | return self.__dict__[attr] 115 | raise AttributeError("'{}' object has no attribute '{}'".format( 116 | type(self).__name__, attr)) 117 | 118 | 119 | def __str__(self): 120 | loss_str = [] 121 | for name, meter in self.meters.items(): 122 | loss_str.append( 123 | "{}: {}".format(name, str(meter)) 124 | ) 125 | return self.delimiter.join(loss_str) 126 | 127 | 128 | def synchronize_between_processes(self): 129 | for meter in self.meters.values(): 130 | meter.synchronize_between_processes() 131 | 132 | 133 | def add_meter(self, name, meter): 134 | self.meters[name] = meter 135 | 136 | 137 | #iterable is our data_loader which is pytorch data loader with our dataset_train obj and RA sampler 138 | def log_every(self, iterable, print_freq, header=None): 139 | i = 0 140 | if not header: 141 | header = '' 142 | start_time = time.time() 143 | end = time.time() 144 | iter_time = SmoothedValue(fmt='{avg:.4f}') 145 | data_time = SmoothedValue(fmt='{avg:.4f}') 146 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 147 | log_msg = [ 148 | header, 149 | '[{0' + space_fmt + '}/{1}]', 150 | 'eta: {eta}', 151 | '{meters}', 152 | 'time: {time}', 153 | 'data: {data}' 154 | ] 155 | if torch.cuda.is_available(): 156 | log_msg.append('max mem: {memory:.0f}') 157 | log_msg = self.delimiter.join(log_msg) 158 | MB = 1024.0 * 1024.0 159 | for obj in iterable: 160 | data_time.update(time.time() - end) 161 | # returns obj to caller, then continues loop 162 | yield obj 163 | iter_time.update(time.time() - end) 164 | if i % print_freq == 0 or i == len(iterable) - 1: 165 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 166 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 167 | if torch.cuda.is_available(): 168 | print(log_msg.format( 169 | i, len(iterable), eta=eta_string, 170 | meters=str(self), 171 | time=str(iter_time), data=str(data_time), 172 | memory=torch.cuda.max_memory_allocated() / MB)) 173 | else: 174 | print(log_msg.format( 175 | i, len(iterable), eta=eta_string, 176 | meters=str(self), 177 | time=str(iter_time), data=str(data_time))) 178 | i += 1 179 | end = time.time() 180 | # need to remove this! 181 | total_time = time.time() - start_time 182 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 183 | print('{} Total time: {} ({:.4f} s / it)'.format( 184 | header, total_time_str, total_time / len(iterable))) 185 | 186 | 187 | 188 | 189 | def _load_checkpoint_for_ema(model_ema, checkpoint): 190 | """ 191 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 192 | """ 193 | mem_file = io.BytesIO() 194 | torch.save(checkpoint, mem_file) 195 | mem_file.seek(0) 196 | model_ema._load_checkpoint(mem_file) 197 | 198 | 199 | 200 | 201 | def setup_for_distributed(is_master): 202 | """ 203 | This function disables printing when not in master process 204 | """ 205 | import builtins as __builtin__ 206 | builtin_print = __builtin__.print 207 | 208 | 209 | def print(*args, **kwargs): 210 | force = kwargs.pop('force', False) 211 | if is_master or force: 212 | builtin_print(*args, **kwargs) 213 | 214 | 215 | __builtin__.print = print 216 | 217 | 218 | 219 | 220 | def is_dist_avail_and_initialized(): 221 | if not dist.is_available(): 222 | return False 223 | if not dist.is_initialized(): 224 | return False 225 | return True 226 | 227 | 228 | 229 | 230 | def get_world_size(): 231 | if not is_dist_avail_and_initialized(): 232 | return 1 233 | return dist.get_world_size() 234 | 235 | 236 | 237 | 238 | def get_rank(): 239 | if not is_dist_avail_and_initialized(): 240 | return 0 241 | return dist.get_rank() 242 | 243 | 244 | 245 | 246 | def is_main_process(): 247 | return get_rank() == 0 248 | 249 | 250 | 251 | 252 | def save_on_master(*args, **kwargs): 253 | if is_main_process(): 254 | torch.save(*args, **kwargs) 255 | 256 | 257 | 258 | 259 | def init_distributed_mode(args): 260 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 261 | args.rank = int(os.environ["RANK"]) 262 | args.world_size = int(os.environ['WORLD_SIZE']) 263 | args.gpu = int(os.environ['LOCAL_RANK']) 264 | elif 'SLURM_PROCID' in os.environ: 265 | args.rank = int(os.environ['SLURM_PROCID']) 266 | args.gpu = args.rank % torch.cuda.device_count() 267 | else: 268 | print('Not using distributed mode') 269 | args.distributed = False 270 | return 271 | 272 | 273 | args.distributed = True 274 | 275 | 276 | torch.cuda.set_device(args.gpu) 277 | args.dist_backend = 'nccl' 278 | print('| distributed init (rank {}): {}'.format( 279 | args.rank, args.dist_url), flush=True) 280 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 281 | world_size=args.world_size, rank=args.rank) 282 | torch.distributed.barrier() 283 | setup_for_distributed(args.rank == 0) 284 | 285 | def named_apply( 286 | fn: Callable, 287 | module: nn.Module, name='', 288 | depth_first: bool = True, 289 | include_root: bool = False, 290 | ) -> nn.Module: 291 | if not depth_first and include_root: 292 | fn(module=module, name=name) 293 | for child_name, child_module in module.named_children(): 294 | child_name = '.'.join((name, child_name)) if name else child_name 295 | named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) 296 | if depth_first and include_root: 297 | fn(module=module, name=name) 298 | return module 299 | -------------------------------------------------------------------------------- /Robust/__pycache__/calibration_tools.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rachtsy/KPCA_code/017fa4e26f511d1c11c1025f8301552d0a6c1545/Robust/__pycache__/calibration_tools.cpython-310.pyc -------------------------------------------------------------------------------- /Robust/__pycache__/datasets.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rachtsy/KPCA_code/017fa4e26f511d1c11c1025f8301552d0a6c1545/Robust/__pycache__/datasets.cpython-310.pyc -------------------------------------------------------------------------------- /Robust/__pycache__/engine.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rachtsy/KPCA_code/017fa4e26f511d1c11c1025f8301552d0a6c1545/Robust/__pycache__/engine.cpython-310.pyc -------------------------------------------------------------------------------- /Robust/__pycache__/losses.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rachtsy/KPCA_code/017fa4e26f511d1c11c1025f8301552d0a6c1545/Robust/__pycache__/losses.cpython-310.pyc -------------------------------------------------------------------------------- /Robust/__pycache__/models.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rachtsy/KPCA_code/017fa4e26f511d1c11c1025f8301552d0a6c1545/Robust/__pycache__/models.cpython-310.pyc -------------------------------------------------------------------------------- /Robust/__pycache__/samplers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rachtsy/KPCA_code/017fa4e26f511d1c11c1025f8301552d0a6c1545/Robust/__pycache__/samplers.cpython-310.pyc -------------------------------------------------------------------------------- /Robust/__pycache__/softmax.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rachtsy/KPCA_code/017fa4e26f511d1c11c1025f8301552d0a6c1545/Robust/__pycache__/softmax.cpython-310.pyc -------------------------------------------------------------------------------- /Robust/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rachtsy/KPCA_code/017fa4e26f511d1c11c1025f8301552d0a6c1545/Robust/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /Robust/__pycache__/utils_robust.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rachtsy/KPCA_code/017fa4e26f511d1c11c1025f8301552d0a6c1545/Robust/__pycache__/utils_robust.cpython-310.pyc -------------------------------------------------------------------------------- /Robust/calibration_tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sklearn.metrics as sk 3 | from sklearn.utils.extmath import stable_cumsum 4 | 5 | recall_level_default = 0.95 6 | 7 | def calib_err(confidence, correct, p='2', beta=100): 8 | # beta is target bin size 9 | idxs = np.argsort(confidence) 10 | confidence = confidence[idxs] 11 | correct = correct[idxs] 12 | bins = [[i * beta, (i + 1) * beta] for i in range(len(confidence) // beta)] 13 | bins[-1] = [bins[-1][0], len(confidence)] 14 | 15 | cerr = 0 16 | total_examples = len(confidence) 17 | for i in range(len(bins) - 1): 18 | bin_confidence = confidence[bins[i][0]:bins[i][1]] 19 | bin_correct = correct[bins[i][0]:bins[i][1]] 20 | num_examples_in_bin = len(bin_confidence) 21 | 22 | if num_examples_in_bin > 0: 23 | difference = np.abs(np.nanmean(bin_confidence) - np.nanmean(bin_correct)) 24 | 25 | if p == '2': 26 | cerr += num_examples_in_bin / total_examples * np.square(difference) 27 | elif p == '1': 28 | cerr += num_examples_in_bin / total_examples * difference 29 | elif p == 'infty' or p == 'infinity' or p == 'max': 30 | cerr = np.maximum(cerr, difference) 31 | else: 32 | assert False, "p must be '1', '2', or 'infty'" 33 | 34 | if p == '2': 35 | cerr = np.sqrt(cerr) 36 | 37 | return cerr 38 | 39 | 40 | def aurra(confidence, correct): 41 | conf_ranks = np.argsort(confidence)[::-1] # indices from greatest to least confidence 42 | rra_curve = np.cumsum(np.asarray(correct)[conf_ranks]) 43 | rra_curve = rra_curve / np.arange(1, len(rra_curve) + 1) # accuracy at each response rate 44 | return np.mean(rra_curve) 45 | 46 | 47 | def soft_f1(confidence, correct): 48 | wrong = 1 - correct 49 | 50 | # # the incorrectly classified samples are our interest 51 | # # so they make the positive class 52 | # tp_soft = np.sum((1 - confidence) * wrong) 53 | # fp_soft = np.sum((1 - confidence) * correct) 54 | # fn_soft = np.sum(confidence * wrong) 55 | 56 | # return 2 * tp_soft / (2 * tp_soft + fn_soft + fp_soft) 57 | return 2 * ((1 - confidence) * wrong).sum()/(1 - confidence + wrong).sum() 58 | 59 | 60 | def tune_temp(logits, labels, binary_search=True, lower=0.2, upper=5.0, eps=0.0001): 61 | logits = np.array(logits) 62 | 63 | if binary_search: 64 | import torch 65 | import torch.nn.functional as F 66 | 67 | logits = torch.FloatTensor(logits) 68 | labels = torch.LongTensor(labels) 69 | t_guess = torch.FloatTensor([0.5*(lower + upper)]).requires_grad_() 70 | 71 | while upper - lower > eps: 72 | if torch.autograd.grad(F.cross_entropy(logits / t_guess, labels), t_guess)[0] > 0: 73 | upper = 0.5 * (lower + upper) 74 | else: 75 | lower = 0.5 * (lower + upper) 76 | t_guess = t_guess * 0 + 0.5 * (lower + upper) 77 | 78 | t = min([lower, 0.5 * (lower + upper), upper], key=lambda x: float(F.cross_entropy(logits / x, labels))) 79 | else: 80 | import cvxpy as cx 81 | 82 | set_size = np.array(logits).shape[0] 83 | 84 | t = cx.Variable() 85 | 86 | expr = sum((cx.Minimize(cx.log_sum_exp(logits[i, :] * t) - logits[i, labels[i]] * t) 87 | for i in range(set_size))) 88 | p = cx.Problem(expr, [lower <= t, t <= upper]) 89 | 90 | p.solve() # p.solve(solver=cx.SCS) 91 | t = 1 / t.value 92 | 93 | return t 94 | 95 | 96 | def print_measures(rms, aurra_metric, mad, sf1, method_name='Baseline'): 97 | print('\t\t\t\t\t\t\t' + method_name) 98 | print('RMS Calib Error (%): \t\t{:.2f}'.format(100 * rms)) 99 | print('AURRA (%): \t\t\t{:.2f}'.format(100 * aurra)) 100 | # print('MAD Calib Error (%): \t\t{:.2f}'.format(100 * mad)) 101 | # print('Soft F1 Score (%): \t\t{:.2f}'.format(100 * sf1)) 102 | 103 | 104 | def show_calibration_results(confidence, correct, method_name='Baseline'): 105 | 106 | print('\t\t\t\t' + method_name) 107 | print('RMS Calib Error (%): \t\t{:.2f}'.format( 108 | 100 * calib_err(confidence, correct, p='2'))) 109 | 110 | print('AURRA (%): \t\t\t{:.2f}'.format( 111 | 100 * aurra(confidence, correct))) 112 | 113 | # print('MAD Calib Error (%): \t\t{:.2f}'.format( 114 | # 100 * calib_err(confidence, correct, p='1'))) 115 | 116 | # print('Soft F1-Score (%): \t\t{:.2f}'.format( 117 | # 100 * soft_f1(confidence, correct))) 118 | 119 | def fpr_and_fdr_at_recall(y_true, y_score, recall_level=recall_level_default, pos_label=None): 120 | classes = np.unique(y_true) 121 | if (pos_label is None and 122 | not (np.array_equal(classes, [0, 1]) or 123 | np.array_equal(classes, [-1, 1]) or 124 | np.array_equal(classes, [0]) or 125 | np.array_equal(classes, [-1]) or 126 | np.array_equal(classes, [1]))): 127 | raise ValueError("Data is not binary and pos_label is not specified") 128 | elif pos_label is None: 129 | pos_label = 1. 130 | 131 | # make y_true a boolean vector 132 | y_true = (y_true == pos_label) 133 | 134 | # sort scores and corresponding truth values 135 | desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1] 136 | y_score = y_score[desc_score_indices] 137 | y_true = y_true[desc_score_indices] 138 | 139 | # y_score typically has many tied values. Here we extract 140 | # the indices associated with the distinct values. We also 141 | # concatenate a value for the end of the curve. 142 | distinct_value_indices = np.where(np.diff(y_score))[0] 143 | threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1] 144 | 145 | # accumulate the true positives with decreasing threshold 146 | tps = stable_cumsum(y_true)[threshold_idxs] 147 | fps = 1 + threshold_idxs - tps # add one because of zero-based indexing 148 | 149 | thresholds = y_score[threshold_idxs] 150 | 151 | recall = tps / tps[-1] 152 | 153 | last_ind = tps.searchsorted(tps[-1]) 154 | sl = slice(last_ind, None, -1) # [last_ind::-1] 155 | recall, fps, tps, thresholds = np.r_[recall[sl], 1], np.r_[fps[sl], 0], np.r_[tps[sl], 0], thresholds[sl] 156 | 157 | cutoff = np.argmin(np.abs(recall - recall_level)) 158 | 159 | return fps[cutoff] / (np.sum(np.logical_not(y_true))) # , fps[cutoff]/(fps[cutoff] + tps[cutoff]) 160 | 161 | def get_measures(_pos, _neg, recall_level=recall_level_default): 162 | pos = np.array(_pos[:]).reshape((-1, 1)) 163 | neg = np.array(_neg[:]).reshape((-1, 1)) 164 | examples = np.squeeze(np.vstack((pos, neg))) 165 | labels = np.zeros(len(examples), dtype=np.int32) 166 | labels[:len(pos)] += 1 167 | 168 | auroc = sk.roc_auc_score(labels, examples) 169 | aupr = sk.average_precision_score(labels, examples) 170 | fpr = fpr_and_fdr_at_recall(labels, examples, recall_level) 171 | 172 | return auroc, aupr, fpr 173 | 174 | 175 | def print_measures_old(auroc, aupr, fpr, method_name='Ours', recall_level=recall_level_default): 176 | print('\t\t\t' + method_name) 177 | print('FPR{:d}:\t{:.2f}'.format(int(100 * recall_level), 100 * fpr)) 178 | print('AUROC: \t{:.2f}'.format(100 * auroc)) 179 | print('AUPR: \t{:.2f}'.format(100 * aupr)) 180 | 181 | 182 | def print_measures_with_std(aurocs, auprs, fprs, method_name='Ours', recall_level=recall_level_default): 183 | print('\t\t\t' + method_name) 184 | print('FPR{:d}:\t{:.2f}\t+/- {:.2f}'.format(int(100 * recall_level), 100 * np.mean(fprs), 100 * np.std(fprs))) 185 | print('AUROC: \t{:.2f}\t+/- {:.2f}'.format(100 * np.mean(aurocs), 100 * np.std(aurocs))) 186 | print('AUPR: \t{:.2f}\t+/- {:.2f}'.format(100 * np.mean(auprs), 100 * np.std(auprs))) 187 | 188 | 189 | def get_and_print_results(out_score, in_score, num_to_avg=1): 190 | 191 | aurocs, auprs, fprs = [], [], [] 192 | #for _ in range(num_to_avg): 193 | # out_score = get_ood_scores(ood_loader) 194 | measures = get_measures(out_score, in_score) 195 | aurocs.append(measures[0]); auprs.append(measures[1]); fprs.append(measures[2]) 196 | 197 | auroc = np.mean(aurocs); aupr = np.mean(auprs); fpr = np.mean(fprs) 198 | #auroc_list.append(auroc); aupr_list.append(aupr); fpr_list.append(fpr) 199 | 200 | #if num_to_avg >= 5: 201 | # print_measures_with_std(aurocs, auprs, fprs, method_name='Ours') 202 | #else: 203 | # print_measures(auroc, aupr, fpr, method_name='Ours') 204 | return auroc, aupr, fpr -------------------------------------------------------------------------------- /Robust/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import os 4 | import json 5 | 6 | 7 | from torchvision import datasets, transforms 8 | from torchvision.datasets.folder import ImageFolder, default_loader 9 | 10 | 11 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 12 | from timm.data import create_transform 13 | 14 | # dont think this is in use anymore 15 | PATH_TO_IMAGENET_VAL = '/path/to/data/imagenet/val' 16 | 17 | # dont think this is in use anymore 18 | def create_symlinks_to_imagenet(imagenet_folder, folder_to_scan): 19 | if not os.path.exists(imagenet_folder): 20 | os.makedirs(imagenet_folder) 21 | folders_of_interest = os.listdir(folder_to_scan) 22 | path_prefix = PATH_TO_IMAGENET_VAL 23 | for folder in folders_of_interest: 24 | os.symlink(path_prefix + folder, imagenet_folder+folder, target_is_directory=True) 25 | 26 | 27 | 28 | 29 | class INatDataset(ImageFolder): 30 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 31 | category='name', loader=default_loader): 32 | self.transform = transform 33 | self.loader = loader 34 | self.target_transform = target_transform 35 | self.year = year 36 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 37 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 38 | with open(path_json) as json_file: 39 | data = json.load(json_file) 40 | 41 | 42 | with open(os.path.join(root, 'categories.json')) as json_file: 43 | data_catg = json.load(json_file) 44 | 45 | 46 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 47 | 48 | 49 | with open(path_json_for_targeter) as json_file: 50 | data_for_targeter = json.load(json_file) 51 | 52 | 53 | targeter = {} 54 | indexer = 0 55 | for elem in data_for_targeter['annotations']: 56 | king = [] 57 | king.append(data_catg[int(elem['category_id'])][category]) 58 | if king[0] not in targeter.keys(): 59 | targeter[king[0]] = indexer 60 | indexer += 1 61 | self.nb_classes = len(targeter) 62 | 63 | 64 | self.samples = [] 65 | for elem in data['images']: 66 | cut = elem['file_name'].split('/') 67 | target_current = int(cut[2]) 68 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 69 | 70 | 71 | categors = data_catg[target_current] 72 | target_current_true = targeter[categors[category]] 73 | self.samples.append((path_current, target_current_true)) 74 | 75 | 76 | # __getitem__ and __len__ inherited from ImageFolder 77 | 78 | 79 | 80 | # called from main twice, once for training, once for val 81 | def build_dataset(is_train, args): 82 | transform = build_transform(is_train, args) 83 | 84 | 85 | if args.data_set == 'CIFAR100': 86 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) 87 | nb_classes = 100 88 | if args.data_set == 'CIFAR10': 89 | dataset = datasets.CIFAR10(args.data_path, train=is_train, transform=transform) 90 | nb_classes = 10 91 | elif args.data_set == 'IMNET': 92 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 93 | dataset = datasets.ImageFolder(root, transform=transform) 94 | class_names = dataset.classes 95 | nb_classes = 1000 96 | elif args.data_set == 'INAT': 97 | dataset = INatDataset(args.data_path, train=is_train, year=2018, 98 | category=args.inat_category, transform=transform) 99 | nb_classes = dataset.nb_classes 100 | elif args.data_set == 'INAT19': 101 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 102 | category=args.inat_category, transform=transform) 103 | nb_classes = dataset.nb_classes 104 | 105 | 106 | 107 | return dataset, nb_classes 108 | 109 | 110 | 111 | 112 | def build_transform(is_train, args): 113 | resize_im = args.input_size > 32 114 | if is_train: 115 | # this should always dispatch to transforms_imagenet_train 116 | transform = create_transform( 117 | input_size=args.input_size, 118 | is_training=True, 119 | color_jitter=args.color_jitter, 120 | auto_augment=args.aa, 121 | interpolation=args.train_interpolation, 122 | re_prob=args.reprob, 123 | re_mode=args.remode, 124 | re_count=args.recount, 125 | ) 126 | if not resize_im: 127 | # replace RandomResizedCropAndInterpolation with 128 | # RandomCrop 129 | transform.transforms[0] = transforms.RandomCrop( 130 | args.input_size, padding=4) 131 | return transform 132 | 133 | 134 | t = [] 135 | if resize_im: 136 | size = int((256 / 224) * args.input_size) 137 | t.append( 138 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 139 | ) 140 | t.append(transforms.CenterCrop(args.input_size)) 141 | 142 | 143 | t.append(transforms.ToTensor()) 144 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 145 | return transforms.Compose(t) 146 | 147 | 148 | 149 | -------------------------------------------------------------------------------- /Robust/engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Train and eval functions used in main.py 5 | """ 6 | import math 7 | import sys 8 | from typing import Iterable, Optional 9 | from fvcore.nn import FlopCountAnalysis 10 | import wandb 11 | import numpy as np 12 | 13 | 14 | import torch 15 | 16 | 17 | from timm.data import Mixup 18 | from timm.utils import accuracy, ModelEma 19 | 20 | from losses import DistillationLoss 21 | import utils 22 | import pdb 23 | 24 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 25 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 26 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 27 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 28 | set_training_mode=True, wandb_flag=False): 29 | # put our model in training mode... so that drop out and batch normalisation does not affect it 30 | model.train(set_training_mode) 31 | metric_logger = utils.MetricLogger(delimiter=" ") 32 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 33 | header = 'Epoch: [{}]'.format(epoch) 34 | print_freq = 10 35 | 36 | # i = 0. 37 | for i, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 38 | samples = samples.to(device, non_blocking=True) 39 | targets = targets.to(device, non_blocking=True) 40 | # if i == 50: 41 | # break 42 | if mixup_fn is not None: 43 | samples, targets = mixup_fn(samples, targets) 44 | 45 | 46 | with torch.cuda.amp.autocast(): 47 | # flops = FlopCountAnalysis(model,samples) 48 | # print(flops.total()/1e9) 49 | # assert 1==2 50 | outputs = model(samples) 51 | loss = criterion(samples, outputs, targets) 52 | 53 | # loss is a tensor, averaged over the mini batch 54 | loss_value = loss.item() 55 | 56 | 57 | if not math.isfinite(loss_value): 58 | # print("Loss is {}, stopping training".format(loss_value)) 59 | f = open("error.txt", "a") 60 | # writing in the file 61 | f.write("Loss is {}, stopping training".format(loss_value)) 62 | # closing the file 63 | f.close() 64 | sys.exit(1) 65 | 66 | 67 | optimizer.zero_grad() 68 | 69 | 70 | # this attribute is added by timm on one optimizer (adahessian) 71 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 72 | # provides optimisation step for model 73 | loss_scaler(loss, optimizer, clip_grad=max_norm, 74 | parameters=model.parameters(), create_graph=is_second_order) 75 | 76 | 77 | torch.cuda.synchronize() 78 | if model_ema is not None: 79 | model_ema.update(model) 80 | 81 | 82 | metric_logger.update(loss=loss_value) 83 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 84 | # gather the stats from all processes 85 | metric_logger.synchronize_between_processes() 86 | print("Averaged stats:", metric_logger) 87 | if wandb_flag: 88 | for k, meter in metric_logger.meters.items(): 89 | wandb.log({k: meter.global_avg, 'epoch':epoch}) 90 | 91 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 92 | 93 | # evaluate on 1000 images in imagenet/val folder 94 | @torch.no_grad() 95 | def evaluate(data_loader, model, device, attn_only=False, batch_limit=0, epoch=0, wandb_flag=False): 96 | criterion = torch.nn.CrossEntropyLoss() 97 | 98 | metric_logger = utils.MetricLogger(delimiter=" ") 99 | header = 'Test:' 100 | 101 | # switch to evaluation mode 102 | model.eval() 103 | # i = 0 104 | if not isinstance(batch_limit, int) or batch_limit < 0: 105 | batch_limit = 0 106 | attn = [] 107 | pi = [] 108 | for i, (images, target) in enumerate(metric_logger.log_every(data_loader, 10, header)): 109 | if i >= batch_limit > 0: 110 | break 111 | images = images.to(device, non_blocking=True) 112 | target = target.to(device, non_blocking=True) 113 | 114 | 115 | with torch.cuda.amp.autocast(): 116 | if attn_only: 117 | output, _aux = model(images) 118 | attn.append(_aux[0].detach().cpu().numpy()) 119 | pi.append(_aux[1].detach().cpu().numpy()) 120 | del _aux 121 | else: 122 | output = model(images) 123 | loss = criterion(output, target) 124 | 125 | # print(output.shape,target.shape) 126 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 127 | 128 | batch_size = images.shape[0] 129 | metric_logger.update(loss=loss.item()) 130 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 131 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 132 | r = {k: meter.global_avg for k, meter in metric_logger.meters.items()} 133 | # gather the stats from all processes 134 | metric_logger.synchronize_between_processes() 135 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 136 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 137 | if wandb_flag: 138 | for k, meter in metric_logger.meters.items(): 139 | wandb.log({f'test_{k}': meter.global_avg, 'epoch':epoch}) 140 | 141 | if attn_only: 142 | return r, (attn, pi) 143 | return r 144 | 145 | 146 | -------------------------------------------------------------------------------- /Robust/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Implements the knowledge distillation loss 5 | """ 6 | import torch 7 | from torch.nn import functional as F 8 | 9 | 10 | 11 | 12 | class DistillationLoss(torch.nn.Module): 13 | """ 14 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 15 | taking a teacher model prediction and using it as additional supervision. 16 | """ 17 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 18 | distillation_type: str, alpha: float, tau: float): 19 | super().__init__() 20 | self.base_criterion = base_criterion 21 | self.teacher_model = teacher_model 22 | assert distillation_type in ['none', 'soft', 'hard'] 23 | self.distillation_type = distillation_type 24 | self.alpha = alpha 25 | self.tau = tau 26 | 27 | 28 | def forward(self, inputs, outputs, labels): 29 | """ 30 | Args: 31 | inputs: The original inputs that are feed to the teacher model 32 | outputs: the outputs of the model to be trained. It is expected to be 33 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 34 | in the first position and the distillation predictions as the second output 35 | labels: the labels for the base criterion 36 | """ 37 | outputs_kd = None 38 | if not isinstance(outputs, torch.Tensor): 39 | # assume that the model outputs a tuple of [outputs, outputs_kd] 40 | outputs, outputs_kd = outputs 41 | base_loss = self.base_criterion(outputs, labels) 42 | if self.distillation_type == 'none': 43 | return base_loss 44 | 45 | 46 | if outputs_kd is None: 47 | raise ValueError("When knowledge distillation is enabled, the model is " 48 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 49 | "class_token and the dist_token") 50 | # don't backprop throught the teacher 51 | with torch.no_grad(): 52 | teacher_outputs = self.teacher_model(inputs) 53 | 54 | 55 | if self.distillation_type == 'soft': 56 | T = self.tau 57 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 58 | # with slight modifications 59 | distillation_loss = F.kl_div( 60 | F.log_softmax(outputs_kd / T, dim=1), 61 | F.log_softmax(teacher_outputs / T, dim=1), 62 | reduction='sum', 63 | log_target=True 64 | ) * (T * T) / outputs_kd.numel() 65 | elif self.distillation_type == 'hard': 66 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) 67 | 68 | 69 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 70 | return loss 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /Robust/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.nn as nn 5 | from functools import partial 6 | 7 | 8 | from timm.models.vision_transformer import _cfg 9 | from softmax import VisionTransformer 10 | from timm.models.registry import register_model 11 | from timm.models.layers import trunc_normal_ 12 | # from xcit import XCiT, HDPXCiT 13 | 14 | class DistilledVisionTransformer(VisionTransformer): 15 | def __init__(self, *args, **kwargs): 16 | super().__init__(*args, **kwargs) 17 | self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 18 | num_patches = self.patch_embed.num_patches 19 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) 20 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() 21 | 22 | trunc_normal_(self.dist_token, std=.02) 23 | trunc_normal_(self.pos_embed, std=.02) 24 | self.head_dist.apply(self._init_weights) 25 | 26 | def forward_features(self, x): 27 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 28 | # with slight modifications to add the dist_token 29 | B = x.shape[0] 30 | x = self.patch_embed(x) 31 | 32 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 33 | dist_token = self.dist_token.expand(B, -1, -1) 34 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 35 | 36 | x = x + self.pos_embed 37 | x = self.pos_drop(x) 38 | 39 | for blk in self.blocks: 40 | x = blk(x) 41 | 42 | 43 | x = self.norm(x) 44 | return x[:, 0], x[:, 1] 45 | 46 | def forward(self, x): 47 | x, x_dist = self.forward_features(x) 48 | x = self.head(x) 49 | x_dist = self.head_dist(x_dist) 50 | if self.training: 51 | return x, x_dist 52 | else: 53 | # during inference, return the average of both classifier predictions 54 | return (x + x_dist) / 2 55 | 56 | # register model with timms to be able to call it from "create_model" using its function name 57 | # but mainly edit the model from softmax.py 58 | @register_model 59 | def deit_tiny_patch16_224(pretrained=False, **kwargs): 60 | from softmax import VisionTransformer 61 | model = VisionTransformer( 62 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 63 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) # Tan's NOTE: in the original code, num_heads = 3 here 64 | model.default_cfg = _cfg() 65 | return model 66 | 67 | @register_model 68 | def deit_small_patch16_224(pretrained=False, **kwargs): 69 | from softmax import VisionTransformer 70 | model = VisionTransformer( 71 | patch_size=16, embed_dim=348, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 72 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) # Tan's NOTE: in the original code, num_heads = 3 here 73 | model.default_cfg = _cfg() 74 | return model 75 | 76 | @register_model 77 | def deit_base_patch16_224(pretrained=False, **kwargs): 78 | from softmax import VisionTransformer 79 | model = VisionTransformer( 80 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 81 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) # Tan's NOTE: in the original code, num_heads = 3 here 82 | model.default_cfg = _cfg() 83 | return model 84 | 85 | -------------------------------------------------------------------------------- /Robust/requirements.txt: -------------------------------------------------------------------------------- 1 | fvcore==0.1.5.post20221221 2 | numpy==1.24.4 3 | timm==0.9.12 4 | torch==2.2.0a0+81ea7a4 5 | torchvision==0.17.0a0 6 | tqdm==4.66.1 7 | wandb==0.16.2 8 | -------------------------------------------------------------------------------- /Robust/run.sh: -------------------------------------------------------------------------------- 1 | ###### FOR TRAINING 2 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 1 --nproc_per_node=4 \ 3 | --use_env main_train.py \ 4 | --model deit_base_patch16_224 --batch-size 256 --data-path /path/to/data/imagenet/ \ 5 | --output_dir /path/to/checkpoints/ \ 6 | --clip-grad 1.0 \ 7 | --robust --num_iter 1 --lambd 4.0 --layer 0 8 | 9 | ###### FOR ROBUSTNESS EVAL 10 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 1 --nproc_per_node=4 --use_env eval_OOD.py \ 11 | --model deit_tiny_patch16_224 --data-path /path/to/data/imagenet/ --output_dir /path/to/checkpoints/ \ 12 | --robust --num_iter 4 --lambd 4 --layer 0 --resume /path/to/model/checkpoint/ -------------------------------------------------------------------------------- /Robust/samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.distributed as dist 5 | import math 6 | 7 | 8 | 9 | 10 | class RASampler(torch.utils.data.Sampler): 11 | """Sampler that restricts data loading to a subset of the dataset for distributed, 12 | with repeated augmentation. 13 | It ensures that different each augmented version of a sample will be visible to a 14 | different process (GPU) 15 | Heavily based on torch.utils.data.DistributedSampler 16 | """ 17 | 18 | # num_replicas = world size, rank = global rank 19 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 20 | if num_replicas is None: 21 | if not dist.is_available(): 22 | raise RuntimeError("Requires distributed package to be available") 23 | num_replicas = dist.get_world_size() 24 | if rank is None: 25 | if not dist.is_available(): 26 | raise RuntimeError("Requires distributed package to be available") 27 | rank = dist.get_rank() 28 | self.dataset = dataset 29 | self.num_replicas = num_replicas 30 | self.rank = rank 31 | self.epoch = 0 32 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 33 | self.total_size = self.num_samples * self.num_replicas 34 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 35 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 36 | self.shuffle = shuffle 37 | 38 | 39 | def __iter__(self): 40 | # deterministically shuffle based on epoch 41 | g = torch.Generator() 42 | g.manual_seed(self.epoch) 43 | if self.shuffle: 44 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 45 | else: 46 | indices = list(range(len(self.dataset))) 47 | 48 | 49 | # add extra samples to make it evenly divisible by 3 50 | indices = [ele for ele in indices for i in range(3)] 51 | indices += indices[:(self.total_size - len(indices))] 52 | assert len(indices) == self.total_size 53 | 54 | 55 | # subsample 56 | indices = indices[self.rank:self.total_size:self.num_replicas] 57 | assert len(indices) == self.num_samples 58 | 59 | 60 | return iter(indices[:self.num_selected_samples]) 61 | 62 | 63 | def __len__(self): 64 | return self.num_selected_samples 65 | 66 | 67 | def set_epoch(self, epoch): 68 | self.epoch = epoch 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /Robust/softmax.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | from functools import partial 4 | from collections import OrderedDict 5 | from copy import deepcopy 6 | from statistics import mean 7 | import numpy as np 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ 14 | from timm.models.vision_transformer import init_weights_vit_timm, init_weights_vit_jax, _load_weights 15 | from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv 16 | from utils import named_apply 17 | import copy 18 | import wandb 19 | 20 | 21 | 22 | class Attention(nn.Module): 23 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., 24 | robust=False, layerth=0, n=1, lambd=0, layer=0): 25 | super().__init__() 26 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 27 | self.num_heads = num_heads 28 | head_dim = dim // num_heads 29 | self.n = n 30 | self.lambd = lambd 31 | self.layer = layer 32 | # sqrt (D) 33 | self.scale = head_dim ** -0.5 34 | self.layerth = layerth 35 | 36 | self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias) 37 | 38 | self.attn_drop = nn.Dropout(attn_drop) 39 | 40 | self.proj = nn.Linear(dim, dim) 41 | self.proj_drop = nn.Dropout(proj_drop) 42 | self.robust = robust 43 | 44 | def forward(self, x): 45 | B, N, C = x.shape 46 | # q,k -> B -> heads -> n -> features 47 | qkv = self.qkv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 48 | k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 49 | 50 | if self.robust and self.layer < 0: 51 | l = torch.zeros((B,self.num_heads,N,C // self.num_heads)).to(torch.device("cuda"), non_blocking=True) 52 | y = torch.zeros((B,self.num_heads,N,C // self.num_heads)).to(torch.device("cuda"), non_blocking=True) 53 | 54 | mu=N*C/4/k.norm(p=1,dim=[-1,-2],keepdim=True) 55 | 56 | for i in range(0,self.n-1): 57 | s = k-l+y/mu 58 | s_less = s.le(-self.lambd*mu).int() 59 | s_more = s.ge(self.lambd*mu).int() 60 | s = (s-self.lambd*mu)*s_more + (s+self.lambd*mu)*s_less 61 | k2 = k-s-y/mu 62 | l = (k2 @ k2.transpose(-2, -1)) * self.scale 63 | l = l.softmax(dim=-1) 64 | l = l @ v 65 | y = y+mu*(k-l-s) 66 | 67 | s = k-l+y/mu 68 | s_less = s.le(-self.lambd*mu).int() 69 | s_more = s.ge(self.lambd*mu).int() 70 | s = (s-self.lambd*mu)*s_more + (s+self.lambd*mu)*s_less 71 | k2 = k-s-y/mu 72 | l = (k2 @ k2.transpose(-2, -1)) * self.scale 73 | l = l.softmax(dim=-1) 74 | l = self.attn_drop(l) 75 | x = l @ v 76 | y = y+mu*(k-x-s) 77 | 78 | elif self.robust and self.layerth==self.layer: 79 | l = torch.zeros((B,self.num_heads,N,C // self.num_heads)).to(torch.device("cuda"), non_blocking=True) 80 | y = torch.zeros((B,self.num_heads,N,C // self.num_heads)).to(torch.device("cuda"), non_blocking=True) 81 | 82 | mu=N*C/4/k.norm(p=1,dim=[-1,-2],keepdim=True) 83 | 84 | for i in range(0,self.n-1): 85 | s = k-l+y/mu 86 | s_less = s.le(-self.lambd*mu).int() 87 | s_more = s.ge(self.lambd*mu).int() 88 | s = (s-self.lambd*mu)*s_more + (s+self.lambd*mu)*s_less 89 | k2 = k-s-y/mu 90 | l = (k2 @ k2.transpose(-2, -1)) * self.scale 91 | l = l.softmax(dim=-1) 92 | l = l @ v 93 | y = y+mu*(k-l-s) 94 | 95 | s = k-l+y/mu 96 | s_less = s.le(-self.lambd*mu).int() 97 | s_more = s.ge(self.lambd*mu).int() 98 | s = (s-self.lambd*mu)*s_more + (s+self.lambd*mu)*s_less 99 | k2 = k-s-y/mu 100 | l = (k2 @ k2.transpose(-2, -1)) * self.scale 101 | l = l.softmax(dim=-1) 102 | l = self.attn_drop(l) 103 | x = l @ v 104 | y = y+mu*(k-x-s) 105 | 106 | else: 107 | attn = (k @ k.transpose(-2, -1)) * self.scale 108 | attn = attn.softmax(dim=-1) 109 | 110 | attn = self.attn_drop(attn) 111 | 112 | # @ is a matrix multiplication 113 | x = (attn @ v) 114 | 115 | # @ is a matrix multiplication 116 | x = x.transpose(1, 2).reshape(B,N,C) 117 | 118 | x = self.proj(x) 119 | x = self.proj_drop(x) 120 | 121 | ################ COSINE SIMILARITY MEASURE 122 | # n = x.shape[1] #x is in shape of (batchsize, length, dim) 123 | # sqaure norm across features 124 | # x_norm = torch.norm(x, 2, dim = -1, keepdim= True) 125 | # x_ = x/x_norm 126 | # x_cossim = torch.tril((x_ @ x_.transpose(-2, -1)), diagonal= -1).sum(dim = (-1, -2))/(n*(n - 1)/2) 127 | # x_cossim = x_cossim.mean() 128 | # python debugger breakpoint 129 | # import pdb;pdb.set_trace() 130 | ################ 131 | 132 | return x 133 | 134 | 135 | class Block(nn.Module): 136 | 137 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 138 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, layerth=None, 139 | robust=False, n=1, lambd=0, layer=0): 140 | super().__init__() 141 | self.norm1 = norm_layer(dim) 142 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, 143 | attn_drop=attn_drop, proj_drop=drop, robust=robust, 144 | layerth=layerth, n=n, lambd=lambd, layer=layer) 145 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 146 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 147 | self.norm2 = norm_layer(dim) 148 | mlp_hidden_dim = int(dim * mlp_ratio) 149 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 150 | self.layerth = layerth 151 | 152 | def forward(self, x): 153 | x = x + self.drop_path(self.attn(self.norm1(x))) 154 | x = x + self.drop_path(self.mlp(self.norm2(x))) 155 | return x 156 | 157 | 158 | class VisionTransformer(nn.Module): 159 | """ Vision Transformer 160 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 161 | - https://arxiv.org/abs/2010.11929 162 | Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` 163 | - https://arxiv.org/abs/2012.12877 164 | """ 165 | 166 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 167 | num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, 168 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, 169 | act_layer=None, weight_init='',pretrained_cfg=None,pretrained_cfg_overlay=None,robust=False,n=1,lambd=0,layer=0): 170 | """ 171 | Args: 172 | img_size (int, tuple): input image size 173 | patch_size (int, tuple): patch size 174 | in_chans (int): number of input channels 175 | num_classes (int): number of classes for classification head 176 | embed_dim (int): embedding dimension 177 | depth (int): depth of transformer 178 | num_heads (int): number of attention heads 179 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 180 | qkv_bias (bool): enable bias for qkv if True 181 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 182 | distilled (bool): model includes a distillation token and head as in DeiT models 183 | drop_rate (float): dropout rate 184 | attn_drop_rate (float): attention dropout rate 185 | drop_path_rate (float): stochastic depth rate 186 | embed_layer (nn.Module): patch embedding layer 187 | norm_layer: (nn.Module): normalization layer 188 | weight_init: (str): weight init scheme 189 | """ 190 | super().__init__() 191 | self.num_classes = num_classes 192 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 193 | self.num_tokens = 2 if distilled else 1 194 | self.lambd = lambd 195 | self.n = n 196 | self.layer = layer 197 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 198 | act_layer = act_layer or nn.GELU 199 | 200 | self.patch_embed = embed_layer( 201 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 202 | num_patches = self.patch_embed.num_patches 203 | # print(img_size,patch_size,in_chans,num_patches) 204 | 205 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 206 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None 207 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 208 | self.pos_drop = nn.Dropout(p=drop_rate) 209 | 210 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 211 | self.blocks = nn.Sequential(*[ 212 | Block( 213 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, 214 | attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, 215 | layerth = i, robust=robust, n=self.n, lambd=self.lambd, layer=self.layer) 216 | for i in range(depth)]) 217 | self.norm = norm_layer(embed_dim) 218 | 219 | # Representation layer 220 | if representation_size and not distilled: 221 | self.num_features = representation_size 222 | self.pre_logits = nn.Sequential(OrderedDict([ 223 | ('fc', nn.Linear(embed_dim, representation_size)), 224 | ('act', nn.Tanh()) 225 | ])) 226 | else: 227 | self.pre_logits = nn.Identity() 228 | 229 | # Classifier head(s) 230 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 231 | self.head_dist = None 232 | if distilled: 233 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 234 | 235 | self.init_weights(weight_init) 236 | 237 | def init_weights(self, mode=''): 238 | assert mode in ('jax', 'jax_nlhb', 'nlhb', '') 239 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. 240 | trunc_normal_(self.pos_embed, std=.02) 241 | if self.dist_token is not None: 242 | trunc_normal_(self.dist_token, std=.02) 243 | if mode.startswith('jax'): 244 | # leave cls token as zeros to match jax impl 245 | partial(init_weights_vit_jax(mode, head_bias), head_bias=head_bias, jax_impl=True) 246 | else: 247 | trunc_normal_(self.cls_token, std=.02) 248 | init_weights_vit_timm 249 | 250 | def _init_weights(self, m): 251 | # this fn left here for compat with downstream users 252 | init_weights(m) 253 | 254 | @torch.jit.ignore() 255 | def load_pretrained(self, checkpoint_path, prefix=''): 256 | _load_weights(self, checkpoint_path, prefix) 257 | 258 | @torch.jit.ignore 259 | def no_weight_decay(self): 260 | return {'pos_embed', 'cls_token', 'dist_token'} 261 | 262 | def get_classifier(self): 263 | if self.dist_token is None: 264 | return self.head 265 | else: 266 | return self.head, self.head_dist 267 | 268 | def reset_classifier(self, num_classes, global_pool=''): 269 | self.num_classes = num_classes 270 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 271 | if self.num_tokens == 2: 272 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 273 | 274 | def forward_features(self, x): 275 | x = self.patch_embed(x) 276 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks 277 | if self.dist_token is None: 278 | x = torch.cat((cls_token, x), dim=1) 279 | else: 280 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) 281 | # add the same pos_emb token to each sample? broadcasting... 282 | x = self.pos_drop(x + self.pos_embed) 283 | x = self.blocks(x) 284 | x = self.norm(x) 285 | 286 | if self.dist_token is None: 287 | return self.pre_logits(x[:, 0]) 288 | else: 289 | return x[:, 0], x[:, 1] 290 | 291 | def forward(self, x): 292 | x = self.forward_features(x) 293 | if self.head_dist is not None: 294 | x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple 295 | if self.training and not torch.jit.is_scripting(): 296 | # during inference, return the average of both classifier predictions 297 | return x, x_dist 298 | else: 299 | return (x + x_dist) / 2 300 | else: 301 | x = self.head(x) 302 | return x 303 | -------------------------------------------------------------------------------- /Robust/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Misc functions, including distributed helpers. 5 | 6 | 7 | Mostly copy-paste from torchvision references. 8 | """ 9 | import io 10 | import os 11 | import time 12 | from collections import defaultdict, deque 13 | import datetime 14 | from typing import Any, Callable, Dict, Iterator, Tuple, Type, Union 15 | 16 | 17 | import torch 18 | import torch.distributed as dist 19 | from torch import nn as nn 20 | 21 | 22 | 23 | class SmoothedValue(object): 24 | """Track a series of values and provide access to smoothed values over a 25 | window or the global series average. 26 | """ 27 | 28 | 29 | def __init__(self, window_size=20, fmt=None): 30 | if fmt is None: 31 | fmt = "{median:.4f} ({global_avg:.4f})" 32 | self.deque = deque(maxlen=window_size) 33 | self.total = 0.0 34 | self.count = 0 35 | self.fmt = fmt 36 | 37 | 38 | def update(self, value, n=1): 39 | self.deque.append(value) 40 | self.count += n 41 | self.total += value * n 42 | 43 | 44 | def synchronize_between_processes(self): 45 | """ 46 | Warning: does not synchronize the deque! 47 | """ 48 | if not is_dist_avail_and_initialized(): 49 | return 50 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 51 | dist.barrier() 52 | dist.all_reduce(t) 53 | t = t.tolist() 54 | self.count = int(t[0]) 55 | self.total = t[1] 56 | 57 | 58 | @property 59 | def median(self): 60 | d = torch.tensor(list(self.deque)) 61 | return d.median().item() 62 | 63 | 64 | @property 65 | def avg(self): 66 | d = torch.tensor(list(self.deque), dtype=torch.float32) 67 | return d.mean().item() 68 | 69 | 70 | @property 71 | def global_avg(self): 72 | return self.total / self.count 73 | 74 | 75 | @property 76 | def max(self): 77 | return max(self.deque) 78 | 79 | 80 | @property 81 | def value(self): 82 | return self.deque[-1] 83 | 84 | 85 | def __str__(self): 86 | return self.fmt.format( 87 | median=self.median, 88 | avg=self.avg, 89 | global_avg=self.global_avg, 90 | max=self.max, 91 | value=self.value) 92 | 93 | 94 | 95 | 96 | class MetricLogger(object): 97 | def __init__(self, delimiter="\t"): 98 | self.meters = defaultdict(SmoothedValue) 99 | self.delimiter = delimiter 100 | 101 | 102 | def update(self, **kwargs): 103 | for k, v in kwargs.items(): 104 | if isinstance(v, torch.Tensor): 105 | v = v.item() 106 | assert isinstance(v, (float, int)) 107 | self.meters[k].update(v) 108 | 109 | 110 | def __getattr__(self, attr): 111 | if attr in self.meters: 112 | return self.meters[attr] 113 | if attr in self.__dict__: 114 | return self.__dict__[attr] 115 | raise AttributeError("'{}' object has no attribute '{}'".format( 116 | type(self).__name__, attr)) 117 | 118 | 119 | def __str__(self): 120 | loss_str = [] 121 | for name, meter in self.meters.items(): 122 | loss_str.append( 123 | "{}: {}".format(name, str(meter)) 124 | ) 125 | return self.delimiter.join(loss_str) 126 | 127 | 128 | def synchronize_between_processes(self): 129 | for meter in self.meters.values(): 130 | meter.synchronize_between_processes() 131 | 132 | 133 | def add_meter(self, name, meter): 134 | self.meters[name] = meter 135 | 136 | 137 | #iterable is our data_loader which is pytorch data loader with our dataset_train obj and RA sampler 138 | def log_every(self, iterable, print_freq, header=None): 139 | i = 0 140 | if not header: 141 | header = '' 142 | start_time = time.time() 143 | end = time.time() 144 | iter_time = SmoothedValue(fmt='{avg:.4f}') 145 | data_time = SmoothedValue(fmt='{avg:.4f}') 146 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 147 | log_msg = [ 148 | header, 149 | '[{0' + space_fmt + '}/{1}]', 150 | 'eta: {eta}', 151 | '{meters}', 152 | 'time: {time}', 153 | 'data: {data}' 154 | ] 155 | if torch.cuda.is_available(): 156 | log_msg.append('max mem: {memory:.0f}') 157 | log_msg = self.delimiter.join(log_msg) 158 | MB = 1024.0 * 1024.0 159 | for obj in iterable: 160 | data_time.update(time.time() - end) 161 | # returns obj to caller, then continues loop 162 | yield obj 163 | iter_time.update(time.time() - end) 164 | if i % print_freq == 0 or i == len(iterable) - 1: 165 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 166 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 167 | if torch.cuda.is_available(): 168 | print(log_msg.format( 169 | i, len(iterable), eta=eta_string, 170 | meters=str(self), 171 | time=str(iter_time), data=str(data_time), 172 | memory=torch.cuda.max_memory_allocated() / MB)) 173 | else: 174 | print(log_msg.format( 175 | i, len(iterable), eta=eta_string, 176 | meters=str(self), 177 | time=str(iter_time), data=str(data_time))) 178 | i += 1 179 | end = time.time() 180 | # need to remove this! 181 | total_time = time.time() - start_time 182 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 183 | print('{} Total time: {} ({:.4f} s / it)'.format( 184 | header, total_time_str, total_time / len(iterable))) 185 | 186 | 187 | 188 | 189 | def _load_checkpoint_for_ema(model_ema, checkpoint): 190 | """ 191 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 192 | """ 193 | mem_file = io.BytesIO() 194 | torch.save(checkpoint, mem_file) 195 | mem_file.seek(0) 196 | model_ema._load_checkpoint(mem_file) 197 | 198 | 199 | 200 | 201 | def setup_for_distributed(is_master): 202 | """ 203 | This function disables printing when not in master process 204 | """ 205 | import builtins as __builtin__ 206 | builtin_print = __builtin__.print 207 | 208 | 209 | def print(*args, **kwargs): 210 | force = kwargs.pop('force', False) 211 | if is_master or force: 212 | builtin_print(*args, **kwargs) 213 | 214 | 215 | __builtin__.print = print 216 | 217 | 218 | 219 | 220 | def is_dist_avail_and_initialized(): 221 | if not dist.is_available(): 222 | return False 223 | if not dist.is_initialized(): 224 | return False 225 | return True 226 | 227 | 228 | 229 | 230 | def get_world_size(): 231 | if not is_dist_avail_and_initialized(): 232 | return 1 233 | return dist.get_world_size() 234 | 235 | 236 | 237 | 238 | def get_rank(): 239 | if not is_dist_avail_and_initialized(): 240 | return 0 241 | return dist.get_rank() 242 | 243 | 244 | 245 | 246 | def is_main_process(): 247 | return get_rank() == 0 248 | 249 | 250 | 251 | 252 | def save_on_master(*args, **kwargs): 253 | if is_main_process(): 254 | torch.save(*args, **kwargs) 255 | 256 | 257 | 258 | 259 | def init_distributed_mode(args): 260 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 261 | args.rank = int(os.environ["RANK"]) 262 | args.world_size = int(os.environ['WORLD_SIZE']) 263 | args.gpu = int(os.environ['LOCAL_RANK']) 264 | elif 'SLURM_PROCID' in os.environ: 265 | args.rank = int(os.environ['SLURM_PROCID']) 266 | args.gpu = args.rank % torch.cuda.device_count() 267 | else: 268 | print('Not using distributed mode') 269 | args.distributed = False 270 | return 271 | 272 | 273 | args.distributed = True 274 | 275 | 276 | torch.cuda.set_device(args.gpu) 277 | args.dist_backend = 'nccl' 278 | print('| distributed init (rank {}): {}'.format( 279 | args.rank, args.dist_url), flush=True) 280 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 281 | world_size=args.world_size, rank=args.rank) 282 | torch.distributed.barrier() 283 | setup_for_distributed(args.rank == 0) 284 | 285 | def named_apply( 286 | fn: Callable, 287 | module: nn.Module, name='', 288 | depth_first: bool = True, 289 | include_root: bool = False, 290 | ) -> nn.Module: 291 | if not depth_first and include_root: 292 | fn(module=module, name=name) 293 | for child_name, child_module in module.named_children(): 294 | child_name = '.'.join((name, child_name)) if name else child_name 295 | named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) 296 | if depth_first and include_root: 297 | fn(module=module, name=name) 298 | return module 299 | -------------------------------------------------------------------------------- /Robust/utils_robust.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Misc functions, including distributed helpers. 5 | 6 | Mostly copy-paste from torchvision references. 7 | """ 8 | import io 9 | import os 10 | import time 11 | from collections import defaultdict, deque 12 | import datetime 13 | 14 | import torch 15 | import torch.distributed as dist 16 | 17 | data_loaders_names = { 18 | 'Brightness': 'brightness', 19 | 'Contrast': 'contrast', 20 | 'Defocus Blur': 'defocus_blur', 21 | 'Elastic Transform': 'elastic_transform', 22 | 'Fog': 'fog', 23 | 'Frost': 'frost', 24 | 'Gaussian Noise': 'gaussian_noise', 25 | 'Glass Blur': 'glass_blur', 26 | 'Impulse Noise': 'impulse_noise', 27 | 'JPEG Compression': 'jpeg_compression', 28 | 'Motion Blur': 'motion_blur', 29 | 'Pixelate': 'pixelate', 30 | 'Shot Noise': 'shot_noise', 31 | 'Snow': 'snow', 32 | 'Zoom Blur': 'zoom_blur' 33 | } 34 | 35 | def get_ce_alexnet(): 36 | """Returns Corruption Error values for AlexNet""" 37 | 38 | ce_alexnet = dict() 39 | ce_alexnet['Gaussian Noise'] = 0.886428 40 | ce_alexnet['Shot Noise'] = 0.894468 41 | ce_alexnet['Impulse Noise'] = 0.922640 42 | ce_alexnet['Defocus Blur'] = 0.819880 43 | ce_alexnet['Glass Blur'] = 0.826268 44 | ce_alexnet['Motion Blur'] = 0.785948 45 | ce_alexnet['Zoom Blur'] = 0.798360 46 | ce_alexnet['Snow'] = 0.866816 47 | ce_alexnet['Frost'] = 0.826572 48 | ce_alexnet['Fog'] = 0.819324 49 | ce_alexnet['Brightness'] = 0.564592 50 | ce_alexnet['Contrast'] = 0.853204 51 | ce_alexnet['Elastic Transform'] = 0.646056 52 | ce_alexnet['Pixelate'] = 0.717840 53 | ce_alexnet['JPEG Compression'] = 0.606500 54 | 55 | return ce_alexnet 56 | 57 | def get_mce_from_accuracy(accuracy, error_alexnet): 58 | """Computes mean Corruption Error from accuracy""" 59 | error = 100. - accuracy 60 | ce = error / (error_alexnet * 100.) 61 | 62 | return ce 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /Scaled_Attention/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import os 4 | import json 5 | 6 | 7 | from torchvision import datasets, transforms 8 | from torchvision.datasets.folder import ImageFolder, default_loader 9 | 10 | 11 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 12 | from timm.data import create_transform 13 | 14 | # dont think this is in use anymore 15 | PATH_TO_IMAGENET_VAL = '/path/data/imagenet/val' 16 | 17 | # dont think this is in use anymore 18 | def create_symlinks_to_imagenet(imagenet_folder, folder_to_scan): 19 | if not os.path.exists(imagenet_folder): 20 | os.makedirs(imagenet_folder) 21 | folders_of_interest = os.listdir(folder_to_scan) 22 | path_prefix = PATH_TO_IMAGENET_VAL 23 | for folder in folders_of_interest: 24 | os.symlink(path_prefix + folder, imagenet_folder+folder, target_is_directory=True) 25 | 26 | 27 | 28 | 29 | class INatDataset(ImageFolder): 30 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 31 | category='name', loader=default_loader): 32 | self.transform = transform 33 | self.loader = loader 34 | self.target_transform = target_transform 35 | self.year = year 36 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 37 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 38 | with open(path_json) as json_file: 39 | data = json.load(json_file) 40 | 41 | 42 | with open(os.path.join(root, 'categories.json')) as json_file: 43 | data_catg = json.load(json_file) 44 | 45 | 46 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 47 | 48 | 49 | with open(path_json_for_targeter) as json_file: 50 | data_for_targeter = json.load(json_file) 51 | 52 | 53 | targeter = {} 54 | indexer = 0 55 | for elem in data_for_targeter['annotations']: 56 | king = [] 57 | king.append(data_catg[int(elem['category_id'])][category]) 58 | if king[0] not in targeter.keys(): 59 | targeter[king[0]] = indexer 60 | indexer += 1 61 | self.nb_classes = len(targeter) 62 | 63 | 64 | self.samples = [] 65 | for elem in data['images']: 66 | cut = elem['file_name'].split('/') 67 | target_current = int(cut[2]) 68 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 69 | 70 | 71 | categors = data_catg[target_current] 72 | target_current_true = targeter[categors[category]] 73 | self.samples.append((path_current, target_current_true)) 74 | 75 | 76 | # __getitem__ and __len__ inherited from ImageFolder 77 | 78 | 79 | 80 | # called from main twice, once for training, once for val 81 | def build_dataset(is_train, args): 82 | transform = build_transform(is_train, args) 83 | 84 | 85 | if args.data_set == 'CIFAR100': 86 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) 87 | nb_classes = 100 88 | if args.data_set == 'CIFAR10': 89 | dataset = datasets.CIFAR10(args.data_path, train=is_train, transform=transform) 90 | nb_classes = 10 91 | elif args.data_set == 'IMNET': 92 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 93 | dataset = datasets.ImageFolder(root, transform=transform) 94 | class_names = dataset.classes 95 | nb_classes = 1000 96 | elif args.data_set == 'INAT': 97 | dataset = INatDataset(args.data_path, train=is_train, year=2018, 98 | category=args.inat_category, transform=transform) 99 | nb_classes = dataset.nb_classes 100 | elif args.data_set == 'INAT19': 101 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 102 | category=args.inat_category, transform=transform) 103 | nb_classes = dataset.nb_classes 104 | 105 | 106 | 107 | return dataset, nb_classes 108 | 109 | 110 | 111 | 112 | def build_transform(is_train, args): 113 | resize_im = args.input_size > 32 114 | if is_train: 115 | # this should always dispatch to transforms_imagenet_train 116 | transform = create_transform( 117 | input_size=args.input_size, 118 | is_training=True, 119 | color_jitter=args.color_jitter, 120 | auto_augment=args.aa, 121 | interpolation=args.train_interpolation, 122 | re_prob=args.reprob, 123 | re_mode=args.remode, 124 | re_count=args.recount, 125 | ) 126 | if not resize_im: 127 | # replace RandomResizedCropAndInterpolation with 128 | # RandomCrop 129 | transform.transforms[0] = transforms.RandomCrop( 130 | args.input_size, padding=4) 131 | return transform 132 | 133 | 134 | t = [] 135 | if resize_im: 136 | size = int((256 / 224) * args.input_size) 137 | t.append( 138 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 139 | ) 140 | t.append(transforms.CenterCrop(args.input_size)) 141 | 142 | 143 | t.append(transforms.ToTensor()) 144 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 145 | return transforms.Compose(t) 146 | 147 | 148 | 149 | -------------------------------------------------------------------------------- /Scaled_Attention/engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Train and eval functions used in main_train.py 5 | """ 6 | import math 7 | import sys 8 | from typing import Iterable, Optional 9 | from fvcore.nn import FlopCountAnalysis 10 | import wandb 11 | import numpy as np 12 | 13 | 14 | import torch 15 | 16 | 17 | from timm.data import Mixup 18 | from timm.utils import accuracy, ModelEma 19 | 20 | from losses import DistillationLoss 21 | import utils 22 | import pdb 23 | 24 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 25 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 26 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 27 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 28 | set_training_mode=True,wandb=False): 29 | # put our model in training mode... so that drop out and batch normalisation does not affect it 30 | model.train(set_training_mode) 31 | metric_logger = utils.MetricLogger(delimiter=" ") 32 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 33 | header = 'Epoch: [{}]'.format(epoch) 34 | print_freq = 10 35 | 36 | # i = 0. 37 | for i, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 38 | samples = samples.to(device, non_blocking=True) 39 | targets = targets.to(device, non_blocking=True) 40 | # if i == 50: 41 | # break 42 | if mixup_fn is not None: 43 | samples, targets = mixup_fn(samples, targets) 44 | 45 | 46 | with torch.cuda.amp.autocast(): 47 | # flops = FlopCountAnalysis(model,samples) 48 | # print(flops.total()/1e9) 49 | # assert 1==2 50 | outputs = model(samples) 51 | loss = criterion(samples, outputs, targets) 52 | 53 | loss_value = loss.item() 54 | 55 | if not math.isfinite(loss_value): 56 | print("Loss is {}, stopping training".format(loss_value)) 57 | f = open("error.txt", "a") 58 | # writing in the file 59 | f.write("Loss is {}, stopping training".format(loss_value)) 60 | # closing the file 61 | f.close() 62 | sys.exit(1) 63 | 64 | 65 | optimizer.zero_grad() 66 | 67 | 68 | # this attribute is added by timm on one optimizer (adahessian) 69 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 70 | # provides optimisation step for model 71 | loss_scaler(loss, optimizer, clip_grad=max_norm, 72 | parameters=model.parameters(), create_graph=is_second_order) 73 | 74 | 75 | torch.cuda.synchronize() 76 | if model_ema is not None: 77 | model_ema.update(model) 78 | 79 | 80 | metric_logger.update(loss=loss_value) 81 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 82 | # gather the stats from all processes 83 | metric_logger.synchronize_between_processes() 84 | print("Averaged stats:", metric_logger) 85 | if wandb: 86 | for k, meter in metric_logger.meters.items(): 87 | wandb.log({k: meter.global_avg, 'epoch': epoch}) 88 | 89 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 90 | 91 | @torch.no_grad() 92 | def evaluate(data_loader, model, device, attn_only=False, batch_limit=0, epoch=0, wandb=False): 93 | criterion = torch.nn.CrossEntropyLoss() 94 | 95 | 96 | metric_logger = utils.MetricLogger(delimiter=" ") 97 | header = 'Test:' 98 | 99 | # switch to evaluation mode 100 | model.eval() 101 | # i = 0 102 | if not isinstance(batch_limit, int) or batch_limit < 0: 103 | batch_limit = 0 104 | attn = [] 105 | pi = [] 106 | for i, (images, target) in enumerate(metric_logger.log_every(data_loader, 10, header)): 107 | if i >= batch_limit > 0: 108 | break 109 | images = images.to(device, non_blocking=True) 110 | target = target.to(device, non_blocking=True) 111 | 112 | 113 | with torch.cuda.amp.autocast(): 114 | if attn_only: 115 | output, _aux = model(images) 116 | attn.append(_aux[0].detach().cpu().numpy()) 117 | pi.append(_aux[1].detach().cpu().numpy()) 118 | del _aux 119 | else: 120 | output = model(images) 121 | loss = criterion(output, target) 122 | 123 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 124 | 125 | batch_size = images.shape[0] 126 | metric_logger.update(loss=loss.item()) 127 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 128 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 129 | r = {k: meter.global_avg for k, meter in metric_logger.meters.items()} 130 | # gather the stats from all processes 131 | metric_logger.synchronize_between_processes() 132 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 133 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 134 | if wandb: 135 | for k, meter in metric_logger.meters.items(): 136 | wandb.log({f'test_{k}': meter.global_avg , 'epoch':epoch}) 137 | 138 | if attn_only: 139 | return r, (attn, pi) 140 | return r 141 | 142 | 143 | -------------------------------------------------------------------------------- /Scaled_Attention/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Implements the knowledge distillation loss 5 | """ 6 | import torch 7 | from torch.nn import functional as F 8 | import timm 9 | 10 | class SoftTargetCrossEntropy(torch.nn.Module): 11 | 12 | def __init__(self): 13 | super(SoftTargetCrossEntropy, self).__init__() 14 | 15 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 16 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) 17 | return loss.mean() 18 | 19 | class DistillationLoss(torch.nn.Module): 20 | """ 21 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 22 | taking a teacher model prediction and using it as additional supervision. 23 | """ 24 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 25 | distillation_type: str, alpha: float, tau: float): 26 | super().__init__() 27 | self.base_criterion = base_criterion 28 | self.teacher_model = teacher_model 29 | assert distillation_type in ['none', 'soft', 'hard'] 30 | self.distillation_type = distillation_type 31 | self.alpha = alpha 32 | self.tau = tau 33 | 34 | 35 | def forward(self, inputs, outputs, labels): 36 | """ 37 | Args: 38 | inputs: The original inputs that are feed to the teacher model 39 | outputs: the outputs of the model to be trained. It is expected to be 40 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 41 | in the first position and the distillation predictions as the second output 42 | labels: the labels for the base criterion 43 | """ 44 | outputs_kd = None 45 | if not isinstance(outputs, torch.Tensor): 46 | # assume that the model outputs a tuple of [outputs, outputs_kd] 47 | outputs, outputs_kd = outputs 48 | base_loss = self.base_criterion(outputs, labels) 49 | if self.distillation_type == 'none': 50 | return base_loss 51 | 52 | 53 | if outputs_kd is None: 54 | raise ValueError("When knowledge distillation is enabled, the model is " 55 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 56 | "class_token and the dist_token") 57 | # don't backprop throught the teacher 58 | with torch.no_grad(): 59 | teacher_outputs = self.teacher_model(inputs) 60 | 61 | 62 | if self.distillation_type == 'soft': 63 | T = self.tau 64 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 65 | # with slight modifications 66 | distillation_loss = F.kl_div( 67 | F.log_softmax(outputs_kd / T, dim=1), 68 | F.log_softmax(teacher_outputs / T, dim=1), 69 | reduction='sum', 70 | log_target=True 71 | ) * (T * T) / outputs_kd.numel() 72 | elif self.distillation_type == 'hard': 73 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) 74 | 75 | 76 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 77 | return loss 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /Scaled_Attention/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.nn as nn 5 | from functools import partial 6 | 7 | 8 | from timm.models.vision_transformer import _cfg 9 | from softmax import VisionTransformer 10 | from timm.models.registry import register_model 11 | from timm.models.layers import trunc_normal_ 12 | # from xcit import XCiT, HDPXCiT 13 | 14 | class DistilledVisionTransformer(VisionTransformer): 15 | def __init__(self, *args, **kwargs): 16 | super().__init__(*args, **kwargs) 17 | self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 18 | num_patches = self.patch_embed.num_patches 19 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) 20 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() 21 | 22 | trunc_normal_(self.dist_token, std=.02) 23 | trunc_normal_(self.pos_embed, std=.02) 24 | self.head_dist.apply(self._init_weights) 25 | 26 | def forward_features(self, x): 27 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 28 | # with slight modifications to add the dist_token 29 | B = x.shape[0] 30 | x = self.patch_embed(x) 31 | 32 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 33 | dist_token = self.dist_token.expand(B, -1, -1) 34 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 35 | 36 | x = x + self.pos_embed 37 | x = self.pos_drop(x) 38 | 39 | for blk in self.blocks: 40 | x = blk(x) 41 | 42 | 43 | x = self.norm(x) 44 | return x[:, 0], x[:, 1] 45 | 46 | def forward(self, x): 47 | x, x_dist = self.forward_features(x) 48 | x = self.head(x) 49 | x_dist = self.head_dist(x_dist) 50 | if self.training: 51 | return x, x_dist 52 | else: 53 | # during inference, return the average of both classifier predictions 54 | return (x + x_dist) / 2 55 | 56 | # register model with timms to be able to call it from "create_model" using its function name 57 | # but mainly edit the model from softmax.py 58 | @register_model 59 | def deit_tiny_patch16_224(pretrained=True, **kwargs): 60 | from softmax import VisionTransformer 61 | model = VisionTransformer( 62 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 63 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) # Tan's NOTE: in the original code, num_heads = 3 here 64 | model.default_cfg = _cfg() 65 | return model 66 | 67 | 68 | -------------------------------------------------------------------------------- /Scaled_Attention/requirements.txt: -------------------------------------------------------------------------------- 1 | fvcore==0.1.5.post20221221 2 | numpy==1.25.2 3 | timm==0.9.7 4 | torch==2.0.1 5 | torchvision==0.15.2 6 | -------------------------------------------------------------------------------- /Scaled_Attention/run.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch --master_port 1 --nproc_per_node=4 --use_env main_train.py \ 2 | --model deit_tiny_patch16_224 --batch-size 256 --data-path /path/to/imagenet/ --output_dir /path/to/output/directory/ 3 | -------------------------------------------------------------------------------- /Scaled_Attention/samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.distributed as dist 5 | import math 6 | 7 | 8 | 9 | 10 | class RASampler(torch.utils.data.Sampler): 11 | """Sampler that restricts data loading to a subset of the dataset for distributed, 12 | with repeated augmentation. 13 | It ensures that different each augmented version of a sample will be visible to a 14 | different process (GPU) 15 | Heavily based on torch.utils.data.DistributedSampler 16 | """ 17 | 18 | # num_replicas = world size, rank = global rank 19 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 20 | if num_replicas is None: 21 | if not dist.is_available(): 22 | raise RuntimeError("Requires distributed package to be available") 23 | num_replicas = dist.get_world_size() 24 | if rank is None: 25 | if not dist.is_available(): 26 | raise RuntimeError("Requires distributed package to be available") 27 | rank = dist.get_rank() 28 | self.dataset = dataset 29 | self.num_replicas = num_replicas 30 | self.rank = rank 31 | self.epoch = 0 32 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 33 | self.total_size = self.num_samples * self.num_replicas 34 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 35 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 36 | self.shuffle = shuffle 37 | 38 | 39 | def __iter__(self): 40 | # deterministically shuffle based on epoch 41 | g = torch.Generator() 42 | g.manual_seed(self.epoch) 43 | if self.shuffle: 44 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 45 | else: 46 | indices = list(range(len(self.dataset))) 47 | 48 | 49 | # add extra samples to make it evenly divisible by 3 50 | indices = [ele for ele in indices for i in range(3)] 51 | indices += indices[:(self.total_size - len(indices))] 52 | assert len(indices) == self.total_size 53 | 54 | 55 | # subsample 56 | indices = indices[self.rank:self.total_size:self.num_replicas] 57 | assert len(indices) == self.num_samples 58 | 59 | 60 | return iter(indices[:self.num_selected_samples]) 61 | 62 | 63 | def __len__(self): 64 | return self.num_selected_samples 65 | 66 | 67 | def set_epoch(self, epoch): 68 | self.epoch = epoch 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /Scaled_Attention/softmax.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | from functools import partial 4 | from collections import OrderedDict 5 | from copy import deepcopy 6 | from statistics import mean 7 | import numpy as np 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ 14 | from timm.models.vision_transformer import init_weights_vit_timm, _load_weights, init_weights_vit_jax 15 | from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv 16 | from utils import named_apply 17 | import copy 18 | 19 | 20 | 21 | class Attention(nn.Module): 22 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., layerth=0, ttl_tokens=0,s_scalar=False): 23 | super().__init__() 24 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 25 | self.num_heads = num_heads 26 | self.ttl_tokens = ttl_tokens 27 | self.layerth = layerth 28 | self.s_scalar = s_scalar 29 | head_dim = dim // num_heads 30 | # sqrt (D) 31 | self.scale = head_dim ** -0.5 32 | 33 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 34 | if self.s_scalar: 35 | self.s = nn.Parameter(torch.zeros(1)) 36 | else: 37 | self.s = nn.Parameter(torch.zeros(self.num_heads, self.ttl_tokens, self.ttl_tokens)) 38 | 39 | self.attn_drop = nn.Dropout(attn_drop) 40 | 41 | self.proj = nn.Linear(dim, dim) 42 | self.proj_drop = nn.Dropout(proj_drop) 43 | 44 | def forward(self, x): 45 | B, N, C = x.shape 46 | # q,k -> B -> heads -> n -> features 47 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 48 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 49 | 50 | attn = (q @ k.transpose(-2, -1)) * self.scale 51 | attn = attn.softmax(dim=-1) 52 | 53 | attn = self.attn_drop(attn) 54 | 55 | I = torch.eye(N,N).unsqueeze(dim=0).unsqueeze(dim=0).expand(B,self.num_heads,N,N).to(torch.device("cuda"), non_blocking=True) 56 | if self.s_scalar: 57 | sym_attn = (k @ k.transpose(-2, -1)) * self.scale 58 | sym_attn = sym_attn.softmax(dim=-1) 59 | v = (I-sym_attn * self.s) @ v 60 | else: 61 | s = self.s.unsqueeze(dim=0).expand(B,self.num_heads,N,N) 62 | v = (I-s) @ v 63 | 64 | x = (attn @ v) 65 | x = x.transpose(1, 2).reshape(B,N,C) 66 | 67 | x = self.proj(x) 68 | x = self.proj_drop(x) 69 | 70 | return x 71 | 72 | 73 | class Block(nn.Module): 74 | 75 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 76 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, layerth = None,ttl_tokens=0,s_scalar=False): 77 | super().__init__() 78 | self.norm1 = norm_layer(dim) 79 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, 80 | attn_drop=attn_drop, proj_drop=drop,layerth=layerth,ttl_tokens=ttl_tokens,s_scalar=s_scalar) 81 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 82 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 83 | self.norm2 = norm_layer(dim) 84 | mlp_hidden_dim = int(dim * mlp_ratio) 85 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 86 | self.layerth = layerth 87 | 88 | def forward(self, x): 89 | 90 | x = x + self.drop_path(self.attn(self.norm1(x))) 91 | x = x + self.drop_path(self.mlp(self.norm2(x))) 92 | return x 93 | 94 | 95 | class VisionTransformer(nn.Module): 96 | """ Vision Transformer 97 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 98 | - https://arxiv.org/abs/2010.11929 99 | Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` 100 | - https://arxiv.org/abs/2012.12877 101 | """ 102 | 103 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 104 | num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, 105 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, 106 | act_layer=None, weight_init='',pretrained_cfg=None,pretrained_cfg_overlay=None,s_scalar=False): 107 | """ 108 | Args: 109 | img_size (int, tuple): input image size 110 | patch_size (int, tuple): patch size 111 | in_chans (int): number of input channels 112 | num_classes (int): number of classes for classification head 113 | embed_dim (int): embedding dimension 114 | depth (int): depth of transformer 115 | num_heads (int): number of attention heads 116 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 117 | qkv_bias (bool): enable bias for qkv if True 118 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 119 | distilled (bool): model includes a distillation token and head as in DeiT models 120 | drop_rate (float): dropout rate 121 | attn_drop_rate (float): attention dropout rate 122 | drop_path_rate (float): stochastic depth rate 123 | embed_layer (nn.Module): patch embedding layer 124 | norm_layer: (nn.Module): normalization layer 125 | weight_init: (str): weight init scheme 126 | """ 127 | super().__init__() 128 | self.num_classes = num_classes 129 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 130 | self.num_tokens = 2 if distilled else 1 131 | self.s_scalar = s_scalar 132 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 133 | act_layer = act_layer or nn.GELU 134 | 135 | self.patch_embed = embed_layer( 136 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 137 | num_patches = self.patch_embed.num_patches 138 | 139 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 140 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None 141 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 142 | self.pos_drop = nn.Dropout(p=drop_rate) 143 | 144 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 145 | self.blocks = nn.Sequential(*[ 146 | Block( 147 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, 148 | attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, layerth = i, 149 | ttl_tokens=num_patches+self.num_tokens,s_scalar=self.s_scalar) 150 | for i in range(depth)]) 151 | self.norm = norm_layer(embed_dim) 152 | 153 | # Representation layer 154 | if representation_size and not distilled: 155 | self.num_features = representation_size 156 | self.pre_logits = nn.Sequential(OrderedDict([f 157 | ('fc', nn.Linear(embed_dim, representation_size)), 158 | ('act', nn.Tanh()) 159 | ])) 160 | else: 161 | self.pre_logits = nn.Identity() 162 | 163 | # Classifier head(s) 164 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 165 | self.head_dist = None 166 | if distilled: 167 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 168 | 169 | self.init_weights(weight_init) 170 | 171 | def init_weights(self, mode=''): 172 | assert mode in ('jax', 'jax_nlhb', 'nlhb', '') 173 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. 174 | trunc_normal_(self.pos_embed, std=.02) 175 | if self.dist_token is not None: 176 | trunc_normal_(self.dist_token, std=.02) 177 | if mode.startswith('jax'): 178 | # leave cls token as zeros to match jax impl 179 | partial(init_weights_vit_jax(mode, head_bias), head_bias=head_bias, jax_impl=True) 180 | else: 181 | trunc_normal_(self.cls_token, std=.02) 182 | init_weights_vit_timm 183 | 184 | def _init_weights(self, m): 185 | # this fn left here for compat with downstream users 186 | init_weights(m) 187 | 188 | @torch.jit.ignore() 189 | def load_pretrained(self, checkpoint_path, prefix=''): 190 | _load_weights(self, checkpoint_path, prefix) 191 | 192 | @torch.jit.ignore 193 | def no_weight_decay(self): 194 | return {'pos_embed', 'cls_token', 'dist_token'} 195 | 196 | def get_classifier(self): 197 | if self.dist_token is None: 198 | return self.head 199 | else: 200 | return self.head, self.head_dist 201 | 202 | def reset_classifier(self, num_classes, global_pool=''): 203 | self.num_classes = num_classes 204 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 205 | if self.num_tokens == 2: 206 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 207 | 208 | def forward_features(self, x): 209 | x = self.patch_embed(x) 210 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks 211 | if self.dist_token is None: 212 | x = torch.cat((cls_token, x), dim=1) 213 | else: 214 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) 215 | x = self.pos_drop(x + self.pos_embed) 216 | x = self.blocks(x) 217 | x = self.norm(x) 218 | 219 | if self.dist_token is None: 220 | return self.pre_logits(x[:, 0]) 221 | else: 222 | return x[:, 0], x[:, 1] 223 | 224 | def forward(self, x): 225 | x = self.forward_features(x) 226 | if self.head_dist is not None: 227 | x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple 228 | if self.training and not torch.jit.is_scripting(): 229 | # during inference, return the average of both classifier predictions 230 | return x, x_dist 231 | else: 232 | return (x + x_dist) / 2 233 | else: 234 | x = self.head(x) 235 | return x 236 | 237 | 238 | 239 | 240 | -------------------------------------------------------------------------------- /Scaled_Attention/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Misc functions, including distributed helpers. 5 | 6 | 7 | Mostly copy-paste from torchvision references. 8 | """ 9 | import io 10 | import os 11 | import time 12 | from collections import defaultdict, deque 13 | import datetime 14 | from typing import Any, Callable, Dict, Iterator, Tuple, Type, Union 15 | 16 | 17 | import torch 18 | import torch.distributed as dist 19 | from torch import nn as nn 20 | 21 | 22 | 23 | class SmoothedValue(object): 24 | """Track a series of values and provide access to smoothed values over a 25 | window or the global series average. 26 | """ 27 | 28 | 29 | def __init__(self, window_size=20, fmt=None): 30 | if fmt is None: 31 | fmt = "{median:.4f} ({global_avg:.4f})" 32 | self.deque = deque(maxlen=window_size) 33 | self.total = 0.0 34 | self.count = 0 35 | self.fmt = fmt 36 | 37 | 38 | def update(self, value, n=1): 39 | self.deque.append(value) 40 | self.count += n 41 | self.total += value * n 42 | 43 | 44 | def synchronize_between_processes(self): 45 | """ 46 | Warning: does not synchronize the deque! 47 | """ 48 | if not is_dist_avail_and_initialized(): 49 | return 50 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 51 | dist.barrier() 52 | dist.all_reduce(t) 53 | t = t.tolist() 54 | self.count = int(t[0]) 55 | self.total = t[1] 56 | 57 | 58 | @property 59 | def median(self): 60 | d = torch.tensor(list(self.deque)) 61 | return d.median().item() 62 | 63 | 64 | @property 65 | def avg(self): 66 | d = torch.tensor(list(self.deque), dtype=torch.float32) 67 | return d.mean().item() 68 | 69 | 70 | @property 71 | def global_avg(self): 72 | return self.total / self.count 73 | 74 | 75 | @property 76 | def max(self): 77 | return max(self.deque) 78 | 79 | 80 | @property 81 | def value(self): 82 | return self.deque[-1] 83 | 84 | 85 | def __str__(self): 86 | return self.fmt.format( 87 | median=self.median, 88 | avg=self.avg, 89 | global_avg=self.global_avg, 90 | max=self.max, 91 | value=self.value) 92 | 93 | 94 | 95 | 96 | class MetricLogger(object): 97 | def __init__(self, delimiter="\t"): 98 | self.meters = defaultdict(SmoothedValue) 99 | self.delimiter = delimiter 100 | 101 | 102 | def update(self, **kwargs): 103 | for k, v in kwargs.items(): 104 | if isinstance(v, torch.Tensor): 105 | v = v.item() 106 | assert isinstance(v, (float, int)) 107 | self.meters[k].update(v) 108 | 109 | 110 | def __getattr__(self, attr): 111 | if attr in self.meters: 112 | return self.meters[attr] 113 | if attr in self.__dict__: 114 | return self.__dict__[attr] 115 | raise AttributeError("'{}' object has no attribute '{}'".format( 116 | type(self).__name__, attr)) 117 | 118 | 119 | def __str__(self): 120 | loss_str = [] 121 | for name, meter in self.meters.items(): 122 | loss_str.append( 123 | "{}: {}".format(name, str(meter)) 124 | ) 125 | return self.delimiter.join(loss_str) 126 | 127 | 128 | def synchronize_between_processes(self): 129 | for meter in self.meters.values(): 130 | meter.synchronize_between_processes() 131 | 132 | 133 | def add_meter(self, name, meter): 134 | self.meters[name] = meter 135 | 136 | 137 | #iterable is our data_loader which is pytorch data loader with our dataset_train obj and RA sampler 138 | def log_every(self, iterable, print_freq, header=None): 139 | i = 0 140 | if not header: 141 | header = '' 142 | start_time = time.time() 143 | end = time.time() 144 | iter_time = SmoothedValue(fmt='{avg:.4f}') 145 | data_time = SmoothedValue(fmt='{avg:.4f}') 146 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 147 | log_msg = [ 148 | header, 149 | '[{0' + space_fmt + '}/{1}]', 150 | 'eta: {eta}', 151 | '{meters}', 152 | 'time: {time}', 153 | 'data: {data}' 154 | ] 155 | if torch.cuda.is_available(): 156 | log_msg.append('max mem: {memory:.0f}') 157 | log_msg = self.delimiter.join(log_msg) 158 | MB = 1024.0 * 1024.0 159 | for obj in iterable: 160 | data_time.update(time.time() - end) 161 | # returns obj to caller, then continues loop 162 | yield obj 163 | iter_time.update(time.time() - end) 164 | if i % print_freq == 0 or i == len(iterable) - 1: 165 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 166 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 167 | if torch.cuda.is_available(): 168 | print(log_msg.format( 169 | i, len(iterable), eta=eta_string, 170 | meters=str(self), 171 | time=str(iter_time), data=str(data_time), 172 | memory=torch.cuda.max_memory_allocated() / MB)) 173 | else: 174 | print(log_msg.format( 175 | i, len(iterable), eta=eta_string, 176 | meters=str(self), 177 | time=str(iter_time), data=str(data_time))) 178 | i += 1 179 | end = time.time() 180 | # need to remove this! 181 | total_time = time.time() - start_time 182 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 183 | print('{} Total time: {} ({:.4f} s / it)'.format( 184 | header, total_time_str, total_time / len(iterable))) 185 | 186 | 187 | 188 | 189 | def _load_checkpoint_for_ema(model_ema, checkpoint): 190 | """ 191 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 192 | """ 193 | mem_file = io.BytesIO() 194 | torch.save(checkpoint, mem_file) 195 | mem_file.seek(0) 196 | model_ema._load_checkpoint(mem_file) 197 | 198 | 199 | 200 | 201 | def setup_for_distributed(is_master): 202 | """ 203 | This function disables printing when not in master process 204 | """ 205 | import builtins as __builtin__ 206 | builtin_print = __builtin__.print 207 | 208 | 209 | def print(*args, **kwargs): 210 | force = kwargs.pop('force', False) 211 | if is_master or force: 212 | builtin_print(*args, **kwargs) 213 | 214 | 215 | __builtin__.print = print 216 | 217 | 218 | 219 | 220 | def is_dist_avail_and_initialized(): 221 | if not dist.is_available(): 222 | return False 223 | if not dist.is_initialized(): 224 | return False 225 | return True 226 | 227 | 228 | 229 | 230 | def get_world_size(): 231 | if not is_dist_avail_and_initialized(): 232 | return 1 233 | return dist.get_world_size() 234 | 235 | 236 | 237 | 238 | def get_rank(): 239 | if not is_dist_avail_and_initialized(): 240 | return 0 241 | return dist.get_rank() 242 | 243 | 244 | 245 | 246 | def is_main_process(): 247 | return get_rank() == 0 248 | 249 | 250 | 251 | 252 | def save_on_master(*args, **kwargs): 253 | if is_main_process(): 254 | torch.save(*args, **kwargs) 255 | 256 | 257 | 258 | 259 | def init_distributed_mode(args): 260 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 261 | args.rank = int(os.environ["RANK"]) 262 | args.world_size = int(os.environ['WORLD_SIZE']) 263 | args.gpu = int(os.environ['LOCAL_RANK']) 264 | elif 'SLURM_PROCID' in os.environ: 265 | args.rank = int(os.environ['SLURM_PROCID']) 266 | args.gpu = args.rank % torch.cuda.device_count() 267 | else: 268 | print('Not using distributed mode') 269 | args.distributed = False 270 | return 271 | 272 | 273 | args.distributed = True 274 | 275 | 276 | torch.cuda.set_device(args.gpu) 277 | args.dist_backend = 'nccl' 278 | print('| distributed init (rank {}): {}'.format( 279 | args.rank, args.dist_url), flush=True) 280 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 281 | world_size=args.world_size, rank=args.rank) 282 | torch.distributed.barrier() 283 | setup_for_distributed(args.rank == 0) 284 | 285 | def named_apply( 286 | fn: Callable, 287 | module: nn.Module, name='', 288 | depth_first: bool = True, 289 | include_root: bool = False, 290 | ) -> nn.Module: 291 | if not depth_first and include_root: 292 | fn(module=module, name=name) 293 | for child_name, child_module in module.named_children(): 294 | child_name = '.'.join((name, child_name)) if name else child_name 295 | named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) 296 | if depth_first and include_root: 297 | fn(module=module, name=name) 298 | return module 299 | --------------------------------------------------------------------------------