├── LICENSE ├── README.md ├── VTAB_SETUP.md ├── configs ├── base-finetune.yaml ├── base-linear.yaml ├── base-prompt.yaml ├── finetune │ ├── cars.yaml │ ├── cub.yaml │ ├── dogs.yaml │ ├── flowers.yaml │ └── nabirds.yaml ├── linear │ ├── cars.yaml │ ├── cub.yaml │ ├── dogs.yaml │ ├── flowers.yaml │ └── nabirds.yaml └── prompt │ ├── cars.yaml │ ├── cub.yaml │ ├── dogs.yaml │ ├── flowers.yaml │ └── nabirds.yaml ├── env_install.sh ├── launch.py ├── run.sh ├── src ├── configs │ ├── config.py │ ├── config_node.py │ └── vit_configs.py ├── data │ ├── datasets │ │ ├── json_dataset.py │ │ └── tf_dataset.py │ ├── loader.py │ ├── transforms.py │ └── vtab_datasets │ │ ├── __init__.py │ │ ├── base.py │ │ ├── caltech.py │ │ ├── cifar.py │ │ ├── clevr.py │ │ ├── diabetic_retinopathy.py │ │ ├── dmlab.py │ │ ├── dsprites.py │ │ ├── dtd.py │ │ ├── eurosat.py │ │ ├── kitti.py │ │ ├── oxford_flowers102.py │ │ ├── oxford_iiit_pet.py │ │ ├── patch_camelyon.py │ │ ├── registry.py │ │ ├── resisc45.py │ │ ├── smallnorb.py │ │ ├── sun397.py │ │ └── svhn.py ├── engine │ ├── eval │ │ ├── multilabel.py │ │ └── singlelabel.py │ ├── evaluator.py │ └── trainer.py ├── models │ ├── build_model.py │ ├── build_vit_backbone.py │ ├── mlp.py │ ├── vit_backbones │ │ ├── vit_mae.py │ │ └── vit_moco.py │ ├── vit_models.py │ └── vit_prompt │ │ ├── vit_mae.py │ │ └── vit_moco.py ├── solver │ ├── losses.py │ ├── lr_scheduler.py │ └── optimizer.py └── utils │ ├── distributed.py │ ├── distributed_orig.py │ ├── file_io.py │ ├── io_utils.py │ ├── logging.py │ ├── train_utils.py │ └── vis_utils.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 ryongithub 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gated Prompt Tuning 2 | This is the official PyTorch implementation for "Improving Visual Prompt Tuning for Self-supervised Vision Transformers" [ICML 2023]. 3 | 4 | This repository is heavily based on the official PyTorch implementation of "Visual Prompt Tuning" [ECCV 2022] : [KMnp/vpt](https://github.com/KMnP/vpt). 5 | 6 | 7 | 8 | # Requirements 9 | - python 3.8.12 10 | - PyTorch 1.7.1 11 | - torchvision 0.8.2 12 | - timm 0.5.4 13 | - CUDA 11.0 14 | - RTX 8000 GPU 15 | 16 | # Environment setup 17 | ``` 18 | conda create -n [ENV_NAME] python=3.8.12 -y 19 | conda activate [ENV_NAME] 20 | bash env_install.sh 21 | ``` 22 | 23 | # Data preparation 24 | - FGVC : The datasets should be located in the 'data' folder (CUB, OxfordFlowers, StanfordCars, StanfordDogs, NABirds) 25 | - VTAB : Please refer to [`VTAB_SETUP.md`] (in accordance with [KMnp/vpt](https://github.com/KMnP/vpt)) 26 | - A more detailed guideline for data preparation will be updated soon. 27 | 28 | # Pretraiend SSL ViTs 29 | - pretrained checkpoints for MAE, MoCo-v3 should be located in the 'params' folder. 30 | 31 | # Run experiments 32 | ``` 33 | bash run.sh [data_name] [encoder] [batch_size] [base_lr] [num_tokens] [gate_init] 34 | ``` 35 | For example for the CUB dataset, execute 36 | ``` 37 | bash run.sh cub mae_vitb16 64 0.1 100 5 38 | ``` -------------------------------------------------------------------------------- /VTAB_SETUP.md: -------------------------------------------------------------------------------- 1 | # VTAB Preperation 2 | 3 | ## Download and prepare 4 | 5 | It is recommended to download the data before the experiments, to avoid duplicated effort if submitting experiments for multiple tuning protocols. Here are the collective command to set up the vtab data. 6 | 7 | ```python 8 | import tensorflow_datasets as tfds 9 | data_dir = "" # TODO: setup the data_dir to put the the data to, the DATA.DATAPATH value in config 10 | 11 | # caltech101 12 | dataset_builder = tfds.builder("caltech101:3.*.*", data_dir=data_dir) 13 | dataset_builder.download_and_prepare() 14 | 15 | # cifar100 16 | dataset_builder = tfds.builder("cifar100:3.*.*", data_dir=data_dir) 17 | dataset_builder.download_and_prepare() 18 | 19 | # clevr 20 | dataset_builder = tfds.builder("clevr:3.*.*", data_dir=data_dir) 21 | dataset_builder.download_and_prepare() 22 | 23 | # dmlab 24 | dataset_builder = tfds.builder("dmlab:2.0.1", data_dir=data_dir) 25 | dataset_builder.download_and_prepare() 26 | 27 | # dsprites 28 | dataset_builder = tfds.builder("dsprites:2.*.*", data_dir=data_dir) 29 | dataset_builder.download_and_prepare() 30 | 31 | # dtd 32 | dataset_builder = tfds.builder("dtd:3.*.*", data_dir=data_dir) 33 | dataset_builder.download_and_prepare() 34 | 35 | # eurosat 36 | subset="rgb" 37 | dataset_name = "eurosat/{}:2.*.*".format(subset) 38 | dataset_builder = tfds.builder(dataset_name, data_dir=data_dir) 39 | dataset_builder.download_and_prepare() 40 | 41 | # oxford_flowers102 42 | dataset_builder = tfds.builder("oxford_flowers102:2.*.*", data_dir=data_dir) 43 | dataset_builder.download_and_prepare() 44 | 45 | # oxford_iiit_pet 46 | dataset_builder = tfds.builder("oxford_iiit_pet:3.*.*", data_dir=data_dir) 47 | dataset_builder.download_and_prepare() 48 | 49 | # patch_camelyon 50 | dataset_builder = tfds.builder("patch_camelyon:2.*.*", data_dir=data_dir) 51 | dataset_builder.download_and_prepare() 52 | 53 | # smallnorb 54 | dataset_builder = tfds.builder("smallnorb:2.*.*", data_dir=data_dir) 55 | dataset_builder.download_and_prepare() 56 | 57 | # svhn 58 | dataset_builder = tfds.builder("svhn_cropped:3.*.*", data_dir=data_dir) 59 | dataset_builder.download_and_prepare() 60 | ``` 61 | 62 | There are 4 datasets need special care: 63 | 64 | ```python 65 | # sun397 --> need cv2 66 | # cannot load one image, similar to issue here: https://github.com/tensorflow/datasets/issues/2889 67 | # "Image /t/track/outdoor/sun_aophkoiosslinihb.jpg could not be decoded by Tensorflow."" 68 | # sol: modify the file: "/fsx/menglin/conda/envs/prompt_tf/lib/python3.7/site-packages/tensorflow_datasets/image_classification/sun.py" to ignore those images 69 | dataset_builder = tfds.builder("sun397/tfds:4.*.*", data_dir=data_dir) 70 | dataset_builder.download_and_prepare() 71 | 72 | # kitti version is wrong from vtab repo, try 3.2.0 (https://github.com/google-research/task_adaptation/issues/18) 73 | dataset_builder = tfds.builder("kitti:3.2.0", data_dir=data_dir) 74 | dataset_builder.download_and_prepare() 75 | 76 | 77 | # diabetic_retinopathy 78 | """ 79 | Download this dataset from Kaggle. 80 | https://www.kaggle.com/c/diabetic-retinopathy-detection/data 81 | After downloading, 82 | - unpack the test.zip file into /manual_dir/. 83 | - unpack the sample.zip to sample/. 84 | - unpack the sampleSubmissions.csv and trainLabels.csv. 85 | 86 | # ==== important! ==== 87 | # 1. make sure to check that there are 5 train.zip files instead of 4 (somehow if you chose to download all from kaggle, the train.zip.005 file is missing) 88 | # 2. if unzip train.zip ran into issues, try to use jar xvf train.zip to handle huge zip file 89 | cat test.zip.* > test.zip 90 | cat train.zip.* > train.zip 91 | """ 92 | 93 | config_and_version = "btgraham-300" + ":3.*.*" 94 | dataset_builder = tfds.builder("diabetic_retinopathy_detection/{}".format(config_and_version), data_dir=data_dir) 95 | dataset_builder.download_and_prepare() 96 | 97 | 98 | # resisc45 99 | """ 100 | download/extract dataset artifacts manually: 101 | Dataset can be downloaded from OneDrive: https://1drv.ms/u/s!AmgKYzARBl5ca3HNaHIlzp_IXjs 102 | After downloading the rar file, please extract it to the manual_dir. 103 | """ 104 | 105 | dataset_builder = tfds.builder("resisc45:3.*.*", data_dir=data_dir) 106 | dataset_builder.download_and_prepare() 107 | ``` 108 | 109 | 110 | 111 | ## Notes 112 | 113 | ### TFDS version 114 | Note that the experimental results may be different with different API and/or dataset generation code versions. See more from [tfds documentation](https://www.tensorflow.org/datasets/datasets_versioning). Here are what we used for VPT: 115 | 116 | ```bash 117 | tfds: 4.4.0+nightly 118 | 119 | # Natural: 120 | cifar100: 3.0.2 121 | caltech101: 3.0.1 122 | dtd: 3.0.1 123 | oxford_flowers102: 2.1.1 124 | oxford_iiit_pet: 3.2.0 125 | svhn_cropped: 3.0.0 126 | sun397: 4.0.0 127 | 128 | # Specialized: 129 | patch_camelyon: 2.0.0 130 | eurosat: 2.0.0 131 | resisc45: 3.0.0 132 | diabetic_retinopathy_detection: 3.0.0 133 | 134 | 135 | # Structured 136 | clevr: 3.1.0 137 | dmlab: 2.0.1 138 | kitti: 3.2.0 139 | dsprites: 2.0.0 140 | smallnorb: 2.0.0 141 | ``` 142 | 143 | ### Train split 144 | As in issue https://github.com/KMnP/vpt/issues/1, we also uploaded the vtab train split info to the vtab data release [Google Drive](https://drive.google.com/drive/folders/1mnvxTkYxmOr2W9QjcgS64UBpoJ4UmKaM)/[Dropbox](https://cornell.app.box.com/v/vptfgvcsplits). In the file `vtab_trainval_splits.json`, for each dataset, you can find the filenames of the randomly selected 1k training examples used in our experiment. We got them by extracting the ‘filename’ attribute from the tensorflow dataset feature dict. Unfortunately, because there’s no such info for [dsprite](https://www.tensorflow.org/datasets/catalog/dsprites), [smallnorb](https://www.tensorflow.org/datasets/catalog/smallnorb) and [svhn](https://www.tensorflow.org/datasets/catalog/svhn_cropped) in the tensorflow dataset format, we cannot provide the splits for these 3 datasets. 145 | -------------------------------------------------------------------------------- /configs/base-finetune.yaml: -------------------------------------------------------------------------------- 1 | NUM_GPUS: 1 2 | NUM_SHARDS: 1 3 | OUTPUT_DIR: "" 4 | RUN_N_TIMES: 1 5 | MODEL: 6 | TRANSFER_TYPE: "end2end" 7 | TYPE: "vit" 8 | LINEAR: 9 | MLP_SIZES: [] 10 | SOLVER: 11 | SCHEDULER: "cosine" 12 | WARMUP_EPOCH: 5 13 | LOSS: "softmax" 14 | OPTIMIZER: "adamw" 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 0.0001 17 | LOG_EVERY_N: 100 18 | TOTAL_EPOCH: 100 19 | PATIENCE: 300 20 | DATA: 21 | NAME: "" 22 | NUMBER_CLASSES: -1 23 | DATAPATH: "" 24 | FEATURE: "sup_vitb16_224" 25 | BATCH_SIZE: 384 26 | NUM_WORKERS: 8 -------------------------------------------------------------------------------- /configs/base-linear.yaml: -------------------------------------------------------------------------------- 1 | NUM_GPUS: 1 2 | NUM_SHARDS: 1 3 | OUTPUT_DIR: "" 4 | RUN_N_TIMES: 1 5 | MODEL: 6 | TRANSFER_TYPE: "linear" 7 | TYPE: "vit" 8 | LINEAR: 9 | MLP_SIZES: [] 10 | SOLVER: 11 | SCHEDULER: "cosine" 12 | PATIENCE: 300 13 | LOSS: "softmax" 14 | OPTIMIZER: "sgd" 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 0.0001 17 | LOG_EVERY_N: 1 18 | WARMUP_EPOCH: 10 19 | TOTAL_EPOCH: 100 20 | DATA: 21 | NAME: "" 22 | NUMBER_CLASSES: -1 23 | DATAPATH: "" 24 | FEATURE: "sup_vitb16_224" 25 | BATCH_SIZE: 1024 -------------------------------------------------------------------------------- /configs/base-prompt.yaml: -------------------------------------------------------------------------------- 1 | NUM_GPUS: 1 2 | NUM_SHARDS: 1 3 | OUTPUT_DIR: "" 4 | RUN_N_TIMES: 1 5 | MODEL: 6 | TRANSFER_TYPE: "prompt" 7 | TYPE: "vit" 8 | LINEAR: 9 | MLP_SIZES: [] 10 | SOLVER: 11 | SCHEDULER: "cosine" 12 | PATIENCE: 300 13 | LOSS: "softmax" 14 | OPTIMIZER: "sgd" 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 0.0001 17 | LOG_EVERY_N: 100 18 | WARMUP_EPOCH: 10 19 | TOTAL_EPOCH: 100 20 | DATA: 21 | NAME: "" 22 | NUMBER_CLASSES: -1 23 | DATAPATH: "" 24 | FEATURE: "sup_vitb16_224" 25 | BATCH_SIZE: 128 26 | -------------------------------------------------------------------------------- /configs/finetune/cars.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-finetune.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "StanfordCars" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 196 7 | MULTILABEL: False 8 | FEATURE: "imagenet_supervised" # need to tune 9 | MODEL: 10 | TYPE: "vit" 11 | SOLVER: 12 | BASE_LR: 0.0375 13 | WEIGHT_DECAY: 0.001 14 | -------------------------------------------------------------------------------- /configs/finetune/cub.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-finetune.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "CUB" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 200 7 | MULTILABEL: False 8 | FEATURE: "imagenet_supervised" # need to tune 9 | MODEL: 10 | TYPE: "vit" 11 | SOLVER: 12 | BASE_LR: 0.00375 13 | WEIGHT_DECAY: 0.01 14 | -------------------------------------------------------------------------------- /configs/finetune/dogs.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-finetune.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "StanfordDogs" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 120 7 | MULTILABEL: False 8 | FEATURE: "imagenet_supervised" # need to tune 9 | MODEL: 10 | TYPE: "vit" 11 | SOLVER: 12 | BASE_LR: 0.00375 13 | WEIGHT_DECAY: 0.001 14 | -------------------------------------------------------------------------------- /configs/finetune/flowers.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-finetune.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "OxfordFlowers" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 102 7 | MULTILABEL: False 8 | FEATURE: "imagenet_supervised" # need to tune 9 | MODEL: 10 | TYPE: "vit" 11 | SOLVER: 12 | BASE_LR: 0.001 13 | WEIGHT_DECAY: 0.0001 14 | -------------------------------------------------------------------------------- /configs/finetune/nabirds.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-finetune.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "nabirds" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 555 7 | MULTILABEL: False 8 | FEATURE: "imagenet_supervised" # need to tune 9 | MODEL: 10 | TYPE: "vit" 11 | SOLVER: 12 | BASE_LR: 0.00375 13 | WEIGHT_DECAY: 0.01 14 | -------------------------------------------------------------------------------- /configs/linear/cars.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-linear.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "StanfordCars" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 196 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.001 12 | WEIGHT_DECAY: 0.0001 13 | -------------------------------------------------------------------------------- /configs/linear/cub.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-linear.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "CUB" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 200 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.1 12 | WEIGHT_DECAY: 0.01 -------------------------------------------------------------------------------- /configs/linear/dogs.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-linear.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "StanfordDogs" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 120 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.001 12 | WEIGHT_DECAY: 0.0001 -------------------------------------------------------------------------------- /configs/linear/flowers.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-linear.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "OxfordFlowers" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 102 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.001 12 | WEIGHT_DECAY: 0.0001 -------------------------------------------------------------------------------- /configs/linear/nabirds.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-linear.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "nabirds" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 555 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.1 12 | WEIGHT_DECAY: 0.01 13 | -------------------------------------------------------------------------------- /configs/prompt/cars.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-prompt.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "StanfordCars" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 196 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.001 12 | WEIGHT_DECAY: 0.0001 13 | -------------------------------------------------------------------------------- /configs/prompt/cub.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-prompt.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "CUB" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 200 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.1 12 | WEIGHT_DECAY: 0.01 -------------------------------------------------------------------------------- /configs/prompt/dogs.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-prompt.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "StanfordDogs" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 120 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.001 12 | WEIGHT_DECAY: 0.0001 -------------------------------------------------------------------------------- /configs/prompt/flowers.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-prompt.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "OxfordFlowers" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 102 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.001 12 | WEIGHT_DECAY: 0.0001 -------------------------------------------------------------------------------- /configs/prompt/nabirds.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-prompt.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "nabirds" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 555 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.1 12 | WEIGHT_DECAY: 0.01 -------------------------------------------------------------------------------- /env_install.sh: -------------------------------------------------------------------------------- 1 | pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 -f https://download.pytorch.org/whl/torch_stable.html 2 | # for VTAB datasets 3 | pip install tensorflow==2.9.1 tensorflow-addons==0.17.1 keras==2.9.0 tfds-nightly==4.4.0.dev202201080107 4 | 5 | pip install tqdm fvcore==0.1.5.post20220512 pandas six simplejson scikit-learn timm==0.5.4 ml_collections 6 | -------------------------------------------------------------------------------- /launch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | launch helper functions 4 | """ 5 | import argparse 6 | import os 7 | import sys 8 | import pprint 9 | import PIL 10 | from collections import defaultdict 11 | from tabulate import tabulate 12 | from typing import Tuple 13 | 14 | import torch 15 | from src.utils.file_io import PathManager 16 | from src.utils import logging 17 | from src.utils.distributed import get_rank, get_world_size 18 | 19 | 20 | def collect_torch_env() -> str: 21 | try: 22 | import torch.__config__ 23 | 24 | return torch.__config__.show() 25 | except ImportError: 26 | # compatible with older versions of pytorch 27 | from torch.utils.collect_env import get_pretty_env_info 28 | 29 | return get_pretty_env_info() 30 | 31 | 32 | def get_env_module() -> Tuple[str]: 33 | var_name = "ENV_MODULE" 34 | return var_name, os.environ.get(var_name, "") 35 | 36 | 37 | def collect_env_info() -> str: 38 | data = [] 39 | data.append(("Python", sys.version.replace("\n", ""))) 40 | data.append(get_env_module()) 41 | data.append(("PyTorch", torch.__version__)) 42 | data.append(("PyTorch Debug Build", torch.version.debug)) 43 | 44 | has_cuda = torch.cuda.is_available() 45 | data.append(("CUDA available", has_cuda)) 46 | if has_cuda: 47 | data.append(("CUDA ID", os.environ["CUDA_VISIBLE_DEVICES"])) 48 | devices = defaultdict(list) 49 | for k in range(torch.cuda.device_count()): 50 | devices[torch.cuda.get_device_name(k)].append(str(k)) 51 | for name, devids in devices.items(): 52 | data.append(("GPU " + ",".join(devids), name)) 53 | data.append(("Pillow", PIL.__version__)) 54 | 55 | try: 56 | import cv2 57 | 58 | data.append(("cv2", cv2.__version__)) 59 | except ImportError: 60 | pass 61 | env_str = tabulate(data) + "\n" 62 | env_str += collect_torch_env() 63 | return env_str 64 | 65 | 66 | def default_argument_parser(): 67 | """ 68 | create a simple parser to wrap around config file 69 | """ 70 | parser = argparse.ArgumentParser(description="visual-prompt") 71 | parser.add_argument( 72 | "--config-file", default="", metavar="FILE", help="path to config file") 73 | parser.add_argument( 74 | "--train-type", default="", help="training types") 75 | parser.add_argument( 76 | "opts", 77 | help="Modify config options using the command-line", 78 | default=None, 79 | nargs=argparse.REMAINDER, 80 | ) 81 | from datetime import datetime 82 | time_fmt = datetime.now().strftime("%y%m%d-%H-%M-%S") 83 | parser.add_argument('--id', default=time_fmt) 84 | 85 | return parser 86 | 87 | 88 | def logging_train_setup(args, cfg) -> None: 89 | output_dir = cfg.OUTPUT_DIR 90 | if output_dir: 91 | PathManager.mkdirs(output_dir) 92 | 93 | logger = logging.setup_logging( 94 | cfg.NUM_GPUS, get_world_size(), output_dir, name="visual_prompt") 95 | 96 | # Log basic information about environment, cmdline arguments, and config 97 | rank = get_rank() 98 | logger.info( 99 | f"Rank of current process: {rank}. World size: {get_world_size()}") 100 | logger.info("Environment info:\n" + collect_env_info()) 101 | 102 | logger.info("Command line arguments: " + str(args)) 103 | if hasattr(args, "config_file") and args.config_file != "": 104 | logger.info( 105 | "Contents of args.config_file={}:\n{}".format( 106 | args.config_file, 107 | PathManager.open(args.config_file, "r").read() 108 | ) 109 | ) 110 | # Show the config 111 | logger.info("Training with config:") 112 | logger.info(pprint.pformat(cfg)) 113 | # cudnn benchmark has large overhead. 114 | # It shouldn't be used considering the small size of typical val set. 115 | if not (hasattr(args, "eval_only") and args.eval_only): 116 | torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK 117 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | model_root=params 2 | data_name=$1 3 | encoder=$2 4 | batch_size=$3 5 | base_lr=$4 6 | num_tokens=$5 7 | gate_init=$6 8 | output_dir=results 9 | model_type="ssl-vit" 10 | 11 | 12 | if [ $data_name = "cub" ] 13 | then 14 | data_path=./data/CUB_200_2011 15 | data_name="CUB" 16 | num_class=200 17 | elif [ $data_name = "dogs" ] 18 | then 19 | data_path=./data/Dogs 20 | data_name="StanfordDogs" 21 | num_class=120 22 | elif [ $data_name = "cars" ] 23 | then 24 | data_path=./data/Cars 25 | data_name="StanfordCars" 26 | num_class=196 27 | elif [ $data_name = "nabirds" ] 28 | then 29 | data_path=./data/nabirds 30 | data_name="nabirds" 31 | num_class=555 32 | elif [ $data_name = "flowers" ] 33 | then 34 | data_path=./data/Flowers 35 | data_name="OxfordFlowers" 36 | num_class=102 37 | elif [ $data_name = "vtab-flowers" ] 38 | then 39 | data_path=~/datasets 40 | data_name="vtab-oxford_flowers102" 41 | num_class=102 42 | elif [ $data_name = "sun397" ] 43 | then 44 | data_path=~/datasets 45 | data_name="vtab-sun397" 46 | num_class=397 47 | elif [ $data_name = "pets" ] 48 | then 49 | data_path=~/datasets/ 50 | data_name="vtab-oxford_iiit_pet" 51 | num_class=37 52 | elif [ $data_name = "dmlab" ] 53 | then 54 | data_path=~/datasets 55 | data_name="vtab-dmlab" 56 | num_class=6 57 | elif [ $data_name = "clevr-distance" ] 58 | then 59 | data_path=~/datasets 60 | data_name='vtab-clevr(task="closest_object_distance")' 61 | num_class=6 62 | elif [ $data_name = "clevr-count" ] 63 | then 64 | data_path=~/datasets 65 | data_name='vtab-clevr(task="count_all")' 66 | num_class=8 67 | elif [ $data_name = "caltech101" ] 68 | then 69 | data_path=./data 70 | data_name="vtab-caltech101" 71 | num_class=102 72 | elif [ $data_name = "cifar100" ] 73 | then 74 | data_path=./data 75 | data_name="vtab-cifar(num_classes=100)" 76 | num_class=100 77 | elif [ $data_name = "dsprites-orientation" ] 78 | then 79 | data_path=~/datasets 80 | data_name='vtab-dsprites(predicted_attribute="label_orientation",num_classes=16)' 81 | num_class=16 82 | elif [ $data_name = "dsprites-location" ] 83 | then 84 | data_path=~/datasets 85 | data_name='vtab-dsprites(predicted_attribute="label_x_position",num_classes=16)' 86 | num_class=16 87 | elif [ $data_name = "dtd" ] 88 | then 89 | data_path=./data 90 | data_name="vtab-dtd" 91 | num_class=47 92 | elif [ $data_name = "eurosat" ] 93 | then 94 | data_path=~/datasets 95 | data_name="vtab-eurosat" 96 | num_class=10 97 | elif [ $data_name = "resisc" ] 98 | then 99 | data_path=~/datasets 100 | data_name="vtab-resisc45" 101 | num_class=45 102 | elif [ $data_name = "smallnorb-azimuth" ] 103 | then 104 | data_path=~/datasets 105 | data_name='vtab-smallnorb(predicted_attribute="label_azimuth")' 106 | num_class=18 107 | elif [ $data_name = "smallnorb-elevation" ] 108 | then 109 | data_path=~/datasets 110 | data_name='vtab-smallnorb(predicted_attribute="label_elevation")' 111 | num_class=9 112 | elif [ $data_name = "patch" ] 113 | then 114 | data_path=~/datasets 115 | data_name="vtab-patch_camelyon" 116 | num_class=2 117 | elif [ $data_name = "kitti" ] 118 | then 119 | data_path=~/datasets 120 | data_name='vtab-kitti(task="closest_vehicle_distance")' 121 | num_class=4 122 | elif [ $data_name = "svhn" ] 123 | then 124 | data_path=~/datasets 125 | data_name="vtab-svhn" 126 | num_class=10 127 | elif [ $data_name = "retino" ] 128 | then 129 | data_path=~/datasets 130 | data_name='vtab-diabetic_retinopathy(config="btgraham-300")' 131 | num_class=5 132 | fi 133 | 134 | 135 | seed=42 136 | echo $data_name 137 | echo $data_path 138 | echo $encoder 139 | 140 | 141 | python3 train.py \ 142 | --config-file configs/base-prompt.yaml \ 143 | DATA.BATCH_SIZE "${batch_size}" \ 144 | DATA.CROPSIZE "224" \ 145 | MODEL.PROMPT.NUM_TOKENS "${num_tokens}" \ 146 | SOLVER.WEIGHT_DECAY "0.0" \ 147 | SOLVER.BASE_LR "${base_lr}" \ 148 | MODEL.PROMPT.DROPOUT "0.1" \ 149 | SEED ${seed} \ 150 | MODEL.TYPE "${model_type}" \ 151 | MODEL.PROMPT.DEEP "False" \ 152 | MODEL.MODEL_ROOT "${model_root}" \ 153 | DATA.DATAPATH "${data_path}" \ 154 | DATA.NAME "${data_name}" \ 155 | DATA.FEATURE "${encoder}" \ 156 | DATA.NUMBER_CLASSES "${num_class}" \ 157 | MODEL.TRANSFER_TYPE "prompt" \ 158 | MODEL.PROMPT.INITIATION "random" \ 159 | MODEL.PROMPT.TEMP_LEARN "True" \ 160 | MODEL.PROMPT.GATE_PRIOR "True" \ 161 | MODEL.PROMPT.GATE_INIT "${gate_init}" \ 162 | MODEL.PROMPT.VIT_POOL_TYPE "original" \ 163 | OUTPUT_DIR "${output_dir}/seed${seed}" \ 164 | 165 | -------------------------------------------------------------------------------- /src/configs/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Config system (based on Detectron's).""" 4 | 5 | from .config_node import CfgNode 6 | 7 | 8 | # Global config object 9 | _C = CfgNode() 10 | # Example usage: 11 | # from configs.config import cfg 12 | 13 | _C.DBG = False 14 | _C.OUTPUT_DIR = "./output" 15 | _C.RUN_N_TIMES = 5 16 | # Perform benchmarking to select the fastest CUDNN algorithms to use 17 | # Note that this may increase the memory usage and will likely not result 18 | # in overall speedups when variable size inputs are used (e.g. COCO training) 19 | _C.CUDNN_BENCHMARK = False 20 | 21 | # Number of GPUs to use (applies to both training and testing) 22 | _C.NUM_GPUS = 1 23 | _C.NUM_SHARDS = 1 24 | 25 | # Note that non-determinism may still be present due to non-deterministic 26 | # operator implementations in GPU operator libraries 27 | _C.SEED = None 28 | 29 | # ---------------------------------------------------------------------- 30 | # Model options 31 | # ---------------------------------------------------------------------- 32 | _C.MODEL = CfgNode() 33 | _C.MODEL.TRANSFER_TYPE = "linear" # one of linear, end2end, prompt, adapter, side, partial-1, tinytl-bias 34 | _C.MODEL.WEIGHT_PATH = "" # if resume from some checkpoint file 35 | _C.MODEL.SAVE_CKPT = True 36 | 37 | _C.MODEL.MODEL_ROOT = "" # root folder for pretrained model weights 38 | 39 | _C.MODEL.TYPE = "vit" 40 | _C.MODEL.MLP_NUM = 0 41 | 42 | _C.MODEL.LINEAR = CfgNode() 43 | _C.MODEL.LINEAR.MLP_SIZES = [] 44 | _C.MODEL.LINEAR.DROPOUT = 0.1 45 | 46 | # ---------------------------------------------------------------------- 47 | # Prompt options 48 | # ---------------------------------------------------------------------- 49 | _C.MODEL.PROMPT = CfgNode() 50 | _C.MODEL.PROMPT.NUM_TOKENS = 5 51 | _C.MODEL.PROMPT.LOCATION = "prepend" 52 | 53 | # prompt initalizatioin: 54 | # (1) default "random" 55 | # (2) "final-cls" use aggregated final [cls] embeddings from training dataset 56 | # (3) "cls-nolastl": use first 12 cls embeddings (exclude the final output) for deep prompt 57 | # (4) "cls-nofirstl": use last 12 cls embeddings (exclude the input to first layer) 58 | _C.MODEL.PROMPT.INITIATION = "random" # "final-cls", "cls-first12" 59 | _C.MODEL.PROMPT.CLSEMB_FOLDER = "" 60 | _C.MODEL.PROMPT.CLSEMB_PATH = "" 61 | _C.MODEL.PROMPT.PROJECT = -1 # "projection mlp hidden dim" 62 | _C.MODEL.PROMPT.DEEP = False # "whether do deep prompt or not, only for prepend location" 63 | 64 | 65 | _C.MODEL.PROMPT.NUM_DEEP_LAYERS = None # if set to be an int, then do partial-deep prompt tuning 66 | _C.MODEL.PROMPT.REVERSE_DEEP = False # if to only update last n layers, not the input layer 67 | _C.MODEL.PROMPT.DEEP_SHARED = False # if true, all deep layers will be use the same prompt emb 68 | _C.MODEL.PROMPT.FORWARD_DEEP_NOEXPAND = False # if true, will not expand input sequence for layers without prompt 69 | # how to get the output emb for cls head: 70 | # original: follow the orignial backbone choice, 71 | # img_pool: image patch pool only 72 | # prompt_pool: prompt embd pool only 73 | # imgprompt_pool: pool everything but the cls token 74 | _C.MODEL.PROMPT.VIT_POOL_TYPE = "original" 75 | _C.MODEL.PROMPT.DROPOUT = 0.0 76 | _C.MODEL.PROMPT.SAVE_FOR_EACH_EPOCH = False 77 | 78 | _C.MODEL.PROMPT.GATE_PRIOR = False 79 | _C.MODEL.PROMPT.GATE_NUM = 11 80 | _C.MODEL.PROMPT.GATE_INIT = 10 81 | _C.MODEL.PROMPT.TEMP = 1.0 82 | _C.MODEL.PROMPT.TEMP_LEARN = False 83 | _C.MODEL.PROMPT.TEMP_NUM = 12 84 | _C.MODEL.PROMPT.TEMP_MIN = 0.01 85 | _C.MODEL.PROMPT.TEMP_MAX = 10.0 86 | # _C.MODEL.PROMPT.TEMP_MIN = 0.05 87 | # _C.MODEL.PROMPT.TEMP_MAX = 5.0 88 | 89 | # ---------------------------------------------------------------------- 90 | # adapter options 91 | # ---------------------------------------------------------------------- 92 | _C.MODEL.ADAPTER = CfgNode() 93 | _C.MODEL.ADAPTER.REDUCATION_FACTOR = 8 94 | _C.MODEL.ADAPTER.STYLE = "Pfeiffer" 95 | 96 | # ---------------------------------------------------------------------- 97 | # Solver options 98 | # ---------------------------------------------------------------------- 99 | _C.SOLVER = CfgNode() 100 | _C.SOLVER.LOSS = "softmax" 101 | _C.SOLVER.LOSS_ALPHA = 0.01 102 | 103 | _C.SOLVER.OPTIMIZER = "sgd" # or "adamw" 104 | _C.SOLVER.MOMENTUM = 0.9 105 | _C.SOLVER.WEIGHT_DECAY = 0.0001 106 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0 107 | 108 | _C.SOLVER.PATIENCE = 300 109 | 110 | 111 | _C.SOLVER.SCHEDULER = "cosine" 112 | 113 | _C.SOLVER.BASE_LR = 0.01 114 | _C.SOLVER.BIAS_MULTIPLIER = 1. # for prompt + bias 115 | 116 | _C.SOLVER.WARMUP_EPOCH = 5 117 | _C.SOLVER.TOTAL_EPOCH = 30 118 | _C.SOLVER.LOG_EVERY_N = 1000 119 | 120 | 121 | _C.SOLVER.DBG_TRAINABLE = False # if True, will print the name of trainable params 122 | 123 | # ---------------------------------------------------------------------- 124 | # Dataset options 125 | # ---------------------------------------------------------------------- 126 | _C.DATA = CfgNode() 127 | 128 | _C.DATA.NAME = "" 129 | _C.DATA.DATAPATH = "" 130 | _C.DATA.FEATURE = "" # e.g. inat2021_supervised 131 | 132 | _C.DATA.PERCENTAGE = 1.0 133 | _C.DATA.NUMBER_CLASSES = -1 134 | _C.DATA.MULTILABEL = False 135 | _C.DATA.CLASS_WEIGHTS_TYPE = "none" 136 | 137 | _C.DATA.CROPSIZE = 224 # or 384 138 | 139 | _C.DATA.NO_TEST = False 140 | _C.DATA.BATCH_SIZE = 32 141 | # Number of data loader workers per training process 142 | _C.DATA.NUM_WORKERS = 4 143 | # Load data to pinned host memory 144 | _C.DATA.PIN_MEMORY = True 145 | 146 | 147 | _C.DIST_BACKEND = "nccl" 148 | _C.DIST_INIT_PATH = "env://" 149 | _C.DIST_INIT_FILE = "" 150 | 151 | 152 | def get_cfg(): 153 | """ 154 | Get a copy of the default config. 155 | """ 156 | return _C.clone() 157 | -------------------------------------------------------------------------------- /src/configs/config_node.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Config system (based on Detectron's).""" 4 | 5 | from fvcore.common.config import CfgNode as _CfgNode 6 | from ..utils.file_io import PathManager 7 | 8 | 9 | class CfgNode(_CfgNode): 10 | """ 11 | The same as `fvcore.common.config.CfgNode`, but different in: 12 | 13 | support manifold path 14 | """ 15 | 16 | @classmethod 17 | def _open_cfg(cls, filename): 18 | return PathManager.open(filename, "r") 19 | 20 | def dump(self, *args, **kwargs): 21 | """ 22 | Returns: 23 | str: a yaml string representation of the config 24 | """ 25 | # to make it show up in docs 26 | return super().dump(*args, **kwargs) 27 | -------------------------------------------------------------------------------- /src/configs/vit_configs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Copyright (c) Meta Platforms, Inc. All Rights Reserved 4 | https://github.com/jeonsworld/ViT-pytorch/blob/main/models/configs.py 5 | """ 6 | import ml_collections 7 | 8 | 9 | def get_testing(): 10 | """Returns a minimal configuration for testing.""" 11 | config = ml_collections.ConfigDict() 12 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 13 | config.hidden_size = 1 14 | config.transformer = ml_collections.ConfigDict() 15 | config.transformer.mlp_dim = 1 16 | config.transformer.num_heads = 1 17 | config.transformer.num_layers = 1 18 | config.transformer.attention_dropout_rate = 0.0 19 | config.transformer.dropout_rate = 0.1 20 | config.classifier = 'token' 21 | config.representation_size = None 22 | return config 23 | 24 | 25 | def get_b16_config(): 26 | """Returns the ViT-B/16 configuration.""" 27 | config = ml_collections.ConfigDict() 28 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 29 | config.hidden_size = 768 30 | config.transformer = ml_collections.ConfigDict() 31 | config.transformer.mlp_dim = 3072 32 | config.transformer.num_heads = 12 33 | config.transformer.num_layers = 12 34 | config.transformer.attention_dropout_rate = 0.0 35 | config.transformer.dropout_rate = 0.1 36 | config.classifier = 'token' 37 | config.representation_size = None 38 | return config 39 | 40 | 41 | def get_r50_b16_config(): 42 | """Returns the Resnet50 + ViT-B/16 configuration.""" 43 | config = get_b16_config() 44 | del config.patches.size 45 | config.patches.grid = (14, 14) 46 | config.resnet = ml_collections.ConfigDict() 47 | config.resnet.num_layers = (3, 4, 9) 48 | config.resnet.width_factor = 1 49 | return config 50 | 51 | 52 | def get_b32_config(): 53 | """Returns the ViT-B/32 configuration.""" 54 | config = get_b16_config() 55 | config.patches.size = (32, 32) 56 | return config 57 | 58 | 59 | def get_b8_config(): 60 | """Returns the ViT-B/32 configuration.""" 61 | config = get_b16_config() 62 | config.patches.size = (8, 8) 63 | return config 64 | 65 | 66 | def get_l16_config(): 67 | """Returns the ViT-L/16 configuration.""" 68 | config = ml_collections.ConfigDict() 69 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 70 | config.hidden_size = 1024 71 | config.transformer = ml_collections.ConfigDict() 72 | config.transformer.mlp_dim = 4096 73 | config.transformer.num_heads = 16 74 | config.transformer.num_layers = 24 75 | config.transformer.attention_dropout_rate = 0.0 76 | config.transformer.dropout_rate = 0.1 77 | config.classifier = 'token' 78 | config.representation_size = None 79 | return config 80 | 81 | 82 | def get_l32_config(): 83 | """Returns the ViT-L/32 configuration.""" 84 | config = get_l16_config() 85 | config.patches.size = (32, 32) 86 | return config 87 | 88 | 89 | def get_h14_config(): 90 | """Returns the ViT-L/16 configuration.""" 91 | config = ml_collections.ConfigDict() 92 | config.patches = ml_collections.ConfigDict({'size': (14, 14)}) 93 | config.hidden_size = 1280 94 | config.transformer = ml_collections.ConfigDict() 95 | config.transformer.mlp_dim = 5120 96 | config.transformer.num_heads = 16 97 | config.transformer.num_layers = 32 98 | config.transformer.attention_dropout_rate = 0.0 99 | config.transformer.dropout_rate = 0.1 100 | config.classifier = 'token' 101 | config.representation_size = None 102 | return config 103 | -------------------------------------------------------------------------------- /src/data/datasets/json_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """JSON dataset: support CUB, NABrids, Flower, Dogs and Cars""" 4 | 5 | import os 6 | import torch 7 | import torch.utils.data 8 | import torchvision as tv 9 | import numpy as np 10 | from collections import Counter 11 | 12 | from ..transforms import get_transforms 13 | from ...utils import logging 14 | from ...utils.io_utils import read_json 15 | logger = logging.get_logger("visual_prompt") 16 | 17 | 18 | class JSONDataset(torch.utils.data.Dataset): 19 | def __init__(self, cfg, split): 20 | assert split in { 21 | "train", 22 | "val", 23 | "test", 24 | }, "Split '{}' not supported for {} dataset".format( 25 | split, cfg.DATA.NAME) 26 | logger.info("Constructing {} dataset {}...".format( 27 | cfg.DATA.NAME, split)) 28 | 29 | self.cfg = cfg 30 | self._split = split 31 | self.name = cfg.DATA.NAME 32 | self.data_dir = cfg.DATA.DATAPATH 33 | self.data_percentage = cfg.DATA.PERCENTAGE 34 | self._construct_imdb(cfg) 35 | self.transform = get_transforms(split, cfg.DATA.CROPSIZE) 36 | 37 | def get_anno(self): 38 | anno_path = os.path.join(self.data_dir, "{}.json".format(self._split)) 39 | if "train" in self._split: 40 | if self.data_percentage < 1.0: 41 | anno_path = os.path.join( 42 | self.data_dir, 43 | "{}_{}.json".format(self._split, self.data_percentage) 44 | ) 45 | assert os.path.exists(anno_path), "{} dir not found".format(anno_path) 46 | 47 | return read_json(anno_path) 48 | 49 | def get_imagedir(self): 50 | raise NotImplementedError() 51 | 52 | def _construct_imdb(self, cfg): 53 | """Constructs the imdb.""" 54 | 55 | img_dir = self.get_imagedir() 56 | assert os.path.exists(img_dir), "{} dir not found".format(img_dir) 57 | 58 | anno = self.get_anno() 59 | # Map class ids to contiguous ids 60 | self._class_ids = sorted(list(set(anno.values()))) 61 | self._class_id_cont_id = {v: i for i, v in enumerate(self._class_ids)} 62 | 63 | # Construct the image db 64 | self._imdb = [] 65 | for img_name, cls_id in anno.items(): 66 | cont_id = self._class_id_cont_id[cls_id] 67 | im_path = os.path.join(img_dir, img_name) 68 | self._imdb.append({"im_path": im_path, "class": cont_id}) 69 | 70 | logger.info("Number of images: {}".format(len(self._imdb))) 71 | logger.info("Number of classes: {}".format(len(self._class_ids))) 72 | 73 | def get_info(self): 74 | num_imgs = len(self._imdb) 75 | return num_imgs, self.get_class_num(), self.name 76 | 77 | def get_class_num(self): 78 | return self.cfg.DATA.NUMBER_CLASSES 79 | # return len(self._class_ids) 80 | 81 | def get_class_weights(self, weight_type): 82 | """get a list of class weight, return a list float""" 83 | if "train" not in self._split: 84 | raise ValueError( 85 | "only getting training class distribution, " + \ 86 | "got split {} instead".format(self._split) 87 | ) 88 | 89 | cls_num = self.get_class_num() 90 | if weight_type == "none": 91 | return [1.0] * cls_num 92 | 93 | id2counts = Counter(self._class_ids) 94 | assert len(id2counts) == cls_num 95 | num_per_cls = np.array([id2counts[i] for i in self._class_ids]) 96 | 97 | if weight_type == 'inv': 98 | mu = -1.0 99 | elif weight_type == 'inv_sqrt': 100 | mu = -0.5 101 | weight_list = num_per_cls ** mu 102 | weight_list = np.divide( 103 | weight_list, np.linalg.norm(weight_list, 1)) * cls_num 104 | return weight_list.tolist() 105 | 106 | def __getitem__(self, index): 107 | # Load the image 108 | im = tv.datasets.folder.default_loader(self._imdb[index]["im_path"]) 109 | label = self._imdb[index]["class"] 110 | im = self.transform(im) 111 | if self._split == "train": 112 | index = index 113 | else: 114 | index = f"{self._split}{index}" 115 | sample = { 116 | "image": im, 117 | "label": label, 118 | # "id": index 119 | } 120 | return sample 121 | 122 | def __len__(self): 123 | return len(self._imdb) 124 | 125 | 126 | class CUB200Dataset(JSONDataset): 127 | """CUB_200 dataset.""" 128 | 129 | def __init__(self, cfg, split): 130 | super(CUB200Dataset, self).__init__(cfg, split) 131 | 132 | def get_imagedir(self): 133 | return os.path.join(self.data_dir, "images") 134 | 135 | 136 | class CarsDataset(JSONDataset): 137 | """stanford-cars dataset.""" 138 | 139 | def __init__(self, cfg, split): 140 | super(CarsDataset, self).__init__(cfg, split) 141 | 142 | def get_imagedir(self): 143 | return self.data_dir 144 | 145 | 146 | class DogsDataset(JSONDataset): 147 | """stanford-dogs dataset.""" 148 | 149 | def __init__(self, cfg, split): 150 | super(DogsDataset, self).__init__(cfg, split) 151 | 152 | def get_imagedir(self): 153 | return os.path.join(self.data_dir, "Images") 154 | 155 | 156 | class FlowersDataset(JSONDataset): 157 | """flowers dataset.""" 158 | 159 | def __init__(self, cfg, split): 160 | super(FlowersDataset, self).__init__(cfg, split) 161 | 162 | def get_imagedir(self): 163 | return self.data_dir 164 | 165 | 166 | class NabirdsDataset(JSONDataset): 167 | """Nabirds dataset.""" 168 | 169 | def __init__(self, cfg, split): 170 | super(NabirdsDataset, self).__init__(cfg, split) 171 | 172 | def get_imagedir(self): 173 | return os.path.join(self.data_dir, "images") 174 | -------------------------------------------------------------------------------- /src/data/datasets/tf_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """a dataset that handles output of tf.data: support datasets from VTAB""" 4 | import functools 5 | import tensorflow.compat.v1 as tf 6 | import torch 7 | import torch.utils.data 8 | import numpy as np 9 | 10 | from collections import Counter 11 | from torch import Tensor 12 | 13 | from ..vtab_datasets import base 14 | # pylint: disable=unused-import 15 | from ..vtab_datasets import caltech 16 | from ..vtab_datasets import cifar 17 | from ..vtab_datasets import clevr 18 | from ..vtab_datasets import diabetic_retinopathy 19 | from ..vtab_datasets import dmlab 20 | from ..vtab_datasets import dsprites 21 | from ..vtab_datasets import dtd 22 | from ..vtab_datasets import eurosat 23 | from ..vtab_datasets import kitti 24 | from ..vtab_datasets import oxford_flowers102 25 | from ..vtab_datasets import oxford_iiit_pet 26 | from ..vtab_datasets import patch_camelyon 27 | from ..vtab_datasets import resisc45 28 | from ..vtab_datasets import smallnorb 29 | from ..vtab_datasets import sun397 30 | from ..vtab_datasets import svhn 31 | from ..vtab_datasets.registry import Registry 32 | 33 | from ...utils import logging 34 | logger = logging.get_logger("visual_prompt") 35 | tf.config.experimental.set_visible_devices([], 'GPU') # set tensorflow to not use gpu # noqa 36 | DATASETS = [ 37 | 'caltech101', 38 | 'cifar(num_classes=100)', 39 | 'dtd', 40 | 'oxford_flowers102', 41 | 'oxford_iiit_pet', 42 | 'patch_camelyon', 43 | 'sun397', 44 | 'svhn', 45 | 'resisc45', 46 | 'eurosat', 47 | 'dmlab', 48 | 'kitti(task="closest_vehicle_distance")', 49 | 'smallnorb(predicted_attribute="label_azimuth")', 50 | 'smallnorb(predicted_attribute="label_elevation")', 51 | 'dsprites(predicted_attribute="label_x_position",num_classes=16)', 52 | 'dsprites(predicted_attribute="label_orientation",num_classes=16)', 53 | 'clevr(task="closest_object_distance")', 54 | 'clevr(task="count_all")', 55 | 'diabetic_retinopathy(config="btgraham-300")' 56 | ] 57 | 58 | 59 | class TFDataset(torch.utils.data.Dataset): 60 | def __init__(self, cfg, split): 61 | assert split in { 62 | "train", 63 | "val", 64 | "test", 65 | "trainval" 66 | }, "Split '{}' not supported for {} dataset".format( 67 | split, cfg.DATA.NAME) 68 | logger.info("Constructing {} dataset {}...".format( 69 | cfg.DATA.NAME, split)) 70 | 71 | self.cfg = cfg 72 | self._split = split 73 | self.name = cfg.DATA.NAME 74 | 75 | self.img_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) 76 | self.img_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) 77 | 78 | self.get_data(cfg, split) 79 | 80 | def get_data(self, cfg, split): 81 | tf_data = build_tf_dataset(cfg, split) 82 | data_list = list(tf_data) # a list of tuples 83 | 84 | self._image_tensor_list = [t[0].numpy().squeeze() for t in data_list] 85 | self._targets = [int(t[1].numpy()[0]) for t in data_list] 86 | self._class_ids = sorted(list(set(self._targets))) 87 | 88 | logger.info("Number of images: {}".format(len(self._image_tensor_list))) 89 | logger.info("Number of classes: {} / {}".format( 90 | len(self._class_ids), self.get_class_num())) 91 | 92 | del data_list 93 | del tf_data 94 | 95 | def get_info(self): 96 | num_imgs = len(self._image_tensor_list) 97 | return num_imgs, self.get_class_num() 98 | 99 | def get_class_num(self): 100 | return self.cfg.DATA.NUMBER_CLASSES 101 | 102 | def get_class_weights(self, weight_type): 103 | """get a list of class weight, return a list float""" 104 | if "train" not in self._split: 105 | raise ValueError( 106 | "only getting training class distribution, " + \ 107 | "got split {} instead".format(self._split) 108 | ) 109 | 110 | cls_num = self.get_class_num() 111 | if weight_type == "none": 112 | return [1.0] * cls_num 113 | 114 | id2counts = Counter(self._class_ids) 115 | assert len(id2counts) == cls_num 116 | num_per_cls = np.array([id2counts[i] for i in self._class_ids]) 117 | 118 | if weight_type == 'inv': 119 | mu = -1.0 120 | elif weight_type == 'inv_sqrt': 121 | mu = -0.5 122 | weight_list = num_per_cls ** mu 123 | weight_list = np.divide( 124 | weight_list, np.linalg.norm(weight_list, 1)) * cls_num 125 | return weight_list.tolist() 126 | 127 | def __getitem__(self, index): 128 | # Load the image 129 | label = self._targets[index] 130 | im = to_torch_imgs( 131 | self._image_tensor_list[index], self.img_mean, self.img_std) 132 | 133 | if self._split == "train": 134 | index = index 135 | else: 136 | index = f"{self._split}{index}" 137 | sample = { 138 | "image": im, 139 | "label": label, 140 | # "id": index 141 | } 142 | return sample 143 | 144 | def __len__(self): 145 | return len(self._targets) 146 | 147 | 148 | def preprocess_fn(data, size=224, input_range=(0.0, 1.0)): 149 | image = data["image"] 150 | image = tf.image.resize(image, [size, size]) 151 | 152 | image = tf.cast(image, tf.float32) / 255.0 153 | image = image * (input_range[1] - input_range[0]) + input_range[0] 154 | 155 | data["image"] = image 156 | return data 157 | 158 | 159 | def build_tf_dataset(cfg, mode): 160 | """ 161 | Builds a tf data instance, then transform to a list of tensors and labels 162 | """ 163 | 164 | if mode not in ["train", "val", "test", "trainval"]: 165 | raise ValueError("The input pipeline supports `train`, `val`, `test`." 166 | "Provided mode is {}".format(mode)) 167 | 168 | vtab_dataname = cfg.DATA.NAME.split("vtab-")[-1] 169 | data_dir = cfg.DATA.DATAPATH 170 | if vtab_dataname in DATASETS: 171 | data_cls = Registry.lookup("data." + vtab_dataname) 172 | vtab_tf_dataloader = data_cls(data_dir=data_dir) 173 | else: 174 | raise ValueError("Unknown type for \"dataset\" field: {}".format( 175 | type(vtab_dataname))) 176 | 177 | split_name_dict = { 178 | "dataset_train_split_name": "train800", 179 | "dataset_val_split_name": "val200", 180 | "dataset_trainval_split_name": "train800val200", 181 | "dataset_test_split_name": "test", 182 | } 183 | 184 | def _dict_to_tuple(batch): 185 | return batch['image'], batch['label'] 186 | 187 | return vtab_tf_dataloader.get_tf_data( 188 | batch_size=1, # data_params["batch_size"], 189 | drop_remainder=False, 190 | split_name=split_name_dict[f"dataset_{mode}_split_name"], 191 | preprocess_fn=functools.partial( 192 | preprocess_fn, 193 | input_range=(0.0, 1.0), 194 | size=cfg.DATA.CROPSIZE, 195 | ), 196 | for_eval=mode != "train", # handles shuffling 197 | shuffle_buffer_size=1000, 198 | prefetch=1, 199 | train_examples=None, 200 | epochs=1 # setting epochs to 1 make sure it returns one copy of the dataset 201 | ).map(_dict_to_tuple) # return a PrefetchDataset object. (which does not have much documentation to go on) 202 | 203 | 204 | def to_torch_imgs(img: np.ndarray, mean: Tensor, std: Tensor) -> Tensor: 205 | t_img: Tensor = torch.from_numpy(np.transpose(img, (2, 0, 1))) 206 | t_img -= mean 207 | t_img /= std 208 | 209 | return t_img 210 | -------------------------------------------------------------------------------- /src/data/loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Data loader.""" 4 | import torch 5 | from torch.utils.data.distributed import DistributedSampler 6 | from torch.utils.data.sampler import RandomSampler 7 | 8 | from ..utils import logging 9 | from .datasets.json_dataset import ( 10 | CUB200Dataset, CarsDataset, DogsDataset, FlowersDataset, NabirdsDataset 11 | ) 12 | 13 | logger = logging.get_logger("visual_prompt") 14 | _DATASET_CATALOG = { 15 | "CUB": CUB200Dataset, 16 | 'OxfordFlowers': FlowersDataset, 17 | 'StanfordCars': CarsDataset, 18 | 'StanfordDogs': DogsDataset, 19 | "nabirds": NabirdsDataset, 20 | } 21 | 22 | 23 | def _construct_loader(cfg, split, batch_size, shuffle, drop_last): 24 | """Constructs the data loader for the given dataset.""" 25 | dataset_name = cfg.DATA.NAME 26 | 27 | # Construct the dataset 28 | if dataset_name.startswith("vtab-"): 29 | # import the tensorflow here only if needed 30 | from .datasets.tf_dataset import TFDataset 31 | dataset = TFDataset(cfg, split) 32 | else: 33 | assert ( 34 | dataset_name in _DATASET_CATALOG.keys() 35 | ), "Dataset '{}' not supported".format(dataset_name) 36 | dataset = _DATASET_CATALOG[dataset_name](cfg, split) 37 | 38 | # Create a sampler for multi-process training 39 | sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None 40 | # Create a loader 41 | loader = torch.utils.data.DataLoader( 42 | dataset, 43 | batch_size=batch_size, 44 | shuffle=(False if sampler else shuffle), 45 | sampler=sampler, 46 | num_workers=cfg.DATA.NUM_WORKERS, 47 | pin_memory=cfg.DATA.PIN_MEMORY, 48 | drop_last=drop_last, 49 | ) 50 | return loader 51 | 52 | 53 | def construct_train_loader(cfg): 54 | """Train loader wrapper.""" 55 | if cfg.NUM_GPUS > 1: 56 | drop_last = True 57 | else: 58 | drop_last = False 59 | return _construct_loader( 60 | cfg=cfg, 61 | split="train", 62 | batch_size=int(cfg.DATA.BATCH_SIZE / cfg.NUM_GPUS), 63 | shuffle=True, 64 | drop_last=drop_last, 65 | ) 66 | 67 | 68 | def construct_trainval_loader(cfg): 69 | """Train loader wrapper.""" 70 | if cfg.NUM_GPUS > 1: 71 | drop_last = True 72 | else: 73 | drop_last = False 74 | return _construct_loader( 75 | cfg=cfg, 76 | split="trainval", 77 | batch_size=int(cfg.DATA.BATCH_SIZE / cfg.NUM_GPUS), 78 | shuffle=True, 79 | drop_last=drop_last, 80 | ) 81 | 82 | 83 | def construct_test_loader(cfg): 84 | """Test loader wrapper.""" 85 | return _construct_loader( 86 | cfg=cfg, 87 | split="test", 88 | batch_size=int(cfg.DATA.BATCH_SIZE / cfg.NUM_GPUS), 89 | shuffle=False, 90 | drop_last=False, 91 | ) 92 | 93 | 94 | def construct_val_loader(cfg, batch_size=None): 95 | if batch_size is None: 96 | bs = int(cfg.DATA.BATCH_SIZE / cfg.NUM_GPUS) 97 | else: 98 | bs = batch_size 99 | """Validation loader wrapper.""" 100 | return _construct_loader( 101 | cfg=cfg, 102 | split="val", 103 | batch_size=bs, 104 | shuffle=False, 105 | drop_last=False, 106 | ) 107 | 108 | 109 | def shuffle(loader, cur_epoch): 110 | """"Shuffles the data.""" 111 | assert isinstance( 112 | loader.sampler, (RandomSampler, DistributedSampler) 113 | ), "Sampler type '{}' not supported".format(type(loader.sampler)) 114 | # RandomSampler handles shuffling automatically 115 | if isinstance(loader.sampler, DistributedSampler): 116 | # DistributedSampler shuffles data based on epoch 117 | loader.sampler.set_epoch(cur_epoch) 118 | -------------------------------------------------------------------------------- /src/data/transforms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Image transformations.""" 4 | import torchvision as tv 5 | 6 | 7 | def get_transforms(split, size): 8 | normalize = tv.transforms.Normalize( 9 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 10 | ) 11 | if size == 448: 12 | resize_dim = 512 13 | crop_dim = 448 14 | elif size == 224: 15 | resize_dim = 256 16 | crop_dim = 224 17 | elif size == 384: 18 | resize_dim = 438 19 | crop_dim = 384 20 | if split == "train": 21 | transform = tv.transforms.Compose( 22 | [ 23 | tv.transforms.Resize(resize_dim), 24 | tv.transforms.RandomCrop(crop_dim), 25 | tv.transforms.RandomHorizontalFlip(0.5), 26 | # tv.transforms.RandomResizedCrop(224, scale=(0.2, 1.0)), 27 | # tv.transforms.RandomHorizontalFlip(), 28 | tv.transforms.ToTensor(), 29 | normalize, 30 | ] 31 | ) 32 | else: 33 | transform = tv.transforms.Compose( 34 | [ 35 | tv.transforms.Resize(resize_dim), 36 | tv.transforms.CenterCrop(crop_dim), 37 | tv.transforms.ToTensor(), 38 | normalize, 39 | ] 40 | ) 41 | return transform 42 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/caltech.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Imports the Caltech images dataset.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | from . import base as base 23 | from .registry import Registry 24 | import tensorflow_datasets as tfds 25 | 26 | 27 | # Percentage of the original training set retained for training, the rest is 28 | # used as a validation set. 29 | _TRAIN_SPLIT_PERCENT = 90 30 | 31 | 32 | @Registry.register("data.caltech101", "class") 33 | class Caltech101(base.ImageTfdsData): 34 | """Provides the Caltech101 dataset. 35 | 36 | See the base class for additional details on the class. 37 | 38 | See TFDS dataset for details on the dataset: 39 | third_party/py/tensorflow_datasets/image/caltech.py 40 | 41 | The original (TFDS) dataset contains only a train and test split. We randomly 42 | sample _TRAIN_SPLIT_PERCENT% of the train split for our "train" set. The 43 | remainder of the TFDS train split becomes our "val" set. The full TFDS train 44 | split is called "trainval". The TFDS test split is used as our test set. 45 | 46 | Note that, in the TFDS dataset, the training split is class-balanced, but not 47 | the test split. Therefore, a significant difference between performance on the 48 | "val" and "test" sets should be expected. 49 | """ 50 | 51 | def __init__(self, data_dir=None): 52 | dataset_builder = tfds.builder("caltech101:3.0.1", data_dir=data_dir) 53 | dataset_builder.download_and_prepare() 54 | 55 | # Creates a dict with example counts for each split. 56 | trainval_count = dataset_builder.info.splits["train"].num_examples 57 | train_count = (_TRAIN_SPLIT_PERCENT * trainval_count) // 100 58 | test_count = dataset_builder.info.splits["test"].num_examples 59 | num_samples_splits = dict( 60 | train=train_count, 61 | val=trainval_count - train_count, 62 | trainval=trainval_count, 63 | test=test_count, 64 | train800=800, 65 | val200=200, 66 | train800val200=1000) 67 | 68 | # Defines dataset specific train/val/trainval/test splits. 69 | tfds_splits = { 70 | "train": "train[:{}]".format(train_count), 71 | "val": "train[{}:]".format(train_count), 72 | "trainval": "train", 73 | "test": "test", 74 | "train800": "train[:800]", 75 | "val200": "train[{}:{}]".format(train_count, train_count+200), 76 | "train800val200": ( 77 | "train[:800]+train[{}:{}]".format(train_count, train_count+200)), 78 | } 79 | 80 | super(Caltech101, self).__init__( 81 | dataset_builder=dataset_builder, 82 | tfds_splits=tfds_splits, 83 | num_samples_splits=num_samples_splits, 84 | num_preprocessing_threads=400, 85 | shuffle_buffer_size=3000, 86 | base_preprocess_fn=base.make_get_tensors_fn(("image", "label")), 87 | num_classes=dataset_builder.info.features["label"].num_classes) 88 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/cifar.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements Cifar data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | from . import base as base 23 | from .registry import Registry 24 | import tensorflow_datasets as tfds 25 | 26 | # This constant specifies the percentage of data that is used to create custom 27 | # train/val splits. Specifically, TRAIN_SPLIT_PERCENT% of the official training 28 | # split is used as a new training split and the rest is used for validation. 29 | TRAIN_SPLIT_PERCENT = 90 30 | 31 | 32 | @Registry.register("data.cifar", "class") 33 | class CifarData(base.ImageTfdsData): 34 | """Provides Cifar10 or Cifar100 data. 35 | 36 | Cifar comes only with a training and test set. Therefore, the validation set 37 | is split out of the original training set, and the remaining examples are used 38 | as the "train" split. The "trainval" split corresponds to the original 39 | training set. 40 | 41 | For additional details and usage, see the base class. 42 | """ 43 | 44 | def __init__(self, num_classes=10, data_dir=None, train_split_percent=None): 45 | 46 | if num_classes == 10: 47 | dataset_builder = tfds.builder("cifar10:3.*.*", data_dir=data_dir) 48 | elif num_classes == 100: 49 | dataset_builder = tfds.builder("cifar100:3.*.*", data_dir=data_dir) 50 | else: 51 | raise ValueError( 52 | "Number of classes must be 10 or 100, got {}".format(num_classes)) 53 | 54 | dataset_builder.download_and_prepare() 55 | 56 | train_split_percent = train_split_percent or TRAIN_SPLIT_PERCENT 57 | 58 | # Creates a dict with example counts for each split. 59 | trainval_count = dataset_builder.info.splits["train"].num_examples 60 | test_count = dataset_builder.info.splits["test"].num_examples 61 | num_samples_splits = { 62 | "train": (train_split_percent * trainval_count) // 100, 63 | "val": trainval_count - (train_split_percent * trainval_count) // 100, 64 | "trainval": trainval_count, 65 | "test": test_count, 66 | "train800": 800, 67 | "val200": 200, 68 | "train800val200": 1000, 69 | } 70 | 71 | # Defines dataset specific train/val/trainval/test splits. 72 | tfds_splits = { 73 | "train": "train[:{}]".format(num_samples_splits["train"]), 74 | "val": "train[{}:]".format(num_samples_splits["train"]), 75 | "trainval": "train", 76 | "test": "test", 77 | "train800": "train[:800]", 78 | "val200": "train[{}:{}]".format( 79 | num_samples_splits["train"], num_samples_splits["train"]+200), 80 | "train800val200": "train[:800]+train[{}:{}]".format( 81 | num_samples_splits["train"], num_samples_splits["train"]+200), 82 | } 83 | 84 | super(CifarData, self).__init__( 85 | dataset_builder=dataset_builder, 86 | tfds_splits=tfds_splits, 87 | num_samples_splits=num_samples_splits, 88 | num_preprocessing_threads=400, 89 | shuffle_buffer_size=10000, 90 | # Note: Export only image and label tensors with their original types. 91 | base_preprocess_fn=base.make_get_tensors_fn(["image", "label", "id"]), 92 | num_classes=dataset_builder.info.features["label"].num_classes) 93 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/clevr.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements CLEVR data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import numpy as np 24 | import tensorflow.compat.v1 as tf 25 | import tensorflow_datasets as tfds 26 | 27 | from . import base as base 28 | from .registry import Registry 29 | 30 | TRAIN_SPLIT_PERCENT = 90 31 | 32 | 33 | def _count_preprocess_fn(x): 34 | return {"image": x["image"], 35 | "label": tf.size(x["objects"]["size"]) - 3} 36 | 37 | 38 | def _count_cylinders_preprocess_fn(x): 39 | # Class distribution: 40 | 41 | num_cylinders = tf.reduce_sum( 42 | tf.cast(tf.equal(x["objects"]["shape"], 2), tf.int32)) 43 | return {"image": x["image"], "label": num_cylinders} 44 | 45 | 46 | def _closest_object_preprocess_fn(x): 47 | dist = tf.reduce_min(x["objects"]["pixel_coords"][:, 2]) 48 | # These thresholds are uniformly spaced and result in more or less balanced 49 | # distribution of classes, see the resulting histogram: 50 | 51 | thrs = np.array([0.0, 8.0, 8.5, 9.0, 9.5, 10.0, 100.0]) 52 | label = tf.reduce_max(tf.where((thrs - dist) < 0)) 53 | return {"image": x["image"], 54 | "label": label} 55 | 56 | 57 | _TASK_DICT = { 58 | "count_all": { 59 | "preprocess_fn": _count_preprocess_fn, 60 | "num_classes": 8 61 | }, 62 | "count_cylinders": { 63 | "preprocess_fn": _count_cylinders_preprocess_fn, 64 | "num_classes": 11 65 | }, 66 | "closest_object_distance": { 67 | "preprocess_fn": _closest_object_preprocess_fn, 68 | "num_classes": 6 69 | }, 70 | } 71 | 72 | 73 | @Registry.register("data.clevr", "class") 74 | class CLEVRData(base.ImageTfdsData): 75 | """Provides CLEVR dataset. 76 | 77 | Currently, two tasks are supported: 78 | 1. Predict number of objects. 79 | 2. Predict distnace to the closest object. 80 | """ 81 | 82 | def __init__(self, task, data_dir=None): 83 | 84 | if task not in _TASK_DICT: 85 | raise ValueError("Unknown task: %s" % task) 86 | 87 | dataset_builder = tfds.builder("clevr:3.*.*", data_dir=data_dir) 88 | dataset_builder.download_and_prepare() 89 | 90 | # Creates a dict with example counts for each split. 91 | trainval_count = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples 92 | test_count = dataset_builder.info.splits[tfds.Split.TEST].num_examples 93 | num_samples_splits = { 94 | "train": (TRAIN_SPLIT_PERCENT * trainval_count) // 100, 95 | "val": trainval_count - (TRAIN_SPLIT_PERCENT * trainval_count) // 100, 96 | "trainval": trainval_count, 97 | "test": test_count, 98 | "train800": 800, 99 | "val200": 200, 100 | "train800val200": 1000, 101 | } 102 | 103 | # Defines dataset specific train/val/trainval/test splits. 104 | tfds_splits = { 105 | "train": "train[:{}]".format(num_samples_splits["train"]), 106 | "val": "train[{}:]".format(num_samples_splits["train"]), 107 | "trainval": "train", 108 | "test": "validation", 109 | "train800": "train[:800]", 110 | "val200": "train[{}:{}]".format( 111 | num_samples_splits["train"], num_samples_splits["train"]+200), 112 | "train800val200": "train[:800]+train[{}:{}]".format( 113 | num_samples_splits["train"], num_samples_splits["train"]+200), 114 | } 115 | 116 | task = _TASK_DICT[task] 117 | base_preprocess_fn = task["preprocess_fn"] 118 | 119 | super(CLEVRData, self).__init__( 120 | dataset_builder=dataset_builder, 121 | tfds_splits=tfds_splits, 122 | num_samples_splits=num_samples_splits, 123 | num_preprocessing_threads=400, 124 | shuffle_buffer_size=10000, 125 | # Note: Export only image and label tensors with their original types. 126 | base_preprocess_fn=base_preprocess_fn, 127 | num_classes=task["num_classes"]) 128 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/diabetic_retinopathy.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements Diabetic Retinopathy data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow.compat.v1 as tf 24 | import tensorflow_addons.image as tfa_image 25 | import tensorflow_datasets as tfds 26 | 27 | from . import base as base 28 | from .registry import Registry 29 | 30 | 31 | @Registry.register("data.diabetic_retinopathy", "class") 32 | class RetinopathyData(base.ImageTfdsData): 33 | """Provides Diabetic Retinopathy classification data. 34 | 35 | Retinopathy comes only with a training and test set. Therefore, the validation 36 | set is split out of the original training set, and the remaining examples are 37 | used as the "train" split. The "trainval" split corresponds to the original 38 | training set. 39 | 40 | For additional details and usage, see the base class. 41 | """ 42 | 43 | _CONFIGS_WITH_GREY_BACKGROUND = ["btgraham-300"] 44 | 45 | def __init__(self, config="btgraham-300", heavy_train_augmentation=False, 46 | data_dir=None): 47 | """Initializer for Diabetic Retinopathy dataset. 48 | 49 | Args: 50 | config: Name of the TFDS config to use for this dataset. 51 | heavy_train_augmentation: If True, use heavy data augmentation on the 52 | training data. Recommended to achieve SOTA. 53 | data_dir: directory for downloading and storing the data. 54 | """ 55 | config_and_version = config + ":3.*.*" 56 | dataset_builder = tfds.builder("diabetic_retinopathy_detection/{}".format( 57 | config_and_version), data_dir=data_dir) 58 | self._config = config 59 | self._heavy_train_augmentation = heavy_train_augmentation 60 | 61 | dataset_builder.download_and_prepare() 62 | 63 | # Defines dataset specific train/val/trainval/test splits. 64 | tfds_splits = { 65 | "train": "train", 66 | "val": "validation", 67 | "trainval": "train+validation", 68 | "test": "test", 69 | "train800": "train[:800]", 70 | "val200": "validation[:200]", 71 | "train800val200": "train[:800]+validation[:200]", 72 | } 73 | 74 | # Creates a dict with example counts for each split. 75 | train_count = dataset_builder.info.splits["train"].num_examples 76 | val_count = dataset_builder.info.splits["validation"].num_examples 77 | test_count = dataset_builder.info.splits["test"].num_examples 78 | num_samples_splits = { 79 | "train": train_count, 80 | "val": val_count, 81 | "trainval": train_count + val_count, 82 | "test": test_count, 83 | "train800": 800, 84 | "val200": 200, 85 | "train800val200": 1000, 86 | } 87 | 88 | super(RetinopathyData, self).__init__( 89 | dataset_builder=dataset_builder, 90 | tfds_splits=tfds_splits, 91 | num_samples_splits=num_samples_splits, 92 | num_preprocessing_threads=400, 93 | shuffle_buffer_size=10000, 94 | # Note: Export only image and label tensors with their original types. 95 | base_preprocess_fn=base.make_get_tensors_fn(["image", "label"]), 96 | num_classes=dataset_builder.info.features["label"].num_classes) 97 | 98 | @property 99 | def config(self): 100 | return self._config 101 | 102 | @property 103 | def heavy_train_augmentation(self): 104 | return self._heavy_train_augmentation 105 | 106 | def get_tf_data(self, 107 | split_name, 108 | batch_size, 109 | preprocess_fn=None, 110 | for_eval=False, 111 | **kwargs): 112 | if self._heavy_train_augmentation and not for_eval: 113 | preprocess_fn = base.compose_preprocess_fn( 114 | self._heavy_train_augmentation, preprocess_fn) 115 | 116 | return super(RetinopathyData, self).get_tf_data( 117 | split_name=split_name, 118 | batch_size=batch_size, 119 | preprocess_fn=preprocess_fn, 120 | for_eval=for_eval, 121 | **kwargs) 122 | 123 | def _sample_heavy_data_augmentation_parameters(self): 124 | # Scale image +/- 10%. 125 | s = tf.random.uniform(shape=(), minval=-0.1, maxval=0.1) 126 | # Rotate image [0, 2pi). 127 | a = tf.random.uniform(shape=(), minval=0.0, maxval=2.0 * 3.1415926535) 128 | # Vertically shear image +/- 20%. 129 | b = tf.random.uniform(shape=(), minval=-0.2, maxval=0.2) + a 130 | # Horizontal and vertial flipping. 131 | hf = tf.random.shuffle([-1.0, 1.0])[0] 132 | vf = tf.random.shuffle([-1.0, 1.0])[0] 133 | # Relative x,y translation. 134 | dx = tf.random.uniform(shape=(), minval=-0.1, maxval=0.1) 135 | dy = tf.random.uniform(shape=(), minval=-0.1, maxval=0.1) 136 | return s, a, b, hf, vf, dx, dy 137 | 138 | def _heavy_data_augmentation_fn(self, example): 139 | """Perform heavy augmentation on a given input data example. 140 | 141 | This is the same data augmentation as the one done by Ben Graham, the winner 142 | of the 2015 Kaggle competition. See: 143 | https://github.com/btgraham/SparseConvNet/blob/a6bdb0c938b3556c1e6c23d5a014db9f404502b9/kaggleDiabetes1.cpp#L12 144 | 145 | Args: 146 | example: A dictionary containing an "image" key with the image to 147 | augment. 148 | 149 | Returns: 150 | The input dictionary with the key "image" containing the augmented image. 151 | """ 152 | image = example["image"] 153 | image_shape = tf.shape(image) 154 | if len(image.get_shape().as_list()) not in [2, 3]: 155 | raise ValueError( 156 | "Input image must be a rank-2 or rank-3 tensor, but rank-{} " 157 | "was given".format(len(image.get_shape().as_list()))) 158 | height = tf.cast(image_shape[0], dtype=tf.float32) 159 | width = tf.cast(image_shape[1], dtype=tf.float32) 160 | # Sample data augmentation parameters. 161 | s, a, b, hf, vf, dx, dy = self._sample_heavy_data_augmentation_parameters() 162 | # Rotation + scale. 163 | c00 = (1 + s) * tf.cos(a) 164 | c01 = (1 + s) * tf.sin(a) 165 | c10 = (s - 1) * tf.sin(b) 166 | c11 = (1 - s) * tf.cos(b) 167 | # Horizontal and vertial flipping. 168 | c00 = c00 * hf 169 | c01 = c01 * hf 170 | c10 = c10 * vf 171 | c11 = c11 * vf 172 | # Convert x,y translation to absolute values. 173 | dx = width * dx 174 | dy = height * dy 175 | # Convert affine matrix to TF's transform. Matrix is applied w.r.t. the 176 | # center of the image. 177 | cy = height / 2.0 178 | cx = width / 2.0 179 | affine_matrix = [[c00, c01, (1.0 - c00) * cx - c01 * cy + dx], 180 | [c10, c11, (1.0 - c11) * cy - c10 * cx + dy], 181 | [0.0, 0.0, 1.0]] 182 | affine_matrix = tf.convert_to_tensor(affine_matrix, dtype=tf.float32) 183 | transform = tfa_image.transform_ops.matrices_to_flat_transforms( 184 | tf.linalg.inv(affine_matrix)) 185 | if self._config in self._CONFIGS_WITH_GREY_BACKGROUND: 186 | # Since background is grey in these configs, put in pixels in [-1, 1] 187 | # range to avoid artifacts from the affine transformation. 188 | image = tf.cast(image, dtype=tf.float32) 189 | image = (image / 127.5) - 1.0 190 | # Apply the affine transformation. 191 | image = tfa_image.transform(images=image, transforms=transform) 192 | if self._config in self._CONFIGS_WITH_GREY_BACKGROUND: 193 | # Put pixels back to [0, 255] range and cast to uint8, since this is what 194 | # our preprocessing pipeline usually expects. 195 | image = (1.0 + image) * 127.5 196 | image = tf.cast(image, dtype=tf.uint8) 197 | example["image"] = image 198 | return example 199 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/dmlab.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements Dmlab data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | 28 | 29 | @Registry.register("data.dmlab", "class") 30 | class DmlabData(base.ImageTfdsData): 31 | """Dmlab dataset. 32 | 33 | The Dmlab dataset contains frames observed by the agent acting in the 34 | DMLab environment, which are annotated by the distance between 35 | the agent and various objects present in the environment. The goal is to 36 | is to evaluate the ability of a visual model to reason about distances 37 | from the visual input in 3D environments. The Dmlab dataset consists of 38 | 360x480 color images in 6 classes. The classes are 39 | {close, far, very far} x {positive reward, negative reward} 40 | respectively. 41 | """ 42 | 43 | def __init__(self, data_dir=None): 44 | 45 | dataset_builder = tfds.builder("dmlab:2.0.1", data_dir=data_dir) 46 | 47 | tfds_splits = { 48 | "train": "train", 49 | "val": "validation", 50 | "trainval": "train+validation", 51 | "test": "test", 52 | "train800": "train[:800]", 53 | "val200": "validation[:200]", 54 | "train800val200": "train[:800]+validation[:200]", 55 | } 56 | 57 | # Example counts are retrieved from the tensorflow dataset info. 58 | train_count = dataset_builder.info.splits["train"].num_examples 59 | val_count = dataset_builder.info.splits["validation"].num_examples 60 | test_count = dataset_builder.info.splits["test"].num_examples 61 | 62 | # Creates a dict with example counts for each split. 63 | num_samples_splits = { 64 | "train": train_count, 65 | "val": val_count, 66 | "trainval": train_count + val_count, 67 | "test": test_count, 68 | "train800": 800, 69 | "val200": 200, 70 | "train800val200": 1000, 71 | } 72 | 73 | super(DmlabData, self).__init__( 74 | dataset_builder=dataset_builder, 75 | tfds_splits=tfds_splits, 76 | num_samples_splits=num_samples_splits, 77 | num_preprocessing_threads=400, 78 | shuffle_buffer_size=10000, 79 | base_preprocess_fn=base.make_get_and_cast_tensors_fn({ 80 | "image": ("image", None), 81 | "label": ("label", None), 82 | }), 83 | num_classes=dataset_builder.info.features["label"].num_classes, 84 | image_key="image") 85 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/dsprites.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements the DSprites data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow.compat.v1 as tf 24 | import tensorflow_datasets as tfds 25 | 26 | from . import base as base 27 | from .registry import Registry 28 | 29 | 30 | # These constants specify the percentage of data that is used to create custom 31 | # train/val splits. Specifically, TRAIN_SPLIT_PERCENT% of the data set is used 32 | # as a new training split and VAL_SPLIT_PERCENT% is used for validation. 33 | # The rest is used for testing. 34 | TRAIN_SPLIT_PERCENT = 80 35 | VAL_SPLIT_PERCENT = 10 36 | 37 | 38 | @Registry.register("data.dsprites", "class") 39 | class DSpritesData(base.ImageTfdsData): 40 | """Provides the DSprites data set. 41 | 42 | DSprites only comes with a training set. Therefore, the training, validation, 43 | and test set are split out of the original training set. 44 | 45 | For additional details and usage, see the base class. 46 | 47 | The data set page is https://github.com/deepmind/dsprites-dataset/. 48 | """ 49 | 50 | def __init__(self, predicted_attribute, num_classes=None, data_dir=None): 51 | dataset_builder = tfds.builder("dsprites:2.*.*", data_dir=data_dir) 52 | dataset_builder.download_and_prepare() 53 | info = dataset_builder.info 54 | 55 | if predicted_attribute not in dataset_builder.info.features: 56 | raise ValueError( 57 | "{} is not a valid attribute to predict.".format(predicted_attribute)) 58 | 59 | # If num_classes is set, we group together nearby integer values to arrive 60 | # at the desired number of classes. This is useful for example for grouping 61 | # together different spatial positions. 62 | num_original_classes = info.features[predicted_attribute].num_classes 63 | if num_classes is None: 64 | num_classes = num_original_classes 65 | if not isinstance(num_classes, int) or num_classes <= 1 or ( 66 | num_classes > num_original_classes): 67 | raise ValueError( 68 | "The number of classes should be None or in [2, ..., num_classes].") 69 | class_division_factor = float(num_original_classes) / num_classes 70 | 71 | # Creates a dict with example counts for each split. 72 | num_total = dataset_builder.info.splits["train"].num_examples 73 | num_samples_train = TRAIN_SPLIT_PERCENT * num_total // 100 74 | num_samples_val = VAL_SPLIT_PERCENT * num_total // 100 75 | num_samples_splits = { 76 | "train": num_samples_train, 77 | "val": num_samples_val, 78 | "trainval": num_samples_val + num_samples_train, 79 | "test": num_total - num_samples_val - num_samples_train, 80 | "train800": 800, 81 | "val200": 200, 82 | "train800val200": 1000, 83 | } 84 | 85 | # Defines dataset specific train/val/trainval/test splits. 86 | tfds_splits = { 87 | "train": "train[:{}]".format(num_samples_splits["train"]), 88 | "val": "train[{}:{}]".format(num_samples_splits["train"], 89 | num_samples_splits["trainval"]), 90 | "trainval": "train[:{}]".format(num_samples_splits["trainval"]), 91 | "test": "train[{}:]".format(num_samples_splits["trainval"]), 92 | "train800": "train[:800]", 93 | "val200": "train[{}:{}]".format(num_samples_splits["train"], 94 | num_samples_splits["train"]+200), 95 | "train800val200": "train[:800]+train[{}:{}]".format( 96 | num_samples_splits["train"], num_samples_splits["train"]+200), 97 | } 98 | 99 | def preprocess_fn(tensors): 100 | # For consistency with other datasets, image needs to have three channels 101 | # and be in [0, 255). 102 | images = tf.tile(tensors["image"], [1, 1, 3]) * 255 103 | label = tf.cast( 104 | tf.math.floordiv( 105 | tf.cast(tensors[predicted_attribute], tf.float32), 106 | class_division_factor), info.features[predicted_attribute].dtype) 107 | return dict(image=images, label=label) 108 | 109 | super(DSpritesData, self).__init__( 110 | dataset_builder=dataset_builder, 111 | tfds_splits=tfds_splits, 112 | num_samples_splits=num_samples_splits, 113 | num_preprocessing_threads=400, 114 | shuffle_buffer_size=10000, 115 | # We extract the attribute we want to predict in the preprocessing. 116 | base_preprocess_fn=preprocess_fn, 117 | num_classes=num_classes) 118 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/dtd.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements the Describable Textures Dataset (DTD) data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | 28 | 29 | @Registry.register("data.dtd", "class") 30 | class DTDData(base.ImageTfdsData): 31 | """Provides Describable Textures Dataset (DTD) data. 32 | 33 | As of version 1.0.0, the train/val/test splits correspond to those of the 34 | 1st fold of the official cross-validation partition. 35 | 36 | For additional details and usage, see the base class. 37 | """ 38 | 39 | def __init__(self, data_dir=None): 40 | 41 | dataset_builder = tfds.builder("dtd:3.*.*", data_dir=data_dir) 42 | dataset_builder.download_and_prepare() 43 | 44 | # Defines dataset specific train/val/trainval/test splits. 45 | tfds_splits = { 46 | "train": "train", 47 | "val": "validation", 48 | "trainval": "train+validation", 49 | "test": "test", 50 | "train800": "train[:800]", 51 | "val200": "validation[:200]", 52 | "train800val200": "train[:800]+validation[:200]", 53 | } 54 | 55 | # Creates a dict with example counts for each split. 56 | train_count = dataset_builder.info.splits["train"].num_examples 57 | val_count = dataset_builder.info.splits["validation"].num_examples 58 | test_count = dataset_builder.info.splits["test"].num_examples 59 | num_samples_splits = { 60 | "train": train_count, 61 | "val": val_count, 62 | "trainval": train_count + val_count, 63 | "test": test_count, 64 | "train800": 800, 65 | "val200": 200, 66 | "train800val200": 1000, 67 | } 68 | 69 | super(DTDData, self).__init__( 70 | dataset_builder=dataset_builder, 71 | tfds_splits=tfds_splits, 72 | num_samples_splits=num_samples_splits, 73 | num_preprocessing_threads=400, 74 | shuffle_buffer_size=10000, 75 | # Note: Export only image and label tensors with their original types. 76 | base_preprocess_fn=base.make_get_tensors_fn(["image", "label"]), 77 | num_classes=dataset_builder.info.features["label"].num_classes) 78 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements EurosatData class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | 28 | TRAIN_SPLIT_PERCENT = 60 29 | VALIDATION_SPLIT_PERCENT = 20 30 | TEST_SPLIT_PERCENT = 20 31 | 32 | 33 | @Registry.register("data.eurosat", "class") 34 | class EurosatData(base.ImageTfdsData): 35 | """Provides EuroSat dataset. 36 | 37 | EuroSAT dataset is based on Sentinel-2 satellite images covering 13 spectral 38 | bands and consisting of 10 classes with 27000 labeled and 39 | geo-referenced samples. 40 | 41 | URL: https://github.com/phelber/eurosat 42 | """ 43 | 44 | def __init__(self, subset="rgb", data_key="image", data_dir=None): 45 | dataset_name = "eurosat/{}:2.*.*".format(subset) 46 | dataset_builder = tfds.builder(dataset_name, data_dir=data_dir) 47 | dataset_builder.download_and_prepare() 48 | 49 | # Example counts are retrieved from the tensorflow dataset info. 50 | num_examples = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples 51 | train_count = num_examples * TRAIN_SPLIT_PERCENT // 100 52 | val_count = num_examples * VALIDATION_SPLIT_PERCENT // 100 53 | test_count = num_examples * TEST_SPLIT_PERCENT // 100 54 | 55 | tfds_splits = { 56 | "train": 57 | "train[:{}]".format(train_count), 58 | "val": 59 | "train[{}:{}]".format(train_count, train_count+val_count), 60 | "trainval": 61 | "train[:{}]".format(train_count+val_count), 62 | "test": 63 | "train[{}:]".format(train_count+val_count), 64 | "train800": 65 | "train[:800]", 66 | "val200": 67 | "train[{}:{}]".format(train_count, train_count+200), 68 | "train800val200": 69 | "train[:800]+train[{}:{}]".format(train_count, train_count+200), 70 | } 71 | 72 | # Creates a dict with example counts for each split. 73 | num_samples_splits = { 74 | "train": train_count, 75 | "val": val_count, 76 | "trainval": train_count + val_count, 77 | "test": test_count, 78 | "train800": 800, 79 | "val200": 200, 80 | "train800val200": 1000, 81 | } 82 | 83 | num_channels = 3 84 | if data_key == "sentinel2": 85 | num_channels = 13 86 | 87 | super(EurosatData, self).__init__( 88 | dataset_builder=dataset_builder, 89 | tfds_splits=tfds_splits, 90 | num_samples_splits=num_samples_splits, 91 | num_preprocessing_threads=100, 92 | shuffle_buffer_size=10000, 93 | base_preprocess_fn=base.make_get_and_cast_tensors_fn({ 94 | data_key: ("image", None), 95 | "label": ("label", None), 96 | }), 97 | image_key=data_key, 98 | num_channels=num_channels, 99 | num_classes=dataset_builder.info.features["label"].num_classes) 100 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/kitti.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements Kitti data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import numpy as np 24 | 25 | import tensorflow.compat.v1 as tf 26 | import tensorflow_datasets as tfds 27 | 28 | from . import base as base 29 | from .registry import Registry 30 | 31 | 32 | def _count_all_pp(x): 33 | """Count all objects.""" 34 | # Count distribution (thresholded at 15): 35 | 36 | label = tf.math.minimum(tf.size(x["objects"]["type"]) - 1, 8) 37 | return {"image": x["image"], "label": label} 38 | 39 | 40 | def _count_vehicles_pp(x): 41 | """Counting vehicles.""" 42 | # Label distribution: 43 | 44 | vehicles = tf.where(x["objects"]["type"] < 3) # Car, Van, Truck. 45 | # Cap at 3. 46 | label = tf.math.minimum(tf.size(vehicles), 3) 47 | return {"image": x["image"], "label": label} 48 | 49 | 50 | def _count_left_pp(x): 51 | """Count objects on the left hand side of the camera.""" 52 | # Count distribution (thresholded at 15): 53 | 54 | # Location feature contains (x, y, z) in meters w.r.t. the camera. 55 | objects_on_left = tf.where(x["objects"]["location"][:, 0] < 0) 56 | label = tf.math.minimum(tf.size(objects_on_left), 8) 57 | return {"image": x["image"], "label": label} 58 | 59 | 60 | def _count_far_pp(x): 61 | """Counts objects far from the camera.""" 62 | # Threshold removes ~half of the objects. 63 | # Count distribution (thresholded at 15): 64 | 65 | # Location feature contains (x, y, z) in meters w.r.t. the camera. 66 | distant_objects = tf.where(x["objects"]["location"][:, 2] >= 25) 67 | label = tf.math.minimum(tf.size(distant_objects), 8) 68 | return {"image": x["image"], "label": label} 69 | 70 | 71 | def _count_near_pp(x): 72 | """Counts objects close to the camera.""" 73 | # Threshold removes ~half of the objects. 74 | # Count distribution: 75 | 76 | # Location feature contains (x, y, z) in meters w.r.t. the camera. 77 | close_objects = tf.where(x["objects"]["location"][:, 2] < 25) 78 | label = tf.math.minimum(tf.size(close_objects), 8) 79 | return {"image": x["image"], "label": label} 80 | 81 | 82 | def _closest_object_distance_pp(x): 83 | """Predict the distance to the closest object.""" 84 | # Label distribution: 85 | 86 | # Location feature contains (x, y, z) in meters w.r.t. the camera. 87 | dist = tf.reduce_min(x["objects"]["location"][:, 2]) 88 | thrs = np.array([-100, 5.6, 8.4, 13.4, 23.4]) 89 | label = tf.reduce_max(tf.where((thrs - dist) < 0)) 90 | return {"image": x["image"], "label": label} 91 | 92 | 93 | def _closest_vehicle_distance_pp(x): 94 | """Predict the distance to the closest vehicle.""" 95 | # Label distribution: 96 | 97 | # Location feature contains (x, y, z) in meters w.r.t. the camera. 98 | vehicles = tf.where(x["objects"]["type"] < 3) # Car, Van, Truck. 99 | vehicle_z = tf.gather(params=x["objects"]["location"][:, 2], indices=vehicles) 100 | vehicle_z = tf.concat([vehicle_z, tf.constant([[1000.0]])], axis=0) 101 | dist = tf.reduce_min(vehicle_z) 102 | # Results in a uniform distribution over three distances, plus one class for 103 | # "no vehicle". 104 | thrs = np.array([-100.0, 8.0, 20.0, 999.0]) 105 | label = tf.reduce_max(tf.where((thrs - dist) < 0)) 106 | return {"image": x["image"], "label": label} 107 | 108 | 109 | def _closest_object_x_location_pp(x): 110 | """Predict the absolute x position of the closest object.""" 111 | # Label distribution: 112 | 113 | # Location feature contains (x, y, z) in meters w.r.t. the camera. 114 | idx = tf.math.argmin(x["objects"]["location"][:, 2]) 115 | xloc = x["objects"]["location"][idx, 0] 116 | thrs = np.array([-100, -6.4, -3.5, 0.0, 3.3, 23.9]) 117 | label = tf.reduce_max(tf.where((thrs - xloc) < 0)) 118 | return {"image": x["image"], "label": label} 119 | 120 | 121 | _TASK_DICT = { 122 | "count_all": { 123 | "preprocess_fn": _count_all_pp, 124 | "num_classes": 16, 125 | }, 126 | "count_left": { 127 | "preprocess_fn": _count_left_pp, 128 | "num_classes": 16, 129 | }, 130 | "count_far": { 131 | "preprocess_fn": _count_far_pp, 132 | "num_classes": 16, 133 | }, 134 | "count_near": { 135 | "preprocess_fn": _count_near_pp, 136 | "num_classes": 16, 137 | }, 138 | "closest_object_distance": { 139 | "preprocess_fn": _closest_object_distance_pp, 140 | "num_classes": 5, 141 | }, 142 | "closest_object_x_location": { 143 | "preprocess_fn": _closest_object_x_location_pp, 144 | "num_classes": 5, 145 | }, 146 | "count_vehicles": { 147 | "preprocess_fn": _count_vehicles_pp, 148 | "num_classes": 4, 149 | }, 150 | "closest_vehicle_distance": { 151 | "preprocess_fn": _closest_vehicle_distance_pp, 152 | "num_classes": 4, 153 | }, 154 | } 155 | 156 | 157 | @Registry.register("data.kitti", "class") 158 | class KittiData(base.ImageTfdsData): 159 | """Provides Kitti dataset. 160 | 161 | Six tasks are supported: 162 | 1. Count the number of objects. 163 | 2. Count the number of objects on the left hand side of the camera. 164 | 3. Count the number of objects in the foreground. 165 | 4. Count the number of objects in the background. 166 | 5. Predict the distance of the closest object. 167 | 6. Predict the x-location (w.r.t. the camera) of the closest object. 168 | """ 169 | 170 | def __init__(self, task, data_dir=None): 171 | 172 | if task not in _TASK_DICT: 173 | raise ValueError("Unknown task: %s" % task) 174 | 175 | dataset_builder = tfds.builder("kitti:3.2.0", data_dir=data_dir) 176 | dataset_builder.download_and_prepare() 177 | 178 | tfds_splits = { 179 | "train": "train", 180 | "val": "validation", 181 | "trainval": "train+validation", 182 | "test": "test", 183 | "train800": "train[:800]", 184 | "val200": "validation[:200]", 185 | "train800val200": "train[:800]+validation[:200]", 186 | } 187 | 188 | # Example counts are retrieved from the tensorflow dataset info. 189 | train_count = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples 190 | val_count = dataset_builder.info.splits[tfds.Split.VALIDATION].num_examples 191 | test_count = dataset_builder.info.splits[tfds.Split.TEST].num_examples 192 | # Creates a dict with example counts for each split. 193 | num_samples_splits = { 194 | "train": train_count, 195 | "val": val_count, 196 | "trainval": train_count + val_count, 197 | "test": test_count, 198 | "train800": 800, 199 | "val200": 200, 200 | "train800val200": 1000, 201 | } 202 | 203 | task = _TASK_DICT[task] 204 | base_preprocess_fn = task["preprocess_fn"] 205 | super(KittiData, self).__init__( 206 | dataset_builder=dataset_builder, 207 | tfds_splits=tfds_splits, 208 | num_samples_splits=num_samples_splits, 209 | num_preprocessing_threads=400, 210 | shuffle_buffer_size=10000, 211 | base_preprocess_fn=base_preprocess_fn, 212 | num_classes=task["num_classes"]) 213 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/oxford_flowers102.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements oxford flowers 102 data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | 28 | 29 | @Registry.register("data.oxford_flowers102", "class") 30 | class OxfordFlowers102Data(base.ImageTfdsData): 31 | """Provides Oxford 102 categories flowers dataset. 32 | 33 | See corresponding tfds dataset for details. 34 | 35 | URL: https://www.robots.ox.ac.uk/~vgg/data/flowers/102/ 36 | """ 37 | 38 | def __init__(self, data_dir=None, train_split_percent=None): 39 | dataset_builder = tfds.builder("oxford_flowers102:2.*.*", data_dir=data_dir) 40 | dataset_builder.download_and_prepare() 41 | 42 | # Example counts are retrieved from the tensorflow dataset info. 43 | train_count = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples 44 | val_count = dataset_builder.info.splits[tfds.Split.VALIDATION].num_examples 45 | test_count = dataset_builder.info.splits[tfds.Split.TEST].num_examples 46 | 47 | if train_split_percent: 48 | tfds_splits = { 49 | "train": "train[:{s}%]+validation[:{s}%]".format( 50 | s=train_split_percent), 51 | "val": "train[-{s}%:]+validation[-{s}%:]".format( 52 | s=train_split_percent), 53 | "trainval": "train+validation", 54 | "test": "test", 55 | "train800": "train[:800]", 56 | "val200": "validation[:200]", 57 | "train800val200": "train[:800]+validation[:200]", 58 | } 59 | num_samples_splits = { 60 | "train": (((train_count + val_count) // 100) 61 | * train_split_percent), 62 | "val": (((train_count + val_count) // 100) * 63 | (100 - train_split_percent)), 64 | "trainval": train_count + val_count, 65 | "test": test_count, 66 | "train800": 800, 67 | "val200": 200, 68 | "train800val200": 1000, 69 | } 70 | else: 71 | tfds_splits = { 72 | "train": "train", 73 | "val": "validation", 74 | "trainval": "train+validation", 75 | "test": "test", 76 | "train800": "train[:800]", 77 | "val200": "validation[:200]", 78 | "train800val200": "train[:800]+validation[:200]", 79 | } 80 | num_samples_splits = { 81 | "train": train_count, 82 | "val": val_count, 83 | "trainval": train_count + val_count, 84 | "test": test_count, 85 | "train800": 800, 86 | "val200": 200, 87 | "train800val200": 1000, 88 | } 89 | 90 | super(OxfordFlowers102Data, self).__init__( 91 | dataset_builder=dataset_builder, 92 | tfds_splits=tfds_splits, 93 | num_samples_splits=num_samples_splits, 94 | num_preprocessing_threads=400, 95 | shuffle_buffer_size=10000, 96 | # Note: Rename tensors but keep their original types. 97 | base_preprocess_fn=base.make_get_and_cast_tensors_fn({ 98 | "image": ("image", None), 99 | "label": ("label", None), 100 | }), 101 | num_classes=dataset_builder.info.features["label"] 102 | .num_classes) 103 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/oxford_iiit_pet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements OxfordIIITPet data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | # This constant specifies the percentage of data that is used to create custom 28 | # train/val splits. Specifically, TRAIN_SPLIT_PERCENT% of the official training 29 | # split is used as a new training split and the rest is used for validation. 30 | TRAIN_SPLIT_PERCENT = 80 31 | 32 | 33 | @Registry.register("data.oxford_iiit_pet", "class") 34 | class OxfordIIITPetData(base.ImageTfdsData): 35 | """Provides OxfordIIITPet data. 36 | 37 | The OxfordIIITPet dataset comes only with a training and test set. 38 | Therefore, the validation set is split out of the original training set, and 39 | the remaining examples are used as the "train" split. The "trainval" split 40 | corresponds to the original training set. 41 | 42 | For additional details and usage, see the base class. 43 | """ 44 | 45 | def __init__(self, data_dir=None, train_split_percent=None): 46 | 47 | dataset_builder = tfds.builder("oxford_iiit_pet:3.*.*", data_dir=data_dir) 48 | dataset_builder.download_and_prepare() 49 | train_split_percent = train_split_percent or TRAIN_SPLIT_PERCENT 50 | 51 | # Creates a dict with example counts for each split. 52 | trainval_count = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples 53 | test_count = dataset_builder.info.splits[tfds.Split.TEST].num_examples 54 | num_samples_splits = { 55 | "train": (train_split_percent * trainval_count) // 100, 56 | "val": trainval_count - (train_split_percent * trainval_count) // 100, 57 | "trainval": trainval_count, 58 | "test": test_count, 59 | "train800": 800, 60 | "val200": 200, 61 | "train800val200": 1000, 62 | } 63 | 64 | # Defines dataset specific train/val/trainval/test splits. 65 | tfds_splits = { 66 | "train": "train[:{}]".format(num_samples_splits["train"]), 67 | "val": "train[{}:]".format(num_samples_splits["train"]), 68 | "trainval": tfds.Split.TRAIN, 69 | "test": tfds.Split.TEST, 70 | "train800": "train[:800]", 71 | "val200": "train[{}:{}]".format( 72 | num_samples_splits["train"], num_samples_splits["train"]+200), 73 | "train800val200": "train[:800]+train[{}:{}]".format( 74 | num_samples_splits["train"], num_samples_splits["train"]+200), 75 | } 76 | 77 | super(OxfordIIITPetData, self).__init__( 78 | dataset_builder=dataset_builder, 79 | tfds_splits=tfds_splits, 80 | num_samples_splits=num_samples_splits, 81 | num_preprocessing_threads=400, 82 | shuffle_buffer_size=10000, 83 | # Note: Export only image and label tensors with their original types. 84 | base_preprocess_fn=base.make_get_tensors_fn(["image", "label"]), 85 | num_classes=dataset_builder.info.features["label"].num_classes) 86 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/patch_camelyon.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements PatchCamelyon data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | 28 | 29 | @Registry.register("data.patch_camelyon", "class") 30 | class PatchCamelyonData(base.ImageTfdsData): 31 | """Provides PatchCamelyon data.""" 32 | 33 | def __init__(self, data_dir=None): 34 | 35 | dataset_builder = tfds.builder("patch_camelyon:2.*.*", data_dir=data_dir) 36 | dataset_builder.download_and_prepare() 37 | 38 | # Defines dataset specific train/val/trainval/test splits. 39 | tfds_splits = { 40 | "test": "test", 41 | "train": "train", 42 | "val": "validation", 43 | "trainval": "train+validation", 44 | "train800": "train[:800]", 45 | "val200": "validation[:200]", 46 | "train800val200": "train[:800]+validation[:200]", 47 | } 48 | # Creates a dict with example counts. 49 | num_samples_splits = { 50 | "test": dataset_builder.info.splits["test"].num_examples, 51 | "train": dataset_builder.info.splits["train"].num_examples, 52 | "val": dataset_builder.info.splits["validation"].num_examples, 53 | "train800": 800, 54 | "val200": 200, 55 | "train800val200": 1000, 56 | } 57 | num_samples_splits["trainval"] = ( 58 | num_samples_splits["train"] + num_samples_splits["val"]) 59 | super(PatchCamelyonData, self).__init__( 60 | dataset_builder=dataset_builder, 61 | tfds_splits=tfds_splits, 62 | num_samples_splits=num_samples_splits, 63 | num_preprocessing_threads=400, 64 | shuffle_buffer_size=10000, 65 | # Note: Export only image and label tensors with their original types. 66 | base_preprocess_fn=base.make_get_tensors_fn(["image", "label"]), 67 | num_classes=dataset_builder.info.features["label"].num_classes) 68 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/registry.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Global Registry for the task adaptation framework. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import ast 25 | import functools 26 | 27 | 28 | def partialclass(cls, *base_args, **base_kwargs): 29 | """Builds a subclass with partial application of the given args and keywords. 30 | 31 | Equivalent to functools.partial performance, base_args are preprended to the 32 | positional arguments given during object initialization and base_kwargs are 33 | updated with the kwargs given later. 34 | 35 | Args: 36 | cls: The base class. 37 | *base_args: Positional arguments to be applied to the subclass. 38 | **base_kwargs: Keyword arguments to be applied to the subclass. 39 | 40 | Returns: 41 | A subclass of the input class. 42 | """ 43 | 44 | class _NewClass(cls): 45 | 46 | def __init__(self, *args, **kwargs): 47 | bound_args = base_args + args 48 | bound_kwargs = base_kwargs.copy() 49 | bound_kwargs.update(kwargs) 50 | super(_NewClass, self).__init__(*bound_args, **bound_kwargs) 51 | 52 | return _NewClass 53 | 54 | 55 | def parse_name(string_to_parse): 56 | """Parses input to the registry's lookup function. 57 | 58 | Args: 59 | string_to_parse: can be either an arbitrary name or function call 60 | (optionally with positional and keyword arguments). 61 | e.g. "multiclass", "resnet50_v2(filters_factor=8)". 62 | 63 | Returns: 64 | A tuple of input name and a dctinary with arguments. Examples: 65 | "multiclass" -> ("multiclass", (), {}) 66 | "resnet50_v2(9, filters_factor=4)" -> 67 | ("resnet50_v2", (9,), {"filters_factor": 4}) 68 | """ 69 | expr = ast.parse(string_to_parse, mode="eval").body # pytype: disable=attribute-error 70 | if not isinstance(expr, (ast.Attribute, ast.Call, ast.Name)): 71 | raise ValueError( 72 | "The given string should be a name or a call, but a {} was parsed from " 73 | "the string {!r}".format(type(expr), string_to_parse)) 74 | 75 | # Notes: 76 | # name="some_name" -> type(expr) = ast.Name 77 | # name="module.some_name" -> type(expr) = ast.Attribute 78 | # name="some_name()" -> type(expr) = ast.Call 79 | # name="module.some_name()" -> type(expr) = ast.Call 80 | 81 | if isinstance(expr, ast.Name): 82 | return string_to_parse, {} 83 | elif isinstance(expr, ast.Attribute): 84 | return string_to_parse, {} 85 | 86 | def _get_func_name(expr): 87 | if isinstance(expr, ast.Attribute): 88 | return _get_func_name(expr.value) + "." + expr.attr 89 | elif isinstance(expr, ast.Name): 90 | return expr.id 91 | else: 92 | raise ValueError( 93 | "Type {!r} is not supported in a function name, the string to parse " 94 | "was {!r}".format(type(expr), string_to_parse)) 95 | 96 | def _get_func_args_and_kwargs(call): 97 | args = tuple([ast.literal_eval(arg) for arg in call.args]) 98 | kwargs = { 99 | kwarg.arg: ast.literal_eval(kwarg.value) for kwarg in call.keywords 100 | } 101 | return args, kwargs 102 | 103 | func_name = _get_func_name(expr.func) 104 | func_args, func_kwargs = _get_func_args_and_kwargs(expr) 105 | if func_args: 106 | raise ValueError("Positional arguments are not supported here, but these " 107 | "were found: {!r}".format(func_args)) 108 | 109 | return func_name, func_kwargs 110 | 111 | 112 | class Registry(object): 113 | """Implements global Registry.""" 114 | 115 | _GLOBAL_REGISTRY = {} 116 | 117 | @staticmethod 118 | def global_registry(): 119 | return Registry._GLOBAL_REGISTRY 120 | 121 | @staticmethod 122 | def register(name, item_type): 123 | """Creates a function that registers its input.""" 124 | if item_type not in ["function", "class"]: 125 | raise ValueError("Unknown item type: %s" % item_type) 126 | 127 | def _register(item): 128 | if name in Registry.global_registry(): 129 | raise KeyError( 130 | "The name {!r} was already registered in with type {!r}".format( 131 | name, item_type)) 132 | 133 | Registry.global_registry()[name] = (item, item_type) 134 | return item 135 | 136 | return _register 137 | 138 | @staticmethod 139 | def lookup(lookup_string, kwargs_extra=None): 140 | """Lookup a name in the registry.""" 141 | 142 | name, kwargs = parse_name(lookup_string) 143 | if kwargs_extra: 144 | kwargs.update(kwargs_extra) 145 | item, item_type = Registry.global_registry()[name] 146 | if item_type == "function": 147 | return functools.partial(item, **kwargs) 148 | elif item_type == "class": 149 | return partialclass(item, **kwargs) 150 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/resisc45.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements RESISC-45 data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | TRAIN_SPLIT_PERCENT = 60 28 | VALIDATION_SPLIT_PERCENT = 20 29 | TEST_SPLIT_PERCENT = 20 30 | 31 | 32 | @Registry.register("data.resisc45", "class") 33 | class Resisc45Data(base.ImageTfdsData): 34 | """Provides RESISC-45 dataset. 35 | 36 | RESISC45 dataset is a publicly available benchmark for Remote Sensing Image 37 | Scene Classification (RESISC), created by Northwestern Polytechnical 38 | University (NWPU). This dataset contains 31,500 images, covering 45 scene 39 | classes with 700 images in each class. 40 | 41 | URL: http://www.escience.cn/people/JunweiHan/NWPU-RESISC45.html 42 | """ 43 | 44 | def __init__(self, data_dir=None): 45 | dataset_builder = tfds.builder("resisc45:3.*.*", data_dir=data_dir) 46 | dataset_builder.download_and_prepare() 47 | 48 | # Example counts are retrieved from the tensorflow dataset info. 49 | num_examples = dataset_builder.info.splits["train"].num_examples 50 | train_count = num_examples * TRAIN_SPLIT_PERCENT // 100 51 | val_count = num_examples * VALIDATION_SPLIT_PERCENT // 100 52 | test_count = num_examples * TEST_SPLIT_PERCENT // 100 53 | 54 | tfds_splits = { 55 | "train": 56 | "train[:{}]".format(train_count), 57 | "val": 58 | "train[{}:{}]".format(train_count, train_count + val_count), 59 | "trainval": 60 | "train[:{}]".format(train_count + val_count), 61 | "test": 62 | "train[{}:]".format(train_count + val_count), 63 | "train800": 64 | "train[:800]", 65 | "val200": 66 | "train[{}:{}]".format(train_count, train_count+200), 67 | "train800val200": 68 | "train[:800]+train[{}:{}]".format(train_count, train_count+200), 69 | } 70 | 71 | # Creates a dict with example counts for each split. 72 | num_samples_splits = { 73 | "train": train_count, 74 | "val": val_count, 75 | "trainval": train_count + val_count, 76 | "test": test_count, 77 | "train800": 800, 78 | "val200": 200, 79 | "train800val200": 1000, 80 | } 81 | 82 | super(Resisc45Data, self).__init__( 83 | dataset_builder=dataset_builder, 84 | tfds_splits=tfds_splits, 85 | num_samples_splits=num_samples_splits, 86 | num_preprocessing_threads=400, 87 | shuffle_buffer_size=10000, 88 | # Note: Rename tensors but keep their original types. 89 | base_preprocess_fn=base.make_get_and_cast_tensors_fn({ 90 | "image": ("image", None), 91 | "label": ("label", None), 92 | }), 93 | num_classes=dataset_builder.info.features["label"].num_classes) 94 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/smallnorb.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements the SmallNORB data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow.compat.v1 as tf 24 | import tensorflow_datasets as tfds 25 | 26 | from . import base as base 27 | from .registry import Registry 28 | # This constant specifies the percentage of data that is used to create custom 29 | # val/test splits. Specifically, VAL_SPLIT_PERCENT% of the official testing 30 | # split is used as a new validation split and the rest is used for testing. 31 | VAL_SPLIT_PERCENT = 50 32 | 33 | 34 | @Registry.register("data.smallnorb", "class") 35 | class SmallNORBData(base.ImageTfdsData): 36 | """Provides the SmallNORB data set. 37 | 38 | SmallNORB comes only with a training and test set. Therefore, the validation 39 | set is split out of the original training set, and the remaining examples are 40 | used as the "train" split. The "trainval" split corresponds to the original 41 | training set. 42 | 43 | For additional details and usage, see the base class. 44 | 45 | The data set page is https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/. 46 | """ 47 | 48 | def __init__(self, predicted_attribute, data_dir=None): 49 | dataset_builder = tfds.builder("smallnorb:2.*.*", data_dir=data_dir) 50 | dataset_builder.download_and_prepare() 51 | 52 | if predicted_attribute not in dataset_builder.info.features: 53 | raise ValueError( 54 | "{} is not a valid attribute to predict.".format(predicted_attribute)) 55 | 56 | # Defines dataset specific train/val/trainval/test splits. 57 | tfds_splits = { 58 | "train": "train", 59 | "val": "test[:{}%]".format(VAL_SPLIT_PERCENT), 60 | "trainval": "train+test[:{}%]".format(VAL_SPLIT_PERCENT), 61 | "test": "test[{}%:]".format(VAL_SPLIT_PERCENT), 62 | "train800": "train[:800]", 63 | "val200": "test[:200]", 64 | "train800val200": "train[:800]+test[:200]", 65 | } 66 | 67 | # Creates a dict with example counts for each split. 68 | train_count = dataset_builder.info.splits["train"].num_examples 69 | test_count = dataset_builder.info.splits["test"].num_examples 70 | num_samples_validation = VAL_SPLIT_PERCENT * test_count // 100 71 | num_samples_splits = { 72 | "train": train_count, 73 | "val": num_samples_validation, 74 | "trainval": train_count + num_samples_validation, 75 | "test": test_count - num_samples_validation, 76 | "train800": 800, 77 | "val200": 200, 78 | "train800val200": 1000, 79 | } 80 | 81 | def preprocess_fn(tensors): 82 | # For consistency with other datasets, image needs to have three channels. 83 | image = tf.tile(tensors["image"], [1, 1, 3]) 84 | return dict(image=image, label=tensors[predicted_attribute]) 85 | 86 | info = dataset_builder.info 87 | super(SmallNORBData, self).__init__( 88 | dataset_builder=dataset_builder, 89 | tfds_splits=tfds_splits, 90 | num_samples_splits=num_samples_splits, 91 | num_preprocessing_threads=400, 92 | shuffle_buffer_size=10000, 93 | # We extract the attribute we want to predict in the preprocessing. 94 | base_preprocess_fn=preprocess_fn, 95 | num_classes=info.features[predicted_attribute].num_classes) 96 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/sun397.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements Sun397 data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | CUSTOM_TRAIN_SPLIT_PERCENT = 50 28 | CUSTOM_VALIDATION_SPLIT_PERCENT = 20 29 | CUSTOM_TEST_SPLIT_PERCENT = 30 30 | 31 | 32 | @Registry.register("data.sun397", "class") 33 | class Sun397Data(base.ImageTfdsData): 34 | """Provides Sun397Data data.""" 35 | 36 | def __init__(self, config="tfds", data_dir=None): 37 | 38 | if config == "tfds": 39 | dataset_builder = tfds.builder("sun397/tfds:4.*.*", data_dir=data_dir) 40 | dataset_builder.download_and_prepare() 41 | 42 | tfds_splits = { 43 | "train": "train", 44 | "val": "validation", 45 | "test": "test", 46 | "trainval": "train+validation", 47 | "train800": "train[:800]", 48 | "val200": "validation[:200]", 49 | "train800val200": "train[:800]+validation[:200]", 50 | } 51 | # Creates a dict with example counts. 52 | num_samples_splits = { 53 | "test": dataset_builder.info.splits["test"].num_examples, 54 | "train": dataset_builder.info.splits["train"].num_examples, 55 | "val": dataset_builder.info.splits["validation"].num_examples, 56 | "train800": 800, 57 | "val200": 200, 58 | "train800val200": 1000, 59 | } 60 | num_samples_splits["trainval"] = ( 61 | num_samples_splits["train"] + num_samples_splits["val"]) 62 | else: 63 | 64 | raise ValueError("No supported config %r for Sun397Data." % config) 65 | 66 | super(Sun397Data, self).__init__( 67 | dataset_builder=dataset_builder, 68 | tfds_splits=tfds_splits, 69 | num_samples_splits=num_samples_splits, 70 | num_preprocessing_threads=400, 71 | shuffle_buffer_size=10000, 72 | # Note: Export only image and label tensors with their original types. 73 | base_preprocess_fn=base.make_get_tensors_fn(["image", "label"]), 74 | num_classes=dataset_builder.info.features["label"].num_classes) 75 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/svhn.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements Svhn data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | # This constant specifies the percentage of data that is used to create custom 28 | # train/val splits. Specifically, TRAIN_SPLIT_PERCENT% of the official training 29 | # split is used as a new training split and the rest is used for validation. 30 | TRAIN_SPLIT_PERCENT = 90 31 | 32 | 33 | @Registry.register("data.svhn", "class") 34 | class SvhnData(base.ImageTfdsData): 35 | """Provides SVHN data. 36 | 37 | The Street View House Numbers (SVHN) Dataset is an image digit recognition 38 | dataset of over 600,000 color digit images coming from real world data. 39 | Split size: 40 | - Training set: 73,257 images 41 | - Testing set: 26,032 images 42 | - Extra training set: 531,131 images 43 | Following the common setup on SVHN, we only use the official training and 44 | testing data. Images are cropped to 32x32. 45 | 46 | URL: http://ufldl.stanford.edu/housenumbers/ 47 | """ 48 | 49 | def __init__(self, data_dir=None): 50 | dataset_builder = tfds.builder("svhn_cropped:3.*.*", data_dir=data_dir) 51 | dataset_builder.download_and_prepare() 52 | 53 | # Example counts are retrieved from the tensorflow dataset info. 54 | trainval_count = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples 55 | test_count = dataset_builder.info.splits[tfds.Split.TEST].num_examples 56 | 57 | # Creates a dict with example counts for each split. 58 | num_samples_splits = { 59 | # Calculates the train/val split example count based on percent. 60 | "train": TRAIN_SPLIT_PERCENT * trainval_count // 100, 61 | "val": trainval_count - TRAIN_SPLIT_PERCENT * trainval_count // 100, 62 | "trainval": trainval_count, 63 | "test": test_count, 64 | "train800": 800, 65 | "val200": 200, 66 | "train800val200": 1000, 67 | } 68 | 69 | # Defines dataset specific train/val/trainval/test splits. 70 | # The validation set is split out of the original training set, and the 71 | # remaining examples are used as the "train" split. The "trainval" split 72 | # corresponds to the original training set. 73 | tfds_splits = { 74 | "train": 75 | "train[:{}]".format(num_samples_splits["train"]), 76 | "val": 77 | "train[{}:]".format(num_samples_splits["train"]), 78 | "trainval": 79 | "train", 80 | "test": 81 | "test", 82 | "train800": 83 | "train[:800]", 84 | "val200": 85 | "train[{}:{}]".format(num_samples_splits["train"], 86 | num_samples_splits["train"] + 200), 87 | "train800val200": 88 | "train[:800]+train[{}:{}]".format( 89 | num_samples_splits["train"], num_samples_splits["train"] + 200), 90 | } 91 | 92 | super(SvhnData, self).__init__( 93 | dataset_builder=dataset_builder, 94 | tfds_splits=tfds_splits, 95 | num_samples_splits=num_samples_splits, 96 | num_preprocessing_threads=400, 97 | shuffle_buffer_size=10000, 98 | # Note: Rename tensors but keep their original types. 99 | base_preprocess_fn=base.make_get_and_cast_tensors_fn({ 100 | "image": ("image", None), 101 | "label": ("label", None), 102 | }), 103 | num_classes=dataset_builder.info.features["label"] 104 | .num_classes) 105 | -------------------------------------------------------------------------------- /src/engine/eval/multilabel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | evaluate precision@1, @5 equal to Top1 and Top5 error rate 4 | """ 5 | import numpy as np 6 | from typing import List, Tuple, Dict 7 | from sklearn.metrics import ( 8 | precision_recall_curve, 9 | average_precision_score, 10 | f1_score 11 | ) 12 | 13 | 14 | def get_continuous_ids(probe_labels: List[int]) -> Dict[int, int]: 15 | sorted(probe_labels) 16 | id2continuousid = {} 17 | for idx, p_id in enumerate(probe_labels): 18 | id2continuousid[p_id] = idx 19 | return id2continuousid 20 | 21 | 22 | def multihot(x: List[List[int]], nb_classes: int) -> np.ndarray: 23 | """transform to multihot encoding 24 | 25 | Arguments: 26 | x: list of multi-class integer labels, in the range 27 | [0, nb_classes-1] 28 | nb_classes: number of classes for the multi-hot vector 29 | 30 | Returns: 31 | multihot: multihot vector of type int, (num_samples, nb_classes) 32 | """ 33 | num_samples = len(x) 34 | 35 | multihot = np.zeros((num_samples, nb_classes), dtype=np.int32) 36 | for idx, labs in enumerate(x): 37 | for lab in labs: 38 | multihot[idx, lab] = 1 39 | 40 | return multihot.astype(np.int) 41 | 42 | 43 | def compute_map( 44 | scores: np.ndarray, multihot_targets: np.ndarray 45 | ) -> Tuple[np.ndarray, np.ndarray, float, float]: 46 | """Compute the mean average precision across all class labels. 47 | 48 | Arguments: 49 | scores: matrix of per-class distances, 50 | of size num_samples x nb_classes 51 | multihot_targets: matrix of multi-hot target predictions, 52 | of size num_samples x nb_classes 53 | 54 | Returns: 55 | ap: list of average-precision scores, one for each of 56 | the nb_classes classes. 57 | ar: list of average-recall scores, one for each of 58 | the nb_classes classes. 59 | mAP: the mean average precision score over all average 60 | precisions for all nb_classes classes. 61 | mAR: the mean average recall score over all average 62 | precisions for all nb_classes classes. 63 | """ 64 | nb_classes = scores.shape[1] 65 | 66 | ap = np.zeros((nb_classes,), dtype=np.float32) 67 | ar = np.zeros((nb_classes,), dtype=np.float32) 68 | 69 | for c in range(nb_classes): 70 | y_true = multihot_targets[:, c] 71 | y_scores = scores[:, c] 72 | 73 | # Use interpolated average precision (a la PASCAL 74 | try: 75 | ap[c] = average_precision_score(y_true, y_scores) 76 | except ValueError: 77 | ap[c] = -1 78 | 79 | # Also get the average of the recalls on the raw PR-curve 80 | try: 81 | _, rec, _ = precision_recall_curve(y_true, y_scores) 82 | ar[c] = rec.mean() 83 | except ValueError: 84 | ar[c] = -1 85 | 86 | mAP = ap.mean() 87 | mAR = ar.mean() 88 | 89 | return ap, ar, mAP, mAR 90 | 91 | 92 | def compute_f1( 93 | multihot_targets: np.ndarray, scores: np.ndarray, threshold: float = 0.5 94 | ) -> Tuple[float, float, float]: 95 | # change scores to predict_labels 96 | predict_labels = scores > threshold 97 | predict_labels = predict_labels.astype(np.int) 98 | 99 | # change targets to multihot 100 | f1 = {} 101 | f1["micro"] = f1_score( 102 | y_true=multihot_targets, 103 | y_pred=predict_labels, 104 | average="micro" 105 | ) 106 | f1["samples"] = f1_score( 107 | y_true=multihot_targets, 108 | y_pred=predict_labels, 109 | average="samples" 110 | ) 111 | f1["macro"] = f1_score( 112 | y_true=multihot_targets, 113 | y_pred=predict_labels, 114 | average="macro" 115 | ) 116 | f1["none"] = f1_score( 117 | y_true=multihot_targets, 118 | y_pred=predict_labels, 119 | average=None 120 | ) 121 | return f1["micro"], f1["samples"], f1["macro"], f1["none"] 122 | 123 | 124 | def get_best_f1_scores( 125 | multihot_targets: np.ndarray, scores: np.ndarray, threshold_end: int 126 | ) -> Dict[str, float]: 127 | end = 0.5 128 | end = 0.05 129 | end = threshold_end 130 | thrs = np.linspace( 131 | end, 0.95, int(np.round((0.95 - end) / 0.05)) + 1, endpoint=True 132 | ) 133 | f1_micros = [] 134 | f1_macros = [] 135 | f1_samples = [] 136 | f1_none = [] 137 | for thr in thrs: 138 | _micros, _samples, _macros, _none = compute_f1(multihot_targets, scores, thr) 139 | f1_micros.append(_micros) 140 | f1_samples.append(_samples) 141 | f1_macros.append(_macros) 142 | f1_none.append(_none) 143 | 144 | f1_macros_m = max(f1_macros) 145 | b_thr = np.argmax(f1_macros) 146 | 147 | f1_micros_m = f1_micros[b_thr] 148 | f1_samples_m = f1_samples[b_thr] 149 | f1_none_m = f1_none[b_thr] 150 | f1 = {} 151 | f1["micro"] = f1_micros_m 152 | f1["macro"] = f1_macros_m 153 | f1["samples"] = f1_samples_m 154 | f1["threshold"] = thrs[b_thr] 155 | f1["none"] = f1_none_m 156 | return f1 157 | -------------------------------------------------------------------------------- /src/engine/eval/singlelabel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Functions for computing metrics. all metrics has range of 0-1""" 4 | 5 | import numpy as np 6 | import torch 7 | from sklearn.metrics import ( 8 | accuracy_score, average_precision_score, f1_score, roc_auc_score 9 | ) 10 | 11 | 12 | def accuracy(y_probs, y_true): 13 | # y_prob: (num_images, num_classes) 14 | y_preds = np.argmax(y_probs, axis=1) 15 | accuracy = accuracy_score(y_true, y_preds) 16 | error = 1.0 - accuracy 17 | return accuracy, error 18 | 19 | 20 | def top_n_accuracy(y_probs, truths, n=1): 21 | # y_prob: (num_images, num_classes) 22 | # truth: (num_images, num_classes) multi/one-hot encoding 23 | best_n = np.argsort(y_probs, axis=1)[:, -n:] 24 | if isinstance(truths, np.ndarray) and truths.shape == y_probs.shape: 25 | ts = np.argmax(truths, axis=1) 26 | else: 27 | # a list of GT class idx 28 | ts = truths 29 | 30 | num_input = y_probs.shape[0] 31 | successes = 0 32 | for i in range(num_input): 33 | if ts[i] in best_n[i, :]: 34 | successes += 1 35 | return float(successes) / num_input 36 | 37 | 38 | def compute_acc_auc(y_probs, y_true_ids): 39 | onehot_tgts = np.zeros_like(y_probs) 40 | for idx, t in enumerate(y_true_ids): 41 | onehot_tgts[idx, t] = 1. 42 | 43 | num_classes = y_probs.shape[1] 44 | if num_classes == 2: 45 | top1, _ = accuracy(y_probs, y_true_ids) 46 | # so precision can set all to 2 47 | try: 48 | auc = roc_auc_score(onehot_tgts, y_probs, average='macro') 49 | except ValueError as e: 50 | print(f"value error encountered {e}, set auc sccore to -1.") 51 | auc = -1 52 | return {"top1": top1, "rocauc": auc} 53 | 54 | top1, _ = accuracy(y_probs, y_true_ids) 55 | k = min([5, num_classes]) # if number of labels < 5, use the total class 56 | top5 = top_n_accuracy(y_probs, y_true_ids, k) 57 | return {"top1": top1, f"top{k}": top5} 58 | 59 | 60 | def topks_correct(preds, labels, ks): 61 | """Computes the number of top-k correct predictions for each k.""" 62 | assert preds.size(0) == labels.size( 63 | 0 64 | ), "Batch dim of predictions and labels must match" 65 | # Find the top max_k predictions for each sample 66 | _top_max_k_vals, top_max_k_inds = torch.topk( 67 | preds, max(ks), dim=1, largest=True, sorted=True 68 | ) 69 | # (batch_size, max_k) -> (max_k, batch_size) 70 | top_max_k_inds = top_max_k_inds.t() 71 | # (batch_size, ) -> (max_k, batch_size) 72 | rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds) 73 | # (i, j) = 1 if top i-th prediction for the j-th sample is correct 74 | top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels) 75 | # Compute the number of topk correct predictions for each k 76 | topks_correct = [ 77 | top_max_k_correct[:k, :].reshape(-1).float().sum() for k in ks 78 | ] 79 | return topks_correct 80 | 81 | 82 | def topk_errors(preds, labels, ks): 83 | """Computes the top-k error for each k.""" 84 | if int(labels.min()) < 0: # has ignore 85 | keep_ids = np.where(labels.cpu() >= 0)[0] 86 | preds = preds[keep_ids, :] 87 | labels = labels[keep_ids] 88 | 89 | num_topks_correct = topks_correct(preds, labels, ks) 90 | return [(1.0 - x / preds.size(0)) for x in num_topks_correct] 91 | 92 | 93 | def topk_accuracies(preds, labels, ks): 94 | """Computes the top-k accuracy for each k.""" 95 | num_topks_correct = topks_correct(preds, labels, ks) 96 | return [(x / preds.size(0)) for x in num_topks_correct] 97 | 98 | -------------------------------------------------------------------------------- /src/engine/evaluator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import numpy as np 3 | 4 | from collections import defaultdict 5 | from typing import List, Union 6 | 7 | from .eval import multilabel 8 | from .eval import singlelabel 9 | from ..utils import logging 10 | logger = logging.get_logger("visual_prompt") 11 | 12 | 13 | class Evaluator(): 14 | """ 15 | An evaluator with below logics: 16 | 17 | 1. find which eval module to use. 18 | 2. store the eval results, pretty print it in log file as well. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | ) -> None: 24 | self.results = defaultdict(dict) 25 | self.iteration = -1 26 | self.threshold_end = 0.5 27 | 28 | def update_iteration(self, iteration: int) -> None: 29 | """update iteration info""" 30 | self.iteration = iteration 31 | 32 | def update_result(self, metric: str, value: Union[float, dict]) -> None: 33 | if self.iteration > -1: 34 | key_name = "epoch_" + str(self.iteration) 35 | else: 36 | key_name = "final" 37 | if isinstance(value, float): 38 | self.results[key_name].update({metric: value}) 39 | else: 40 | if metric in self.results[key_name]: 41 | self.results[key_name][metric].update(value) 42 | else: 43 | self.results[key_name].update({metric: value}) 44 | 45 | def classify(self, probs, targets, test_data, multilabel=False): 46 | """ 47 | Evaluate classification result. 48 | Args: 49 | probs: np.ndarray for num_data x num_class, predicted probabilities 50 | targets: np.ndarray for multilabel, list of integers for single label 51 | test_labels: map test image ids to a list of class labels 52 | """ 53 | if not targets: 54 | raise ValueError( 55 | "When evaluating classification, need at least give targets") 56 | 57 | if multilabel: 58 | self._eval_multilabel(probs, targets, test_data) 59 | else: 60 | self._eval_singlelabel(probs, targets, test_data) 61 | 62 | def _eval_singlelabel( 63 | self, 64 | scores: np.ndarray, 65 | targets: List[int], 66 | eval_type: str 67 | ) -> None: 68 | """ 69 | if number of labels > 2: 70 | top1 and topk (5 by default) accuracy 71 | if number of labels == 2: 72 | top1 and rocauc 73 | """ 74 | acc_dict = singlelabel.compute_acc_auc(scores, targets) 75 | 76 | log_results = { 77 | k: np.around(v * 100, decimals=2) for k, v in acc_dict.items() 78 | } 79 | save_results = acc_dict 80 | 81 | self.log_and_update(log_results, save_results, eval_type) 82 | 83 | def _eval_multilabel( 84 | self, 85 | scores: np.ndarray, 86 | targets: np.ndarray, 87 | eval_type: str 88 | ) -> None: 89 | num_labels = scores.shape[-1] 90 | targets = multilabel.multihot(targets, num_labels) 91 | 92 | log_results = {} 93 | ap, ar, mAP, mAR = multilabel.compute_map(scores, targets) 94 | f1_dict = multilabel.get_best_f1_scores( 95 | targets, scores, self.threshold_end) 96 | 97 | log_results["mAP"] = np.around(mAP * 100, decimals=2) 98 | log_results["mAR"] = np.around(mAR * 100, decimals=2) 99 | log_results.update({ 100 | k: np.around(v * 100, decimals=2) for k, v in f1_dict.items()}) 101 | save_results = { 102 | "ap": ap, "ar": ar, "mAP": mAP, "mAR": mAR, "f1": f1_dict 103 | } 104 | self.log_and_update(log_results, save_results, eval_type) 105 | 106 | def log_and_update(self, log_results, save_results, eval_type): 107 | log_str = "" 108 | for k, result in log_results.items(): 109 | if not isinstance(result, np.ndarray): 110 | log_str += f"{k}: {result:.2f}\t" 111 | else: 112 | log_str += f"{k}: {list(result)}\t" 113 | logger.info(f"Classification results with {eval_type}: {log_str}") 114 | # save everything 115 | self.update_result("classification", {eval_type: save_results}) 116 | -------------------------------------------------------------------------------- /src/models/build_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Model construction functions. 4 | """ 5 | from tabnanny import verbose 6 | import torch 7 | 8 | from .vit_models import SSLViT 9 | from ..utils import logging 10 | logger = logging.get_logger("visual_prompt") 11 | # Supported model types 12 | _MODEL_TYPES = { 13 | "ssl-vit": SSLViT, 14 | } 15 | 16 | 17 | def build_model(cfg): 18 | """ 19 | build model here 20 | """ 21 | assert ( 22 | cfg.MODEL.TYPE in _MODEL_TYPES.keys() 23 | ), "Model type '{}' not supported".format(cfg.MODEL.TYPE) 24 | assert ( 25 | cfg.NUM_GPUS <= torch.cuda.device_count() 26 | ), "Cannot use more GPU devices than available" 27 | 28 | # Construct the model 29 | train_type = cfg.MODEL.TYPE 30 | model = _MODEL_TYPES[train_type](cfg) 31 | 32 | log_model_info(model, verbose=cfg.DBG) 33 | model, device = load_model_to_device(model, cfg) 34 | logger.info(f"Device used for model: {device}") 35 | 36 | return model, device 37 | 38 | 39 | def log_model_info(model, verbose=False): 40 | """Logs model info""" 41 | if verbose: 42 | logger.info(f"Classification Model:\n{model}") 43 | model_total_params = sum(p.numel() for p in model.parameters()) 44 | model_grad_params = sum( 45 | p.numel() for p in model.parameters() if p.requires_grad) 46 | logger.info("Total Parameters: {0}\t Gradient Parameters: {1}".format( 47 | model_total_params, model_grad_params)) 48 | logger.info("tuned percent:%.3f"%(model_grad_params/model_total_params*100)) 49 | 50 | 51 | def get_current_device(): 52 | if torch.cuda.is_available(): 53 | # Determine the GPU used by the current process 54 | cur_device = torch.cuda.current_device() 55 | else: 56 | cur_device = torch.device('cpu') 57 | return cur_device 58 | 59 | 60 | def load_model_to_device(model, cfg): 61 | cur_device = get_current_device() 62 | if torch.cuda.is_available(): 63 | # Transfer the model to the current GPU device 64 | model = model.cuda(device=cur_device) 65 | # Use multi-process data parallel model in the multi-gpu setting 66 | if cfg.NUM_GPUS > 1: 67 | # Make model replica operate on the current device 68 | model = torch.nn.parallel.DistributedDataParallel( 69 | module=model, device_ids=[cur_device], output_device=cur_device, 70 | find_unused_parameters=True, 71 | ) 72 | else: 73 | model = model.to(cur_device) 74 | return model, cur_device 75 | -------------------------------------------------------------------------------- /src/models/build_vit_backbone.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import numpy as np 3 | import torch 4 | import os 5 | from .vit_backbones.vit_moco import vit_base as moco_vit_model 6 | from .vit_backbones.vit_mae import build_model as mae_vit_model 7 | 8 | from .vit_prompt.vit_moco import build_model as prompt_moco_vit 9 | from .vit_prompt.vit_mae import build_model as prompt_mae_vit 10 | 11 | 12 | MODEL_ZOO = { 13 | "mae_vitb16": "mae-ViT-B.pth", 14 | "mae_vitl16": "mae-ViT-L.pth", 15 | "mocov3_vitb16" : "mocov3-ViT-B.pth.tar", 16 | "mocov3_vits16" : "mocov3-ViT-S.pth.tar", 17 | } 18 | 19 | 20 | 21 | def build_mae_model( 22 | model_type, crop_size, prompt_cfg, model_root, adapter_cfg=None 23 | ): 24 | if not model_type in ["mae_vitb16", "mae_vitl16"]: 25 | raise ValueError("Does not support other arch") 26 | if prompt_cfg is not None: 27 | model = prompt_mae_vit(model_type, prompt_cfg) 28 | else: 29 | model = mae_vit_model(model_type) 30 | out_dim = model.embed_dim 31 | 32 | ckpt = os.path.join(model_root, MODEL_ZOO[model_type]) 33 | checkpoint = torch.load(ckpt, map_location="cpu") 34 | state_dict = checkpoint['model'] 35 | 36 | msg = model.load_state_dict(state_dict, strict=False) 37 | print(msg) 38 | model.head = torch.nn.Identity() 39 | return model, out_dim 40 | 41 | 42 | def build_mocov3_model( 43 | model_type, crop_size, prompt_cfg, model_root, adapter_cfg=None 44 | ): 45 | if not model_type in ["mocov3_vitb16", "mocov3_vits16"]: 46 | raise ValueError("Does not support other arch") 47 | if prompt_cfg is not None: 48 | model = prompt_moco_vit(model_type, prompt_cfg) 49 | else: 50 | model = moco_vit_model() 51 | 52 | out_dim = 384 if model_type.endswith('s16') else 768 53 | ckpt = os.path.join(model_root, MODEL_ZOO[model_type]) 54 | checkpoint = torch.load(ckpt, map_location="cpu") 55 | state_dict = checkpoint['state_dict'] 56 | for k in list(state_dict.keys()): 57 | # retain only base_encoder up to before the embedding layer 58 | if k.startswith('module.'): 59 | # remove prefix 60 | key = k.replace('module.', '') 61 | if key.startswith('base_encoder.'): 62 | key = key.replace('base_encoder.', '') 63 | elif key.startswith('momentum'): 64 | del state_dict[k] 65 | continue 66 | state_dict[key] = state_dict[k] 67 | 68 | # delete renamed or unused k 69 | del state_dict[k] 70 | 71 | msg = model.load_state_dict(state_dict, strict=False) 72 | print(msg) 73 | model.head = torch.nn.Identity() 74 | return model, out_dim 75 | 76 | -------------------------------------------------------------------------------- /src/models/mlp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Modified from: fbcode/multimo/models/encoders/mlp.py 4 | """ 5 | import math 6 | import torch 7 | 8 | from torch import nn 9 | from typing import List, Type 10 | 11 | from ..utils import logging 12 | logger = logging.get_logger("visual_prompt") 13 | 14 | 15 | class MLP(nn.Module): 16 | def __init__( 17 | self, 18 | input_dim: int, 19 | mlp_dims: List[int], 20 | dropout: float = 0.1, 21 | nonlinearity: Type[nn.Module] = nn.ReLU, 22 | normalization: Type[nn.Module] = nn.BatchNorm1d, # nn.LayerNorm, 23 | special_bias: bool = False, 24 | add_bn_first: bool = False, 25 | ): 26 | super(MLP, self).__init__() 27 | projection_prev_dim = input_dim 28 | projection_modulelist = [] 29 | last_dim = mlp_dims[-1] 30 | mlp_dims = mlp_dims[:-1] 31 | 32 | if add_bn_first: 33 | if normalization is not None: 34 | projection_modulelist.append(normalization(projection_prev_dim)) 35 | if dropout != 0: 36 | projection_modulelist.append(nn.Dropout(dropout)) 37 | 38 | for idx, mlp_dim in enumerate(mlp_dims): 39 | fc_layer = nn.Linear(projection_prev_dim, mlp_dim) 40 | nn.init.kaiming_normal_(fc_layer.weight, a=0, mode='fan_out') 41 | projection_modulelist.append(fc_layer) 42 | projection_modulelist.append(nonlinearity()) 43 | 44 | if normalization is not None: 45 | projection_modulelist.append(normalization(mlp_dim)) 46 | 47 | if dropout != 0: 48 | projection_modulelist.append(nn.Dropout(dropout)) 49 | projection_prev_dim = mlp_dim 50 | 51 | self.projection = nn.Sequential(*projection_modulelist) 52 | self.last_layer = nn.Linear(projection_prev_dim, last_dim) 53 | nn.init.kaiming_normal_(self.last_layer.weight, a=0, mode='fan_out') 54 | if special_bias: 55 | prior_prob = 0.01 56 | bias_value = -math.log((1 - prior_prob) / prior_prob) 57 | torch.nn.init.constant_(self.last_layer.bias, bias_value) 58 | 59 | def forward(self, x: torch.Tensor) -> torch.Tensor: 60 | """ 61 | input_arguments: 62 | @x: torch.FloatTensor 63 | """ 64 | x = self.projection(x) 65 | x = self.last_layer(x) 66 | return x 67 | -------------------------------------------------------------------------------- /src/models/vit_backbones/vit_mae.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | """ 4 | borrowed from https://github.com/facebookresearch/mae/blob/main/models_vit.py 5 | """ 6 | from functools import partial 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | import timm.models.vision_transformer 12 | from timm.models.layers import Mlp, DropPath 13 | 14 | 15 | # based on timm Attention implementation 16 | class Attention(nn.Module): 17 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 18 | super().__init__() 19 | self.num_heads = num_heads 20 | head_dim = dim // num_heads 21 | self.scale = head_dim ** -0.5 22 | 23 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 24 | self.attn_drop = nn.Dropout(attn_drop) 25 | self.proj = nn.Linear(dim, dim) 26 | self.proj_drop = nn.Dropout(proj_drop) 27 | 28 | def forward(self, x, temp=1.0): 29 | """ 30 | temp = 1.0 by default or learnable scalar 31 | """ 32 | B, N, C = x.shape 33 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 34 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 35 | 36 | attn = (q @ k.transpose(-2, -1)) * self.scale 37 | 38 | attn = (attn / temp).softmax(dim=-1) 39 | attn = self.attn_drop(attn) 40 | 41 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 42 | x = self.proj(x) 43 | x = self.proj_drop(x) 44 | 45 | return x 46 | 47 | 48 | class Block(nn.Module): 49 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 50 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 51 | super().__init__() 52 | self.dim = dim 53 | self.norm1 = norm_layer(dim) 54 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 55 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 56 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 57 | self.norm2 = norm_layer(dim) 58 | mlp_hidden_dim = int(dim * mlp_ratio) 59 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 60 | 61 | def forward(self, x, temp=1.0): 62 | """ 63 | temp = 1.0 by default or learnable scalar 64 | """ 65 | x = x + self.drop_path(self.attn(self.norm1(x), temp=temp)) 66 | x = x + self.drop_path(self.mlp(self.norm2(x))) 67 | return x 68 | 69 | 70 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 71 | """ Vision Transformer with support for global average pooling 72 | """ 73 | def __init__(self, global_pool=False, **kwargs): 74 | super(VisionTransformer, self).__init__(**kwargs) 75 | 76 | self.global_pool = global_pool 77 | 78 | norm_layer = kwargs['norm_layer'] 79 | embed_dim = kwargs['embed_dim'] 80 | 81 | 82 | dpr = [x.item() for x in torch.linspace(0, kwargs['drop_path_rate'], kwargs['depth'])] # stochastic depth decay rule 83 | self.blocks = nn.Sequential(*[ 84 | Block( 85 | dim=embed_dim, num_heads=kwargs['num_heads'], mlp_ratio=kwargs['mlp_ratio'], qkv_bias=kwargs['qkv_bias'], 86 | drop_path=dpr[i], norm_layer=kwargs['norm_layer']) 87 | for i in range(kwargs['depth'])]) 88 | # if pretrained_norm: 89 | self.norm = norm_layer(embed_dim) 90 | # self.fc_norm = norm_layer(embed_dim) 91 | 92 | def forward_features(self, x): 93 | B = x.shape[0] 94 | x = self.patch_embed(x) 95 | 96 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 97 | x = torch.cat((cls_tokens, x), dim=1) 98 | x = x + self.pos_embed 99 | x = self.pos_drop(x) 100 | 101 | for blk in self.blocks: 102 | x, attn = blk(x) 103 | 104 | 105 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 106 | 107 | outcome = self.norm(x) 108 | 109 | return outcome 110 | 111 | 112 | def build_model(model_type): 113 | if "vitb" in model_type: 114 | return vit_base_patch16() 115 | elif "vitl" in model_type: 116 | return vit_large_patch16() 117 | elif "vith" in model_type: 118 | return vit_huge_patch14() 119 | 120 | 121 | def vit_base_patch16(**kwargs): 122 | model = VisionTransformer( 123 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 124 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 125 | mlp_ratio=4, qkv_bias=True, 126 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 127 | return model 128 | 129 | 130 | def vit_large_patch16(**kwargs): 131 | model = VisionTransformer( 132 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 133 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 134 | mlp_ratio=4, qkv_bias=True, 135 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 136 | return model 137 | 138 | 139 | def vit_huge_patch14(**kwargs): 140 | model = VisionTransformer( 141 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 142 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 143 | mlp_ratio=4, qkv_bias=True, 144 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 145 | return model 146 | -------------------------------------------------------------------------------- /src/models/vit_backbones/vit_moco.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | """ 4 | borrowed from https://github.com/facebookresearch/moco-v3/blob/main/vits.py 5 | """ 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | from functools import partial, reduce 10 | from operator import mul 11 | 12 | from timm.models.vision_transformer import _cfg 13 | from timm.models.layers.helpers import to_2tuple 14 | from timm.models.layers import PatchEmbed 15 | from .vit_mae import VisionTransformer 16 | 17 | __all__ = [ 18 | 'vit_small', 19 | 'vit_base', 20 | 'vit_conv_small', 21 | 'vit_conv_base', 22 | ] 23 | 24 | 25 | 26 | class VisionTransformerMoCo(VisionTransformer): 27 | def __init__(self, stop_grad_conv1=False, **kwargs): 28 | super().__init__(**kwargs) 29 | # Use fixed 2D sin-cos position embedding 30 | self.build_2d_sincos_position_embedding() 31 | 32 | # weight initialization 33 | for name, m in self.named_modules(): 34 | if isinstance(m, nn.Linear): 35 | if 'qkv' in name: 36 | # treat the weights of Q, K, V separately 37 | val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1])) 38 | nn.init.uniform_(m.weight, -val, val) 39 | else: 40 | nn.init.xavier_uniform_(m.weight) 41 | nn.init.zeros_(m.bias) 42 | nn.init.normal_(self.cls_token, std=1e-6) 43 | 44 | if isinstance(self.patch_embed, PatchEmbed): 45 | # xavier_uniform initialization 46 | val = math.sqrt(6. / float(3 * reduce(mul, self.patch_embed.patch_size, 1) + self.embed_dim)) 47 | nn.init.uniform_(self.patch_embed.proj.weight, -val, val) 48 | nn.init.zeros_(self.patch_embed.proj.bias) 49 | 50 | if stop_grad_conv1: 51 | self.patch_embed.proj.weight.requires_grad = False 52 | self.patch_embed.proj.bias.requires_grad = False 53 | 54 | def build_2d_sincos_position_embedding(self, temperature=10000.): 55 | h, w = self.patch_embed.grid_size 56 | grid_w = torch.arange(w, dtype=torch.float32) 57 | grid_h = torch.arange(h, dtype=torch.float32) 58 | grid_w, grid_h = torch.meshgrid(grid_w, grid_h) 59 | assert self.embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' 60 | pos_dim = self.embed_dim // 4 61 | omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim 62 | omega = 1. / (temperature**omega) 63 | out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) 64 | out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) 65 | pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :] 66 | 67 | assert self.num_tokens == 1, 'Assuming one and only one token, [cls]' 68 | pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32) 69 | self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1)) 70 | self.pos_embed.requires_grad = False 71 | 72 | 73 | class ConvStem(nn.Module): 74 | """ 75 | ConvStem, from Early Convolutions Help Transformers See Better, Tete et al. https://arxiv.org/abs/2106.14881 76 | """ 77 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 78 | super().__init__() 79 | 80 | assert patch_size == 16, 'ConvStem only supports patch size of 16' 81 | assert embed_dim % 8 == 0, 'Embed dimension must be divisible by 8 for ConvStem' 82 | 83 | img_size = to_2tuple(img_size) 84 | patch_size = to_2tuple(patch_size) 85 | self.img_size = img_size 86 | self.patch_size = patch_size 87 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 88 | self.num_patches = self.grid_size[0] * self.grid_size[1] 89 | self.flatten = flatten 90 | 91 | # build stem, similar to the design in https://arxiv.org/abs/2106.14881 92 | stem = [] 93 | input_dim, output_dim = 3, embed_dim // 8 94 | for l in range(4): 95 | stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False)) 96 | stem.append(nn.BatchNorm2d(output_dim)) 97 | stem.append(nn.ReLU(inplace=True)) 98 | input_dim = output_dim 99 | output_dim *= 2 100 | stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1)) 101 | self.proj = nn.Sequential(*stem) 102 | 103 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 104 | 105 | def forward(self, x): 106 | B, C, H, W = x.shape 107 | assert H == self.img_size[0] and W == self.img_size[1], \ 108 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 109 | x = self.proj(x) 110 | if self.flatten: 111 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 112 | x = self.norm(x) 113 | return x 114 | 115 | 116 | def vit_small(**kwargs): 117 | model = VisionTransformerMoCo( 118 | patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 119 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 120 | model.default_cfg = _cfg() 121 | return model 122 | 123 | def vit_base(**kwargs): 124 | model = VisionTransformerMoCo( 125 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, drop_path_rate=0.1, 126 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 127 | model.default_cfg = _cfg() 128 | return model 129 | 130 | def vit_conv_small(**kwargs): 131 | # minus one ViT block 132 | model = VisionTransformerMoCo( 133 | patch_size=16, embed_dim=384, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True, 134 | norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs) 135 | model.default_cfg = _cfg() 136 | return model 137 | 138 | def vit_conv_base(**kwargs): 139 | # minus one ViT block 140 | model = VisionTransformerMoCo( 141 | patch_size=16, embed_dim=768, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True, 142 | norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs) 143 | model.default_cfg = _cfg() 144 | return model 145 | -------------------------------------------------------------------------------- /src/models/vit_models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | ViT-related models 5 | Note: models return logits instead of prob 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | 10 | from collections import OrderedDict 11 | from torchvision import models 12 | 13 | from .build_vit_backbone import ( 14 | build_mocov3_model, build_mae_model, 15 | ) 16 | from .mlp import MLP 17 | from ..utils import logging 18 | logger = logging.get_logger("visual_prompt") 19 | 20 | 21 | 22 | class SSLViT(nn.Module): 23 | """moco-v3 and mae model.""" 24 | 25 | def __init__(self, cfg): 26 | super(SSLViT, self).__init__() 27 | 28 | if "prompt" in cfg.MODEL.TRANSFER_TYPE: 29 | prompt_cfg = cfg.MODEL.PROMPT 30 | else: 31 | prompt_cfg = None 32 | 33 | if cfg.MODEL.TRANSFER_TYPE != "end2end" and "prompt" not in cfg.MODEL.TRANSFER_TYPE: 34 | # linear, cls, tiny-tl, parital, adapter 35 | self.froze_enc = True 36 | else: 37 | # prompt, end2end, cls+prompt 38 | self.froze_enc = False 39 | 40 | if cfg.MODEL.TRANSFER_TYPE == "adapter": 41 | adapter_cfg = cfg.MODEL.ADAPTER 42 | else: 43 | adapter_cfg = None 44 | 45 | self.build_backbone( 46 | prompt_cfg, cfg, adapter_cfg) 47 | 48 | self.cfg = cfg 49 | self.setup_side() 50 | self.setup_head(cfg) 51 | 52 | def setup_side(self): 53 | if self.cfg.MODEL.TRANSFER_TYPE != "side": 54 | self.side = None 55 | else: 56 | self.side_alpha = nn.Parameter(torch.tensor(0.0)) 57 | m = models.alexnet(pretrained=True) 58 | self.side = nn.Sequential(OrderedDict([ 59 | ("features", m.features), 60 | ("avgpool", m.avgpool), 61 | ])) 62 | self.side_projection = nn.Linear(9216, self.feat_dim, bias=False) 63 | 64 | def setup_head(self, cfg): 65 | self.head = MLP( 66 | input_dim=self.feat_dim, 67 | mlp_dims=[self.feat_dim] * self.cfg.MODEL.MLP_NUM + \ 68 | [cfg.DATA.NUMBER_CLASSES], # noqa 69 | special_bias=True 70 | ) 71 | 72 | def build_backbone(self, prompt_cfg, cfg, adapter_cfg): 73 | if "moco" in cfg.DATA.FEATURE: 74 | build_fn = build_mocov3_model 75 | elif "mae" in cfg.DATA.FEATURE: 76 | build_fn = build_mae_model 77 | 78 | self.enc, self.feat_dim = build_fn( 79 | cfg.DATA.FEATURE, cfg.DATA.CROPSIZE, 80 | prompt_cfg, cfg.MODEL.MODEL_ROOT, adapter_cfg=adapter_cfg 81 | ) 82 | 83 | transfer_type = cfg.MODEL.TRANSFER_TYPE 84 | # linear, prompt, cls, cls+prompt, partial_1 85 | if transfer_type == "partial-1": 86 | total_layer = len(self.enc.blocks) 87 | for k, p in self.enc.named_parameters(): 88 | if "blocks.{}".format(total_layer - 1) not in k and "fc_norm" not in k and k != "norm": # noqa 89 | p.requires_grad = False 90 | elif transfer_type == "partial-2": 91 | total_layer = len(self.enc.blocks) 92 | for k, p in self.enc.named_parameters(): 93 | if "blocks.{}".format(total_layer - 1) not in k and "blocks.{}".format(total_layer - 2) not in k and "fc_norm" not in k and k != "norm": # noqa 94 | p.requires_grad = False 95 | 96 | elif transfer_type == "partial-4": 97 | total_layer = len(self.enc.blocks) 98 | for k, p in self.enc.named_parameters(): 99 | if "blocks.{}".format(total_layer - 1) not in k and "blocks.{}".format(total_layer - 2) not in k and "blocks.{}".format(total_layer - 3) not in k and "blocks.{}".format(total_layer - 4) not in k and "fc_norm" not in k and k != "norm": # noqa 100 | p.requires_grad = False 101 | 102 | elif transfer_type == "linear" or transfer_type == "sidetune": 103 | for k, p in self.enc.named_parameters(): 104 | p.requires_grad = False 105 | 106 | elif transfer_type == "tinytl-bias": 107 | for k, p in self.enc.named_parameters(): 108 | if 'bias' not in k: 109 | p.requires_grad = False 110 | 111 | elif transfer_type == "prompt+bias": 112 | for k, p in self.enc.named_parameters(): 113 | if "prompt" not in k and 'bias' not in k: 114 | p.requires_grad = False 115 | 116 | elif transfer_type == "prompt" and prompt_cfg.LOCATION == "below": 117 | for k, p in self.enc.named_parameters(): 118 | if "prompt" not in k and "patch_embed.proj.weight" not in k and "patch_embed.proj.bias" not in k: 119 | p.requires_grad = False 120 | 121 | elif transfer_type == "prompt": 122 | for k, p in self.enc.named_parameters(): 123 | if "prompt" not in k: 124 | p.requires_grad = False 125 | 126 | elif transfer_type == "end2end": 127 | logger.info("Enable all parameters update during training") 128 | 129 | # adapter 130 | elif transfer_type == "adapter": 131 | for k, p in self.enc.named_parameters(): 132 | if "adapter" not in k: 133 | p.requires_grad = False 134 | else: 135 | raise ValueError("transfer type {} is not supported".format( 136 | transfer_type)) 137 | 138 | for k, p in self.enc.named_parameters(): 139 | if 'gate' in k: 140 | p.requires_grad = True 141 | if 'temp' in k: 142 | p.requires_grad = True 143 | 144 | def forward(self, x, return_feature=False): 145 | if self.side is not None: 146 | side_output = self.side(x) 147 | side_output = side_output.view(side_output.size(0), -1) 148 | side_output = self.side_projection(side_output) 149 | 150 | if self.froze_enc and self.enc.training: 151 | self.enc.eval() 152 | x = self.enc(x) # batch_size x self.feat_dim 153 | 154 | if self.side is not None: 155 | alpha_squashed = torch.sigmoid(self.side_alpha) 156 | x = alpha_squashed * x + (1 - alpha_squashed) * side_output 157 | 158 | if return_feature: 159 | return x, x 160 | x = self.head(x) 161 | 162 | return x 163 | 164 | def forward_cls_layerwise(self, x): 165 | cls_embeds = self.enc.forward_cls_layerwise(x) 166 | return cls_embeds 167 | 168 | def get_features(self, x): 169 | """get a (batch_size, self.feat_dim) feature""" 170 | x = self.enc(x) # batch_size x self.feat_dim 171 | return x 172 | -------------------------------------------------------------------------------- /src/models/vit_prompt/vit_mae.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | vit-moco-v3 with prompt 4 | """ 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision as tv 10 | 11 | from functools import partial, reduce 12 | from operator import mul 13 | from torch.nn import Conv2d, Dropout 14 | from timm.models.vision_transformer import _cfg 15 | 16 | from ..vit_backbones.vit_mae import VisionTransformer 17 | from ...utils import logging 18 | logger = logging.get_logger("visual_prompt") 19 | 20 | 21 | class PromptedVisionTransformer(VisionTransformer): 22 | def __init__(self, prompt_config, **kwargs): 23 | super().__init__(**kwargs) 24 | self.prompt_config = prompt_config 25 | if self.prompt_config.DEEP and self.prompt_config.LOCATION not in ["prepend", ]: 26 | raise ValueError("Deep-{} is not supported".format(self.prompt_config.LOCATION)) 27 | 28 | num_tokens = self.prompt_config.NUM_TOKENS 29 | 30 | self.num_tokens = num_tokens 31 | self.prompt_dropout = Dropout(self.prompt_config.DROPOUT) 32 | 33 | # define temperature for attention shaping 34 | self.temp = self.prompt_config.TEMP 35 | self.temp_learn = self.prompt_config.TEMP_LEARN 36 | if self.temp_learn: 37 | self.temp = nn.Parameter(torch.ones(prompt_config.TEMP_NUM)) 38 | 39 | # initiate prompt: 40 | if self.prompt_config.INITIATION == "random": 41 | val = math.sqrt(6. / float(3 * reduce(mul, self.patch_embed.patch_size, 1) + self.embed_dim)) # noqa 42 | 43 | self.prompt_embeddings = nn.Parameter(torch.zeros( 44 | 1, num_tokens, self.embed_dim)) 45 | # xavier_uniform initialization 46 | nn.init.uniform_(self.prompt_embeddings.data, -val, val) 47 | 48 | if self.prompt_config.DEEP: 49 | self.deep_prompt_embeddings = nn.Parameter(torch.zeros( 50 | len(self.blocks) - 1, 51 | num_tokens, self.embed_dim 52 | )) 53 | # xavier_uniform initialization 54 | nn.init.uniform_( 55 | self.deep_prompt_embeddings.data, -val, val) 56 | 57 | else: 58 | raise ValueError("Other initiation scheme is not supported") 59 | 60 | # define block-wise learnable gate scalar 61 | if self.prompt_config.GATE_PRIOR: 62 | gate_logit = (-torch.ones(self.prompt_config.GATE_NUM) * self.prompt_config.GATE_INIT) 63 | self.gate_logit = nn.Parameter(gate_logit) 64 | print(self.gate_logit) 65 | 66 | def incorporate_prompt(self, x): 67 | # combine prompt embeddings with image-patch embeddings 68 | B = x.shape[0] 69 | if self.prompt_config.LOCATION == "prepend": 70 | # after CLS token, all before image patches 71 | x = self.embeddings(x) # (batch_size, 1 + n_pa 72 | x = torch.cat(( 73 | x[:, :1, :], 74 | self.prompt_dropout( 75 | self.prompt_embeddings.expand(B, -1, -1)), 76 | x[:, 1:, :] 77 | ), dim=1) 78 | # (batch_size, cls_token + n_prompt + n_patches, hidden_dim) 79 | else: 80 | raise ValueError("Other prompt locations are not supported") 81 | return x 82 | 83 | def embeddings(self, x): 84 | B = x.shape[0] 85 | x = self.patch_embed(x) 86 | 87 | cls_tokens = self.cls_token.expand(B, -1, -1) 88 | x = torch.cat((cls_tokens, x), dim=1) 89 | x = x + self.pos_embed 90 | x = self.pos_drop(x) 91 | return x 92 | 93 | def train(self, mode=True): 94 | # set train status for this class: disable all but the prompt-related modules 95 | if mode: 96 | # training: 97 | self.blocks.eval() 98 | self.patch_embed.eval() 99 | self.pos_drop.eval() 100 | self.prompt_dropout.train() 101 | else: 102 | # eval: 103 | for module in self.children(): 104 | module.train(mode) 105 | 106 | def reinit_temp(self): 107 | assert self.temp_learn, "reinit_temp() could be run only when config.TEMP_LEARN == True" 108 | self.temp.data.copy_(self.temp.data.clamp(min=self.prompt_config.TEMP_MIN, max=self.prompt_config.TEMP_MAX)) 109 | 110 | def forward_features(self, x): 111 | x = self.incorporate_prompt(x) 112 | 113 | # deep 114 | if self.prompt_config.DEEP: 115 | B = x.shape[0] 116 | num_layers = len(self.blocks) 117 | 118 | for i in range(num_layers): 119 | if i == 0: 120 | x = self.blocks[i](x) 121 | else: 122 | # prepend 123 | x = torch.cat(( 124 | x[:, 0:1, :], 125 | self.prompt_dropout( 126 | self.deep_prompt_embeddings[i - 1].expand(B, -1, -1) 127 | ), 128 | x[:, (1 + self.num_tokens):, :] 129 | ), dim=1) 130 | x = self.blocks[i](x) 131 | 132 | else: 133 | # clamp temperatures not to be too small or too large 134 | if self.temp_learn: 135 | self.reinit_temp() 136 | 137 | for i, blk in enumerate(self.blocks): 138 | # current block's input prompt representation 139 | if self.prompt_config.GATE_PRIOR and i < self.gate_logit.shape[0]: 140 | gate = self.gate_logit[i].sigmoid() 141 | prompt_in = x[:, 1: 1+self.prompt_config.NUM_TOKENS, :] 142 | 143 | # block-wise learnable temperature 144 | temp = self.temp if not isinstance(self.temp, nn.Parameter) else self.temp[i] 145 | 146 | x = blk(x, temp=temp) 147 | if self.prompt_config.GATE_PRIOR and i < self.gate_logit.shape[0]: 148 | # current block's output prompt representation 149 | prompt_out = x[:, 1: 1+self.prompt_config.NUM_TOKENS, :] 150 | # convex combinate input and output prompt representations of current block via learnalbe gate 151 | x = torch.cat([ 152 | x[:, 0:1, :], 153 | gate * prompt_out + (1 - gate) * prompt_in, 154 | x[:, 1+self.prompt_config.NUM_TOKENS:, :] 155 | ], dim=1) 156 | 157 | norm_func = self.norm 158 | if self.prompt_config.VIT_POOL_TYPE == "imgprompt_pool": 159 | assert self.prompt_config.LOCATION == "prepend" 160 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 161 | outcome = norm_func(x) 162 | elif self.prompt_config.VIT_POOL_TYPE == "original": 163 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 164 | outcome = norm_func(x) 165 | elif self.prompt_config.VIT_POOL_TYPE == "img_pool": 166 | assert self.prompt_config.LOCATION == "prepend" 167 | x = x[:, self.num_tokens+1:, :].mean(dim=1) 168 | outcome = norm_func(x) 169 | elif self.prompt_config.VIT_POOL_TYPE == "prompt_pool": 170 | assert self.prompt_config.LOCATION == "prepend" 171 | x = x[:, 1:self.num_tokens+1, :].mean(dim=1) 172 | outcome = norm_func(x) 173 | else: 174 | raise ValueError("pooling type for output is not supported") 175 | 176 | 177 | return outcome 178 | 179 | 180 | def build_model(model_type, prompt_cfg): 181 | if "vitb" in model_type: 182 | return vit_base_patch16(prompt_cfg) 183 | elif "vitl" in model_type: 184 | return vit_large_patch16(prompt_cfg) 185 | elif "vith" in model_type: 186 | return vit_huge_patch14(prompt_cfg) 187 | 188 | 189 | def vit_base_patch16(prompt_cfg, **kwargs): 190 | model = PromptedVisionTransformer( 191 | prompt_cfg, 192 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 193 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 194 | mlp_ratio=4, qkv_bias=True, 195 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 196 | return model 197 | 198 | 199 | def vit_large_patch16(prompt_cfg, **kwargs): 200 | model = PromptedVisionTransformer( 201 | prompt_cfg, 202 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 203 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 204 | mlp_ratio=4, qkv_bias=True, 205 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 206 | return model 207 | 208 | 209 | def vit_huge_patch14(prompt_cfg, **kwargs): 210 | model = PromptedVisionTransformer( 211 | prompt_cfg, 212 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 213 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 214 | mlp_ratio=4, qkv_bias=True, 215 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 216 | return model 217 | -------------------------------------------------------------------------------- /src/models/vit_prompt/vit_moco.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | vit-moco-v3 with prompt 4 | """ 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision as tv 10 | 11 | from functools import partial, reduce 12 | from operator import mul 13 | from torch.nn import Conv2d, Dropout 14 | from timm.models.vision_transformer import _cfg 15 | 16 | from ..vit_backbones.vit_moco import VisionTransformerMoCo 17 | from ...utils import logging 18 | logger = logging.get_logger("visual_prompt") 19 | 20 | 21 | class PromptedVisionTransformerMoCo(VisionTransformerMoCo): 22 | def __init__(self, prompt_config, **kwargs): 23 | super().__init__(**kwargs) 24 | self.prompt_config = prompt_config 25 | 26 | if self.prompt_config.DEEP and self.prompt_config.LOCATION not in ["prepend", ]: 27 | raise ValueError("Deep-{} is not supported".format(self.prompt_config.LOCATION)) 28 | 29 | num_tokens = self.prompt_config.NUM_TOKENS 30 | 31 | self.num_tokens = num_tokens 32 | self.prompt_dropout = Dropout(self.prompt_config.DROPOUT) 33 | 34 | # define temperature for attention shaping 35 | self.temp = self.prompt_config.TEMP 36 | self.temp_learn = self.prompt_config.TEMP_LEARN 37 | if self.temp_learn: 38 | self.temp = nn.Parameter(torch.ones(prompt_config.TEMP_NUM)) 39 | 40 | # initiate prompt: 41 | if self.prompt_config.INITIATION == "random": 42 | val = math.sqrt(6. / float(3 * reduce(mul, self.patch_embed.patch_size, 1) + self.embed_dim)) # noqa 43 | 44 | self.prompt_embeddings = nn.Parameter(torch.zeros( 45 | 1, num_tokens, self.embed_dim)) 46 | # xavier_uniform initialization 47 | nn.init.uniform_(self.prompt_embeddings.data, -val, val) 48 | 49 | if self.prompt_config.DEEP: 50 | self.deep_prompt_embeddings = nn.Parameter(torch.zeros( 51 | len(self.blocks) - 1, 52 | num_tokens, self.embed_dim 53 | )) 54 | # xavier_uniform initialization 55 | nn.init.uniform_( 56 | self.deep_prompt_embeddings.data, -val, val) 57 | else: 58 | raise ValueError("Other initiation scheme is not supported") 59 | 60 | # define block-wise learnable gate scalar 61 | if self.prompt_config.GATE_PRIOR: 62 | gate_logit = (-torch.ones(self.prompt_config.GATE_NUM) * self.prompt_config.GATE_INIT) 63 | self.gate_logit = nn.Parameter(gate_logit) 64 | print(self.gate_logit) 65 | 66 | def incorporate_prompt(self, x): 67 | # combine prompt embeddings with image-patch embeddings 68 | B = x.shape[0] 69 | if self.prompt_config.LOCATION == "prepend": 70 | # after CLS token, all before image patches 71 | x = self.embeddings(x) # (batch_size, 1 + n_patches, hidden_dim) 72 | x = torch.cat(( 73 | x[:, :1, :], 74 | self.prompt_dropout( 75 | self.prompt_embeddings.expand(B, -1, -1)), 76 | x[:, 1:, :] 77 | ), dim=1) 78 | # (batch_size, cls_token + n_prompt + n_patches, hidden_dim) 79 | else: 80 | raise ValueError("Other prompt locations are not supported") 81 | return x 82 | 83 | def embeddings(self, x): 84 | x = self.patch_embed(x) 85 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) 86 | if self.dist_token is None: 87 | x = torch.cat((cls_token, x), dim=1) 88 | else: 89 | x = torch.cat(( 90 | cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), 91 | dim=1) 92 | x = self.pos_drop(x + self.pos_embed) 93 | return x 94 | 95 | def train(self, mode=True): 96 | # set train status for this class: disable all but the prompt-related modules 97 | if mode: 98 | # training: 99 | self.blocks.eval() 100 | self.patch_embed.eval() 101 | self.pos_drop.eval() 102 | self.prompt_dropout.train() 103 | else: 104 | # eval: 105 | for module in self.children(): 106 | module.train(mode) 107 | 108 | def reinit_temp(self): 109 | assert self.temp_learn, "reinit_temp() could be run only when config.TEMP_LEARN == True" 110 | self.temp.data.copy_(self.temp.data.clamp(min=self.prompt_config.TEMP_MIN, max=self.prompt_config.TEMP_MAX)) 111 | 112 | def forward_features(self, x): 113 | x = self.incorporate_prompt(x) 114 | 115 | # deep 116 | if self.prompt_config.DEEP: 117 | B = x.shape[0] 118 | num_layers = len(self.blocks) 119 | 120 | for i in range(num_layers): 121 | if i == 0: 122 | x = self.blocks[i](x) 123 | else: 124 | # prepend 125 | x = torch.cat(( 126 | x[:, :1, :], 127 | self.prompt_dropout( 128 | self.deep_prompt_embeddings[i - 1].expand(B, -1, -1) 129 | ), 130 | x[:, (1 + self.num_tokens):, :] 131 | ), dim=1) 132 | x = self.blocks[i](x) 133 | 134 | else: 135 | # clamp temperatures not to be too small or too large 136 | if self.temp_learn: 137 | self.reinit_temp() 138 | 139 | for i, blk in enumerate(self.blocks): 140 | # current block's input prompt representation 141 | if self.prompt_config.GATE_PRIOR and i < self.gate_logit.shape[0]: 142 | gate = self.gate_logit[i].sigmoid() 143 | prompt_in = x[:, 1: 1+self.prompt_config.NUM_TOKENS, :] 144 | 145 | # block-wise learnable temperature 146 | temp = self.temp if not isinstance(self.temp, nn.Parameter) else self.temp[i] 147 | 148 | x = blk(x, temp=temp) 149 | if self.prompt_config.GATE_PRIOR and i < self.gate_logit.shape[0]: 150 | # current block's output prompt representation 151 | prompt_out = x[:, 1: 1+self.prompt_config.NUM_TOKENS, :] 152 | # convex combinate input and output prompt representations of current block via learnalbe gate 153 | x = torch.cat([ 154 | x[:, 0:1, :], 155 | gate * prompt_out + (1 - gate) * prompt_in, 156 | x[:, 1+self.prompt_config.NUM_TOKENS:, :] 157 | ], dim=1) 158 | 159 | norm_func = self.norm 160 | if self.prompt_config.VIT_POOL_TYPE == "imgprompt_pool": 161 | assert self.prompt_config.LOCATION == "prepend" 162 | outcome = norm_func(x[:, 1:, :].mean(dim=1)) # global pool without cls token 163 | 164 | elif self.prompt_config.VIT_POOL_TYPE == "original": 165 | x = norm_func(x) 166 | outcome = x[:, 0] 167 | 168 | elif self.prompt_config.VIT_POOL_TYPE == "img_pool": 169 | assert self.prompt_config.LOCATION == "prepend" 170 | outcome = norm_func(x[:, self.num_tokens+1:, :].mean(dim=1)) 171 | elif self.prompt_config.VIT_POOL_TYPE == "prompt_pool": 172 | assert self.prompt_config.LOCATION == "prepend" 173 | outcome = norm_func(x[:, 1:self.num_tokens+1, :].mean(dim=1)) 174 | 175 | else: 176 | raise ValueError("pooling type for output is not supported") 177 | 178 | return outcome 179 | 180 | 181 | def build_model(model_type, prompt_cfg): 182 | if "vitb" in model_type: 183 | return vit_base(prompt_cfg) 184 | elif "vits" in model_type: 185 | return vit_small(prompt_cfg) 186 | 187 | 188 | def vit_small(prompt_cfg, **kwargs): 189 | model = PromptedVisionTransformerMoCo( 190 | prompt_cfg, 191 | patch_size=16, embed_dim=384, depth=12, drop_path_rate=0.1, 192 | num_heads=12, mlp_ratio=4, qkv_bias=True, 193 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 194 | model.default_cfg = _cfg() 195 | return model 196 | 197 | 198 | def vit_base(prompt_cfg, **kwargs): 199 | model = PromptedVisionTransformerMoCo( 200 | prompt_cfg, 201 | patch_size=16, embed_dim=768, depth=12, drop_path_rate=0.1, 202 | num_heads=12, mlp_ratio=4, qkv_bias=True, 203 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 204 | model.default_cfg = _cfg() 205 | return model 206 | 207 | -------------------------------------------------------------------------------- /src/solver/losses.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from typing import Optional 6 | 7 | from ..utils import logging 8 | logger = logging.get_logger("visual_prompt") 9 | 10 | 11 | class SigmoidLoss(nn.Module): 12 | def __init__(self, cfg=None): 13 | super(SigmoidLoss, self).__init__() 14 | 15 | def is_single(self): 16 | return True 17 | 18 | def is_local(self): 19 | return False 20 | 21 | def multi_hot(self, labels: torch.Tensor, nb_classes: int) -> torch.Tensor: 22 | labels = labels.unsqueeze(1) # (batch_size, 1) 23 | target = torch.zeros( 24 | labels.size(0), nb_classes, device=labels.device 25 | ).scatter_(1, labels, 1.) 26 | # (batch_size, num_classes) 27 | return target 28 | 29 | def loss( 30 | self, logits, targets, per_cls_weights, 31 | multihot_targets: Optional[bool] = False 32 | ): 33 | # targets: 1d-tensor of integer 34 | # Only support single label at this moment 35 | # if len(targets.shape) != 2: 36 | num_classes = logits.shape[1] 37 | targets = self.multi_hot(targets, num_classes) 38 | 39 | loss = F.binary_cross_entropy_with_logits( 40 | logits, targets, reduction="none") 41 | # logger.info(f"loss shape: {loss.shape}") 42 | weight = torch.tensor( 43 | per_cls_weights, device=logits.device 44 | ).unsqueeze(0) 45 | # logger.info(f"weight shape: {weight.shape}") 46 | loss = torch.mul(loss.to(torch.float32), weight.to(torch.float32)) 47 | return torch.sum(loss) / targets.shape[0] 48 | 49 | def forward( 50 | self, pred_logits, targets, per_cls_weights, multihot_targets=False 51 | ): 52 | loss = self.loss( 53 | pred_logits, targets, per_cls_weights, multihot_targets) 54 | return loss 55 | 56 | 57 | class SoftmaxLoss(SigmoidLoss): 58 | def __init__(self, cfg=None): 59 | super(SoftmaxLoss, self).__init__() 60 | 61 | def loss(self, logits, targets, per_cls_weights, kwargs): 62 | weight = torch.tensor( 63 | per_cls_weights, device=logits.device 64 | ) 65 | loss = F.cross_entropy(logits, targets, weight, reduction="none") 66 | 67 | return torch.sum(loss) / targets.shape[0] 68 | 69 | 70 | LOSS = { 71 | "softmax": SoftmaxLoss, 72 | } 73 | 74 | 75 | def build_loss(cfg): 76 | loss_name = cfg.SOLVER.LOSS 77 | assert loss_name in LOSS, \ 78 | f'loss name {loss_name} is not supported' 79 | loss_fn = LOSS[loss_name] 80 | if not loss_fn: 81 | return None 82 | else: 83 | return loss_fn(cfg) 84 | -------------------------------------------------------------------------------- /src/solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import math 3 | 4 | import torch.optim as optim 5 | from fvcore.common.config import CfgNode 6 | from torch.optim.lr_scheduler import LambdaLR 7 | 8 | 9 | def make_scheduler( 10 | optimizer: optim.Optimizer, train_params: CfgNode 11 | ) -> LambdaLR: 12 | warmup = train_params.WARMUP_EPOCH 13 | total_iters = train_params.TOTAL_EPOCH 14 | 15 | if train_params.SCHEDULER == "cosine": 16 | scheduler = WarmupCosineSchedule( 17 | optimizer, 18 | warmup_steps=warmup, 19 | t_total=total_iters 20 | ) 21 | elif train_params.SCHEDULER == "cosine_hardrestart": 22 | scheduler = WarmupCosineWithHardRestartsSchedule( 23 | optimizer, 24 | warmup_steps=warmup, 25 | t_total=total_iters 26 | ) 27 | 28 | elif train_params.SCHEDULER == "plateau": 29 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 30 | optimizer, 31 | "max", 32 | patience=5, 33 | verbose=True, 34 | factor=train_params.LR_DECAY_FACTOR, 35 | ) 36 | else: 37 | scheduler = None 38 | return scheduler 39 | 40 | 41 | class WarmupCosineSchedule(LambdaLR): 42 | """ Linear warmup and then cosine decay. 43 | Linearly increases learning rate from 0 to 1 over `warmup_steps`. 44 | Decreases learning rate from 1. to 0. over remaining 45 | `t_total - warmup_steps` steps following a cosine curve. 46 | If `cycles` (default=0.5) is different from default, learning rate 47 | follows cosine function after warmup. 48 | """ 49 | def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1): 50 | self.warmup_steps = warmup_steps 51 | self.t_total = t_total 52 | self.cycles = cycles 53 | super(WarmupCosineSchedule, self).__init__( 54 | optimizer, self.lr_lambda, last_epoch=last_epoch) 55 | 56 | def lr_lambda(self, step): 57 | if step < self.warmup_steps: 58 | return float(step) / float(max(1.0, self.warmup_steps)) 59 | # progress after warmup 60 | progress = float(step - self.warmup_steps) / float(max( 61 | 1, self.t_total - self.warmup_steps)) 62 | return max( 63 | 0.0, 64 | 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)) 65 | ) 66 | 67 | 68 | class WarmupCosineWithHardRestartsSchedule(LambdaLR): 69 | """ Linear warmup and then cosine cycles with hard restarts. 70 | Linearly increases learning rate from 0 to 1 over `warmup_steps`. 71 | If `cycles` (default=1.) is different from default, learning rate 72 | follows `cycles` times a cosine decaying learning rate 73 | (with hard restarts). 74 | """ 75 | def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1): 76 | self.warmup_steps = warmup_steps 77 | self.t_total = t_total 78 | self.cycles = cycles 79 | super(WarmupCosineWithHardRestartsSchedule, self).__init__( 80 | optimizer, self.lr_lambda, last_epoch=last_epoch) 81 | 82 | def lr_lambda(self, step): 83 | if step < self.warmup_steps: 84 | return float(step) / float(max(1, self.warmup_steps)) 85 | # progress after warmup 86 | progress = float(step - self.warmup_steps) / float( 87 | max(1, self.t_total - self.warmup_steps)) 88 | if progress >= 1.0: 89 | return 0.0 90 | return max( 91 | 0.0, 92 | 0.5 * (1. + math.cos( 93 | math.pi * ((float(self.cycles) * progress) % 1.0))) 94 | ) 95 | -------------------------------------------------------------------------------- /src/utils/distributed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Distributed helpers.""" 4 | 5 | import torch 6 | import torch.distributed as dist 7 | _LOCAL_PROCESS_GROUP = None 8 | 9 | 10 | def get_world_size() -> int: 11 | if not dist.is_available(): 12 | return 1 13 | if not dist.is_initialized(): 14 | return 1 15 | return dist.get_world_size() 16 | 17 | 18 | def get_rank() -> int: 19 | if not dist.is_available(): 20 | return 0 21 | if not dist.is_initialized(): 22 | return 0 23 | return dist.get_rank() 24 | 25 | 26 | def is_master_process(num_gpus=8): 27 | """ 28 | Determines if the current process is the master process. 29 | """ 30 | if torch.distributed.is_initialized(): 31 | return dist.get_rank() % num_gpus == 0 32 | else: 33 | return True 34 | 35 | 36 | def run( 37 | local_rank, 38 | num_proc, 39 | func, 40 | init_method, 41 | shard_id, 42 | num_shards, 43 | backend, 44 | cfg, 45 | args, 46 | ): 47 | """ 48 | Runs a function from a child process. 49 | Args: 50 | local_rank (int): rank of the current process on the current machine. 51 | num_proc (int): number of processes per machine. 52 | func (function): function to execute on each of the process. 53 | init_method (string): method to initialize the distributed training. 54 | TCP initialization: equiring a network address reachable from all 55 | processes followed by the port. 56 | Shared file-system initialization: makes use of a file system that 57 | is shared and visible from all machines. The URL should start with 58 | file:// and contain a path to a non-existent file on a shared file 59 | system. 60 | shard_id (int): the rank of the current machine. 61 | num_shards (int): number of overall machines for the distributed 62 | training job. 63 | backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are 64 | supports, each with different capabilities. Details can be found 65 | here: 66 | https://pytorch.org/docs/stable/distributed.html 67 | cfg (CfgNode): configs. Details can be found in 68 | loco/config/defaults.py 69 | """ 70 | # Initialize the process group. 71 | # shard_id = get_rank() 72 | world_size = num_proc * num_shards 73 | rank = shard_id * num_proc + local_rank 74 | 75 | try: 76 | torch.distributed.init_process_group( 77 | backend=backend, 78 | init_method=init_method, 79 | world_size=world_size, 80 | rank=rank, 81 | ) 82 | except Exception as e: 83 | raise e 84 | 85 | torch.cuda.set_device(local_rank) 86 | func(cfg, args) 87 | 88 | 89 | def destroy_process_group(): 90 | """Destroys the default process group.""" 91 | torch.distributed.destroy_process_group() 92 | 93 | 94 | def scaled_all_reduce(cfg, tensors): 95 | """Performs the scaled all_reduce operation on the provided tensors. 96 | 97 | The input tensors are modified in-place. Currently supports only the sum 98 | reduction operator. The reduced values are scaled by the inverse size of 99 | the process group (equivalent to cfg.NUM_GPUS). 100 | """ 101 | # Queue the reductions 102 | reductions = [] 103 | for tensor in tensors: 104 | reduction = torch.distributed.all_reduce(tensor, async_op=True) 105 | reductions.append(reduction) 106 | # Wait for reductions to finish 107 | for reduction in reductions: 108 | reduction.wait() 109 | # Scale the results 110 | for tensor in tensors: 111 | tensor.mul_(1.0 / cfg.NUM_GPUS / cfg.NUM_SHARDS) 112 | return tensors 113 | 114 | 115 | def cat_all_gather(tensors): 116 | """Performs the concatenated all_gather operation on the provided tensors. 117 | """ 118 | tensors_gather = [ 119 | torch.ones_like(tensors) 120 | for _ in range(torch.distributed.get_world_size()) 121 | ] 122 | torch.distributed.all_gather(tensors_gather, tensors, async_op=False) 123 | 124 | output = torch.cat(tensors_gather, dim=0) 125 | return output 126 | 127 | 128 | def local_cat_all_gather(tensors): 129 | """Performs the concatenated all_gather operation on the provided tensors. 130 | """ 131 | tensors_gather = [ 132 | torch.ones_like(tensors) 133 | for _ in range(get_local_size()) 134 | ] 135 | torch.distributed.all_gather( 136 | tensors_gather, 137 | tensors, 138 | async_op=False, 139 | group=_LOCAL_PROCESS_GROUP, 140 | ) 141 | output = torch.cat(tensors_gather, dim=0) 142 | return output 143 | 144 | 145 | def get_local_size(): 146 | """ 147 | Returns: 148 | The size of the per-machine process group, 149 | i.e. the number of processes per machine. 150 | """ 151 | if not dist.is_available(): 152 | return 1 153 | if not dist.is_initialized(): 154 | return 1 155 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 156 | 157 | 158 | def get_local_rank(): 159 | """ 160 | Returns: 161 | The rank of the current process within the local (per-machine) process group. 162 | """ 163 | if not dist.is_available(): 164 | return 0 165 | if not dist.is_initialized(): 166 | return 0 167 | assert _LOCAL_PROCESS_GROUP is not None 168 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 169 | -------------------------------------------------------------------------------- /src/utils/distributed_orig.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Distributed helpers.""" 4 | 5 | import torch 6 | import torch.distributed as dist 7 | _LOCAL_PROCESS_GROUP = None 8 | 9 | 10 | def get_world_size() -> int: 11 | if not dist.is_available(): 12 | return 1 13 | if not dist.is_initialized(): 14 | return 1 15 | return dist.get_world_size() 16 | 17 | 18 | def get_rank() -> int: 19 | if not dist.is_available(): 20 | return 0 21 | if not dist.is_initialized(): 22 | return 0 23 | return dist.get_rank() 24 | 25 | 26 | def is_master_process(num_gpus=8): 27 | """ 28 | Determines if the current process is the master process. 29 | """ 30 | if torch.distributed.is_initialized(): 31 | return dist.get_rank() % num_gpus == 0 32 | else: 33 | return True 34 | 35 | 36 | def run( 37 | local_rank, 38 | num_proc, 39 | func, 40 | init_method, 41 | shard_id, 42 | num_shards, 43 | backend, 44 | cfg, 45 | args, 46 | ): 47 | """ 48 | Runs a function from a child process. 49 | Args: 50 | local_rank (int): rank of the current process on the current machine. 51 | num_proc (int): number of processes per machine. 52 | func (function): function to execute on each of the process. 53 | init_method (string): method to initialize the distributed training. 54 | TCP initialization: equiring a network address reachable from all 55 | processes followed by the port. 56 | Shared file-system initialization: makes use of a file system that 57 | is shared and visible from all machines. The URL should start with 58 | file:// and contain a path to a non-existent file on a shared file 59 | system. 60 | shard_id (int): the rank of the current machine. 61 | num_shards (int): number of overall machines for the distributed 62 | training job. 63 | backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are 64 | supports, each with different capabilities. Details can be found 65 | here: 66 | https://pytorch.org/docs/stable/distributed.html 67 | cfg (CfgNode): configs. Details can be found in 68 | loco/config/defaults.py 69 | """ 70 | # Initialize the process group. 71 | # shard_id = get_rank() 72 | world_size = num_proc * num_shards 73 | rank = shard_id * num_proc + local_rank 74 | 75 | try: 76 | torch.distributed.init_process_group( 77 | backend=backend, 78 | init_method=init_method, 79 | world_size=world_size, 80 | rank=rank, 81 | ) 82 | except Exception as e: 83 | raise e 84 | 85 | torch.cuda.set_device(local_rank) 86 | func(cfg, args) 87 | 88 | 89 | def destroy_process_group(): 90 | """Destroys the default process group.""" 91 | torch.distributed.destroy_process_group() 92 | 93 | 94 | def scaled_all_reduce(cfg, tensors): 95 | """Performs the scaled all_reduce operation on the provided tensors. 96 | 97 | The input tensors are modified in-place. Currently supports only the sum 98 | reduction operator. The reduced values are scaled by the inverse size of 99 | the process group (equivalent to cfg.NUM_GPUS). 100 | """ 101 | # Queue the reductions 102 | reductions = [] 103 | for tensor in tensors: 104 | reduction = torch.distributed.all_reduce(tensor, async_op=True) 105 | reductions.append(reduction) 106 | # Wait for reductions to finish 107 | for reduction in reductions: 108 | reduction.wait() 109 | # Scale the results 110 | for tensor in tensors: 111 | tensor.mul_(1.0 / cfg.NUM_GPUS / cfg.NUM_SHARDS) 112 | return tensors 113 | 114 | 115 | def cat_all_gather(tensors): 116 | """Performs the concatenated all_gather operation on the provided tensors. 117 | """ 118 | tensors_gather = [ 119 | torch.ones_like(tensors) 120 | for _ in range(torch.distributed.get_world_size()) 121 | ] 122 | torch.distributed.all_gather(tensors_gather, tensors, async_op=False) 123 | 124 | output = torch.cat(tensors_gather, dim=0) 125 | return output 126 | 127 | 128 | def local_cat_all_gather(tensors): 129 | """Performs the concatenated all_gather operation on the provided tensors. 130 | """ 131 | tensors_gather = [ 132 | torch.ones_like(tensors) 133 | for _ in range(get_local_size()) 134 | ] 135 | torch.distributed.all_gather( 136 | tensors_gather, 137 | tensors, 138 | async_op=False, 139 | group=_LOCAL_PROCESS_GROUP, 140 | ) 141 | output = torch.cat(tensors_gather, dim=0) 142 | return output 143 | 144 | 145 | def get_local_size(): 146 | """ 147 | Returns: 148 | The size of the per-machine process group, 149 | i.e. the number of processes per machine. 150 | """ 151 | if not dist.is_available(): 152 | return 1 153 | if not dist.is_initialized(): 154 | return 1 155 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 156 | 157 | 158 | def get_local_rank(): 159 | """ 160 | Returns: 161 | The rank of the current process within the local (per-machine) process group. 162 | """ 163 | if not dist.is_available(): 164 | return 0 165 | if not dist.is_initialized(): 166 | return 0 167 | assert _LOCAL_PROCESS_GROUP is not None 168 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 169 | -------------------------------------------------------------------------------- /src/utils/file_io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Project specific pathmanagers for a project as recommended by Detectron2 5 | """ 6 | from iopath.common.file_io import PathManager as PathManagerBase 7 | from iopath.common.file_io import HTTPURLHandler 8 | 9 | 10 | PathManager = PathManagerBase() 11 | PathManager.register_handler(HTTPURLHandler()) 12 | -------------------------------------------------------------------------------- /src/utils/io_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | a bunch of helper functions for read and write data 4 | """ 5 | import os 6 | import json 7 | import numpy as np 8 | import time 9 | import pandas as pd 10 | 11 | from typing import List, Union 12 | from PIL import Image, ImageFile 13 | Image.MAX_IMAGE_PIXELS = None 14 | 15 | 16 | def save_or_append_df(out_path, df): 17 | if os.path.exists(out_path): 18 | previous_df = pd.read_pickle(out_path) 19 | df = pd.concat([previous_df, df], ignore_index=True) 20 | df.to_pickle(out_path) 21 | print(f"Saved output at {out_path}") 22 | 23 | 24 | class JSONEncoder(json.JSONEncoder): 25 | def default(self, obj): 26 | if isinstance(obj, np.ndarray): 27 | return obj.tolist() 28 | elif isinstance(obj, bytes): 29 | return str(obj, encoding='utf-8') 30 | elif isinstance(obj, np.integer): 31 | return int(obj) 32 | elif isinstance(obj, np.floating): 33 | return float(obj) 34 | elif isinstance(obj, np.ndarray): 35 | return obj.tolist() 36 | else: 37 | # return super(MyEncoder, self).default(obj) 38 | 39 | raise TypeError( 40 | "Unserializable object {} of type {}".format(obj, type(obj)) 41 | ) 42 | 43 | 44 | def write_json(data: Union[list, dict], outfile: str) -> None: 45 | json_dir, _ = os.path.split(outfile) 46 | if json_dir and not os.path.exists(json_dir): 47 | os.makedirs(json_dir) 48 | 49 | with open(outfile, 'w') as f: 50 | json.dump(data, f, cls=JSONEncoder, ensure_ascii=False, indent=2) 51 | 52 | 53 | def read_json(filename: str) -> Union[list, dict]: 54 | """read json files""" 55 | with open(filename, "rb") as fin: 56 | data = json.load(fin, encoding="utf-8") 57 | return data 58 | 59 | 60 | def pil_loader(path: str) -> Image.Image: 61 | """load an image from path, and suppress warning""" 62 | # to avoid crashing for truncated (corrupted images) 63 | ImageFile.LOAD_TRUNCATED_IMAGES = True 64 | # open path as file to avoid ResourceWarning 65 | # (https://github.com/python-pillow/Pillow/issues/835) 66 | with open(path, 'rb') as f: 67 | img = Image.open(f) 68 | return img.convert('RGB') 69 | -------------------------------------------------------------------------------- /src/utils/logging.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Logging.""" 4 | 5 | import builtins 6 | import decimal 7 | import functools 8 | import logging 9 | import simplejson 10 | import sys 11 | import os 12 | from termcolor import colored 13 | 14 | from .distributed import is_master_process 15 | from .file_io import PathManager 16 | 17 | # Show filename and line number in logs 18 | _FORMAT = "[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s" 19 | 20 | 21 | def _suppress_print(): 22 | """Suppresses printing from the current process.""" 23 | 24 | def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False): 25 | pass 26 | 27 | builtins.print = print_pass 28 | 29 | 30 | # cache the opened file object, so that different calls to `setup_logger` 31 | # with the same file name can safely write to the same file. 32 | @functools.lru_cache(maxsize=None) 33 | def _cached_log_stream(filename): 34 | return PathManager.open(filename, "a") 35 | 36 | 37 | @functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers # noqa 38 | def setup_logging( 39 | num_gpu, num_shards, output="", name="visual_prompt", color=True): 40 | """Sets up the logging.""" 41 | # Enable logging only for the master process 42 | if is_master_process(num_gpu): 43 | # Clear the root logger to prevent any existing logging config 44 | # (e.g. set by another module) from messing with our setup 45 | logging.root.handlers = [] 46 | # Configure logging 47 | logging.basicConfig( 48 | level=logging.INFO, format=_FORMAT, stream=sys.stdout 49 | ) 50 | else: 51 | _suppress_print() 52 | 53 | if name is None: 54 | name = __name__ 55 | logger = logging.getLogger(name) 56 | # remove any lingering handler 57 | logger.handlers.clear() 58 | 59 | logger.setLevel(logging.INFO) 60 | logger.propagate = False 61 | 62 | plain_formatter = logging.Formatter( 63 | "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s", 64 | datefmt="%m/%d %H:%M:%S", 65 | ) 66 | if color: 67 | formatter = _ColorfulFormatter( 68 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 69 | datefmt="%m/%d %H:%M:%S", 70 | root_name=name, 71 | abbrev_name=str(name), 72 | ) 73 | else: 74 | formatter = plain_formatter 75 | 76 | if is_master_process(num_gpu): 77 | ch = logging.StreamHandler(stream=sys.stdout) 78 | ch.setLevel(logging.DEBUG) 79 | ch.setFormatter(formatter) 80 | logger.addHandler(ch) 81 | 82 | if is_master_process(num_gpu * num_shards): 83 | if len(output) > 0: 84 | if output.endswith(".txt") or output.endswith(".log"): 85 | filename = output 86 | else: 87 | filename = os.path.join(output, "logs.txt") 88 | 89 | PathManager.mkdirs(os.path.dirname(filename)) 90 | 91 | fh = logging.StreamHandler(_cached_log_stream(filename)) 92 | fh.setLevel(logging.DEBUG) 93 | fh.setFormatter(plain_formatter) 94 | logger.addHandler(fh) 95 | return logger 96 | 97 | 98 | def setup_single_logging(name, output=""): 99 | """Sets up the logging.""" 100 | # Enable logging only for the master process 101 | # Clear the root logger to prevent any existing logging config 102 | # (e.g. set by another module) from messing with our setup 103 | logging.root.handlers = [] 104 | # Configure logging 105 | logging.basicConfig( 106 | level=logging.INFO, format=_FORMAT, stream=sys.stdout 107 | ) 108 | 109 | if len(name) == 0: 110 | name = __name__ 111 | logger = logging.getLogger(name) 112 | logger.setLevel(logging.INFO) 113 | logger.propagate = False 114 | 115 | plain_formatter = logging.Formatter( 116 | "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s", 117 | datefmt="%m/%d %H:%M:%S", 118 | ) 119 | formatter = _ColorfulFormatter( 120 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 121 | datefmt="%m/%d %H:%M:%S", 122 | root_name=name, 123 | abbrev_name=str(name), 124 | ) 125 | 126 | ch = logging.StreamHandler(stream=sys.stdout) 127 | ch.setLevel(logging.DEBUG) 128 | ch.setFormatter(formatter) 129 | logger.addHandler(ch) 130 | 131 | if len(output) > 0: 132 | if output.endswith(".txt") or output.endswith(".log"): 133 | filename = output 134 | else: 135 | filename = os.path.join(output, "logs.txt") 136 | 137 | PathManager.mkdirs(os.path.dirname(filename)) 138 | 139 | fh = logging.StreamHandler(_cached_log_stream(filename)) 140 | fh.setLevel(logging.DEBUG) 141 | fh.setFormatter(plain_formatter) 142 | logger.addHandler(fh) 143 | 144 | return logger 145 | 146 | 147 | def get_logger(name): 148 | """Retrieves the logger.""" 149 | return logging.getLogger(name) 150 | 151 | 152 | def log_json_stats(stats, sort_keys=True): 153 | """Logs json stats.""" 154 | # It seems that in Python >= 3.6 json.encoder.FLOAT_REPR has no effect 155 | # Use decimal+string as a workaround for having fixed length values in logs 156 | logger = get_logger(__name__) 157 | stats = { 158 | k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v 159 | for k, v in stats.items() 160 | } 161 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True) 162 | if stats["_type"] == "test_epoch" or stats["_type"] == "train_epoch": 163 | logger.info("json_stats: {:s}".format(json_stats)) 164 | else: 165 | logger.info("{:s}".format(json_stats)) 166 | 167 | 168 | class _ColorfulFormatter(logging.Formatter): 169 | # from detectron2 170 | def __init__(self, *args, **kwargs): 171 | self._root_name = kwargs.pop("root_name") + "." 172 | self._abbrev_name = kwargs.pop("abbrev_name", "") 173 | if len(self._abbrev_name): 174 | self._abbrev_name = self._abbrev_name + "." 175 | super(_ColorfulFormatter, self).__init__(*args, **kwargs) 176 | 177 | def formatMessage(self, record: logging.LogRecord) -> str: 178 | record.name = record.name.replace(self._root_name, self._abbrev_name) 179 | log = super(_ColorfulFormatter, self).formatMessage(record) 180 | if record.levelno == logging.WARNING: 181 | prefix = colored("WARNING", "red", attrs=["blink"]) 182 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 183 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 184 | else: 185 | return log 186 | return prefix + " " + log 187 | -------------------------------------------------------------------------------- /src/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import torch 3 | 4 | 5 | def gpu_mem_usage(): 6 | """Computes the GPU memory usage for the current device (GB).""" 7 | if not torch.cuda.is_available(): 8 | return 0 9 | # Number of bytes in a megabyte 10 | _B_IN_GB = 1024 * 1024 * 1024 11 | 12 | mem_usage_bytes = torch.cuda.max_memory_allocated() 13 | return mem_usage_bytes / _B_IN_GB 14 | 15 | 16 | class AverageMeter(object): 17 | """Computes and stores the average and current value""" 18 | def __init__(self, name, fmt=':f'): 19 | self.name = name 20 | self.fmt = fmt 21 | self.reset() 22 | 23 | def reset(self): 24 | self.val = 0 25 | self.avg = 0 26 | self.sum = 0 27 | self.count = 0 28 | 29 | def update(self, val, n=1): 30 | self.val = val 31 | self.sum += val * n 32 | self.count += n 33 | self.avg = self.sum / self.count 34 | 35 | def __str__(self): 36 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 37 | return fmtstr.format(**self.__dict__) 38 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | major actions here: fine-tune the features and evaluate different settings 4 | """ 5 | import os 6 | import torch 7 | import warnings 8 | 9 | import numpy as np 10 | import random 11 | 12 | from time import sleep 13 | from random import randint 14 | 15 | import src.utils.logging as logging 16 | from src.configs.config import get_cfg 17 | from src.data import loader as data_loader 18 | from src.engine.evaluator import Evaluator 19 | from src.engine.trainer import Trainer 20 | from src.models.build_model import build_model 21 | from src.utils.file_io import PathManager 22 | 23 | from launch import default_argument_parser, logging_train_setup 24 | warnings.filterwarnings("ignore") 25 | torch.set_num_threads(4) 26 | 27 | 28 | def setup(args): 29 | """ 30 | Create configs and perform basic setups. 31 | """ 32 | cfg = get_cfg() 33 | cfg.merge_from_file(args.config_file) 34 | cfg.merge_from_list(args.opts) 35 | 36 | # setup dist 37 | # cfg.DIST_INIT_PATH = "tcp://{}:12399".format(os.environ["SLURMD_NODENAME"]) 38 | 39 | # setup output dir 40 | # output_dir / data_name / feature_name / lr_wd / run1 41 | output_dir = cfg.OUTPUT_DIR 42 | lr = cfg.SOLVER.BASE_LR 43 | wd = cfg.SOLVER.WEIGHT_DECAY 44 | output_folder = os.path.join( 45 | cfg.DATA.NAME, cfg.DATA.FEATURE, f"{args.id}_lr{lr}_wd{wd}") 46 | 47 | # train cfg.RUN_N_TIMES times 48 | count = 1 49 | while count <= cfg.RUN_N_TIMES: 50 | output_path = os.path.join(output_dir, output_folder, f"run{count}") 51 | # pause for a random time, so concurrent process with same setting won't interfere with each other. # noqa 52 | sleep(randint(3, 30)) 53 | if not PathManager.exists(output_path): 54 | PathManager.mkdirs(output_path) 55 | cfg.OUTPUT_DIR = output_path 56 | break 57 | else: 58 | count += 1 59 | if count > cfg.RUN_N_TIMES: 60 | raise ValueError( 61 | f"Already run {cfg.RUN_N_TIMES} times for {output_folder}, no need to run more") 62 | 63 | cfg.freeze() 64 | return cfg 65 | 66 | 67 | def get_loaders(cfg, logger): 68 | logger.info("Loading training data (final training data for vtab)...") 69 | if cfg.DATA.NAME.startswith("vtab-"): 70 | train_loader = data_loader.construct_trainval_loader(cfg) 71 | else: 72 | train_loader = data_loader.construct_train_loader(cfg) 73 | 74 | logger.info("Loading validation data...") 75 | # not really needed for vtab 76 | val_loader = data_loader.construct_val_loader(cfg) 77 | logger.info("Loading test data...") 78 | if cfg.DATA.NO_TEST: 79 | logger.info("...no test data is constructed") 80 | test_loader = None 81 | else: 82 | test_loader = data_loader.construct_test_loader(cfg) 83 | return train_loader, val_loader, test_loader 84 | 85 | 86 | def train(cfg, args): 87 | # clear up residual cache from previous runs 88 | if torch.cuda.is_available(): 89 | torch.cuda.empty_cache() 90 | 91 | # main training / eval actions here 92 | 93 | # fix the seed for reproducibility 94 | if cfg.SEED is not None: 95 | torch.manual_seed(cfg.SEED) 96 | np.random.seed(cfg.SEED) 97 | random.seed(0) 98 | 99 | # setup training env including loggers 100 | logging_train_setup(args, cfg) 101 | logger = logging.get_logger("visual_prompt") 102 | 103 | train_loader, val_loader, test_loader = get_loaders(cfg, logger) 104 | logger.info("Constructing models...") 105 | model, cur_device = build_model(cfg) 106 | 107 | trainable_params = [name for name, p in model.named_parameters() if p.requires_grad] 108 | print(trainable_params) 109 | 110 | logger.info("Setting up Evalutator...") 111 | evaluator = Evaluator() 112 | logger.info("Setting up Trainer...") 113 | trainer = Trainer(cfg, model, evaluator, cur_device) 114 | 115 | if train_loader: 116 | trainer.train_classifier(train_loader, val_loader, test_loader) 117 | else: 118 | print("No train loader presented. Exit") 119 | 120 | if cfg.SOLVER.TOTAL_EPOCH == 0: 121 | trainer.eval_classifier(test_loader, "test", 0) 122 | 123 | 124 | def main(args): 125 | """main function to call from workflow""" 126 | # set up cfg and args 127 | cfg = setup(args) 128 | with open(os.path.join(cfg.OUTPUT_DIR, 'configs.yaml'), 'w') as f: 129 | f.write(cfg.dump()) 130 | # Perform training. 131 | train(cfg, args) 132 | 133 | 134 | if __name__ == '__main__': 135 | args = default_argument_parser().parse_args() 136 | main(args) 137 | --------------------------------------------------------------------------------