├── utils ├── __init__.py ├── optim.py ├── misc.py ├── masking.py └── experiman.py ├── models ├── __init__.py ├── clip │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── templates │ │ ├── simple_template.py │ │ ├── iwildcam_template.py │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── fmow_template.py │ │ ├── cifar_template.py │ │ └── openai_imagenet_template.py │ ├── README.md │ ├── __init__.py │ ├── zeroshot.py │ ├── simple_tokenizer.py │ ├── clip.py │ └── model.py └── classification.py ├── trainers ├── __init__.py ├── distill_trainer.py ├── standard_trainer.py └── base.py ├── data ├── imagenet_sketch.py ├── imagenet_v2.py ├── __init__.py ├── imagenet_r.py ├── imagenet_a.py ├── base.py ├── objectnet.py └── imagenet.py ├── losses └── distill.py ├── example.sh ├── README.md ├── .gitignore ├── run.sh ├── eval.py ├── LICENSE ├── main_standard.py └── main_distill.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import get_clip_model -------------------------------------------------------------------------------- /models/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Coxy7/robust-finetuning/HEAD/models/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /models/clip/templates/simple_template.py: -------------------------------------------------------------------------------- 1 | from .utils import append_proper_article 2 | 3 | simple_template = [ 4 | lambda c: f"a photo of a {c}." 5 | ] -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import OptimizerWithSchedule 2 | from .standard_trainer import StandardTrainer, StandardLoopConfig 3 | from .distill_trainer import DistillTrainer -------------------------------------------------------------------------------- /models/clip/templates/iwildcam_template.py: -------------------------------------------------------------------------------- 1 | from .utils import append_proper_article, get_plural 2 | 3 | iwildcam_template = [ 4 | lambda c: f"a photo of {c}.", 5 | lambda c: f"{c} in the wild.", 6 | ] -------------------------------------------------------------------------------- /models/clip/README.md: -------------------------------------------------------------------------------- 1 | Modified from: 2 | https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip 3 | https://github.com/mlfoundations/wise-ft/blob/58b7a4b343b09dc06606aa929c2ef51accced8d1/src 4 | -------------------------------------------------------------------------------- /models/clip/templates/__init__.py: -------------------------------------------------------------------------------- 1 | from .cifar_template import cifar_template 2 | from .openai_imagenet_template import openai_imagenet_template 3 | from .simple_template import simple_template 4 | from .fmow_template import fmow_template 5 | from .iwildcam_template import iwildcam_template -------------------------------------------------------------------------------- /data/imagenet_sketch.py: -------------------------------------------------------------------------------- 1 | from .imagenet import EvaluationDataset 2 | 3 | 4 | class ImageNetSketchDataset(EvaluationDataset): 5 | 6 | def __init__(self, data_dir, size=224, interpolation='bicubic', transform='std'): 7 | super().__init__(data_dir, size, interpolation, transform) 8 | self.sub_dir = 'imagenet-sketch' 9 | -------------------------------------------------------------------------------- /data/imagenet_v2.py: -------------------------------------------------------------------------------- 1 | from .imagenet import EvaluationDataset 2 | 3 | 4 | # https://github.com/modestyachts/ImageNetV2/issues/6 5 | IMAGENET_V2_CLASSES = [int(s) for s in sorted([str(i) for i in range(1000)])] 6 | 7 | 8 | class ImageNetV2Dataset(EvaluationDataset): 9 | 10 | def __init__(self, data_dir, size=224, interpolation='bicubic', transform='std'): 11 | super().__init__(data_dir, size, interpolation, transform) 12 | self.sub_dir = 'imagenetv2-matched-frequency' 13 | self.target_transform = (lambda y: IMAGENET_V2_CLASSES[y]) 14 | -------------------------------------------------------------------------------- /losses/distill.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class KnowledgeDistillationLoss(nn.Module): 6 | 7 | def __init__(self, temp=10): 8 | super().__init__() 9 | self.criterion_kl = nn.KLDivLoss(reduction='batchmean') 10 | self.t = temp 11 | 12 | def forward(self, logits, soft_labels): 13 | kl = (self.t * self.t) * self.criterion_kl( 14 | F.log_softmax(logits / self.t, dim=1), 15 | F.softmax(soft_labels / self.t, dim=1) 16 | ) 17 | return kl 18 | -------------------------------------------------------------------------------- /models/clip/templates/utils.py: -------------------------------------------------------------------------------- 1 | 2 | def get_plural(name): 3 | name = name.replace('_', ' ') 4 | if name[-2:] == 'sh': 5 | name = name + 'es' 6 | elif name[-2:] == 'ch': 7 | name = name + 'es' 8 | elif name[-1:] == 'y': 9 | name = name[:-1] + 'ies' 10 | elif name[-1:] == 's': 11 | name = name + 'es' 12 | elif name[-1:] == 'x': 13 | name = name + 'es' 14 | elif name[-3:] == 'man': 15 | name = name[:-3] + 'men' 16 | elif name == 'mouse': 17 | name = 'mice' 18 | elif name[-1:] == 'f': 19 | name = name[:-1] + 'ves' 20 | else: 21 | name = name + 's' 22 | return name 23 | 24 | 25 | def append_proper_article(name): 26 | name = name.replace('_', ' ') 27 | if name[0] in 'aeiou': 28 | return 'an ' + name 29 | return 'a ' + name 30 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .imagenet import ImageNetDataset 2 | from .imagenet_a import ImageNetADataset 3 | from .imagenet_r import ImageNetRDataset 4 | from .imagenet_sketch import ImageNetSketchDataset 5 | from .imagenet_v2 import ImageNetV2Dataset 6 | from .objectnet import ObjectNetDataset 7 | 8 | 9 | def get_dataset(name, data_dir, **kwargs): 10 | if name == 'imagenet': 11 | return ImageNetDataset(data_dir, **kwargs) 12 | elif name == 'imagenet_a': 13 | return ImageNetADataset(data_dir, **kwargs) 14 | elif name == 'imagenet_r': 15 | return ImageNetRDataset(data_dir, **kwargs) 16 | elif name == 'imagenet_sketch': 17 | return ImageNetSketchDataset(data_dir, **kwargs) 18 | elif name == 'imagenet_v2': 19 | return ImageNetV2Dataset(data_dir, **kwargs) 20 | elif name == 'objectnet': 21 | return ObjectNetDataset(data_dir, **kwargs) 22 | else: 23 | raise NotImplementedError() 24 | -------------------------------------------------------------------------------- /models/clip/templates/fmow_template.py: -------------------------------------------------------------------------------- 1 | from .utils import append_proper_article, get_plural 2 | 3 | fmow_template = [ 4 | lambda c : f"satellite photo of a {c}.", 5 | lambda c : f"aerial photo of a {c}.", 6 | lambda c : f"satellite photo of {append_proper_article(c)}.", 7 | lambda c : f"aerial photo of {append_proper_article(c)}.", 8 | lambda c : f"satellite photo of a {c} in asia.", 9 | lambda c : f"aerial photo of a {c} in asia.", 10 | lambda c : f"satellite photo of a {c} in africa.", 11 | lambda c : f"aerial photo of a {c} in africa.", 12 | lambda c : f"satellite photo of a {c} in the americas.", 13 | lambda c : f"aerial photo of a {c} in the americas.", 14 | lambda c : f"satellite photo of a {c} in europe.", 15 | lambda c : f"aerial photo of a {c} in europe.", 16 | lambda c : f"satellite photo of a {c} in oceania.", 17 | lambda c : f"aerial photo of a {c} in oceania.", 18 | lambda c: f"a photo of a {c}.", 19 | lambda c: f"{c}.", 20 | ] 21 | -------------------------------------------------------------------------------- /models/clip/templates/cifar_template.py: -------------------------------------------------------------------------------- 1 | # https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/data/prompts.md 2 | 3 | 4 | cifar_template = [ 5 | lambda c: f'a photo of a {c}.', 6 | lambda c: f'a blurry photo of a {c}.', 7 | lambda c: f'a black and white photo of a {c}.', 8 | lambda c: f'a low contrast photo of a {c}.', 9 | lambda c: f'a high contrast photo of a {c}.', 10 | lambda c: f'a bad photo of a {c}.', 11 | lambda c: f'a good photo of a {c}.', 12 | lambda c: f'a photo of a small {c}.', 13 | lambda c: f'a photo of a big {c}.', 14 | lambda c: f'a photo of the {c}.', 15 | lambda c: f'a blurry photo of the {c}.', 16 | lambda c: f'a black and white photo of the {c}.', 17 | lambda c: f'a low contrast photo of the {c}.', 18 | lambda c: f'a high contrast photo of the {c}.', 19 | lambda c: f'a bad photo of the {c}.', 20 | lambda c: f'a good photo of the {c}.', 21 | lambda c: f'a photo of the small {c}.', 22 | lambda c: f'a photo of the big {c}.', 23 | ] 24 | -------------------------------------------------------------------------------- /models/classification.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class ClassificationModel(nn.Module): 4 | 5 | def __init__(self, num_classes, backbone, classifier, 6 | preprocess_fn=None, use_dataset_preprocess=True): 7 | super().__init__() 8 | self.num_classes = num_classes 9 | if preprocess_fn is not None: 10 | self.preprocess_fn = preprocess_fn 11 | else: 12 | self.preprocess_fn = lambda x: x 13 | self.use_dataset_preprocess = use_dataset_preprocess 14 | self.backbone = backbone 15 | self.classifier = classifier 16 | 17 | def forward(self, images, get_logits=True, get_features=False, **kwargs): 18 | x = self.preprocess_fn(images) 19 | f = self.backbone(x, **kwargs) 20 | if not get_logits: 21 | return f 22 | y = self.classifier(f) 23 | if get_features: 24 | return y, f 25 | else: 26 | return y 27 | 28 | def freeze_backbone(self, freeze=True): 29 | self.backbone.requires_grad_(not freeze) 30 | 31 | def set_preprocess(self, preprocess_fn): 32 | self.preprocess_fn = preprocess_fn 33 | -------------------------------------------------------------------------------- /data/imagenet_r.py: -------------------------------------------------------------------------------- 1 | from .imagenet import EvaluationDataset 2 | 3 | 4 | IMAGENET_R_CLASSES = [1, 2, 4, 6, 8, 9, 11, 13, 22, 23, 26, 29, 31, 39, 47, 63, 71, 76, 79, 84, 90, 94, 96, 97, 99, 100, 105, 107, 113, 122, 125, 130, 132, 144, 145, 147, 148, 150, 151, 155, 160, 161, 162, 163, 171, 172, 178, 187, 195, 199, 203, 207, 208, 219, 231, 232, 234, 235, 242, 245, 247, 250, 251, 254, 259, 260, 263, 265, 267, 269, 276, 277, 281, 288, 289, 291, 292, 293, 296, 299, 301, 308, 309, 310, 311, 314, 315, 319, 323, 327, 330, 334, 335, 337, 338, 340, 341, 344, 347, 353, 355, 361, 362, 365, 366, 367, 368, 372, 388, 390, 393, 397, 401, 407, 413, 414, 425, 428, 430, 435, 437, 441, 447, 448, 457, 462, 463, 469, 470, 471, 472, 476, 483, 487, 515, 546, 555, 558, 570, 579, 583, 587, 593, 594, 596, 609, 613, 617, 621, 629, 637, 657, 658, 701, 717, 724, 763, 768, 774, 776, 779, 780, 787, 805, 812, 815, 820, 824, 833, 847, 852, 866, 875, 883, 889, 895, 907, 928, 931, 932, 933, 934, 936, 937, 943, 945, 947, 948, 949, 951, 953, 954, 957, 963, 965, 967, 980, 981, 983, 988] 5 | 6 | 7 | class ImageNetRDataset(EvaluationDataset): 8 | 9 | def __init__(self, data_dir, size=224, interpolation='bicubic', transform='std'): 10 | super().__init__(data_dir, size, interpolation, transform) 11 | self.sub_dir = 'imagenet-r' 12 | concerned_classes = set(IMAGENET_R_CLASSES) 13 | self.ignored_classes = [i for i in range(self.num_classes) 14 | if i not in concerned_classes] 15 | self.target_transform = (lambda y: IMAGENET_R_CLASSES[y]) 16 | -------------------------------------------------------------------------------- /data/imagenet_a.py: -------------------------------------------------------------------------------- 1 | from .imagenet import EvaluationDataset 2 | 3 | 4 | IMAGENET_A_CLASSES = [6, 11, 13, 15, 17, 22, 23, 27, 30, 37, 39, 42, 47, 50, 57, 70, 71, 76, 79, 89, 90, 94, 96, 97, 99, 105, 107, 108, 110, 113, 124, 125, 130, 132, 143, 144, 150, 151, 207, 234, 235, 254, 277, 283, 287, 291, 295, 298, 301, 306, 307, 308, 309, 310, 311, 313, 314, 315, 317, 319, 323, 324, 326, 327, 330, 334, 335, 336, 347, 361, 363, 372, 378, 386, 397, 400, 401, 402, 404, 407, 411, 416, 417, 420, 425, 428, 430, 437, 438, 445, 456, 457, 461, 462, 470, 472, 483, 486, 488, 492, 496, 514, 516, 528, 530, 539, 542, 543, 549, 552, 557, 561, 562, 569, 572, 573, 575, 579, 589, 606, 607, 609, 614, 626, 627, 640, 641, 642, 643, 658, 668, 677, 682, 684, 687, 701, 704, 719, 736, 746, 749, 752, 758, 763, 765, 768, 773, 774, 776, 779, 780, 786, 792, 797, 802, 803, 804, 813, 815, 820, 823, 831, 833, 835, 839, 845, 847, 850, 859, 862, 870, 879, 880, 888, 890, 897, 900, 907, 913, 924, 932, 933, 934, 937, 943, 945, 947, 951, 954, 956, 957, 959, 971, 972, 980, 981, 984, 986, 987, 988] 5 | 6 | 7 | class ImageNetADataset(EvaluationDataset): 8 | 9 | def __init__(self, data_dir, size=224, interpolation='bicubic', transform='std'): 10 | super().__init__(data_dir, size, interpolation, transform) 11 | self.sub_dir = 'imagenet-a' 12 | concerned_classes = set(IMAGENET_A_CLASSES) 13 | self.ignored_classes = [i for i in range(self.num_classes) 14 | if i not in concerned_classes] 15 | self.target_transform = (lambda y: IMAGENET_A_CLASSES[y]) 16 | -------------------------------------------------------------------------------- /models/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import load as load_clip 2 | from .zeroshot import get_zeroshot_classifier, ClassificationHead 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torchvision.transforms.functional as TF 7 | from models.classification import ClassificationModel 8 | 9 | CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) 10 | CLIP_STD = (0.26862954, 0.26130258, 0.27577711) 11 | 12 | 13 | def get_clip_model(arch, dataset, variant, model_dir, device, 14 | get_zeroshot_weights=False): 15 | num_classes = dataset.num_classes 16 | clip_model, _ = load_clip( 17 | name=arch[5:], 18 | device=device, 19 | download_root=model_dir, 20 | ) 21 | clip_model.float() 22 | dtype = clip_model.dtype 23 | def clip_preprocess(images): 24 | images = images.type(dtype) 25 | return TF.normalize(images, CLIP_MEAN, CLIP_STD) 26 | if variant == 'std': # random init linear classifier 27 | backbone = clip_model.visual 28 | classifier = nn.Linear(backbone.output_dim, num_classes, device=device, dtype=dtype) 29 | model = ClassificationModel( 30 | num_classes=num_classes, 31 | backbone=backbone, 32 | classifier=classifier, 33 | preprocess_fn=clip_preprocess, 34 | use_dataset_preprocess=False, 35 | ).to(device) 36 | elif variant == 'zeroshot': 37 | backbone = clip_model.visual 38 | if get_zeroshot_weights: 39 | template = 'openai_imagenet_template' 40 | classifier = get_zeroshot_classifier( 41 | dataset, clip_model, device, 42 | template=template, 43 | ) 44 | else: 45 | classifier = ClassificationHead( 46 | normalize=True, 47 | weights=torch.zeros((num_classes, backbone.output_dim)), 48 | ) 49 | model = ClassificationModel( 50 | num_classes=num_classes, 51 | backbone=backbone, 52 | classifier=classifier, 53 | preprocess_fn=clip_preprocess, 54 | use_dataset_preprocess=False, 55 | ).to(device) 56 | else: 57 | raise NotImplementedError() 58 | return model 59 | -------------------------------------------------------------------------------- /models/clip/zeroshot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import numpy as np 4 | 5 | from . import templates, clip 6 | 7 | 8 | class ClassificationHead(torch.nn.Linear): 9 | def __init__(self, normalize, weights, biases=None): 10 | output_size, input_size = weights.shape 11 | super().__init__(input_size, output_size) 12 | self.normalize = normalize 13 | if weights is not None: 14 | self.weight = torch.nn.Parameter(weights.clone()) 15 | if biases is not None: 16 | self.bias = torch.nn.Parameter(biases.clone()) 17 | else: 18 | self.bias = torch.nn.Parameter(torch.zeros_like(self.bias)) 19 | 20 | def forward(self, inputs): 21 | if self.normalize: 22 | inputs = inputs / inputs.norm(dim=-1, keepdim=True) 23 | return super().forward(inputs) 24 | 25 | 26 | def get_zeroshot_classifier(dataset=None, clip_model=None, device=None, 27 | template='openai_imagenet_template'): 28 | template = getattr(templates, template) 29 | logit_scale = clip_model.logit_scale 30 | clip_model.eval() 31 | clip_model.to(device) 32 | 33 | print('Getting zeroshot weights.') 34 | with torch.no_grad(): 35 | zeroshot_weights = [] 36 | for class_name in tqdm(dataset.class_names): 37 | texts = [] 38 | for t in template: 39 | texts.append(t(class_name)) 40 | texts = clip.tokenize(texts).to(device) # tokenize 41 | embeddings = clip_model.encode_text(texts) # embed with text encoder 42 | embeddings /= embeddings.norm(dim=-1, keepdim=True) 43 | 44 | embeddings = embeddings.mean(dim=0, keepdim=True) 45 | embeddings /= embeddings.norm() 46 | 47 | zeroshot_weights.append(embeddings) 48 | 49 | zeroshot_weights = torch.stack(zeroshot_weights, dim=0).to(device) 50 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 2) 51 | 52 | zeroshot_weights *= logit_scale.exp() 53 | 54 | zeroshot_weights = zeroshot_weights.squeeze().float() 55 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 1) 56 | 57 | return ClassificationHead(normalize=True, weights=zeroshot_weights) 58 | -------------------------------------------------------------------------------- /data/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | from torch.utils.data.distributed import DistributedSampler 4 | from torch.utils.data.dataset import Subset 5 | import torchvision.transforms.functional as TF 6 | from sklearn.model_selection import train_test_split 7 | 8 | 9 | class DatasetWithIndex(torch.utils.data.Dataset): 10 | def __init__(self, dataset): 11 | super().__init__() 12 | self._dataset = dataset 13 | 14 | def __getitem__(self, index): 15 | item = self._dataset[index] 16 | return item, index 17 | 18 | def __len__(self): 19 | return len(self._dataset) 20 | 21 | def __getattr__(self, name): 22 | return getattr(self._dataset, name) 23 | 24 | 25 | 26 | def get_dataloader(dataset, shuffle=False, drop_last=False, with_index=False, num_replicas=1, rank=0, **kwargs): 27 | if with_index: 28 | dataset = DatasetWithIndex(dataset) 29 | if num_replicas > 1: 30 | sampler = DistributedSampler( 31 | dataset, num_replicas, rank, shuffle=shuffle) 32 | loader = torch.utils.data.DataLoader( 33 | dataset, sampler=sampler, drop_last=drop_last, **kwargs) 34 | else: 35 | loader = torch.utils.data.DataLoader( 36 | dataset, shuffle=shuffle, drop_last=drop_last, **kwargs) 37 | return loader 38 | 39 | 40 | def stratified_random_split(dataset, labels, train_split, seed=7): 41 | indices = list(range(len(dataset))) 42 | indices_train, indices_test = train_test_split( 43 | indices, 44 | train_size=train_split, 45 | random_state=seed, 46 | stratify=labels, 47 | ) 48 | trainset = Subset(dataset, indices_train) 49 | testset = Subset(dataset, indices_test) 50 | return trainset, testset 51 | 52 | 53 | class BaseDataset(): 54 | 55 | def __init__(self, data_dir, size=None, mean=None, std=None): 56 | self.data_dir = data_dir 57 | self.size = size 58 | self.mean = mean 59 | self.std = std 60 | 61 | def get_loader(self, batch_size, num_workers, with_index=False): 62 | raise NotImplementedError() 63 | 64 | def preprocess(self, images): 65 | if self.mean: 66 | return TF.normalize(images, self.mean, self.std) 67 | return images 68 | -------------------------------------------------------------------------------- /example.sh: -------------------------------------------------------------------------------- 1 | # This file shows some example usages of run.sh . # The following commands are for training. 2 | # Replace 'train' with 'eval' to evaluate the corresponding model. 3 | 4 | 5 | # Build zero-shot classification model for ImageNet 6 | # This model serves as the teacher model 7 | CUDA_VISIBLE_DEVICES=0 bash run.sh train 'clip_ViT-B/32' 'zeroshot' '' 0 8 | 9 | # Vanilla fine-tuning (FT) 10 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash run.sh train 'clip_ViT-B/32' 'FT' '' 0 11 | 12 | # WiSE-FT (only calculates the ensemble; should be run after FT) 13 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash run.sh train 'clip_ViT-B/32' 'WiSE-FT' '' 0 --wise_alpha 0.5 14 | 15 | # Fine-tuning (FT) + feature-based distillation (FD) without masking 16 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash run.sh train 'clip_ViT-B/32' 'FT_FD' '' 0 17 | 18 | # The proposed method: FT + FD with masked images 19 | # The parameter in the bracket is the masking probability (for random-mask) 20 | # or CAM score threshold (for object-mask / context-mask) 21 | # (random-mask / object-mask / context-mask) + (no-fill) 22 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash run.sh train 'clip_ViT-B/32' 'FT_FD_mae_mask' 'RandMaskNoFill(0.75)' 0 23 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash run.sh train 'clip_ViT-B/32' 'FT_FD_attn_mask' 'ObjMaskNoFill(0.3)' 0 24 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash run.sh train 'clip_ViT-B/32' 'FT_FD_attn_mask' 'CtxMaskNoFill(0.6)' 0 25 | # (random-mask / object-mask / context-mask) + (single-fill) 26 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash run.sh train 'clip_ViT-B/32' 'FT_FD_image_mask' 'RandMaskSingleFill(0.5)' 0 27 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash run.sh train 'clip_ViT-B/32' 'FT_FD_image_mask' 'ObjMaskSingleFill(0.3)' 0 28 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash run.sh train 'clip_ViT-B/32' 'FT_FD_image_mask' 'CtxMaskSingleFill(0.5)' 0 29 | # (random-mask / object-mask / context-mask) + (multi-fill) 30 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash run.sh train 'clip_ViT-B/32' 'FT_FD_image_mask' 'RandMaskMultiFill(0.5)' 0 31 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash run.sh train 'clip_ViT-B/32' 'FT_FD_image_mask' 'ObjMaskMultiFill(0.6)' 0 32 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash run.sh train 'clip_ViT-B/32' 'FT_FD_image_mask' 'CtxMaskMultiFill(0.3)' 0 33 | 34 | # Note: in case the default batch size 512 causes OOM for your GPU devices, 35 | # you may reduce the memory usage while keeping the effective batch size 36 | # by using gradient accumulation (batch * accum_steps = 512): 37 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash run.sh train 'clip_ViT-B/32' 'FT_FD_image_mask' 'ObjMaskSingleFill(0.3)' 0 --batch 256 --accum_steps 2 38 | -------------------------------------------------------------------------------- /utils/optim.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.optim as optim 3 | from trainers import OptimizerWithSchedule 4 | 5 | 6 | def get_optim(parameters, optimizer_name, lr, schedule, 7 | weight_decay, num_epochs, num_iters_train, 8 | cyclic_stepsize=None, multistep_milestones=None, 9 | onecycle_pct_start=0.25, adam_beta=0.5) -> OptimizerWithSchedule: 10 | 11 | if optimizer_name == 'sgd': 12 | optimizer = optim.SGD( 13 | parameters, lr=lr, momentum=0.9, weight_decay=weight_decay, 14 | nesterov=False) 15 | elif optimizer_name == 'sgd_nesterov': 16 | optimizer = optim.SGD( 17 | parameters, lr=lr, momentum=0.9, weight_decay=weight_decay, 18 | nesterov=True) 19 | elif optimizer_name == 'adam': 20 | optimizer = optim.Adam( 21 | parameters, lr=lr, weight_decay=weight_decay, 22 | betas=(adam_beta, 0.999)) 23 | elif optimizer_name == 'adamw': 24 | optimizer = optim.AdamW( 25 | parameters, lr=lr, weight_decay=weight_decay, 26 | betas=(adam_beta, 0.999)) 27 | 28 | if schedule == 'cos': 29 | scheduler = optim.lr_scheduler.CosineAnnealingLR( 30 | optimizer, T_max=num_epochs) 31 | schedule_step = 'epoch' 32 | elif schedule == 'sgdr': 33 | scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( 34 | optimizer, T_0=10, T_mult=1) 35 | schedule_step = 'epoch' 36 | elif schedule == 'cyclic': 37 | if cyclic_stepsize is None: 38 | cyclic_stepsize = 0.5 * num_epochs 39 | scheduler = optim.lr_scheduler.CyclicLR( 40 | optimizer, base_lr=0, max_lr=lr, 41 | cycle_momentum=(optimizer_name == 'sgd'), 42 | step_size_up=int(cyclic_stepsize * num_iters_train)) 43 | schedule_step = 'iter' 44 | elif schedule == '1cycle': 45 | scheduler = optim.lr_scheduler.OneCycleLR( 46 | optimizer, max_lr=lr, 47 | epochs=num_epochs, steps_per_epoch=num_iters_train, 48 | pct_start=onecycle_pct_start, anneal_strategy='cos') 49 | schedule_step = 'iter' 50 | elif schedule == 'multistep': 51 | scheduler = optim.lr_scheduler.MultiStepLR( 52 | optimizer, milestones=multistep_milestones, gamma=0.1) 53 | schedule_step = 'epoch' 54 | elif schedule == 'none': 55 | scheduler = optim.lr_scheduler.LambdaLR( 56 | optimizer, (lambda epoch: 1)) 57 | schedule_step = 'epoch' 58 | 59 | config = OptimizerWithSchedule(optimizer, scheduler, schedule_step) 60 | return config 61 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Masked Images Are Counterfactual Samples for Robust Fine-tuning 2 | 3 | This repository is the official PyTorch implementation of _"Masked Images Are Counterfactual Samples for Robust Fine-tuning"_ [[paper](https://arxiv.org/abs/2303.03052)], accepted by **CVPR 2023**. 4 | 5 | ## Updates 6 | 7 | - 2023-03-24: Code released. 8 | 9 | ## Setups 10 | 11 | 12 | ### 0. System environment 13 | 14 | Our experiments are conducted on: 15 | - OS: Ubuntu 20.04.4 16 | - GPU: NVIDIA GeForce RTX 3090 17 | 18 | ### 1. Python environment 19 | 20 | - Python 3.9 21 | - PyTorch 1.11 22 | - cudatoolkit 11.3.1 23 | - torchvision 0.12.0 24 | - tensorboard 2.8.0 25 | - scikit-learn 1.0.2 26 | - [torchattacks](https://github.com/Harry24k/adversarial-attacks-pytorch) 27 | - tqdm 28 | 29 | ### 2. Prepare datasets 30 | 31 | The data directory (`DATA_DIR`) should contain the following sub-directories: 32 | - `ILSVRC2012`: [ImageNet](https://www.image-net.org) 33 | - `imagenet-a`: [ImageNet-A](https://github.com/hendrycks/natural-adv-examples) 34 | - `imagenet-r`: [ImageNet-R](https://github.com/hendrycks/natural-adv-examples) 35 | - `imagenet-sketch`: [ImageNet-Sketch](https://github.com/hendrycks/natural-adv-examples) 36 | - `imagenetv2-matched-frequency`: [ImageNet-V2](https://github.com/hendrycks/natural-adv-examples) 37 | - `objectnet-1.0`: [ObjectNet](https://github.com/hendrycks/natural-adv-examples) 38 | 39 | ### 3. Setup directories in `run.sh` 40 | 41 | Please modify line 3-6 of the main script `run.sh` to set the proper directories: 42 | - `LOG_DIR`: root directory for the logging of all experiments and runs 43 | - `DATA_DIR`: the directory for all datasets as stated above 44 | - `MODEL_DIR`: the directory for pre-trained model weights (i.e., CLIP weights; the weights will be automatically downloaded if not exist) 45 | - `EXP_NAME`: experiment name; to be a sub-directory of `LOG_DIR` 46 | 47 | ## Code usage 48 | 49 | The bash script `run.sh` provides a uniform and simplified interface of the Python scripts for training and evaluation, which accepts the following arguments: 50 | - script mode: to train or evaluate a model; can be `train`, `eval` or `train-eval` 51 | - architecture: `clip_{arch}`, where `{arch}` can be `ViT-B/32`, `ViT-B/16` or `ViT-L/14`. 52 | - method: the training method (see `example.sh` or `run.sh` for available options) 53 | - masking: the masking strategy (see `example.sh`) 54 | - seed: an integer seed number (note: we use three seeds (0, 1, 2) in the paper) 55 | - other arguments that are passed to the Python scripts 56 | 57 | The following commands show an example of fine-tuning a CLIP ViT-B/32 model with our proposed method, using object-mask (threshold 0.3) & single-fill. Please refer to `example.sh` for more examples. 58 | ```bash 59 | # Build the zero-shot model 60 | CUDA_VISIBLE_DEVICES=0 bash run.sh train 'clip_ViT-B/32' 'zeroshot' '' 0 61 | # Fine-tune using our approach 62 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash run.sh train 'clip_ViT-B/32' 'FT_FD_image_mask' 'ObjMaskSingleFill(0.3)' 0 63 | # Evaluate the fine-tuned model (replace `train` by `eval`) 64 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash run.sh eval 'clip_ViT-B/32' 'FT_FD_image_mask' 'ObjMaskSingleFill(0.3)' 0 65 | ``` 66 | 67 | ## Results 68 | 69 | (WIP) 70 | 71 | ## Acknowledgement 72 | 73 | Some of the code in this repository is based on the following repositories: 74 | - CLIP: https://github.com/openai/CLIP 75 | - WiSE-FT: https://github.com/mlfoundations/wise-ft 76 | - CAM for ViT: https://github.com/hila-chefer/Transformer-MM-Explainability 77 | -------------------------------------------------------------------------------- /models/clip/templates/openai_imagenet_template.py: -------------------------------------------------------------------------------- 1 | # https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb 2 | 3 | 4 | openai_imagenet_template = [ 5 | lambda c: f'a bad photo of a {c}.', 6 | lambda c: f'a photo of many {c}.', 7 | lambda c: f'a sculpture of a {c}.', 8 | lambda c: f'a photo of the hard to see {c}.', 9 | lambda c: f'a low resolution photo of the {c}.', 10 | lambda c: f'a rendering of a {c}.', 11 | lambda c: f'graffiti of a {c}.', 12 | lambda c: f'a bad photo of the {c}.', 13 | lambda c: f'a cropped photo of the {c}.', 14 | lambda c: f'a tattoo of a {c}.', 15 | lambda c: f'the embroidered {c}.', 16 | lambda c: f'a photo of a hard to see {c}.', 17 | lambda c: f'a bright photo of a {c}.', 18 | lambda c: f'a photo of a clean {c}.', 19 | lambda c: f'a photo of a dirty {c}.', 20 | lambda c: f'a dark photo of the {c}.', 21 | lambda c: f'a drawing of a {c}.', 22 | lambda c: f'a photo of my {c}.', 23 | lambda c: f'the plastic {c}.', 24 | lambda c: f'a photo of the cool {c}.', 25 | lambda c: f'a close-up photo of a {c}.', 26 | lambda c: f'a black and white photo of the {c}.', 27 | lambda c: f'a painting of the {c}.', 28 | lambda c: f'a painting of a {c}.', 29 | lambda c: f'a pixelated photo of the {c}.', 30 | lambda c: f'a sculpture of the {c}.', 31 | lambda c: f'a bright photo of the {c}.', 32 | lambda c: f'a cropped photo of a {c}.', 33 | lambda c: f'a plastic {c}.', 34 | lambda c: f'a photo of the dirty {c}.', 35 | lambda c: f'a jpeg corrupted photo of a {c}.', 36 | lambda c: f'a blurry photo of the {c}.', 37 | lambda c: f'a photo of the {c}.', 38 | lambda c: f'a good photo of the {c}.', 39 | lambda c: f'a rendering of the {c}.', 40 | lambda c: f'a {c} in a video game.', 41 | lambda c: f'a photo of one {c}.', 42 | lambda c: f'a doodle of a {c}.', 43 | lambda c: f'a close-up photo of the {c}.', 44 | lambda c: f'a photo of a {c}.', 45 | lambda c: f'the origami {c}.', 46 | lambda c: f'the {c} in a video game.', 47 | lambda c: f'a sketch of a {c}.', 48 | lambda c: f'a doodle of the {c}.', 49 | lambda c: f'a origami {c}.', 50 | lambda c: f'a low resolution photo of a {c}.', 51 | lambda c: f'the toy {c}.', 52 | lambda c: f'a rendition of the {c}.', 53 | lambda c: f'a photo of the clean {c}.', 54 | lambda c: f'a photo of a large {c}.', 55 | lambda c: f'a rendition of a {c}.', 56 | lambda c: f'a photo of a nice {c}.', 57 | lambda c: f'a photo of a weird {c}.', 58 | lambda c: f'a blurry photo of a {c}.', 59 | lambda c: f'a cartoon {c}.', 60 | lambda c: f'art of a {c}.', 61 | lambda c: f'a sketch of the {c}.', 62 | lambda c: f'a embroidered {c}.', 63 | lambda c: f'a pixelated photo of a {c}.', 64 | lambda c: f'itap of the {c}.', 65 | lambda c: f'a jpeg corrupted photo of the {c}.', 66 | lambda c: f'a good photo of a {c}.', 67 | lambda c: f'a plushie {c}.', 68 | lambda c: f'a photo of the nice {c}.', 69 | lambda c: f'a photo of the small {c}.', 70 | lambda c: f'a photo of the weird {c}.', 71 | lambda c: f'the cartoon {c}.', 72 | lambda c: f'art of the {c}.', 73 | lambda c: f'a drawing of the {c}.', 74 | lambda c: f'a photo of the large {c}.', 75 | lambda c: f'a black and white photo of a {c}.', 76 | lambda c: f'the plushie {c}.', 77 | lambda c: f'a dark photo of a {c}.', 78 | lambda c: f'itap of a {c}.', 79 | lambda c: f'graffiti of the {c}.', 80 | lambda c: f'a toy {c}.', 81 | lambda c: f'itap of my {c}.', 82 | lambda c: f'a photo of a cool {c}.', 83 | lambda c: f'a photo of a small {c}.', 84 | lambda c: f'a tattoo of the {c}.', 85 | ] -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .vscode/* 163 | !.vscode/settings.json 164 | !.vscode/tasks.json 165 | !.vscode/launch.json 166 | !.vscode/extensions.json 167 | !.vscode/*.code-snippets 168 | 169 | # Local History for Visual Studio Code 170 | .history/ 171 | 172 | # Built Visual Studio Code Extensions 173 | *.vsix 174 | -------------------------------------------------------------------------------- /data/objectnet.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | from torchvision.datasets import ImageFolder 5 | import torchvision.transforms as transforms 6 | from .imagenet import ImageNetDataset 7 | from data.base import get_dataloader 8 | 9 | 10 | class ImageFolderWithSpecificClasses(ImageFolder): 11 | 12 | def __init__(self, *args, class_idx=None, **kwargs): 13 | self._specified_class_idx = class_idx 14 | super().__init__(*args, **kwargs) 15 | 16 | def find_classes(self, directory): 17 | all_classes, class_to_idx = super().find_classes(directory) 18 | classes = [all_classes[i] for i in self._specified_class_idx] 19 | class_set = set(classes) 20 | for cls in all_classes: 21 | if cls not in class_set: 22 | class_to_idx.pop(cls) 23 | return classes, class_to_idx 24 | 25 | 26 | class ObjectNetDataset(ImageNetDataset): 27 | 28 | def __init__(self, data_dir, size=224, interpolation='bicubic', transform='std'): 29 | super().__init__(data_dir, size, interpolation, transform) 30 | self.root_dir = os.path.join(self.data_dir, 'objectnet-1.0') 31 | ON_classes, IN_classes, ON_pid_to_IN_pids = \ 32 | self._parse_metadata(self.root_dir) 33 | self.ON_classes = ON_classes 34 | self.ignored_classes = [i for i in range(self.num_classes) 35 | if i not in IN_classes] 36 | self.target_transform = self._get_target_transform(ON_pid_to_IN_pids) 37 | self.transforms_test = transforms.Compose([ 38 | self.crop_red_border, 39 | self.transforms_test, 40 | ]) 41 | 42 | @staticmethod 43 | def crop_red_border(img): 44 | width, height = img.size 45 | cropArea = (2, 2, width - 2, height - 2) 46 | img = img.crop(cropArea) 47 | return img 48 | 49 | def _parse_metadata(self, root_dir): 50 | mapping_dir = os.path.join(root_dir, 'mappings') 51 | with open(os.path.join(mapping_dir, 'folder_to_objectnet_label.json'), 'r') as f: 52 | folder_to_ON_label = json.load(f) 53 | ON_pid_to_ON_label = [v for k, v in sorted(folder_to_ON_label.items())] 54 | with open(os.path.join(mapping_dir, 'objectnet_to_imagenet_1k.json'), 'r') as f: 55 | ON_label_to_IN_labels = {ON_label: IN_labels.split('; ') 56 | for ON_label, IN_labels in json.load(f).items()} 57 | with open(os.path.join(mapping_dir, 'imagenet_to_label_2012_v2'), 'r') as f: 58 | IN_label_to_IN_id = {v.strip(): i for i, v in enumerate(f)} 59 | with open(os.path.join(mapping_dir, 'pytorch_to_imagenet_2012_id.json'), 'r') as f: 60 | IN_pid_to_IN_id = json.load(f) 61 | IN_id_to_IN_pid = {id: int(pid) for pid, id in IN_pid_to_IN_id.items()} 62 | num_ON_classes = len(ON_pid_to_ON_label) 63 | ON_classes = [] 64 | IN_classes = set() 65 | ON_pid_to_IN_pids = [] 66 | for ON_pid in range(num_ON_classes): 67 | ON_label = ON_pid_to_ON_label[ON_pid] 68 | if ON_label in ON_label_to_IN_labels: 69 | IN_labels = ON_label_to_IN_labels[ON_label] 70 | IN_ids = [IN_label_to_IN_id[IN_label] for IN_label in IN_labels] 71 | IN_pids = [IN_id_to_IN_pid[IN_id] for IN_id in IN_ids] 72 | ON_classes.append(ON_pid) 73 | else: 74 | IN_pids = [] 75 | ON_pid_to_IN_pids.append(IN_pids) 76 | IN_classes |= set(IN_pids) 77 | return ON_classes, IN_classes, ON_pid_to_IN_pids 78 | 79 | def _get_target_transform(self, ON_pid_to_IN_pids): 80 | def transform(ON_pid): 81 | IN_pids = ON_pid_to_IN_pids[ON_pid] 82 | target = np.zeros(self.num_classes, dtype=int) 83 | target[IN_pids] = 1 84 | return target 85 | return transform 86 | 87 | def get_loader(self, batch_size, num_workers, with_index=False, 88 | train_split='original', val_size=0, split_seed=0, 89 | shuffle_test=False, augment=True, drop_last=True, 90 | world_size=1, rank=0): 91 | images_dir = os.path.join(self.root_dir, 'images') 92 | 93 | testset = ImageFolderWithSpecificClasses( 94 | root=images_dir, transform=self.transforms_test, 95 | target_transform=self.target_transform, 96 | class_idx=self.ON_classes) 97 | 98 | kwargs = dict(batch_size=batch_size, num_workers=num_workers, 99 | with_index=with_index, num_replicas=world_size, rank=rank) 100 | testloader = get_dataloader( 101 | testset, shuffle=shuffle_test, drop_last=False, **kwargs) 102 | return testloader 103 | -------------------------------------------------------------------------------- /models/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | trap "exit" INT 2 | 3 | LOG_DIR="$HOME/log/robust_ft" 4 | DATA_DIR="$HOME/data" 5 | MODEL_DIR="$HOME/model" 6 | EXP_NAME="release" 7 | NUM_WORKERS=6 8 | basic_args="--code_dir ./ --data_dir $DATA_DIR --log_dir $LOG_DIR --exp_name $EXP_NAME" 9 | 10 | mode=$1; shift 11 | 12 | model=$1; shift 13 | model_args="--arch $model --load_pretrained $MODEL_DIR/clip" 14 | model=${model//[-\/@]/_} # avoid '/' in filename (e.g. ViT-B/32 -> ViT_B_32) 15 | 16 | methods=$1; shift 17 | masking=$1; shift 18 | 19 | seed=$1; shift 20 | seed_args="--seed $seed --data_split_seed $seed --run_number $seed" 21 | 22 | do_train () { 23 | script_name=$1; shift 24 | num_devices=$(python -c 'import torch; print(torch.cuda.device_count())') 25 | port=$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1]); s.close()') 26 | OMP_NUM_THREADS=2 \ 27 | torchrun --nnodes=1 --nproc_per_node=$num_devices --rdzv_endpoint="localhost:$port" \ 28 | "./$script_name.py" $basic_args $model_args $seed_args \ 29 | --num_workers $NUM_WORKERS --num_iters_trainset_test 0 "$@" 30 | } 31 | do_test () { 32 | num_devices=$(python -c 'import torch; print(torch.cuda.device_count())') 33 | port=$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1]); s.close()') 34 | OMP_NUM_THREADS=2 \ 35 | torchrun --nnodes=1 --nproc_per_node=$num_devices --rdzv_endpoint="localhost:$port" \ 36 | ./eval.py $basic_args $model_args $seed_args --load_run_number $seed \ 37 | --num_workers $NUM_WORKERS "$@" 38 | } 39 | dummy(){ unused(){ :;} } 40 | 41 | if [[ $mode = 'train' ]]; then 42 | train=do_train 43 | test=dummy 44 | elif [[ $mode = 'eval' ]]; then 45 | train=dummy 46 | test=do_test 47 | elif [[ $mode = 'train-eval' ]]; then 48 | train=do_train 49 | test=do_test 50 | else 51 | echo "Incorrect script mode '$mode', should be: train / eval / train-eval" 52 | exit 1 53 | fi 54 | 55 | 56 | for method in $methods; do 57 | 58 | zeroshot_run_name="$model-zeroshot" 59 | wise_run_name="$model-WiSE-FT" 60 | teacher_run_name=$zeroshot_run_name 61 | 62 | if [[ $masking = '' ]]; then 63 | RUN_NAME="$model-$method" 64 | else 65 | RUN_NAME="$model-$method-$masking" 66 | fi 67 | 68 | if [[ $method = 'zeroshot' ]]; then 69 | RUN_NAME=$zeroshot_run_name 70 | $train 'main_standard' --run_name $RUN_NAME \ 71 | --epoch 1 --num_iters_train 0 \ 72 | --num_iters_trainset_test 0 --lr_schedule none \ 73 | "$@" 74 | elif [[ $method = 'LP' ]]; then # linear probe 75 | $train 'main_standard' --run_name $RUN_NAME \ 76 | --load_run_name $zeroshot_run_name \ 77 | --freeze_backbone --weight_decay 0 \ 78 | "$@" 79 | elif [[ $method = 'FT' ]]; then # end-to-end fine-tune 80 | $train 'main_standard' --run_name $RUN_NAME \ 81 | --load_run_name $zeroshot_run_name \ 82 | "$@" 83 | elif [[ $method = 'WiSE-FT' ]]; then 84 | RUN_NAME=$wise_run_name 85 | $train 'main_standard' --run_name $RUN_NAME \ 86 | --epoch 1 --num_iters_train 0 \ 87 | --num_iters_trainset_test 0 --lr_schedule none \ 88 | --load_run_name "$model-FT" --load_run_number $seed \ 89 | --wise_base_run_name $zeroshot_run_name \ 90 | "$@" 91 | elif [[ $method = 'FT_KD' ]]; then 92 | $train 'main_distill' --run_name $RUN_NAME \ 93 | --load_run_name $zeroshot_run_name \ 94 | --teacher_run_name $teacher_run_name \ 95 | --task std --distill_mode kd \ 96 | "$@" 97 | elif [[ $method = 'FT_KD_image_mask' ]]; then 98 | $train 'main_distill' --run_name $RUN_NAME \ 99 | --load_run_name $zeroshot_run_name \ 100 | --teacher_run_name $teacher_run_name \ 101 | --task std --distill_mode kd_image_mask \ 102 | --distill_masking $masking \ 103 | "$@" 104 | elif [[ $method = 'FT_FD' ]]; then 105 | $train 'main_distill' --run_name $RUN_NAME \ 106 | --load_run_name $zeroshot_run_name \ 107 | --teacher_run_name $teacher_run_name \ 108 | --task std --distill_mode fd \ 109 | "$@" 110 | elif [[ $method = 'FT_FD_image_mask' ]]; then 111 | $train 'main_distill' --run_name $RUN_NAME \ 112 | --load_run_name $zeroshot_run_name \ 113 | --teacher_run_name $teacher_run_name \ 114 | --task std --distill_mode fd_image_mask \ 115 | --distill_masking $masking \ 116 | "$@" 117 | elif [[ $method = 'FT_FD_mae_mask' ]]; then 118 | $train 'main_distill' --run_name $RUN_NAME \ 119 | --load_run_name $zeroshot_run_name \ 120 | --teacher_run_name $teacher_run_name \ 121 | --task std --distill_mode fd_mae_mask \ 122 | --distill_masking $masking \ 123 | "$@" 124 | elif [[ $method = 'FT_FD_attn_mask' ]]; then 125 | $train 'main_distill' --run_name $RUN_NAME \ 126 | --load_run_name $zeroshot_run_name \ 127 | --teacher_run_name $teacher_run_name \ 128 | --task std --distill_mode fd_attn_mask \ 129 | --distill_masking $masking \ 130 | "$@" 131 | fi 132 | 133 | $test --run_name "eval-$RUN_NAME" --load_run_name $RUN_NAME "$@" 134 | 135 | done 136 | -------------------------------------------------------------------------------- /trainers/distill_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from trainers import StandardTrainer 5 | from trainers.base import inference_mode 6 | 7 | 8 | class DistillTrainer(StandardTrainer): 9 | 10 | def __init__(self, *args, teacher=None, masking=None, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | assert teacher is not None 13 | self.teacher = teacher 14 | self.masking = masking 15 | self.masking = masking 16 | if self.opt.distill_mode != 'none': 17 | self.add_meter('loss_distill', 'Ld', loop_id=0, fstr_format='6.3f') 18 | self.add_meter('loss_task', 'Lt', loop_id=0, fstr_format='6.3f') 19 | if self.opt.save_mask: 20 | self.masks_to_save = [] 21 | 22 | def get_data_batch(self, loop_id, phase_id): 23 | batch = self._next_data_batch(loop_id) 24 | if self.opt.save_mask: 25 | batch = list(batch[0]) + [batch[1]] 26 | return [t.to(self.device) for t in batch] 27 | 28 | def save_mask(self, epoch_id, idx, mask): 29 | N, L = mask.size() 30 | p = torch.arange(0, L, dtype=torch.int64, device=mask.device) 31 | p = p.unsqueeze(0).repeat(N, 1) 32 | mask_int = (mask * 2 ** p).sum(1) 33 | self.masks_to_save.append(torch.stack([idx, mask_int], 1).cpu()) 34 | 35 | def do_epoch(self, epoch_id): 36 | super().do_epoch(epoch_id) 37 | if self.opt.save_mask: 38 | masks = torch.cat(self.masks_to_save, 0) 39 | path = os.path.join(self.manager.get_checkpoint_dir(), 40 | f'mask-{epoch_id}-{self.manager._rank}.pt') 41 | with open(path, 'w'): 42 | torch.save(masks, path) 43 | self.masks_to_save = [] 44 | 45 | def do_step_train(self, epoch_id, data_batch, config, n_accum_steps): 46 | model = self.models['model'] 47 | criterion_cls = self.criterions['classification'] 48 | if 'knowledge_distillation' in self.criterions: 49 | criterion_kd = self.criterions['knowledge_distillation'] 50 | if 'feature_distillation' in self.criterions: 51 | criterion_fd = self.criterions['feature_distillation'] 52 | if self.opt.save_mask: 53 | images, labels, idx = data_batch 54 | else: 55 | images, labels = data_batch 56 | 57 | logits = None 58 | 59 | # Distillation 60 | if self.opt.distill_mode == 'kd': 61 | with torch.no_grad(): 62 | logits_teacher = self.teacher(images) 63 | logits = model(images) 64 | loss_distill = criterion_kd(logits, logits_teacher) 65 | elif self.opt.distill_mode == 'kd_image_mask': 66 | images_masked = self.masking(images, labels) 67 | with torch.no_grad(): 68 | logits_teacher_masked = self.teacher(images_masked) 69 | logits_masked = model(images_masked) 70 | loss_distill = criterion_kd(logits_masked, logits_teacher_masked) 71 | elif self.opt.distill_mode == 'fd': 72 | with torch.no_grad(): 73 | features_teacher = self.teacher(images, get_logits=False, get_features=True) 74 | features = model(images, get_logits=False, get_features=True) 75 | loss_distill = criterion_fd(features, features_teacher) 76 | elif self.opt.distill_mode == 'fd_image_mask': 77 | images_masked = self.masking(images, labels) 78 | if self.opt.save_mask: 79 | images_masked, patch_mask = images_masked 80 | self.save_mask(epoch_id, idx, patch_mask) 81 | with torch.no_grad(): 82 | features_teacher_masked = self.teacher(images_masked, get_logits=False, get_features=True) 83 | features_masked = model(images_masked, get_logits=False, get_features=True) 84 | loss_distill = criterion_fd(features_masked, features_teacher_masked) 85 | elif self.opt.distill_mode == 'fd_mae_mask': 86 | mask = self.masking(images, labels) 87 | with torch.no_grad(): 88 | features_teacher_masked = self.teacher(images, mask=mask, get_logits=False, get_features=True) 89 | features_masked = model(images, mask=mask, get_logits=False, get_features=True) 90 | loss_distill = criterion_fd(features_masked, features_teacher_masked) 91 | elif self.opt.distill_mode == 'fd_attn_mask': 92 | attn_mask = self.masking(images, labels) 93 | with torch.no_grad(): 94 | features_teacher_masked = self.teacher(images, attn_mask=attn_mask, get_logits=False, get_features=True) 95 | features_masked = model(images, attn_mask=attn_mask, get_logits=False, get_features=True) 96 | loss_distill = criterion_fd(features_masked, features_teacher_masked) 97 | elif self.opt.distill_mode == 'none': 98 | loss_distill = 0 99 | else: 100 | raise NotImplementedError() 101 | 102 | # Task 103 | if self.opt.task == 'std': 104 | if logits is None: 105 | logits = model(images) 106 | loss_task = criterion_cls(logits, labels) 107 | else: 108 | raise NotImplementedError() 109 | 110 | loss = self.opt.w_task * loss_task + self.opt.w_distill * loss_distill 111 | loss /= n_accum_steps 112 | loss.backward() 113 | if self.opt.grad_clip: 114 | nn.utils.clip_grad_norm_(model.parameters(), self.opt.grad_clip) 115 | 116 | self.loop_meters['loss'].update(loss) 117 | self.loop_meters['loss_task'].update(loss_task) 118 | if self.opt.distill_mode != 'none': 119 | self.loop_meters['loss_distill'].update(loss_distill) 120 | self._update_acc_meter('acc', logits, labels) 121 | 122 | def do_step_test(self, data_batch, config): 123 | if self.opt.save_mask: 124 | data_batch = data_batch[:-1] 125 | super().do_step_test(data_batch, config) 126 | -------------------------------------------------------------------------------- /trainers/standard_trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Any, Union 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader 7 | 8 | from data.base import BaseDataset 9 | from trainers.base import ClassificationTrainer, LoopConfig, OptimizerWithSchedule, inference_mode 10 | from utils.experiman import ExperiMan 11 | 12 | 13 | class StandardLoopConfig(LoopConfig): 14 | 15 | def __init__( 16 | self, 17 | name: str, 18 | dataset: BaseDataset, 19 | dataloader: DataLoader, 20 | training: bool, 21 | n_iterations: int, 22 | n_phases: int = 1, 23 | n_logical_steps: Union[list[int], int] = 1, 24 | n_computation_steps: Union[list[int], int] = 1, 25 | run_every_n_epochs: int = 1, 26 | run_at_checkpoint: bool = True, 27 | run_at_last_epoch: bool = True, 28 | for_best_meter: bool = False, 29 | ): 30 | super().__init__( 31 | name, dataset, dataloader, training, n_iterations, 32 | n_phases, n_logical_steps, n_computation_steps, run_every_n_epochs, 33 | run_at_checkpoint, run_at_last_epoch, for_best_meter) 34 | 35 | 36 | class StandardTrainer(ClassificationTrainer): 37 | 38 | def __init__( 39 | self, 40 | manager: ExperiMan, 41 | models: dict[str, nn.Module], 42 | criterions: dict[str, nn.Module], 43 | n_epochs: int, 44 | loop_configs: list[LoopConfig], 45 | optimizers: dict[str, OptimizerWithSchedule], 46 | log_period: int, 47 | ckpt_period: int, 48 | device: torch.device, 49 | save_init_ckpt: bool = False, 50 | resume_ckpt: dict = None, 51 | num_classes: int = None, 52 | ignored_classes: Union[list[list], list] = None, 53 | keep_eval_mode: bool = False, 54 | acc_per_class: bool = False, 55 | ): 56 | self.opt = manager.get_opt() 57 | super().__init__( 58 | manager=manager, 59 | models=models, 60 | criterions=criterions, 61 | n_epochs=n_epochs, 62 | loop_configs=loop_configs, 63 | optimizers=optimizers, 64 | log_period=log_period, 65 | ckpt_period=ckpt_period, 66 | device=device, 67 | save_init_ckpt=save_init_ckpt, 68 | resume_ckpt=resume_ckpt, 69 | num_classes=num_classes, 70 | ignored_classes=ignored_classes, 71 | ) 72 | self.keep_eval_mode = keep_eval_mode 73 | self.acc_per_class = acc_per_class 74 | self.setup_meters() 75 | 76 | def setup_meters(self): 77 | def loops_satisfy(criterion): 78 | return [i for i, c in enumerate(self.loop_configs) if criterion(c)] 79 | self.add_meter('learning_rate', 'lr', 80 | meter_type='scaler', omit_from_results=True) 81 | self.add_meter('loss', 'L', 82 | loop_id=loops_satisfy(lambda c: c.training), 83 | fstr_format='6.3f') 84 | acc_meter_type = 'per_class_avg' if self.acc_per_class else 'avg' 85 | self.add_meter('acc', 'Acc', 86 | meter_type=acc_meter_type, fstr_format='5.2f') 87 | loop_for_best_meter = loops_satisfy(lambda c: c.for_best_meter) 88 | if loop_for_best_meter: 89 | loop_id = loop_for_best_meter[0] 90 | config = self.loop_configs[loop_id] 91 | best_meter = 'acc' 92 | self.set_meter_for_best_checkpoint( 93 | loop_id=loop_id, name=best_meter, maximum=True) 94 | 95 | def get_data_batch(self, loop_id, phase_id): 96 | batch = self._next_data_batch(loop_id) 97 | return [t.to(self.device) for t in batch] 98 | 99 | def get_active_optimizers(self, loop_id, phase_id): 100 | if self.loop_configs[loop_id].training: 101 | return [self.optimizers['optimizer']] 102 | else: 103 | return [] 104 | 105 | def get_checkpoint(self, epoch_id): 106 | checkpoint = super().get_checkpoint(epoch_id) 107 | if self._should_run_loop(epoch_id=epoch_id, loop_id=3): 108 | meters = self.meters[3] 109 | if 'acc' in meters: 110 | checkpoint['test_acc'] = meters['acc'].get_value() 111 | if self.acc_per_class: 112 | if 'acc' in meters: 113 | checkpoint['test_acc_per_class'] = \ 114 | meters['acc'].get_value(per_class_avg=False) 115 | return checkpoint 116 | 117 | def toggle_model_mode(self, epoch_id, loop_id): 118 | model = self.models['model'] 119 | training = self.loop_configs[loop_id].training 120 | model.train(training and not self.keep_eval_mode) 121 | bare_model = model.module if hasattr(model, 'module') else model 122 | if bare_model.use_dataset_preprocess: 123 | preprocess_fn = self.loop_configs[loop_id].dataset.preprocess 124 | bare_model.set_preprocess(preprocess_fn) 125 | 126 | def update_meters(self): 127 | if self.optimizers: 128 | lr = self.optimizers['optimizer'].get_learning_rates()[0] 129 | self.loop_meters['learning_rate'].update(lr) 130 | 131 | def do_step(self, epoch_id, loop_id, iter_id, phase_id, data_batch): 132 | config = self.loop_configs[loop_id] 133 | if config.training: 134 | self.do_step_train(epoch_id, data_batch, config, 135 | n_accum_steps=config.n_computation_steps[phase_id]) 136 | else: 137 | self.do_step_test(data_batch, config) 138 | 139 | def do_step_train(self, epoch_id, data_batch, config, n_accum_steps): 140 | model = self.models['model'] 141 | criterion_cls = self.criterions['classification'] 142 | images, labels = data_batch 143 | 144 | logits = model(images) 145 | loss = criterion_cls(logits, labels) 146 | 147 | loss /= n_accum_steps 148 | loss.backward() 149 | if self.opt.grad_clip: 150 | nn.utils.clip_grad_norm_(model.parameters(), self.opt.grad_clip) 151 | 152 | self.loop_meters['loss'].update(loss) 153 | self._update_acc_meter('acc', logits, labels) 154 | 155 | def do_step_test(self, data_batch, config): 156 | model = self.models['model'] 157 | images, labels = data_batch 158 | self._update_acc_meter('acc', model(images), labels) 159 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | 5 | class Accuracy(): 6 | 7 | def __init__(self, num_classes=None, involved_classes=None, 8 | ignored_classes=None, reduction='mean', denominator=100): 9 | if ignored_classes: 10 | assert involved_classes is None 11 | self.ignored_classes = ignored_classes 12 | ignored = set(ignored_classes) 13 | self.involved_classes = [i for i in range(num_classes) 14 | if i not in ignored] 15 | elif involved_classes: 16 | assert ignored_classes is None 17 | self.involved_classes = involved_classes 18 | involved = set(involved_classes) 19 | self.ignored_classes = [i for i in range(num_classes) 20 | if i not in involved] 21 | else: 22 | self.involved_classes = list(range(num_classes)) 23 | self.ignored_classes = [] 24 | self.involved_classes_set = set(self.involved_classes) 25 | self.reduction = reduction 26 | self.denominator = denominator 27 | 28 | def __call__(self, outputs, labels, reduction=None): 29 | outputs = outputs.detach().clone() 30 | labels = labels.detach().clone() 31 | 32 | # mask out samples with ignored classes 33 | if labels.dim() == 1: # regular labels 34 | mask = labels.clone().cpu().apply_( 35 | lambda c: c in self.involved_classes_set) 36 | mask = mask.to(dtype=bool, device=labels.device) 37 | else: # one-hot / multi labels 38 | mask = (labels[:, self.involved_classes].sum(-1) > 0) 39 | outputs = outputs[mask] 40 | labels = labels[mask] 41 | 42 | # predict from outputs 43 | if outputs.dim() == 1: # hard predictions 44 | predictions = outputs 45 | else: # logits / soft predictions 46 | if len(outputs): 47 | outputs[:, self.ignored_classes] = outputs.min() - 1 48 | predictions = outputs.argmax(-1) 49 | 50 | # decide correctness 51 | if labels.dim() == 1: 52 | correct = (predictions == labels) 53 | else: 54 | correct = labels.gather(1, predictions.unsqueeze(1)) 55 | 56 | # produce results 57 | correct = correct.float() * self.denominator 58 | if reduction is None: 59 | reduction = self.reduction 60 | if reduction == 'sum': 61 | return correct.sum(), len(correct) 62 | elif reduction == 'mean': 63 | return correct.mean() 64 | elif reduction == 'none': 65 | return correct 66 | 67 | 68 | class ScalerMeter(object): 69 | 70 | def __init__(self): 71 | self.x = None 72 | 73 | def update(self, x): 74 | if not isinstance(x, (int, float)): 75 | x = x.item() 76 | self.x = x 77 | 78 | def reset(self): 79 | self.x = None 80 | 81 | def get_value(self): 82 | if self.x: 83 | return self.x 84 | return 0 85 | 86 | def sync(self, device): 87 | pass 88 | 89 | 90 | class AverageMeter(object): 91 | 92 | def __init__(self): 93 | self.sum = 0 94 | self.n = 0 95 | 96 | def update(self, x, n=1): 97 | self.sum += float(x) 98 | self.n += int(n) 99 | 100 | def reset(self): 101 | self.sum = 0 102 | self.n = 0 103 | 104 | def get_value(self): 105 | if self.n: 106 | return self.sum / self.n 107 | return 0 108 | 109 | def sync(self, device): 110 | t = torch.tensor([self.sum, self.n], 111 | dtype=torch.float32, device=device) 112 | dist.all_reduce(t, op=dist.ReduceOp.SUM) 113 | self.sum = t[0].item() 114 | self.n = round(t[1].item()) 115 | 116 | 117 | class MovingAverageMeter(object): 118 | 119 | def __init__(self, decay=0.95): 120 | self.x = None 121 | self.decay = decay 122 | 123 | def update(self, x, n=1): 124 | if n > 0: 125 | x = float(x) / int(n) 126 | if self.x is None: 127 | self.x = x 128 | else: 129 | self.x = self.x * self.decay + x * (1 - self.decay) 130 | 131 | def reset(self): 132 | self.x = None 133 | 134 | def get_value(self): 135 | if self.x: 136 | return self.x 137 | return 0 138 | 139 | def sync(self, device): 140 | if self.x is not None: 141 | t = torch.tensor([self.x], dtype=torch.float32, device=device) 142 | dist.all_reduce(t, op=dist.ReduceOp.SUM) 143 | self.x = t[0].item() / dist.get_world_size() 144 | 145 | 146 | class PerClassMeter(object): 147 | 148 | def __init__(self, meter, num_classes=None, **kwargs): 149 | self.meter = meter 150 | self.num_classes = num_classes or 0 151 | self.kwargs = kwargs 152 | self.meters = [meter(**kwargs) for _ in range(self.num_classes)] 153 | 154 | def update(self, x, y): 155 | n = int(max(y)) 156 | if n > self.num_classes: 157 | self.meters += [self.meter(**self.kwargs) 158 | for _ in range(n - self.num_classes)] 159 | self.num_classes = n 160 | for i in range(self.num_classes): 161 | mask = (y == i) 162 | self.meters[i].update(sum(x[mask]), sum(mask)) 163 | 164 | def reset(self): 165 | for meter in self.meters: 166 | meter.reset() 167 | 168 | def get_value(self, per_class_avg=True): 169 | values = [meter.get_value() for meter in self.meters] 170 | if per_class_avg: 171 | return sum(values) / len(values) 172 | else: 173 | return values 174 | 175 | def sync(self, device): 176 | for meter in self.meters: 177 | meter.sync(device) 178 | 179 | 180 | def consume_prefix_in_state_dict_if_present(state_dict, prefix): 181 | r"""Strip the prefix in state_dict, if any. 182 | ..note:: 183 | Given a `state_dict` from a DP/DDP model, a local model can load it by applying 184 | `consume_prefix_in_state_dict_if_present(state_dict, "module.")` before calling 185 | :meth:`torch.nn.Module.load_state_dict`. 186 | Args: 187 | state_dict (OrderedDict): a state-dict to be loaded to the model. 188 | prefix (str): prefix. 189 | """ 190 | keys = sorted(state_dict.keys()) 191 | for key in keys: 192 | if key.startswith(prefix): 193 | newkey = key[len(prefix) :] 194 | state_dict[newkey] = state_dict.pop(key) 195 | 196 | # also strip the prefix in metadata if any. 197 | if "_metadata" in state_dict: 198 | metadata = state_dict["_metadata"] 199 | for key in list(metadata.keys()): 200 | # for the metadata dict, the key can be: 201 | # '': for the DDP module, which we want to remove. 202 | # 'module': for the actual model. 203 | # 'module.xx.xx': for the rest. 204 | 205 | if len(key) == 0: 206 | continue 207 | newkey = key[len(prefix) :] 208 | metadata[newkey] = metadata.pop(key) 209 | 210 | def parse(arg, default): 211 | if arg is None: 212 | return default 213 | return arg 214 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.distributed as dist 5 | from torch.nn.parallel import DistributedDataParallel 6 | 7 | from utils.experiman import manager 8 | from data import * 9 | from models import get_clip_model 10 | from trainers import StandardTrainer, StandardLoopConfig 11 | from utils.misc import parse 12 | 13 | 14 | def add_parser_argument(parser): 15 | ## ======================== Data ========================== 16 | parser.add_argument('--dataset', type=str, 17 | default='imagenet,imagenet_v2,imagenet_r,imagenet_sketch,objectnet,imagenet_a') 18 | parser.add_argument('--train_split', default='original', type=str) 19 | parser.add_argument('--val_size', default=10240, type=int) 20 | parser.add_argument('--data_split_seed', default=0, type=int) 21 | parser.add_argument('--batch', default=512, type=int) 22 | parser.add_argument('--num_workers', default=1, type=int) 23 | parser.add_argument('--image_size', default=224, type=int) 24 | parser.add_argument('--transform', default='clip', type=str) 25 | parser.add_argument('--same_ignored_classes', action='store_true') 26 | ## ======================= Model ========================== 27 | parser.add_argument('--arch', type=str) 28 | parser.add_argument('--arch_variant', default='zeroshot', type=str) 29 | parser.add_argument('--load_pretrained', type=str) 30 | parser.add_argument('--load_ckpt', type=str) 31 | parser.add_argument('--load_run_name', type=str) 32 | parser.add_argument('--load_run_number', type=str) 33 | parser.add_argument('--load_run_ckpt_name', type=str, default='ckpt-best') 34 | parser.add_argument('--freeze_backbone', action='store_true') 35 | ## ===================== Evaluation ======================= 36 | parser.add_argument('--num_iters_test', type=int, 37 | help="default: len(testloader)") 38 | ## ====================== Logging ========================= 39 | parser.add_argument('--log_period', default=5, type=int, metavar='LP', 40 | help='log every LP iterations') 41 | parser.add_argument('--ckpt_period', type=int, metavar='CP', 42 | help='make checkpoints every CP epochs') 43 | parser.add_argument('--comment', default='', type=str) 44 | ## ==================== Experimental ====================== 45 | 46 | 47 | def main(): 48 | local_rank = int(os.environ["LOCAL_RANK"]) 49 | rank = int(os.environ["RANK"]) 50 | world_size = int(os.environ["WORLD_SIZE"]) 51 | device = torch.device(local_rank) 52 | torch.cuda.set_device(device) 53 | 54 | # Parse arguments and setup ExperiMan 55 | parser = manager.get_basic_arg_parser() 56 | add_parser_argument(parser) 57 | opt = parser.parse_args() 58 | manager.setup(opt, rank=rank, world_size=world_size, 59 | third_party_tools=('tensorboard',)) 60 | if world_size > 1: 61 | dist.init_process_group("nccl") 62 | if rank == 0: 63 | t = torch.tensor([opt.run_number + .1], device=device) 64 | else: 65 | t = torch.empty(1, device=device) 66 | dist.broadcast(t, src=0) 67 | opt.run_number = int(t.item()) 68 | manager.set_run_dir(manager.get_run_dir(opt.run_name, opt.run_number)) 69 | logger = manager.get_logger() 70 | logger.info(f'==> Number of devices: {world_size}') 71 | use_clip = opt.arch.startswith('clip') 72 | 73 | # Data 74 | logger.info('==> Preparing data') 75 | assert opt.batch % world_size == 0 76 | batch = opt.batch // world_size 77 | data_kwargs = dict( 78 | batch_size=batch, num_workers=opt.num_workers, with_index=False, 79 | train_split=opt.train_split, val_size=opt.val_size, 80 | split_seed=opt.data_split_seed, 81 | world_size=world_size, rank=rank) 82 | dataset_names = opt.dataset.split(',') 83 | datasets = [] 84 | testloaders = [] 85 | num_classes = None 86 | ignored_classes = set() # get the union of ignored classes 87 | for dataset_name in dataset_names: 88 | dataset = get_dataset( 89 | dataset_name, opt.data_dir, size=opt.image_size, transform=opt.transform) 90 | loader = dataset.get_loader(**data_kwargs) 91 | testloader = loader[-1] if isinstance(loader, (list, tuple)) else loader 92 | datasets.append(dataset) 93 | testloaders.append(testloader) 94 | num_classes = num_classes or dataset.num_classes 95 | assert num_classes == dataset.num_classes 96 | if hasattr(dataset, 'ignored_classes'): 97 | ignored_classes |= set(dataset.ignored_classes) 98 | if opt.same_ignored_classes: 99 | ignored_classes = list(ignored_classes) 100 | else: 101 | ignored_classes = None 102 | 103 | # Model 104 | logger.info('==> Building models') 105 | if use_clip: 106 | model = get_clip_model( 107 | arch=opt.arch, 108 | dataset=dataset, 109 | variant=opt.arch_variant, 110 | model_dir=opt.load_pretrained, 111 | device=device, 112 | get_zeroshot_weights=(not (opt.load_ckpt or opt.load_run_name)), 113 | ) 114 | else: 115 | raise NotImplementedError() 116 | if world_size > 1: 117 | model = DistributedDataParallel(model, device_ids=[local_rank]) 118 | 119 | # Load 120 | bare_model = model.module if world_size > 1 else model 121 | if opt.load_ckpt: 122 | load_path = opt.load_ckpt 123 | else: 124 | ckpt_dir = manager.get_checkpoint_dir( 125 | opt.load_run_name, opt.load_run_number) 126 | load_path = os.path.join(ckpt_dir, f'{opt.load_run_ckpt_name}.pt') 127 | logger.info(f'==> Loading model from {load_path}') 128 | checkpoint = torch.load(load_path, map_location='cpu') 129 | bare_model.load_state_dict(checkpoint['model']) 130 | 131 | # Trainer 132 | loop_configs = [] 133 | for dataset_name, dataset, testloader in zip(dataset_names, datasets, testloaders): 134 | if isinstance(testloader, dict): 135 | for split_name, split_loader in testloader.items(): 136 | num_iters_test = parse(opt.num_iters_test, len(split_loader)) 137 | config = StandardLoopConfig( 138 | f'{dataset_name}-{split_name}', dataset, split_loader, 139 | training=False, n_iterations=num_iters_test) 140 | loop_configs.append(config) 141 | else: 142 | num_iters_test = parse(opt.num_iters_test, len(testloader)) 143 | config = StandardLoopConfig( 144 | dataset_name, dataset, testloader, 145 | training=False, n_iterations=num_iters_test) 146 | loop_configs.append(config) 147 | trainer = StandardTrainer( 148 | manager=manager, 149 | models={'model': model}, 150 | criterions={}, 151 | n_epochs=1, 152 | loop_configs=loop_configs, 153 | optimizers={}, 154 | log_period=opt.log_period, 155 | ckpt_period=opt.ckpt_period, 156 | device=device, 157 | num_classes=num_classes, 158 | ignored_classes=ignored_classes, 159 | ) 160 | 161 | trainer.test() 162 | 163 | 164 | if __name__ == "__main__": 165 | # Set the environment variables if not launched by torchrun 166 | if 'RANK' not in os.environ: 167 | os.environ['RANK'] = '0' 168 | if 'LOCAL_RANK' not in os.environ: 169 | os.environ['LOCAL_RANK'] = os.environ['RANK'] 170 | if 'WORLD_SIZE' not in os.environ: 171 | os.environ['WORLD_SIZE'] = '1' 172 | if 'LOCAL_WORLD_SIZE' not in os.environ: 173 | os.environ['LOCAL_WORLD_SIZE'] = os.environ['WORLD_SIZE'] 174 | main() 175 | -------------------------------------------------------------------------------- /models/clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 39 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 40 | } 41 | 42 | 43 | def _download(url: str, root: str): 44 | os.makedirs(root, exist_ok=True) 45 | filename = os.path.basename(url) 46 | 47 | expected_sha256 = url.split("/")[-2] 48 | download_target = os.path.join(root, filename) 49 | 50 | if os.path.exists(download_target) and not os.path.isfile(download_target): 51 | raise RuntimeError(f"{download_target} exists and is not a regular file") 52 | 53 | if os.path.isfile(download_target): 54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 55 | return download_target 56 | else: 57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 58 | 59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 61 | while True: 62 | buffer = source.read(8192) 63 | if not buffer: 64 | break 65 | 66 | output.write(buffer) 67 | loop.update(len(buffer)) 68 | 69 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 70 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") 71 | 72 | return download_target 73 | 74 | 75 | def _convert_image_to_rgb(image): 76 | return image.convert("RGB") 77 | 78 | 79 | def _transform(n_px): 80 | return Compose([ 81 | Resize(n_px, interpolation=BICUBIC), 82 | CenterCrop(n_px), 83 | _convert_image_to_rgb, 84 | ToTensor(), 85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 86 | ]) 87 | 88 | 89 | def available_models() -> List[str]: 90 | """Returns the names of available CLIP models""" 91 | return list(_MODELS.keys()) 92 | 93 | 94 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 95 | """Load a CLIP model 96 | 97 | Parameters 98 | ---------- 99 | name : str 100 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 101 | 102 | device : Union[str, torch.device] 103 | The device to put the loaded model 104 | 105 | jit : bool 106 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 107 | 108 | download_root: str 109 | path to download the model files; by default, it uses "~/.cache/clip" 110 | 111 | Returns 112 | ------- 113 | model : torch.nn.Module 114 | The CLIP model 115 | 116 | preprocess : Callable[[PIL.Image], torch.Tensor] 117 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 118 | """ 119 | if name in _MODELS: 120 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 121 | elif os.path.isfile(name): 122 | model_path = name 123 | else: 124 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 125 | 126 | with open(model_path, 'rb') as opened_file: 127 | try: 128 | # loading JIT archive 129 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 130 | state_dict = None 131 | except RuntimeError: 132 | # loading saved state dict 133 | if jit: 134 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 135 | jit = False 136 | state_dict = torch.load(opened_file, map_location="cpu") 137 | 138 | if not jit: 139 | model = build_model(state_dict or model.state_dict()).to(device) 140 | if str(device) == "cpu": 141 | model.float() 142 | return model, _transform(model.visual.input_resolution) 143 | 144 | # patch the device names 145 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 146 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 147 | 148 | def patch_device(module): 149 | try: 150 | graphs = [module.graph] if hasattr(module, "graph") else [] 151 | except RuntimeError: 152 | graphs = [] 153 | 154 | if hasattr(module, "forward1"): 155 | graphs.append(module.forward1.graph) 156 | 157 | for graph in graphs: 158 | for node in graph.findAllNodes("prim::Constant"): 159 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 160 | node.copyAttributes(device_node) 161 | 162 | model.apply(patch_device) 163 | patch_device(model.encode_image) 164 | patch_device(model.encode_text) 165 | 166 | # patch dtype to float32 on CPU 167 | if str(device) == "cpu": 168 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 169 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 170 | float_node = float_input.node() 171 | 172 | def patch_float(module): 173 | try: 174 | graphs = [module.graph] if hasattr(module, "graph") else [] 175 | except RuntimeError: 176 | graphs = [] 177 | 178 | if hasattr(module, "forward1"): 179 | graphs.append(module.forward1.graph) 180 | 181 | for graph in graphs: 182 | for node in graph.findAllNodes("aten::to"): 183 | inputs = list(node.inputs()) 184 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 185 | if inputs[i].node()["value"] == 5: 186 | inputs[i].node().copyAttributes(float_node) 187 | 188 | model.apply(patch_float) 189 | patch_float(model.encode_image) 190 | patch_float(model.encode_text) 191 | 192 | model.float() 193 | 194 | return model, _transform(model.input_resolution.item()) 195 | 196 | 197 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: 198 | """ 199 | Returns the tokenized representation of given input string(s) 200 | 201 | Parameters 202 | ---------- 203 | texts : Union[str, List[str]] 204 | An input string or a list of input strings to tokenize 205 | 206 | context_length : int 207 | The context length to use; all CLIP models use 77 as the context length 208 | 209 | truncate: bool 210 | Whether to truncate the text in case its encoding is longer than the context length 211 | 212 | Returns 213 | ------- 214 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 215 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 216 | """ 217 | if isinstance(texts, str): 218 | texts = [texts] 219 | 220 | sot_token = _tokenizer.encoder["<|startoftext|>"] 221 | eot_token = _tokenizer.encoder["<|endoftext|>"] 222 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 223 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 224 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 225 | else: 226 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 227 | 228 | for i, tokens in enumerate(all_tokens): 229 | if len(tokens) > context_length: 230 | if truncate: 231 | tokens = tokens[:context_length] 232 | tokens[-1] = eot_token 233 | else: 234 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 235 | result[i, :len(tokens)] = torch.tensor(tokens) 236 | 237 | return result 238 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /main_standard.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.distributed as dist 5 | from torch.nn.parallel import DistributedDataParallel 6 | 7 | from utils.experiman import manager 8 | from data import * 9 | from models import get_clip_model 10 | from trainers import StandardTrainer, StandardLoopConfig 11 | from utils.misc import parse 12 | from utils.optim import get_optim 13 | 14 | 15 | def add_parser_argument(parser): 16 | ## ======================== Data ========================== 17 | parser.add_argument('--dataset', default='imagenet', type=str) 18 | parser.add_argument('--train_split', default='original', type=str) 19 | parser.add_argument('--val_size', default=10240, type=int) 20 | parser.add_argument('--data_split_seed', default=0, type=int) 21 | parser.add_argument('--batch', default=512, type=int) 22 | parser.add_argument('--num_workers', default=1, type=int) 23 | parser.add_argument('--image_size', default=224, type=int) 24 | parser.add_argument('--transform', default='clip', type=str) 25 | ## ======================= Model ========================== 26 | parser.add_argument('--arch', type=str) 27 | parser.add_argument('--arch_variant', default='zeroshot', type=str) 28 | parser.add_argument('--load_pretrained', type=str) 29 | parser.add_argument('--load_ckpt', type=str) 30 | parser.add_argument('--load_run_name', type=str) 31 | parser.add_argument('--load_run_number', type=str) 32 | parser.add_argument('--freeze_backbone', action='store_true') 33 | parser.add_argument('--sync_bn', action='store_true') 34 | ## ===================== Training ========================= 35 | parser.add_argument('--auto_resume', action='store_true') 36 | parser.add_argument('--resume_ckpt', type=str) 37 | parser.add_argument('--label_smooth', action='store_true') 38 | ## ==================== Optimization ====================== 39 | parser.add_argument('--epoch', default=10, type=int) 40 | parser.add_argument('--num_iters_train', type=int, 41 | help="default: len(trainloader)") 42 | parser.add_argument('--num_iters_test', type=int, 43 | help="default: len(testloader)") 44 | parser.add_argument('--num_iters_trainset_test', type=int, 45 | help="default: len(raw_trainloader)") 46 | parser.add_argument('--accum_steps', type=int, default=1) 47 | parser.add_argument('--lr', default=3e-5, type=float) 48 | parser.add_argument('--lr_bb', type=float) 49 | parser.add_argument('--lr_schedule', default='1cycle', type=str) 50 | parser.add_argument('--multistep_milestones', type=int, nargs='+') 51 | parser.add_argument('--optimizer', default='adamw', type=str) 52 | parser.add_argument('--adam_beta', default=0.9, type=float) 53 | parser.add_argument('--weight_decay', default=1e-1, type=float) 54 | parser.add_argument('--cyclic_step', type=float) 55 | parser.add_argument('--onecycle_pct_start', default=0.02, type=float) 56 | parser.add_argument('--grad_clip', default=1, type=float) 57 | ## ====================== Logging ========================= 58 | parser.add_argument('--log_period', default=5, type=int, metavar='LP', 59 | help='log every LP iterations') 60 | parser.add_argument('--ckpt_period', type=int, metavar='CP', 61 | help='make checkpoints every CP epochs') 62 | parser.add_argument('--test_period', default=1, type=int, metavar='TP', 63 | help='test every TP epochs') 64 | parser.add_argument('--trainset_test_period', type=int, metavar='TP', 65 | help='test on training set every TP epochs') 66 | parser.add_argument('--comment', default='', type=str) 67 | ## ==================== Experimental ====================== 68 | parser.add_argument('--wise_alpha', default=0.5, type=float) 69 | parser.add_argument('--wise_base_run_name', type=str) 70 | parser.add_argument('--wise_base_run_number', type=str) 71 | 72 | 73 | def main(): 74 | local_rank = int(os.environ["LOCAL_RANK"]) 75 | rank = int(os.environ["RANK"]) 76 | world_size = int(os.environ["WORLD_SIZE"]) 77 | device = torch.device(local_rank) 78 | torch.cuda.set_device(device) 79 | 80 | # Parse arguments and setup ExperiMan 81 | parser = manager.get_basic_arg_parser() 82 | add_parser_argument(parser) 83 | opt = parser.parse_args() 84 | if opt.resume_ckpt or opt.auto_resume: 85 | opt.option_for_existing_dir = 'k' 86 | manager.setup(opt, rank=rank, world_size=world_size, 87 | third_party_tools=('tensorboard',)) 88 | if world_size > 1: 89 | dist.init_process_group("nccl") 90 | if rank == 0: 91 | t = torch.tensor([opt.run_number + .1], device=device) 92 | else: 93 | t = torch.empty(1, device=device) 94 | dist.broadcast(t, src=0) 95 | opt.run_number = int(t.item()) 96 | manager.set_run_dir(manager.get_run_dir(opt.run_name, opt.run_number)) 97 | logger = manager.get_logger() 98 | logger.info(f'==> Number of devices: {world_size}') 99 | use_clip = opt.arch.startswith('clip') 100 | 101 | # Data 102 | logger.info('==> Preparing data') 103 | dataset = get_dataset(opt.dataset, opt.data_dir, size=opt.image_size, transform=opt.transform) 104 | assert opt.batch % world_size == 0 105 | batch = opt.batch // world_size 106 | data_kwargs = dict( 107 | batch_size=batch, num_workers=opt.num_workers, with_index=False, 108 | train_split=opt.train_split, val_size=opt.val_size, 109 | split_seed=opt.data_split_seed, 110 | world_size=world_size, rank=rank) 111 | if opt.val_size > 0: 112 | trainloader, raw_trainloader, valloader, testloader = \ 113 | dataset.get_loader(**data_kwargs) 114 | else: 115 | trainloader, raw_trainloader, testloader = \ 116 | dataset.get_loader(**data_kwargs) 117 | valloader = [] 118 | num_iters_train = parse(opt.num_iters_train, len(trainloader) // opt.accum_steps) 119 | num_iters_val = len(valloader) 120 | num_iters_trainset_test = parse(opt.num_iters_trainset_test, len(raw_trainloader)) 121 | num_iters_test = parse(opt.num_iters_test, len(testloader)) 122 | 123 | # Model 124 | logger.info('==> Building models') 125 | if use_clip: 126 | model = get_clip_model( 127 | arch=opt.arch, 128 | dataset=dataset, 129 | variant=opt.arch_variant, 130 | model_dir=opt.load_pretrained, 131 | device=device, 132 | get_zeroshot_weights=(not (opt.load_ckpt or opt.load_run_name)), 133 | ) 134 | else: 135 | raise NotImplementedError() 136 | if opt.freeze_backbone: 137 | model.freeze_backbone() 138 | if world_size > 1: 139 | if opt.sync_bn: 140 | logger.info('==> Using SyncBN') 141 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 142 | model = DistributedDataParallel(model, device_ids=[local_rank]) 143 | models = {'model': model} 144 | 145 | # Criterions 146 | criterions = {} 147 | criterions['classification'] = nn.CrossEntropyLoss() 148 | for criterion in criterions.values(): 149 | criterion.to(device) 150 | 151 | # Optimizer 152 | bare_model = model.module if world_size > 1 else model 153 | head_parameters = [ 154 | p for n, p in model.named_parameters() if 'backbone' not in n 155 | ] 156 | if opt.freeze_backbone: 157 | parameters = head_parameters 158 | elif opt.lr_bb is not None: 159 | parameters = [ 160 | {'params': bare_model.backbone.parameters(), 'lr': opt.lr_bb}, 161 | {'params': head_parameters} 162 | ] 163 | else: 164 | parameters = model.parameters() 165 | optimizer = get_optim( 166 | parameters=parameters, 167 | optimizer_name=opt.optimizer, 168 | lr=opt.lr, 169 | schedule=opt.lr_schedule, 170 | weight_decay=opt.weight_decay, 171 | num_epochs=opt.epoch, 172 | num_iters_train=num_iters_train, 173 | cyclic_stepsize=opt.cyclic_step, 174 | onecycle_pct_start=opt.onecycle_pct_start, 175 | multistep_milestones=opt.multistep_milestones, 176 | adam_beta=opt.adam_beta, 177 | ) 178 | optimizers = {'optimizer': optimizer} 179 | 180 | # Load 181 | resume_ckpt = None 182 | bare_model = model.module if world_size > 1 else model 183 | if opt.auto_resume: 184 | assert opt.resume_ckpt is None 185 | load_path = os.path.join(manager.get_checkpoint_dir(), 'ckpt-last.pt') 186 | if os.path.exists(load_path): 187 | opt.resume_ckpt = 'ckpt-last.pt' 188 | if opt.resume_ckpt: 189 | load_path = os.path.join(manager.get_checkpoint_dir(), opt.resume_ckpt) 190 | logger.info(f'==> Resume from checkpoint {load_path}') 191 | resume_ckpt = torch.load(load_path, map_location='cpu') 192 | elif opt.load_ckpt or opt.load_run_name: 193 | if opt.load_ckpt: 194 | load_path = opt.load_ckpt 195 | else: 196 | ckpt_dir = manager.get_checkpoint_dir( 197 | opt.load_run_name, opt.load_run_number) 198 | load_path = os.path.join(ckpt_dir, 'ckpt-last.pt') 199 | logger.info(f'==> Loading model from {load_path}') 200 | checkpoint = torch.load(load_path, map_location='cpu') 201 | bare_model.load_state_dict(checkpoint['model']) 202 | if opt.wise_base_run_name: 203 | ckpt_dir = manager.get_checkpoint_dir( 204 | opt.wise_base_run_name, opt.wise_base_run_number) 205 | load_path = os.path.join(ckpt_dir, 'ckpt-last.pt') 206 | logger.info(f'==> Loading WiSE base model from {load_path}') 207 | base_state_dict = torch.load(load_path, map_location='cpu')['model'] 208 | model_state_dict = checkpoint['model'] 209 | wise_state_dict = { 210 | name: (1 - opt.wise_alpha) * base_state_dict[name] + \ 211 | opt.wise_alpha * model_state_dict[name] 212 | for name in model_state_dict 213 | } 214 | bare_model.load_state_dict(wise_state_dict) 215 | elif opt.load_pretrained: 216 | if not use_clip: 217 | logger.info(f'==> Loading pretrained backbone from {opt.load_pretrained}') 218 | pretrained_dict = torch.load(opt.load_pretrained, map_location='cpu') 219 | bare_model.backbone.load_pretrained(pretrained_dict) 220 | else: 221 | logger.info(f'==> Will train from scratch') 222 | 223 | # Trainer 224 | loop_configs = [ 225 | StandardLoopConfig('train', dataset, trainloader, 226 | training=True, n_iterations=num_iters_train, 227 | n_computation_steps=opt.accum_steps), 228 | StandardLoopConfig('val', dataset, valloader, 229 | training=False, n_iterations=num_iters_val, 230 | for_best_meter=True), 231 | StandardLoopConfig('test-trainset', dataset, raw_trainloader, 232 | training=False, n_iterations=num_iters_trainset_test, 233 | run_every_n_epochs=opt.trainset_test_period, 234 | run_at_checkpoint=False), 235 | StandardLoopConfig('test-testset', dataset, testloader, 236 | training=False, n_iterations=num_iters_test, 237 | run_every_n_epochs=opt.test_period), 238 | ] 239 | trainer = StandardTrainer( 240 | manager=manager, 241 | models=models, 242 | criterions=criterions, 243 | n_epochs=opt.epoch, 244 | loop_configs=loop_configs, 245 | optimizers=optimizers, 246 | log_period=opt.log_period, 247 | ckpt_period=opt.ckpt_period, 248 | device=device, 249 | keep_eval_mode=opt.freeze_backbone, 250 | resume_ckpt=resume_ckpt, 251 | num_classes=dataset.num_classes, 252 | ) 253 | 254 | trainer.train() 255 | 256 | 257 | if __name__ == "__main__": 258 | # Set the environment variables if not launched by torchrun 259 | if 'RANK' not in os.environ: 260 | os.environ['RANK'] = '0' 261 | if 'LOCAL_RANK' not in os.environ: 262 | os.environ['LOCAL_RANK'] = os.environ['RANK'] 263 | if 'WORLD_SIZE' not in os.environ: 264 | os.environ['WORLD_SIZE'] = '1' 265 | if 'LOCAL_WORLD_SIZE' not in os.environ: 266 | os.environ['LOCAL_WORLD_SIZE'] = os.environ['WORLD_SIZE'] 267 | main() 268 | -------------------------------------------------------------------------------- /utils/masking.py: -------------------------------------------------------------------------------- 1 | import re 2 | from functools import partial 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torchvision.transforms.functional as TF 9 | from torchattacks.attack import Attack 10 | from torchvision.transforms import InterpolationMode 11 | 12 | 13 | def get_GradCAM(model, images, labels, upsample=None): 14 | # Ref: https://github.com/jacobgil/pytorch-grad-cam 15 | 16 | # Get activations and gradients for target layer 17 | bare_model = model.module if hasattr(model, 'module') else model 18 | target_layer = bare_model.backbone.transformer.resblocks[-1].ln_1 19 | activations = [] 20 | def save_activation(module, input, output): 21 | activations.append(output) 22 | handle = target_layer.register_forward_hook(save_activation) 23 | logits = model(images) 24 | handle.remove() 25 | loss = logits.gather(1, labels.unsqueeze(1)).sum() 26 | grad = torch.autograd.grad(loss, activations[0])[0] 27 | 28 | act = activations[0].detach() # L * N * C 29 | weights = grad.sum(dim=0, keepdim=True) # 1 * N * C 30 | cam = F.relu(torch.sum(weights * act, dim=2))[1:] # (L-1) * N 31 | if upsample: 32 | Np, N = cam.size() 33 | s = round(Np ** 0.5) 34 | cam = cam.T.reshape(N, s, s) 35 | cam = TF.resize(cam, images.size()[2:], upsample) # N * H * W 36 | # cam = cam / (cam.amax(dim=(1, 2), keepdim=True) + 1e-8) 37 | return cam 38 | 39 | 40 | def get_MMCAM(model, images, labels, upsample=None, get_patch_cam=False): 41 | # Ref: https://github.com/hila-chefer/Transformer-MM-Explainability 42 | 43 | # Get activations and gradients for each attention map 44 | attentions = [] 45 | logits = model(images, attn_out=attentions) 46 | loss = logits.gather(1, labels.unsqueeze(1)).sum() 47 | grads = torch.autograd.grad(loss, attentions) 48 | attentions = [attn.detach() for attn in attentions] 49 | 50 | # Compute CAM 51 | bare_model = model.module if hasattr(model, 'module') else model 52 | N = images.size(0) 53 | L = bare_model.backbone.positional_embedding.size(0) 54 | R = torch.eye(L, device=images.device).repeat(N, 1, 1) 55 | for attn, grad in zip(attentions, grads): 56 | A = (grad * attn).clamp(min=0) # (N * Nh) * L * L 57 | A = A.reshape(N, -1, L, L).mean(dim=1) # N * L * L 58 | R += torch.matmul(A, R) # N * L * L 59 | patch_cam = R[:, 0, 1:] 60 | cam = patch_cam.T # (L-1) * N 61 | if upsample: 62 | s = round((L - 1) ** 0.5) 63 | cam = patch_cam.reshape(N, s, s) 64 | cam = TF.resize(cam, images.size()[2:], upsample) # N * H * W 65 | if get_patch_cam: 66 | return cam, patch_cam 67 | return cam 68 | 69 | 70 | class RandMaskNoFill(Attack): 71 | 72 | def __init__(self, model, mask_rate): 73 | super().__init__("RandMaskNoFill", model) 74 | self.mask_rate = mask_rate 75 | bare_model = model.module if hasattr(model, 'module') else model 76 | self.n_patch = bare_model.backbone.positional_embedding.size(0) - 1 77 | self.n_keep = int(self.n_patch * (1 - mask_rate)) 78 | 79 | def forward(self, images, labels=None): 80 | mask = torch.stack([torch.randperm(self.n_patch)[:self.n_keep].to(self.device) 81 | for _ in images], dim=1) 82 | mask = torch.cat([torch.zeros_like(mask[:1]), mask + 1]) # CLS token 83 | return mask 84 | 85 | 86 | class CAMMaskNoFill(Attack): 87 | 88 | def __init__(self, model, cam_method, threshold, ctx_mask=False): 89 | super().__init__("CAMMaskNoFill", model) 90 | self.method = cam_method 91 | if cam_method == 'GradCAM': 92 | self.get_CAM = get_GradCAM 93 | elif cam_method == 'MMCAM': 94 | self.get_CAM = get_MMCAM 95 | else: 96 | raise NotImplementedError() 97 | self.threshold = threshold 98 | self.ctx_mask = ctx_mask 99 | bare_model = model.module if hasattr(model, 'module') else model 100 | self.n_head = bare_model.backbone.heads 101 | self.n_patch = bare_model.backbone.positional_embedding.size(0) - 1 102 | 103 | def forward(self, images, labels): 104 | images = images.clone().detach().to(self.device) 105 | labels = labels.clone().detach().to(self.device) 106 | cam = self.get_CAM(self.model, images, labels) 107 | cam = cam / (cam.amax(dim=0, keepdim=True) + 1e-8) 108 | mask = (cam > self.threshold) # (L-1) * N 109 | # print(torch.count_nonzero(cam > self.threshold, dim=0) / mask.size(0)) 110 | mask = mask.T # N * (L-1) 111 | if self.ctx_mask: 112 | mask = ~mask 113 | attn_mask = torch.zeros_like(mask, dtype=float) 114 | attn_mask[mask] = float('-inf') # N * (L-1) 115 | attn_mask = torch.cat( 116 | [torch.zeros_like(attn_mask[:, :1]), attn_mask], dim=1) # N * L 117 | attn_mask = attn_mask.unsqueeze(1).repeat_interleave(self.n_head, dim=0).repeat(1, self.n_patch + 1, 1) 118 | return attn_mask 119 | 120 | 121 | class RandMaskSingleFill(nn.Module): 122 | 123 | def __init__(self, model, prob): 124 | super().__init__() 125 | self.prob = prob 126 | bare_model = model.module if hasattr(model, 'module') else model 127 | n_patch = bare_model.backbone.positional_embedding.size(0) - 1 128 | self.w = round(n_patch ** 0.5) 129 | 130 | def forward(self, images, labels=None): 131 | N, _, H, W = images.size() 132 | p = torch.empty(N, 1, self.w, self.w, device=images.device).uniform_() 133 | mask = TF.resize(p, (H, W), InterpolationMode.NEAREST) < self.prob 134 | idx = torch.randperm(N, device=images.device) 135 | return images * (~mask) + images[idx] * (mask) 136 | 137 | 138 | class CAMMaskSingleFill(Attack): 139 | 140 | def __init__(self, model, cam_method, threshold, ctx_mask=False, save_mask=False): 141 | super().__init__("CAMMaskSingleFill", model) 142 | self.method = cam_method 143 | if cam_method == 'GradCAM': 144 | self.get_CAM = partial( 145 | get_GradCAM, upsample=InterpolationMode.NEAREST) 146 | elif cam_method == 'MMCAM': 147 | self.get_CAM = partial( 148 | get_MMCAM, upsample=InterpolationMode.NEAREST, get_patch_cam=save_mask) 149 | else: 150 | raise NotImplementedError() 151 | self.threshold = threshold 152 | self.ctx_mask = ctx_mask 153 | self.save_mask = save_mask 154 | 155 | def forward(self, images, labels): 156 | images = images.clone().detach().to(self.device) 157 | labels = labels.clone().detach().to(self.device) 158 | cam = self.get_CAM(self.model, images, labels) 159 | if self.save_mask: 160 | cam, patch_cam = cam 161 | patch_cam = patch_cam / (patch_cam.amax(dim=1, keepdim=True) + 1e-8) 162 | patch_mask = (patch_cam > self.threshold) 163 | cam = cam / (cam.amax(dim=(1, 2), keepdim=True) + 1e-8) 164 | mask = (cam > self.threshold).unsqueeze(1) 165 | if self.ctx_mask: 166 | mask = ~mask 167 | if self.save_mask: 168 | patch_mask = ~patch_mask 169 | idx = torch.randperm(images.size()[0], device=self.device) 170 | images_masked = images * (~mask) + images[idx] * (mask) 171 | if self.save_mask: 172 | return images_masked, patch_mask 173 | return images_masked 174 | 175 | 176 | class RandMaskMultiFill(nn.Module): 177 | 178 | def __init__(self, model, prob): 179 | super().__init__() 180 | self.prob = prob 181 | bare_model = model.module if hasattr(model, 'module') else model 182 | n_patch = bare_model.backbone.positional_embedding.size(0) - 1 183 | self.w = round(n_patch ** 0.5) 184 | 185 | def forward(self, images, labels=None): 186 | N, C, H, W = images.size() 187 | w = self.w 188 | device = images.device 189 | p = torch.empty(N, 1, w, w, device=device).uniform_() 190 | mask = TF.resize(p, (H, W), InterpolationMode.NEAREST) < self.prob 191 | idx_shift = torch.randint(1, N, (N, 1, w, w), device=device) 192 | idx = (torch.arange(N, device=device).resize(N, 1, 1, 1) + idx_shift) % N 193 | # idx = idx.repeat_interleave(H / w, dim=2).repeat_interleave(W / w, dim=3) 194 | idx = TF.resize(idx, (H, W), InterpolationMode.NEAREST) 195 | idx = idx.repeat(1, C, 1, 1) 196 | images_m = images.gather(0, idx) 197 | return images * (~mask) + images_m * (mask) 198 | 199 | 200 | class CAMMaskMultiFill(Attack): 201 | 202 | def __init__(self, model, cam_method, threshold, ctx_mask=False): 203 | super().__init__("CAMMaskMultiFill", model) 204 | self.method = cam_method 205 | self.threshold = threshold 206 | bare_model = model.module if hasattr(model, 'module') else model 207 | n_patch = bare_model.backbone.positional_embedding.size(0) - 1 208 | self.w = round(n_patch ** 0.5) 209 | if cam_method == 'GradCAM': 210 | self.get_CAM = partial( 211 | get_GradCAM, upsample=InterpolationMode.NEAREST) 212 | elif cam_method == 'MMCAM': 213 | self.get_CAM = partial( 214 | get_MMCAM, upsample=InterpolationMode.NEAREST) 215 | else: 216 | raise NotImplementedError() 217 | self.ctx_mask = ctx_mask 218 | 219 | def forward(self, images, labels): 220 | images = images.clone().detach().to(self.device) 221 | labels = labels.clone().detach().to(self.device) 222 | 223 | cam = self.get_CAM(self.model, images, labels) 224 | cam = cam / (cam.amax(dim=(1, 2), keepdim=True) + 1e-8) 225 | N, C, H, W = images.size() 226 | w = self.w 227 | device = self.device 228 | mask = (cam > self.threshold).unsqueeze(1) 229 | if self.ctx_mask: 230 | mask = ~mask 231 | 232 | idx_shift = torch.randint(1, N, (N, 1, w, w), device=device) 233 | idx = (torch.arange(N, device=device).resize(N, 1, 1, 1) + idx_shift) % N 234 | idx = TF.resize(idx, (H, W), InterpolationMode.NEAREST) 235 | idx = idx.repeat(1, C, 1, 1) 236 | images_m = images.gather(0, idx) 237 | 238 | images_masked = images * (~mask) + images_m * (mask) 239 | return images_masked 240 | 241 | 242 | def get_masking(name: str = None, **kwargs): 243 | 244 | cam_method = 'MMCAM' 245 | def get_params(s): 246 | # return the string inside the brackets 247 | return re.search(r'\((.*?)\)', s).group(1) 248 | 249 | if name is None or name == 'none': 250 | return None 251 | 252 | elif name.startswith('RandMaskNoFill'): # RandMaskNoFill(mask_rate) 253 | mask_rate = float(get_params(name)) 254 | return RandMaskNoFill(kwargs['model'], mask_rate) 255 | elif name.startswith('ObjMaskNoFill'): # ObjMaskNoFill(threshold) 256 | threshold = get_params(name) 257 | return CAMMaskNoFill(kwargs['model'], cam_method, float(threshold)) 258 | elif name.startswith('CtxMaskNoFill'): # CtxMaskNoFill(threshold) 259 | threshold = get_params(name) 260 | return CAMMaskNoFill(kwargs['model'], cam_method, float(threshold), ctx_mask=True) 261 | 262 | elif name.startswith('RandMaskSingleFill'): # RandMaskSingleFill(prob) 263 | prob = float(get_params(name)) 264 | return RandMaskSingleFill(kwargs['model'], prob) 265 | elif name.startswith('ObjMaskSingleFill'): # ObjMaskSingleFill(threshold) 266 | threshold = get_params(name) 267 | return CAMMaskSingleFill(kwargs['model'], cam_method, float(threshold), save_mask=kwargs['save_mask']) 268 | elif name.startswith('CtxMaskSingleFill'): # CtxMaskSingleFill(threshold) 269 | threshold = get_params(name) 270 | return CAMMaskSingleFill(kwargs['model'], cam_method, float(threshold), ctx_mask=True) 271 | 272 | elif name.startswith('RandMaskMultiFill'): # RandMaskMultiFill(prob) 273 | prob = float(get_params(name)) 274 | return RandMaskMultiFill(kwargs['model'], prob) 275 | elif name.startswith('ObjMaskMultiFill'): # ObjMaskMultiFill(threshold) 276 | threshold = get_params(name) 277 | return CAMMaskMultiFill(kwargs['model'], cam_method, float(threshold)) 278 | elif name.startswith('CtxMaskMultiFill'): # CtxMaskMultiFill(threshold) 279 | threshold = get_params(name) 280 | return CAMMaskMultiFill(kwargs['model'], cam_method, float(threshold), ctx_mask=True) 281 | 282 | else: 283 | raise NotImplementedError() -------------------------------------------------------------------------------- /main_distill.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torch.distributed as dist 6 | from torch.nn.parallel import DistributedDataParallel 7 | 8 | from utils.experiman import manager 9 | from data import * 10 | from losses.distill import KnowledgeDistillationLoss 11 | from models import get_clip_model 12 | from trainers import DistillTrainer, StandardLoopConfig 13 | from utils.masking import get_masking 14 | from utils.misc import parse 15 | from utils.optim import get_optim 16 | 17 | 18 | def add_parser_argument(parser): 19 | ## ======================== Data ========================== 20 | parser.add_argument('--dataset', default='imagenet', type=str) 21 | parser.add_argument('--train_split', default='original', type=str) 22 | parser.add_argument('--val_size', default=10240, type=int) 23 | parser.add_argument('--data_split_seed', default=0, type=int) 24 | parser.add_argument('--batch', default=512, type=int) 25 | parser.add_argument('--num_workers', default=1, type=int) 26 | parser.add_argument('--image_size', default=224, type=int) 27 | parser.add_argument('--transform', default='clip', type=str) 28 | ## ======================= Model ========================== 29 | parser.add_argument('--arch', type=str) 30 | parser.add_argument('--arch_variant', default='zeroshot', type=str) 31 | parser.add_argument('--load_pretrained', type=str) 32 | parser.add_argument('--load_ckpt', type=str) 33 | parser.add_argument('--load_run_name', type=str) 34 | parser.add_argument('--load_run_number', type=str) 35 | parser.add_argument('--teacher_run_name', type=str) 36 | parser.add_argument('--teacher_run_number', type=str) 37 | parser.add_argument('--freeze_backbone', action='store_true') 38 | parser.add_argument('--sync_bn', action='store_true') 39 | ## ===================== Training ========================= 40 | parser.add_argument('--auto_resume', action='store_true') 41 | parser.add_argument('--resume_ckpt', type=str) 42 | parser.add_argument('--label_smooth', action='store_true') 43 | parser.add_argument('--task', type=str) 44 | parser.add_argument('--distill_mode', type=str) 45 | parser.add_argument('--kd_temp', default=10, type=float) 46 | parser.add_argument('--w_task', default=1, type=float) 47 | parser.add_argument('--w_distill', default=30, type=float) 48 | ## ==================== Optimization ====================== 49 | parser.add_argument('--epoch', default=10, type=int) 50 | parser.add_argument('--num_iters_train', type=int, 51 | help="default: len(trainloader)") 52 | parser.add_argument('--num_iters_test', type=int, 53 | help="default: len(testloader)") 54 | parser.add_argument('--num_iters_trainset_test', type=int, 55 | help="default: len(raw_trainloader)") 56 | parser.add_argument('--accum_steps', type=int, default=1) 57 | parser.add_argument('--lr', default=3e-5, type=float) 58 | parser.add_argument('--lr_bb', type=float) 59 | parser.add_argument('--lr_schedule', default='1cycle', type=str) 60 | parser.add_argument('--multistep_milestones', type=int, nargs='+') 61 | parser.add_argument('--optimizer', default='adamw', type=str) 62 | parser.add_argument('--adam_beta', default=0.9, type=float) 63 | parser.add_argument('--weight_decay', default=1e-1, type=float) 64 | parser.add_argument('--cyclic_step', type=float) 65 | parser.add_argument('--onecycle_pct_start', default=0.02, type=float) 66 | parser.add_argument('--grad_clip', default=1, type=float) 67 | ## ====================== Logging ========================= 68 | parser.add_argument('--log_period', default=5, type=int, metavar='LP', 69 | help='log every LP iterations') 70 | parser.add_argument('--ckpt_period', type=int, metavar='CP', 71 | help='make checkpoints every CP epochs') 72 | parser.add_argument('--test_period', default=1, type=int, metavar='TP', 73 | help='test every TP epochs') 74 | parser.add_argument('--trainset_test_period', type=int, metavar='TP', 75 | help='test on training set every TP epochs') 76 | parser.add_argument('--comment', default='', type=str) 77 | ## ==================== Experimental ====================== 78 | parser.add_argument('--distill_masking', type=str) 79 | parser.add_argument('--save_mask', action='store_true') 80 | 81 | 82 | def main(): 83 | local_rank = int(os.environ["LOCAL_RANK"]) 84 | rank = int(os.environ["RANK"]) 85 | world_size = int(os.environ["WORLD_SIZE"]) 86 | device = torch.device(local_rank) 87 | torch.cuda.set_device(device) 88 | 89 | # Parse arguments and setup ExperiMan 90 | parser = manager.get_basic_arg_parser() 91 | add_parser_argument(parser) 92 | opt = parser.parse_args() 93 | if opt.resume_ckpt or opt.auto_resume: 94 | opt.option_for_existing_dir = 'k' 95 | manager.setup(opt, rank=rank, world_size=world_size, 96 | third_party_tools=('tensorboard',)) 97 | if world_size > 1: 98 | dist.init_process_group("nccl") 99 | if rank == 0: 100 | t = torch.tensor([opt.run_number + .1], device=device) 101 | else: 102 | t = torch.empty(1, device=device) 103 | dist.broadcast(t, src=0) 104 | opt.run_number = int(t.item()) 105 | manager.set_run_dir(manager.get_run_dir(opt.run_name, opt.run_number)) 106 | logger = manager.get_logger() 107 | logger.info(f'==> Number of devices: {world_size}') 108 | use_clip = opt.arch.startswith('clip') 109 | 110 | # Data 111 | logger.info('==> Preparing data') 112 | dataset = get_dataset(opt.dataset, opt.data_dir, size=opt.image_size, transform=opt.transform) 113 | assert opt.batch % world_size == 0 114 | batch = opt.batch // world_size 115 | data_kwargs = dict( 116 | batch_size=batch, num_workers=opt.num_workers, 117 | with_index=opt.save_mask, 118 | train_split=opt.train_split, val_size=opt.val_size, 119 | split_seed=opt.data_split_seed, 120 | world_size=world_size, rank=rank) 121 | if opt.val_size > 0: 122 | trainloader, raw_trainloader, valloader, testloader = \ 123 | dataset.get_loader(**data_kwargs) 124 | else: 125 | trainloader, raw_trainloader, testloader = \ 126 | dataset.get_loader(**data_kwargs) 127 | valloader = [] 128 | num_iters_train = parse(opt.num_iters_train, len(trainloader) // opt.accum_steps) 129 | num_iters_val = len(valloader) 130 | num_iters_trainset_test = parse(opt.num_iters_trainset_test, len(raw_trainloader)) 131 | num_iters_test = parse(opt.num_iters_test, len(testloader)) 132 | 133 | # Model 134 | logger.info('==> Building models') 135 | if use_clip: 136 | model = get_clip_model( 137 | arch=opt.arch, 138 | dataset=dataset, 139 | variant=opt.arch_variant, 140 | model_dir=opt.load_pretrained, 141 | device=device, 142 | get_zeroshot_weights=(not (opt.load_ckpt or opt.load_run_name)), 143 | ) 144 | else: 145 | raise NotImplementedError() 146 | if opt.freeze_backbone: 147 | model.freeze_backbone() 148 | if world_size > 1: 149 | if opt.sync_bn: 150 | logger.info('==> Using SyncBN') 151 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 152 | model = DistributedDataParallel(model, device_ids=[local_rank]) 153 | models = {'model': model} 154 | bare_model = model.module if world_size > 1 else model 155 | 156 | # Criterions 157 | criterions = {} 158 | criterions['classification'] = nn.CrossEntropyLoss() 159 | criterions['knowledge_distillation'] = KnowledgeDistillationLoss(opt.kd_temp) 160 | criterions['feature_distillation'] = nn.MSELoss() 161 | for criterion in criterions.values(): 162 | criterion.to(device) 163 | 164 | # Optimizer 165 | head_parameters = [ 166 | p for n, p in model.named_parameters() if 'backbone' not in n 167 | ] 168 | if opt.freeze_backbone: 169 | parameters = head_parameters 170 | elif opt.lr_bb is not None: 171 | parameters = [ 172 | {'params': bare_model.backbone.parameters(), 'lr': opt.lr_bb}, 173 | {'params': head_parameters} 174 | ] 175 | else: 176 | parameters = model.parameters() 177 | optimizer = get_optim( 178 | parameters=parameters, 179 | optimizer_name=opt.optimizer, 180 | lr=opt.lr, 181 | schedule=opt.lr_schedule, 182 | weight_decay=opt.weight_decay, 183 | num_epochs=opt.epoch, 184 | num_iters_train=num_iters_train, 185 | cyclic_stepsize=opt.cyclic_step, 186 | onecycle_pct_start=opt.onecycle_pct_start, 187 | multistep_milestones=opt.multistep_milestones, 188 | adam_beta=opt.adam_beta, 189 | ) 190 | optimizers = {'optimizer': optimizer} 191 | 192 | # Load 193 | resume_ckpt = None 194 | bare_model = model.module if world_size > 1 else model 195 | if opt.auto_resume: 196 | assert opt.resume_ckpt is None 197 | load_path = os.path.join(manager.get_checkpoint_dir(), 'ckpt-last.pt') 198 | if os.path.exists(load_path): 199 | opt.resume_ckpt = 'ckpt-last.pt' 200 | if opt.resume_ckpt: 201 | load_path = os.path.join(manager.get_checkpoint_dir(), opt.resume_ckpt) 202 | logger.info(f'==> Resume from checkpoint {load_path}') 203 | resume_ckpt = torch.load(load_path, map_location='cpu') 204 | elif opt.load_ckpt or opt.load_run_name: 205 | if opt.load_ckpt: 206 | load_path = opt.load_ckpt 207 | else: 208 | ckpt_dir = manager.get_checkpoint_dir( 209 | opt.load_run_name, opt.load_run_number) 210 | load_path = os.path.join(ckpt_dir, 'ckpt-last.pt') 211 | logger.info(f'==> Loading model from {load_path}') 212 | checkpoint = torch.load(load_path, map_location='cpu') 213 | bare_model.load_state_dict(checkpoint['model']) 214 | elif opt.load_pretrained: 215 | if not use_clip: 216 | logger.info(f'==> Loading pretrained backbone from {opt.load_pretrained}') 217 | pretrained_dict = torch.load(opt.load_pretrained, map_location='cpu') 218 | bare_model.backbone.load_pretrained(pretrained_dict) 219 | else: 220 | logger.info(f'==> Will train from scratch') 221 | 222 | # Teacher model 223 | def load_teacher_model(run_name, run_number): 224 | teacher = copy.deepcopy(model).eval() 225 | ckpt_dir = manager.get_checkpoint_dir(run_name, run_number) 226 | load_path = os.path.join(ckpt_dir, 'ckpt-best.pt') 227 | logger.info(f'==> Loading teacher model from {load_path}') 228 | checkpoint = torch.load(load_path, map_location='cpu') 229 | logger.info(f'==> Teacher test acc: {checkpoint["test_acc"]}') 230 | bare_teacher = teacher.module if world_size > 1 else teacher 231 | bare_teacher.load_state_dict(checkpoint['model']) 232 | return teacher 233 | teacher = load_teacher_model(opt.teacher_run_name, opt.teacher_run_number) 234 | 235 | # Masking 236 | masking = get_masking(opt.distill_masking, model=model, save_mask=opt.save_mask) 237 | logger.info(f'Distillation masking: {masking}') 238 | 239 | # Trainer 240 | loop_configs = [ 241 | StandardLoopConfig('train', dataset, trainloader, 242 | training=True, n_iterations=num_iters_train, 243 | n_computation_steps=opt.accum_steps), 244 | StandardLoopConfig('val', dataset, valloader, 245 | training=False, n_iterations=num_iters_val, 246 | for_best_meter=True), 247 | StandardLoopConfig('test-trainset', dataset, raw_trainloader, 248 | training=False, n_iterations=num_iters_trainset_test, 249 | run_every_n_epochs=opt.trainset_test_period, 250 | run_at_checkpoint=False), 251 | StandardLoopConfig('test-testset', dataset, testloader, 252 | training=False, n_iterations=num_iters_test, 253 | run_every_n_epochs=opt.test_period), 254 | ] 255 | trainer = DistillTrainer( 256 | manager=manager, 257 | models=models, 258 | criterions=criterions, 259 | n_epochs=opt.epoch, 260 | loop_configs=loop_configs, 261 | optimizers=optimizers, 262 | log_period=opt.log_period, 263 | ckpt_period=opt.ckpt_period, 264 | device=device, 265 | keep_eval_mode=opt.freeze_backbone, 266 | resume_ckpt=resume_ckpt, 267 | num_classes=dataset.num_classes, 268 | teacher=teacher, 269 | masking=masking, 270 | ) 271 | 272 | trainer.train() 273 | 274 | 275 | if __name__ == "__main__": 276 | # Set the environment variables if not launched by torchrun 277 | if 'RANK' not in os.environ: 278 | os.environ['RANK'] = '0' 279 | if 'LOCAL_RANK' not in os.environ: 280 | os.environ['LOCAL_RANK'] = os.environ['RANK'] 281 | if 'WORLD_SIZE' not in os.environ: 282 | os.environ['WORLD_SIZE'] = '1' 283 | if 'LOCAL_WORLD_SIZE' not in os.environ: 284 | os.environ['LOCAL_WORLD_SIZE'] = os.environ['WORLD_SIZE'] 285 | main() 286 | -------------------------------------------------------------------------------- /utils/experiman.py: -------------------------------------------------------------------------------- 1 | """ 2 | Experiment manager, a helper aimed for deep learning code. 3 | """ 4 | 5 | import argparse 6 | from collections import OrderedDict 7 | from datetime import datetime 8 | from fnmatch import fnmatch 9 | import logging 10 | import os 11 | import shutil 12 | import sys 13 | import json 14 | import base64 15 | import tarfile 16 | 17 | # import aim 18 | import torch 19 | import numpy as np 20 | from torch.utils.tensorboard import SummaryWriter 21 | from torch.utils.tensorboard.summary import hparams 22 | 23 | 24 | def _generate_short_uid(length): 25 | assert length < 43 26 | return base64.urlsafe_b64encode(os.urandom(32)).decode()[:length] 27 | 28 | 29 | class _SummaryWriter(SummaryWriter): 30 | """ 31 | Enable writing hparams and scalars using the same writer. 32 | [bug] Hparams do not show in the hparams tab of tensorboard (although they 33 | can be exported as CSV / JSON file.) 34 | """ 35 | def add_hparams(self, hparam_dict, metric_dict): 36 | torch._C._log_api_usage_once("tensorboard.logging.add_hparams") 37 | if type(hparam_dict) is not dict or type(metric_dict) is not dict: 38 | raise TypeError('hparam_dict and metric_dict should be dictionary.') 39 | exp, ssi, sei = hparams(hparam_dict, metric_dict) 40 | 41 | self.file_writer.add_summary(exp) 42 | self.file_writer.add_summary(ssi) 43 | self.file_writer.add_summary(sei) 44 | for k, v in metric_dict.items(): 45 | self.add_scalar(k, v) 46 | 47 | 48 | class _ArgumentParser(argparse.ArgumentParser): 49 | def convert_arg_line_to_args(self, arg_line): 50 | args = arg_line.split() 51 | # treat lines that starts with '#' as comments 52 | if args and args[0].startswith('#'): 53 | args = [] 54 | return args 55 | 56 | 57 | class _NullLogger(logging.Logger): 58 | def __init__(self): 59 | super().__init__('_null') 60 | self.disabled = True 61 | 62 | 63 | class ExperiMan(object): 64 | 65 | def __init__(self, name): 66 | self._name = name 67 | self._rank = 0 68 | self._world_size = 1 69 | self._logger = _NullLogger() 70 | self._opt = None 71 | self._uid = None 72 | self._exp_dir = None 73 | self._run_dir = None 74 | self._third_party_tools = [] 75 | self._keep_existing_dir = False 76 | 77 | def _get_run_number(self, run_root_dir, opt_run_number): 78 | if opt_run_number in ('new', 'last'): 79 | if os.path.exists(run_root_dir): 80 | current_numbers = [int(x) for x in os.listdir(run_root_dir) if x.isdigit()] 81 | if current_numbers: # run_root_dir not empty 82 | if opt_run_number == 'new': 83 | run_number = max(current_numbers) + 1 84 | else: 85 | run_number = max(current_numbers) 86 | else: 87 | if opt_run_number == 'new': 88 | run_number = 0 89 | else: 90 | raise OSError(f"{run_root_dir} is empty!") 91 | else: # run_root_dir does not exist 92 | if opt_run_number == 'new': 93 | run_number = 0 94 | else: 95 | raise OSError(f"{run_root_dir} does not exist!") 96 | else: # manual number 97 | assert opt_run_number.isdigit(), "`run_number` is not a valid number" 98 | run_number = int(opt_run_number) 99 | return run_number 100 | 101 | def _setup_dirs(self): 102 | opt = self._opt 103 | # exp_dir: direcotry for the experiment 104 | exp_dir = os.path.join(opt.log_dir, opt.exp_name) 105 | os.makedirs(exp_dir, exist_ok=True) 106 | self._exp_dir = exp_dir 107 | # run_dir: directory for the run 108 | run_root_dir = os.path.join(exp_dir, opt.run_name) 109 | opt.run_number = self._get_run_number(run_root_dir, opt.run_number) 110 | run_dir = os.path.join(run_root_dir, str(opt.run_number)) 111 | if os.path.exists(run_dir): 112 | if opt.option_for_existing_dir: 113 | op = opt.option_for_existing_dir 114 | else: 115 | print(f"Directory {run_dir} exists, please choose an option:") 116 | op = input("b (backup) / k (keep) / d (delete) / n (new) / q (quit): ") 117 | if op == 'b': 118 | with open(os.path.join(run_dir, 'args.json'), 'r') as fp: 119 | old_opt = json.load(fp) 120 | d_backup = run_dir + f"-backup-({old_opt['uid']})" 121 | shutil.move(run_dir, d_backup) 122 | print(f"Old files backuped to {d_backup}.") 123 | elif op == 'k': 124 | self._keep_existing_dir = True 125 | print("Old files kept unchanged.") 126 | elif op == 'd': 127 | shutil.rmtree(run_dir) 128 | print("Old files deleted.") 129 | # if 'aim' in self._third_party_tools: 130 | # aim_dir = os.path.join( 131 | # opt.log_dir, '.aim', opt.exp_name, old_opt['uid']) 132 | # shutil.rmtree(aim_dir) 133 | # print(f"Aim dir {aim_dir} deleted.") 134 | elif op == 'n': 135 | opt.run_number = self._get_run_number(run_root_dir, 'new') 136 | print(f"New run number: {opt.run_number}") 137 | run_dir = os.path.join(run_root_dir, str(opt.run_number)) 138 | else: 139 | raise OSError("Quit without changes.") 140 | os.makedirs(run_dir, exist_ok=True) 141 | print(f"==> Directory for this run: {run_dir}") 142 | self._run_dir = run_dir 143 | # checkpoint_dir: directory for the checkpoints of the run 144 | checkpoint_dir = self.get_checkpoint_dir() 145 | os.makedirs(checkpoint_dir, exist_ok=True) 146 | 147 | def _setup_uid(self): 148 | self._uid = '-'.join([datetime.now().strftime('%y%m%d-%H%M%S'), 149 | _generate_short_uid(length=6)]) 150 | self._opt.uid = self._uid 151 | print(f"==> UID of this run: {self._uid}") 152 | 153 | def _setup_logger(self): 154 | self._logger = logging.getLogger(name=self._name) 155 | self._logger.propagate = False 156 | self._logger.setLevel(logging.DEBUG) 157 | # Stdout handler 158 | formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s') 159 | ch = logging.StreamHandler(sys.stdout) 160 | ch.setLevel(logging.DEBUG) 161 | ch.setFormatter(formatter) 162 | self._logger.addHandler(ch) 163 | # Log file handler 164 | formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s') 165 | filename = "log.log" 166 | path = os.path.join(self._run_dir, filename) 167 | fh = logging.FileHandler(path, encoding='utf-8') 168 | fh.setLevel(logging.DEBUG) 169 | fh.setFormatter(formatter) 170 | self._logger.addHandler(fh) 171 | 172 | def _backup_code(self): 173 | code_dir = self._opt.code_dir 174 | def exclude(tarinfo): 175 | patterns = ['*__pycache__', '*.git', '*pymp-*'] 176 | path = tarinfo.name 177 | for pattern in patterns: 178 | if fnmatch(path, pattern): 179 | return None 180 | return tarinfo 181 | if code_dir is not None: 182 | arcname = f"code-{self._uid}" 183 | path = os.path.join(self._run_dir, f"{arcname}.tar") 184 | with tarfile.open(path, 'w') as tar: 185 | tar.add(code_dir, arcname=arcname, filter=exclude) 186 | else: 187 | self._logger.warning( 188 | "Argument --code_dir unspecified, code will not be backuped.") 189 | 190 | def _setup_seed(self): 191 | np.random.seed(self._opt.seed) 192 | torch.manual_seed(self._opt.seed) 193 | torch.cuda.manual_seed_all(self._opt.seed) 194 | 195 | def _setup_torch(self): 196 | torch.backends.cudnn.enabled = True 197 | torch.backends.cudnn.benchmark = True 198 | 199 | def _setup_third_party_tools(self): 200 | if 'tensorboard' in self._third_party_tools: 201 | self._tensorboard_writer = _SummaryWriter( 202 | log_dir=self._run_dir, 203 | max_queue=100, 204 | flush_secs=60, 205 | purge_step=0, 206 | ) 207 | # if 'aim' in self._third_party_tools: 208 | # self._aim_session = aim.Session( 209 | # repo=self._opt.log_dir, 210 | # experiment=self._opt.exp_name, 211 | # flush_frequency=128, 212 | # block_termination=True, 213 | # run=self._uid, 214 | # ) 215 | 216 | def _export_arguments(self): 217 | escape_opts = ['code_dir', 'data_dir', 'log_dir', 218 | 'option_for_existing_dir'] 219 | opt = vars(self._opt).copy() 220 | for opt_name in escape_opts: 221 | opt.pop(opt_name) 222 | self._logger.info(f"Opts: {opt}") 223 | with open(os.path.join(self._run_dir, 'argv.txt'), 'a') as f: 224 | print(sys.argv, file=f) 225 | if not self._keep_existing_dir: 226 | with open(os.path.join(self._run_dir, 'args.json'), 'a') as f: 227 | json.dump(opt, fp=f, indent=4) 228 | if 'tensorboard' in self._third_party_tools: 229 | tb_opt_dict = {} 230 | for name, value in opt.items(): 231 | if type(value) is list: 232 | tb_opt_dict[name] = torch.tensor(value) 233 | else: 234 | tb_opt_dict[name] = value 235 | self._tensorboard_writer.add_hparams(tb_opt_dict, {}) 236 | # if 'aim' in self._third_party_tools: 237 | # self._aim_session.set_params(opt_dict, name='hparams') 238 | 239 | def get_basic_arg_parser(self): 240 | parser = _ArgumentParser(fromfile_prefix_chars='@') 241 | parser.add_argument('--code_dir', type=str, help="code dir (for backup)") 242 | parser.add_argument('--data_dir', type=str, help="data dir") 243 | parser.add_argument('--log_dir', type=str, help="root dir for logging") 244 | parser.add_argument('--exp_name', type=str, help="name of the experiment") 245 | parser.add_argument('--run_name', type=str, help="name of this run") 246 | parser.add_argument('--run_number', type=str, default='0', 247 | help="Number of this run. Choices: {new, last, MANUAL_NUMBER}") 248 | parser.add_argument('--seed', type=int, help="random seed") 249 | parser.add_argument('--option_for_existing_dir', '-O', type=str, 250 | help="Specify the option for existing run_dir:" + 251 | " b (backup) / k (keep) / d (delete) / n (new) / q (quit)") 252 | return parser 253 | 254 | def setup(self, opt, rank=0, world_size=1, third_party_tools=None, setup_logging=None): 255 | self._opt = opt 256 | self._rank = rank 257 | self._world_size = world_size 258 | if third_party_tools: 259 | self._third_party_tools = third_party_tools 260 | self._setup_torch() 261 | if opt.seed is not None: 262 | self._setup_seed() 263 | if setup_logging is None: 264 | setup_logging = self.is_master() 265 | if setup_logging: 266 | self._setup_uid() 267 | self._setup_dirs() 268 | self._setup_logger() 269 | self._backup_code() 270 | self._setup_third_party_tools() 271 | self._export_arguments() 272 | else: 273 | self._exp_dir = os.path.join(opt.log_dir, opt.exp_name) 274 | 275 | def set_run_dir(self, run_dir): 276 | if self._run_dir is not None and self._run_dir != run_dir: 277 | raise ValueError("Run dir is already set.") 278 | self._run_dir = run_dir 279 | 280 | def get_opt(self): 281 | return self._opt 282 | 283 | def get_run_dir(self, run_name=None, run_number=None): 284 | """ 285 | If run_name is None, return the directory for this run. 286 | Otherwise, return the directory of the specified run. 287 | (run_number defaults to 0) 288 | """ 289 | if run_name is None: 290 | run_dir = self._run_dir 291 | else: 292 | if run_number is None: 293 | run_number = '0' 294 | run_dir = os.path.join(self._exp_dir, run_name, str(run_number)) 295 | return run_dir 296 | 297 | def get_checkpoint_dir(self, run_name=None, run_number=None): 298 | """ 299 | If run_name is None, return the checkpoint directory for this run. 300 | Otherwise, return the checkpoint directory of the specified run. 301 | (run_number defaults to 0) 302 | """ 303 | run_dir = self.get_run_dir(run_name, run_number) 304 | return os.path.join(run_dir, 'checkpoints') 305 | 306 | def get_logger(self, name=None): 307 | if name is None: 308 | logger = self._logger 309 | # logger = logging.getLogger(name=self._name) 310 | else: 311 | logger_name = self._logger.name + '.' + name 312 | logger = logging.getLogger(name=logger_name) 313 | return logger 314 | 315 | def log_metric(self, name, value, global_step, epoch, split=None): 316 | if 'tensorboard' in self._third_party_tools: 317 | writer = self._tensorboard_writer 318 | if split is None: 319 | scaler_name = name 320 | else: 321 | scaler_name = '/'.join((split, name)) 322 | writer.add_scalar(scaler_name, value, global_step) 323 | # if 'aim' in self._third_party_tools: 324 | # sess = self._aim_session 325 | # sess.track(value, name=name, epoch=epoch, split=split) 326 | 327 | def save_metrics(self, metrics, filename='results'): 328 | metric_dict = OrderedDict() 329 | for metric in metrics: 330 | name = metric['name'] 331 | if 'split' in metric: 332 | name = f"{metric['split']}:{name}" 333 | value = metric['value'] 334 | metric_dict[name] = value 335 | with open(os.path.join(self._run_dir, f'{filename}.json'), 'w') as f: 336 | json.dump(metric_dict, fp=f, indent=4) 337 | with open(os.path.join(self._run_dir, f'{filename}.csv'), 'w') as f: 338 | print(*list(metric_dict.keys()), sep=',', file=f) 339 | print(*list(metric_dict.values()), sep=',', file=f) 340 | 341 | def is_master(self): 342 | return self._rank == 0 343 | 344 | def is_distributed(self): 345 | return self._world_size > 1 346 | 347 | 348 | manager = ExperiMan(name='default') 349 | -------------------------------------------------------------------------------- /models/clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | from .multihead_attention import MultiheadAttention 9 | 10 | 11 | class Sequential(nn.Sequential): 12 | def forward(self, *inputs, **kwargs): 13 | for module in self._modules.values(): 14 | if type(inputs) == tuple: 15 | inputs = module(*inputs, **kwargs) 16 | else: 17 | inputs = module(inputs, **kwargs) 18 | return inputs 19 | 20 | 21 | class Bottleneck(nn.Module): 22 | expansion = 4 23 | 24 | def __init__(self, inplanes, planes, stride=1): 25 | super().__init__() 26 | 27 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 28 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 29 | self.bn1 = nn.BatchNorm2d(planes) 30 | self.relu1 = nn.ReLU(inplace=True) 31 | 32 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.relu2 = nn.ReLU(inplace=True) 35 | 36 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 37 | 38 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 39 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 40 | self.relu3 = nn.ReLU(inplace=True) 41 | 42 | self.downsample = None 43 | self.stride = stride 44 | 45 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 46 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 47 | self.downsample = nn.Sequential(OrderedDict([ 48 | ("-1", nn.AvgPool2d(stride)), 49 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 50 | ("1", nn.BatchNorm2d(planes * self.expansion)) 51 | ])) 52 | 53 | def forward(self, x: torch.Tensor): 54 | identity = x 55 | 56 | out = self.relu1(self.bn1(self.conv1(x))) 57 | out = self.relu2(self.bn2(self.conv2(out))) 58 | out = self.avgpool(out) 59 | out = self.bn3(self.conv3(out)) 60 | 61 | if self.downsample is not None: 62 | identity = self.downsample(x) 63 | 64 | out += identity 65 | out = self.relu3(out) 66 | return out 67 | 68 | 69 | class AttentionPool2d(nn.Module): 70 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 71 | super().__init__() 72 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 73 | self.k_proj = nn.Linear(embed_dim, embed_dim) 74 | self.q_proj = nn.Linear(embed_dim, embed_dim) 75 | self.v_proj = nn.Linear(embed_dim, embed_dim) 76 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 77 | self.num_heads = num_heads 78 | 79 | def forward(self, x): 80 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC 81 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 82 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 83 | x, _ = F.multi_head_attention_forward( 84 | query=x[:1], key=x, value=x, 85 | embed_dim_to_check=x.shape[-1], 86 | num_heads=self.num_heads, 87 | q_proj_weight=self.q_proj.weight, 88 | k_proj_weight=self.k_proj.weight, 89 | v_proj_weight=self.v_proj.weight, 90 | in_proj_weight=None, 91 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 92 | bias_k=None, 93 | bias_v=None, 94 | add_zero_attn=False, 95 | dropout_p=0, 96 | out_proj_weight=self.c_proj.weight, 97 | out_proj_bias=self.c_proj.bias, 98 | use_separate_proj_weight=True, 99 | training=self.training, 100 | need_weights=False 101 | ) 102 | return x.squeeze(0) 103 | 104 | 105 | class ModifiedResNet(nn.Module): 106 | """ 107 | A ResNet class that is similar to torchvision's but contains the following changes: 108 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 109 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 110 | - The final pooling layer is a QKV attention instead of an average pool 111 | """ 112 | 113 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 114 | super().__init__() 115 | self.output_dim = output_dim 116 | self.input_resolution = input_resolution 117 | 118 | # the 3-layer stem 119 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 120 | self.bn1 = nn.BatchNorm2d(width // 2) 121 | self.relu1 = nn.ReLU(inplace=True) 122 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 123 | self.bn2 = nn.BatchNorm2d(width // 2) 124 | self.relu2 = nn.ReLU(inplace=True) 125 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 126 | self.bn3 = nn.BatchNorm2d(width) 127 | self.relu3 = nn.ReLU(inplace=True) 128 | self.avgpool = nn.AvgPool2d(2) 129 | 130 | # residual layers 131 | self._inplanes = width # this is a *mutable* variable used during construction 132 | self.layer1 = self._make_layer(width, layers[0]) 133 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 134 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 135 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 136 | 137 | embed_dim = width * 32 # the ResNet feature dimension 138 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 139 | 140 | def _make_layer(self, planes, blocks, stride=1): 141 | layers = [Bottleneck(self._inplanes, planes, stride)] 142 | 143 | self._inplanes = planes * Bottleneck.expansion 144 | for _ in range(1, blocks): 145 | layers.append(Bottleneck(self._inplanes, planes)) 146 | 147 | return nn.Sequential(*layers) 148 | 149 | def forward(self, x): 150 | def stem(x): 151 | x = self.relu1(self.bn1(self.conv1(x))) 152 | x = self.relu2(self.bn2(self.conv2(x))) 153 | x = self.relu3(self.bn3(self.conv3(x))) 154 | x = self.avgpool(x) 155 | return x 156 | 157 | x = x.type(self.conv1.weight.dtype) 158 | x = stem(x) 159 | x = self.layer1(x) 160 | x = self.layer2(x) 161 | x = self.layer3(x) 162 | x = self.layer4(x) 163 | x = self.attnpool(x) 164 | 165 | return x 166 | 167 | 168 | class LayerNorm(nn.LayerNorm): 169 | """Subclass torch's LayerNorm to handle fp16.""" 170 | 171 | def forward(self, x: torch.Tensor): 172 | orig_type = x.dtype 173 | ret = super().forward(x.type(torch.float32)) 174 | return ret.type(orig_type) 175 | 176 | 177 | class QuickGELU(nn.Module): 178 | def forward(self, x: torch.Tensor): 179 | return x * torch.sigmoid(1.702 * x) 180 | 181 | 182 | class ResidualAttentionBlock(nn.Module): 183 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 184 | super().__init__() 185 | 186 | # self.attn = nn.MultiheadAttention(d_model, n_head) 187 | self.attn = MultiheadAttention(d_model, n_head) 188 | self.ln_1 = LayerNorm(d_model) 189 | self.mlp = nn.Sequential(OrderedDict([ 190 | ("c_fc", nn.Linear(d_model, d_model * 4)), 191 | ("gelu", QuickGELU()), 192 | ("c_proj", nn.Linear(d_model * 4, d_model)) 193 | ])) 194 | self.ln_2 = LayerNorm(d_model) 195 | self.attn_mask = attn_mask 196 | 197 | def attention(self, x: torch.Tensor, attn_mask: torch.Tensor = None, attn_out: list[torch.Tensor] = None): 198 | if attn_mask is None: 199 | attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 200 | if attn_out is None: 201 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] 202 | else: 203 | output, weight = self.attn(x, x, x, need_weights=True, attn_mask=attn_mask) 204 | attn_out.append(weight) 205 | return output 206 | 207 | def forward(self, x: torch.Tensor, attn_mask: torch.Tensor = None, attn_out: list[torch.Tensor] = None): 208 | x = x + self.attention(self.ln_1(x), attn_mask=attn_mask, attn_out=attn_out) 209 | x = x + self.mlp(self.ln_2(x)) 210 | return x 211 | 212 | 213 | class Transformer(nn.Module): 214 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 215 | super().__init__() 216 | self.width = width 217 | self.layers = layers 218 | # self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 219 | self.resblocks = Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 220 | 221 | def forward(self, x: torch.Tensor, attn_mask: torch.Tensor = None, attn_out: list[torch.Tensor] = None): 222 | return self.resblocks(x, attn_mask=attn_mask, attn_out=attn_out) 223 | 224 | 225 | class VisionTransformer(nn.Module): 226 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 227 | super().__init__() 228 | self.input_resolution = input_resolution 229 | self.heads = heads 230 | self.output_dim = output_dim 231 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 232 | 233 | scale = width ** -0.5 234 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 235 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 236 | self.ln_pre = LayerNorm(width) 237 | 238 | self.transformer = Transformer(width, layers, heads) 239 | 240 | self.ln_post = LayerNorm(width) 241 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 242 | 243 | def forward(self, x: torch.Tensor, mask: torch.Tensor = None, 244 | attn_mask: torch.Tensor = None, attn_out: list[torch.Tensor] = None): 245 | # mask: MAE-like masking; reduce computation 246 | # attn_mask: apply on MultiheadAttention; 247 | # allow different mask rates for images for the same batch 248 | # attn_out: a list to save the attention maps 249 | 250 | x = self.conv1(x) # shape = [*, width, grid, grid] 251 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 252 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 253 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 254 | x = x + self.positional_embedding.to(x.dtype) 255 | x = self.ln_pre(x) 256 | 257 | x = x.permute(1, 0, 2) # NLD -> LND 258 | if mask is not None: 259 | x = x.gather(0, mask.unsqueeze(-1).expand(-1, -1, x.size(2))) 260 | x = self.transformer(x, attn_mask=attn_mask, attn_out=attn_out) 261 | x = x.permute(1, 0, 2) # LND -> NLD 262 | 263 | x = self.ln_post(x[:, 0, :]) 264 | 265 | if self.proj is not None: 266 | x = x @ self.proj 267 | 268 | return x 269 | 270 | 271 | class CLIP(nn.Module): 272 | def __init__(self, 273 | embed_dim: int, 274 | # vision 275 | image_resolution: int, 276 | vision_layers: Union[Tuple[int, int, int, int], int], 277 | vision_width: int, 278 | vision_patch_size: int, 279 | # text 280 | context_length: int, 281 | vocab_size: int, 282 | transformer_width: int, 283 | transformer_heads: int, 284 | transformer_layers: int 285 | ): 286 | super().__init__() 287 | 288 | self.context_length = context_length 289 | 290 | if isinstance(vision_layers, (tuple, list)): 291 | vision_heads = vision_width * 32 // 64 292 | self.visual = ModifiedResNet( 293 | layers=vision_layers, 294 | output_dim=embed_dim, 295 | heads=vision_heads, 296 | input_resolution=image_resolution, 297 | width=vision_width 298 | ) 299 | else: 300 | vision_heads = vision_width // 64 301 | self.visual = VisionTransformer( 302 | input_resolution=image_resolution, 303 | patch_size=vision_patch_size, 304 | width=vision_width, 305 | layers=vision_layers, 306 | heads=vision_heads, 307 | output_dim=embed_dim 308 | ) 309 | 310 | self.transformer = Transformer( 311 | width=transformer_width, 312 | layers=transformer_layers, 313 | heads=transformer_heads, 314 | attn_mask=self.build_attention_mask() 315 | ) 316 | 317 | self.vocab_size = vocab_size 318 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 319 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 320 | self.ln_final = LayerNorm(transformer_width) 321 | 322 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 323 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 324 | 325 | self.initialize_parameters() 326 | 327 | def initialize_parameters(self): 328 | nn.init.normal_(self.token_embedding.weight, std=0.02) 329 | nn.init.normal_(self.positional_embedding, std=0.01) 330 | 331 | if isinstance(self.visual, ModifiedResNet): 332 | if self.visual.attnpool is not None: 333 | std = self.visual.attnpool.c_proj.in_features ** -0.5 334 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 335 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 336 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 337 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 338 | 339 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 340 | for name, param in resnet_block.named_parameters(): 341 | if name.endswith("bn3.weight"): 342 | nn.init.zeros_(param) 343 | 344 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 345 | attn_std = self.transformer.width ** -0.5 346 | fc_std = (2 * self.transformer.width) ** -0.5 347 | for block in self.transformer.resblocks: 348 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 349 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 350 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 351 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 352 | 353 | if self.text_projection is not None: 354 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 355 | 356 | def build_attention_mask(self): 357 | # lazily create causal attention mask, with full attention between the vision tokens 358 | # pytorch uses additive attention mask; fill with -inf 359 | mask = torch.empty(self.context_length, self.context_length) 360 | mask.fill_(float("-inf")) 361 | mask.triu_(1) # zero out the lower diagonal 362 | return mask 363 | 364 | @property 365 | def dtype(self): 366 | return self.visual.conv1.weight.dtype 367 | 368 | def encode_image(self, image): 369 | return self.visual(image.type(self.dtype)) 370 | 371 | def encode_text(self, text): 372 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 373 | 374 | x = x + self.positional_embedding.type(self.dtype) 375 | x = x.permute(1, 0, 2) # NLD -> LND 376 | x = self.transformer(x) 377 | x = x.permute(1, 0, 2) # LND -> NLD 378 | x = self.ln_final(x).type(self.dtype) 379 | 380 | # x.shape = [batch_size, n_ctx, transformer.width] 381 | # take features from the eot embedding (eot_token is the highest number in each sequence) 382 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 383 | 384 | return x 385 | 386 | def forward(self, image, text): 387 | image_features = self.encode_image(image) 388 | text_features = self.encode_text(text) 389 | 390 | # normalized features 391 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 392 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 393 | 394 | # cosine similarity as logits 395 | logit_scale = self.logit_scale.exp() 396 | logits_per_image = logit_scale * image_features @ text_features.t() 397 | logits_per_text = logits_per_image.t() 398 | 399 | # shape = [global_batch_size, global_batch_size] 400 | return logits_per_image, logits_per_text 401 | 402 | 403 | def convert_weights(model: nn.Module): 404 | """Convert applicable model parameters to fp16""" 405 | 406 | def _convert_weights_to_fp16(l): 407 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 408 | l.weight.data = l.weight.data.half() 409 | if l.bias is not None: 410 | l.bias.data = l.bias.data.half() 411 | 412 | if isinstance(l, nn.MultiheadAttention): 413 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 414 | tensor = getattr(l, attr) 415 | if tensor is not None: 416 | tensor.data = tensor.data.half() 417 | 418 | for name in ["text_projection", "proj"]: 419 | if hasattr(l, name): 420 | attr = getattr(l, name) 421 | if attr is not None: 422 | attr.data = attr.data.half() 423 | 424 | model.apply(_convert_weights_to_fp16) 425 | 426 | 427 | def build_model(state_dict: dict): 428 | vit = "visual.proj" in state_dict 429 | 430 | if vit: 431 | vision_width = state_dict["visual.conv1.weight"].shape[0] 432 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 433 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 434 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 435 | image_resolution = vision_patch_size * grid_size 436 | else: 437 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 438 | vision_layers = tuple(counts) 439 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 440 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 441 | vision_patch_size = None 442 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 443 | image_resolution = output_width * 32 444 | 445 | embed_dim = state_dict["text_projection"].shape[1] 446 | context_length = state_dict["positional_embedding"].shape[0] 447 | vocab_size = state_dict["token_embedding.weight"].shape[0] 448 | transformer_width = state_dict["ln_final.weight"].shape[0] 449 | transformer_heads = transformer_width // 64 450 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) 451 | 452 | model = CLIP( 453 | embed_dim, 454 | image_resolution, vision_layers, vision_width, vision_patch_size, 455 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 456 | ) 457 | 458 | for key in ["input_resolution", "context_length", "vocab_size"]: 459 | if key in state_dict: 460 | del state_dict[key] 461 | 462 | convert_weights(model) 463 | model.load_state_dict(state_dict) 464 | return model.eval() 465 | -------------------------------------------------------------------------------- /data/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torchvision.datasets import ImageNet, ImageFolder 3 | import torchvision.transforms as transforms 4 | from torchvision.transforms import InterpolationMode 5 | import torchvision.transforms.functional as TF 6 | from data.base import BaseDataset, get_dataloader, stratified_random_split 7 | 8 | 9 | # class names taken from: https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/notebooks/Prompt_Engineering_for_ImageNet.ipynb 10 | IMAGENET_CLASS_NAMES = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", "box turtle", "banded gecko", "green iguana", "Carolina anole", "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", "American alligator", "triceratops", "worm snake", "ring-necked snake", "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", "freight car", "French horn", "frying pan", "fur coat", "garbage truck", "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"] 11 | 12 | 13 | class ImageNetDataset(BaseDataset): 14 | 15 | def __init__(self, data_dir, size=224, interpolation='bicubic', transform='std'): 16 | super().__init__( 17 | data_dir=data_dir, 18 | size=size, 19 | mean=(0.485, 0.456, 0.406), 20 | std=(0.229, 0.224, 0.225), 21 | ) 22 | self.sub_dir = 'ILSVRC2012' 23 | self.num_classes = 1000 24 | self.class_names = IMAGENET_CLASS_NAMES 25 | 26 | self.interpolation = interpolation 27 | if interpolation == 'bicubic': 28 | interpolation = InterpolationMode.BICUBIC 29 | elif interpolation == 'bilinear': 30 | interpolation = InterpolationMode.BILINEAR 31 | elif interpolation == 'nearest': 32 | interpolation = InterpolationMode.NEAREST 33 | else: 34 | raise NotImplementedError( 35 | f"Unsupported interpolation mode: {interpolation}") 36 | 37 | self.target_transform = None 38 | if transform == 'std': 39 | self.transforms_train = transforms.Compose([ 40 | transforms.RandomResizedCrop( 41 | self.size, interpolation=interpolation), 42 | transforms.RandomHorizontalFlip(), 43 | transforms.ToTensor(), 44 | ]) 45 | self.transforms_test = transforms.Compose([ 46 | transforms.Resize( 47 | round(self.size / 224 * 256), interpolation=interpolation), 48 | transforms.CenterCrop(self.size), 49 | transforms.ToTensor(), 50 | ]) 51 | elif transform == 'clip': 52 | # https://github.com/mlfoundations/wise-ft/blob/58b7a4b343b09dc06606aa929c2ef51accced8d1/clip/clip.py#L67 53 | # self.transforms_train = transforms.Compose([ 54 | # transforms.RandomResizedCrop( 55 | # self.size, scale=(0.9, 1), interpolation=interpolation), 56 | # transforms.ToTensor(), 57 | # ]) 58 | self.transforms_test = transforms.Compose([ 59 | transforms.Resize(self.size, interpolation=interpolation), 60 | transforms.CenterCrop(self.size), 61 | transforms.ToTensor(), 62 | ]) 63 | self.transforms_train = self.transforms_test 64 | elif transform.startswith('RandAug'): # RandAug(N,M) 65 | N, M = [int(x) for x in transform[8:-1].split(',')] 66 | self.transforms_train = transforms.Compose([ 67 | transforms.RandAugment(N, M), 68 | transforms.RandomResizedCrop( 69 | self.size, scale=(0.9, 1), interpolation=interpolation), 70 | transforms.RandomHorizontalFlip(), 71 | transforms.ToTensor(), 72 | ]) 73 | self.transforms_test = transforms.Compose([ 74 | transforms.Resize(self.size, interpolation=interpolation), 75 | transforms.CenterCrop(self.size), 76 | transforms.ToTensor(), 77 | ]) 78 | else: 79 | raise NotImplementedError( 80 | f"Unsupported transform type: {transform}") 81 | 82 | def get_loader(self, batch_size, num_workers, with_index=False, 83 | train_split='original', val_size=0, split_seed=0, 84 | shuffle_test=False, augment=True, drop_last=True, 85 | world_size=1, rank=0): 86 | transforms_train = self.transforms_train if augment else self.transforms_test 87 | transforms_test = self.transforms_test 88 | data_dir = os.path.join(self.data_dir, self.sub_dir) 89 | 90 | if train_split == 'original': 91 | aug_trainset = ImageNet( 92 | root=data_dir, split='train', transform=transforms_train, 93 | target_transform=self.target_transform) 94 | raw_trainset = ImageNet( 95 | root=data_dir, split='train', transform=transforms_test, 96 | target_transform=self.target_transform) 97 | testset = ImageNet( 98 | root=data_dir, split='val', transform=transforms_test, 99 | target_transform=self.target_transform) 100 | else: 101 | raise NotImplementedError() 102 | 103 | if val_size > 0: 104 | labels = raw_trainset.targets 105 | train_size = len(raw_trainset) - val_size 106 | aug_trainset, _ = stratified_random_split( 107 | aug_trainset, labels, train_size, seed=split_seed) 108 | raw_trainset, valset = stratified_random_split( 109 | raw_trainset, labels, train_size, seed=split_seed) 110 | 111 | kwargs = dict(batch_size=batch_size, num_workers=num_workers, 112 | with_index=with_index, num_replicas=world_size, rank=rank) 113 | aug_trainloader = get_dataloader( 114 | aug_trainset, shuffle=True, drop_last=drop_last, **kwargs) 115 | raw_trainloader = get_dataloader( 116 | raw_trainset, shuffle=shuffle_test, drop_last=False, **kwargs) 117 | testloader = get_dataloader( 118 | testset, shuffle=shuffle_test, drop_last=False, **kwargs) 119 | if val_size > 0: 120 | valloader = get_dataloader( 121 | valset, shuffle=False, drop_last=False, **kwargs) 122 | return aug_trainloader, raw_trainloader, valloader, testloader 123 | else: 124 | return aug_trainloader, raw_trainloader, testloader 125 | 126 | 127 | class EvaluationDataset(ImageNetDataset): 128 | 129 | def __init__(self, data_dir, size=224, interpolation='bicubic', transform='std'): 130 | super().__init__(data_dir, size, interpolation, transform) 131 | 132 | def get_loader(self, batch_size, num_workers, with_index=False, 133 | train_split='original', val_size=0, split_seed=0, 134 | shuffle_test=False, augment=True, drop_last=True, 135 | world_size=1, rank=0): 136 | data_dir = os.path.join(self.data_dir, self.sub_dir) 137 | 138 | testset = ImageFolder( 139 | root=data_dir, transform=self.transforms_test, 140 | target_transform=self.target_transform) 141 | 142 | kwargs = dict(batch_size=batch_size, num_workers=num_workers, 143 | with_index=with_index, num_replicas=world_size, rank=rank) 144 | testloader = get_dataloader( 145 | testset, shuffle=shuffle_test, drop_last=False, **kwargs) 146 | return testloader 147 | -------------------------------------------------------------------------------- /trainers/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import os 3 | from typing import Any, Union 4 | from typing_extensions import Literal 5 | from tqdm import tqdm 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim 10 | from torch.utils.data import DataLoader 11 | from data.base import BaseDataset 12 | from utils.experiman import ExperiMan 13 | from utils.misc import Accuracy, AverageMeter, MovingAverageMeter, ScalerMeter, PerClassMeter 14 | 15 | 16 | class inference_mode(torch.no_grad): 17 | 18 | def __init__(self, models=None): 19 | super().__init__() 20 | if models is None: 21 | models = [] 22 | elif not isinstance(models, list): 23 | models = [models] 24 | self.models = models 25 | self.training = [False for _ in models] 26 | 27 | def __enter__(self): 28 | super().__enter__() 29 | for i, model in enumerate(self.models): 30 | self.training[i] = model.training 31 | model.eval() 32 | 33 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 34 | for model, training in zip(self.models, self.training): 35 | model.train(training) 36 | super().__exit__(exc_type, exc_value, traceback) 37 | 38 | 39 | class LoopConfig(): 40 | 41 | def __init__( 42 | self, 43 | name: str, 44 | dataset: BaseDataset, 45 | dataloader: DataLoader, 46 | training: bool, 47 | n_iterations: int, 48 | n_phases: int = 1, 49 | n_logical_steps: Union[list[int], int] = 1, 50 | n_computation_steps: Union[list[int], int] = 1, 51 | run_every_n_epochs: int = 1, 52 | run_at_checkpoint: bool = True, 53 | run_at_last_epoch: bool = True, 54 | for_best_meter: bool = False, 55 | ): 56 | self.name = name 57 | self.dataset = dataset 58 | self.dataloader = dataloader 59 | self.training = training 60 | self.n_iterations = n_iterations 61 | self.n_phases = n_phases 62 | if type(n_logical_steps) is int: 63 | n_logical_steps = [n_logical_steps] 64 | self.n_logical_steps = n_logical_steps 65 | assert n_phases == len(n_logical_steps) 66 | if type(n_computation_steps) is int: 67 | n_computation_steps = [n_computation_steps] 68 | self.n_computation_steps = n_computation_steps 69 | assert n_phases == len(n_computation_steps) 70 | self.run_every_n_epochs = run_every_n_epochs 71 | self.run_at_checkpoint = run_at_checkpoint 72 | self.run_at_last_epoch = run_at_last_epoch or run_at_checkpoint 73 | self.for_best_meter = for_best_meter 74 | 75 | def __str__(self): 76 | configs = self.__dict__.copy() 77 | name = configs.pop('name') 78 | configs.pop('dataset') 79 | configs.pop('dataloader') 80 | arg_strings = [f"{k}={v}" for k, v in configs.items()] 81 | return f"[{name}] ({', '.join(arg_strings)})" 82 | 83 | 84 | class OptimizerWithSchedule(): 85 | 86 | def __init__( 87 | self, 88 | optimizer: torch.optim.Optimizer, 89 | scheduler: torch.optim.lr_scheduler._LRScheduler, 90 | schedule_step: Literal['epoch', 'iter'], 91 | ): 92 | self.optimizer = optimizer 93 | self.scheduler = scheduler 94 | self.schedule_step = schedule_step 95 | 96 | def state_dict(self): 97 | return { 98 | 'optimizer': self.optimizer.state_dict(), 99 | 'scheduler': self.scheduler.state_dict(), 100 | 'schedule_step': self.schedule_step, 101 | } 102 | 103 | def load_state_dict(self, state_dict): 104 | self.optimizer.load_state_dict(state_dict['optimizer']) 105 | self.scheduler.load_state_dict(state_dict['scheduler']) 106 | 107 | def step(self, *args, **kwargs): 108 | self.optimizer.step(*args, **kwargs) 109 | 110 | def zero_grad(self, *args, **kwargs): 111 | self.optimizer.zero_grad(*args, **kwargs) 112 | 113 | def scheduler_step(self, *args, **kwargs): 114 | self.scheduler.step(*args, **kwargs) 115 | 116 | def get_learning_rates(self): 117 | return self.scheduler.get_last_lr() 118 | 119 | 120 | class BaseTrainer(): 121 | """ 122 | Hierarchy: 123 | 1. Epoch: consists of multiple training / testing / updating loops 124 | 2. Loop: loops (through a dataset) with repeated iterations 125 | 3. Iteration: training iteration may consist of multiple phases with 126 | different objectives 127 | 4. Phase: consists of one or more repeated (accumulated) steps 128 | 5. Logical step: update the models once (after one or more computation 129 | steps for accumulated gradients) 130 | 6. Computation step: handles a batch of data; update the meters 131 | (train: forward & backward; test: forward) 132 | """ 133 | 134 | def __init__( 135 | self, 136 | manager: ExperiMan, 137 | models: dict[str, nn.Module], 138 | criterions: dict[str, nn.Module], 139 | n_epochs: int, 140 | loop_configs: list[LoopConfig], 141 | optimizers: dict[str, OptimizerWithSchedule], 142 | log_period: int, 143 | ckpt_period: int, 144 | device: torch.device, 145 | save_init_ckpt: bool = False, 146 | resume_ckpt: dict = None, 147 | ): 148 | self.manager = manager 149 | self.is_master = manager.is_master() 150 | self.is_distributed = manager.is_distributed() 151 | self.logger = manager.get_logger() 152 | self.last_log_iter_id = -1 153 | self.tqdms = [None for _ in loop_configs] 154 | self.models = models 155 | self.criterions = criterions 156 | self.n_epochs = n_epochs 157 | self.loop_configs = loop_configs 158 | self.data_counters = [0 for _ in loop_configs] 159 | self.data_iters = [self._get_data_iter(i) for i in range(len(loop_configs))] 160 | self.optimizers = optimizers 161 | self.log_period = log_period 162 | self.ckpt_period = ckpt_period or n_epochs + 1 163 | self.device = device 164 | self.save_init_ckpt = save_init_ckpt 165 | self.meters = [{} for _ in loop_configs] 166 | self.meters_info = [{} for _ in loop_configs] 167 | self.loop_meters = self.meters[0] 168 | self.loop_meters_info = self.meters_info[0] 169 | self.meter_for_best_checkpoint = None 170 | self.best_value = None 171 | self.start_epoch = 0 172 | self.iter_count = 0 173 | if resume_ckpt: 174 | self.resume_from_checkpoint(resume_ckpt) 175 | self._default_checkpoint_name = 'ckpt-last' 176 | 177 | def master_only(func): 178 | def wrapper_master_only(self, *args, **kwargs): 179 | if self.is_master: 180 | return func(self, *args, **kwargs) 181 | else: 182 | return None 183 | return wrapper_master_only 184 | 185 | def _setup_tqdms(self, loop_id): 186 | n_iters = self.loop_configs[loop_id].n_iterations 187 | t = tqdm(total=n_iters, leave=False, dynamic_ncols=True, 188 | disable=(not self.is_master)) 189 | t.clear() 190 | self.tqdms[loop_id] = t 191 | 192 | @master_only 193 | def _manager_log_metric(self, epoch_id, loop_id): 194 | split_name = self.loop_configs[loop_id].name 195 | for name, meter in self.loop_meters.items(): 196 | self.manager.log_metric(name, meter.get_value(), 197 | self.iter_count, epoch_id, 198 | split=split_name) 199 | 200 | def _get_data_iter(self, loop_id): 201 | loader = self.loop_configs[loop_id].dataloader 202 | if hasattr(loader, 'sampler'): 203 | if hasattr(loader.sampler, 'set_epoch'): 204 | loader.sampler.set_epoch(self.data_counters[loop_id]) 205 | self.data_counters[loop_id] += 1 206 | return iter(loader) 207 | 208 | def _next_data_batch(self, loop_id): 209 | try: 210 | batch = next(self.data_iters[loop_id]) 211 | except StopIteration: 212 | self.data_iters[loop_id] = self._get_data_iter(loop_id) 213 | batch = next(self.data_iters[loop_id]) 214 | return batch 215 | 216 | def _should_run_loop(self, epoch_id, loop_id): 217 | config = self.loop_configs[loop_id] 218 | cnt = epoch_id + 1 219 | period = config.run_every_n_epochs 220 | periodic = (period and cnt % period == 0) 221 | at_ckpt = (config.run_at_checkpoint and cnt % self.ckpt_period == 0) 222 | at_last = (config.run_at_last_epoch and cnt == self.n_epochs) 223 | not_empty = (config.n_iterations > 0) 224 | return not_empty and (periodic or at_ckpt or at_last) 225 | 226 | def get_data_batch(self, loop_id, phase_id): 227 | """Return a batch of data for the phase.""" 228 | raise NotImplementedError 229 | 230 | def get_active_optimizers(self, loop_id, phase_id): 231 | """Return the optimizers active for the phase.""" 232 | raise NotImplementedError 233 | 234 | def get_checkpoint(self, epoch_id): 235 | """Return a checkpoint object to be saved.""" 236 | checkpoint = { 237 | 'epoch': epoch_id, 238 | 'best_value': self.best_value, 239 | '_data_counters': self.data_counters, 240 | } 241 | for name, model in self.models.items(): 242 | bare_model = model.module if hasattr(model, 'module') else model 243 | checkpoint[name] = bare_model.state_dict() 244 | for name, optimizer in self.optimizers.items(): 245 | checkpoint[name] = optimizer.state_dict() 246 | return checkpoint 247 | 248 | def resume_from_checkpoint(self, checkpoint): 249 | """Resume training from a checkpoint object.""" 250 | self.start_epoch = checkpoint['epoch'] + 1 251 | self.best_value = checkpoint['best_value'] 252 | self.data_counters = checkpoint['_data_counters'] 253 | for name, model in self.models.items(): 254 | bare_model = model.module if hasattr(model, 'module') else model 255 | bare_model.load_state_dict(checkpoint[name]) 256 | for name, optimizer in self.optimizers.items(): 257 | optimizer.load_state_dict(checkpoint[name]) 258 | for config in self.loop_configs: 259 | if config.training: 260 | self.iter_count += config.n_iterations * self.start_epoch 261 | return checkpoint 262 | 263 | def toggle_model_mode(self, epoch_id, loop_id): 264 | """Toggle train/eval mode of models.""" 265 | raise NotImplementedError 266 | 267 | def update_meters(self): 268 | """Update meters before logging.""" 269 | 270 | def do_step(self, epoch_id, loop_id, iter_id, phase_id, data_batch): 271 | """ 272 | Typical procedure: 273 | 1. Forward for predictions and losses; 274 | 2. Backward for gradients (if needed); 275 | 3. Update (avg/sum) meters. 276 | """ 277 | raise NotImplementedError 278 | 279 | def add_meter(self, name, abbr=None, loop_id=None, meter_type='avg', 280 | fstr_format=None, reset_every_epoch=None, 281 | omit_from_results=False): 282 | if loop_id is None: 283 | loop_id = list(range(len(self.loop_configs))) 284 | elif type(loop_id) is int: 285 | loop_id = [loop_id] 286 | assert meter_type in ('avg', 'scaler', 'per_class_avg') 287 | for id in loop_id: 288 | if reset_every_epoch is None: 289 | reset = not self.loop_configs[id].training 290 | else: 291 | reset = reset_every_epoch 292 | self.meters_info[id][name] = { 293 | 'abbr': abbr if abbr is not None else name, 294 | 'type': meter_type, 295 | 'format': fstr_format, 296 | 'reset_every_epoch': reset, 297 | 'omit_from_results': omit_from_results, 298 | } 299 | 300 | def set_meter_for_best_checkpoint(self, loop_id, name, maximum=True): 301 | self.meter_for_best_checkpoint = (loop_id, name, maximum) 302 | 303 | def setup_loop_meters(self, loop_id): 304 | training = self.loop_configs[loop_id].training 305 | self.loop_meters = self.meters[loop_id] 306 | self.loop_meters_info = self.meters_info[loop_id] 307 | for name, info in self.loop_meters_info.items(): 308 | if name not in self.loop_meters: 309 | meter_type = info['type'] 310 | if meter_type == 'avg': 311 | if training: 312 | meter = MovingAverageMeter() 313 | else: 314 | meter = AverageMeter() 315 | elif meter_type == 'per_class_avg': 316 | if training: 317 | meter = PerClassMeter(MovingAverageMeter) 318 | else: 319 | meter = PerClassMeter(AverageMeter) 320 | elif meter_type == 'scaler': 321 | meter = ScalerMeter() 322 | else: 323 | raise NotImplementedError() 324 | self.loop_meters[name] = meter 325 | elif info['reset_every_epoch']: 326 | self.loop_meters[name].reset() 327 | 328 | def is_best_checkpoint(self, epoch_id): 329 | is_best = False 330 | if self.meter_for_best_checkpoint is not None: 331 | loop_id, name, maximum = self.meter_for_best_checkpoint 332 | if self._should_run_loop(epoch_id, loop_id): 333 | value = self.meters[loop_id][name].get_value() 334 | if self.best_value is None: 335 | self.best_value = value 336 | is_best = True 337 | else: 338 | delta = value - self.best_value 339 | sign = 1 if maximum else -1 340 | if delta * sign > 0: 341 | self.best_value = value 342 | is_best = True 343 | return is_best 344 | 345 | @master_only 346 | def save_checkpoint(self, epoch_id, checkpoint_names=None): 347 | checkpoint = self.get_checkpoint(epoch_id) 348 | if checkpoint_names is None: 349 | filenames = [f'ckpt-{epoch_id}.pt'] 350 | else: 351 | filenames = [f'{name}.pt' for name in checkpoint_names] 352 | for name in filenames: 353 | model_path = os.path.join(self.manager.get_checkpoint_dir(), name) 354 | torch.save(checkpoint, model_path) 355 | if name[:-3] != self._default_checkpoint_name \ 356 | or epoch_id == self.n_epochs - 1: 357 | self.logger.info(f'Checkpoint saved to: {model_path}') 358 | 359 | @master_only 360 | def save_results(self, filename='results'): 361 | metrics = [] 362 | for i, loop_config in enumerate(self.loop_configs): 363 | split = loop_config.name 364 | for name, meter in self.meters[i].items(): 365 | if not self.meters_info[i][name]['omit_from_results']: 366 | metric = dict(split=split, name=name, 367 | value=meter.get_value()) 368 | metrics.append(metric) 369 | self.manager.save_metrics(metrics, filename=filename) 370 | 371 | def update_lr(self, step): 372 | for optimizer in self.optimizers.values(): 373 | if optimizer.schedule_step == step: 374 | optimizer.scheduler_step() 375 | 376 | def do_iter(self, epoch_id, loop_id, iter_id): 377 | config = self.loop_configs[loop_id] 378 | for phase_id in range(config.n_phases): 379 | for _ in range(config.n_logical_steps[phase_id]): 380 | optimizers = self.get_active_optimizers(loop_id, phase_id) 381 | for _ in range(config.n_computation_steps[phase_id]): 382 | data_batch = self.get_data_batch(loop_id, phase_id) 383 | self.do_step(epoch_id, loop_id, iter_id, phase_id, data_batch) 384 | for optimizer in optimizers: 385 | optimizer.step() 386 | optimizer.zero_grad() 387 | if config.training: 388 | self.iter_count += 1 389 | 390 | def log_iter(self, epoch_id, loop_id, iter_id): 391 | self.update_meters() 392 | display_items = [""] 393 | for name, info in self.loop_meters_info.items(): 394 | fmt = info['format'] 395 | if fmt is not None: 396 | value = self.loop_meters[name].get_value() 397 | display_items.append(f"{info['abbr']} {value:{fmt}}") 398 | msg = '|'.join(display_items) 399 | loop_name = self.loop_configs[loop_id].name 400 | self.tqdms[loop_id].set_postfix_str(f"[{loop_name}] {msg}") 401 | self.tqdms[loop_id].update(iter_id - self.last_log_iter_id) 402 | self.last_log_iter_id = iter_id 403 | if self.loop_configs[loop_id].training: 404 | self._manager_log_metric(epoch_id, loop_id) 405 | 406 | def do_loop(self, epoch_id, loop_id): 407 | config = self.loop_configs[loop_id] 408 | self.toggle_model_mode(epoch_id, loop_id) 409 | self.setup_loop_meters(loop_id) 410 | for iter_id in range(config.n_iterations): 411 | self.do_iter(epoch_id, loop_id, iter_id) 412 | if (iter_id + 1) % self.log_period == 0 \ 413 | or iter_id == config.n_iterations - 1: 414 | self.log_iter(epoch_id, loop_id, iter_id) 415 | if config.training: 416 | self.update_lr('iter') 417 | 418 | def log_loop(self, epoch_id, loop_id): 419 | config = self.loop_configs[loop_id] 420 | self.last_log_iter_id = -1 421 | bar = self.tqdms[loop_id] 422 | elapsed_time = bar.format_interval(bar.format_dict['elapsed']) 423 | bar.close() 424 | self.update_meters() 425 | display_items = [] 426 | for name, info in self.loop_meters_info.items(): 427 | meter = self.loop_meters[name] 428 | if self.is_distributed: 429 | meter.sync(self.device) 430 | fmt = info['format'] 431 | if fmt is not None: 432 | value = meter.get_value() 433 | display_items.append(f"{info['abbr']} {value:{fmt}}") 434 | msg = '|'.join(display_items) 435 | self.logger.info(f"elapsed: {elapsed_time} [{config.name}] {msg}") 436 | if not config.training: 437 | self._manager_log_metric(epoch_id, loop_id) 438 | 439 | def do_epoch(self, epoch_id): 440 | for loop_id, loop_config in enumerate(self.loop_configs): 441 | if self._should_run_loop(epoch_id, loop_id): 442 | self._setup_tqdms(loop_id) 443 | with torch.set_grad_enabled(loop_config.training): 444 | self.do_loop(epoch_id, loop_id) 445 | self.log_loop(epoch_id, loop_id) 446 | 447 | def log_epoch(self, epoch_id): 448 | lrs = [optimizer.get_learning_rates()[0] 449 | for optimizer in self.optimizers.values()] 450 | lrs = "|".join([f"{lr:.5f}" for lr in lrs]) 451 | self.logger.info(f'Epoch: {epoch_id}/{self.n_epochs} lr: {lrs}') 452 | 453 | def log_loop_configs(self): 454 | n_loops = len(self.loop_configs) 455 | self.logger.info(f"Configs of {n_loops} loops:") 456 | for i, config in enumerate(self.loop_configs): 457 | self.logger.info(f"{i}: {str(config)}") 458 | 459 | def train(self): 460 | self.log_loop_configs() 461 | if self.save_init_ckpt and self.start_epoch == 0: 462 | self.save_checkpoint(-1, ['ckpt-init']) 463 | for epoch_id in range(self.start_epoch, self.n_epochs): 464 | self.log_epoch(epoch_id) 465 | self.do_epoch(epoch_id) 466 | self.update_lr('epoch') 467 | checkpoint_names = [self._default_checkpoint_name] 468 | if (epoch_id + 1) % self.ckpt_period == 0: 469 | checkpoint_names.append(f'ckpt-{epoch_id}') 470 | if self.is_best_checkpoint(epoch_id): 471 | checkpoint_names.append('ckpt-best') 472 | self.save_results('results-best') 473 | self.save_checkpoint(epoch_id, checkpoint_names) 474 | self.save_results() 475 | 476 | def test(self): 477 | self.log_loop_configs() 478 | self.do_epoch(0) 479 | self.save_results() 480 | 481 | 482 | class ClassificationTrainer(BaseTrainer): 483 | 484 | def __init__(self, *args, num_classes=None, ignored_classes=None, **kwargs): 485 | super().__init__(*args, **kwargs) 486 | if ignored_classes is None: 487 | self.ignored_classes = [] 488 | for config in self.loop_configs: 489 | if hasattr(config.dataset, 'ignored_classes'): 490 | self.ignored_classes.append(config.dataset.ignored_classes) 491 | else: 492 | self.ignored_classes.append([]) 493 | elif ignored_classes and isinstance(ignored_classes[0], list): 494 | assert len(ignored_classes) == len(self.loop_configs) 495 | self.ignored_classes = ignored_classes 496 | else: 497 | self.ignored_classes = [ignored_classes for _ in self.loop_configs] 498 | self.accuracies = [] 499 | for ic in self.ignored_classes: 500 | accuracy = Accuracy(num_classes=num_classes, 501 | ignored_classes=ic, reduction='sum') 502 | self.accuracies.append(accuracy) 503 | self.loop_accuracy = self.accuracies[0] 504 | 505 | def setup_loop_meters(self, loop_id): 506 | super().setup_loop_meters(loop_id) 507 | self.loop_accuracy = self.accuracies[loop_id] 508 | 509 | def _update_acc_meter(self, meter_name, outputs, labels): 510 | accuracy = self.loop_accuracy 511 | if self.loop_meters_info[meter_name]['type'] == 'per_class_avg': 512 | self.loop_meters[meter_name].update( 513 | accuracy(outputs, labels, reduction='none'), labels) 514 | else: 515 | n_correct, n = accuracy(outputs, labels) 516 | self.loop_meters[meter_name].update(n_correct, n) 517 | --------------------------------------------------------------------------------