├── imgs └── teaser.png ├── configs ├── linear │ ├── cub.yaml │ ├── dogs.yaml │ ├── flowers.yaml │ ├── cars.yaml │ └── nabirds.yaml ├── prompt │ ├── cub.yaml │ ├── nabirds.yaml │ ├── dogs.yaml │ ├── flowers.yaml │ └── cars.yaml ├── finetune │ ├── nabirds.yaml │ ├── cub.yaml │ ├── cars.yaml │ ├── dogs.yaml │ └── flowers.yaml ├── base-linear.yaml ├── base-prompt.yaml └── base-finetune.yaml ├── src ├── utils │ ├── file_io.py │ ├── train_utils.py │ ├── io_utils.py │ ├── distributed.py │ └── logging.py ├── data │ ├── vtab_datasets │ │ ├── __init__.py │ │ ├── patch_camelyon.py │ │ ├── sun397.py │ │ ├── dtd.py │ │ ├── dmlab.py │ │ ├── caltech.py │ │ ├── resisc45.py │ │ ├── oxford_iiit_pet.py │ │ ├── eurosat.py │ │ ├── cifar.py │ │ ├── smallnorb.py │ │ ├── oxford_flowers102.py │ │ ├── svhn.py │ │ ├── clevr.py │ │ ├── dsprites.py │ │ ├── registry.py │ │ ├── kitti.py │ │ └── diabetic_retinopathy.py │ ├── transforms.py │ ├── loader.py │ └── datasets │ │ ├── json_dataset.py │ │ └── tf_dataset.py ├── configs │ ├── config_node.py │ ├── vit_configs.py │ └── config.py ├── models │ ├── mlp.py │ ├── build_model.py │ ├── vit_adapter │ │ ├── adapter_block.py │ │ ├── vit_moco.py │ │ └── vit_mae.py │ ├── vit_backbones │ │ ├── vit_mae.py │ │ └── vit_moco.py │ └── vit_prompt │ │ ├── vit_moco.py │ │ ├── vit.py │ │ └── vit_mae.py ├── solver │ ├── losses.py │ └── lr_scheduler.py └── engine │ ├── eval │ ├── singlelabel.py │ └── multilabel.py │ └── evaluator.py ├── env_setup.sh ├── .gitignore ├── launch.py ├── train.py ├── VTAB_SETUP.md ├── tune_fgvc.py └── README.md /imgs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KMnP/vpt/HEAD/imgs/teaser.png -------------------------------------------------------------------------------- /configs/linear/cub.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-linear.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "CUB" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 200 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.1 12 | WEIGHT_DECAY: 0.01 -------------------------------------------------------------------------------- /configs/prompt/cub.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-prompt.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "CUB" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 200 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.1 12 | WEIGHT_DECAY: 0.01 -------------------------------------------------------------------------------- /configs/prompt/nabirds.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-prompt.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "nabirds" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 555 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.1 12 | WEIGHT_DECAY: 0.01 -------------------------------------------------------------------------------- /configs/linear/dogs.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-linear.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "StanfordDogs" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 120 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.001 12 | WEIGHT_DECAY: 0.0001 -------------------------------------------------------------------------------- /configs/prompt/dogs.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-prompt.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "StanfordDogs" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 120 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.001 12 | WEIGHT_DECAY: 0.0001 -------------------------------------------------------------------------------- /configs/linear/flowers.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-linear.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "OxfordFlowers" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 102 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.001 12 | WEIGHT_DECAY: 0.0001 -------------------------------------------------------------------------------- /configs/prompt/flowers.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-prompt.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "OxfordFlowers" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 102 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.001 12 | WEIGHT_DECAY: 0.0001 -------------------------------------------------------------------------------- /configs/finetune/nabirds.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-finetune.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "nabirds" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 555 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.00375 12 | WEIGHT_DECAY: 0.01 13 | -------------------------------------------------------------------------------- /configs/linear/cars.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-linear.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "StanfordCars" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 196 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.001 12 | WEIGHT_DECAY: 0.0001 13 | -------------------------------------------------------------------------------- /configs/prompt/cars.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-prompt.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "StanfordCars" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 196 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.001 12 | WEIGHT_DECAY: 0.0001 13 | -------------------------------------------------------------------------------- /configs/linear/nabirds.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-linear.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "nabirds" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 555 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.1 12 | WEIGHT_DECAY: 0.01 13 | LOG_EVERY_N: 10 -------------------------------------------------------------------------------- /configs/finetune/cub.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-finetune.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "CUB" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 200 7 | MULTILABEL: False 8 | FEATURE: "imagenet_supervised" # need to tune 9 | MODEL: 10 | TYPE: "vit" 11 | SOLVER: 12 | BASE_LR: 0.00375 13 | WEIGHT_DECAY: 0.01 14 | -------------------------------------------------------------------------------- /configs/finetune/cars.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-finetune.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "StanfordCars" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 196 7 | MULTILABEL: False 8 | FEATURE: "imagenet_supervised" # need to tune 9 | MODEL: 10 | TYPE: "vit" 11 | SOLVER: 12 | BASE_LR: 0.0375 13 | WEIGHT_DECAY: 0.001 14 | -------------------------------------------------------------------------------- /configs/finetune/dogs.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-finetune.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "StanfordDogs" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 120 7 | MULTILABEL: False 8 | FEATURE: "imagenet_supervised" # need to tune 9 | MODEL: 10 | TYPE: "vit" 11 | SOLVER: 12 | BASE_LR: 0.00375 13 | WEIGHT_DECAY: 0.001 14 | -------------------------------------------------------------------------------- /configs/finetune/flowers.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-finetune.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "OxfordFlowers" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 102 7 | MULTILABEL: False 8 | FEATURE: "imagenet_supervised" # need to tune 9 | MODEL: 10 | TYPE: "vit" 11 | SOLVER: 12 | BASE_LR: 0.001 13 | WEIGHT_DECAY: 0.0001 14 | -------------------------------------------------------------------------------- /src/utils/file_io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Project specific pathmanagers for a project as recommended by Detectron2 5 | """ 6 | from iopath.common.file_io import PathManager as PathManagerBase 7 | from iopath.common.file_io import HTTPURLHandler 8 | 9 | 10 | PathManager = PathManagerBase() 11 | PathManager.register_handler(HTTPURLHandler()) 12 | -------------------------------------------------------------------------------- /configs/base-linear.yaml: -------------------------------------------------------------------------------- 1 | NUM_GPUS: 1 2 | NUM_SHARDS: 1 3 | OUTPUT_DIR: "" 4 | RUN_N_TIMES: 1 5 | MODEL: 6 | TRANSFER_TYPE: "linear" 7 | TYPE: "vit" 8 | LINEAR: 9 | MLP_SIZES: [] 10 | SOLVER: 11 | SCHEDULER: "cosine" 12 | PATIENCE: 300 13 | LOSS: "softmax" 14 | OPTIMIZER: "sgd" 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 0.0001 17 | LOG_EVERY_N: 1 18 | WARMUP_EPOCH: 10 19 | TOTAL_EPOCH: 100 20 | DATA: 21 | NAME: "" 22 | NUMBER_CLASSES: -1 23 | DATAPATH: "" 24 | FEATURE: "sup_vitb16_224" 25 | BATCH_SIZE: 1024 -------------------------------------------------------------------------------- /configs/base-prompt.yaml: -------------------------------------------------------------------------------- 1 | NUM_GPUS: 1 2 | NUM_SHARDS: 1 3 | OUTPUT_DIR: "" 4 | RUN_N_TIMES: 1 5 | MODEL: 6 | TRANSFER_TYPE: "prompt" 7 | TYPE: "vit" 8 | LINEAR: 9 | MLP_SIZES: [] 10 | SOLVER: 11 | SCHEDULER: "cosine" 12 | PATIENCE: 300 13 | LOSS: "softmax" 14 | OPTIMIZER: "sgd" 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 0.0001 17 | LOG_EVERY_N: 100 18 | WARMUP_EPOCH: 10 19 | TOTAL_EPOCH: 100 20 | DATA: 21 | NAME: "" 22 | NUMBER_CLASSES: -1 23 | DATAPATH: "" 24 | FEATURE: "sup_vitb16_224" 25 | BATCH_SIZE: 128 26 | -------------------------------------------------------------------------------- /configs/base-finetune.yaml: -------------------------------------------------------------------------------- 1 | NUM_GPUS: 1 2 | NUM_SHARDS: 1 3 | OUTPUT_DIR: "" 4 | RUN_N_TIMES: 1 5 | MODEL: 6 | TRANSFER_TYPE: "end2end" 7 | TYPE: "vit" 8 | LINEAR: 9 | MLP_SIZES: [] 10 | SOLVER: 11 | SCHEDULER: "cosine" 12 | WARMUP_EPOCH: 5 13 | LOSS: "softmax" 14 | OPTIMIZER: "adamw" 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 0.0001 17 | LOG_EVERY_N: 100 18 | TOTAL_EPOCH: 100 19 | PATIENCE: 300 20 | DATA: 21 | NAME: "" 22 | NUMBER_CLASSES: -1 23 | DATAPATH: "" 24 | FEATURE: "sup_vitb16_224" 25 | BATCH_SIZE: 384 26 | NUM_WORKERS: 8 -------------------------------------------------------------------------------- /src/data/vtab_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | -------------------------------------------------------------------------------- /src/configs/config_node.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Config system (based on Detectron's).""" 4 | 5 | from fvcore.common.config import CfgNode as _CfgNode 6 | from ..utils.file_io import PathManager 7 | 8 | 9 | class CfgNode(_CfgNode): 10 | """ 11 | The same as `fvcore.common.config.CfgNode`, but different in: 12 | 13 | support manifold path 14 | """ 15 | 16 | @classmethod 17 | def _open_cfg(cls, filename): 18 | return PathManager.open(filename, "r") 19 | 20 | def dump(self, *args, **kwargs): 21 | """ 22 | Returns: 23 | str: a yaml string representation of the config 24 | """ 25 | # to make it show up in docs 26 | return super().dump(*args, **kwargs) 27 | -------------------------------------------------------------------------------- /env_setup.sh: -------------------------------------------------------------------------------- 1 | conda create -n prompt python=3.7 2 | conda activate prompt 3 | 4 | pip install -q tensorflow 5 | # specifying tfds versions is important to reproduce our results 6 | pip install tfds-nightly==4.4.0.dev202201080107 7 | pip install opencv-python 8 | pip install tensorflow-addons 9 | pip install mock 10 | 11 | 12 | conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=11.0 -c pytorch 13 | 14 | python -m pip install detectron2 -f \ 15 | https://dl.fbaipublicfiles.com/detectron2/wheels/cu110/torch1.7/index.html 16 | pip install opencv-python 17 | 18 | conda install tqdm pandas matplotlib seaborn scikit-learn scipy simplejson termcolor 19 | conda install -c iopath iopath 20 | 21 | 22 | # for transformers 23 | pip install timm==0.4.12 24 | pip install ml-collections 25 | 26 | # Optional: for slurm jobs 27 | pip install submitit -U 28 | pip install slurm_gpustat -------------------------------------------------------------------------------- /src/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import torch 3 | 4 | 5 | def gpu_mem_usage(): 6 | """Computes the GPU memory usage for the current device (GB).""" 7 | if not torch.cuda.is_available(): 8 | return 0 9 | # Number of bytes in a megabyte 10 | _B_IN_GB = 1024 * 1024 * 1024 11 | 12 | mem_usage_bytes = torch.cuda.max_memory_allocated() 13 | return mem_usage_bytes / _B_IN_GB 14 | 15 | 16 | class AverageMeter(object): 17 | """Computes and stores the average and current value""" 18 | def __init__(self, name, fmt=':f'): 19 | self.name = name 20 | self.fmt = fmt 21 | self.reset() 22 | 23 | def reset(self): 24 | self.val = 0 25 | self.avg = 0 26 | self.sum = 0 27 | self.count = 0 28 | 29 | def update(self, val, n=1): 30 | self.val = val 31 | self.sum += val * n 32 | self.count += n 33 | self.avg = self.sum / self.count 34 | 35 | def __str__(self): 36 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 37 | return fmtstr.format(**self.__dict__) 38 | -------------------------------------------------------------------------------- /src/data/transforms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Image transformations.""" 4 | import torchvision as tv 5 | 6 | 7 | def get_transforms(split, size): 8 | normalize = tv.transforms.Normalize( 9 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 10 | ) 11 | if size == 448: 12 | resize_dim = 512 13 | crop_dim = 448 14 | elif size == 224: 15 | resize_dim = 256 16 | crop_dim = 224 17 | elif size == 384: 18 | resize_dim = 438 19 | crop_dim = 384 20 | if split == "train": 21 | transform = tv.transforms.Compose( 22 | [ 23 | tv.transforms.Resize(resize_dim), 24 | tv.transforms.RandomCrop(crop_dim), 25 | tv.transforms.RandomHorizontalFlip(0.5), 26 | # tv.transforms.RandomResizedCrop(224, scale=(0.2, 1.0)), 27 | # tv.transforms.RandomHorizontalFlip(), 28 | tv.transforms.ToTensor(), 29 | normalize, 30 | ] 31 | ) 32 | else: 33 | transform = tv.transforms.Compose( 34 | [ 35 | tv.transforms.Resize(resize_dim), 36 | tv.transforms.CenterCrop(crop_dim), 37 | tv.transforms.ToTensor(), 38 | normalize, 39 | ] 40 | ) 41 | return transform 42 | -------------------------------------------------------------------------------- /src/utils/io_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | a bunch of helper functions for read and write data 4 | """ 5 | import os 6 | import json 7 | import numpy as np 8 | import time 9 | import pandas as pd 10 | 11 | from typing import List, Union 12 | from PIL import Image, ImageFile 13 | Image.MAX_IMAGE_PIXELS = None 14 | 15 | 16 | def save_or_append_df(out_path, df): 17 | if os.path.exists(out_path): 18 | previous_df = pd.read_pickle(out_path) 19 | df = pd.concat([previous_df, df], ignore_index=True) 20 | df.to_pickle(out_path) 21 | print(f"Saved output at {out_path}") 22 | 23 | 24 | class JSONEncoder(json.JSONEncoder): 25 | def default(self, obj): 26 | if isinstance(obj, np.ndarray): 27 | return obj.tolist() 28 | elif isinstance(obj, bytes): 29 | return str(obj, encoding='utf-8') 30 | elif isinstance(obj, np.integer): 31 | return int(obj) 32 | elif isinstance(obj, np.floating): 33 | return float(obj) 34 | elif isinstance(obj, np.ndarray): 35 | return obj.tolist() 36 | else: 37 | # return super(MyEncoder, self).default(obj) 38 | 39 | raise TypeError( 40 | "Unserializable object {} of type {}".format(obj, type(obj)) 41 | ) 42 | 43 | 44 | def write_json(data: Union[list, dict], outfile: str) -> None: 45 | json_dir, _ = os.path.split(outfile) 46 | if json_dir and not os.path.exists(json_dir): 47 | os.makedirs(json_dir) 48 | 49 | with open(outfile, 'w') as f: 50 | json.dump(data, f, cls=JSONEncoder, ensure_ascii=False, indent=2) 51 | 52 | 53 | def read_json(filename: str) -> Union[list, dict]: 54 | """read json files""" 55 | with open(filename, "rb") as fin: 56 | data = json.load(fin, encoding="utf-8") 57 | return data 58 | 59 | 60 | def pil_loader(path: str) -> Image.Image: 61 | """load an image from path, and suppress warning""" 62 | # to avoid crashing for truncated (corrupted images) 63 | ImageFile.LOAD_TRUNCATED_IMAGES = True 64 | # open path as file to avoid ResourceWarning 65 | # (https://github.com/python-pillow/Pillow/issues/835) 66 | with open(path, 'rb') as f: 67 | img = Image.open(f) 68 | return img.convert('RGB') 69 | -------------------------------------------------------------------------------- /src/models/mlp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Modified from: fbcode/multimo/models/encoders/mlp.py 4 | """ 5 | import math 6 | import torch 7 | 8 | from torch import nn 9 | from typing import List, Type 10 | 11 | from ..utils import logging 12 | logger = logging.get_logger("visual_prompt") 13 | 14 | 15 | class MLP(nn.Module): 16 | def __init__( 17 | self, 18 | input_dim: int, 19 | mlp_dims: List[int], 20 | dropout: float = 0.1, 21 | nonlinearity: Type[nn.Module] = nn.ReLU, 22 | normalization: Type[nn.Module] = nn.BatchNorm1d, # nn.LayerNorm, 23 | special_bias: bool = False, 24 | add_bn_first: bool = False, 25 | ): 26 | super(MLP, self).__init__() 27 | projection_prev_dim = input_dim 28 | projection_modulelist = [] 29 | last_dim = mlp_dims[-1] 30 | mlp_dims = mlp_dims[:-1] 31 | 32 | if add_bn_first: 33 | if normalization is not None: 34 | projection_modulelist.append(normalization(projection_prev_dim)) 35 | if dropout != 0: 36 | projection_modulelist.append(nn.Dropout(dropout)) 37 | 38 | for idx, mlp_dim in enumerate(mlp_dims): 39 | fc_layer = nn.Linear(projection_prev_dim, mlp_dim) 40 | nn.init.kaiming_normal_(fc_layer.weight, a=0, mode='fan_out') 41 | projection_modulelist.append(fc_layer) 42 | projection_modulelist.append(nonlinearity()) 43 | 44 | if normalization is not None: 45 | projection_modulelist.append(normalization(mlp_dim)) 46 | 47 | if dropout != 0: 48 | projection_modulelist.append(nn.Dropout(dropout)) 49 | projection_prev_dim = mlp_dim 50 | 51 | self.projection = nn.Sequential(*projection_modulelist) 52 | self.last_layer = nn.Linear(projection_prev_dim, last_dim) 53 | nn.init.kaiming_normal_(self.last_layer.weight, a=0, mode='fan_out') 54 | if special_bias: 55 | prior_prob = 0.01 56 | bias_value = -math.log((1 - prior_prob) / prior_prob) 57 | torch.nn.init.constant_(self.last_layer.bias, bias_value) 58 | 59 | def forward(self, x: torch.Tensor) -> torch.Tensor: 60 | """ 61 | input_arguments: 62 | @x: torch.FloatTensor 63 | """ 64 | x = self.projection(x) 65 | x = self.last_layer(x) 66 | return x 67 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | GIT_TOKENS.md 2 | run_examples.sh 3 | *.npy 4 | *.zip 5 | 6 | # General 7 | .DS_Store? 8 | .DS_Store 9 | .AppleDouble 10 | .LSOverride 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | pip-wheel-metadata/ 35 | share/python-wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .nox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | *.py,cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 106 | __pypackages__/ 107 | 108 | # Celery stuff 109 | celerybeat-schedule 110 | celerybeat.pid 111 | 112 | # SageMath parsed files 113 | *.sage.py 114 | 115 | # Environments 116 | .env 117 | .venv 118 | env/ 119 | venv/ 120 | ENV/ 121 | env.bak/ 122 | venv.bak/ 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | -------------------------------------------------------------------------------- /src/solver/losses.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from typing import Optional 6 | 7 | from ..utils import logging 8 | logger = logging.get_logger("visual_prompt") 9 | 10 | 11 | class SigmoidLoss(nn.Module): 12 | def __init__(self, cfg=None): 13 | super(SigmoidLoss, self).__init__() 14 | 15 | def is_single(self): 16 | return True 17 | 18 | def is_local(self): 19 | return False 20 | 21 | def multi_hot(self, labels: torch.Tensor, nb_classes: int) -> torch.Tensor: 22 | labels = labels.unsqueeze(1) # (batch_size, 1) 23 | target = torch.zeros( 24 | labels.size(0), nb_classes, device=labels.device 25 | ).scatter_(1, labels, 1.) 26 | # (batch_size, num_classes) 27 | return target 28 | 29 | def loss( 30 | self, logits, targets, per_cls_weights, 31 | multihot_targets: Optional[bool] = False 32 | ): 33 | # targets: 1d-tensor of integer 34 | # Only support single label at this moment 35 | # if len(targets.shape) != 2: 36 | num_classes = logits.shape[1] 37 | targets = self.multi_hot(targets, num_classes) 38 | 39 | loss = F.binary_cross_entropy_with_logits( 40 | logits, targets, reduction="none") 41 | # logger.info(f"loss shape: {loss.shape}") 42 | weight = torch.tensor( 43 | per_cls_weights, device=logits.device 44 | ).unsqueeze(0) 45 | # logger.info(f"weight shape: {weight.shape}") 46 | loss = torch.mul(loss.to(torch.float32), weight.to(torch.float32)) 47 | return torch.sum(loss) / targets.shape[0] 48 | 49 | def forward( 50 | self, pred_logits, targets, per_cls_weights, multihot_targets=False 51 | ): 52 | loss = self.loss( 53 | pred_logits, targets, per_cls_weights, multihot_targets) 54 | return loss 55 | 56 | 57 | class SoftmaxLoss(SigmoidLoss): 58 | def __init__(self, cfg=None): 59 | super(SoftmaxLoss, self).__init__() 60 | 61 | def loss(self, logits, targets, per_cls_weights, kwargs): 62 | weight = torch.tensor( 63 | per_cls_weights, device=logits.device 64 | ) 65 | loss = F.cross_entropy(logits, targets, weight, reduction="none") 66 | 67 | return torch.sum(loss) / targets.shape[0] 68 | 69 | 70 | LOSS = { 71 | "softmax": SoftmaxLoss, 72 | } 73 | 74 | 75 | def build_loss(cfg): 76 | loss_name = cfg.SOLVER.LOSS 77 | assert loss_name in LOSS, \ 78 | f'loss name {loss_name} is not supported' 79 | loss_fn = LOSS[loss_name] 80 | if not loss_fn: 81 | return None 82 | else: 83 | return loss_fn(cfg) 84 | -------------------------------------------------------------------------------- /src/models/build_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Model construction functions. 4 | """ 5 | from tabnanny import verbose 6 | import torch 7 | 8 | from .resnet import ResNet 9 | from .convnext import ConvNeXt 10 | from .vit_models import ViT, Swin, SSLViT 11 | from ..utils import logging 12 | logger = logging.get_logger("visual_prompt") 13 | # Supported model types 14 | _MODEL_TYPES = { 15 | "resnet": ResNet, 16 | "convnext": ConvNeXt, 17 | "vit": ViT, 18 | "swin": Swin, 19 | "ssl-vit": SSLViT, 20 | } 21 | 22 | 23 | def build_model(cfg): 24 | """ 25 | build model here 26 | """ 27 | assert ( 28 | cfg.MODEL.TYPE in _MODEL_TYPES.keys() 29 | ), "Model type '{}' not supported".format(cfg.MODEL.TYPE) 30 | assert ( 31 | cfg.NUM_GPUS <= torch.cuda.device_count() 32 | ), "Cannot use more GPU devices than available" 33 | 34 | # Construct the model 35 | train_type = cfg.MODEL.TYPE 36 | model = _MODEL_TYPES[train_type](cfg) 37 | 38 | log_model_info(model, verbose=cfg.DBG) 39 | model, device = load_model_to_device(model, cfg) 40 | logger.info(f"Device used for model: {device}") 41 | 42 | return model, device 43 | 44 | 45 | def log_model_info(model, verbose=False): 46 | """Logs model info""" 47 | if verbose: 48 | logger.info(f"Classification Model:\n{model}") 49 | model_total_params = sum(p.numel() for p in model.parameters()) 50 | model_grad_params = sum( 51 | p.numel() for p in model.parameters() if p.requires_grad) 52 | logger.info("Total Parameters: {0}\t Gradient Parameters: {1}".format( 53 | model_total_params, model_grad_params)) 54 | logger.info("tuned percent:%.3f"%(model_grad_params/model_total_params*100)) 55 | 56 | 57 | def get_current_device(): 58 | if torch.cuda.is_available(): 59 | # Determine the GPU used by the current process 60 | cur_device = torch.cuda.current_device() 61 | else: 62 | cur_device = torch.device('cpu') 63 | return cur_device 64 | 65 | 66 | def load_model_to_device(model, cfg): 67 | cur_device = get_current_device() 68 | if torch.cuda.is_available(): 69 | # Transfer the model to the current GPU device 70 | model = model.cuda(device=cur_device) 71 | # Use multi-process data parallel model in the multi-gpu setting 72 | if cfg.NUM_GPUS > 1: 73 | # Make model replica operate on the current device 74 | model = torch.nn.parallel.DistributedDataParallel( 75 | module=model, device_ids=[cur_device], output_device=cur_device, 76 | find_unused_parameters=True, 77 | ) 78 | else: 79 | model = model.to(cur_device) 80 | return model, cur_device 81 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/patch_camelyon.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements PatchCamelyon data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | 28 | 29 | @Registry.register("data.patch_camelyon", "class") 30 | class PatchCamelyonData(base.ImageTfdsData): 31 | """Provides PatchCamelyon data.""" 32 | 33 | def __init__(self, data_dir=None): 34 | 35 | dataset_builder = tfds.builder("patch_camelyon:2.*.*", data_dir=data_dir) 36 | dataset_builder.download_and_prepare() 37 | 38 | # Defines dataset specific train/val/trainval/test splits. 39 | tfds_splits = { 40 | "test": "test", 41 | "train": "train", 42 | "val": "validation", 43 | "trainval": "train+validation", 44 | "train800": "train[:800]", 45 | "val200": "validation[:200]", 46 | "train800val200": "train[:800]+validation[:200]", 47 | } 48 | # Creates a dict with example counts. 49 | num_samples_splits = { 50 | "test": dataset_builder.info.splits["test"].num_examples, 51 | "train": dataset_builder.info.splits["train"].num_examples, 52 | "val": dataset_builder.info.splits["validation"].num_examples, 53 | "train800": 800, 54 | "val200": 200, 55 | "train800val200": 1000, 56 | } 57 | num_samples_splits["trainval"] = ( 58 | num_samples_splits["train"] + num_samples_splits["val"]) 59 | super(PatchCamelyonData, self).__init__( 60 | dataset_builder=dataset_builder, 61 | tfds_splits=tfds_splits, 62 | num_samples_splits=num_samples_splits, 63 | num_preprocessing_threads=400, 64 | shuffle_buffer_size=10000, 65 | # Note: Export only image and label tensors with their original types. 66 | base_preprocess_fn=base.make_get_tensors_fn(["image", "label"]), 67 | num_classes=dataset_builder.info.features["label"].num_classes) 68 | -------------------------------------------------------------------------------- /src/models/vit_adapter/adapter_block.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | ''' 3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 4 | ''' 5 | import math 6 | import logging 7 | from functools import partial 8 | from collections import OrderedDict 9 | from copy import deepcopy 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ 16 | from timm.models.vision_transformer import Attention 17 | from timm.models.vision_transformer import Block 18 | 19 | from ...utils import logging 20 | logger = logging.get_logger("visual_prompt") 21 | 22 | 23 | class Pfeiffer_Block(Block): 24 | 25 | def __init__(self, adapter_config, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 26 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 27 | 28 | super(Pfeiffer_Block, self).__init__( 29 | dim=dim, 30 | num_heads=num_heads, 31 | mlp_ratio=mlp_ratio, 32 | qkv_bias=qkv_bias, 33 | drop=drop, 34 | attn_drop=attn_drop, 35 | drop_path=drop_path, 36 | act_layer=act_layer, 37 | norm_layer=norm_layer) 38 | 39 | self.adapter_config = adapter_config 40 | 41 | if adapter_config.STYLE == "Pfeiffer": 42 | self.adapter_downsample = nn.Linear( 43 | dim, 44 | dim // adapter_config.REDUCATION_FACTOR 45 | ) 46 | self.adapter_upsample = nn.Linear( 47 | dim // adapter_config.REDUCATION_FACTOR, 48 | dim 49 | ) 50 | self.adapter_act_fn = act_layer() 51 | 52 | nn.init.zeros_(self.adapter_downsample.weight) 53 | nn.init.zeros_(self.adapter_downsample.bias) 54 | 55 | nn.init.zeros_(self.adapter_upsample.weight) 56 | nn.init.zeros_(self.adapter_upsample.bias) 57 | else: 58 | raise ValueError("Other adapter styles are not supported.") 59 | 60 | def forward(self, x): 61 | 62 | if self.adapter_config.STYLE == "Pfeiffer": 63 | # same as reguluar ViT block 64 | h = x 65 | x = self.norm1(x) 66 | x = self.attn(x) 67 | x = self.drop_path(x) 68 | x = x + h 69 | 70 | h = x 71 | x = self.norm2(x) 72 | x = self.mlp(x) 73 | 74 | # start to insert adapter layers... 75 | adpt = self.adapter_downsample(x) 76 | adpt = self.adapter_act_fn(adpt) 77 | adpt = self.adapter_upsample(adpt) 78 | x = adpt + x 79 | # ...end 80 | 81 | x = self.drop_path(x) 82 | x = x + h 83 | 84 | return x 85 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/sun397.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements Sun397 data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | CUSTOM_TRAIN_SPLIT_PERCENT = 50 28 | CUSTOM_VALIDATION_SPLIT_PERCENT = 20 29 | CUSTOM_TEST_SPLIT_PERCENT = 30 30 | 31 | 32 | @Registry.register("data.sun397", "class") 33 | class Sun397Data(base.ImageTfdsData): 34 | """Provides Sun397Data data.""" 35 | 36 | def __init__(self, config="tfds", data_dir=None): 37 | 38 | if config == "tfds": 39 | dataset_builder = tfds.builder("sun397/tfds:4.*.*", data_dir=data_dir) 40 | dataset_builder.download_and_prepare() 41 | 42 | tfds_splits = { 43 | "train": "train", 44 | "val": "validation", 45 | "test": "test", 46 | "trainval": "train+validation", 47 | "train800": "train[:800]", 48 | "val200": "validation[:200]", 49 | "train800val200": "train[:800]+validation[:200]", 50 | } 51 | # Creates a dict with example counts. 52 | num_samples_splits = { 53 | "test": dataset_builder.info.splits["test"].num_examples, 54 | "train": dataset_builder.info.splits["train"].num_examples, 55 | "val": dataset_builder.info.splits["validation"].num_examples, 56 | "train800": 800, 57 | "val200": 200, 58 | "train800val200": 1000, 59 | } 60 | num_samples_splits["trainval"] = ( 61 | num_samples_splits["train"] + num_samples_splits["val"]) 62 | else: 63 | 64 | raise ValueError("No supported config %r for Sun397Data." % config) 65 | 66 | super(Sun397Data, self).__init__( 67 | dataset_builder=dataset_builder, 68 | tfds_splits=tfds_splits, 69 | num_samples_splits=num_samples_splits, 70 | num_preprocessing_threads=400, 71 | shuffle_buffer_size=10000, 72 | # Note: Export only image and label tensors with their original types. 73 | base_preprocess_fn=base.make_get_tensors_fn(["image", "label"]), 74 | num_classes=dataset_builder.info.features["label"].num_classes) 75 | -------------------------------------------------------------------------------- /src/models/vit_backbones/vit_mae.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | """ 4 | borrowed from https://github.com/facebookresearch/mae/blob/main/models_vit.py 5 | """ 6 | from functools import partial 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | import timm.models.vision_transformer 12 | 13 | 14 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 15 | """ Vision Transformer with support for global average pooling 16 | """ 17 | def __init__(self, global_pool=False, **kwargs): 18 | super(VisionTransformer, self).__init__(**kwargs) 19 | 20 | self.global_pool = global_pool 21 | if self.global_pool: 22 | norm_layer = kwargs['norm_layer'] 23 | embed_dim = kwargs['embed_dim'] 24 | self.fc_norm = norm_layer(embed_dim) 25 | 26 | del self.norm # remove the original norm 27 | 28 | def forward_features(self, x): 29 | B = x.shape[0] 30 | x = self.patch_embed(x) 31 | 32 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 33 | x = torch.cat((cls_tokens, x), dim=1) 34 | x = x + self.pos_embed 35 | x = self.pos_drop(x) 36 | 37 | for blk in self.blocks: 38 | x = blk(x) 39 | 40 | if self.global_pool: 41 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 42 | outcome = self.fc_norm(x) 43 | else: 44 | x = self.norm(x) 45 | outcome = x[:, 0] 46 | 47 | return outcome 48 | 49 | 50 | def build_model(model_type): 51 | if "vitb" in model_type: 52 | return vit_base_patch16() 53 | elif "vitl" in model_type: 54 | return vit_large_patch16() 55 | elif "vith" in model_type: 56 | return vit_huge_patch14() 57 | 58 | 59 | def vit_base_patch16(**kwargs): 60 | model = VisionTransformer( 61 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 62 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 63 | mlp_ratio=4, qkv_bias=True, 64 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 65 | return model 66 | 67 | 68 | def vit_large_patch16(**kwargs): 69 | model = VisionTransformer( 70 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 71 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 72 | mlp_ratio=4, qkv_bias=True, 73 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 74 | return model 75 | 76 | 77 | def vit_huge_patch14(**kwargs): 78 | model = VisionTransformer( 79 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 80 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 81 | mlp_ratio=4, qkv_bias=True, 82 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 83 | return model 84 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/dtd.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements the Describable Textures Dataset (DTD) data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | 28 | 29 | @Registry.register("data.dtd", "class") 30 | class DTDData(base.ImageTfdsData): 31 | """Provides Describable Textures Dataset (DTD) data. 32 | 33 | As of version 1.0.0, the train/val/test splits correspond to those of the 34 | 1st fold of the official cross-validation partition. 35 | 36 | For additional details and usage, see the base class. 37 | """ 38 | 39 | def __init__(self, data_dir=None): 40 | 41 | dataset_builder = tfds.builder("dtd:3.*.*", data_dir=data_dir) 42 | dataset_builder.download_and_prepare() 43 | 44 | # Defines dataset specific train/val/trainval/test splits. 45 | tfds_splits = { 46 | "train": "train", 47 | "val": "validation", 48 | "trainval": "train+validation", 49 | "test": "test", 50 | "train800": "train[:800]", 51 | "val200": "validation[:200]", 52 | "train800val200": "train[:800]+validation[:200]", 53 | } 54 | 55 | # Creates a dict with example counts for each split. 56 | train_count = dataset_builder.info.splits["train"].num_examples 57 | val_count = dataset_builder.info.splits["validation"].num_examples 58 | test_count = dataset_builder.info.splits["test"].num_examples 59 | num_samples_splits = { 60 | "train": train_count, 61 | "val": val_count, 62 | "trainval": train_count + val_count, 63 | "test": test_count, 64 | "train800": 800, 65 | "val200": 200, 66 | "train800val200": 1000, 67 | } 68 | 69 | super(DTDData, self).__init__( 70 | dataset_builder=dataset_builder, 71 | tfds_splits=tfds_splits, 72 | num_samples_splits=num_samples_splits, 73 | num_preprocessing_threads=400, 74 | shuffle_buffer_size=10000, 75 | # Note: Export only image and label tensors with their original types. 76 | base_preprocess_fn=base.make_get_tensors_fn(["image", "label"]), 77 | num_classes=dataset_builder.info.features["label"].num_classes) 78 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/dmlab.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements Dmlab data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | 28 | 29 | @Registry.register("data.dmlab", "class") 30 | class DmlabData(base.ImageTfdsData): 31 | """Dmlab dataset. 32 | 33 | The Dmlab dataset contains frames observed by the agent acting in the 34 | DMLab environment, which are annotated by the distance between 35 | the agent and various objects present in the environment. The goal is to 36 | is to evaluate the ability of a visual model to reason about distances 37 | from the visual input in 3D environments. The Dmlab dataset consists of 38 | 360x480 color images in 6 classes. The classes are 39 | {close, far, very far} x {positive reward, negative reward} 40 | respectively. 41 | """ 42 | 43 | def __init__(self, data_dir=None): 44 | 45 | dataset_builder = tfds.builder("dmlab:2.0.1", data_dir=data_dir) 46 | 47 | tfds_splits = { 48 | "train": "train", 49 | "val": "validation", 50 | "trainval": "train+validation", 51 | "test": "test", 52 | "train800": "train[:800]", 53 | "val200": "validation[:200]", 54 | "train800val200": "train[:800]+validation[:200]", 55 | } 56 | 57 | # Example counts are retrieved from the tensorflow dataset info. 58 | train_count = dataset_builder.info.splits["train"].num_examples 59 | val_count = dataset_builder.info.splits["validation"].num_examples 60 | test_count = dataset_builder.info.splits["test"].num_examples 61 | 62 | # Creates a dict with example counts for each split. 63 | num_samples_splits = { 64 | "train": train_count, 65 | "val": val_count, 66 | "trainval": train_count + val_count, 67 | "test": test_count, 68 | "train800": 800, 69 | "val200": 200, 70 | "train800val200": 1000, 71 | } 72 | 73 | super(DmlabData, self).__init__( 74 | dataset_builder=dataset_builder, 75 | tfds_splits=tfds_splits, 76 | num_samples_splits=num_samples_splits, 77 | num_preprocessing_threads=400, 78 | shuffle_buffer_size=10000, 79 | base_preprocess_fn=base.make_get_and_cast_tensors_fn({ 80 | "image": ("image", None), 81 | "label": ("label", None), 82 | }), 83 | num_classes=dataset_builder.info.features["label"].num_classes, 84 | image_key="image") 85 | -------------------------------------------------------------------------------- /src/engine/eval/singlelabel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Functions for computing metrics. all metrics has range of 0-1""" 4 | 5 | import numpy as np 6 | import torch 7 | from sklearn.metrics import ( 8 | accuracy_score, average_precision_score, f1_score, roc_auc_score 9 | ) 10 | 11 | 12 | def accuracy(y_probs, y_true): 13 | # y_prob: (num_images, num_classes) 14 | y_preds = np.argmax(y_probs, axis=1) 15 | accuracy = accuracy_score(y_true, y_preds) 16 | error = 1.0 - accuracy 17 | return accuracy, error 18 | 19 | 20 | def top_n_accuracy(y_probs, truths, n=1): 21 | # y_prob: (num_images, num_classes) 22 | # truth: (num_images, num_classes) multi/one-hot encoding 23 | best_n = np.argsort(y_probs, axis=1)[:, -n:] 24 | if isinstance(truths, np.ndarray) and truths.shape == y_probs.shape: 25 | ts = np.argmax(truths, axis=1) 26 | else: 27 | # a list of GT class idx 28 | ts = truths 29 | 30 | num_input = y_probs.shape[0] 31 | successes = 0 32 | for i in range(num_input): 33 | if ts[i] in best_n[i, :]: 34 | successes += 1 35 | return float(successes) / num_input 36 | 37 | 38 | def compute_acc_auc(y_probs, y_true_ids): 39 | onehot_tgts = np.zeros_like(y_probs) 40 | for idx, t in enumerate(y_true_ids): 41 | onehot_tgts[idx, t] = 1. 42 | 43 | num_classes = y_probs.shape[1] 44 | if num_classes == 2: 45 | top1, _ = accuracy(y_probs, y_true_ids) 46 | # so precision can set all to 2 47 | try: 48 | auc = roc_auc_score(onehot_tgts, y_probs, average='macro') 49 | except ValueError as e: 50 | print(f"value error encountered {e}, set auc sccore to -1.") 51 | auc = -1 52 | return {"top1": top1, "rocauc": auc} 53 | 54 | top1, _ = accuracy(y_probs, y_true_ids) 55 | k = min([5, num_classes]) # if number of labels < 5, use the total class 56 | top5 = top_n_accuracy(y_probs, y_true_ids, k) 57 | return {"top1": top1, f"top{k}": top5} 58 | 59 | 60 | def topks_correct(preds, labels, ks): 61 | """Computes the number of top-k correct predictions for each k.""" 62 | assert preds.size(0) == labels.size( 63 | 0 64 | ), "Batch dim of predictions and labels must match" 65 | # Find the top max_k predictions for each sample 66 | _top_max_k_vals, top_max_k_inds = torch.topk( 67 | preds, max(ks), dim=1, largest=True, sorted=True 68 | ) 69 | # (batch_size, max_k) -> (max_k, batch_size) 70 | top_max_k_inds = top_max_k_inds.t() 71 | # (batch_size, ) -> (max_k, batch_size) 72 | rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds) 73 | # (i, j) = 1 if top i-th prediction for the j-th sample is correct 74 | top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels) 75 | # Compute the number of topk correct predictions for each k 76 | topks_correct = [ 77 | top_max_k_correct[:k, :].reshape(-1).float().sum() for k in ks 78 | ] 79 | return topks_correct 80 | 81 | 82 | def topk_errors(preds, labels, ks): 83 | """Computes the top-k error for each k.""" 84 | if int(labels.min()) < 0: # has ignore 85 | keep_ids = np.where(labels.cpu() >= 0)[0] 86 | preds = preds[keep_ids, :] 87 | labels = labels[keep_ids] 88 | 89 | num_topks_correct = topks_correct(preds, labels, ks) 90 | return [(1.0 - x / preds.size(0)) for x in num_topks_correct] 91 | 92 | 93 | def topk_accuracies(preds, labels, ks): 94 | """Computes the top-k accuracy for each k.""" 95 | num_topks_correct = topks_correct(preds, labels, ks) 96 | return [(x / preds.size(0)) for x in num_topks_correct] 97 | 98 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/caltech.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Imports the Caltech images dataset.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | from . import base as base 23 | from .registry import Registry 24 | import tensorflow_datasets as tfds 25 | 26 | 27 | # Percentage of the original training set retained for training, the rest is 28 | # used as a validation set. 29 | _TRAIN_SPLIT_PERCENT = 90 30 | 31 | 32 | @Registry.register("data.caltech101", "class") 33 | class Caltech101(base.ImageTfdsData): 34 | """Provides the Caltech101 dataset. 35 | 36 | See the base class for additional details on the class. 37 | 38 | See TFDS dataset for details on the dataset: 39 | third_party/py/tensorflow_datasets/image/caltech.py 40 | 41 | The original (TFDS) dataset contains only a train and test split. We randomly 42 | sample _TRAIN_SPLIT_PERCENT% of the train split for our "train" set. The 43 | remainder of the TFDS train split becomes our "val" set. The full TFDS train 44 | split is called "trainval". The TFDS test split is used as our test set. 45 | 46 | Note that, in the TFDS dataset, the training split is class-balanced, but not 47 | the test split. Therefore, a significant difference between performance on the 48 | "val" and "test" sets should be expected. 49 | """ 50 | 51 | def __init__(self, data_dir=None): 52 | dataset_builder = tfds.builder("caltech101:3.*.*", data_dir=data_dir) 53 | dataset_builder.download_and_prepare() 54 | 55 | # Creates a dict with example counts for each split. 56 | trainval_count = dataset_builder.info.splits["train"].num_examples 57 | train_count = (_TRAIN_SPLIT_PERCENT * trainval_count) // 100 58 | test_count = dataset_builder.info.splits["test"].num_examples 59 | num_samples_splits = dict( 60 | train=train_count, 61 | val=trainval_count - train_count, 62 | trainval=trainval_count, 63 | test=test_count, 64 | train800=800, 65 | val200=200, 66 | train800val200=1000) 67 | 68 | # Defines dataset specific train/val/trainval/test splits. 69 | tfds_splits = { 70 | "train": "train[:{}]".format(train_count), 71 | "val": "train[{}:]".format(train_count), 72 | "trainval": "train", 73 | "test": "test", 74 | "train800": "train[:800]", 75 | "val200": "train[{}:{}]".format(train_count, train_count+200), 76 | "train800val200": ( 77 | "train[:800]+train[{}:{}]".format(train_count, train_count+200)), 78 | } 79 | 80 | super(Caltech101, self).__init__( 81 | dataset_builder=dataset_builder, 82 | tfds_splits=tfds_splits, 83 | num_samples_splits=num_samples_splits, 84 | num_preprocessing_threads=400, 85 | shuffle_buffer_size=3000, 86 | base_preprocess_fn=base.make_get_tensors_fn(("image", "label")), 87 | num_classes=dataset_builder.info.features["label"].num_classes) 88 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/resisc45.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements RESISC-45 data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | TRAIN_SPLIT_PERCENT = 60 28 | VALIDATION_SPLIT_PERCENT = 20 29 | TEST_SPLIT_PERCENT = 20 30 | 31 | 32 | @Registry.register("data.resisc45", "class") 33 | class Resisc45Data(base.ImageTfdsData): 34 | """Provides RESISC-45 dataset. 35 | 36 | RESISC45 dataset is a publicly available benchmark for Remote Sensing Image 37 | Scene Classification (RESISC), created by Northwestern Polytechnical 38 | University (NWPU). This dataset contains 31,500 images, covering 45 scene 39 | classes with 700 images in each class. 40 | 41 | URL: http://www.escience.cn/people/JunweiHan/NWPU-RESISC45.html 42 | """ 43 | 44 | def __init__(self, data_dir=None): 45 | dataset_builder = tfds.builder("resisc45:3.*.*", data_dir=data_dir) 46 | dataset_builder.download_and_prepare() 47 | 48 | # Example counts are retrieved from the tensorflow dataset info. 49 | num_examples = dataset_builder.info.splits["train"].num_examples 50 | train_count = num_examples * TRAIN_SPLIT_PERCENT // 100 51 | val_count = num_examples * VALIDATION_SPLIT_PERCENT // 100 52 | test_count = num_examples * TEST_SPLIT_PERCENT // 100 53 | 54 | tfds_splits = { 55 | "train": 56 | "train[:{}]".format(train_count), 57 | "val": 58 | "train[{}:{}]".format(train_count, train_count + val_count), 59 | "trainval": 60 | "train[:{}]".format(train_count + val_count), 61 | "test": 62 | "train[{}:]".format(train_count + val_count), 63 | "train800": 64 | "train[:800]", 65 | "val200": 66 | "train[{}:{}]".format(train_count, train_count+200), 67 | "train800val200": 68 | "train[:800]+train[{}:{}]".format(train_count, train_count+200), 69 | } 70 | 71 | # Creates a dict with example counts for each split. 72 | num_samples_splits = { 73 | "train": train_count, 74 | "val": val_count, 75 | "trainval": train_count + val_count, 76 | "test": test_count, 77 | "train800": 800, 78 | "val200": 200, 79 | "train800val200": 1000, 80 | } 81 | 82 | super(Resisc45Data, self).__init__( 83 | dataset_builder=dataset_builder, 84 | tfds_splits=tfds_splits, 85 | num_samples_splits=num_samples_splits, 86 | num_preprocessing_threads=400, 87 | shuffle_buffer_size=10000, 88 | # Note: Rename tensors but keep their original types. 89 | base_preprocess_fn=base.make_get_and_cast_tensors_fn({ 90 | "image": ("image", None), 91 | "label": ("label", None), 92 | }), 93 | num_classes=dataset_builder.info.features["label"].num_classes) 94 | -------------------------------------------------------------------------------- /src/models/vit_adapter/vit_moco.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | borrow from https://github.com/facebookresearch/moco-v3/blob/main/vits.py 4 | """ 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | from functools import partial, reduce 9 | from operator import mul 10 | 11 | from timm.models.vision_transformer import VisionTransformer, _cfg 12 | from timm.models.layers.helpers import to_2tuple 13 | from timm.models.layers import PatchEmbed 14 | 15 | from .adapter_block import Pfeiffer_Block 16 | from ..vit_backbones.vit_moco import VisionTransformerMoCo 17 | from ...utils import logging 18 | logger = logging.get_logger("visual_prompt") 19 | 20 | 21 | class ADPT_VisionTransformerMoCo(VisionTransformerMoCo): 22 | def __init__( 23 | self, 24 | adapter_cfg, 25 | stop_grad_conv1=False, 26 | img_size=224, 27 | patch_size=16, 28 | in_chans=3, 29 | num_classes=1000, 30 | embed_dim=768, 31 | depth=12, 32 | num_heads=12, 33 | mlp_ratio=4., 34 | qkv_bias=True, 35 | representation_size=None, 36 | distilled=False, 37 | drop_rate=0., 38 | attn_drop_rate=0., 39 | drop_path_rate=0., 40 | embed_layer=PatchEmbed, 41 | norm_layer=None, 42 | act_layer=None, 43 | weight_init='', 44 | **kwargs): 45 | 46 | super(ADPT_VisionTransformerMoCo, self).__init__( 47 | stop_grad_conv1=stop_grad_conv1, 48 | img_size=img_size, 49 | patch_size=patch_size, 50 | in_chans=in_chans, 51 | num_classes=num_classes, 52 | embed_dim=embed_dim, 53 | depth=depth, 54 | num_heads=num_heads, 55 | mlp_ratio=mlp_ratio, 56 | qkv_bias=qkv_bias, 57 | representation_size=representation_size, 58 | distilled=distilled, 59 | drop_rate=drop_rate, 60 | attn_drop_rate=attn_drop_rate, 61 | drop_path_rate=drop_path_rate, 62 | embed_layer=embed_layer, 63 | norm_layer=norm_layer, 64 | act_layer=act_layer, 65 | weight_init=weight_init, 66 | **kwargs 67 | ) 68 | 69 | self.adapter_cfg = adapter_cfg 70 | 71 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 72 | act_layer = act_layer or nn.GELU 73 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 74 | 75 | if adapter_cfg.STYLE == "Pfeiffer": 76 | self.blocks = nn.Sequential(*[ 77 | Pfeiffer_Block( 78 | adapter_config=adapter_cfg, 79 | dim=embed_dim, 80 | num_heads=num_heads, 81 | mlp_ratio=mlp_ratio, 82 | qkv_bias=qkv_bias, 83 | drop=drop_rate, 84 | attn_drop=attn_drop_rate, 85 | drop_path=dpr[i], 86 | norm_layer=norm_layer, 87 | act_layer=act_layer) for i in range(depth)]) 88 | else: 89 | raise ValueError("Other adapter styles are not supported.") 90 | 91 | 92 | 93 | def vit_base(adapter_cfg, **kwargs): 94 | model = ADPT_VisionTransformerMoCo( 95 | adapter_cfg, 96 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 97 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 98 | model.default_cfg = _cfg() 99 | return model 100 | -------------------------------------------------------------------------------- /src/configs/vit_configs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Copyright (c) Meta Platforms, Inc. All Rights Reserved 4 | https://github.com/jeonsworld/ViT-pytorch/blob/main/models/configs.py 5 | """ 6 | import ml_collections 7 | 8 | 9 | def get_testing(): 10 | """Returns a minimal configuration for testing.""" 11 | config = ml_collections.ConfigDict() 12 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 13 | config.hidden_size = 1 14 | config.transformer = ml_collections.ConfigDict() 15 | config.transformer.mlp_dim = 1 16 | config.transformer.num_heads = 1 17 | config.transformer.num_layers = 1 18 | config.transformer.attention_dropout_rate = 0.0 19 | config.transformer.dropout_rate = 0.1 20 | config.classifier = 'token' 21 | config.representation_size = None 22 | return config 23 | 24 | 25 | def get_b16_config(): 26 | """Returns the ViT-B/16 configuration.""" 27 | config = ml_collections.ConfigDict() 28 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 29 | config.hidden_size = 768 30 | config.transformer = ml_collections.ConfigDict() 31 | config.transformer.mlp_dim = 3072 32 | config.transformer.num_heads = 12 33 | config.transformer.num_layers = 12 34 | config.transformer.attention_dropout_rate = 0.0 35 | config.transformer.dropout_rate = 0.1 36 | config.classifier = 'token' 37 | config.representation_size = None 38 | return config 39 | 40 | 41 | def get_r50_b16_config(): 42 | """Returns the Resnet50 + ViT-B/16 configuration.""" 43 | config = get_b16_config() 44 | del config.patches.size 45 | config.patches.grid = (14, 14) 46 | config.resnet = ml_collections.ConfigDict() 47 | config.resnet.num_layers = (3, 4, 9) 48 | config.resnet.width_factor = 1 49 | return config 50 | 51 | 52 | def get_b32_config(): 53 | """Returns the ViT-B/32 configuration.""" 54 | config = get_b16_config() 55 | config.patches.size = (32, 32) 56 | return config 57 | 58 | 59 | def get_b8_config(): 60 | """Returns the ViT-B/32 configuration.""" 61 | config = get_b16_config() 62 | config.patches.size = (8, 8) 63 | return config 64 | 65 | 66 | def get_l16_config(): 67 | """Returns the ViT-L/16 configuration.""" 68 | config = ml_collections.ConfigDict() 69 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 70 | config.hidden_size = 1024 71 | config.transformer = ml_collections.ConfigDict() 72 | config.transformer.mlp_dim = 4096 73 | config.transformer.num_heads = 16 74 | config.transformer.num_layers = 24 75 | config.transformer.attention_dropout_rate = 0.0 76 | config.transformer.dropout_rate = 0.1 77 | config.classifier = 'token' 78 | config.representation_size = None 79 | return config 80 | 81 | 82 | def get_l32_config(): 83 | """Returns the ViT-L/32 configuration.""" 84 | config = get_l16_config() 85 | config.patches.size = (32, 32) 86 | return config 87 | 88 | 89 | def get_h14_config(): 90 | """Returns the ViT-L/16 configuration.""" 91 | config = ml_collections.ConfigDict() 92 | config.patches = ml_collections.ConfigDict({'size': (14, 14)}) 93 | config.hidden_size = 1280 94 | config.transformer = ml_collections.ConfigDict() 95 | config.transformer.mlp_dim = 5120 96 | config.transformer.num_heads = 16 97 | config.transformer.num_layers = 32 98 | config.transformer.attention_dropout_rate = 0.0 99 | config.transformer.dropout_rate = 0.1 100 | config.classifier = 'token' 101 | config.representation_size = None 102 | return config 103 | -------------------------------------------------------------------------------- /src/solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import math 3 | 4 | import torch.optim as optim 5 | from fvcore.common.config import CfgNode 6 | from torch.optim.lr_scheduler import LambdaLR 7 | 8 | 9 | def make_scheduler( 10 | optimizer: optim.Optimizer, train_params: CfgNode 11 | ) -> LambdaLR: 12 | warmup = train_params.WARMUP_EPOCH 13 | total_iters = train_params.TOTAL_EPOCH 14 | 15 | if train_params.SCHEDULER == "cosine": 16 | scheduler = WarmupCosineSchedule( 17 | optimizer, 18 | warmup_steps=warmup, 19 | t_total=total_iters 20 | ) 21 | elif train_params.SCHEDULER == "cosine_hardrestart": 22 | scheduler = WarmupCosineWithHardRestartsSchedule( 23 | optimizer, 24 | warmup_steps=warmup, 25 | t_total=total_iters 26 | ) 27 | 28 | elif train_params.SCHEDULER == "plateau": 29 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 30 | optimizer, 31 | "max", 32 | patience=5, 33 | verbose=True, 34 | factor=train_params.LR_DECAY_FACTOR, 35 | ) 36 | else: 37 | scheduler = None 38 | return scheduler 39 | 40 | 41 | class WarmupCosineSchedule(LambdaLR): 42 | """ Linear warmup and then cosine decay. 43 | Linearly increases learning rate from 0 to 1 over `warmup_steps`. 44 | Decreases learning rate from 1. to 0. over remaining 45 | `t_total - warmup_steps` steps following a cosine curve. 46 | If `cycles` (default=0.5) is different from default, learning rate 47 | follows cosine function after warmup. 48 | """ 49 | def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1): 50 | self.warmup_steps = warmup_steps 51 | self.t_total = t_total 52 | self.cycles = cycles 53 | super(WarmupCosineSchedule, self).__init__( 54 | optimizer, self.lr_lambda, last_epoch=last_epoch) 55 | 56 | def lr_lambda(self, step): 57 | if step < self.warmup_steps: 58 | return float(step) / float(max(1.0, self.warmup_steps)) 59 | # progress after warmup 60 | progress = float(step - self.warmup_steps) / float(max( 61 | 1, self.t_total - self.warmup_steps)) 62 | return max( 63 | 0.0, 64 | 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)) 65 | ) 66 | 67 | 68 | class WarmupCosineWithHardRestartsSchedule(LambdaLR): 69 | """ Linear warmup and then cosine cycles with hard restarts. 70 | Linearly increases learning rate from 0 to 1 over `warmup_steps`. 71 | If `cycles` (default=1.) is different from default, learning rate 72 | follows `cycles` times a cosine decaying learning rate 73 | (with hard restarts). 74 | """ 75 | def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1): 76 | self.warmup_steps = warmup_steps 77 | self.t_total = t_total 78 | self.cycles = cycles 79 | super(WarmupCosineWithHardRestartsSchedule, self).__init__( 80 | optimizer, self.lr_lambda, last_epoch=last_epoch) 81 | 82 | def lr_lambda(self, step): 83 | if step < self.warmup_steps: 84 | return float(step) / float(max(1, self.warmup_steps)) 85 | # progress after warmup 86 | progress = float(step - self.warmup_steps) / float( 87 | max(1, self.t_total - self.warmup_steps)) 88 | if progress >= 1.0: 89 | return 0.0 90 | return max( 91 | 0.0, 92 | 0.5 * (1. + math.cos( 93 | math.pi * ((float(self.cycles) * progress) % 1.0))) 94 | ) 95 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/oxford_iiit_pet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements OxfordIIITPet data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | # This constant specifies the percentage of data that is used to create custom 28 | # train/val splits. Specifically, TRAIN_SPLIT_PERCENT% of the official training 29 | # split is used as a new training split and the rest is used for validation. 30 | TRAIN_SPLIT_PERCENT = 80 31 | 32 | 33 | @Registry.register("data.oxford_iiit_pet", "class") 34 | class OxfordIIITPetData(base.ImageTfdsData): 35 | """Provides OxfordIIITPet data. 36 | 37 | The OxfordIIITPet dataset comes only with a training and test set. 38 | Therefore, the validation set is split out of the original training set, and 39 | the remaining examples are used as the "train" split. The "trainval" split 40 | corresponds to the original training set. 41 | 42 | For additional details and usage, see the base class. 43 | """ 44 | 45 | def __init__(self, data_dir=None, train_split_percent=None): 46 | 47 | dataset_builder = tfds.builder("oxford_iiit_pet:3.*.*", data_dir=data_dir) 48 | dataset_builder.download_and_prepare() 49 | train_split_percent = train_split_percent or TRAIN_SPLIT_PERCENT 50 | 51 | # Creates a dict with example counts for each split. 52 | trainval_count = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples 53 | test_count = dataset_builder.info.splits[tfds.Split.TEST].num_examples 54 | num_samples_splits = { 55 | "train": (train_split_percent * trainval_count) // 100, 56 | "val": trainval_count - (train_split_percent * trainval_count) // 100, 57 | "trainval": trainval_count, 58 | "test": test_count, 59 | "train800": 800, 60 | "val200": 200, 61 | "train800val200": 1000, 62 | } 63 | 64 | # Defines dataset specific train/val/trainval/test splits. 65 | tfds_splits = { 66 | "train": "train[:{}]".format(num_samples_splits["train"]), 67 | "val": "train[{}:]".format(num_samples_splits["train"]), 68 | "trainval": tfds.Split.TRAIN, 69 | "test": tfds.Split.TEST, 70 | "train800": "train[:800]", 71 | "val200": "train[{}:{}]".format( 72 | num_samples_splits["train"], num_samples_splits["train"]+200), 73 | "train800val200": "train[:800]+train[{}:{}]".format( 74 | num_samples_splits["train"], num_samples_splits["train"]+200), 75 | } 76 | 77 | super(OxfordIIITPetData, self).__init__( 78 | dataset_builder=dataset_builder, 79 | tfds_splits=tfds_splits, 80 | num_samples_splits=num_samples_splits, 81 | num_preprocessing_threads=400, 82 | shuffle_buffer_size=10000, 83 | # Note: Export only image and label tensors with their original types. 84 | base_preprocess_fn=base.make_get_tensors_fn(["image", "label"]), 85 | num_classes=dataset_builder.info.features["label"].num_classes) 86 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements EurosatData class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | 28 | TRAIN_SPLIT_PERCENT = 60 29 | VALIDATION_SPLIT_PERCENT = 20 30 | TEST_SPLIT_PERCENT = 20 31 | 32 | 33 | @Registry.register("data.eurosat", "class") 34 | class EurosatData(base.ImageTfdsData): 35 | """Provides EuroSat dataset. 36 | 37 | EuroSAT dataset is based on Sentinel-2 satellite images covering 13 spectral 38 | bands and consisting of 10 classes with 27000 labeled and 39 | geo-referenced samples. 40 | 41 | URL: https://github.com/phelber/eurosat 42 | """ 43 | 44 | def __init__(self, subset="rgb", data_key="image", data_dir=None): 45 | dataset_name = "eurosat/{}:2.*.*".format(subset) 46 | dataset_builder = tfds.builder(dataset_name, data_dir=data_dir) 47 | dataset_builder.download_and_prepare() 48 | 49 | # Example counts are retrieved from the tensorflow dataset info. 50 | num_examples = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples 51 | train_count = num_examples * TRAIN_SPLIT_PERCENT // 100 52 | val_count = num_examples * VALIDATION_SPLIT_PERCENT // 100 53 | test_count = num_examples * TEST_SPLIT_PERCENT // 100 54 | 55 | tfds_splits = { 56 | "train": 57 | "train[:{}]".format(train_count), 58 | "val": 59 | "train[{}:{}]".format(train_count, train_count+val_count), 60 | "trainval": 61 | "train[:{}]".format(train_count+val_count), 62 | "test": 63 | "train[{}:]".format(train_count+val_count), 64 | "train800": 65 | "train[:800]", 66 | "val200": 67 | "train[{}:{}]".format(train_count, train_count+200), 68 | "train800val200": 69 | "train[:800]+train[{}:{}]".format(train_count, train_count+200), 70 | } 71 | 72 | # Creates a dict with example counts for each split. 73 | num_samples_splits = { 74 | "train": train_count, 75 | "val": val_count, 76 | "trainval": train_count + val_count, 77 | "test": test_count, 78 | "train800": 800, 79 | "val200": 200, 80 | "train800val200": 1000, 81 | } 82 | 83 | num_channels = 3 84 | if data_key == "sentinel2": 85 | num_channels = 13 86 | 87 | super(EurosatData, self).__init__( 88 | dataset_builder=dataset_builder, 89 | tfds_splits=tfds_splits, 90 | num_samples_splits=num_samples_splits, 91 | num_preprocessing_threads=100, 92 | shuffle_buffer_size=10000, 93 | base_preprocess_fn=base.make_get_and_cast_tensors_fn({ 94 | data_key: ("image", None), 95 | "label": ("label", None), 96 | }), 97 | image_key=data_key, 98 | num_channels=num_channels, 99 | num_classes=dataset_builder.info.features["label"].num_classes) 100 | -------------------------------------------------------------------------------- /src/data/loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Data loader.""" 4 | import torch 5 | from torch.utils.data.distributed import DistributedSampler 6 | from torch.utils.data.sampler import RandomSampler 7 | 8 | from ..utils import logging 9 | from .datasets.json_dataset import ( 10 | CUB200Dataset, CarsDataset, DogsDataset, FlowersDataset, NabirdsDataset 11 | ) 12 | 13 | logger = logging.get_logger("visual_prompt") 14 | _DATASET_CATALOG = { 15 | "CUB": CUB200Dataset, 16 | 'OxfordFlowers': FlowersDataset, 17 | 'StanfordCars': CarsDataset, 18 | 'StanfordDogs': DogsDataset, 19 | "nabirds": NabirdsDataset, 20 | } 21 | 22 | 23 | def _construct_loader(cfg, split, batch_size, shuffle, drop_last): 24 | """Constructs the data loader for the given dataset.""" 25 | dataset_name = cfg.DATA.NAME 26 | 27 | # Construct the dataset 28 | if dataset_name.startswith("vtab-"): 29 | # import the tensorflow here only if needed 30 | from .datasets.tf_dataset import TFDataset 31 | dataset = TFDataset(cfg, split) 32 | else: 33 | assert ( 34 | dataset_name in _DATASET_CATALOG.keys() 35 | ), "Dataset '{}' not supported".format(dataset_name) 36 | dataset = _DATASET_CATALOG[dataset_name](cfg, split) 37 | 38 | # Create a sampler for multi-process training 39 | sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None 40 | # Create a loader 41 | loader = torch.utils.data.DataLoader( 42 | dataset, 43 | batch_size=batch_size, 44 | shuffle=(False if sampler else shuffle), 45 | sampler=sampler, 46 | num_workers=cfg.DATA.NUM_WORKERS, 47 | pin_memory=cfg.DATA.PIN_MEMORY, 48 | drop_last=drop_last, 49 | ) 50 | return loader 51 | 52 | 53 | def construct_train_loader(cfg): 54 | """Train loader wrapper.""" 55 | if cfg.NUM_GPUS > 1: 56 | drop_last = True 57 | else: 58 | drop_last = False 59 | return _construct_loader( 60 | cfg=cfg, 61 | split="train", 62 | batch_size=int(cfg.DATA.BATCH_SIZE / cfg.NUM_GPUS), 63 | shuffle=True, 64 | drop_last=drop_last, 65 | ) 66 | 67 | 68 | def construct_trainval_loader(cfg): 69 | """Train loader wrapper.""" 70 | if cfg.NUM_GPUS > 1: 71 | drop_last = True 72 | else: 73 | drop_last = False 74 | return _construct_loader( 75 | cfg=cfg, 76 | split="trainval", 77 | batch_size=int(cfg.DATA.BATCH_SIZE / cfg.NUM_GPUS), 78 | shuffle=True, 79 | drop_last=drop_last, 80 | ) 81 | 82 | 83 | def construct_test_loader(cfg): 84 | """Test loader wrapper.""" 85 | return _construct_loader( 86 | cfg=cfg, 87 | split="test", 88 | batch_size=int(cfg.DATA.BATCH_SIZE / cfg.NUM_GPUS), 89 | shuffle=False, 90 | drop_last=False, 91 | ) 92 | 93 | 94 | def construct_val_loader(cfg, batch_size=None): 95 | if batch_size is None: 96 | bs = int(cfg.DATA.BATCH_SIZE / cfg.NUM_GPUS) 97 | else: 98 | bs = batch_size 99 | """Validation loader wrapper.""" 100 | return _construct_loader( 101 | cfg=cfg, 102 | split="val", 103 | batch_size=bs, 104 | shuffle=False, 105 | drop_last=False, 106 | ) 107 | 108 | 109 | def shuffle(loader, cur_epoch): 110 | """"Shuffles the data.""" 111 | assert isinstance( 112 | loader.sampler, (RandomSampler, DistributedSampler) 113 | ), "Sampler type '{}' not supported".format(type(loader.sampler)) 114 | # RandomSampler handles shuffling automatically 115 | if isinstance(loader.sampler, DistributedSampler): 116 | # DistributedSampler shuffles data based on epoch 117 | loader.sampler.set_epoch(cur_epoch) 118 | -------------------------------------------------------------------------------- /launch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | launch helper functions 4 | """ 5 | import argparse 6 | import os 7 | import sys 8 | import pprint 9 | import PIL 10 | from collections import defaultdict 11 | from tabulate import tabulate 12 | from typing import Tuple 13 | 14 | import torch 15 | from src.utils.file_io import PathManager 16 | from src.utils import logging 17 | from src.utils.distributed import get_rank, get_world_size 18 | 19 | 20 | def collect_torch_env() -> str: 21 | try: 22 | import torch.__config__ 23 | 24 | return torch.__config__.show() 25 | except ImportError: 26 | # compatible with older versions of pytorch 27 | from torch.utils.collect_env import get_pretty_env_info 28 | 29 | return get_pretty_env_info() 30 | 31 | 32 | def get_env_module() -> Tuple[str]: 33 | var_name = "ENV_MODULE" 34 | return var_name, os.environ.get(var_name, "") 35 | 36 | 37 | def collect_env_info() -> str: 38 | data = [] 39 | data.append(("Python", sys.version.replace("\n", ""))) 40 | data.append(get_env_module()) 41 | data.append(("PyTorch", torch.__version__)) 42 | data.append(("PyTorch Debug Build", torch.version.debug)) 43 | 44 | has_cuda = torch.cuda.is_available() 45 | data.append(("CUDA available", has_cuda)) 46 | if has_cuda: 47 | data.append(("CUDA ID", os.environ["CUDA_VISIBLE_DEVICES"])) 48 | devices = defaultdict(list) 49 | for k in range(torch.cuda.device_count()): 50 | devices[torch.cuda.get_device_name(k)].append(str(k)) 51 | for name, devids in devices.items(): 52 | data.append(("GPU " + ",".join(devids), name)) 53 | data.append(("Pillow", PIL.__version__)) 54 | 55 | try: 56 | import cv2 57 | 58 | data.append(("cv2", cv2.__version__)) 59 | except ImportError: 60 | pass 61 | env_str = tabulate(data) + "\n" 62 | env_str += collect_torch_env() 63 | return env_str 64 | 65 | 66 | def default_argument_parser(): 67 | """ 68 | create a simple parser to wrap around config file 69 | """ 70 | parser = argparse.ArgumentParser(description="visual-prompt") 71 | parser.add_argument( 72 | "--config-file", default="", metavar="FILE", help="path to config file") 73 | parser.add_argument( 74 | "--train-type", default="", help="training types") 75 | parser.add_argument( 76 | "opts", 77 | help="Modify config options using the command-line", 78 | default=None, 79 | nargs=argparse.REMAINDER, 80 | ) 81 | 82 | return parser 83 | 84 | 85 | def logging_train_setup(args, cfg) -> None: 86 | output_dir = cfg.OUTPUT_DIR 87 | if output_dir: 88 | PathManager.mkdirs(output_dir) 89 | 90 | logger = logging.setup_logging( 91 | cfg.NUM_GPUS, get_world_size(), output_dir, name="visual_prompt") 92 | 93 | # Log basic information about environment, cmdline arguments, and config 94 | rank = get_rank() 95 | logger.info( 96 | f"Rank of current process: {rank}. World size: {get_world_size()}") 97 | logger.info("Environment info:\n" + collect_env_info()) 98 | 99 | logger.info("Command line arguments: " + str(args)) 100 | if hasattr(args, "config_file") and args.config_file != "": 101 | logger.info( 102 | "Contents of args.config_file={}:\n{}".format( 103 | args.config_file, 104 | PathManager.open(args.config_file, "r").read() 105 | ) 106 | ) 107 | # Show the config 108 | logger.info("Training with config:") 109 | logger.info(pprint.pformat(cfg)) 110 | # cudnn benchmark has large overhead. 111 | # It shouldn't be used considering the small size of typical val set. 112 | if not (hasattr(args, "eval_only") and args.eval_only): 113 | torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK 114 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/cifar.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements Cifar data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | from . import base as base 23 | from .registry import Registry 24 | import tensorflow_datasets as tfds 25 | 26 | # This constant specifies the percentage of data that is used to create custom 27 | # train/val splits. Specifically, TRAIN_SPLIT_PERCENT% of the official training 28 | # split is used as a new training split and the rest is used for validation. 29 | TRAIN_SPLIT_PERCENT = 90 30 | 31 | 32 | @Registry.register("data.cifar", "class") 33 | class CifarData(base.ImageTfdsData): 34 | """Provides Cifar10 or Cifar100 data. 35 | 36 | Cifar comes only with a training and test set. Therefore, the validation set 37 | is split out of the original training set, and the remaining examples are used 38 | as the "train" split. The "trainval" split corresponds to the original 39 | training set. 40 | 41 | For additional details and usage, see the base class. 42 | """ 43 | 44 | def __init__(self, num_classes=10, data_dir=None, train_split_percent=None): 45 | 46 | if num_classes == 10: 47 | dataset_builder = tfds.builder("cifar10:3.*.*", data_dir=data_dir) 48 | elif num_classes == 100: 49 | dataset_builder = tfds.builder("cifar100:3.*.*", data_dir=data_dir) 50 | else: 51 | raise ValueError( 52 | "Number of classes must be 10 or 100, got {}".format(num_classes)) 53 | 54 | dataset_builder.download_and_prepare() 55 | 56 | train_split_percent = train_split_percent or TRAIN_SPLIT_PERCENT 57 | 58 | # Creates a dict with example counts for each split. 59 | trainval_count = dataset_builder.info.splits["train"].num_examples 60 | test_count = dataset_builder.info.splits["test"].num_examples 61 | num_samples_splits = { 62 | "train": (train_split_percent * trainval_count) // 100, 63 | "val": trainval_count - (train_split_percent * trainval_count) // 100, 64 | "trainval": trainval_count, 65 | "test": test_count, 66 | "train800": 800, 67 | "val200": 200, 68 | "train800val200": 1000, 69 | } 70 | 71 | # Defines dataset specific train/val/trainval/test splits. 72 | tfds_splits = { 73 | "train": "train[:{}]".format(num_samples_splits["train"]), 74 | "val": "train[{}:]".format(num_samples_splits["train"]), 75 | "trainval": "train", 76 | "test": "test", 77 | "train800": "train[:800]", 78 | "val200": "train[{}:{}]".format( 79 | num_samples_splits["train"], num_samples_splits["train"]+200), 80 | "train800val200": "train[:800]+train[{}:{}]".format( 81 | num_samples_splits["train"], num_samples_splits["train"]+200), 82 | } 83 | 84 | super(CifarData, self).__init__( 85 | dataset_builder=dataset_builder, 86 | tfds_splits=tfds_splits, 87 | num_samples_splits=num_samples_splits, 88 | num_preprocessing_threads=400, 89 | shuffle_buffer_size=10000, 90 | # Note: Export only image and label tensors with their original types. 91 | base_preprocess_fn=base.make_get_tensors_fn(["image", "label", "id"]), 92 | num_classes=dataset_builder.info.features["label"].num_classes) 93 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/smallnorb.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements the SmallNORB data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow.compat.v1 as tf 24 | import tensorflow_datasets as tfds 25 | 26 | from . import base as base 27 | from .registry import Registry 28 | # This constant specifies the percentage of data that is used to create custom 29 | # val/test splits. Specifically, VAL_SPLIT_PERCENT% of the official testing 30 | # split is used as a new validation split and the rest is used for testing. 31 | VAL_SPLIT_PERCENT = 50 32 | 33 | 34 | @Registry.register("data.smallnorb", "class") 35 | class SmallNORBData(base.ImageTfdsData): 36 | """Provides the SmallNORB data set. 37 | 38 | SmallNORB comes only with a training and test set. Therefore, the validation 39 | set is split out of the original training set, and the remaining examples are 40 | used as the "train" split. The "trainval" split corresponds to the original 41 | training set. 42 | 43 | For additional details and usage, see the base class. 44 | 45 | The data set page is https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/. 46 | """ 47 | 48 | def __init__(self, predicted_attribute, data_dir=None): 49 | dataset_builder = tfds.builder("smallnorb:2.*.*", data_dir=data_dir) 50 | dataset_builder.download_and_prepare() 51 | 52 | if predicted_attribute not in dataset_builder.info.features: 53 | raise ValueError( 54 | "{} is not a valid attribute to predict.".format(predicted_attribute)) 55 | 56 | # Defines dataset specific train/val/trainval/test splits. 57 | tfds_splits = { 58 | "train": "train", 59 | "val": "test[:{}%]".format(VAL_SPLIT_PERCENT), 60 | "trainval": "train+test[:{}%]".format(VAL_SPLIT_PERCENT), 61 | "test": "test[{}%:]".format(VAL_SPLIT_PERCENT), 62 | "train800": "train[:800]", 63 | "val200": "test[:200]", 64 | "train800val200": "train[:800]+test[:200]", 65 | } 66 | 67 | # Creates a dict with example counts for each split. 68 | train_count = dataset_builder.info.splits["train"].num_examples 69 | test_count = dataset_builder.info.splits["test"].num_examples 70 | num_samples_validation = VAL_SPLIT_PERCENT * test_count // 100 71 | num_samples_splits = { 72 | "train": train_count, 73 | "val": num_samples_validation, 74 | "trainval": train_count + num_samples_validation, 75 | "test": test_count - num_samples_validation, 76 | "train800": 800, 77 | "val200": 200, 78 | "train800val200": 1000, 79 | } 80 | 81 | def preprocess_fn(tensors): 82 | # For consistency with other datasets, image needs to have three channels. 83 | image = tf.tile(tensors["image"], [1, 1, 3]) 84 | return dict(image=image, label=tensors[predicted_attribute]) 85 | 86 | info = dataset_builder.info 87 | super(SmallNORBData, self).__init__( 88 | dataset_builder=dataset_builder, 89 | tfds_splits=tfds_splits, 90 | num_samples_splits=num_samples_splits, 91 | num_preprocessing_threads=400, 92 | shuffle_buffer_size=10000, 93 | # We extract the attribute we want to predict in the preprocessing. 94 | base_preprocess_fn=preprocess_fn, 95 | num_classes=info.features[predicted_attribute].num_classes) 96 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/oxford_flowers102.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements oxford flowers 102 data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | 28 | 29 | @Registry.register("data.oxford_flowers102", "class") 30 | class OxfordFlowers102Data(base.ImageTfdsData): 31 | """Provides Oxford 102 categories flowers dataset. 32 | 33 | See corresponding tfds dataset for details. 34 | 35 | URL: https://www.robots.ox.ac.uk/~vgg/data/flowers/102/ 36 | """ 37 | 38 | def __init__(self, data_dir=None, train_split_percent=None): 39 | dataset_builder = tfds.builder("oxford_flowers102:2.*.*", data_dir=data_dir) 40 | dataset_builder.download_and_prepare() 41 | 42 | # Example counts are retrieved from the tensorflow dataset info. 43 | train_count = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples 44 | val_count = dataset_builder.info.splits[tfds.Split.VALIDATION].num_examples 45 | test_count = dataset_builder.info.splits[tfds.Split.TEST].num_examples 46 | 47 | if train_split_percent: 48 | tfds_splits = { 49 | "train": "train[:{s}%]+validation[:{s}%]".format( 50 | s=train_split_percent), 51 | "val": "train[-{s}%:]+validation[-{s}%:]".format( 52 | s=train_split_percent), 53 | "trainval": "train+validation", 54 | "test": "test", 55 | "train800": "train[:800]", 56 | "val200": "validation[:200]", 57 | "train800val200": "train[:800]+validation[:200]", 58 | } 59 | num_samples_splits = { 60 | "train": (((train_count + val_count) // 100) 61 | * train_split_percent), 62 | "val": (((train_count + val_count) // 100) * 63 | (100 - train_split_percent)), 64 | "trainval": train_count + val_count, 65 | "test": test_count, 66 | "train800": 800, 67 | "val200": 200, 68 | "train800val200": 1000, 69 | } 70 | else: 71 | tfds_splits = { 72 | "train": "train", 73 | "val": "validation", 74 | "trainval": "train+validation", 75 | "test": "test", 76 | "train800": "train[:800]", 77 | "val200": "validation[:200]", 78 | "train800val200": "train[:800]+validation[:200]", 79 | } 80 | num_samples_splits = { 81 | "train": train_count, 82 | "val": val_count, 83 | "trainval": train_count + val_count, 84 | "test": test_count, 85 | "train800": 800, 86 | "val200": 200, 87 | "train800val200": 1000, 88 | } 89 | 90 | super(OxfordFlowers102Data, self).__init__( 91 | dataset_builder=dataset_builder, 92 | tfds_splits=tfds_splits, 93 | num_samples_splits=num_samples_splits, 94 | num_preprocessing_threads=400, 95 | shuffle_buffer_size=10000, 96 | # Note: Rename tensors but keep their original types. 97 | base_preprocess_fn=base.make_get_and_cast_tensors_fn({ 98 | "image": ("image", None), 99 | "label": ("label", None), 100 | }), 101 | num_classes=dataset_builder.info.features["label"] 102 | .num_classes) 103 | -------------------------------------------------------------------------------- /src/engine/evaluator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import numpy as np 3 | 4 | from collections import defaultdict 5 | from typing import List, Union 6 | 7 | from .eval import multilabel 8 | from .eval import singlelabel 9 | from ..utils import logging 10 | logger = logging.get_logger("visual_prompt") 11 | 12 | 13 | class Evaluator(): 14 | """ 15 | An evaluator with below logics: 16 | 17 | 1. find which eval module to use. 18 | 2. store the eval results, pretty print it in log file as well. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | ) -> None: 24 | self.results = defaultdict(dict) 25 | self.iteration = -1 26 | self.threshold_end = 0.5 27 | 28 | def update_iteration(self, iteration: int) -> None: 29 | """update iteration info""" 30 | self.iteration = iteration 31 | 32 | def update_result(self, metric: str, value: Union[float, dict]) -> None: 33 | if self.iteration > -1: 34 | key_name = "epoch_" + str(self.iteration) 35 | else: 36 | key_name = "final" 37 | if isinstance(value, float): 38 | self.results[key_name].update({metric: value}) 39 | else: 40 | if metric in self.results[key_name]: 41 | self.results[key_name][metric].update(value) 42 | else: 43 | self.results[key_name].update({metric: value}) 44 | 45 | def classify(self, probs, targets, test_data, multilabel=False): 46 | """ 47 | Evaluate classification result. 48 | Args: 49 | probs: np.ndarray for num_data x num_class, predicted probabilities 50 | targets: np.ndarray for multilabel, list of integers for single label 51 | test_labels: map test image ids to a list of class labels 52 | """ 53 | if not targets: 54 | raise ValueError( 55 | "When evaluating classification, need at least give targets") 56 | 57 | if multilabel: 58 | self._eval_multilabel(probs, targets, test_data) 59 | else: 60 | self._eval_singlelabel(probs, targets, test_data) 61 | 62 | def _eval_singlelabel( 63 | self, 64 | scores: np.ndarray, 65 | targets: List[int], 66 | eval_type: str 67 | ) -> None: 68 | """ 69 | if number of labels > 2: 70 | top1 and topk (5 by default) accuracy 71 | if number of labels == 2: 72 | top1 and rocauc 73 | """ 74 | acc_dict = singlelabel.compute_acc_auc(scores, targets) 75 | 76 | log_results = { 77 | k: np.around(v * 100, decimals=2) for k, v in acc_dict.items() 78 | } 79 | save_results = acc_dict 80 | 81 | self.log_and_update(log_results, save_results, eval_type) 82 | 83 | def _eval_multilabel( 84 | self, 85 | scores: np.ndarray, 86 | targets: np.ndarray, 87 | eval_type: str 88 | ) -> None: 89 | num_labels = scores.shape[-1] 90 | targets = multilabel.multihot(targets, num_labels) 91 | 92 | log_results = {} 93 | ap, ar, mAP, mAR = multilabel.compute_map(scores, targets) 94 | f1_dict = multilabel.get_best_f1_scores( 95 | targets, scores, self.threshold_end) 96 | 97 | log_results["mAP"] = np.around(mAP * 100, decimals=2) 98 | log_results["mAR"] = np.around(mAR * 100, decimals=2) 99 | log_results.update({ 100 | k: np.around(v * 100, decimals=2) for k, v in f1_dict.items()}) 101 | save_results = { 102 | "ap": ap, "ar": ar, "mAP": mAP, "mAR": mAR, "f1": f1_dict 103 | } 104 | self.log_and_update(log_results, save_results, eval_type) 105 | 106 | def log_and_update(self, log_results, save_results, eval_type): 107 | log_str = "" 108 | for k, result in log_results.items(): 109 | if not isinstance(result, np.ndarray): 110 | log_str += f"{k}: {result:.2f}\t" 111 | else: 112 | log_str += f"{k}: {list(result)}\t" 113 | logger.info(f"Classification results with {eval_type}: {log_str}") 114 | # save everything 115 | self.update_result("classification", {eval_type: save_results}) 116 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/svhn.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements Svhn data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | # This constant specifies the percentage of data that is used to create custom 28 | # train/val splits. Specifically, TRAIN_SPLIT_PERCENT% of the official training 29 | # split is used as a new training split and the rest is used for validation. 30 | TRAIN_SPLIT_PERCENT = 90 31 | 32 | 33 | @Registry.register("data.svhn", "class") 34 | class SvhnData(base.ImageTfdsData): 35 | """Provides SVHN data. 36 | 37 | The Street View House Numbers (SVHN) Dataset is an image digit recognition 38 | dataset of over 600,000 color digit images coming from real world data. 39 | Split size: 40 | - Training set: 73,257 images 41 | - Testing set: 26,032 images 42 | - Extra training set: 531,131 images 43 | Following the common setup on SVHN, we only use the official training and 44 | testing data. Images are cropped to 32x32. 45 | 46 | URL: http://ufldl.stanford.edu/housenumbers/ 47 | """ 48 | 49 | def __init__(self, data_dir=None): 50 | dataset_builder = tfds.builder("svhn_cropped:3.*.*", data_dir=data_dir) 51 | dataset_builder.download_and_prepare() 52 | 53 | # Example counts are retrieved from the tensorflow dataset info. 54 | trainval_count = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples 55 | test_count = dataset_builder.info.splits[tfds.Split.TEST].num_examples 56 | 57 | # Creates a dict with example counts for each split. 58 | num_samples_splits = { 59 | # Calculates the train/val split example count based on percent. 60 | "train": TRAIN_SPLIT_PERCENT * trainval_count // 100, 61 | "val": trainval_count - TRAIN_SPLIT_PERCENT * trainval_count // 100, 62 | "trainval": trainval_count, 63 | "test": test_count, 64 | "train800": 800, 65 | "val200": 200, 66 | "train800val200": 1000, 67 | } 68 | 69 | # Defines dataset specific train/val/trainval/test splits. 70 | # The validation set is split out of the original training set, and the 71 | # remaining examples are used as the "train" split. The "trainval" split 72 | # corresponds to the original training set. 73 | tfds_splits = { 74 | "train": 75 | "train[:{}]".format(num_samples_splits["train"]), 76 | "val": 77 | "train[{}:]".format(num_samples_splits["train"]), 78 | "trainval": 79 | "train", 80 | "test": 81 | "test", 82 | "train800": 83 | "train[:800]", 84 | "val200": 85 | "train[{}:{}]".format(num_samples_splits["train"], 86 | num_samples_splits["train"] + 200), 87 | "train800val200": 88 | "train[:800]+train[{}:{}]".format( 89 | num_samples_splits["train"], num_samples_splits["train"] + 200), 90 | } 91 | 92 | super(SvhnData, self).__init__( 93 | dataset_builder=dataset_builder, 94 | tfds_splits=tfds_splits, 95 | num_samples_splits=num_samples_splits, 96 | num_preprocessing_threads=400, 97 | shuffle_buffer_size=10000, 98 | # Note: Rename tensors but keep their original types. 99 | base_preprocess_fn=base.make_get_and_cast_tensors_fn({ 100 | "image": ("image", None), 101 | "label": ("label", None), 102 | }), 103 | num_classes=dataset_builder.info.features["label"] 104 | .num_classes) 105 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | major actions here: fine-tune the features and evaluate different settings 4 | """ 5 | import os 6 | import torch 7 | import warnings 8 | 9 | import numpy as np 10 | import random 11 | 12 | from time import sleep 13 | from random import randint 14 | 15 | import src.utils.logging as logging 16 | from src.configs.config import get_cfg 17 | from src.data import loader as data_loader 18 | from src.engine.evaluator import Evaluator 19 | from src.engine.trainer import Trainer 20 | from src.models.build_model import build_model 21 | from src.utils.file_io import PathManager 22 | 23 | from launch import default_argument_parser, logging_train_setup 24 | warnings.filterwarnings("ignore") 25 | 26 | 27 | def setup(args): 28 | """ 29 | Create configs and perform basic setups. 30 | """ 31 | cfg = get_cfg() 32 | cfg.merge_from_file(args.config_file) 33 | cfg.merge_from_list(args.opts) 34 | 35 | # setup dist 36 | cfg.DIST_INIT_PATH = "tcp://{}:12399".format(os.environ["SLURMD_NODENAME"]) 37 | 38 | # setup output dir 39 | # output_dir / data_name / feature_name / lr_wd / run1 40 | output_dir = cfg.OUTPUT_DIR 41 | lr = cfg.SOLVER.BASE_LR 42 | wd = cfg.SOLVER.WEIGHT_DECAY 43 | output_folder = os.path.join( 44 | cfg.DATA.NAME, cfg.DATA.FEATURE, f"lr{lr}_wd{wd}") 45 | 46 | # train cfg.RUN_N_TIMES times 47 | count = 1 48 | while count <= cfg.RUN_N_TIMES: 49 | output_path = os.path.join(output_dir, output_folder, f"run{count}") 50 | # pause for a random time, so concurrent process with same setting won't interfere with each other. # noqa 51 | sleep(randint(3, 30)) 52 | if not PathManager.exists(output_path): 53 | PathManager.mkdirs(output_path) 54 | cfg.OUTPUT_DIR = output_path 55 | break 56 | else: 57 | count += 1 58 | if count > cfg.RUN_N_TIMES: 59 | raise ValueError( 60 | f"Already run {cfg.RUN_N_TIMES} times for {output_folder}, no need to run more") 61 | 62 | cfg.freeze() 63 | return cfg 64 | 65 | 66 | def get_loaders(cfg, logger): 67 | logger.info("Loading training data (final training data for vtab)...") 68 | if cfg.DATA.NAME.startswith("vtab-"): 69 | train_loader = data_loader.construct_trainval_loader(cfg) 70 | else: 71 | train_loader = data_loader.construct_train_loader(cfg) 72 | 73 | logger.info("Loading validation data...") 74 | # not really needed for vtab 75 | val_loader = data_loader.construct_val_loader(cfg) 76 | logger.info("Loading test data...") 77 | if cfg.DATA.NO_TEST: 78 | logger.info("...no test data is constructed") 79 | test_loader = None 80 | else: 81 | test_loader = data_loader.construct_test_loader(cfg) 82 | return train_loader, val_loader, test_loader 83 | 84 | 85 | def train(cfg, args): 86 | # clear up residual cache from previous runs 87 | if torch.cuda.is_available(): 88 | torch.cuda.empty_cache() 89 | 90 | # main training / eval actions here 91 | 92 | # fix the seed for reproducibility 93 | if cfg.SEED is not None: 94 | torch.manual_seed(cfg.SEED) 95 | np.random.seed(cfg.SEED) 96 | random.seed(0) 97 | 98 | # setup training env including loggers 99 | logging_train_setup(args, cfg) 100 | logger = logging.get_logger("visual_prompt") 101 | 102 | train_loader, val_loader, test_loader = get_loaders(cfg, logger) 103 | logger.info("Constructing models...") 104 | model, cur_device = build_model(cfg) 105 | 106 | logger.info("Setting up Evalutator...") 107 | evaluator = Evaluator() 108 | logger.info("Setting up Trainer...") 109 | trainer = Trainer(cfg, model, evaluator, cur_device) 110 | 111 | if train_loader: 112 | trainer.train_classifier(train_loader, val_loader, test_loader) 113 | else: 114 | print("No train loader presented. Exit") 115 | 116 | if cfg.SOLVER.TOTAL_EPOCH == 0: 117 | trainer.eval_classifier(test_loader, "test", 0) 118 | 119 | 120 | def main(args): 121 | """main function to call from workflow""" 122 | 123 | # set up cfg and args 124 | cfg = setup(args) 125 | 126 | # Perform training. 127 | train(cfg, args) 128 | 129 | 130 | if __name__ == '__main__': 131 | args = default_argument_parser().parse_args() 132 | main(args) 133 | -------------------------------------------------------------------------------- /src/models/vit_adapter/vit_mae.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | borrow from https://github.com/facebookresearch/mae/blob/main/models_vit.py 4 | """ 5 | from functools import partial 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from .adapter_block import Pfeiffer_Block 11 | from ..vit_backbones.vit_mae import VisionTransformer 12 | from timm.models.layers import PatchEmbed 13 | from ...utils import logging 14 | logger = logging.get_logger("visual_prompt") 15 | 16 | 17 | class ADPT_VisionTransformer(VisionTransformer): 18 | """ Vision Transformer with support for global average pooling 19 | """ 20 | def __init__( 21 | self, 22 | adapter_cfg, 23 | img_size=224, 24 | patch_size=16, 25 | in_chans=3, 26 | num_classes=1000, 27 | embed_dim=768, 28 | depth=12, 29 | num_heads=12, 30 | mlp_ratio=4., 31 | qkv_bias=True, 32 | representation_size=None, 33 | distilled=False, 34 | drop_rate=0., 35 | attn_drop_rate=0., 36 | drop_path_rate=0., 37 | embed_layer=PatchEmbed, 38 | norm_layer=None, 39 | act_layer=None, 40 | weight_init='', 41 | **kwargs): 42 | 43 | super(ADPT_VisionTransformer, self).__init__( 44 | img_size=img_size, 45 | patch_size=patch_size, 46 | in_chans=in_chans, 47 | num_classes=num_classes, 48 | embed_dim=embed_dim, 49 | depth=depth, 50 | num_heads=num_heads, 51 | mlp_ratio=mlp_ratio, 52 | qkv_bias=qkv_bias, 53 | representation_size=representation_size, 54 | distilled=distilled, 55 | drop_rate=drop_rate, 56 | attn_drop_rate=attn_drop_rate, 57 | drop_path_rate=drop_path_rate, 58 | embed_layer=embed_layer, 59 | norm_layer=norm_layer, 60 | act_layer=act_layer, 61 | weight_init=weight_init, 62 | **kwargs 63 | ) 64 | 65 | self.adapter_cfg = adapter_cfg 66 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 67 | act_layer = act_layer or nn.GELU 68 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 69 | 70 | if adapter_cfg.STYLE == "Pfeiffer": 71 | self.blocks = nn.Sequential(*[ 72 | Pfeiffer_Block( 73 | adapter_config=adapter_cfg, 74 | dim=embed_dim, 75 | num_heads=num_heads, 76 | mlp_ratio=mlp_ratio, 77 | qkv_bias=qkv_bias, 78 | drop=drop_rate, 79 | attn_drop=attn_drop_rate, 80 | drop_path=dpr[i], 81 | norm_layer=norm_layer, 82 | act_layer=act_layer) for i in range(depth)]) 83 | else: 84 | raise ValueError("Other adapter styles are not supported.") 85 | 86 | 87 | 88 | def build_model(model_type, adapter_cfg): 89 | if "vitb" in model_type: 90 | return vit_base_patch16(adapter_cfg) 91 | elif "vitl" in model_type: 92 | return vit_large_patch16(adapter_cfg) 93 | elif "vith" in model_type: 94 | return vit_huge_patch14(adapter_cfg) 95 | 96 | 97 | def vit_base_patch16(adapter_cfg, **kwargs): 98 | model = ADPT_VisionTransformer( 99 | adapter_cfg, 100 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 101 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 102 | mlp_ratio=4, qkv_bias=True, 103 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 104 | return model 105 | 106 | 107 | def vit_large_patch16(adapter_cfg, **kwargs): 108 | model = ADPT_VisionTransformer( 109 | adapter_cfg, 110 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 111 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 112 | mlp_ratio=4, qkv_bias=True, 113 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 114 | return model 115 | 116 | 117 | def vit_huge_patch14(adapter_cfg, **kwargs): 118 | model = ADPT_VisionTransformer( 119 | adapter_cfg, 120 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 121 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 122 | mlp_ratio=4, qkv_bias=True, 123 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 124 | return model 125 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/clevr.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements CLEVR data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import numpy as np 24 | import tensorflow.compat.v1 as tf 25 | import tensorflow_datasets as tfds 26 | 27 | from . import base as base 28 | from .registry import Registry 29 | 30 | TRAIN_SPLIT_PERCENT = 90 31 | 32 | 33 | def _count_preprocess_fn(x): 34 | return {"image": x["image"], 35 | "label": tf.size(x["objects"]["size"]) - 3} 36 | 37 | 38 | def _count_cylinders_preprocess_fn(x): 39 | # Class distribution: 40 | 41 | num_cylinders = tf.reduce_sum( 42 | tf.cast(tf.equal(x["objects"]["shape"], 2), tf.int32)) 43 | return {"image": x["image"], "label": num_cylinders} 44 | 45 | 46 | def _closest_object_preprocess_fn(x): 47 | dist = tf.reduce_min(x["objects"]["pixel_coords"][:, 2]) 48 | # These thresholds are uniformly spaced and result in more or less balanced 49 | # distribution of classes, see the resulting histogram: 50 | 51 | thrs = np.array([0.0, 8.0, 8.5, 9.0, 9.5, 10.0, 100.0]) 52 | label = tf.reduce_max(tf.where((thrs - dist) < 0)) 53 | return {"image": x["image"], 54 | "label": label} 55 | 56 | 57 | _TASK_DICT = { 58 | "count_all": { 59 | "preprocess_fn": _count_preprocess_fn, 60 | "num_classes": 8 61 | }, 62 | "count_cylinders": { 63 | "preprocess_fn": _count_cylinders_preprocess_fn, 64 | "num_classes": 11 65 | }, 66 | "closest_object_distance": { 67 | "preprocess_fn": _closest_object_preprocess_fn, 68 | "num_classes": 6 69 | }, 70 | } 71 | 72 | 73 | @Registry.register("data.clevr", "class") 74 | class CLEVRData(base.ImageTfdsData): 75 | """Provides CLEVR dataset. 76 | 77 | Currently, two tasks are supported: 78 | 1. Predict number of objects. 79 | 2. Predict distnace to the closest object. 80 | """ 81 | 82 | def __init__(self, task, data_dir=None): 83 | 84 | if task not in _TASK_DICT: 85 | raise ValueError("Unknown task: %s" % task) 86 | 87 | dataset_builder = tfds.builder("clevr:3.*.*", data_dir=data_dir) 88 | dataset_builder.download_and_prepare() 89 | 90 | # Creates a dict with example counts for each split. 91 | trainval_count = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples 92 | test_count = dataset_builder.info.splits[tfds.Split.TEST].num_examples 93 | num_samples_splits = { 94 | "train": (TRAIN_SPLIT_PERCENT * trainval_count) // 100, 95 | "val": trainval_count - (TRAIN_SPLIT_PERCENT * trainval_count) // 100, 96 | "trainval": trainval_count, 97 | "test": test_count, 98 | "train800": 800, 99 | "val200": 200, 100 | "train800val200": 1000, 101 | } 102 | 103 | # Defines dataset specific train/val/trainval/test splits. 104 | tfds_splits = { 105 | "train": "train[:{}]".format(num_samples_splits["train"]), 106 | "val": "train[{}:]".format(num_samples_splits["train"]), 107 | "trainval": "train", 108 | "test": "validation", 109 | "train800": "train[:800]", 110 | "val200": "train[{}:{}]".format( 111 | num_samples_splits["train"], num_samples_splits["train"]+200), 112 | "train800val200": "train[:800]+train[{}:{}]".format( 113 | num_samples_splits["train"], num_samples_splits["train"]+200), 114 | } 115 | 116 | task = _TASK_DICT[task] 117 | base_preprocess_fn = task["preprocess_fn"] 118 | 119 | super(CLEVRData, self).__init__( 120 | dataset_builder=dataset_builder, 121 | tfds_splits=tfds_splits, 122 | num_samples_splits=num_samples_splits, 123 | num_preprocessing_threads=400, 124 | shuffle_buffer_size=10000, 125 | # Note: Export only image and label tensors with their original types. 126 | base_preprocess_fn=base_preprocess_fn, 127 | num_classes=task["num_classes"]) 128 | -------------------------------------------------------------------------------- /src/engine/eval/multilabel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | evaluate precision@1, @5 equal to Top1 and Top5 error rate 4 | """ 5 | import numpy as np 6 | from typing import List, Tuple, Dict 7 | from sklearn.metrics import ( 8 | precision_recall_curve, 9 | average_precision_score, 10 | f1_score 11 | ) 12 | 13 | 14 | def get_continuous_ids(probe_labels: List[int]) -> Dict[int, int]: 15 | sorted(probe_labels) 16 | id2continuousid = {} 17 | for idx, p_id in enumerate(probe_labels): 18 | id2continuousid[p_id] = idx 19 | return id2continuousid 20 | 21 | 22 | def multihot(x: List[List[int]], nb_classes: int) -> np.ndarray: 23 | """transform to multihot encoding 24 | 25 | Arguments: 26 | x: list of multi-class integer labels, in the range 27 | [0, nb_classes-1] 28 | nb_classes: number of classes for the multi-hot vector 29 | 30 | Returns: 31 | multihot: multihot vector of type int, (num_samples, nb_classes) 32 | """ 33 | num_samples = len(x) 34 | 35 | multihot = np.zeros((num_samples, nb_classes), dtype=np.int32) 36 | for idx, labs in enumerate(x): 37 | for lab in labs: 38 | multihot[idx, lab] = 1 39 | 40 | return multihot.astype(np.int) 41 | 42 | 43 | def compute_map( 44 | scores: np.ndarray, multihot_targets: np.ndarray 45 | ) -> Tuple[np.ndarray, np.ndarray, float, float]: 46 | """Compute the mean average precision across all class labels. 47 | 48 | Arguments: 49 | scores: matrix of per-class distances, 50 | of size num_samples x nb_classes 51 | multihot_targets: matrix of multi-hot target predictions, 52 | of size num_samples x nb_classes 53 | 54 | Returns: 55 | ap: list of average-precision scores, one for each of 56 | the nb_classes classes. 57 | ar: list of average-recall scores, one for each of 58 | the nb_classes classes. 59 | mAP: the mean average precision score over all average 60 | precisions for all nb_classes classes. 61 | mAR: the mean average recall score over all average 62 | precisions for all nb_classes classes. 63 | """ 64 | nb_classes = scores.shape[1] 65 | 66 | ap = np.zeros((nb_classes,), dtype=np.float32) 67 | ar = np.zeros((nb_classes,), dtype=np.float32) 68 | 69 | for c in range(nb_classes): 70 | y_true = multihot_targets[:, c] 71 | y_scores = scores[:, c] 72 | 73 | # Use interpolated average precision (a la PASCAL 74 | try: 75 | ap[c] = average_precision_score(y_true, y_scores) 76 | except ValueError: 77 | ap[c] = -1 78 | 79 | # Also get the average of the recalls on the raw PR-curve 80 | try: 81 | _, rec, _ = precision_recall_curve(y_true, y_scores) 82 | ar[c] = rec.mean() 83 | except ValueError: 84 | ar[c] = -1 85 | 86 | mAP = ap.mean() 87 | mAR = ar.mean() 88 | 89 | return ap, ar, mAP, mAR 90 | 91 | 92 | def compute_f1( 93 | multihot_targets: np.ndarray, scores: np.ndarray, threshold: float = 0.5 94 | ) -> Tuple[float, float, float]: 95 | # change scores to predict_labels 96 | predict_labels = scores > threshold 97 | predict_labels = predict_labels.astype(np.int) 98 | 99 | # change targets to multihot 100 | f1 = {} 101 | f1["micro"] = f1_score( 102 | y_true=multihot_targets, 103 | y_pred=predict_labels, 104 | average="micro" 105 | ) 106 | f1["samples"] = f1_score( 107 | y_true=multihot_targets, 108 | y_pred=predict_labels, 109 | average="samples" 110 | ) 111 | f1["macro"] = f1_score( 112 | y_true=multihot_targets, 113 | y_pred=predict_labels, 114 | average="macro" 115 | ) 116 | f1["none"] = f1_score( 117 | y_true=multihot_targets, 118 | y_pred=predict_labels, 119 | average=None 120 | ) 121 | return f1["micro"], f1["samples"], f1["macro"], f1["none"] 122 | 123 | 124 | def get_best_f1_scores( 125 | multihot_targets: np.ndarray, scores: np.ndarray, threshold_end: int 126 | ) -> Dict[str, float]: 127 | end = 0.5 128 | end = 0.05 129 | end = threshold_end 130 | thrs = np.linspace( 131 | end, 0.95, int(np.round((0.95 - end) / 0.05)) + 1, endpoint=True 132 | ) 133 | f1_micros = [] 134 | f1_macros = [] 135 | f1_samples = [] 136 | f1_none = [] 137 | for thr in thrs: 138 | _micros, _samples, _macros, _none = compute_f1(multihot_targets, scores, thr) 139 | f1_micros.append(_micros) 140 | f1_samples.append(_samples) 141 | f1_macros.append(_macros) 142 | f1_none.append(_none) 143 | 144 | f1_macros_m = max(f1_macros) 145 | b_thr = np.argmax(f1_macros) 146 | 147 | f1_micros_m = f1_micros[b_thr] 148 | f1_samples_m = f1_samples[b_thr] 149 | f1_none_m = f1_none[b_thr] 150 | f1 = {} 151 | f1["micro"] = f1_micros_m 152 | f1["macro"] = f1_macros_m 153 | f1["samples"] = f1_samples_m 154 | f1["threshold"] = thrs[b_thr] 155 | f1["none"] = f1_none_m 156 | return f1 157 | -------------------------------------------------------------------------------- /src/models/vit_prompt/vit_moco.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | vit-moco-v3 with prompt 4 | """ 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | import torchvision as tv 9 | 10 | from functools import partial, reduce 11 | from operator import mul 12 | from torch.nn import Conv2d, Dropout 13 | from timm.models.vision_transformer import _cfg 14 | 15 | from ..vit_backbones.vit_moco import VisionTransformerMoCo 16 | from ...utils import logging 17 | logger = logging.get_logger("visual_prompt") 18 | 19 | 20 | class PromptedVisionTransformerMoCo(VisionTransformerMoCo): 21 | def __init__(self, prompt_config, **kwargs): 22 | super().__init__(**kwargs) 23 | self.prompt_config = prompt_config 24 | 25 | if self.prompt_config.DEEP and self.prompt_config.LOCATION not in ["prepend", ]: 26 | raise ValueError("Deep-{} is not supported".format(self.prompt_config.LOCATION)) 27 | 28 | num_tokens = self.prompt_config.NUM_TOKENS 29 | 30 | self.num_tokens = num_tokens 31 | self.prompt_dropout = Dropout(self.prompt_config.DROPOUT) 32 | 33 | # initiate prompt: 34 | if self.prompt_config.INITIATION == "random": 35 | val = math.sqrt(6. / float(3 * reduce(mul, self.patch_embed.patch_size, 1) + self.embed_dim)) # noqa 36 | 37 | self.prompt_embeddings = nn.Parameter(torch.zeros( 38 | 1, num_tokens, self.embed_dim)) 39 | # xavier_uniform initialization 40 | nn.init.uniform_(self.prompt_embeddings.data, -val, val) 41 | if self.prompt_config.DEEP: 42 | self.deep_prompt_embeddings = nn.Parameter(torch.zeros( 43 | len(self.blocks) - 1, 44 | num_tokens, self.embed_dim 45 | )) 46 | # xavier_uniform initialization 47 | nn.init.uniform_( 48 | self.deep_prompt_embeddings.data, -val, val) 49 | 50 | else: 51 | raise ValueError("Other initiation scheme is not supported") 52 | 53 | def incorporate_prompt(self, x): 54 | # combine prompt embeddings with image-patch embeddings 55 | B = x.shape[0] 56 | if self.prompt_config.LOCATION == "prepend": 57 | # after CLS token, all before image patches 58 | x = self.embeddings(x) # (batch_size, 1 + n_patches, hidden_dim) 59 | x = torch.cat(( 60 | x[:, :1, :], 61 | self.prompt_dropout( 62 | self.prompt_embeddings.expand(B, -1, -1)), 63 | x[:, 1:, :] 64 | ), dim=1) 65 | # (batch_size, cls_token + n_prompt + n_patches, hidden_dim) 66 | else: 67 | raise ValueError("Other prompt locations are not supported") 68 | 69 | return x 70 | 71 | def embeddings(self, x): 72 | x = self.patch_embed(x) 73 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) 74 | if self.dist_token is None: 75 | x = torch.cat((cls_token, x), dim=1) 76 | else: 77 | x = torch.cat(( 78 | cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), 79 | dim=1) 80 | x = self.pos_drop(x + self.pos_embed) 81 | return x 82 | 83 | def train(self, mode=True): 84 | # set train status for this class: disable all but the prompt-related modules 85 | if mode: 86 | # training: 87 | self.blocks.eval() 88 | self.patch_embed.eval() 89 | self.pos_drop.eval() 90 | self.prompt_dropout.train() 91 | else: 92 | # eval: 93 | for module in self.children(): 94 | module.train(mode) 95 | 96 | def forward_features(self, x): 97 | x = self.incorporate_prompt(x) 98 | 99 | # deep 100 | if self.prompt_config.DEEP: 101 | B = x.shape[0] 102 | num_layers = len(self.blocks) 103 | 104 | for i in range(num_layers): 105 | if i == 0: 106 | x = self.blocks[i](x) 107 | else: 108 | # prepend 109 | x = torch.cat(( 110 | x[:, :1, :], 111 | self.prompt_dropout( 112 | self.deep_prompt_embeddings[i-1].expand(B, -1, -1) 113 | ), 114 | x[:, (1 + self.num_tokens):, :] 115 | ), dim=1) 116 | x = self.blocks[i](x) 117 | else: 118 | # not deep: 119 | x = self.blocks(x) 120 | 121 | x = self.norm(x) 122 | if self.dist_token is None: 123 | return self.pre_logits(x[:, 0]) 124 | else: 125 | return x[:, 0], x[:, 1] 126 | 127 | 128 | def vit_base(prompt_cfg, **kwargs): 129 | model = PromptedVisionTransformerMoCo( 130 | prompt_cfg, 131 | patch_size=16, embed_dim=768, depth=12, 132 | num_heads=12, mlp_ratio=4, qkv_bias=True, 133 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 134 | model.default_cfg = _cfg() 135 | return model 136 | 137 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/dsprites.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements the DSprites data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow.compat.v1 as tf 24 | import tensorflow_datasets as tfds 25 | 26 | from . import base as base 27 | from .registry import Registry 28 | 29 | 30 | # These constants specify the percentage of data that is used to create custom 31 | # train/val splits. Specifically, TRAIN_SPLIT_PERCENT% of the data set is used 32 | # as a new training split and VAL_SPLIT_PERCENT% is used for validation. 33 | # The rest is used for testing. 34 | TRAIN_SPLIT_PERCENT = 80 35 | VAL_SPLIT_PERCENT = 10 36 | 37 | 38 | @Registry.register("data.dsprites", "class") 39 | class DSpritesData(base.ImageTfdsData): 40 | """Provides the DSprites data set. 41 | 42 | DSprites only comes with a training set. Therefore, the training, validation, 43 | and test set are split out of the original training set. 44 | 45 | For additional details and usage, see the base class. 46 | 47 | The data set page is https://github.com/deepmind/dsprites-dataset/. 48 | """ 49 | 50 | def __init__(self, predicted_attribute, num_classes=None, data_dir=None): 51 | dataset_builder = tfds.builder("dsprites:2.*.*", data_dir=data_dir) 52 | dataset_builder.download_and_prepare() 53 | info = dataset_builder.info 54 | 55 | if predicted_attribute not in dataset_builder.info.features: 56 | raise ValueError( 57 | "{} is not a valid attribute to predict.".format(predicted_attribute)) 58 | 59 | # If num_classes is set, we group together nearby integer values to arrive 60 | # at the desired number of classes. This is useful for example for grouping 61 | # together different spatial positions. 62 | num_original_classes = info.features[predicted_attribute].num_classes 63 | if num_classes is None: 64 | num_classes = num_original_classes 65 | if not isinstance(num_classes, int) or num_classes <= 1 or ( 66 | num_classes > num_original_classes): 67 | raise ValueError( 68 | "The number of classes should be None or in [2, ..., num_classes].") 69 | class_division_factor = float(num_original_classes) / num_classes 70 | 71 | # Creates a dict with example counts for each split. 72 | num_total = dataset_builder.info.splits["train"].num_examples 73 | num_samples_train = TRAIN_SPLIT_PERCENT * num_total // 100 74 | num_samples_val = VAL_SPLIT_PERCENT * num_total // 100 75 | num_samples_splits = { 76 | "train": num_samples_train, 77 | "val": num_samples_val, 78 | "trainval": num_samples_val + num_samples_train, 79 | "test": num_total - num_samples_val - num_samples_train, 80 | "train800": 800, 81 | "val200": 200, 82 | "train800val200": 1000, 83 | } 84 | 85 | # Defines dataset specific train/val/trainval/test splits. 86 | tfds_splits = { 87 | "train": "train[:{}]".format(num_samples_splits["train"]), 88 | "val": "train[{}:{}]".format(num_samples_splits["train"], 89 | num_samples_splits["trainval"]), 90 | "trainval": "train[:{}]".format(num_samples_splits["trainval"]), 91 | "test": "train[{}:]".format(num_samples_splits["trainval"]), 92 | "train800": "train[:800]", 93 | "val200": "train[{}:{}]".format(num_samples_splits["train"], 94 | num_samples_splits["train"]+200), 95 | "train800val200": "train[:800]+train[{}:{}]".format( 96 | num_samples_splits["train"], num_samples_splits["train"]+200), 97 | } 98 | 99 | def preprocess_fn(tensors): 100 | # For consistency with other datasets, image needs to have three channels 101 | # and be in [0, 255). 102 | images = tf.tile(tensors["image"], [1, 1, 3]) * 255 103 | label = tf.cast( 104 | tf.math.floordiv( 105 | tf.cast(tensors[predicted_attribute], tf.float32), 106 | class_division_factor), info.features[predicted_attribute].dtype) 107 | return dict(image=images, label=label) 108 | 109 | super(DSpritesData, self).__init__( 110 | dataset_builder=dataset_builder, 111 | tfds_splits=tfds_splits, 112 | num_samples_splits=num_samples_splits, 113 | num_preprocessing_threads=400, 114 | shuffle_buffer_size=10000, 115 | # We extract the attribute we want to predict in the preprocessing. 116 | base_preprocess_fn=preprocess_fn, 117 | num_classes=num_classes) 118 | -------------------------------------------------------------------------------- /src/configs/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Config system (based on Detectron's).""" 4 | 5 | from .config_node import CfgNode 6 | 7 | 8 | # Global config object 9 | _C = CfgNode() 10 | # Example usage: 11 | # from configs.config import cfg 12 | 13 | _C.DBG = False 14 | _C.OUTPUT_DIR = "./output" 15 | _C.RUN_N_TIMES = 5 16 | # Perform benchmarking to select the fastest CUDNN algorithms to use 17 | # Note that this may increase the memory usage and will likely not result 18 | # in overall speedups when variable size inputs are used (e.g. COCO training) 19 | _C.CUDNN_BENCHMARK = False 20 | 21 | # Number of GPUs to use (applies to both training and testing) 22 | _C.NUM_GPUS = 1 23 | _C.NUM_SHARDS = 1 24 | 25 | # Note that non-determinism may still be present due to non-deterministic 26 | # operator implementations in GPU operator libraries 27 | _C.SEED = None 28 | 29 | # ---------------------------------------------------------------------- 30 | # Model options 31 | # ---------------------------------------------------------------------- 32 | _C.MODEL = CfgNode() 33 | _C.MODEL.TRANSFER_TYPE = "linear" # one of linear, end2end, prompt, adapter, side, partial-1, tinytl-bias 34 | _C.MODEL.WEIGHT_PATH = "" # if resume from some checkpoint file 35 | _C.MODEL.SAVE_CKPT = False 36 | 37 | _C.MODEL.MODEL_ROOT = "" # root folder for pretrained model weights 38 | 39 | _C.MODEL.TYPE = "vit" 40 | _C.MODEL.MLP_NUM = 0 41 | 42 | _C.MODEL.LINEAR = CfgNode() 43 | _C.MODEL.LINEAR.MLP_SIZES = [] 44 | _C.MODEL.LINEAR.DROPOUT = 0.1 45 | 46 | # ---------------------------------------------------------------------- 47 | # Prompt options 48 | # ---------------------------------------------------------------------- 49 | _C.MODEL.PROMPT = CfgNode() 50 | _C.MODEL.PROMPT.NUM_TOKENS = 5 51 | _C.MODEL.PROMPT.LOCATION = "prepend" 52 | # prompt initalizatioin: 53 | # (1) default "random" 54 | # (2) "final-cls" use aggregated final [cls] embeddings from training dataset 55 | # (3) "cls-nolastl": use first 12 cls embeddings (exclude the final output) for deep prompt 56 | # (4) "cls-nofirstl": use last 12 cls embeddings (exclude the input to first layer) 57 | _C.MODEL.PROMPT.INITIATION = "random" # "final-cls", "cls-first12" 58 | _C.MODEL.PROMPT.CLSEMB_FOLDER = "" 59 | _C.MODEL.PROMPT.CLSEMB_PATH = "" 60 | _C.MODEL.PROMPT.PROJECT = -1 # "projection mlp hidden dim" 61 | _C.MODEL.PROMPT.DEEP = False # "whether do deep prompt or not, only for prepend location" 62 | 63 | 64 | _C.MODEL.PROMPT.NUM_DEEP_LAYERS = None # if set to be an int, then do partial-deep prompt tuning 65 | _C.MODEL.PROMPT.REVERSE_DEEP = False # if to only update last n layers, not the input layer 66 | _C.MODEL.PROMPT.DEEP_SHARED = False # if true, all deep layers will be use the same prompt emb 67 | _C.MODEL.PROMPT.FORWARD_DEEP_NOEXPAND = False # if true, will not expand input sequence for layers without prompt 68 | # how to get the output emb for cls head: 69 | # original: follow the orignial backbone choice, 70 | # img_pool: image patch pool only 71 | # prompt_pool: prompt embd pool only 72 | # imgprompt_pool: pool everything but the cls token 73 | _C.MODEL.PROMPT.VIT_POOL_TYPE = "original" 74 | _C.MODEL.PROMPT.DROPOUT = 0.0 75 | _C.MODEL.PROMPT.SAVE_FOR_EACH_EPOCH = False 76 | # ---------------------------------------------------------------------- 77 | # adapter options 78 | # ---------------------------------------------------------------------- 79 | _C.MODEL.ADAPTER = CfgNode() 80 | _C.MODEL.ADAPTER.REDUCATION_FACTOR = 8 81 | _C.MODEL.ADAPTER.STYLE = "Pfeiffer" 82 | 83 | # ---------------------------------------------------------------------- 84 | # Solver options 85 | # ---------------------------------------------------------------------- 86 | _C.SOLVER = CfgNode() 87 | _C.SOLVER.LOSS = "softmax" 88 | _C.SOLVER.LOSS_ALPHA = 0.01 89 | 90 | _C.SOLVER.OPTIMIZER = "sgd" # or "adamw" 91 | _C.SOLVER.MOMENTUM = 0.9 92 | _C.SOLVER.WEIGHT_DECAY = 0.0001 93 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0 94 | 95 | _C.SOLVER.PATIENCE = 300 96 | 97 | 98 | _C.SOLVER.SCHEDULER = "cosine" 99 | 100 | _C.SOLVER.BASE_LR = 0.01 101 | _C.SOLVER.BIAS_MULTIPLIER = 1. # for prompt + bias 102 | 103 | _C.SOLVER.WARMUP_EPOCH = 5 104 | _C.SOLVER.TOTAL_EPOCH = 30 105 | _C.SOLVER.LOG_EVERY_N = 1000 106 | 107 | 108 | _C.SOLVER.DBG_TRAINABLE = False # if True, will print the name of trainable params 109 | 110 | # ---------------------------------------------------------------------- 111 | # Dataset options 112 | # ---------------------------------------------------------------------- 113 | _C.DATA = CfgNode() 114 | 115 | _C.DATA.NAME = "" 116 | _C.DATA.DATAPATH = "" 117 | _C.DATA.FEATURE = "" # e.g. inat2021_supervised 118 | 119 | _C.DATA.PERCENTAGE = 1.0 120 | _C.DATA.NUMBER_CLASSES = -1 121 | _C.DATA.MULTILABEL = False 122 | _C.DATA.CLASS_WEIGHTS_TYPE = "none" 123 | 124 | _C.DATA.CROPSIZE = 224 # or 384 125 | 126 | _C.DATA.NO_TEST = False 127 | _C.DATA.BATCH_SIZE = 32 128 | # Number of data loader workers per training process 129 | _C.DATA.NUM_WORKERS = 4 130 | # Load data to pinned host memory 131 | _C.DATA.PIN_MEMORY = True 132 | 133 | _C.DIST_BACKEND = "nccl" 134 | _C.DIST_INIT_PATH = "env://" 135 | _C.DIST_INIT_FILE = "" 136 | 137 | 138 | def get_cfg(): 139 | """ 140 | Get a copy of the default config. 141 | """ 142 | return _C.clone() 143 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/registry.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Global Registry for the task adaptation framework. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import ast 25 | import functools 26 | 27 | 28 | def partialclass(cls, *base_args, **base_kwargs): 29 | """Builds a subclass with partial application of the given args and keywords. 30 | 31 | Equivalent to functools.partial performance, base_args are preprended to the 32 | positional arguments given during object initialization and base_kwargs are 33 | updated with the kwargs given later. 34 | 35 | Args: 36 | cls: The base class. 37 | *base_args: Positional arguments to be applied to the subclass. 38 | **base_kwargs: Keyword arguments to be applied to the subclass. 39 | 40 | Returns: 41 | A subclass of the input class. 42 | """ 43 | 44 | class _NewClass(cls): 45 | 46 | def __init__(self, *args, **kwargs): 47 | bound_args = base_args + args 48 | bound_kwargs = base_kwargs.copy() 49 | bound_kwargs.update(kwargs) 50 | super(_NewClass, self).__init__(*bound_args, **bound_kwargs) 51 | 52 | return _NewClass 53 | 54 | 55 | def parse_name(string_to_parse): 56 | """Parses input to the registry's lookup function. 57 | 58 | Args: 59 | string_to_parse: can be either an arbitrary name or function call 60 | (optionally with positional and keyword arguments). 61 | e.g. "multiclass", "resnet50_v2(filters_factor=8)". 62 | 63 | Returns: 64 | A tuple of input name and a dctinary with arguments. Examples: 65 | "multiclass" -> ("multiclass", (), {}) 66 | "resnet50_v2(9, filters_factor=4)" -> 67 | ("resnet50_v2", (9,), {"filters_factor": 4}) 68 | """ 69 | expr = ast.parse(string_to_parse, mode="eval").body # pytype: disable=attribute-error 70 | if not isinstance(expr, (ast.Attribute, ast.Call, ast.Name)): 71 | raise ValueError( 72 | "The given string should be a name or a call, but a {} was parsed from " 73 | "the string {!r}".format(type(expr), string_to_parse)) 74 | 75 | # Notes: 76 | # name="some_name" -> type(expr) = ast.Name 77 | # name="module.some_name" -> type(expr) = ast.Attribute 78 | # name="some_name()" -> type(expr) = ast.Call 79 | # name="module.some_name()" -> type(expr) = ast.Call 80 | 81 | if isinstance(expr, ast.Name): 82 | return string_to_parse, {} 83 | elif isinstance(expr, ast.Attribute): 84 | return string_to_parse, {} 85 | 86 | def _get_func_name(expr): 87 | if isinstance(expr, ast.Attribute): 88 | return _get_func_name(expr.value) + "." + expr.attr 89 | elif isinstance(expr, ast.Name): 90 | return expr.id 91 | else: 92 | raise ValueError( 93 | "Type {!r} is not supported in a function name, the string to parse " 94 | "was {!r}".format(type(expr), string_to_parse)) 95 | 96 | def _get_func_args_and_kwargs(call): 97 | args = tuple([ast.literal_eval(arg) for arg in call.args]) 98 | kwargs = { 99 | kwarg.arg: ast.literal_eval(kwarg.value) for kwarg in call.keywords 100 | } 101 | return args, kwargs 102 | 103 | func_name = _get_func_name(expr.func) 104 | func_args, func_kwargs = _get_func_args_and_kwargs(expr) 105 | if func_args: 106 | raise ValueError("Positional arguments are not supported here, but these " 107 | "were found: {!r}".format(func_args)) 108 | 109 | return func_name, func_kwargs 110 | 111 | 112 | class Registry(object): 113 | """Implements global Registry.""" 114 | 115 | _GLOBAL_REGISTRY = {} 116 | 117 | @staticmethod 118 | def global_registry(): 119 | return Registry._GLOBAL_REGISTRY 120 | 121 | @staticmethod 122 | def register(name, item_type): 123 | """Creates a function that registers its input.""" 124 | if item_type not in ["function", "class"]: 125 | raise ValueError("Unknown item type: %s" % item_type) 126 | 127 | def _register(item): 128 | if name in Registry.global_registry(): 129 | raise KeyError( 130 | "The name {!r} was already registered in with type {!r}".format( 131 | name, item_type)) 132 | 133 | Registry.global_registry()[name] = (item, item_type) 134 | return item 135 | 136 | return _register 137 | 138 | @staticmethod 139 | def lookup(lookup_string, kwargs_extra=None): 140 | """Lookup a name in the registry.""" 141 | 142 | name, kwargs = parse_name(lookup_string) 143 | if kwargs_extra: 144 | kwargs.update(kwargs_extra) 145 | item, item_type = Registry.global_registry()[name] 146 | if item_type == "function": 147 | return functools.partial(item, **kwargs) 148 | elif item_type == "class": 149 | return partialclass(item, **kwargs) 150 | -------------------------------------------------------------------------------- /src/utils/distributed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Distributed helpers.""" 4 | 5 | import torch 6 | import torch.distributed as dist 7 | _LOCAL_PROCESS_GROUP = None 8 | 9 | 10 | def get_world_size() -> int: 11 | if not dist.is_available(): 12 | return 1 13 | if not dist.is_initialized(): 14 | return 1 15 | return dist.get_world_size() 16 | 17 | 18 | def get_rank() -> int: 19 | if not dist.is_available(): 20 | return 0 21 | if not dist.is_initialized(): 22 | return 0 23 | return dist.get_rank() 24 | 25 | 26 | def is_master_process(num_gpus=8): 27 | """ 28 | Determines if the current process is the master process. 29 | """ 30 | if torch.distributed.is_initialized(): 31 | return dist.get_rank() % num_gpus == 0 32 | else: 33 | return True 34 | 35 | 36 | def run( 37 | local_rank, 38 | num_proc, 39 | func, 40 | init_method, 41 | shard_id, 42 | num_shards, 43 | backend, 44 | cfg, 45 | args, 46 | ): 47 | """ 48 | Runs a function from a child process. 49 | Args: 50 | local_rank (int): rank of the current process on the current machine. 51 | num_proc (int): number of processes per machine. 52 | func (function): function to execute on each of the process. 53 | init_method (string): method to initialize the distributed training. 54 | TCP initialization: equiring a network address reachable from all 55 | processes followed by the port. 56 | Shared file-system initialization: makes use of a file system that 57 | is shared and visible from all machines. The URL should start with 58 | file:// and contain a path to a non-existent file on a shared file 59 | system. 60 | shard_id (int): the rank of the current machine. 61 | num_shards (int): number of overall machines for the distributed 62 | training job. 63 | backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are 64 | supports, each with different capabilities. Details can be found 65 | here: 66 | https://pytorch.org/docs/stable/distributed.html 67 | cfg (CfgNode): configs. Details can be found in 68 | loco/config/defaults.py 69 | """ 70 | # Initialize the process group. 71 | # shard_id = get_rank() 72 | world_size = num_proc * num_shards 73 | rank = shard_id * num_proc + local_rank 74 | 75 | try: 76 | torch.distributed.init_process_group( 77 | backend=backend, 78 | init_method=init_method, 79 | world_size=world_size, 80 | rank=rank, 81 | ) 82 | except Exception as e: 83 | raise e 84 | 85 | torch.cuda.set_device(local_rank) 86 | func(cfg, args) 87 | 88 | 89 | def destroy_process_group(): 90 | """Destroys the default process group.""" 91 | torch.distributed.destroy_process_group() 92 | 93 | 94 | def scaled_all_reduce(cfg, tensors): 95 | """Performs the scaled all_reduce operation on the provided tensors. 96 | 97 | The input tensors are modified in-place. Currently supports only the sum 98 | reduction operator. The reduced values are scaled by the inverse size of 99 | the process group (equivalent to cfg.NUM_GPUS). 100 | """ 101 | # Queue the reductions 102 | reductions = [] 103 | for tensor in tensors: 104 | reduction = torch.distributed.all_reduce(tensor, async_op=True) 105 | reductions.append(reduction) 106 | # Wait for reductions to finish 107 | for reduction in reductions: 108 | reduction.wait() 109 | # Scale the results 110 | for tensor in tensors: 111 | tensor.mul_(1.0 / cfg.NUM_GPUS / cfg.NUM_SHARDS) 112 | return tensors 113 | 114 | 115 | def cat_all_gather(tensors): 116 | """Performs the concatenated all_gather operation on the provided tensors. 117 | """ 118 | tensors_gather = [ 119 | torch.ones_like(tensors) 120 | for _ in range(torch.distributed.get_world_size()) 121 | ] 122 | torch.distributed.all_gather(tensors_gather, tensors, async_op=False) 123 | 124 | output = torch.cat(tensors_gather, dim=0) 125 | return output 126 | 127 | 128 | def local_cat_all_gather(tensors): 129 | """Performs the concatenated all_gather operation on the provided tensors. 130 | """ 131 | tensors_gather = [ 132 | torch.ones_like(tensors) 133 | for _ in range(get_local_size()) 134 | ] 135 | torch.distributed.all_gather( 136 | tensors_gather, 137 | tensors, 138 | async_op=False, 139 | group=_LOCAL_PROCESS_GROUP, 140 | ) 141 | output = torch.cat(tensors_gather, dim=0) 142 | return output 143 | 144 | 145 | def get_local_size(): 146 | """ 147 | Returns: 148 | The size of the per-machine process group, 149 | i.e. the number of processes per machine. 150 | """ 151 | if not dist.is_available(): 152 | return 1 153 | if not dist.is_initialized(): 154 | return 1 155 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 156 | 157 | 158 | def get_local_rank(): 159 | """ 160 | Returns: 161 | The rank of the current process within the local (per-machine) process group. 162 | """ 163 | if not dist.is_available(): 164 | return 0 165 | if not dist.is_initialized(): 166 | return 0 167 | assert _LOCAL_PROCESS_GROUP is not None 168 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 169 | -------------------------------------------------------------------------------- /VTAB_SETUP.md: -------------------------------------------------------------------------------- 1 | # VTAB Preperation 2 | 3 | ## Download and prepare 4 | 5 | It is recommended to download the data before the experiments, to avoid duplicated effort if submitting experiments for multiple tuning protocols. Here are the collective command to set up the vtab data. 6 | 7 | ```python 8 | import tensorflow_datasets as tfds 9 | data_dir = "" # TODO: setup the data_dir to put the the data to, the DATA.DATAPATH value in config 10 | 11 | # caltech101 12 | dataset_builder = tfds.builder("caltech101:3.*.*", data_dir=data_dir) 13 | dataset_builder.download_and_prepare() 14 | 15 | # cifar100 16 | dataset_builder = tfds.builder("cifar100:3.*.*", data_dir=data_dir) 17 | dataset_builder.download_and_prepare() 18 | 19 | # clevr 20 | dataset_builder = tfds.builder("clevr:3.*.*", data_dir=data_dir) 21 | dataset_builder.download_and_prepare() 22 | 23 | # dmlab 24 | dataset_builder = tfds.builder("dmlab:2.0.1", data_dir=data_dir) 25 | dataset_builder.download_and_prepare() 26 | 27 | # dsprites 28 | dataset_builder = tfds.builder("dsprites:2.*.*", data_dir=data_dir) 29 | dataset_builder.download_and_prepare() 30 | 31 | # dtd 32 | dataset_builder = tfds.builder("dtd:3.*.*", data_dir=data_dir) 33 | dataset_builder.download_and_prepare() 34 | 35 | # eurosat 36 | subset="rgb" 37 | dataset_name = "eurosat/{}:2.*.*".format(subset) 38 | dataset_builder = tfds.builder(dataset_name, data_dir=data_dir) 39 | dataset_builder.download_and_prepare() 40 | 41 | # oxford_flowers102 42 | dataset_builder = tfds.builder("oxford_flowers102:2.*.*", data_dir=data_dir) 43 | dataset_builder.download_and_prepare() 44 | 45 | # oxford_iiit_pet 46 | dataset_builder = tfds.builder("oxford_iiit_pet:3.*.*", data_dir=data_dir) 47 | dataset_builder.download_and_prepare() 48 | 49 | # patch_camelyon 50 | dataset_builder = tfds.builder("patch_camelyon:2.*.*", data_dir=data_dir) 51 | dataset_builder.download_and_prepare() 52 | 53 | # smallnorb 54 | dataset_builder = tfds.builder("smallnorb:2.*.*", data_dir=data_dir) 55 | dataset_builder.download_and_prepare() 56 | 57 | # svhn 58 | dataset_builder = tfds.builder("svhn_cropped:3.*.*", data_dir=data_dir) 59 | dataset_builder.download_and_prepare() 60 | ``` 61 | 62 | There are 4 datasets need special care: 63 | 64 | ```python 65 | # sun397 --> need cv2 66 | # cannot load one image, similar to issue here: https://github.com/tensorflow/datasets/issues/2889 67 | # "Image /t/track/outdoor/sun_aophkoiosslinihb.jpg could not be decoded by Tensorflow."" 68 | # sol: modify the file: "/fsx/menglin/conda/envs/prompt_tf/lib/python3.7/site-packages/tensorflow_datasets/image_classification/sun.py" to ignore those images 69 | dataset_builder = tfds.builder("sun397/tfds:4.*.*", data_dir=data_dir) 70 | dataset_builder.download_and_prepare() 71 | 72 | # kitti version is wrong from vtab repo, try 3.2.0 (https://github.com/google-research/task_adaptation/issues/18) 73 | dataset_builder = tfds.builder("kitti:3.2.0", data_dir=data_dir) 74 | dataset_builder.download_and_prepare() 75 | 76 | 77 | # diabetic_retinopathy 78 | """ 79 | Download this dataset from Kaggle. 80 | https://www.kaggle.com/c/diabetic-retinopathy-detection/data 81 | After downloading, 82 | - unpack the test.zip file into /manual_dir/. 83 | - unpack the sample.zip to sample/. 84 | - unpack the sampleSubmissions.csv and trainLabels.csv. 85 | 86 | # ==== important! ==== 87 | # 1. make sure to check that there are 5 train.zip files instead of 4 (somehow if you chose to download all from kaggle, the train.zip.005 file is missing) 88 | # 2. if unzip train.zip ran into issues, try to use jar xvf train.zip to handle huge zip file 89 | cat test.zip.* > test.zip 90 | cat train.zip.* > train.zip 91 | """ 92 | 93 | config_and_version = "btgraham-300" + ":3.*.*" 94 | dataset_builder = tfds.builder("diabetic_retinopathy_detection/{}".format(config_and_version), data_dir=data_dir) 95 | dataset_builder.download_and_prepare() 96 | 97 | 98 | # resisc45 99 | """ 100 | download/extract dataset artifacts manually: 101 | Dataset can be downloaded from OneDrive: https://1drv.ms/u/s!AmgKYzARBl5ca3HNaHIlzp_IXjs 102 | After downloading the rar file, please extract it to the manual_dir. 103 | """ 104 | 105 | dataset_builder = tfds.builder("resisc45:3.*.*", data_dir=data_dir) 106 | dataset_builder.download_and_prepare() 107 | ``` 108 | 109 | 110 | 111 | ## Notes 112 | 113 | ### TFDS version 114 | Note that the experimental results may be different with different API and/or dataset generation code versions. See more from [tfds documentation](https://www.tensorflow.org/datasets/datasets_versioning). Here are what we used for VPT: 115 | 116 | ```bash 117 | tfds: 4.4.0+nightly 118 | 119 | # Natural: 120 | cifar100: 3.0.2 121 | caltech101: 3.0.1 122 | dtd: 3.0.1 123 | oxford_flowers102: 2.1.1 124 | oxford_iiit_pet: 3.2.0 125 | svhn_cropped: 3.0.0 126 | sun397: 4.0.0 127 | 128 | # Specialized: 129 | patch_camelyon: 2.0.0 130 | eurosat: 2.0.0 131 | resisc45: 3.0.0 132 | diabetic_retinopathy_detection: 3.0.0 133 | 134 | 135 | # Structured 136 | clevr: 3.1.0 137 | dmlab: 2.0.1 138 | kitti: 3.2.0 139 | dsprites: 2.0.0 140 | smallnorb: 2.0.0 141 | ``` 142 | 143 | ### Train split 144 | As in issue https://github.com/KMnP/vpt/issues/1, we also uploaded the vtab train split info to the vtab data release [Google Drive](https://drive.google.com/drive/folders/1mnvxTkYxmOr2W9QjcgS64UBpoJ4UmKaM)/[Dropbox](https://cornell.app.box.com/v/vptfgvcsplits). In the file `vtab_trainval_splits.json`, for each dataset, you can find the filenames of the randomly selected 1k training examples used in our experiment. We got them by extracting the ‘filename’ attribute from the tensorflow dataset feature dict. Unfortunately, because there’s no such info for [dsprite](https://www.tensorflow.org/datasets/catalog/dsprites), [smallnorb](https://www.tensorflow.org/datasets/catalog/smallnorb) and [svhn](https://www.tensorflow.org/datasets/catalog/svhn_cropped) in the tensorflow dataset format, we cannot provide the splits for these 3 datasets. 145 | -------------------------------------------------------------------------------- /src/data/datasets/json_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """JSON dataset: support CUB, NABrids, Flower, Dogs and Cars""" 4 | 5 | import os 6 | import torch 7 | import torch.utils.data 8 | import torchvision as tv 9 | import numpy as np 10 | from collections import Counter 11 | 12 | from ..transforms import get_transforms 13 | from ...utils import logging 14 | from ...utils.io_utils import read_json 15 | logger = logging.get_logger("visual_prompt") 16 | 17 | 18 | class JSONDataset(torch.utils.data.Dataset): 19 | def __init__(self, cfg, split): 20 | assert split in { 21 | "train", 22 | "val", 23 | "test", 24 | }, "Split '{}' not supported for {} dataset".format( 25 | split, cfg.DATA.NAME) 26 | logger.info("Constructing {} dataset {}...".format( 27 | cfg.DATA.NAME, split)) 28 | 29 | self.cfg = cfg 30 | self._split = split 31 | self.name = cfg.DATA.NAME 32 | self.data_dir = cfg.DATA.DATAPATH 33 | self.data_percentage = cfg.DATA.PERCENTAGE 34 | self._construct_imdb(cfg) 35 | self.transform = get_transforms(split, cfg.DATA.CROPSIZE) 36 | 37 | def get_anno(self): 38 | anno_path = os.path.join(self.data_dir, "{}.json".format(self._split)) 39 | if "train" in self._split: 40 | if self.data_percentage < 1.0: 41 | anno_path = os.path.join( 42 | self.data_dir, 43 | "{}_{}.json".format(self._split, self.data_percentage) 44 | ) 45 | assert os.path.exists(anno_path), "{} dir not found".format(anno_path) 46 | 47 | return read_json(anno_path) 48 | 49 | def get_imagedir(self): 50 | raise NotImplementedError() 51 | 52 | def _construct_imdb(self, cfg): 53 | """Constructs the imdb.""" 54 | 55 | img_dir = self.get_imagedir() 56 | assert os.path.exists(img_dir), "{} dir not found".format(img_dir) 57 | 58 | anno = self.get_anno() 59 | # Map class ids to contiguous ids 60 | self._class_ids = sorted(list(set(anno.values()))) 61 | self._class_id_cont_id = {v: i for i, v in enumerate(self._class_ids)} 62 | 63 | # Construct the image db 64 | self._imdb = [] 65 | for img_name, cls_id in anno.items(): 66 | cont_id = self._class_id_cont_id[cls_id] 67 | im_path = os.path.join(img_dir, img_name) 68 | self._imdb.append({"im_path": im_path, "class": cont_id}) 69 | 70 | logger.info("Number of images: {}".format(len(self._imdb))) 71 | logger.info("Number of classes: {}".format(len(self._class_ids))) 72 | 73 | def get_info(self): 74 | num_imgs = len(self._imdb) 75 | return num_imgs, self.get_class_num() 76 | 77 | def get_class_num(self): 78 | return self.cfg.DATA.NUMBER_CLASSES 79 | # return len(self._class_ids) 80 | 81 | def get_class_weights(self, weight_type): 82 | """get a list of class weight, return a list float""" 83 | if "train" not in self._split: 84 | raise ValueError( 85 | "only getting training class distribution, " + \ 86 | "got split {} instead".format(self._split) 87 | ) 88 | 89 | cls_num = self.get_class_num() 90 | if weight_type == "none": 91 | return [1.0] * cls_num 92 | 93 | id2counts = Counter(self._class_ids) 94 | assert len(id2counts) == cls_num 95 | num_per_cls = np.array([id2counts[i] for i in self._class_ids]) 96 | 97 | if weight_type == 'inv': 98 | mu = -1.0 99 | elif weight_type == 'inv_sqrt': 100 | mu = -0.5 101 | weight_list = num_per_cls ** mu 102 | weight_list = np.divide( 103 | weight_list, np.linalg.norm(weight_list, 1)) * cls_num 104 | return weight_list.tolist() 105 | 106 | def __getitem__(self, index): 107 | # Load the image 108 | im = tv.datasets.folder.default_loader(self._imdb[index]["im_path"]) 109 | label = self._imdb[index]["class"] 110 | im = self.transform(im) 111 | if self._split == "train": 112 | index = index 113 | else: 114 | index = f"{self._split}{index}" 115 | sample = { 116 | "image": im, 117 | "label": label, 118 | # "id": index 119 | } 120 | return sample 121 | 122 | def __len__(self): 123 | return len(self._imdb) 124 | 125 | 126 | class CUB200Dataset(JSONDataset): 127 | """CUB_200 dataset.""" 128 | 129 | def __init__(self, cfg, split): 130 | super(CUB200Dataset, self).__init__(cfg, split) 131 | 132 | def get_imagedir(self): 133 | return os.path.join(self.data_dir, "images") 134 | 135 | 136 | class CarsDataset(JSONDataset): 137 | """stanford-cars dataset.""" 138 | 139 | def __init__(self, cfg, split): 140 | super(CarsDataset, self).__init__(cfg, split) 141 | 142 | def get_imagedir(self): 143 | return self.data_dir 144 | 145 | 146 | class DogsDataset(JSONDataset): 147 | """stanford-dogs dataset.""" 148 | 149 | def __init__(self, cfg, split): 150 | super(DogsDataset, self).__init__(cfg, split) 151 | 152 | def get_imagedir(self): 153 | return os.path.join(self.data_dir, "Images") 154 | 155 | 156 | class FlowersDataset(JSONDataset): 157 | """flowers dataset.""" 158 | 159 | def __init__(self, cfg, split): 160 | super(FlowersDataset, self).__init__(cfg, split) 161 | 162 | def get_imagedir(self): 163 | return self.data_dir 164 | 165 | 166 | class NabirdsDataset(JSONDataset): 167 | """Nabirds dataset.""" 168 | 169 | def __init__(self, cfg, split): 170 | super(NabirdsDataset, self).__init__(cfg, split) 171 | 172 | def get_imagedir(self): 173 | return os.path.join(self.data_dir, "images") 174 | 175 | -------------------------------------------------------------------------------- /src/models/vit_backbones/vit_moco.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | """ 4 | borrowed from https://github.com/facebookresearch/moco-v3/blob/main/vits.py 5 | """ 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | from functools import partial, reduce 10 | from operator import mul 11 | 12 | from timm.models.vision_transformer import VisionTransformer, _cfg 13 | from timm.models.layers.helpers import to_2tuple 14 | from timm.models.layers import PatchEmbed 15 | 16 | __all__ = [ 17 | 'vit_small', 18 | 'vit_base', 19 | 'vit_conv_small', 20 | 'vit_conv_base', 21 | ] 22 | 23 | 24 | class VisionTransformerMoCo(VisionTransformer): 25 | def __init__(self, stop_grad_conv1=False, **kwargs): 26 | super().__init__(**kwargs) 27 | # Use fixed 2D sin-cos position embedding 28 | self.build_2d_sincos_position_embedding() 29 | 30 | # weight initialization 31 | for name, m in self.named_modules(): 32 | if isinstance(m, nn.Linear): 33 | if 'qkv' in name: 34 | # treat the weights of Q, K, V separately 35 | val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1])) 36 | nn.init.uniform_(m.weight, -val, val) 37 | else: 38 | nn.init.xavier_uniform_(m.weight) 39 | nn.init.zeros_(m.bias) 40 | nn.init.normal_(self.cls_token, std=1e-6) 41 | 42 | if isinstance(self.patch_embed, PatchEmbed): 43 | # xavier_uniform initialization 44 | val = math.sqrt(6. / float(3 * reduce(mul, self.patch_embed.patch_size, 1) + self.embed_dim)) 45 | nn.init.uniform_(self.patch_embed.proj.weight, -val, val) 46 | nn.init.zeros_(self.patch_embed.proj.bias) 47 | 48 | if stop_grad_conv1: 49 | self.patch_embed.proj.weight.requires_grad = False 50 | self.patch_embed.proj.bias.requires_grad = False 51 | 52 | def build_2d_sincos_position_embedding(self, temperature=10000.): 53 | h, w = self.patch_embed.grid_size 54 | grid_w = torch.arange(w, dtype=torch.float32) 55 | grid_h = torch.arange(h, dtype=torch.float32) 56 | grid_w, grid_h = torch.meshgrid(grid_w, grid_h) 57 | assert self.embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' 58 | pos_dim = self.embed_dim // 4 59 | omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim 60 | omega = 1. / (temperature**omega) 61 | out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) 62 | out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) 63 | pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :] 64 | 65 | assert self.num_tokens == 1, 'Assuming one and only one token, [cls]' 66 | pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32) 67 | self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1)) 68 | self.pos_embed.requires_grad = False 69 | 70 | 71 | class ConvStem(nn.Module): 72 | """ 73 | ConvStem, from Early Convolutions Help Transformers See Better, Tete et al. https://arxiv.org/abs/2106.14881 74 | """ 75 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 76 | super().__init__() 77 | 78 | assert patch_size == 16, 'ConvStem only supports patch size of 16' 79 | assert embed_dim % 8 == 0, 'Embed dimension must be divisible by 8 for ConvStem' 80 | 81 | img_size = to_2tuple(img_size) 82 | patch_size = to_2tuple(patch_size) 83 | self.img_size = img_size 84 | self.patch_size = patch_size 85 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 86 | self.num_patches = self.grid_size[0] * self.grid_size[1] 87 | self.flatten = flatten 88 | 89 | # build stem, similar to the design in https://arxiv.org/abs/2106.14881 90 | stem = [] 91 | input_dim, output_dim = 3, embed_dim // 8 92 | for l in range(4): 93 | stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False)) 94 | stem.append(nn.BatchNorm2d(output_dim)) 95 | stem.append(nn.ReLU(inplace=True)) 96 | input_dim = output_dim 97 | output_dim *= 2 98 | stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1)) 99 | self.proj = nn.Sequential(*stem) 100 | 101 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 102 | 103 | def forward(self, x): 104 | B, C, H, W = x.shape 105 | assert H == self.img_size[0] and W == self.img_size[1], \ 106 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 107 | x = self.proj(x) 108 | if self.flatten: 109 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 110 | x = self.norm(x) 111 | return x 112 | 113 | 114 | def vit_small(**kwargs): 115 | model = VisionTransformerMoCo( 116 | patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 117 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 118 | model.default_cfg = _cfg() 119 | return model 120 | 121 | def vit_base(**kwargs): 122 | model = VisionTransformerMoCo( 123 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 124 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 125 | model.default_cfg = _cfg() 126 | return model 127 | 128 | def vit_conv_small(**kwargs): 129 | # minus one ViT block 130 | model = VisionTransformerMoCo( 131 | patch_size=16, embed_dim=384, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True, 132 | norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs) 133 | model.default_cfg = _cfg() 134 | return model 135 | 136 | def vit_conv_base(**kwargs): 137 | # minus one ViT block 138 | model = VisionTransformerMoCo( 139 | patch_size=16, embed_dim=768, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True, 140 | norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs) 141 | model.default_cfg = _cfg() 142 | return model 143 | -------------------------------------------------------------------------------- /src/models/vit_prompt/vit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | vit with prompt: a clean version with the default settings of VPT 4 | """ 5 | import math 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torchvision as tv 10 | 11 | from functools import reduce 12 | from operator import mul 13 | from torch.nn.modules.utils import _pair 14 | from torch.nn import Conv2d, Dropout 15 | from scipy import ndimage 16 | 17 | from ..vit_backbones.vit import CONFIGS, Transformer, VisionTransformer, np2th 18 | from ...utils import logging 19 | 20 | logger = logging.get_logger("visual_prompt") 21 | 22 | 23 | class PromptedTransformer(Transformer): 24 | def __init__(self, prompt_config, config, img_size, vis): 25 | assert prompt_config.LOCATION == "prepend" 26 | assert prompt_config.INITIATION == "random" 27 | assert prompt_config.NUM_DEEP_LAYERS is None 28 | assert not prompt_config.DEEP_SHARED 29 | super(PromptedTransformer, self).__init__( 30 | config, img_size, vis) 31 | 32 | self.prompt_config = prompt_config 33 | self.vit_config = config 34 | 35 | img_size = _pair(img_size) 36 | patch_size = _pair(config.patches["size"]) 37 | 38 | num_tokens = self.prompt_config.NUM_TOKENS 39 | self.num_tokens = num_tokens # number of prompted tokens 40 | 41 | self.prompt_dropout = Dropout(self.prompt_config.DROPOUT) 42 | 43 | # if project the prompt embeddings 44 | if self.prompt_config.PROJECT > -1: 45 | # only for prepend / add 46 | prompt_dim = self.prompt_config.PROJECT 47 | self.prompt_proj = nn.Linear( 48 | prompt_dim, config.hidden_size) 49 | nn.init.kaiming_normal_( 50 | self.prompt_proj.weight, a=0, mode='fan_out') 51 | else: 52 | prompt_dim = config.hidden_size 53 | self.prompt_proj = nn.Identity() 54 | 55 | # initiate prompt: 56 | if self.prompt_config.INITIATION == "random": 57 | val = math.sqrt(6. / float(3 * reduce(mul, patch_size, 1) + prompt_dim)) # noqa 58 | 59 | self.prompt_embeddings = nn.Parameter(torch.zeros( 60 | 1, num_tokens, prompt_dim)) 61 | # xavier_uniform initialization 62 | nn.init.uniform_(self.prompt_embeddings.data, -val, val) 63 | 64 | if self.prompt_config.DEEP: # noqa 65 | 66 | total_d_layer = config.transformer["num_layers"]-1 67 | self.deep_prompt_embeddings = nn.Parameter(torch.zeros( 68 | total_d_layer, num_tokens, prompt_dim)) 69 | # xavier_uniform initialization 70 | nn.init.uniform_(self.deep_prompt_embeddings.data, -val, val) 71 | 72 | else: 73 | raise ValueError("Other initiation scheme is not supported") 74 | 75 | def incorporate_prompt(self, x): 76 | # combine prompt embeddings with image-patch embeddings 77 | B = x.shape[0] 78 | # after CLS token, all before image patches 79 | x = self.embeddings(x) # (batch_size, 1 + n_patches, hidden_dim) 80 | x = torch.cat(( 81 | x[:, :1, :], 82 | self.prompt_dropout(self.prompt_proj(self.prompt_embeddings).expand(B, -1, -1)), 83 | x[:, 1:, :] 84 | ), dim=1) 85 | # (batch_size, cls_token + n_prompt + n_patches, hidden_dim) 86 | 87 | return x 88 | 89 | def train(self, mode=True): 90 | # set train status for this class: disable all but the prompt-related modules 91 | if mode: 92 | # training: 93 | self.encoder.eval() 94 | self.embeddings.eval() 95 | self.prompt_proj.train() 96 | self.prompt_dropout.train() 97 | else: 98 | # eval: 99 | for module in self.children(): 100 | module.train(mode) 101 | 102 | def forward_deep_prompt(self, embedding_output): 103 | attn_weights = [] 104 | hidden_states = None 105 | weights = None 106 | B = embedding_output.shape[0] 107 | num_layers = self.vit_config.transformer["num_layers"] 108 | 109 | for i in range(num_layers): 110 | if i == 0: 111 | hidden_states, weights = self.encoder.layer[i](embedding_output) 112 | else: 113 | if i <= self.deep_prompt_embeddings.shape[0]: 114 | deep_prompt_emb = self.prompt_dropout(self.prompt_proj( 115 | self.deep_prompt_embeddings[i-1]).expand(B, -1, -1)) 116 | 117 | hidden_states = torch.cat(( 118 | hidden_states[:, :1, :], 119 | deep_prompt_emb, 120 | hidden_states[:, (1+self.num_tokens):, :] 121 | ), dim=1) 122 | 123 | 124 | hidden_states, weights = self.encoder.layer[i](hidden_states) 125 | 126 | if self.encoder.vis: 127 | attn_weights.append(weights) 128 | 129 | encoded = self.encoder.encoder_norm(hidden_states) 130 | return encoded, attn_weights 131 | 132 | def forward(self, x): 133 | # this is the default version: 134 | embedding_output = self.incorporate_prompt(x) 135 | 136 | if self.prompt_config.DEEP: 137 | encoded, attn_weights = self.forward_deep_prompt( 138 | embedding_output) 139 | else: 140 | encoded, attn_weights = self.encoder(embedding_output) 141 | 142 | return encoded, attn_weights 143 | 144 | 145 | class PromptedVisionTransformer(VisionTransformer): 146 | def __init__( 147 | self, prompt_cfg, model_type, 148 | img_size=224, num_classes=21843, vis=False 149 | ): 150 | assert prompt_cfg.VIT_POOL_TYPE == "original" 151 | super(PromptedVisionTransformer, self).__init__( 152 | model_type, img_size, num_classes, vis) 153 | if prompt_cfg is None: 154 | raise ValueError("prompt_cfg cannot be None if using PromptedVisionTransformer") 155 | self.prompt_cfg = prompt_cfg 156 | vit_cfg = CONFIGS[model_type] 157 | self.transformer = PromptedTransformer( 158 | prompt_cfg, vit_cfg, img_size, vis) 159 | 160 | def forward(self, x, vis=False): 161 | x, attn_weights = self.transformer(x) 162 | 163 | x = x[:, 0] 164 | 165 | logits = self.head(x) 166 | 167 | if not vis: 168 | return logits 169 | return logits, attn_weights 170 | -------------------------------------------------------------------------------- /tune_fgvc.py: -------------------------------------------------------------------------------- 1 | """ 2 | tune lr, wd for fgvc datasets and other datasets with train / val / test splits 3 | """ 4 | import os 5 | import warnings 6 | 7 | from time import sleep 8 | from random import randint 9 | 10 | from src.configs.config import get_cfg 11 | from src.utils.file_io import PathManager 12 | 13 | from train import train as train_main 14 | from launch import default_argument_parser 15 | warnings.filterwarnings("ignore") 16 | 17 | 18 | def setup(args, lr, wd, check_runtime=True): 19 | """ 20 | Create configs and perform basic setups. 21 | overwrite the 2 parameters in cfg and args 22 | """ 23 | cfg = get_cfg() 24 | cfg.merge_from_file(args.config_file) 25 | cfg.merge_from_list(args.opts) 26 | 27 | # setup dist 28 | cfg.DIST_INIT_PATH = "tcp://{}:4000".format(os.environ["SLURMD_NODENAME"]) 29 | 30 | # overwrite below four parameters 31 | lr = lr / 256 * cfg.DATA.BATCH_SIZE # update lr based on the batchsize 32 | cfg.SOLVER.BASE_LR = lr 33 | cfg.SOLVER.WEIGHT_DECAY = wd 34 | 35 | # setup output dir 36 | # output_dir / data_name / feature_name / lr_wd / run1 37 | output_dir = cfg.OUTPUT_DIR 38 | output_folder = os.path.join( 39 | cfg.DATA.NAME, cfg.DATA.FEATURE, f"lr{lr}_wd{wd}" 40 | ) 41 | # output_folder = os.path.splitext(os.path.basename(args.config_file))[0] 42 | 43 | # train cfg.RUN_N_TIMES times 44 | if check_runtime: 45 | count = 1 46 | while count <= cfg.RUN_N_TIMES: 47 | output_path = os.path.join(output_dir, output_folder, f"run{count}") 48 | # pause for a random time, so concurrent process with same setting won't interfere with each other. # noqa 49 | sleep(randint(1, 5)) 50 | if not PathManager.exists(output_path): 51 | PathManager.mkdirs(output_path) 52 | cfg.OUTPUT_DIR = output_path 53 | break 54 | else: 55 | count += 1 56 | if count > cfg.RUN_N_TIMES: 57 | raise ValueError( 58 | f"Already run {cfg.RUN_N_TIMES} times for {output_folder}, no need to run more") 59 | else: 60 | # only used for dummy config file 61 | output_path = os.path.join(output_dir, output_folder, f"run1") 62 | cfg.OUTPUT_DIR = output_path 63 | 64 | cfg.freeze() 65 | return cfg 66 | 67 | 68 | def finetune_main(args): 69 | lr_range = [0.001, 0.0001, 0.0005, 0.005] 70 | wd_range = [0.01, 0.001, 0.0001, 0.0] 71 | for wd in wd_range: 72 | for lr in lr_range: 73 | # set up cfg and args 74 | try: 75 | cfg = setup(args, lr, wd) 76 | except ValueError: 77 | continue 78 | train_main(cfg, args) 79 | 80 | 81 | def finetune_rn_main(args): 82 | lr_range = [ 83 | 0.05, 0.025, 0.005, 0.0025 84 | ] 85 | wd_range = [0.01, 0.001, 0.0001, 0.0] 86 | for wd in wd_range: 87 | for lr in lr_range: 88 | # set up cfg and args 89 | try: 90 | cfg = setup(args, lr, wd) 91 | except ValueError as e: 92 | print(e) 93 | continue 94 | train_main(cfg, args) 95 | 96 | 97 | def prompt_rn_main(args): 98 | lr_range = [ 99 | 0.05, 0.025, 0.01, 0.5, 0.25, 0.1, 100 | 1.0, 2.5, 5. 101 | ] 102 | wd_range = [0.01, 0.001, 0.0001, 0.0] 103 | for lr in sorted(lr_range, reverse=True): 104 | for wd in wd_range: 105 | # set up cfg and args 106 | try: 107 | cfg = setup(args, lr, wd) 108 | except ValueError as e: 109 | print(e) 110 | continue 111 | train_main(cfg, args) 112 | 113 | 114 | def linear_main(args): 115 | lr_range = [ 116 | 50.0, 25., 10.0, 117 | 5.0, 2.5, 1.0, 118 | 0.5, 0.25, 0.1, 0.05 119 | ] 120 | wd_range = [0.01, 0.001, 0.0001, 0.0] 121 | for lr in lr_range: 122 | for wd in wd_range: 123 | # set up cfg and args 124 | try: 125 | cfg = setup(args, lr, wd) 126 | except ValueError: 127 | continue 128 | train_main(cfg, args) 129 | sleep(randint(1, 10)) 130 | 131 | 132 | def linear_mae_main(args): 133 | lr_range = [ 134 | 50.0, 25., 10.0, 135 | 5.0, 2.5, 1.0, 136 | 0.5, 0.25, 0.1, 0.05, 137 | 0.025, 0.005, 0.0025, 138 | ] 139 | wd_range = [0.01, 0.001, 0.0001, 0.0] 140 | for lr in lr_range: 141 | for wd in wd_range: 142 | # set up cfg and args 143 | try: 144 | cfg = setup(args, lr, wd) 145 | except ValueError: 146 | continue 147 | train_main(cfg, args) 148 | sleep(randint(1, 10)) 149 | 150 | 151 | def prompt_main(args): 152 | lr_range = [ 153 | 5.0, 2.5, 1.0, 154 | 50.0, 25., 10.0, 155 | 0.5, 0.25, 0.1, 156 | ] 157 | wd_range = [0.01, 0.001, 0.0001, 0.0] 158 | for lr in lr_range: 159 | for wd in wd_range: 160 | # set up cfg and args 161 | try: 162 | cfg = setup(args, lr, wd) 163 | except ValueError: 164 | continue 165 | train_main(cfg, args) 166 | sleep(randint(1, 10)) 167 | 168 | 169 | def prompt_main_largerrange(args): 170 | lr_range = [ 171 | 500, 1000, # for parralel-based prompt for stanford cars 172 | 250., 100.0, # for parralel-based prompt for stanford cars 173 | ] 174 | wd_range = [0.0, 0.01, 0.001, 0.0001] 175 | for lr in lr_range: 176 | for wd in wd_range: 177 | # set up cfg and args 178 | try: 179 | cfg = setup(args, lr, wd) 180 | except ValueError: 181 | continue 182 | train_main(cfg, args) 183 | sleep(randint(1, 10)) 184 | 185 | 186 | def main(args): 187 | """main function to call from workflow""" 188 | if args.train_type == "finetune": 189 | finetune_main(args) 190 | elif args.train_type == "finetune_resnet": 191 | finetune_rn_main(args) 192 | 193 | elif args.train_type == "linear": 194 | linear_main(args) 195 | elif args.train_type == "linear_mae": 196 | linear_mae_main(args) 197 | 198 | elif args.train_type == "prompt": 199 | prompt_main(args) 200 | elif args.train_type == "prompt_resnet": 201 | prompt_rn_main(args) 202 | elif args.train_type == "prompt_largerrange" or args.train_type == "prompt_largerlr": # noqa 203 | prompt_main_largerrange(args) 204 | 205 | 206 | if __name__ == '__main__': 207 | args = default_argument_parser().parse_args() 208 | main(args) 209 | -------------------------------------------------------------------------------- /src/utils/logging.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Logging.""" 4 | 5 | import builtins 6 | import decimal 7 | import functools 8 | import logging 9 | import simplejson 10 | import sys 11 | import os 12 | from termcolor import colored 13 | 14 | from .distributed import is_master_process 15 | from .file_io import PathManager 16 | 17 | # Show filename and line number in logs 18 | _FORMAT = "[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s" 19 | 20 | 21 | def _suppress_print(): 22 | """Suppresses printing from the current process.""" 23 | 24 | def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False): 25 | pass 26 | 27 | builtins.print = print_pass 28 | 29 | 30 | # cache the opened file object, so that different calls to `setup_logger` 31 | # with the same file name can safely write to the same file. 32 | @functools.lru_cache(maxsize=None) 33 | def _cached_log_stream(filename): 34 | return PathManager.open(filename, "a") 35 | 36 | 37 | @functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers # noqa 38 | def setup_logging( 39 | num_gpu, num_shards, output="", name="visual_prompt", color=True): 40 | """Sets up the logging.""" 41 | # Enable logging only for the master process 42 | if is_master_process(num_gpu): 43 | # Clear the root logger to prevent any existing logging config 44 | # (e.g. set by another module) from messing with our setup 45 | logging.root.handlers = [] 46 | # Configure logging 47 | logging.basicConfig( 48 | level=logging.INFO, format=_FORMAT, stream=sys.stdout 49 | ) 50 | else: 51 | _suppress_print() 52 | 53 | if name is None: 54 | name = __name__ 55 | logger = logging.getLogger(name) 56 | # remove any lingering handler 57 | logger.handlers.clear() 58 | 59 | logger.setLevel(logging.INFO) 60 | logger.propagate = False 61 | 62 | plain_formatter = logging.Formatter( 63 | "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s", 64 | datefmt="%m/%d %H:%M:%S", 65 | ) 66 | if color: 67 | formatter = _ColorfulFormatter( 68 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 69 | datefmt="%m/%d %H:%M:%S", 70 | root_name=name, 71 | abbrev_name=str(name), 72 | ) 73 | else: 74 | formatter = plain_formatter 75 | 76 | if is_master_process(num_gpu): 77 | ch = logging.StreamHandler(stream=sys.stdout) 78 | ch.setLevel(logging.DEBUG) 79 | ch.setFormatter(formatter) 80 | logger.addHandler(ch) 81 | 82 | if is_master_process(num_gpu * num_shards): 83 | if len(output) > 0: 84 | if output.endswith(".txt") or output.endswith(".log"): 85 | filename = output 86 | else: 87 | filename = os.path.join(output, "logs.txt") 88 | 89 | PathManager.mkdirs(os.path.dirname(filename)) 90 | 91 | fh = logging.StreamHandler(_cached_log_stream(filename)) 92 | fh.setLevel(logging.DEBUG) 93 | fh.setFormatter(plain_formatter) 94 | logger.addHandler(fh) 95 | return logger 96 | 97 | 98 | def setup_single_logging(name, output=""): 99 | """Sets up the logging.""" 100 | # Enable logging only for the master process 101 | # Clear the root logger to prevent any existing logging config 102 | # (e.g. set by another module) from messing with our setup 103 | logging.root.handlers = [] 104 | # Configure logging 105 | logging.basicConfig( 106 | level=logging.INFO, format=_FORMAT, stream=sys.stdout 107 | ) 108 | 109 | if len(name) == 0: 110 | name = __name__ 111 | logger = logging.getLogger(name) 112 | logger.setLevel(logging.INFO) 113 | logger.propagate = False 114 | 115 | plain_formatter = logging.Formatter( 116 | "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s", 117 | datefmt="%m/%d %H:%M:%S", 118 | ) 119 | formatter = _ColorfulFormatter( 120 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 121 | datefmt="%m/%d %H:%M:%S", 122 | root_name=name, 123 | abbrev_name=str(name), 124 | ) 125 | 126 | ch = logging.StreamHandler(stream=sys.stdout) 127 | ch.setLevel(logging.DEBUG) 128 | ch.setFormatter(formatter) 129 | logger.addHandler(ch) 130 | 131 | if len(output) > 0: 132 | if output.endswith(".txt") or output.endswith(".log"): 133 | filename = output 134 | else: 135 | filename = os.path.join(output, "logs.txt") 136 | 137 | PathManager.mkdirs(os.path.dirname(filename)) 138 | 139 | fh = logging.StreamHandler(_cached_log_stream(filename)) 140 | fh.setLevel(logging.DEBUG) 141 | fh.setFormatter(plain_formatter) 142 | logger.addHandler(fh) 143 | 144 | return logger 145 | 146 | 147 | def get_logger(name): 148 | """Retrieves the logger.""" 149 | return logging.getLogger(name) 150 | 151 | 152 | def log_json_stats(stats, sort_keys=True): 153 | """Logs json stats.""" 154 | # It seems that in Python >= 3.6 json.encoder.FLOAT_REPR has no effect 155 | # Use decimal+string as a workaround for having fixed length values in logs 156 | logger = get_logger(__name__) 157 | stats = { 158 | k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v 159 | for k, v in stats.items() 160 | } 161 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True) 162 | if stats["_type"] == "test_epoch" or stats["_type"] == "train_epoch": 163 | logger.info("json_stats: {:s}".format(json_stats)) 164 | else: 165 | logger.info("{:s}".format(json_stats)) 166 | 167 | 168 | class _ColorfulFormatter(logging.Formatter): 169 | # from detectron2 170 | def __init__(self, *args, **kwargs): 171 | self._root_name = kwargs.pop("root_name") + "." 172 | self._abbrev_name = kwargs.pop("abbrev_name", "") 173 | if len(self._abbrev_name): 174 | self._abbrev_name = self._abbrev_name + "." 175 | super(_ColorfulFormatter, self).__init__(*args, **kwargs) 176 | 177 | def formatMessage(self, record: logging.LogRecord) -> str: 178 | record.name = record.name.replace(self._root_name, self._abbrev_name) 179 | log = super(_ColorfulFormatter, self).formatMessage(record) 180 | if record.levelno == logging.WARNING: 181 | prefix = colored("WARNING", "red", attrs=["blink"]) 182 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 183 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 184 | else: 185 | return log 186 | return prefix + " " + log 187 | -------------------------------------------------------------------------------- /src/models/vit_prompt/vit_mae.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | vit-moco-v3 with prompt 4 | """ 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | import torchvision as tv 9 | 10 | from functools import partial, reduce 11 | from operator import mul 12 | from torch.nn import Conv2d, Dropout 13 | from timm.models.vision_transformer import _cfg 14 | 15 | from ..vit_backbones.vit_mae import VisionTransformer 16 | from ...utils import logging 17 | logger = logging.get_logger("visual_prompt") 18 | 19 | 20 | class PromptedVisionTransformer(VisionTransformer): 21 | def __init__(self, prompt_config, **kwargs): 22 | super().__init__(**kwargs) 23 | self.prompt_config = prompt_config 24 | if self.prompt_config.DEEP and self.prompt_config.LOCATION not in ["prepend", ]: 25 | raise ValueError("Deep-{} is not supported".format(self.prompt_config.LOCATION)) 26 | 27 | num_tokens = self.prompt_config.NUM_TOKENS 28 | 29 | self.num_tokens = num_tokens 30 | self.prompt_dropout = Dropout(self.prompt_config.DROPOUT) 31 | 32 | # initiate prompt: 33 | if self.prompt_config.INITIATION == "random": 34 | val = math.sqrt(6. / float(3 * reduce(mul, self.patch_embed.patch_size, 1) + self.embed_dim)) # noqa 35 | 36 | self.prompt_embeddings = nn.Parameter(torch.zeros( 37 | 1, num_tokens, self.embed_dim)) 38 | # xavier_uniform initialization 39 | nn.init.uniform_(self.prompt_embeddings.data, -val, val) 40 | 41 | if self.prompt_config.DEEP: 42 | self.deep_prompt_embeddings = nn.Parameter(torch.zeros( 43 | len(self.blocks) - 1, 44 | num_tokens, self.embed_dim 45 | )) 46 | # xavier_uniform initialization 47 | nn.init.uniform_( 48 | self.deep_prompt_embeddings.data, -val, val) 49 | 50 | else: 51 | raise ValueError("Other initiation scheme is not supported") 52 | 53 | def incorporate_prompt(self, x): 54 | # combine prompt embeddings with image-patch embeddings 55 | B = x.shape[0] 56 | if self.prompt_config.LOCATION == "prepend": 57 | # after CLS token, all before image patches 58 | x = self.embeddings(x) # (batch_size, 1 + n_patches, hidden_dim) 59 | x = torch.cat(( 60 | x[:, :1, :], 61 | self.prompt_dropout( 62 | self.prompt_embeddings.expand(B, -1, -1)), 63 | x[:, 1:, :] 64 | ), dim=1) 65 | # (batch_size, cls_token + n_prompt + n_patches, hidden_dim) 66 | 67 | else: 68 | raise ValueError("Other prompt locations are not supported") 69 | return x 70 | 71 | def embeddings(self, x): 72 | B = x.shape[0] 73 | x = self.patch_embed(x) 74 | 75 | cls_tokens = self.cls_token.expand(B, -1, -1) 76 | x = torch.cat((cls_tokens, x), dim=1) 77 | x = x + self.pos_embed 78 | x = self.pos_drop(x) 79 | return x 80 | 81 | def train(self, mode=True): 82 | # set train status for this class: disable all but the prompt-related modules 83 | if mode: 84 | # training: 85 | self.blocks.eval() 86 | self.patch_embed.eval() 87 | self.pos_drop.eval() 88 | self.prompt_dropout.train() 89 | else: 90 | # eval: 91 | for module in self.children(): 92 | module.train(mode) 93 | 94 | def forward_features(self, x): 95 | x = self.incorporate_prompt(x) 96 | 97 | if self.prompt_config.DEEP: 98 | B = x.shape[0] 99 | num_layers = len(self.blocks) 100 | 101 | for i in range(num_layers): 102 | if i == 0: 103 | x = self.blocks[i](x) 104 | else: 105 | # prepend 106 | x = torch.cat(( 107 | x[:, :1, :], 108 | self.prompt_dropout( 109 | self.deep_prompt_embeddings[i-1].expand(B, -1, -1) 110 | ), 111 | x[:, (1 + self.num_tokens):, :] 112 | ), dim=1) 113 | x = self.blocks[i](x) 114 | else: 115 | for blk in self.blocks: 116 | x = blk(x) 117 | 118 | if self.prompt_config.VIT_POOL_TYPE == "imgprompt_pool": 119 | assert self.prompt_config.LOCATION == "prepend" 120 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 121 | outcome = self.fc_norm(x) 122 | elif self.prompt_config.VIT_POOL_TYPE == "original": 123 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 124 | outcome = self.fc_norm(x) 125 | elif self.prompt_config.VIT_POOL_TYPE == "img_pool": 126 | assert self.prompt_config.LOCATION == "prepend" 127 | x = x[:, self.num_tokens+1:, :].mean(dim=1) 128 | outcome = self.fc_norm(x) 129 | elif self.prompt_config.VIT_POOL_TYPE == "prompt_pool": 130 | assert self.prompt_config.LOCATION == "prepend" 131 | x = x[:, 1:self.num_tokens+1, :].mean(dim=1) 132 | outcome = self.fc_norm(x) 133 | else: 134 | raise ValueError("pooling type for output is not supported") 135 | 136 | return outcome 137 | 138 | 139 | def build_model(model_type, prompt_cfg): 140 | if "vitb" in model_type: 141 | return vit_base_patch16(prompt_cfg) 142 | elif "vitl" in model_type: 143 | return vit_large_patch16(prompt_cfg) 144 | elif "vith" in model_type: 145 | return vit_huge_patch14(prompt_cfg) 146 | 147 | 148 | def vit_base_patch16(prompt_cfg, **kwargs): 149 | model = PromptedVisionTransformer( 150 | prompt_cfg, 151 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 152 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 153 | mlp_ratio=4, qkv_bias=True, 154 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 155 | return model 156 | 157 | 158 | def vit_large_patch16(prompt_cfg, **kwargs): 159 | model = PromptedVisionTransformer( 160 | prompt_cfg, 161 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 162 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 163 | mlp_ratio=4, qkv_bias=True, 164 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 165 | return model 166 | 167 | 168 | def vit_huge_patch14(prompt_cfg, **kwargs): 169 | model = PromptedVisionTransformer( 170 | prompt_cfg, 171 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 172 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 173 | mlp_ratio=4, qkv_bias=True, 174 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 175 | return model 176 | 177 | 178 | -------------------------------------------------------------------------------- /src/data/datasets/tf_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """a dataset that handles output of tf.data: support datasets from VTAB""" 4 | import functools 5 | import tensorflow.compat.v1 as tf 6 | import torch 7 | import torch.utils.data 8 | import numpy as np 9 | 10 | from collections import Counter 11 | from torch import Tensor 12 | 13 | from ..vtab_datasets import base 14 | # pylint: disable=unused-import 15 | from ..vtab_datasets import caltech 16 | from ..vtab_datasets import cifar 17 | from ..vtab_datasets import clevr 18 | from ..vtab_datasets import diabetic_retinopathy 19 | from ..vtab_datasets import dmlab 20 | from ..vtab_datasets import dsprites 21 | from ..vtab_datasets import dtd 22 | from ..vtab_datasets import eurosat 23 | from ..vtab_datasets import kitti 24 | from ..vtab_datasets import oxford_flowers102 25 | from ..vtab_datasets import oxford_iiit_pet 26 | from ..vtab_datasets import patch_camelyon 27 | from ..vtab_datasets import resisc45 28 | from ..vtab_datasets import smallnorb 29 | from ..vtab_datasets import sun397 30 | from ..vtab_datasets import svhn 31 | from ..vtab_datasets.registry import Registry 32 | 33 | from ...utils import logging 34 | logger = logging.get_logger("visual_prompt") 35 | tf.config.experimental.set_visible_devices([], 'GPU') # set tensorflow to not use gpu # noqa 36 | DATASETS = [ 37 | 'caltech101', 38 | 'cifar(num_classes=100)', 39 | 'dtd', 40 | 'oxford_flowers102', 41 | 'oxford_iiit_pet', 42 | 'patch_camelyon', 43 | 'sun397', 44 | 'svhn', 45 | 'resisc45', 46 | 'eurosat', 47 | 'dmlab', 48 | 'kitti(task="closest_vehicle_distance")', 49 | 'smallnorb(predicted_attribute="label_azimuth")', 50 | 'smallnorb(predicted_attribute="label_elevation")', 51 | 'dsprites(predicted_attribute="label_x_position",num_classes=16)', 52 | 'dsprites(predicted_attribute="label_orientation",num_classes=16)', 53 | 'clevr(task="closest_object_distance")', 54 | 'clevr(task="count_all")', 55 | 'diabetic_retinopathy(config="btgraham-300")' 56 | ] 57 | 58 | 59 | class TFDataset(torch.utils.data.Dataset): 60 | def __init__(self, cfg, split): 61 | assert split in { 62 | "train", 63 | "val", 64 | "test", 65 | "trainval" 66 | }, "Split '{}' not supported for {} dataset".format( 67 | split, cfg.DATA.NAME) 68 | logger.info("Constructing {} dataset {}...".format( 69 | cfg.DATA.NAME, split)) 70 | 71 | self.cfg = cfg 72 | self._split = split 73 | self.name = cfg.DATA.NAME 74 | 75 | self.img_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) 76 | self.img_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) 77 | 78 | self.get_data(cfg, split) 79 | 80 | def get_data(self, cfg, split): 81 | tf_data = build_tf_dataset(cfg, split) 82 | data_list = list(tf_data) # a list of tuples 83 | 84 | self._image_tensor_list = [t[0].numpy().squeeze() for t in data_list] 85 | self._targets = [int(t[1].numpy()[0]) for t in data_list] 86 | self._class_ids = sorted(list(set(self._targets))) 87 | 88 | logger.info("Number of images: {}".format(len(self._image_tensor_list))) 89 | logger.info("Number of classes: {} / {}".format( 90 | len(self._class_ids), self.get_class_num())) 91 | 92 | del data_list 93 | del tf_data 94 | 95 | def get_info(self): 96 | num_imgs = len(self._image_tensor_list) 97 | return num_imgs, self.get_class_num() 98 | 99 | def get_class_num(self): 100 | return self.cfg.DATA.NUMBER_CLASSES 101 | 102 | def get_class_weights(self, weight_type): 103 | """get a list of class weight, return a list float""" 104 | if "train" not in self._split: 105 | raise ValueError( 106 | "only getting training class distribution, " + \ 107 | "got split {} instead".format(self._split) 108 | ) 109 | 110 | cls_num = self.get_class_num() 111 | if weight_type == "none": 112 | return [1.0] * cls_num 113 | 114 | id2counts = Counter(self._class_ids) 115 | assert len(id2counts) == cls_num 116 | num_per_cls = np.array([id2counts[i] for i in self._class_ids]) 117 | 118 | if weight_type == 'inv': 119 | mu = -1.0 120 | elif weight_type == 'inv_sqrt': 121 | mu = -0.5 122 | weight_list = num_per_cls ** mu 123 | weight_list = np.divide( 124 | weight_list, np.linalg.norm(weight_list, 1)) * cls_num 125 | return weight_list.tolist() 126 | 127 | def __getitem__(self, index): 128 | # Load the image 129 | label = self._targets[index] 130 | im = to_torch_imgs( 131 | self._image_tensor_list[index], self.img_mean, self.img_std) 132 | 133 | if self._split == "train": 134 | index = index 135 | else: 136 | index = f"{self._split}{index}" 137 | sample = { 138 | "image": im, 139 | "label": label, 140 | # "id": index 141 | } 142 | return sample 143 | 144 | def __len__(self): 145 | return len(self._targets) 146 | 147 | 148 | def preprocess_fn(data, size=224, input_range=(0.0, 1.0)): 149 | image = data["image"] 150 | image = tf.image.resize(image, [size, size]) 151 | 152 | image = tf.cast(image, tf.float32) / 255.0 153 | image = image * (input_range[1] - input_range[0]) + input_range[0] 154 | 155 | data["image"] = image 156 | return data 157 | 158 | 159 | def build_tf_dataset(cfg, mode): 160 | """ 161 | Builds a tf data instance, then transform to a list of tensors and labels 162 | """ 163 | 164 | if mode not in ["train", "val", "test", "trainval"]: 165 | raise ValueError("The input pipeline supports `train`, `val`, `test`." 166 | "Provided mode is {}".format(mode)) 167 | 168 | vtab_dataname = cfg.DATA.NAME.split("vtab-")[-1] 169 | data_dir = cfg.DATA.DATAPATH 170 | if vtab_dataname in DATASETS: 171 | data_cls = Registry.lookup("data." + vtab_dataname) 172 | vtab_tf_dataloader = data_cls(data_dir=data_dir) 173 | else: 174 | raise ValueError("Unknown type for \"dataset\" field: {}".format( 175 | type(vtab_dataname))) 176 | 177 | split_name_dict = { 178 | "dataset_train_split_name": "train800", 179 | "dataset_val_split_name": "val200", 180 | "dataset_trainval_split_name": "train800val200", 181 | "dataset_test_split_name": "test", 182 | } 183 | 184 | def _dict_to_tuple(batch): 185 | return batch['image'], batch['label'] 186 | 187 | return vtab_tf_dataloader.get_tf_data( 188 | batch_size=1, # data_params["batch_size"], 189 | drop_remainder=False, 190 | split_name=split_name_dict[f"dataset_{mode}_split_name"], 191 | preprocess_fn=functools.partial( 192 | preprocess_fn, 193 | input_range=(0.0, 1.0), 194 | size=cfg.DATA.CROPSIZE, 195 | ), 196 | for_eval=mode != "train", # handles shuffling 197 | shuffle_buffer_size=1000, 198 | prefetch=1, 199 | train_examples=None, 200 | epochs=1 # setting epochs to 1 make sure it returns one copy of the dataset 201 | ).map(_dict_to_tuple) # return a PrefetchDataset object. (which does not have much documentation to go on) 202 | 203 | 204 | def to_torch_imgs(img: np.ndarray, mean: Tensor, std: Tensor) -> Tensor: 205 | t_img: Tensor = torch.from_numpy(np.transpose(img, (2, 0, 1))) 206 | t_img -= mean 207 | t_img /= std 208 | 209 | return t_img 210 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/kitti.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements Kitti data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import numpy as np 24 | 25 | import tensorflow.compat.v1 as tf 26 | import tensorflow_datasets as tfds 27 | 28 | from . import base as base 29 | from .registry import Registry 30 | 31 | 32 | def _count_all_pp(x): 33 | """Count all objects.""" 34 | # Count distribution (thresholded at 15): 35 | 36 | label = tf.math.minimum(tf.size(x["objects"]["type"]) - 1, 8) 37 | return {"image": x["image"], "label": label} 38 | 39 | 40 | def _count_vehicles_pp(x): 41 | """Counting vehicles.""" 42 | # Label distribution: 43 | 44 | vehicles = tf.where(x["objects"]["type"] < 3) # Car, Van, Truck. 45 | # Cap at 3. 46 | label = tf.math.minimum(tf.size(vehicles), 3) 47 | return {"image": x["image"], "label": label} 48 | 49 | 50 | def _count_left_pp(x): 51 | """Count objects on the left hand side of the camera.""" 52 | # Count distribution (thresholded at 15): 53 | 54 | # Location feature contains (x, y, z) in meters w.r.t. the camera. 55 | objects_on_left = tf.where(x["objects"]["location"][:, 0] < 0) 56 | label = tf.math.minimum(tf.size(objects_on_left), 8) 57 | return {"image": x["image"], "label": label} 58 | 59 | 60 | def _count_far_pp(x): 61 | """Counts objects far from the camera.""" 62 | # Threshold removes ~half of the objects. 63 | # Count distribution (thresholded at 15): 64 | 65 | # Location feature contains (x, y, z) in meters w.r.t. the camera. 66 | distant_objects = tf.where(x["objects"]["location"][:, 2] >= 25) 67 | label = tf.math.minimum(tf.size(distant_objects), 8) 68 | return {"image": x["image"], "label": label} 69 | 70 | 71 | def _count_near_pp(x): 72 | """Counts objects close to the camera.""" 73 | # Threshold removes ~half of the objects. 74 | # Count distribution: 75 | 76 | # Location feature contains (x, y, z) in meters w.r.t. the camera. 77 | close_objects = tf.where(x["objects"]["location"][:, 2] < 25) 78 | label = tf.math.minimum(tf.size(close_objects), 8) 79 | return {"image": x["image"], "label": label} 80 | 81 | 82 | def _closest_object_distance_pp(x): 83 | """Predict the distance to the closest object.""" 84 | # Label distribution: 85 | 86 | # Location feature contains (x, y, z) in meters w.r.t. the camera. 87 | dist = tf.reduce_min(x["objects"]["location"][:, 2]) 88 | thrs = np.array([-100, 5.6, 8.4, 13.4, 23.4]) 89 | label = tf.reduce_max(tf.where((thrs - dist) < 0)) 90 | return {"image": x["image"], "label": label} 91 | 92 | 93 | def _closest_vehicle_distance_pp(x): 94 | """Predict the distance to the closest vehicle.""" 95 | # Label distribution: 96 | 97 | # Location feature contains (x, y, z) in meters w.r.t. the camera. 98 | vehicles = tf.where(x["objects"]["type"] < 3) # Car, Van, Truck. 99 | vehicle_z = tf.gather(params=x["objects"]["location"][:, 2], indices=vehicles) 100 | vehicle_z = tf.concat([vehicle_z, tf.constant([[1000.0]])], axis=0) 101 | dist = tf.reduce_min(vehicle_z) 102 | # Results in a uniform distribution over three distances, plus one class for 103 | # "no vehicle". 104 | thrs = np.array([-100.0, 8.0, 20.0, 999.0]) 105 | label = tf.reduce_max(tf.where((thrs - dist) < 0)) 106 | return {"image": x["image"], "label": label} 107 | 108 | 109 | def _closest_object_x_location_pp(x): 110 | """Predict the absolute x position of the closest object.""" 111 | # Label distribution: 112 | 113 | # Location feature contains (x, y, z) in meters w.r.t. the camera. 114 | idx = tf.math.argmin(x["objects"]["location"][:, 2]) 115 | xloc = x["objects"]["location"][idx, 0] 116 | thrs = np.array([-100, -6.4, -3.5, 0.0, 3.3, 23.9]) 117 | label = tf.reduce_max(tf.where((thrs - xloc) < 0)) 118 | return {"image": x["image"], "label": label} 119 | 120 | 121 | _TASK_DICT = { 122 | "count_all": { 123 | "preprocess_fn": _count_all_pp, 124 | "num_classes": 16, 125 | }, 126 | "count_left": { 127 | "preprocess_fn": _count_left_pp, 128 | "num_classes": 16, 129 | }, 130 | "count_far": { 131 | "preprocess_fn": _count_far_pp, 132 | "num_classes": 16, 133 | }, 134 | "count_near": { 135 | "preprocess_fn": _count_near_pp, 136 | "num_classes": 16, 137 | }, 138 | "closest_object_distance": { 139 | "preprocess_fn": _closest_object_distance_pp, 140 | "num_classes": 5, 141 | }, 142 | "closest_object_x_location": { 143 | "preprocess_fn": _closest_object_x_location_pp, 144 | "num_classes": 5, 145 | }, 146 | "count_vehicles": { 147 | "preprocess_fn": _count_vehicles_pp, 148 | "num_classes": 4, 149 | }, 150 | "closest_vehicle_distance": { 151 | "preprocess_fn": _closest_vehicle_distance_pp, 152 | "num_classes": 4, 153 | }, 154 | } 155 | 156 | 157 | @Registry.register("data.kitti", "class") 158 | class KittiData(base.ImageTfdsData): 159 | """Provides Kitti dataset. 160 | 161 | Six tasks are supported: 162 | 1. Count the number of objects. 163 | 2. Count the number of objects on the left hand side of the camera. 164 | 3. Count the number of objects in the foreground. 165 | 4. Count the number of objects in the background. 166 | 5. Predict the distance of the closest object. 167 | 6. Predict the x-location (w.r.t. the camera) of the closest object. 168 | """ 169 | 170 | def __init__(self, task, data_dir=None): 171 | 172 | if task not in _TASK_DICT: 173 | raise ValueError("Unknown task: %s" % task) 174 | 175 | dataset_builder = tfds.builder("kitti:3.2.0", data_dir=data_dir) 176 | dataset_builder.download_and_prepare() 177 | 178 | tfds_splits = { 179 | "train": "train", 180 | "val": "validation", 181 | "trainval": "train+validation", 182 | "test": "test", 183 | "train800": "train[:800]", 184 | "val200": "validation[:200]", 185 | "train800val200": "train[:800]+validation[:200]", 186 | } 187 | 188 | # Example counts are retrieved from the tensorflow dataset info. 189 | train_count = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples 190 | val_count = dataset_builder.info.splits[tfds.Split.VALIDATION].num_examples 191 | test_count = dataset_builder.info.splits[tfds.Split.TEST].num_examples 192 | # Creates a dict with example counts for each split. 193 | num_samples_splits = { 194 | "train": train_count, 195 | "val": val_count, 196 | "trainval": train_count + val_count, 197 | "test": test_count, 198 | "train800": 800, 199 | "val200": 200, 200 | "train800val200": 1000, 201 | } 202 | 203 | task = _TASK_DICT[task] 204 | base_preprocess_fn = task["preprocess_fn"] 205 | super(KittiData, self).__init__( 206 | dataset_builder=dataset_builder, 207 | tfds_splits=tfds_splits, 208 | num_samples_splits=num_samples_splits, 209 | num_preprocessing_threads=400, 210 | shuffle_buffer_size=10000, 211 | base_preprocess_fn=base_preprocess_fn, 212 | num_classes=task["num_classes"]) 213 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Visual Prompt Tuning 2 | 3 | https://arxiv.org/abs/2203.12119 4 | 5 | ------ 6 | 7 | This repository contains the official PyTorch implementation for Visual Prompt Tuning. 8 | 9 | ![vpt_teaser](https://github.com/KMnP/vpt/blob/main/imgs/teaser.png) 10 | 11 | ## Environment settings 12 | 13 | See `env_setup.sh` 14 | 15 | ## Structure of the this repo (key files are marked with 👉): 16 | 17 | - `src/configs`: handles config parameters for the experiments. 18 | 19 | * 👉 `src/config/config.py`: main config setups for experiments and explanation for each of them. 20 | 21 | - `src/data`: loading and setup input datasets. The `src/data/vtab_datasets` are borrowed from 22 | 23 | [VTAB github repo](https://github.com/google-research/task_adaptation/tree/master/task_adaptation/data). 24 | 25 | 26 | - `src/engine`: main training and eval actions here. 27 | 28 | - `src/models`: handles backbone archs and heads for different fine-tuning protocols 29 | 30 | * 👉`src/models/vit_prompt`: a folder contains the same backbones in `vit_backbones` folder, specified for VPT. This folder should contain the same file names as those in `vit_backbones` 31 | 32 | * 👉 `src/models/vit_models.py`: main model for transformer-based models ❗️Note❗️: Current version only support ViT, Swin and ViT with mae, moco-v3 33 | 34 | * `src/models/build_model.py`: main action here to utilize the config and build the model to train / eval. 35 | 36 | - `src/solver`: optimization, losses and learning rate schedules. 37 | - `src/utils`: helper functions for io, loggings, training, visualizations. 38 | - 👉`train.py`: call this one for training and eval a model with a specified transfer type. 39 | - 👉`tune_fgvc.py`: call this one for tuning learning rate and weight decay for a model with a specified transfer type. We used this script for FGVC tasks. 40 | - 👉`tune_vtab.py`: call this one for tuning vtab tasks: use 800/200 split to find the best lr and wd, and use the best lr/wd for the final runs 41 | - `launch.py`: contains functions used to launch the job. 42 | 43 | ## Experiments 44 | 45 | ### Key configs: 46 | 47 | - 🔥VPT related: 48 | - MODEL.PROMPT.NUM_TOKENS: prompt length 49 | - MODEL.PROMPT.DEEP: deep or shallow prompt 50 | - Fine-tuning method specification: 51 | - MODEL.TRANSFER_TYPE 52 | - Vision backbones: 53 | - DATA.FEATURE: specify which representation to use 54 | - MODEL.TYPE: the general backbone type, e.g., "vit" or "swin" 55 | - MODEL.MODEL_ROOT: folder with pre-trained model checkpoints 56 | - Optimization related: 57 | - SOLVER.BASE_LR: learning rate for the experiment 58 | - SOLVER.WEIGHT_DECAY: weight decay value for the experiment 59 | - DATA.BATCH_SIZE 60 | - Datasets related: 61 | - DATA.NAME 62 | - DATA.DATAPATH: where you put the datasets 63 | - DATA.NUMBER_CLASSES 64 | - Others: 65 | - RUN_N_TIMES: ensure only run once in case for duplicated submision, not used during vtab runs 66 | - OUTPUT_DIR: output dir of the final model and logs 67 | - MODEL.SAVE_CKPT: if set to `True`, will save model ckpts and final output of both val and test set 68 | 69 | ### Datasets preperation: 70 | 71 | See Table 8 in the Appendix for dataset details. 72 | 73 | - Fine-Grained Visual Classification tasks (FGVC): The datasets can be downloaded following the official links. We split the training data if the public validation set is not available. The splitted dataset can be found here: [Dropbox](https://cornell.box.com/v/vptfgvcsplits), [Google Drive](https://drive.google.com/drive/folders/1mnvxTkYxmOr2W9QjcgS64UBpoJ4UmKaM?usp=sharing). 74 | 75 | - [CUB200 2011](https://data.caltech.edu/records/65de6-vp158) 76 | 77 | - [NABirds](http://info.allaboutbirds.org/nabirds/) 78 | 79 | - [Oxford Flowers](https://www.robots.ox.ac.uk/~vgg/data/flowers/) 80 | 81 | - [Stanford Dogs](http://vision.stanford.edu/aditya86/ImageNetDogs/main.html) 82 | 83 | - [Stanford Cars](https://ai.stanford.edu/~jkrause/cars/car_dataset.html) 84 | 85 | - [Visual Task Adaptation Benchmark](https://google-research.github.io/task_adaptation/) (VTAB): see [`VTAB_SETUP.md`](https://github.com/KMnP/vpt/blob/main/VTAB_SETUP.md) for detailed instructions and tips. 86 | 87 | ### Pre-trained model preperation 88 | 89 | Download and place the pre-trained Transformer-based backbones to `MODEL.MODEL_ROOT` (ConvNeXt-Base and ResNet50 would be automatically downloaded via the links in the code). Note that you also need to rename the downloaded ViT-B/16 ckpt from `ViT-B_16.npz` to `imagenet21k_ViT-B_16.npz`. 90 | 91 | See Table 9 in the Appendix for more details about pre-trained backbones. 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 |
Pre-trained BackbonePre-trained ObjectiveLinkmd5sum
ViT-B/16Supervisedlinkd9715d
ViT-B/16MoCo v3link8f39ce
ViT-B/16MAElink8cad7c
Swin-BSupervisedlinkbf9cc1
ConvNeXt-BaseSupervisedlink-
ResNet-50Supervisedlink-
133 | 134 | ### Examples for training and aggregating results 135 | 136 | See [`demo.ipynb`](https://github.com/KMnP/vpt/blob/main/demo.ipynb) for how to use this repo. 137 | 138 | ### Hyperparameters for experiments in paper 139 | 140 | The hyperparameter values used (prompt length for VPT / reduction rate for Adapters, base learning rate, weight decay values) in Table 1-2, Fig. 3-4, Table 4-5 can be found here: [Dropbox](https://cornell.box.com/s/lv10kptgyrm8uxb6v6ctugrhao24rs2z) / [Google Drive](https://drive.google.com/drive/folders/1ldhqkXelHDXq4bG7qpKn5YEfU6sRehJH?usp=sharing). 141 | 142 | ## Citation 143 | 144 | If you find our work helpful in your research, please cite it as: 145 | 146 | ``` 147 | @inproceedings{jia2022vpt, 148 | title={Visual Prompt Tuning}, 149 | author={Jia, Menglin and Tang, Luming and Chen, Bor-Chun and Cardie, Claire and Belongie, Serge and Hariharan, Bharath and Lim, Ser-Nam}, 150 | booktitle={European Conference on Computer Vision (ECCV)}, 151 | year={2022} 152 | } 153 | ``` 154 | 155 | ## License 156 | 157 | The majority of VPT is licensed under the CC-BY-NC 4.0 license (see [LICENSE](https://github.com/KMnP/vpt/blob/main/LICENSE) for details). Portions of the project are available under separate license terms: GitHub - [google-research/task_adaptation](https://github.com/google-research/task_adaptation) and [huggingface/transformers](https://github.com/huggingface/transformers) are licensed under the Apache 2.0 license; [Swin-Transformer](https://github.com/microsoft/Swin-Transformer), [ConvNeXt](https://github.com/facebookresearch/ConvNeXt) and [ViT-pytorch](https://github.com/jeonsworld/ViT-pytorch) are licensed under the MIT license; and [MoCo-v3](https://github.com/facebookresearch/moco-v3) and [MAE](https://github.com/facebookresearch/mae) are licensed under the Attribution-NonCommercial 4.0 International license. 158 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/diabetic_retinopathy.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements Diabetic Retinopathy data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow.compat.v1 as tf 24 | import tensorflow_addons.image as tfa_image 25 | import tensorflow_datasets as tfds 26 | 27 | from . import base as base 28 | from .registry import Registry 29 | 30 | 31 | @Registry.register("data.diabetic_retinopathy", "class") 32 | class RetinopathyData(base.ImageTfdsData): 33 | """Provides Diabetic Retinopathy classification data. 34 | 35 | Retinopathy comes only with a training and test set. Therefore, the validation 36 | set is split out of the original training set, and the remaining examples are 37 | used as the "train" split. The "trainval" split corresponds to the original 38 | training set. 39 | 40 | For additional details and usage, see the base class. 41 | """ 42 | 43 | _CONFIGS_WITH_GREY_BACKGROUND = ["btgraham-300"] 44 | 45 | def __init__(self, config="btgraham-300", heavy_train_augmentation=False, 46 | data_dir=None): 47 | """Initializer for Diabetic Retinopathy dataset. 48 | 49 | Args: 50 | config: Name of the TFDS config to use for this dataset. 51 | heavy_train_augmentation: If True, use heavy data augmentation on the 52 | training data. Recommended to achieve SOTA. 53 | data_dir: directory for downloading and storing the data. 54 | """ 55 | config_and_version = config + ":3.*.*" 56 | dataset_builder = tfds.builder("diabetic_retinopathy_detection/{}".format( 57 | config_and_version), data_dir=data_dir) 58 | self._config = config 59 | self._heavy_train_augmentation = heavy_train_augmentation 60 | 61 | dataset_builder.download_and_prepare() 62 | 63 | # Defines dataset specific train/val/trainval/test splits. 64 | tfds_splits = { 65 | "train": "train", 66 | "val": "validation", 67 | "trainval": "train+validation", 68 | "test": "test", 69 | "train800": "train[:800]", 70 | "val200": "validation[:200]", 71 | "train800val200": "train[:800]+validation[:200]", 72 | } 73 | 74 | # Creates a dict with example counts for each split. 75 | train_count = dataset_builder.info.splits["train"].num_examples 76 | val_count = dataset_builder.info.splits["validation"].num_examples 77 | test_count = dataset_builder.info.splits["test"].num_examples 78 | num_samples_splits = { 79 | "train": train_count, 80 | "val": val_count, 81 | "trainval": train_count + val_count, 82 | "test": test_count, 83 | "train800": 800, 84 | "val200": 200, 85 | "train800val200": 1000, 86 | } 87 | 88 | super(RetinopathyData, self).__init__( 89 | dataset_builder=dataset_builder, 90 | tfds_splits=tfds_splits, 91 | num_samples_splits=num_samples_splits, 92 | num_preprocessing_threads=400, 93 | shuffle_buffer_size=10000, 94 | # Note: Export only image and label tensors with their original types. 95 | base_preprocess_fn=base.make_get_tensors_fn(["image", "label"]), 96 | num_classes=dataset_builder.info.features["label"].num_classes) 97 | 98 | @property 99 | def config(self): 100 | return self._config 101 | 102 | @property 103 | def heavy_train_augmentation(self): 104 | return self._heavy_train_augmentation 105 | 106 | def get_tf_data(self, 107 | split_name, 108 | batch_size, 109 | preprocess_fn=None, 110 | for_eval=False, 111 | **kwargs): 112 | if self._heavy_train_augmentation and not for_eval: 113 | preprocess_fn = base.compose_preprocess_fn( 114 | self._heavy_train_augmentation, preprocess_fn) 115 | 116 | return super(RetinopathyData, self).get_tf_data( 117 | split_name=split_name, 118 | batch_size=batch_size, 119 | preprocess_fn=preprocess_fn, 120 | for_eval=for_eval, 121 | **kwargs) 122 | 123 | def _sample_heavy_data_augmentation_parameters(self): 124 | # Scale image +/- 10%. 125 | s = tf.random.uniform(shape=(), minval=-0.1, maxval=0.1) 126 | # Rotate image [0, 2pi). 127 | a = tf.random.uniform(shape=(), minval=0.0, maxval=2.0 * 3.1415926535) 128 | # Vertically shear image +/- 20%. 129 | b = tf.random.uniform(shape=(), minval=-0.2, maxval=0.2) + a 130 | # Horizontal and vertial flipping. 131 | hf = tf.random.shuffle([-1.0, 1.0])[0] 132 | vf = tf.random.shuffle([-1.0, 1.0])[0] 133 | # Relative x,y translation. 134 | dx = tf.random.uniform(shape=(), minval=-0.1, maxval=0.1) 135 | dy = tf.random.uniform(shape=(), minval=-0.1, maxval=0.1) 136 | return s, a, b, hf, vf, dx, dy 137 | 138 | def _heavy_data_augmentation_fn(self, example): 139 | """Perform heavy augmentation on a given input data example. 140 | 141 | This is the same data augmentation as the one done by Ben Graham, the winner 142 | of the 2015 Kaggle competition. See: 143 | https://github.com/btgraham/SparseConvNet/blob/a6bdb0c938b3556c1e6c23d5a014db9f404502b9/kaggleDiabetes1.cpp#L12 144 | 145 | Args: 146 | example: A dictionary containing an "image" key with the image to 147 | augment. 148 | 149 | Returns: 150 | The input dictionary with the key "image" containing the augmented image. 151 | """ 152 | image = example["image"] 153 | image_shape = tf.shape(image) 154 | if len(image.get_shape().as_list()) not in [2, 3]: 155 | raise ValueError( 156 | "Input image must be a rank-2 or rank-3 tensor, but rank-{} " 157 | "was given".format(len(image.get_shape().as_list()))) 158 | height = tf.cast(image_shape[0], dtype=tf.float32) 159 | width = tf.cast(image_shape[1], dtype=tf.float32) 160 | # Sample data augmentation parameters. 161 | s, a, b, hf, vf, dx, dy = self._sample_heavy_data_augmentation_parameters() 162 | # Rotation + scale. 163 | c00 = (1 + s) * tf.cos(a) 164 | c01 = (1 + s) * tf.sin(a) 165 | c10 = (s - 1) * tf.sin(b) 166 | c11 = (1 - s) * tf.cos(b) 167 | # Horizontal and vertial flipping. 168 | c00 = c00 * hf 169 | c01 = c01 * hf 170 | c10 = c10 * vf 171 | c11 = c11 * vf 172 | # Convert x,y translation to absolute values. 173 | dx = width * dx 174 | dy = height * dy 175 | # Convert affine matrix to TF's transform. Matrix is applied w.r.t. the 176 | # center of the image. 177 | cy = height / 2.0 178 | cx = width / 2.0 179 | affine_matrix = [[c00, c01, (1.0 - c00) * cx - c01 * cy + dx], 180 | [c10, c11, (1.0 - c11) * cy - c10 * cx + dy], 181 | [0.0, 0.0, 1.0]] 182 | affine_matrix = tf.convert_to_tensor(affine_matrix, dtype=tf.float32) 183 | transform = tfa_image.transform_ops.matrices_to_flat_transforms( 184 | tf.linalg.inv(affine_matrix)) 185 | if self._config in self._CONFIGS_WITH_GREY_BACKGROUND: 186 | # Since background is grey in these configs, put in pixels in [-1, 1] 187 | # range to avoid artifacts from the affine transformation. 188 | image = tf.cast(image, dtype=tf.float32) 189 | image = (image / 127.5) - 1.0 190 | # Apply the affine transformation. 191 | image = tfa_image.transform(images=image, transforms=transform) 192 | if self._config in self._CONFIGS_WITH_GREY_BACKGROUND: 193 | # Put pixels back to [0, 255] range and cast to uint8, since this is what 194 | # our preprocessing pipeline usually expects. 195 | image = (1.0 + image) * 127.5 196 | image = tf.cast(image, dtype=tf.uint8) 197 | example["image"] = image 198 | return example 199 | --------------------------------------------------------------------------------