├── .gitignore ├── LICENSE ├── README.md ├── collect_demo_exp_results.py ├── configs ├── base-finetune.yaml ├── base-linear.yaml ├── base-prompt.yaml ├── finetune │ ├── cars.yaml │ ├── cub.yaml │ ├── dogs.yaml │ ├── flowers.yaml │ └── nabirds.yaml ├── h2t-prompt │ └── vtab.yaml ├── linear │ ├── cars.yaml │ ├── cub.yaml │ ├── dogs.yaml │ ├── flowers.yaml │ └── nabirds.yaml └── prompt │ ├── cars.yaml │ ├── cub.yaml │ ├── dogs.yaml │ ├── flowers.yaml │ └── nabirds.yaml ├── head2toe_sparsity_train.py ├── head2toe_train.py ├── launch.py ├── pre-trained_weights └── .gitignore ├── run_demo_exp.sh ├── scripts └── VQT │ ├── run_vqt_vtab.sh │ ├── run_vqt_vtab_sparsity.sh │ └── run_vqt_vtab_ssl.sh ├── src ├── configs │ ├── config.py │ ├── config_node.py │ └── vit_configs.py ├── data │ ├── datasets │ │ ├── json_dataset.py │ │ └── tf_dataset.py │ ├── loader.py │ ├── transforms.py │ └── vtab_datasets │ │ ├── __init__.py │ │ ├── base.py │ │ ├── caltech.py │ │ ├── cifar.py │ │ ├── clevr.py │ │ ├── diabetic_retinopathy.py │ │ ├── dmlab.py │ │ ├── dsprites.py │ │ ├── dtd.py │ │ ├── eurosat.py │ │ ├── kitti.py │ │ ├── oxford_flowers102.py │ │ ├── oxford_iiit_pet.py │ │ ├── patch_camelyon.py │ │ ├── registry.py │ │ ├── resisc45.py │ │ ├── smallnorb.py │ │ ├── sun397.py │ │ └── svhn.py ├── engine │ ├── eval │ │ ├── multilabel.py │ │ └── singlelabel.py │ ├── evaluator.py │ ├── h2t_sparsity_trainer.py │ └── trainer.py ├── models │ ├── build_h2t_model.py │ ├── build_model.py │ ├── build_vit_backbone.py │ ├── convnext.py │ ├── convnext_backbone │ │ └── convnext.py │ ├── mlp.py │ ├── resnet.py │ ├── vit_adapter │ │ ├── adapter_block.py │ │ ├── vit.py │ │ ├── vit_mae.py │ │ └── vit_moco.py │ ├── vit_backbones │ │ ├── h2t_vit.py │ │ ├── h2t_vit_mae.py │ │ ├── swin_transformer.py │ │ ├── timm_h2t_vit.py │ │ ├── vit.py │ │ ├── vit_mae.py │ │ └── vit_moco.py │ ├── vit_models.py │ └── vit_prompt │ │ ├── swin_transformer.py │ │ ├── vit.py │ │ ├── vit_ablations.py │ │ ├── vit_mae.py │ │ └── vit_moco.py ├── solver │ ├── losses.py │ ├── lr_scheduler.py │ └── optimizer.py └── utils │ ├── distributed.py │ ├── file_io.py │ ├── io_utils.py │ ├── logging.py │ ├── train_utils.py │ └── vis_utils.py ├── tune_vtab.py └── vtab_data └── .gitignore /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, Cheng-Hao Tu 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Visual Query Tuning (VQT) 2 | 3 | This is an offical implementation of [Visual Query Tuning: Towards Effective Usage of Intermediate Representations for Parameter and Memory Efficient Transfer Learning](https://arxiv.org/pdf/2212.03220.pdf). 4 | 5 | 6 | ## Dependencies 7 | 8 | * python3.7 9 | * torch==1.7.1 10 | * torchvision==0.8.2 11 | * tensorflow==2.9.1 12 | * tensorflow_datasets==4.4.0+nightly 13 | 14 | ## Usage 15 | 16 | We present instructions on training VQT with a ImageNet-1k pre-trained ViT-B/16. 17 | 18 | ### Perparing the data 19 | 20 | Please setup the VTAB-1k benchmark following the instruction [here](https://github.com/KMnP/vpt/blob/main/VTAB_SETUP.md). By default, our scripts will try to access the VTAB-1k datasets from the `vtab_data/` folder. You can modify the `DATA_PATH` variable in our scripts, which are placed under the `scripts/` folder, if you download the datasets to another place. 21 | 22 | For pre-trained ViT-B/16 models, you can download the weights of various pre-training setups as follows: 23 | * [ImageNet-1k Supervised](https://drive.google.com/file/d/1ruqA7gRkvkoM_QFrT0JI23XXZNnwqXBK/view?usp=share_link) 24 | * [ImageNet-21k Supervised](https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz) 25 | * [ImageNet-1k MAE](https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth) 26 | 27 | Please place the downloaded checkpoints under the `pre-trained_weights/` folder. Note that you need to rename the ImageNet-21k supervised checkpoint from `ViT-B_16.npz` to `imagenet21k_ViT-B_16.npz`. 28 | 29 | 30 | ### Training VQT 31 | 32 | Use the following command to train a VQT model on a dataset in VTAB-1k. 33 | 34 | ```bash 35 | $ bash scripts/VQT/run_vqt_vtab.sh ${GPUIDX} ${DATA_NAME} ${NUM_CLASSES} ${Q_LEN} ${OPTIMIZER} ${FEATURE} 36 | ``` 37 | 38 | We describe the meaning of these arguments as follows: 39 | * `${GPUIDX}`: The GPU used for training. For example, it can be set to 0. 40 | * `${DATA_NAME}`: The dataset name in VTAB-1k for training and evaluation. For example, it can be set to `vtab-caltech101`. Please see `run_demo_exp.sh` for more details about the 19 datasets in VTAB-1k. 41 | * `${NUM_CLASSES}`: The number of classes in the dataset. For example, for `vtab-caltech101`, this should be set to 102. 42 | * `${Q_LEN}`: The length of the query tokens. This can be simply set to 1. 43 | * `${OPTIMIZER}`: The optimizer used for training. In our experiments, we set this to `adam`. 44 | * `${FEATURE}`: The name of the pre-trained features. For example, it can be set to `sup_vitb16_imagenet1k` to indicate the ImageNet-1k supervised pre-trained model. 45 | 46 | After training a VQT model, you can optionally use the following command to compress the linear classifier via feature selection. 47 | 48 | ```bash 49 | $ bash scripts/VQT/run_vqt_vtab_sparsity.sh ${GPUIDX} ${DATA_NAME} ${NUM_CLASSES} ${Q_LEN} ${OPTIMIZER} ${FEATURE} ${FRACTION} 50 | ``` 51 | 52 | The first 6 arguments, `${GPUIDX}`, `${DATA_NAME}`, `${NUM_CLASSES}`, `${Q_LEN}`, `${OPTIMIZER}`, and `${FEATURE}`, are the same as the previous command for training a VQT model, and they can be set accordingly to indicate the trained VQT model we are going to compress. The last argument `${FRACTION}` specifies the proportion of the pre-classifier features (penultimate layer features) that we want to keep after compression. For example, it can be set to 0.7 to indicate keeping 70% of the features input to the final linear classifier. 53 | 54 | 55 | ### Demo experiment 56 | 57 | For simplicity, you can use the following command for running through all the 19 datasets in VTAB-1k. 58 | 59 | ```bash 60 | $ bash run_demo_exp.sh ${GPUIDX} 61 | ``` 62 | 63 | The `${GPUIDX}` argument specifies the GPU used for training (e.g., 0). 64 | 65 | After training VQT models for all the 19 datasets, you can use the following command to collect the results. 66 | 67 | ```bash 68 | $ python collect_demo_exp_results.py 69 | ``` 70 | 71 | 72 | ## Reference 73 | 74 | This repo is modified from [Visual Prompt Tuning (VPT)](https://github.com/KMnP/vpt). 75 | 76 | ## Contact 77 | 78 | If you have any questions, please contact [Cheng-Hao Tu](https://andytu28.github.io/)(tu.343@osu.edu). 79 | -------------------------------------------------------------------------------- /collect_demo_exp_results.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import torch 4 | import numpy as np 5 | 6 | 7 | setting_list = ['VQTSup_1_adam_sparsefinal', 'VQTSup_10_adam_final', 'VQTSup_20_adam_final'] 8 | run_idx_end = 5 9 | arch_list = ['sup_vitb16_imagenet1k'] 10 | data_names = ['vtab-caltech101', 'vtab-dtd', 'vtab-cifar(num_classes=100)', 'vtab-oxford_flowers102', 'vtab-oxford_iiit_pet', 11 | 'vtab-svhn', 'vtab-patch_camelyon', 'vtab-resisc45', 'vtab-eurosat', 'vtab-diabetic_retinopathy(config="btgraham-300")', 12 | 'vtab-dmlab', 'vtab-clevr(task="closest_object_distance")', 'vtab-clevr(task="count_all")', 13 | 'vtab-dsprites(predicted_attribute="label_orientation",num_classes=16)', 'vtab-dsprites(predicted_attribute="label_x_position",num_classes=16)', 14 | 'vtab-smallnorb(predicted_attribute="label_azimuth")', 'vtab-smallnorb(predicted_attribute="label_elevation")', 15 | 'vtab-kitti(task="closest_vehicle_distance")', 'vtab-sun397'] 16 | 17 | for setting_idx, setting in enumerate(setting_list): 18 | print(f'====== {setting} ======') 19 | for data_idx, data_name in enumerate(data_names): 20 | for arch in arch_list: 21 | directory_name = os.path.join( 22 | 'h2t_vit_experiments', setting, data_name, arch) 23 | if not os.path.exists(directory_name): 24 | continue 25 | hyparam = os.listdir(directory_name)[0] 26 | hyparam_dir = os.path.join(directory_name, hyparam) 27 | if 'keep_frac_0.7' in os.listdir(hyparam_dir): 28 | prefix = 'keep_frac_0.7/' 29 | else: 30 | prefix = '' 31 | 32 | results = [] 33 | for run_idx in range(1, run_idx_end+1): 34 | path = os.path.join( 35 | hyparam_dir, f'{prefix}run{run_idx}/eval_results.pth') 36 | if not os.path.exists(path): 37 | continue 38 | 39 | content = torch.load(path, map_location='cpu') 40 | results.append(content['epoch_99']['classification'][f'test_{data_name}']['top1']*100.) 41 | # print(f'--- {data_name} - {arch} - {hyparam}') 42 | # print('{:.2f}'.format( 43 | # content['epoch_99']['classification'][f'test_{data_name}']['top1']*100.)) 44 | 45 | print(f'--- {data_name} - {arch} - {hyparam}') 46 | print('=====>', len(results)) 47 | print('{:.2f}'.format(np.mean(results))) 48 | print() 49 | -------------------------------------------------------------------------------- /configs/base-finetune.yaml: -------------------------------------------------------------------------------- 1 | NUM_GPUS: 1 2 | NUM_SHARDS: 1 3 | OUTPUT_DIR: "" 4 | RUN_N_TIMES: 1 5 | MODEL: 6 | TRANSFER_TYPE: "end2end" 7 | TYPE: "vit" 8 | LINEAR: 9 | MLP_SIZES: [] 10 | SOLVER: 11 | SCHEDULER: "cosine" 12 | WARMUP_EPOCH: 5 13 | LOSS: "softmax" 14 | OPTIMIZER: "adamw" 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 0.0001 17 | LOG_EVERY_N: 100 18 | TOTAL_EPOCH: 100 19 | PATIENCE: 300 20 | DATA: 21 | NAME: "" 22 | NUMBER_CLASSES: -1 23 | DATAPATH: "" 24 | FEATURE: "sup_vitb16_224" 25 | BATCH_SIZE: 384 26 | NUM_WORKERS: 8 -------------------------------------------------------------------------------- /configs/base-linear.yaml: -------------------------------------------------------------------------------- 1 | NUM_GPUS: 1 2 | NUM_SHARDS: 1 3 | OUTPUT_DIR: "" 4 | RUN_N_TIMES: 1 5 | MODEL: 6 | TRANSFER_TYPE: "linear" 7 | TYPE: "vit" 8 | LINEAR: 9 | MLP_SIZES: [] 10 | SOLVER: 11 | SCHEDULER: "cosine" 12 | PATIENCE: 300 13 | LOSS: "softmax" 14 | OPTIMIZER: "sgd" 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 0.0001 17 | LOG_EVERY_N: 1 18 | WARMUP_EPOCH: 10 19 | TOTAL_EPOCH: 100 20 | DATA: 21 | NAME: "" 22 | NUMBER_CLASSES: -1 23 | DATAPATH: "" 24 | FEATURE: "sup_vitb16_224" 25 | BATCH_SIZE: 1024 -------------------------------------------------------------------------------- /configs/base-prompt.yaml: -------------------------------------------------------------------------------- 1 | NUM_GPUS: 1 2 | NUM_SHARDS: 1 3 | OUTPUT_DIR: "" 4 | RUN_N_TIMES: 1 5 | MODEL: 6 | TRANSFER_TYPE: "prompt" 7 | TYPE: "vit" 8 | LINEAR: 9 | MLP_SIZES: [] 10 | SOLVER: 11 | SCHEDULER: "cosine" 12 | PATIENCE: 300 13 | LOSS: "softmax" 14 | OPTIMIZER: "sgd" 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 0.0001 17 | LOG_EVERY_N: 100 18 | WARMUP_EPOCH: 10 19 | TOTAL_EPOCH: 100 20 | DATA: 21 | NAME: "" 22 | NUMBER_CLASSES: -1 23 | DATAPATH: "" 24 | FEATURE: "sup_vitb16_224" 25 | BATCH_SIZE: 128 26 | -------------------------------------------------------------------------------- /configs/finetune/cars.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-finetune.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "StanfordCars" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 196 7 | MULTILABEL: False 8 | FEATURE: "imagenet_supervised" # need to tune 9 | MODEL: 10 | TYPE: "vit" 11 | SOLVER: 12 | BASE_LR: 0.0375 13 | WEIGHT_DECAY: 0.001 14 | -------------------------------------------------------------------------------- /configs/finetune/cub.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-finetune.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "CUB" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 200 7 | MULTILABEL: False 8 | FEATURE: "imagenet_supervised" # need to tune 9 | MODEL: 10 | TYPE: "vit" 11 | SOLVER: 12 | BASE_LR: 0.00375 13 | WEIGHT_DECAY: 0.01 14 | -------------------------------------------------------------------------------- /configs/finetune/dogs.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-finetune.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "StanfordDogs" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 120 7 | MULTILABEL: False 8 | FEATURE: "imagenet_supervised" # need to tune 9 | MODEL: 10 | TYPE: "vit" 11 | SOLVER: 12 | BASE_LR: 0.00375 13 | WEIGHT_DECAY: 0.001 14 | -------------------------------------------------------------------------------- /configs/finetune/flowers.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-finetune.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "OxfordFlowers" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 102 7 | MULTILABEL: False 8 | FEATURE: "imagenet_supervised" # need to tune 9 | MODEL: 10 | TYPE: "vit" 11 | SOLVER: 12 | BASE_LR: 0.001 13 | WEIGHT_DECAY: 0.0001 14 | -------------------------------------------------------------------------------- /configs/finetune/nabirds.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-finetune.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "nabirds" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 555 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.00375 12 | WEIGHT_DECAY: 0.01 13 | -------------------------------------------------------------------------------- /configs/h2t-prompt/vtab.yaml: -------------------------------------------------------------------------------- 1 | NUM_GPUS: 1 2 | NUM_SHARDS: 1 3 | OUTPUT_DIR: "" 4 | RUN_N_TIMES: 1 5 | MODEL: 6 | TRANSFER_TYPE: "h2t-prompt" 7 | TYPE: "h2t-vit" 8 | LINEAR: 9 | MLP_SIZES: [] 10 | SOLVER: 11 | SCHEDULER: "cosine" 12 | PATIENCE: 300 13 | LOSS: "softmax" 14 | OPTIMIZER: "sgd" 15 | MOMENTUM: 0.9 16 | LOG_EVERY_N: 100 17 | WARMUP_EPOCH: 10 18 | TOTAL_EPOCH: 100 19 | BASE_LR: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | DATA: 22 | NAME: "" 23 | NUMBER_CLASSES: -1 24 | DATAPATH: "" 25 | FEATURE: "sup_vitb16_224" 26 | BATCH_SIZE: 128 27 | -------------------------------------------------------------------------------- /configs/linear/cars.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-linear.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "StanfordCars" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 196 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.001 12 | WEIGHT_DECAY: 0.0001 13 | -------------------------------------------------------------------------------- /configs/linear/cub.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-linear.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "CUB" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 200 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.1 12 | WEIGHT_DECAY: 0.01 -------------------------------------------------------------------------------- /configs/linear/dogs.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-linear.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "StanfordDogs" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 120 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.001 12 | WEIGHT_DECAY: 0.0001 -------------------------------------------------------------------------------- /configs/linear/flowers.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-linear.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "OxfordFlowers" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 102 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.001 12 | WEIGHT_DECAY: 0.0001 -------------------------------------------------------------------------------- /configs/linear/nabirds.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-linear.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "nabirds" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 555 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.1 12 | WEIGHT_DECAY: 0.01 13 | LOG_EVERY_N: 10 -------------------------------------------------------------------------------- /configs/prompt/cars.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-prompt.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "StanfordCars" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 196 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.001 12 | WEIGHT_DECAY: 0.0001 13 | -------------------------------------------------------------------------------- /configs/prompt/cub.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-prompt.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "CUB" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 200 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.1 12 | WEIGHT_DECAY: 0.01 -------------------------------------------------------------------------------- /configs/prompt/dogs.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-prompt.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "StanfordDogs" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 120 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.001 12 | WEIGHT_DECAY: 0.0001 -------------------------------------------------------------------------------- /configs/prompt/flowers.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-prompt.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "OxfordFlowers" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 102 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.001 12 | WEIGHT_DECAY: 0.0001 -------------------------------------------------------------------------------- /configs/prompt/nabirds.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-prompt.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "nabirds" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 555 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.1 12 | WEIGHT_DECAY: 0.01 -------------------------------------------------------------------------------- /head2toe_train.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import random 3 | from random import randint 4 | from time import sleep 5 | import os 6 | import torch 7 | import numpy as np 8 | import src.utils.logging as logging 9 | from src.utils.file_io import PathManager 10 | from src.configs.config import get_cfg 11 | from src.models.build_h2t_model import build_head2toe_model 12 | from launch import default_argument_parser, logging_train_setup 13 | from tune_vtab import get_loaders 14 | from src.engine.evaluator import Evaluator 15 | from src.engine.trainer import Trainer 16 | 17 | 18 | def get_lrwd_range(args): 19 | 20 | if args.train_type == "h2t-prompt": 21 | if args.optimizer == 'sgd': 22 | lr_range = [ 23 | 5.0, 2.5, 1.0, 24 | 50.0, 25., 10.0, 25 | 0.5, 0.25, 0.1, 0.05 26 | ] 27 | wd_range = [0.01, 0.001, 0.0001, 0.0] 28 | else: # For adam basically 29 | if args.arch_feats == 'sup_vitl16_imagenet21k': 30 | lr_range = [ 31 | 0.5, 0.1, 0.05 32 | ] 33 | wd_range = [0.01, 0.0001, 0.0] 34 | elif args.arch_feats == 'sup_vith14_imagenet21k': 35 | lr_range = [ 36 | 0.5, 0.1, 0.05 37 | ] 38 | wd_range = [0.01, 0.0001, 0.0] 39 | else: 40 | lr_range = [ 41 | 1.0, 0.5, 0.25, 0.1, 0.05 42 | ] 43 | wd_range = [0.01, 0.001, 0.0001, 0.0] 44 | else: 45 | raise ValueError() 46 | 47 | return lr_range, wd_range 48 | 49 | 50 | def find_best_lrwd(files, data_name): 51 | t_name = "val_" + data_name 52 | best_lr = None 53 | best_wd = None 54 | best_val_acc = -1 55 | for f in files: 56 | try: 57 | results_dict = torch.load(f, "cpu") 58 | epoch = len(results_dict) - 1 59 | val_result = results_dict[f"epoch_{epoch}"]["classification"][t_name]["top1"] 60 | val_result = float(val_result) 61 | except Exception as e: 62 | print(f"Encounter issue: {e} for file {f}") 63 | continue 64 | 65 | if val_result == best_val_acc: 66 | frag_txt = f.split("/run")[0] 67 | cur_lr = float(frag_txt.split("/lr")[-1].split("_wd")[0]) 68 | cur_wd = float(frag_txt.split("_wd")[-1].split('_bs')[0]) 69 | if best_lr is not None and cur_lr < best_lr: 70 | # get the smallest lr to break tie for stability 71 | best_lr = cur_lr 72 | best_wd = cur_wd 73 | best_val_acc = val_result 74 | 75 | elif val_result > best_val_acc: 76 | best_val_acc = val_result 77 | frag_txt = f.split("/run")[0] 78 | best_lr = float(frag_txt.split("/lr")[-1].split("_wd")[0]) 79 | best_wd = float(frag_txt.split("_wd")[-1].split('_bs')[0]) 80 | return best_lr, best_wd 81 | 82 | 83 | def setup(args, lr, wd, final_runs, run_idx=None, seed=100): 84 | """ 85 | Create configs and perform basic setups. 86 | """ 87 | cfg = get_cfg() 88 | cfg.merge_from_file(args.config_file) 89 | cfg.merge_from_list(args.opts) 90 | cfg.SOLVER.DBG_TRAINABLE = True 91 | assert(args.optimizer == cfg.SOLVER.OPTIMIZER) 92 | 93 | cfg.SEED = seed 94 | 95 | if not final_runs: 96 | cfg.RUN_N_TIMES = 1 97 | cfg.MODEL.SAVE_CKPT = False 98 | cfg.OUTPUT_DIR = cfg.OUTPUT_DIR + '_val' 99 | lr = lr / 256 * cfg.DATA.BATCH_SIZE 100 | cfg.SOLVER.BASE_LR = lr 101 | cfg.SOLVER.WEIGHT_DECAY = wd 102 | else: 103 | cfg.RUN_N_TIMES = 5 # No use. Just individually try out 5 seeds 104 | cfg.MODEL.SAVE_CKPT = True 105 | files = glob.glob( 106 | f'{cfg.OUTPUT_DIR}_val/{cfg.DATA.NAME}/{cfg.DATA.FEATURE}/*/' 107 | + 'run1/eval_results.pth' 108 | ) 109 | cfg.OUTPUT_DIR = cfg.OUTPUT_DIR + '_final' 110 | lr, wd = find_best_lrwd(files, cfg.DATA.NAME) 111 | cfg.SOLVER.BASE_LR = lr 112 | cfg.SOLVER.WEIGHT_DECAY = wd 113 | 114 | # Setup the output dir 115 | output_dir = cfg.OUTPUT_DIR 116 | bs = cfg.DATA.BATCH_SIZE 117 | output_folder = os.path.join( 118 | cfg.DATA.NAME, cfg.DATA.FEATURE, f'lr{lr}_wd{wd}_bs{bs}' 119 | ) 120 | 121 | # Train cfg.RUN_N_TIMES times 122 | if run_idx is None: 123 | count = 1 124 | while count <= cfg.RUN_N_TIMES: 125 | output_path = os.path.join( 126 | output_dir, output_folder, f'run{count}') 127 | sleep(randint(1, 5)) 128 | if not PathManager.exists(output_path): 129 | PathManager.mkdirs(output_path) 130 | cfg.OUTPUT_DIR = output_path 131 | break 132 | else: 133 | count += 1 134 | if count > cfg.RUN_N_TIMES: 135 | raise ValueError( 136 | f'Already run {cfg.RUN_N_TIMES} times for {output_folder}.' 137 | ) 138 | 139 | else: 140 | output_path = os.path.join( 141 | output_dir, output_folder, f'run{run_idx}') 142 | if not PathManager.exists(output_path): 143 | PathManager.mkdirs(output_path) 144 | cfg.OUTPUT_DIR = output_path 145 | else: 146 | raise ValueError( 147 | f'Already run run-{run_idx} for {output_dir}.' 148 | ) 149 | 150 | cfg.freeze() 151 | return cfg 152 | 153 | 154 | def train(cfg, args, final_runs): 155 | 156 | if torch.cuda.is_available(): 157 | torch.cuda.empty_cache() 158 | 159 | if cfg.SEED is not None: 160 | torch.manual_seed(cfg.SEED) 161 | np.random.seed(cfg.SEED) 162 | random.seed(0) 163 | 164 | # Setup training env including loggers 165 | logging_train_setup(args, cfg) 166 | logger = logging.get_logger('visual_prompt') 167 | 168 | # Setup data loaders 169 | train_loader, val_loader, test_loader = get_loaders( 170 | cfg, logger, final_runs) 171 | logger.info('Constructing models ...') 172 | model, cur_device = build_head2toe_model(cfg) 173 | 174 | # Setup the evaluator 175 | logger.info('Setting up Evaluator ...') 176 | evaluator = Evaluator() 177 | logger.info("Setting up Trainer...") 178 | trainer = Trainer(cfg, model, evaluator, cur_device) 179 | 180 | if train_loader: 181 | trainer.train_classifier(train_loader, val_loader, test_loader) 182 | torch.save( 183 | evaluator.results, 184 | os.path.join(cfg.OUTPUT_DIR, 'eval_results.pth') 185 | ) 186 | else: 187 | print('No train loader presented. Exit.') 188 | 189 | 190 | def main(args): 191 | 192 | # Tuning lr and wd on the validation set 193 | if not args.dont_search: 194 | lr_range, wd_range = get_lrwd_range(args) 195 | for lr in sorted(lr_range, reverse=True): 196 | for wd in sorted(wd_range, reverse=True): 197 | print(f'val ==> lr {lr}, wd {wd}') 198 | try: 199 | cfg = setup(args, lr, wd, final_runs=False) 200 | except ValueError: 201 | continue # Already run 202 | train(cfg, args, final_runs=False) 203 | 204 | # Final run 5 times with different seeds 205 | random_seeds = [42, 44, 82, 100, 800] 206 | for run_idx, seed in enumerate(random_seeds): 207 | try: 208 | cfg = setup( 209 | args, 0.1, 0.1, final_runs=True, 210 | run_idx=run_idx+1, seed=seed) 211 | except ValueError: 212 | continue # Already run 213 | train(cfg, args, final_runs=True) 214 | 215 | 216 | if __name__ == '__main__': 217 | parser = default_argument_parser() 218 | parser.add_argument('--dont_search', default=False, action='store_true', 219 | help='') 220 | parser.add_argument('--optimizer', type=str, default='sgd', 221 | help='') 222 | parser.add_argument('--arch_feats', type=str, default='', 223 | help='') 224 | args = parser.parse_args() 225 | main(args) 226 | -------------------------------------------------------------------------------- /launch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | launch helper functions 4 | """ 5 | import argparse 6 | import os 7 | import sys 8 | import pprint 9 | import PIL 10 | from collections import defaultdict 11 | from tabulate import tabulate 12 | from typing import Tuple 13 | 14 | import torch 15 | from src.utils.file_io import PathManager 16 | from src.utils import logging 17 | from src.utils.distributed import get_rank, get_world_size 18 | 19 | 20 | def collect_torch_env() -> str: 21 | try: 22 | import torch.__config__ 23 | 24 | return torch.__config__.show() 25 | except ImportError: 26 | # compatible with older versions of pytorch 27 | from torch.utils.collect_env import get_pretty_env_info 28 | 29 | return get_pretty_env_info() 30 | 31 | 32 | def get_env_module() -> Tuple[str]: 33 | var_name = "ENV_MODULE" 34 | return var_name, os.environ.get(var_name, "") 35 | 36 | 37 | def collect_env_info() -> str: 38 | data = [] 39 | data.append(("Python", sys.version.replace("\n", ""))) 40 | data.append(get_env_module()) 41 | data.append(("PyTorch", torch.__version__)) 42 | data.append(("PyTorch Debug Build", torch.version.debug)) 43 | 44 | has_cuda = torch.cuda.is_available() 45 | data.append(("CUDA available", has_cuda)) 46 | if has_cuda: 47 | data.append(("CUDA ID", os.environ["CUDA_VISIBLE_DEVICES"])) 48 | devices = defaultdict(list) 49 | for k in range(torch.cuda.device_count()): 50 | devices[torch.cuda.get_device_name(k)].append(str(k)) 51 | for name, devids in devices.items(): 52 | data.append(("GPU " + ",".join(devids), name)) 53 | data.append(("Pillow", PIL.__version__)) 54 | 55 | try: 56 | import cv2 57 | 58 | data.append(("cv2", cv2.__version__)) 59 | except ImportError: 60 | pass 61 | env_str = tabulate(data) + "\n" 62 | env_str += collect_torch_env() 63 | return env_str 64 | 65 | 66 | def default_argument_parser(): 67 | """ 68 | create a simple parser to wrap around config file 69 | """ 70 | parser = argparse.ArgumentParser(description="visual-prompt") 71 | parser.add_argument( 72 | "--config-file", default="", metavar="FILE", help="path to config file") 73 | parser.add_argument( 74 | "--train-type", default="", help="training types") 75 | parser.add_argument( 76 | "opts", 77 | help="Modify config options using the command-line", 78 | default=None, 79 | nargs=argparse.REMAINDER, 80 | ) 81 | 82 | return parser 83 | 84 | 85 | def logging_train_setup(args, cfg) -> None: 86 | output_dir = cfg.OUTPUT_DIR 87 | if output_dir: 88 | PathManager.mkdirs(output_dir) 89 | 90 | logger = logging.setup_logging( 91 | cfg.NUM_GPUS, get_world_size(), output_dir, name="visual_prompt") 92 | 93 | # Log basic information about environment, cmdline arguments, and config 94 | rank = get_rank() 95 | logger.info( 96 | f"Rank of current process: {rank}. World size: {get_world_size()}") 97 | logger.info("Environment info:\n" + collect_env_info()) 98 | 99 | logger.info("Command line arguments: " + str(args)) 100 | if hasattr(args, "config_file") and args.config_file != "": 101 | logger.info( 102 | "Contents of args.config_file={}:\n{}".format( 103 | args.config_file, 104 | PathManager.open(args.config_file, "r").read() 105 | ) 106 | ) 107 | # Show the config 108 | logger.info("Training with config:") 109 | logger.info(pprint.pformat(cfg)) 110 | # cudnn benchmark has large overhead. 111 | # It shouldn't be used considering the small size of typical val set. 112 | if not (hasattr(args, "eval_only") and args.eval_only): 113 | torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK 114 | -------------------------------------------------------------------------------- /pre-trained_weights/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /run_demo_exp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | GPUIDX=$1 5 | 6 | # ===== Q = 1, frac = 0.7 7 | bash scripts/VQT/run_vqt_vtab.sh $GPUIDX 'vtab-caltech101' 102 1 adam sup_vitb16_imagenet1k 8 | bash scripts/VQT/run_vqt_vtab_sparsity.sh $GPUIDX 'vtab-caltech101' 102 1 adam sup_vitb16_imagenet1k 0.7 9 | bash scripts/VQT/run_vqt_vtab.sh $GPUIDX 'vtab-dtd' 47 1 adam sup_vitb16_imagenet1k 10 | bash scripts/VQT/run_vqt_vtab_sparsity.sh $GPUIDX 'vtab-dtd' 47 1 adam sup_vitb16_imagenet1k 0.7 11 | bash scripts/VQT/run_vqt_vtab.sh $GPUIDX 'vtab-oxford_flowers102' 102 1 adam sup_vitb16_imagenet1k 12 | bash scripts/VQT/run_vqt_vtab_sparsity.sh $GPUIDX 'vtab-oxford_flowers102' 102 1 adam sup_vitb16_imagenet1k 0.7 13 | bash scripts/VQT/run_vqt_vtab.sh $GPUIDX 'vtab-oxford_iiit_pet' 37 1 adam sup_vitb16_imagenet1k 14 | bash scripts/VQT/run_vqt_vtab_sparsity.sh $GPUIDX 'vtab-oxford_iiit_pet' 37 1 adam sup_vitb16_imagenet1k 0.7 15 | bash scripts/VQT/run_vqt_vtab.sh $GPUIDX 'vtab-cifar(num_classes=100)' 100 1 adam sup_vitb16_imagenet1k 16 | bash scripts/VQT/run_vqt_vtab_sparsity.sh $GPUIDX 'vtab-cifar(num_classes=100)' 100 1 adam sup_vitb16_imagenet1k 0.7 17 | bash scripts/VQT/run_vqt_vtab.sh $GPUIDX 'vtab-patch_camelyon' 2 1 adam sup_vitb16_imagenet1k 18 | bash scripts/VQT/run_vqt_vtab_sparsity.sh $GPUIDX 'vtab-patch_camelyon' 2 1 adam sup_vitb16_imagenet1k 0.7 19 | bash scripts/VQT/run_vqt_vtab.sh $GPUIDX 'vtab-resisc45' 45 1 adam sup_vitb16_imagenet1k 20 | bash scripts/VQT/run_vqt_vtab_sparsity.sh $GPUIDX 'vtab-resisc45' 45 1 adam sup_vitb16_imagenet1k 0.7 21 | bash scripts/VQT/run_vqt_vtab.sh $GPUIDX 'vtab-diabetic_retinopathy(config="btgraham-300")' 5 1 adam sup_vitb16_imagenet1k 22 | bash scripts/VQT/run_vqt_vtab_sparsity.sh $GPUIDX 'vtab-diabetic_retinopathy(config="btgraham-300")' 5 1 adam sup_vitb16_imagenet1k 0.7 23 | bash scripts/VQT/run_vqt_vtab.sh $GPUIDX 'vtab-eurosat' 10 1 adam sup_vitb16_imagenet1k 24 | bash scripts/VQT/run_vqt_vtab_sparsity.sh $GPUIDX 'vtab-eurosat' 10 1 adam sup_vitb16_imagenet1k 0.7 25 | bash scripts/VQT/run_vqt_vtab.sh $GPUIDX 'vtab-clevr(task="count_all")' 8 1 adam sup_vitb16_imagenet1k 26 | bash scripts/VQT/run_vqt_vtab_sparsity.sh $GPUIDX 'vtab-clevr(task="count_all")' 8 1 adam sup_vitb16_imagenet1k 0.7 27 | bash scripts/VQT/run_vqt_vtab.sh $GPUIDX 'vtab-dsprites(predicted_attribute="label_orientation",num_classes=16)' 16 1 adam sup_vitb16_imagenet1k 28 | bash scripts/VQT/run_vqt_vtab_sparsity.sh $GPUIDX 'vtab-dsprites(predicted_attribute="label_orientation",num_classes=16)' 16 1 adam sup_vitb16_imagenet1k 0.7 29 | bash scripts/VQT/run_vqt_vtab.sh $GPUIDX 'vtab-sun397' 397 1 adam sup_vitb16_imagenet1k 30 | bash scripts/VQT/run_vqt_vtab_sparsity.sh $GPUIDX 'vtab-sun397' 397 1 adam sup_vitb16_imagenet1k 0.7 31 | 32 | # ===== Q = 10, frac = 1.0 33 | bash scripts/VQT/run_vqt_vtab.sh $GPUIDX 'vtab-smallnorb(predicted_attribute="label_elevation")' 9 10 adam sup_vitb16_imagenet1k 34 | bash scripts/VQT/run_vqt_vtab.sh $GPUIDX 'vtab-clevr(task="closest_object_distance")' 6 10 adam sup_vitb16_imagenet1k 35 | 36 | # ===== Q = 20, frac = 1.0 37 | bash scripts/VQT/run_vqt_vtab.sh $GPUIDX 'vtab-svhn' 10 20 adam sup_vitb16_imagenet1k 38 | bash scripts/VQT/run_vqt_vtab.sh $GPUIDX 'vtab-dmlab' 6 20 adam sup_vitb16_imagenet1k 39 | bash scripts/VQT/run_vqt_vtab.sh $GPUIDX 'vtab-dsprites(predicted_attribute="label_x_position",num_classes=16)' 16 20 adam sup_vitb16_imagenet1k 40 | bash scripts/VQT/run_vqt_vtab.sh $GPUIDX 'vtab-smallnorb(predicted_attribute="label_azimuth")' 18 20 adam sup_vitb16_imagenet1k 41 | bash scripts/VQT/run_vqt_vtab.sh $GPUIDX 'vtab-kitti(task="closest_vehicle_distance")' 4 20 adam sup_vitb16_imagenet1k 42 | -------------------------------------------------------------------------------- /scripts/VQT/run_vqt_vtab.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | GPUIDX=$1 5 | DATA_NAME=$2 6 | NUM_CLASSES=$3 7 | Q_PROMPT_LEN=$4 8 | OPTIMIZER=$5 9 | FEATURE=$6 10 | 11 | MODEL_ROOT=pre-trained_weights 12 | DATA_PATH=vtab_data 13 | OUTPUT_DIR=h2t_vit_experiments/VQTSup_${Q_PROMPT_LEN}_${OPTIMIZER} 14 | 15 | 16 | CUDA_VISIBLE_DEVICES=$GPUIDX python head2toe_train.py \ 17 | --train-type "h2t-prompt" \ 18 | --config-file configs/h2t-prompt/vtab.yaml \ 19 | --optimizer $OPTIMIZER \ 20 | MODEL.TYPE "h2t-vit" \ 21 | MODEL.TRANSFER_TYPE "h2t-prompt" \ 22 | DATA.BATCH_SIZE "128" \ 23 | MODEL.H2T.NUM_QUERY_TOKENS "$Q_PROMPT_LEN" \ 24 | MODEL.H2T.DROPOUT "0.1" \ 25 | DATA.FEATURE $FEATURE \ 26 | DATA.NAME $DATA_NAME \ 27 | DATA.NUMBER_CLASSES $NUM_CLASSES \ 28 | DATA.DATAPATH $DATA_PATH \ 29 | MODEL.MODEL_ROOT $MODEL_ROOT \ 30 | OUTPUT_DIR $OUTPUT_DIR \ 31 | SOLVER.OPTIMIZER $OPTIMIZER \ 32 | SOLVER.DBG_TRAINABLE "True" 33 | -------------------------------------------------------------------------------- /scripts/VQT/run_vqt_vtab_sparsity.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | GPUIDX=$1 5 | DATA_NAME=$2 6 | NUM_CLASSES=$3 7 | Q_PROMPT_LEN=$4 8 | OPTIMIZER=$5 9 | FEATURE=$6 10 | KEEP_FRAC=$7 11 | 12 | MODEL_ROOT=pre-trained_weights 13 | DATA_PATH=vtab_data 14 | OUTPUT_DIR=h2t_vit_experiments/VQTSup_${Q_PROMPT_LEN}_${OPTIMIZER} 15 | LRP_COEF=0.0001 16 | 17 | 18 | echo 'Compressing the model ...' 19 | CUDA_VISIBLE_DEVICES=$GPUIDX python head2toe_sparsity_train.py \ 20 | --train-type "h2t-prompt" \ 21 | --config-file configs/h2t-prompt/vtab.yaml \ 22 | --h2t_sparse_mode compress \ 23 | MODEL.TYPE "h2t-vit" \ 24 | MODEL.TRANSFER_TYPE "h2t-prompt" \ 25 | DATA.BATCH_SIZE "128" \ 26 | MODEL.H2T.NUM_QUERY_TOKENS "$Q_PROMPT_LEN" \ 27 | MODEL.H2T.DROPOUT "0.1" \ 28 | DATA.FEATURE $FEATURE \ 29 | DATA.NAME $DATA_NAME \ 30 | DATA.NUMBER_CLASSES $NUM_CLASSES \ 31 | DATA.DATAPATH $DATA_PATH \ 32 | MODEL.MODEL_ROOT $MODEL_ROOT \ 33 | OUTPUT_DIR $OUTPUT_DIR \ 34 | SOLVER.OPTIMIZER $OPTIMIZER \ 35 | MODEL.H2T.LRP_COEF $LRP_COEF \ 36 | MODEL.H2T.KEEP_FRAC 1.0 37 | 38 | 39 | echo 'Feature selection and training with '$KEEP_FRAC' ...' 40 | CUDA_VISIBLE_DEVICES=$GPUIDX python head2toe_sparsity_train.py \ 41 | --train-type "h2t-prompt" \ 42 | --config-file configs/h2t-prompt/vtab.yaml \ 43 | --h2t_sparse_mode featselect \ 44 | MODEL.TYPE "h2t-vit" \ 45 | MODEL.TRANSFER_TYPE "h2t-prompt" \ 46 | DATA.BATCH_SIZE "128" \ 47 | MODEL.H2T.NUM_QUERY_TOKENS "$Q_PROMPT_LEN" \ 48 | MODEL.H2T.DROPOUT "0.1" \ 49 | DATA.FEATURE $FEATURE \ 50 | DATA.NAME $DATA_NAME \ 51 | DATA.NUMBER_CLASSES $NUM_CLASSES \ 52 | DATA.DATAPATH $DATA_PATH \ 53 | MODEL.MODEL_ROOT $MODEL_ROOT \ 54 | OUTPUT_DIR $OUTPUT_DIR \ 55 | SOLVER.OPTIMIZER $OPTIMIZER \ 56 | MODEL.H2T.LRP_COEF 0.0 \ 57 | MODEL.H2T.KEEP_FRAC $KEEP_FRAC 58 | -------------------------------------------------------------------------------- /scripts/VQT/run_vqt_vtab_ssl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | GPUIDX=$1 5 | DATA_NAME=$2 6 | NUM_CLASSES=$3 7 | Q_PROMPT_LEN=$4 8 | OPTIMIZER=$5 9 | FEATURE=$6 10 | 11 | MODEL_ROOT=pre-trained_weights 12 | DATA_PATH=vtab_data 13 | OUTPUT_DIR=h2t_vit_experiments/VQTSSL_${Q_PROMPT_LEN}_${OPTIMIZER} 14 | 15 | 16 | python head2toe_train.py \ 17 | --train-type "h2t-prompt" \ 18 | --config-file configs/h2t-prompt/vtab.yaml \ 19 | --optimizer $OPTIMIZER \ 20 | MODEL.TYPE "h2t-ssl-vit" \ 21 | MODEL.TRANSFER_TYPE "h2t-prompt" \ 22 | DATA.BATCH_SIZE "128" \ 23 | MODEL.H2T.NUM_QUERY_TOKENS "$Q_PROMPT_LEN" \ 24 | MODEL.H2T.DROPOUT "0.1" \ 25 | DATA.FEATURE $FEATURE \ 26 | DATA.NAME $DATA_NAME \ 27 | DATA.NUMBER_CLASSES $NUM_CLASSES \ 28 | DATA.DATAPATH $DATA_PATH \ 29 | MODEL.MODEL_ROOT $MODEL_ROOT \ 30 | OUTPUT_DIR $OUTPUT_DIR \ 31 | SOLVER.OPTIMIZER $OPTIMIZER \ 32 | SOLVER.DBG_TRAINABLE "True" 33 | -------------------------------------------------------------------------------- /src/configs/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Config system (based on Detectron's).""" 4 | 5 | from .config_node import CfgNode 6 | 7 | 8 | # Global config object 9 | _C = CfgNode() 10 | # Example usage: 11 | # from configs.config import cfg 12 | 13 | _C.DBG = False 14 | _C.OUTPUT_DIR = "./output" 15 | _C.RUN_N_TIMES = 5 16 | # Perform benchmarking to select the fastest CUDNN algorithms to use 17 | # Note that this may increase the memory usage and will likely not result 18 | # in overall speedups when variable size inputs are used (e.g. COCO training) 19 | _C.CUDNN_BENCHMARK = False 20 | 21 | # Number of GPUs to use (applies to both training and testing) 22 | _C.NUM_GPUS = 1 23 | _C.NUM_SHARDS = 1 24 | 25 | # Note that non-determinism may still be present due to non-deterministic 26 | # operator implementations in GPU operator libraries 27 | _C.SEED = None 28 | 29 | # ---------------------------------------------------------------------- 30 | # Model options 31 | # ---------------------------------------------------------------------- 32 | _C.MODEL = CfgNode() 33 | _C.MODEL.TRANSFER_TYPE = "linear" # one of linear, end2end, prompt, adapter, side, partial-1, tinytl-bias 34 | _C.MODEL.WEIGHT_PATH = "" # if resume from some checkpoint file 35 | _C.MODEL.SAVE_CKPT = False 36 | 37 | _C.MODEL.MODEL_ROOT = "" # root folder for pretrained model weights 38 | 39 | _C.MODEL.TYPE = "vit" 40 | _C.MODEL.MLP_NUM = 0 41 | 42 | _C.MODEL.LINEAR = CfgNode() 43 | _C.MODEL.LINEAR.MLP_SIZES = [] 44 | _C.MODEL.LINEAR.DROPOUT = 0.1 45 | 46 | # ---------------------------------------------------------------------- 47 | # Prompt options 48 | # ---------------------------------------------------------------------- 49 | _C.MODEL.PROMPT = CfgNode() 50 | _C.MODEL.PROMPT.NUM_TOKENS = 5 51 | _C.MODEL.PROMPT.LOCATION = "prepend" 52 | # prompt initalizatioin: 53 | # (1) default "random" 54 | # (2) "final-cls" use aggregated final [cls] embeddings from training dataset 55 | # (3) "cls-nolastl": use first 12 cls embeddings (exclude the final output) for deep prompt 56 | # (4) "cls-nofirstl": use last 12 cls embeddings (exclude the input to first layer) 57 | _C.MODEL.PROMPT.INITIATION = "random" # "final-cls", "cls-first12" 58 | _C.MODEL.PROMPT.CLSEMB_FOLDER = "" 59 | _C.MODEL.PROMPT.CLSEMB_PATH = "" 60 | _C.MODEL.PROMPT.PROJECT = -1 # "projection mlp hidden dim" 61 | _C.MODEL.PROMPT.DEEP = False # "whether do deep prompt or not, only for prepend location" 62 | 63 | 64 | _C.MODEL.PROMPT.NUM_DEEP_LAYERS = None # if set to be an int, then do partial-deep prompt tuning 65 | _C.MODEL.PROMPT.REVERSE_DEEP = False # if to only update last n layers, not the input layer 66 | _C.MODEL.PROMPT.DEEP_SHARED = False # if true, all deep layers will be use the same prompt emb 67 | _C.MODEL.PROMPT.FORWARD_DEEP_NOEXPAND = False # if true, will not expand input sequence for layers without prompt 68 | # how to get the output emb for cls head: 69 | # original: follow the orignial backbone choice, 70 | # img_pool: image patch pool only 71 | # prompt_pool: prompt embd pool only 72 | # imgprompt_pool: pool everything but the cls token 73 | _C.MODEL.PROMPT.VIT_POOL_TYPE = "original" 74 | _C.MODEL.PROMPT.DROPOUT = 0.0 75 | _C.MODEL.PROMPT.SAVE_FOR_EACH_EPOCH = False 76 | # ---------------------------------------------------------------------- 77 | # adapter options 78 | # ---------------------------------------------------------------------- 79 | _C.MODEL.ADAPTER = CfgNode() 80 | _C.MODEL.ADAPTER.REDUCATION_FACTOR = 8 81 | _C.MODEL.ADAPTER.STYLE = "Pfeiffer" 82 | 83 | _C.MODEL.H2T = CfgNode() 84 | _C.MODEL.H2T.NUM_QUERY_TOKENS = 5 85 | _C.MODEL.H2T.DROPOUT = 0.0 86 | _C.MODEL.H2T.NORMALIZE_FEATS = True 87 | _C.MODEL.H2T.POOLING_FEATS = False 88 | _C.MODEL.H2T.WEIGHTED_SUM_FEATS = False 89 | _C.MODEL.H2T.LRP_COEF = 0.0001 90 | _C.MODEL.H2T.KEEP_FRAC = 1.0 91 | 92 | # ---------------------------------------------------------------------- 93 | # Solver options 94 | # ---------------------------------------------------------------------- 95 | _C.SOLVER = CfgNode() 96 | _C.SOLVER.LOSS = "softmax" 97 | _C.SOLVER.LOSS_ALPHA = 0.01 98 | 99 | _C.SOLVER.OPTIMIZER = "sgd" # or "adamw" 100 | _C.SOLVER.MOMENTUM = 0.9 101 | _C.SOLVER.WEIGHT_DECAY = 0.0001 102 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0 103 | 104 | _C.SOLVER.PATIENCE = 300 105 | 106 | 107 | _C.SOLVER.SCHEDULER = "cosine" 108 | 109 | _C.SOLVER.BASE_LR = 0.01 110 | _C.SOLVER.BIAS_MULTIPLIER = 1. # for prompt + bias 111 | 112 | _C.SOLVER.WARMUP_EPOCH = 5 113 | _C.SOLVER.TOTAL_EPOCH = 30 114 | _C.SOLVER.LOG_EVERY_N = 1000 115 | 116 | 117 | _C.SOLVER.DBG_TRAINABLE = False # if True, will print the name of trainable params 118 | 119 | # ---------------------------------------------------------------------- 120 | # Dataset options 121 | # ---------------------------------------------------------------------- 122 | _C.DATA = CfgNode() 123 | 124 | _C.DATA.NAME = "" 125 | _C.DATA.DATAPATH = "" 126 | _C.DATA.FEATURE = "" # e.g. inat2021_supervised 127 | 128 | _C.DATA.PERCENTAGE = 1.0 129 | _C.DATA.NUMBER_CLASSES = -1 130 | _C.DATA.MULTILABEL = False 131 | _C.DATA.CLASS_WEIGHTS_TYPE = "none" 132 | 133 | _C.DATA.CROPSIZE = 224 # or 384 134 | 135 | _C.DATA.NO_TEST = False 136 | _C.DATA.BATCH_SIZE = 32 137 | # Number of data loader workers per training process 138 | _C.DATA.NUM_WORKERS = 4 139 | # Load data to pinned host memory 140 | _C.DATA.PIN_MEMORY = True 141 | 142 | _C.DIST_BACKEND = "nccl" 143 | _C.DIST_INIT_PATH = "env://" 144 | _C.DIST_INIT_FILE = "" 145 | 146 | 147 | def get_cfg(): 148 | """ 149 | Get a copy of the default config. 150 | """ 151 | return _C.clone() 152 | -------------------------------------------------------------------------------- /src/configs/config_node.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Config system (based on Detectron's).""" 4 | 5 | from fvcore.common.config import CfgNode as _CfgNode 6 | from ..utils.file_io import PathManager 7 | 8 | 9 | class CfgNode(_CfgNode): 10 | """ 11 | The same as `fvcore.common.config.CfgNode`, but different in: 12 | 13 | support manifold path 14 | """ 15 | 16 | @classmethod 17 | def _open_cfg(cls, filename): 18 | return PathManager.open(filename, "r") 19 | 20 | def dump(self, *args, **kwargs): 21 | """ 22 | Returns: 23 | str: a yaml string representation of the config 24 | """ 25 | # to make it show up in docs 26 | return super().dump(*args, **kwargs) 27 | -------------------------------------------------------------------------------- /src/configs/vit_configs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Copyright (c) Meta Platforms, Inc. All Rights Reserved 4 | https://github.com/jeonsworld/ViT-pytorch/blob/main/models/configs.py 5 | """ 6 | import ml_collections 7 | 8 | 9 | def get_testing(): 10 | """Returns a minimal configuration for testing.""" 11 | config = ml_collections.ConfigDict() 12 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 13 | config.hidden_size = 1 14 | config.transformer = ml_collections.ConfigDict() 15 | config.transformer.mlp_dim = 1 16 | config.transformer.num_heads = 1 17 | config.transformer.num_layers = 1 18 | config.transformer.attention_dropout_rate = 0.0 19 | config.transformer.dropout_rate = 0.1 20 | config.classifier = 'token' 21 | config.representation_size = None 22 | return config 23 | 24 | 25 | def get_b16_config(): 26 | """Returns the ViT-B/16 configuration.""" 27 | config = ml_collections.ConfigDict() 28 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 29 | config.hidden_size = 768 30 | config.transformer = ml_collections.ConfigDict() 31 | config.transformer.mlp_dim = 3072 32 | config.transformer.num_heads = 12 33 | config.transformer.num_layers = 12 34 | config.transformer.attention_dropout_rate = 0.0 35 | config.transformer.dropout_rate = 0.1 36 | config.classifier = 'token' 37 | config.representation_size = None 38 | return config 39 | 40 | 41 | def get_r50_b16_config(): 42 | """Returns the Resnet50 + ViT-B/16 configuration.""" 43 | config = get_b16_config() 44 | del config.patches.size 45 | config.patches.grid = (14, 14) 46 | config.resnet = ml_collections.ConfigDict() 47 | config.resnet.num_layers = (3, 4, 9) 48 | config.resnet.width_factor = 1 49 | return config 50 | 51 | 52 | def get_b32_config(): 53 | """Returns the ViT-B/32 configuration.""" 54 | config = get_b16_config() 55 | config.patches.size = (32, 32) 56 | return config 57 | 58 | 59 | def get_b8_config(): 60 | """Returns the ViT-B/32 configuration.""" 61 | config = get_b16_config() 62 | config.patches.size = (8, 8) 63 | return config 64 | 65 | 66 | def get_l16_config(): 67 | """Returns the ViT-L/16 configuration.""" 68 | config = ml_collections.ConfigDict() 69 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 70 | config.hidden_size = 1024 71 | config.transformer = ml_collections.ConfigDict() 72 | config.transformer.mlp_dim = 4096 73 | config.transformer.num_heads = 16 74 | config.transformer.num_layers = 24 75 | config.transformer.attention_dropout_rate = 0.0 76 | config.transformer.dropout_rate = 0.1 77 | config.classifier = 'token' 78 | config.representation_size = None 79 | return config 80 | 81 | 82 | def get_l32_config(): 83 | """Returns the ViT-L/32 configuration.""" 84 | config = get_l16_config() 85 | config.patches.size = (32, 32) 86 | return config 87 | 88 | 89 | def get_h14_config(): 90 | """Returns the ViT-L/16 configuration.""" 91 | config = ml_collections.ConfigDict() 92 | config.patches = ml_collections.ConfigDict({'size': (14, 14)}) 93 | config.hidden_size = 1280 94 | config.transformer = ml_collections.ConfigDict() 95 | config.transformer.mlp_dim = 5120 96 | config.transformer.num_heads = 16 97 | config.transformer.num_layers = 32 98 | config.transformer.attention_dropout_rate = 0.0 99 | config.transformer.dropout_rate = 0.1 100 | config.classifier = 'token' 101 | config.representation_size = None 102 | return config 103 | -------------------------------------------------------------------------------- /src/data/datasets/json_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """JSON dataset: support CUB, NABrids, Flower, Dogs and Cars""" 4 | 5 | import os 6 | import torch 7 | import torch.utils.data 8 | import torchvision as tv 9 | import numpy as np 10 | from collections import Counter 11 | 12 | from ..transforms import get_transforms 13 | from ...utils import logging 14 | from ...utils.io_utils import read_json 15 | logger = logging.get_logger("visual_prompt") 16 | 17 | 18 | class JSONDataset(torch.utils.data.Dataset): 19 | def __init__(self, cfg, split): 20 | assert split in { 21 | "train", 22 | "val", 23 | "test", 24 | }, "Split '{}' not supported for {} dataset".format( 25 | split, cfg.DATA.NAME) 26 | logger.info("Constructing {} dataset {}...".format( 27 | cfg.DATA.NAME, split)) 28 | 29 | self.cfg = cfg 30 | self._split = split 31 | self.name = cfg.DATA.NAME 32 | self.data_dir = cfg.DATA.DATAPATH 33 | self.data_percentage = cfg.DATA.PERCENTAGE 34 | self._construct_imdb(cfg) 35 | self.transform = get_transforms(split, cfg.DATA.CROPSIZE) 36 | 37 | def get_anno(self): 38 | anno_path = os.path.join(self.data_dir, "{}.json".format(self._split)) 39 | if "train" in self._split: 40 | if self.data_percentage < 1.0: 41 | anno_path = os.path.join( 42 | self.data_dir, 43 | "{}_{}.json".format(self._split, self.data_percentage) 44 | ) 45 | assert os.path.exists(anno_path), "{} dir not found".format(anno_path) 46 | 47 | return read_json(anno_path) 48 | 49 | def get_imagedir(self): 50 | raise NotImplementedError() 51 | 52 | def _construct_imdb(self, cfg): 53 | """Constructs the imdb.""" 54 | 55 | img_dir = self.get_imagedir() 56 | assert os.path.exists(img_dir), "{} dir not found".format(img_dir) 57 | 58 | anno = self.get_anno() 59 | # Map class ids to contiguous ids 60 | self._class_ids = sorted(list(set(anno.values()))) 61 | self._class_id_cont_id = {v: i for i, v in enumerate(self._class_ids)} 62 | 63 | # Construct the image db 64 | self._imdb = [] 65 | for img_name, cls_id in anno.items(): 66 | cont_id = self._class_id_cont_id[cls_id] 67 | im_path = os.path.join(img_dir, img_name) 68 | self._imdb.append({"im_path": im_path, "class": cont_id}) 69 | 70 | logger.info("Number of images: {}".format(len(self._imdb))) 71 | logger.info("Number of classes: {}".format(len(self._class_ids))) 72 | 73 | def get_info(self): 74 | num_imgs = len(self._imdb) 75 | return num_imgs, self.get_class_num() 76 | 77 | def get_class_num(self): 78 | return self.cfg.DATA.NUMBER_CLASSES 79 | # return len(self._class_ids) 80 | 81 | def get_class_weights(self, weight_type): 82 | """get a list of class weight, return a list float""" 83 | if "train" not in self._split: 84 | raise ValueError( 85 | "only getting training class distribution, " + \ 86 | "got split {} instead".format(self._split) 87 | ) 88 | 89 | cls_num = self.get_class_num() 90 | if weight_type == "none": 91 | return [1.0] * cls_num 92 | 93 | id2counts = Counter(self._class_ids) 94 | assert len(id2counts) == cls_num 95 | num_per_cls = np.array([id2counts[i] for i in self._class_ids]) 96 | 97 | if weight_type == 'inv': 98 | mu = -1.0 99 | elif weight_type == 'inv_sqrt': 100 | mu = -0.5 101 | weight_list = num_per_cls ** mu 102 | weight_list = np.divide( 103 | weight_list, np.linalg.norm(weight_list, 1)) * cls_num 104 | return weight_list.tolist() 105 | 106 | def __getitem__(self, index): 107 | # Load the image 108 | im = tv.datasets.folder.default_loader(self._imdb[index]["im_path"]) 109 | label = self._imdb[index]["class"] 110 | im = self.transform(im) 111 | if self._split == "train": 112 | index = index 113 | else: 114 | index = f"{self._split}{index}" 115 | sample = { 116 | "image": im, 117 | "label": label, 118 | # "id": index 119 | } 120 | return sample 121 | 122 | def __len__(self): 123 | return len(self._imdb) 124 | 125 | 126 | class CUB200Dataset(JSONDataset): 127 | """CUB_200 dataset.""" 128 | 129 | def __init__(self, cfg, split): 130 | super(CUB200Dataset, self).__init__(cfg, split) 131 | 132 | def get_imagedir(self): 133 | return os.path.join(self.data_dir, "images") 134 | 135 | 136 | class CarsDataset(JSONDataset): 137 | """stanford-cars dataset.""" 138 | 139 | def __init__(self, cfg, split): 140 | super(CarsDataset, self).__init__(cfg, split) 141 | 142 | def get_imagedir(self): 143 | return self.data_dir 144 | 145 | 146 | class DogsDataset(JSONDataset): 147 | """stanford-dogs dataset.""" 148 | 149 | def __init__(self, cfg, split): 150 | super(DogsDataset, self).__init__(cfg, split) 151 | 152 | def get_imagedir(self): 153 | return os.path.join(self.data_dir, "Images") 154 | 155 | 156 | class FlowersDataset(JSONDataset): 157 | """flowers dataset.""" 158 | 159 | def __init__(self, cfg, split): 160 | super(FlowersDataset, self).__init__(cfg, split) 161 | 162 | def get_imagedir(self): 163 | return self.data_dir 164 | 165 | 166 | class NabirdsDataset(JSONDataset): 167 | """Nabirds dataset.""" 168 | 169 | def __init__(self, cfg, split): 170 | super(NabirdsDataset, self).__init__(cfg, split) 171 | 172 | def get_imagedir(self): 173 | return os.path.join(self.data_dir, "images") 174 | 175 | -------------------------------------------------------------------------------- /src/data/datasets/tf_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """a dataset that handles output of tf.data: support datasets from VTAB""" 4 | import functools 5 | import tensorflow.compat.v1 as tf 6 | import torch 7 | import torch.utils.data 8 | import numpy as np 9 | 10 | from collections import Counter 11 | from torch import Tensor 12 | 13 | from ..vtab_datasets import base 14 | # pylint: disable=unused-import 15 | from ..vtab_datasets import caltech 16 | from ..vtab_datasets import cifar 17 | from ..vtab_datasets import clevr 18 | from ..vtab_datasets import diabetic_retinopathy 19 | from ..vtab_datasets import dmlab 20 | from ..vtab_datasets import dsprites 21 | from ..vtab_datasets import dtd 22 | from ..vtab_datasets import eurosat 23 | from ..vtab_datasets import kitti 24 | from ..vtab_datasets import oxford_flowers102 25 | from ..vtab_datasets import oxford_iiit_pet 26 | from ..vtab_datasets import patch_camelyon 27 | from ..vtab_datasets import resisc45 28 | from ..vtab_datasets import smallnorb 29 | from ..vtab_datasets import sun397 30 | from ..vtab_datasets import svhn 31 | from ..vtab_datasets.registry import Registry 32 | 33 | from ...utils import logging 34 | logger = logging.get_logger("visual_prompt") 35 | tf.config.experimental.set_visible_devices([], 'GPU') # set tensorflow to not use gpu # noqa 36 | DATASETS = [ 37 | 'caltech101', 38 | 'cifar(num_classes=100)', 39 | 'dtd', 40 | 'oxford_flowers102', 41 | 'oxford_iiit_pet', 42 | 'patch_camelyon', 43 | 'sun397', 44 | 'svhn', 45 | 'resisc45', 46 | 'eurosat', 47 | 'dmlab', 48 | 'kitti(task="closest_vehicle_distance")', 49 | 'smallnorb(predicted_attribute="label_azimuth")', 50 | 'smallnorb(predicted_attribute="label_elevation")', 51 | 'dsprites(predicted_attribute="label_x_position",num_classes=16)', 52 | 'dsprites(predicted_attribute="label_orientation",num_classes=16)', 53 | 'clevr(task="closest_object_distance")', 54 | 'clevr(task="count_all")', 55 | 'diabetic_retinopathy(config="btgraham-300")' 56 | ] 57 | 58 | 59 | class TFDataset(torch.utils.data.Dataset): 60 | def __init__(self, cfg, split): 61 | assert split in { 62 | "train", 63 | "val", 64 | "test", 65 | "trainval" 66 | }, "Split '{}' not supported for {} dataset".format( 67 | split, cfg.DATA.NAME) 68 | logger.info("Constructing {} dataset {}...".format( 69 | cfg.DATA.NAME, split)) 70 | 71 | self.cfg = cfg 72 | self._split = split 73 | self.name = cfg.DATA.NAME 74 | 75 | self.img_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) 76 | self.img_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) 77 | 78 | self.get_data(cfg, split) 79 | 80 | def get_data(self, cfg, split): 81 | tf_data = build_tf_dataset(cfg, split) 82 | data_list = list(tf_data) # a list of tuples 83 | 84 | self._image_tensor_list = [t[0].numpy().squeeze() for t in data_list] 85 | self._targets = [int(t[1].numpy()[0]) for t in data_list] 86 | self._class_ids = sorted(list(set(self._targets))) 87 | 88 | logger.info("Number of images: {}".format(len(self._image_tensor_list))) 89 | logger.info("Number of classes: {} / {}".format( 90 | len(self._class_ids), self.get_class_num())) 91 | 92 | del data_list 93 | del tf_data 94 | 95 | def get_info(self): 96 | num_imgs = len(self._image_tensor_list) 97 | return num_imgs, self.get_class_num() 98 | 99 | def get_class_num(self): 100 | return self.cfg.DATA.NUMBER_CLASSES 101 | 102 | def get_class_weights(self, weight_type): 103 | """get a list of class weight, return a list float""" 104 | if "train" not in self._split: 105 | raise ValueError( 106 | "only getting training class distribution, " + \ 107 | "got split {} instead".format(self._split) 108 | ) 109 | 110 | cls_num = self.get_class_num() 111 | if weight_type == "none": 112 | return [1.0] * cls_num 113 | 114 | id2counts = Counter(self._class_ids) 115 | assert len(id2counts) == cls_num 116 | num_per_cls = np.array([id2counts[i] for i in self._class_ids]) 117 | 118 | if weight_type == 'inv': 119 | mu = -1.0 120 | elif weight_type == 'inv_sqrt': 121 | mu = -0.5 122 | weight_list = num_per_cls ** mu 123 | weight_list = np.divide( 124 | weight_list, np.linalg.norm(weight_list, 1)) * cls_num 125 | return weight_list.tolist() 126 | 127 | def __getitem__(self, index): 128 | # Load the image 129 | label = self._targets[index] 130 | im = to_torch_imgs( 131 | self._image_tensor_list[index], self.img_mean, self.img_std) 132 | 133 | if self._split == "train": 134 | index = index 135 | else: 136 | index = f"{self._split}{index}" 137 | sample = { 138 | "image": im, 139 | "label": label, 140 | # "id": index 141 | } 142 | return sample 143 | 144 | def __len__(self): 145 | return len(self._targets) 146 | 147 | 148 | def preprocess_fn(data, size=224, input_range=(0.0, 1.0)): 149 | image = data["image"] 150 | image = tf.image.resize(image, [size, size]) 151 | 152 | image = tf.cast(image, tf.float32) / 255.0 153 | image = image * (input_range[1] - input_range[0]) + input_range[0] 154 | 155 | data["image"] = image 156 | return data 157 | 158 | 159 | def build_tf_dataset(cfg, mode): 160 | """ 161 | Builds a tf data instance, then transform to a list of tensors and labels 162 | """ 163 | 164 | if mode not in ["train", "val", "test", "trainval"]: 165 | raise ValueError("The input pipeline supports `train`, `val`, `test`." 166 | "Provided mode is {}".format(mode)) 167 | 168 | vtab_dataname = cfg.DATA.NAME.split("vtab-")[-1] 169 | data_dir = cfg.DATA.DATAPATH 170 | if vtab_dataname in DATASETS: 171 | data_cls = Registry.lookup("data." + vtab_dataname) 172 | vtab_tf_dataloader = data_cls(data_dir=data_dir) 173 | else: 174 | raise ValueError("Unknown type for \"dataset\" field: {}".format( 175 | type(vtab_dataname))) 176 | 177 | split_name_dict = { 178 | "dataset_train_split_name": "train800", 179 | "dataset_val_split_name": "val200", 180 | "dataset_trainval_split_name": "train800val200", 181 | "dataset_test_split_name": "test", 182 | } 183 | 184 | def _dict_to_tuple(batch): 185 | return batch['image'], batch['label'] 186 | 187 | return vtab_tf_dataloader.get_tf_data( 188 | batch_size=1, # data_params["batch_size"], 189 | drop_remainder=False, 190 | split_name=split_name_dict[f"dataset_{mode}_split_name"], 191 | preprocess_fn=functools.partial( 192 | preprocess_fn, 193 | input_range=(0.0, 1.0), 194 | size=cfg.DATA.CROPSIZE, 195 | ), 196 | for_eval=mode != "train", # handles shuffling 197 | shuffle_buffer_size=1000, 198 | prefetch=1, 199 | train_examples=None, 200 | epochs=1 # setting epochs to 1 make sure it returns one copy of the dataset 201 | ).map(_dict_to_tuple) # return a PrefetchDataset object. (which does not have much documentation to go on) 202 | 203 | 204 | def to_torch_imgs(img: np.ndarray, mean: Tensor, std: Tensor) -> Tensor: 205 | t_img: Tensor = torch.from_numpy(np.transpose(img, (2, 0, 1))) 206 | t_img -= mean 207 | t_img /= std 208 | 209 | return t_img 210 | -------------------------------------------------------------------------------- /src/data/loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Data loader.""" 4 | import torch 5 | from torch.utils.data.distributed import DistributedSampler 6 | from torch.utils.data.sampler import RandomSampler 7 | 8 | from ..utils import logging 9 | from .datasets.json_dataset import ( 10 | CUB200Dataset, CarsDataset, DogsDataset, FlowersDataset, NabirdsDataset 11 | ) 12 | 13 | logger = logging.get_logger("visual_prompt") 14 | _DATASET_CATALOG = { 15 | "CUB": CUB200Dataset, 16 | 'OxfordFlowers': FlowersDataset, 17 | 'StanfordCars': CarsDataset, 18 | 'StanfordDogs': DogsDataset, 19 | "nabirds": NabirdsDataset, 20 | } 21 | 22 | 23 | def _construct_loader(cfg, split, batch_size, shuffle, drop_last): 24 | """Constructs the data loader for the given dataset.""" 25 | dataset_name = cfg.DATA.NAME 26 | 27 | # Construct the dataset 28 | if dataset_name.startswith("vtab-"): 29 | # import the tensorflow here only if needed 30 | from .datasets.tf_dataset import TFDataset 31 | dataset = TFDataset(cfg, split) 32 | else: 33 | assert ( 34 | dataset_name in _DATASET_CATALOG.keys() 35 | ), "Dataset '{}' not supported".format(dataset_name) 36 | dataset = _DATASET_CATALOG[dataset_name](cfg, split) 37 | 38 | # Create a sampler for multi-process training 39 | sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None 40 | # Create a loader 41 | loader = torch.utils.data.DataLoader( 42 | dataset, 43 | batch_size=batch_size, 44 | shuffle=(False if sampler else shuffle), 45 | sampler=sampler, 46 | num_workers=cfg.DATA.NUM_WORKERS, 47 | pin_memory=cfg.DATA.PIN_MEMORY, 48 | drop_last=drop_last, 49 | ) 50 | return loader 51 | 52 | 53 | def construct_train_loader(cfg): 54 | """Train loader wrapper.""" 55 | if cfg.NUM_GPUS > 1: 56 | drop_last = True 57 | else: 58 | drop_last = False 59 | return _construct_loader( 60 | cfg=cfg, 61 | split="train", 62 | batch_size=int(cfg.DATA.BATCH_SIZE / cfg.NUM_GPUS), 63 | shuffle=True, 64 | drop_last=drop_last, 65 | ) 66 | 67 | 68 | def construct_trainval_loader(cfg): 69 | """Train loader wrapper.""" 70 | if cfg.NUM_GPUS > 1: 71 | drop_last = True 72 | else: 73 | drop_last = False 74 | return _construct_loader( 75 | cfg=cfg, 76 | split="trainval", 77 | batch_size=int(cfg.DATA.BATCH_SIZE / cfg.NUM_GPUS), 78 | shuffle=True, 79 | drop_last=drop_last, 80 | ) 81 | 82 | 83 | def construct_test_loader(cfg): 84 | """Test loader wrapper.""" 85 | return _construct_loader( 86 | cfg=cfg, 87 | split="test", 88 | batch_size=int(cfg.DATA.BATCH_SIZE / cfg.NUM_GPUS), 89 | shuffle=False, 90 | drop_last=False, 91 | ) 92 | 93 | 94 | def construct_val_loader(cfg, batch_size=None): 95 | if batch_size is None: 96 | bs = int(cfg.DATA.BATCH_SIZE / cfg.NUM_GPUS) 97 | else: 98 | bs = batch_size 99 | """Validation loader wrapper.""" 100 | return _construct_loader( 101 | cfg=cfg, 102 | split="val", 103 | batch_size=bs, 104 | shuffle=False, 105 | drop_last=False, 106 | ) 107 | 108 | 109 | def shuffle(loader, cur_epoch): 110 | """"Shuffles the data.""" 111 | assert isinstance( 112 | loader.sampler, (RandomSampler, DistributedSampler) 113 | ), "Sampler type '{}' not supported".format(type(loader.sampler)) 114 | # RandomSampler handles shuffling automatically 115 | if isinstance(loader.sampler, DistributedSampler): 116 | # DistributedSampler shuffles data based on epoch 117 | loader.sampler.set_epoch(cur_epoch) 118 | -------------------------------------------------------------------------------- /src/data/transforms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Image transformations.""" 4 | import torchvision as tv 5 | 6 | 7 | def get_transforms(split, size): 8 | normalize = tv.transforms.Normalize( 9 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 10 | ) 11 | if size == 448: 12 | resize_dim = 512 13 | crop_dim = 448 14 | elif size == 224: 15 | resize_dim = 256 16 | crop_dim = 224 17 | elif size == 384: 18 | resize_dim = 438 19 | crop_dim = 384 20 | if split == "train": 21 | transform = tv.transforms.Compose( 22 | [ 23 | tv.transforms.Resize(resize_dim), 24 | tv.transforms.RandomCrop(crop_dim), 25 | tv.transforms.RandomHorizontalFlip(0.5), 26 | # tv.transforms.RandomResizedCrop(224, scale=(0.2, 1.0)), 27 | # tv.transforms.RandomHorizontalFlip(), 28 | tv.transforms.ToTensor(), 29 | normalize, 30 | ] 31 | ) 32 | else: 33 | transform = tv.transforms.Compose( 34 | [ 35 | tv.transforms.Resize(resize_dim), 36 | tv.transforms.CenterCrop(crop_dim), 37 | tv.transforms.ToTensor(), 38 | normalize, 39 | ] 40 | ) 41 | return transform 42 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/caltech.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Imports the Caltech images dataset.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | from . import base as base 23 | from .registry import Registry 24 | import tensorflow_datasets as tfds 25 | 26 | 27 | # Percentage of the original training set retained for training, the rest is 28 | # used as a validation set. 29 | _TRAIN_SPLIT_PERCENT = 90 30 | 31 | 32 | @Registry.register("data.caltech101", "class") 33 | class Caltech101(base.ImageTfdsData): 34 | """Provides the Caltech101 dataset. 35 | 36 | See the base class for additional details on the class. 37 | 38 | See TFDS dataset for details on the dataset: 39 | third_party/py/tensorflow_datasets/image/caltech.py 40 | 41 | The original (TFDS) dataset contains only a train and test split. We randomly 42 | sample _TRAIN_SPLIT_PERCENT% of the train split for our "train" set. The 43 | remainder of the TFDS train split becomes our "val" set. The full TFDS train 44 | split is called "trainval". The TFDS test split is used as our test set. 45 | 46 | Note that, in the TFDS dataset, the training split is class-balanced, but not 47 | the test split. Therefore, a significant difference between performance on the 48 | "val" and "test" sets should be expected. 49 | """ 50 | 51 | def __init__(self, data_dir=None): 52 | dataset_builder = tfds.builder("caltech101:3.*.*", data_dir=data_dir) 53 | dataset_builder.download_and_prepare() 54 | 55 | # Creates a dict with example counts for each split. 56 | trainval_count = dataset_builder.info.splits["train"].num_examples 57 | train_count = (_TRAIN_SPLIT_PERCENT * trainval_count) // 100 58 | test_count = dataset_builder.info.splits["test"].num_examples 59 | num_samples_splits = dict( 60 | train=train_count, 61 | val=trainval_count - train_count, 62 | trainval=trainval_count, 63 | test=test_count, 64 | train800=800, 65 | val200=200, 66 | train800val200=1000) 67 | 68 | # Defines dataset specific train/val/trainval/test splits. 69 | tfds_splits = { 70 | "train": "train[:{}]".format(train_count), 71 | "val": "train[{}:]".format(train_count), 72 | "trainval": "train", 73 | "test": "test", 74 | "train800": "train[:800]", 75 | "val200": "train[{}:{}]".format(train_count, train_count+200), 76 | "train800val200": ( 77 | "train[:800]+train[{}:{}]".format(train_count, train_count+200)), 78 | } 79 | 80 | super(Caltech101, self).__init__( 81 | dataset_builder=dataset_builder, 82 | tfds_splits=tfds_splits, 83 | num_samples_splits=num_samples_splits, 84 | num_preprocessing_threads=400, 85 | shuffle_buffer_size=3000, 86 | base_preprocess_fn=base.make_get_tensors_fn(("image", "label")), 87 | num_classes=dataset_builder.info.features["label"].num_classes) 88 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/cifar.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements Cifar data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | from . import base as base 23 | from .registry import Registry 24 | import tensorflow_datasets as tfds 25 | 26 | # This constant specifies the percentage of data that is used to create custom 27 | # train/val splits. Specifically, TRAIN_SPLIT_PERCENT% of the official training 28 | # split is used as a new training split and the rest is used for validation. 29 | TRAIN_SPLIT_PERCENT = 90 30 | 31 | 32 | @Registry.register("data.cifar", "class") 33 | class CifarData(base.ImageTfdsData): 34 | """Provides Cifar10 or Cifar100 data. 35 | 36 | Cifar comes only with a training and test set. Therefore, the validation set 37 | is split out of the original training set, and the remaining examples are used 38 | as the "train" split. The "trainval" split corresponds to the original 39 | training set. 40 | 41 | For additional details and usage, see the base class. 42 | """ 43 | 44 | def __init__(self, num_classes=10, data_dir=None, train_split_percent=None): 45 | 46 | if num_classes == 10: 47 | dataset_builder = tfds.builder("cifar10:3.*.*", data_dir=data_dir) 48 | elif num_classes == 100: 49 | dataset_builder = tfds.builder("cifar100:3.*.*", data_dir=data_dir) 50 | else: 51 | raise ValueError( 52 | "Number of classes must be 10 or 100, got {}".format(num_classes)) 53 | 54 | dataset_builder.download_and_prepare() 55 | 56 | train_split_percent = train_split_percent or TRAIN_SPLIT_PERCENT 57 | 58 | # Creates a dict with example counts for each split. 59 | trainval_count = dataset_builder.info.splits["train"].num_examples 60 | test_count = dataset_builder.info.splits["test"].num_examples 61 | num_samples_splits = { 62 | "train": (train_split_percent * trainval_count) // 100, 63 | "val": trainval_count - (train_split_percent * trainval_count) // 100, 64 | "trainval": trainval_count, 65 | "test": test_count, 66 | "train800": 800, 67 | "val200": 200, 68 | "train800val200": 1000, 69 | } 70 | 71 | # Defines dataset specific train/val/trainval/test splits. 72 | tfds_splits = { 73 | "train": "train[:{}]".format(num_samples_splits["train"]), 74 | "val": "train[{}:]".format(num_samples_splits["train"]), 75 | "trainval": "train", 76 | "test": "test", 77 | "train800": "train[:800]", 78 | "val200": "train[{}:{}]".format( 79 | num_samples_splits["train"], num_samples_splits["train"]+200), 80 | "train800val200": "train[:800]+train[{}:{}]".format( 81 | num_samples_splits["train"], num_samples_splits["train"]+200), 82 | } 83 | 84 | super(CifarData, self).__init__( 85 | dataset_builder=dataset_builder, 86 | tfds_splits=tfds_splits, 87 | num_samples_splits=num_samples_splits, 88 | num_preprocessing_threads=400, 89 | shuffle_buffer_size=10000, 90 | # Note: Export only image and label tensors with their original types. 91 | base_preprocess_fn=base.make_get_tensors_fn(["image", "label", "id"]), 92 | num_classes=dataset_builder.info.features["label"].num_classes) 93 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/clevr.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements CLEVR data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import numpy as np 24 | import tensorflow.compat.v1 as tf 25 | import tensorflow_datasets as tfds 26 | 27 | from . import base as base 28 | from .registry import Registry 29 | 30 | TRAIN_SPLIT_PERCENT = 90 31 | 32 | 33 | def _count_preprocess_fn(x): 34 | return {"image": x["image"], 35 | "label": tf.size(x["objects"]["size"]) - 3} 36 | 37 | 38 | def _count_cylinders_preprocess_fn(x): 39 | # Class distribution: 40 | 41 | num_cylinders = tf.reduce_sum( 42 | tf.cast(tf.equal(x["objects"]["shape"], 2), tf.int32)) 43 | return {"image": x["image"], "label": num_cylinders} 44 | 45 | 46 | def _closest_object_preprocess_fn(x): 47 | dist = tf.reduce_min(x["objects"]["pixel_coords"][:, 2]) 48 | # These thresholds are uniformly spaced and result in more or less balanced 49 | # distribution of classes, see the resulting histogram: 50 | 51 | thrs = np.array([0.0, 8.0, 8.5, 9.0, 9.5, 10.0, 100.0]) 52 | label = tf.reduce_max(tf.where((thrs - dist) < 0)) 53 | return {"image": x["image"], 54 | "label": label} 55 | 56 | 57 | _TASK_DICT = { 58 | "count_all": { 59 | "preprocess_fn": _count_preprocess_fn, 60 | "num_classes": 8 61 | }, 62 | "count_cylinders": { 63 | "preprocess_fn": _count_cylinders_preprocess_fn, 64 | "num_classes": 11 65 | }, 66 | "closest_object_distance": { 67 | "preprocess_fn": _closest_object_preprocess_fn, 68 | "num_classes": 6 69 | }, 70 | } 71 | 72 | 73 | @Registry.register("data.clevr", "class") 74 | class CLEVRData(base.ImageTfdsData): 75 | """Provides CLEVR dataset. 76 | 77 | Currently, two tasks are supported: 78 | 1. Predict number of objects. 79 | 2. Predict distnace to the closest object. 80 | """ 81 | 82 | def __init__(self, task, data_dir=None): 83 | 84 | if task not in _TASK_DICT: 85 | raise ValueError("Unknown task: %s" % task) 86 | 87 | dataset_builder = tfds.builder("clevr:3.*.*", data_dir=data_dir) 88 | dataset_builder.download_and_prepare() 89 | 90 | # Creates a dict with example counts for each split. 91 | trainval_count = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples 92 | test_count = dataset_builder.info.splits[tfds.Split.TEST].num_examples 93 | num_samples_splits = { 94 | "train": (TRAIN_SPLIT_PERCENT * trainval_count) // 100, 95 | "val": trainval_count - (TRAIN_SPLIT_PERCENT * trainval_count) // 100, 96 | "trainval": trainval_count, 97 | "test": test_count, 98 | "train800": 800, 99 | "val200": 200, 100 | "train800val200": 1000, 101 | } 102 | 103 | # Defines dataset specific train/val/trainval/test splits. 104 | tfds_splits = { 105 | "train": "train[:{}]".format(num_samples_splits["train"]), 106 | "val": "train[{}:]".format(num_samples_splits["train"]), 107 | "trainval": "train", 108 | "test": "validation", 109 | "train800": "train[:800]", 110 | "val200": "train[{}:{}]".format( 111 | num_samples_splits["train"], num_samples_splits["train"]+200), 112 | "train800val200": "train[:800]+train[{}:{}]".format( 113 | num_samples_splits["train"], num_samples_splits["train"]+200), 114 | } 115 | 116 | task = _TASK_DICT[task] 117 | base_preprocess_fn = task["preprocess_fn"] 118 | 119 | super(CLEVRData, self).__init__( 120 | dataset_builder=dataset_builder, 121 | tfds_splits=tfds_splits, 122 | num_samples_splits=num_samples_splits, 123 | num_preprocessing_threads=400, 124 | shuffle_buffer_size=10000, 125 | # Note: Export only image and label tensors with their original types. 126 | base_preprocess_fn=base_preprocess_fn, 127 | num_classes=task["num_classes"]) 128 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/dmlab.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements Dmlab data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | 28 | 29 | @Registry.register("data.dmlab", "class") 30 | class DmlabData(base.ImageTfdsData): 31 | """Dmlab dataset. 32 | 33 | The Dmlab dataset contains frames observed by the agent acting in the 34 | DMLab environment, which are annotated by the distance between 35 | the agent and various objects present in the environment. The goal is to 36 | is to evaluate the ability of a visual model to reason about distances 37 | from the visual input in 3D environments. The Dmlab dataset consists of 38 | 360x480 color images in 6 classes. The classes are 39 | {close, far, very far} x {positive reward, negative reward} 40 | respectively. 41 | """ 42 | 43 | def __init__(self, data_dir=None): 44 | 45 | dataset_builder = tfds.builder("dmlab:2.0.1", data_dir=data_dir) 46 | 47 | tfds_splits = { 48 | "train": "train", 49 | "val": "validation", 50 | "trainval": "train+validation", 51 | "test": "test", 52 | "train800": "train[:800]", 53 | "val200": "validation[:200]", 54 | "train800val200": "train[:800]+validation[:200]", 55 | } 56 | 57 | # Example counts are retrieved from the tensorflow dataset info. 58 | train_count = dataset_builder.info.splits["train"].num_examples 59 | val_count = dataset_builder.info.splits["validation"].num_examples 60 | test_count = dataset_builder.info.splits["test"].num_examples 61 | 62 | # Creates a dict with example counts for each split. 63 | num_samples_splits = { 64 | "train": train_count, 65 | "val": val_count, 66 | "trainval": train_count + val_count, 67 | "test": test_count, 68 | "train800": 800, 69 | "val200": 200, 70 | "train800val200": 1000, 71 | } 72 | 73 | super(DmlabData, self).__init__( 74 | dataset_builder=dataset_builder, 75 | tfds_splits=tfds_splits, 76 | num_samples_splits=num_samples_splits, 77 | num_preprocessing_threads=400, 78 | shuffle_buffer_size=10000, 79 | base_preprocess_fn=base.make_get_and_cast_tensors_fn({ 80 | "image": ("image", None), 81 | "label": ("label", None), 82 | }), 83 | num_classes=dataset_builder.info.features["label"].num_classes, 84 | image_key="image") 85 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/dsprites.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements the DSprites data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow.compat.v1 as tf 24 | import tensorflow_datasets as tfds 25 | 26 | from . import base as base 27 | from .registry import Registry 28 | 29 | 30 | # These constants specify the percentage of data that is used to create custom 31 | # train/val splits. Specifically, TRAIN_SPLIT_PERCENT% of the data set is used 32 | # as a new training split and VAL_SPLIT_PERCENT% is used for validation. 33 | # The rest is used for testing. 34 | TRAIN_SPLIT_PERCENT = 80 35 | VAL_SPLIT_PERCENT = 10 36 | 37 | 38 | @Registry.register("data.dsprites", "class") 39 | class DSpritesData(base.ImageTfdsData): 40 | """Provides the DSprites data set. 41 | 42 | DSprites only comes with a training set. Therefore, the training, validation, 43 | and test set are split out of the original training set. 44 | 45 | For additional details and usage, see the base class. 46 | 47 | The data set page is https://github.com/deepmind/dsprites-dataset/. 48 | """ 49 | 50 | def __init__(self, predicted_attribute, num_classes=None, data_dir=None): 51 | dataset_builder = tfds.builder("dsprites:2.*.*", data_dir=data_dir) 52 | dataset_builder.download_and_prepare() 53 | info = dataset_builder.info 54 | 55 | if predicted_attribute not in dataset_builder.info.features: 56 | raise ValueError( 57 | "{} is not a valid attribute to predict.".format(predicted_attribute)) 58 | 59 | # If num_classes is set, we group together nearby integer values to arrive 60 | # at the desired number of classes. This is useful for example for grouping 61 | # together different spatial positions. 62 | num_original_classes = info.features[predicted_attribute].num_classes 63 | if num_classes is None: 64 | num_classes = num_original_classes 65 | if not isinstance(num_classes, int) or num_classes <= 1 or ( 66 | num_classes > num_original_classes): 67 | raise ValueError( 68 | "The number of classes should be None or in [2, ..., num_classes].") 69 | class_division_factor = float(num_original_classes) / num_classes 70 | 71 | # Creates a dict with example counts for each split. 72 | num_total = dataset_builder.info.splits["train"].num_examples 73 | num_samples_train = TRAIN_SPLIT_PERCENT * num_total // 100 74 | num_samples_val = VAL_SPLIT_PERCENT * num_total // 100 75 | num_samples_splits = { 76 | "train": num_samples_train, 77 | "val": num_samples_val, 78 | "trainval": num_samples_val + num_samples_train, 79 | "test": num_total - num_samples_val - num_samples_train, 80 | "train800": 800, 81 | "val200": 200, 82 | "train800val200": 1000, 83 | } 84 | 85 | # Defines dataset specific train/val/trainval/test splits. 86 | tfds_splits = { 87 | "train": "train[:{}]".format(num_samples_splits["train"]), 88 | "val": "train[{}:{}]".format(num_samples_splits["train"], 89 | num_samples_splits["trainval"]), 90 | "trainval": "train[:{}]".format(num_samples_splits["trainval"]), 91 | "test": "train[{}:]".format(num_samples_splits["trainval"]), 92 | "train800": "train[:800]", 93 | "val200": "train[{}:{}]".format(num_samples_splits["train"], 94 | num_samples_splits["train"]+200), 95 | "train800val200": "train[:800]+train[{}:{}]".format( 96 | num_samples_splits["train"], num_samples_splits["train"]+200), 97 | } 98 | 99 | def preprocess_fn(tensors): 100 | # For consistency with other datasets, image needs to have three channels 101 | # and be in [0, 255). 102 | images = tf.tile(tensors["image"], [1, 1, 3]) * 255 103 | label = tf.cast( 104 | tf.math.floordiv( 105 | tf.cast(tensors[predicted_attribute], tf.float32), 106 | class_division_factor), info.features[predicted_attribute].dtype) 107 | return dict(image=images, label=label) 108 | 109 | super(DSpritesData, self).__init__( 110 | dataset_builder=dataset_builder, 111 | tfds_splits=tfds_splits, 112 | num_samples_splits=num_samples_splits, 113 | num_preprocessing_threads=400, 114 | shuffle_buffer_size=10000, 115 | # We extract the attribute we want to predict in the preprocessing. 116 | base_preprocess_fn=preprocess_fn, 117 | num_classes=num_classes) 118 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/dtd.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements the Describable Textures Dataset (DTD) data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | 28 | 29 | @Registry.register("data.dtd", "class") 30 | class DTDData(base.ImageTfdsData): 31 | """Provides Describable Textures Dataset (DTD) data. 32 | 33 | As of version 1.0.0, the train/val/test splits correspond to those of the 34 | 1st fold of the official cross-validation partition. 35 | 36 | For additional details and usage, see the base class. 37 | """ 38 | 39 | def __init__(self, data_dir=None): 40 | 41 | dataset_builder = tfds.builder("dtd:3.*.*", data_dir=data_dir) 42 | dataset_builder.download_and_prepare() 43 | 44 | # Defines dataset specific train/val/trainval/test splits. 45 | tfds_splits = { 46 | "train": "train", 47 | "val": "validation", 48 | "trainval": "train+validation", 49 | "test": "test", 50 | "train800": "train[:800]", 51 | "val200": "validation[:200]", 52 | "train800val200": "train[:800]+validation[:200]", 53 | } 54 | 55 | # Creates a dict with example counts for each split. 56 | train_count = dataset_builder.info.splits["train"].num_examples 57 | val_count = dataset_builder.info.splits["validation"].num_examples 58 | test_count = dataset_builder.info.splits["test"].num_examples 59 | num_samples_splits = { 60 | "train": train_count, 61 | "val": val_count, 62 | "trainval": train_count + val_count, 63 | "test": test_count, 64 | "train800": 800, 65 | "val200": 200, 66 | "train800val200": 1000, 67 | } 68 | 69 | super(DTDData, self).__init__( 70 | dataset_builder=dataset_builder, 71 | tfds_splits=tfds_splits, 72 | num_samples_splits=num_samples_splits, 73 | num_preprocessing_threads=400, 74 | shuffle_buffer_size=10000, 75 | # Note: Export only image and label tensors with their original types. 76 | base_preprocess_fn=base.make_get_tensors_fn(["image", "label"]), 77 | num_classes=dataset_builder.info.features["label"].num_classes) 78 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements EurosatData class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | 28 | TRAIN_SPLIT_PERCENT = 60 29 | VALIDATION_SPLIT_PERCENT = 20 30 | TEST_SPLIT_PERCENT = 20 31 | 32 | 33 | @Registry.register("data.eurosat", "class") 34 | class EurosatData(base.ImageTfdsData): 35 | """Provides EuroSat dataset. 36 | 37 | EuroSAT dataset is based on Sentinel-2 satellite images covering 13 spectral 38 | bands and consisting of 10 classes with 27000 labeled and 39 | geo-referenced samples. 40 | 41 | URL: https://github.com/phelber/eurosat 42 | """ 43 | 44 | def __init__(self, subset="rgb", data_key="image", data_dir=None): 45 | dataset_name = "eurosat/{}:2.*.*".format(subset) 46 | dataset_builder = tfds.builder(dataset_name, data_dir=data_dir) 47 | dataset_builder.download_and_prepare() 48 | 49 | # Example counts are retrieved from the tensorflow dataset info. 50 | num_examples = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples 51 | train_count = num_examples * TRAIN_SPLIT_PERCENT // 100 52 | val_count = num_examples * VALIDATION_SPLIT_PERCENT // 100 53 | test_count = num_examples * TEST_SPLIT_PERCENT // 100 54 | 55 | tfds_splits = { 56 | "train": 57 | "train[:{}]".format(train_count), 58 | "val": 59 | "train[{}:{}]".format(train_count, train_count+val_count), 60 | "trainval": 61 | "train[:{}]".format(train_count+val_count), 62 | "test": 63 | "train[{}:]".format(train_count+val_count), 64 | "train800": 65 | "train[:800]", 66 | "val200": 67 | "train[{}:{}]".format(train_count, train_count+200), 68 | "train800val200": 69 | "train[:800]+train[{}:{}]".format(train_count, train_count+200), 70 | } 71 | 72 | # Creates a dict with example counts for each split. 73 | num_samples_splits = { 74 | "train": train_count, 75 | "val": val_count, 76 | "trainval": train_count + val_count, 77 | "test": test_count, 78 | "train800": 800, 79 | "val200": 200, 80 | "train800val200": 1000, 81 | } 82 | 83 | num_channels = 3 84 | if data_key == "sentinel2": 85 | num_channels = 13 86 | 87 | super(EurosatData, self).__init__( 88 | dataset_builder=dataset_builder, 89 | tfds_splits=tfds_splits, 90 | num_samples_splits=num_samples_splits, 91 | num_preprocessing_threads=100, 92 | shuffle_buffer_size=10000, 93 | base_preprocess_fn=base.make_get_and_cast_tensors_fn({ 94 | data_key: ("image", None), 95 | "label": ("label", None), 96 | }), 97 | image_key=data_key, 98 | num_channels=num_channels, 99 | num_classes=dataset_builder.info.features["label"].num_classes) 100 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/kitti.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements Kitti data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import numpy as np 24 | 25 | import tensorflow.compat.v1 as tf 26 | import tensorflow_datasets as tfds 27 | 28 | from . import base as base 29 | from .registry import Registry 30 | 31 | 32 | def _count_all_pp(x): 33 | """Count all objects.""" 34 | # Count distribution (thresholded at 15): 35 | 36 | label = tf.math.minimum(tf.size(x["objects"]["type"]) - 1, 8) 37 | return {"image": x["image"], "label": label} 38 | 39 | 40 | def _count_vehicles_pp(x): 41 | """Counting vehicles.""" 42 | # Label distribution: 43 | 44 | vehicles = tf.where(x["objects"]["type"] < 3) # Car, Van, Truck. 45 | # Cap at 3. 46 | label = tf.math.minimum(tf.size(vehicles), 3) 47 | return {"image": x["image"], "label": label} 48 | 49 | 50 | def _count_left_pp(x): 51 | """Count objects on the left hand side of the camera.""" 52 | # Count distribution (thresholded at 15): 53 | 54 | # Location feature contains (x, y, z) in meters w.r.t. the camera. 55 | objects_on_left = tf.where(x["objects"]["location"][:, 0] < 0) 56 | label = tf.math.minimum(tf.size(objects_on_left), 8) 57 | return {"image": x["image"], "label": label} 58 | 59 | 60 | def _count_far_pp(x): 61 | """Counts objects far from the camera.""" 62 | # Threshold removes ~half of the objects. 63 | # Count distribution (thresholded at 15): 64 | 65 | # Location feature contains (x, y, z) in meters w.r.t. the camera. 66 | distant_objects = tf.where(x["objects"]["location"][:, 2] >= 25) 67 | label = tf.math.minimum(tf.size(distant_objects), 8) 68 | return {"image": x["image"], "label": label} 69 | 70 | 71 | def _count_near_pp(x): 72 | """Counts objects close to the camera.""" 73 | # Threshold removes ~half of the objects. 74 | # Count distribution: 75 | 76 | # Location feature contains (x, y, z) in meters w.r.t. the camera. 77 | close_objects = tf.where(x["objects"]["location"][:, 2] < 25) 78 | label = tf.math.minimum(tf.size(close_objects), 8) 79 | return {"image": x["image"], "label": label} 80 | 81 | 82 | def _closest_object_distance_pp(x): 83 | """Predict the distance to the closest object.""" 84 | # Label distribution: 85 | 86 | # Location feature contains (x, y, z) in meters w.r.t. the camera. 87 | dist = tf.reduce_min(x["objects"]["location"][:, 2]) 88 | thrs = np.array([-100, 5.6, 8.4, 13.4, 23.4]) 89 | label = tf.reduce_max(tf.where((thrs - dist) < 0)) 90 | return {"image": x["image"], "label": label} 91 | 92 | 93 | def _closest_vehicle_distance_pp(x): 94 | """Predict the distance to the closest vehicle.""" 95 | # Label distribution: 96 | 97 | # Location feature contains (x, y, z) in meters w.r.t. the camera. 98 | vehicles = tf.where(x["objects"]["type"] < 3) # Car, Van, Truck. 99 | vehicle_z = tf.gather(params=x["objects"]["location"][:, 2], indices=vehicles) 100 | vehicle_z = tf.concat([vehicle_z, tf.constant([[1000.0]])], axis=0) 101 | dist = tf.reduce_min(vehicle_z) 102 | # Results in a uniform distribution over three distances, plus one class for 103 | # "no vehicle". 104 | thrs = np.array([-100.0, 8.0, 20.0, 999.0]) 105 | label = tf.reduce_max(tf.where((thrs - dist) < 0)) 106 | return {"image": x["image"], "label": label} 107 | 108 | 109 | def _closest_object_x_location_pp(x): 110 | """Predict the absolute x position of the closest object.""" 111 | # Label distribution: 112 | 113 | # Location feature contains (x, y, z) in meters w.r.t. the camera. 114 | idx = tf.math.argmin(x["objects"]["location"][:, 2]) 115 | xloc = x["objects"]["location"][idx, 0] 116 | thrs = np.array([-100, -6.4, -3.5, 0.0, 3.3, 23.9]) 117 | label = tf.reduce_max(tf.where((thrs - xloc) < 0)) 118 | return {"image": x["image"], "label": label} 119 | 120 | 121 | _TASK_DICT = { 122 | "count_all": { 123 | "preprocess_fn": _count_all_pp, 124 | "num_classes": 16, 125 | }, 126 | "count_left": { 127 | "preprocess_fn": _count_left_pp, 128 | "num_classes": 16, 129 | }, 130 | "count_far": { 131 | "preprocess_fn": _count_far_pp, 132 | "num_classes": 16, 133 | }, 134 | "count_near": { 135 | "preprocess_fn": _count_near_pp, 136 | "num_classes": 16, 137 | }, 138 | "closest_object_distance": { 139 | "preprocess_fn": _closest_object_distance_pp, 140 | "num_classes": 5, 141 | }, 142 | "closest_object_x_location": { 143 | "preprocess_fn": _closest_object_x_location_pp, 144 | "num_classes": 5, 145 | }, 146 | "count_vehicles": { 147 | "preprocess_fn": _count_vehicles_pp, 148 | "num_classes": 4, 149 | }, 150 | "closest_vehicle_distance": { 151 | "preprocess_fn": _closest_vehicle_distance_pp, 152 | "num_classes": 4, 153 | }, 154 | } 155 | 156 | 157 | @Registry.register("data.kitti", "class") 158 | class KittiData(base.ImageTfdsData): 159 | """Provides Kitti dataset. 160 | 161 | Six tasks are supported: 162 | 1. Count the number of objects. 163 | 2. Count the number of objects on the left hand side of the camera. 164 | 3. Count the number of objects in the foreground. 165 | 4. Count the number of objects in the background. 166 | 5. Predict the distance of the closest object. 167 | 6. Predict the x-location (w.r.t. the camera) of the closest object. 168 | """ 169 | 170 | def __init__(self, task, data_dir=None): 171 | 172 | if task not in _TASK_DICT: 173 | raise ValueError("Unknown task: %s" % task) 174 | 175 | dataset_builder = tfds.builder("kitti:3.2.0", data_dir=data_dir) 176 | dataset_builder.download_and_prepare() 177 | 178 | tfds_splits = { 179 | "train": "train", 180 | "val": "validation", 181 | "trainval": "train+validation", 182 | "test": "test", 183 | "train800": "train[:800]", 184 | "val200": "validation[:200]", 185 | "train800val200": "train[:800]+validation[:200]", 186 | } 187 | 188 | # Example counts are retrieved from the tensorflow dataset info. 189 | train_count = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples 190 | val_count = dataset_builder.info.splits[tfds.Split.VALIDATION].num_examples 191 | test_count = dataset_builder.info.splits[tfds.Split.TEST].num_examples 192 | # Creates a dict with example counts for each split. 193 | num_samples_splits = { 194 | "train": train_count, 195 | "val": val_count, 196 | "trainval": train_count + val_count, 197 | "test": test_count, 198 | "train800": 800, 199 | "val200": 200, 200 | "train800val200": 1000, 201 | } 202 | 203 | task = _TASK_DICT[task] 204 | base_preprocess_fn = task["preprocess_fn"] 205 | super(KittiData, self).__init__( 206 | dataset_builder=dataset_builder, 207 | tfds_splits=tfds_splits, 208 | num_samples_splits=num_samples_splits, 209 | num_preprocessing_threads=400, 210 | shuffle_buffer_size=10000, 211 | base_preprocess_fn=base_preprocess_fn, 212 | num_classes=task["num_classes"]) 213 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/oxford_flowers102.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements oxford flowers 102 data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | 28 | 29 | @Registry.register("data.oxford_flowers102", "class") 30 | class OxfordFlowers102Data(base.ImageTfdsData): 31 | """Provides Oxford 102 categories flowers dataset. 32 | 33 | See corresponding tfds dataset for details. 34 | 35 | URL: https://www.robots.ox.ac.uk/~vgg/data/flowers/102/ 36 | """ 37 | 38 | def __init__(self, data_dir=None, train_split_percent=None): 39 | dataset_builder = tfds.builder("oxford_flowers102:2.*.*", data_dir=data_dir) 40 | dataset_builder.download_and_prepare() 41 | 42 | # Example counts are retrieved from the tensorflow dataset info. 43 | train_count = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples 44 | val_count = dataset_builder.info.splits[tfds.Split.VALIDATION].num_examples 45 | test_count = dataset_builder.info.splits[tfds.Split.TEST].num_examples 46 | 47 | if train_split_percent: 48 | tfds_splits = { 49 | "train": "train[:{s}%]+validation[:{s}%]".format( 50 | s=train_split_percent), 51 | "val": "train[-{s}%:]+validation[-{s}%:]".format( 52 | s=train_split_percent), 53 | "trainval": "train+validation", 54 | "test": "test", 55 | "train800": "train[:800]", 56 | "val200": "validation[:200]", 57 | "train800val200": "train[:800]+validation[:200]", 58 | } 59 | num_samples_splits = { 60 | "train": (((train_count + val_count) // 100) 61 | * train_split_percent), 62 | "val": (((train_count + val_count) // 100) * 63 | (100 - train_split_percent)), 64 | "trainval": train_count + val_count, 65 | "test": test_count, 66 | "train800": 800, 67 | "val200": 200, 68 | "train800val200": 1000, 69 | } 70 | else: 71 | tfds_splits = { 72 | "train": "train", 73 | "val": "validation", 74 | "trainval": "train+validation", 75 | "test": "test", 76 | "train800": "train[:800]", 77 | "val200": "validation[:200]", 78 | "train800val200": "train[:800]+validation[:200]", 79 | } 80 | num_samples_splits = { 81 | "train": train_count, 82 | "val": val_count, 83 | "trainval": train_count + val_count, 84 | "test": test_count, 85 | "train800": 800, 86 | "val200": 200, 87 | "train800val200": 1000, 88 | } 89 | 90 | super(OxfordFlowers102Data, self).__init__( 91 | dataset_builder=dataset_builder, 92 | tfds_splits=tfds_splits, 93 | num_samples_splits=num_samples_splits, 94 | num_preprocessing_threads=400, 95 | shuffle_buffer_size=10000, 96 | # Note: Rename tensors but keep their original types. 97 | base_preprocess_fn=base.make_get_and_cast_tensors_fn({ 98 | "image": ("image", None), 99 | "label": ("label", None), 100 | }), 101 | num_classes=dataset_builder.info.features["label"] 102 | .num_classes) 103 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/oxford_iiit_pet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements OxfordIIITPet data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | # This constant specifies the percentage of data that is used to create custom 28 | # train/val splits. Specifically, TRAIN_SPLIT_PERCENT% of the official training 29 | # split is used as a new training split and the rest is used for validation. 30 | TRAIN_SPLIT_PERCENT = 80 31 | 32 | 33 | @Registry.register("data.oxford_iiit_pet", "class") 34 | class OxfordIIITPetData(base.ImageTfdsData): 35 | """Provides OxfordIIITPet data. 36 | 37 | The OxfordIIITPet dataset comes only with a training and test set. 38 | Therefore, the validation set is split out of the original training set, and 39 | the remaining examples are used as the "train" split. The "trainval" split 40 | corresponds to the original training set. 41 | 42 | For additional details and usage, see the base class. 43 | """ 44 | 45 | def __init__(self, data_dir=None, train_split_percent=None): 46 | 47 | dataset_builder = tfds.builder("oxford_iiit_pet:3.*.*", data_dir=data_dir) 48 | dataset_builder.download_and_prepare() 49 | train_split_percent = train_split_percent or TRAIN_SPLIT_PERCENT 50 | 51 | # Creates a dict with example counts for each split. 52 | trainval_count = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples 53 | test_count = dataset_builder.info.splits[tfds.Split.TEST].num_examples 54 | num_samples_splits = { 55 | "train": (train_split_percent * trainval_count) // 100, 56 | "val": trainval_count - (train_split_percent * trainval_count) // 100, 57 | "trainval": trainval_count, 58 | "test": test_count, 59 | "train800": 800, 60 | "val200": 200, 61 | "train800val200": 1000, 62 | } 63 | 64 | # Defines dataset specific train/val/trainval/test splits. 65 | tfds_splits = { 66 | "train": "train[:{}]".format(num_samples_splits["train"]), 67 | "val": "train[{}:]".format(num_samples_splits["train"]), 68 | "trainval": tfds.Split.TRAIN, 69 | "test": tfds.Split.TEST, 70 | "train800": "train[:800]", 71 | "val200": "train[{}:{}]".format( 72 | num_samples_splits["train"], num_samples_splits["train"]+200), 73 | "train800val200": "train[:800]+train[{}:{}]".format( 74 | num_samples_splits["train"], num_samples_splits["train"]+200), 75 | } 76 | 77 | super(OxfordIIITPetData, self).__init__( 78 | dataset_builder=dataset_builder, 79 | tfds_splits=tfds_splits, 80 | num_samples_splits=num_samples_splits, 81 | num_preprocessing_threads=400, 82 | shuffle_buffer_size=10000, 83 | # Note: Export only image and label tensors with their original types. 84 | base_preprocess_fn=base.make_get_tensors_fn(["image", "label"]), 85 | num_classes=dataset_builder.info.features["label"].num_classes) 86 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/patch_camelyon.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements PatchCamelyon data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | 28 | 29 | @Registry.register("data.patch_camelyon", "class") 30 | class PatchCamelyonData(base.ImageTfdsData): 31 | """Provides PatchCamelyon data.""" 32 | 33 | def __init__(self, data_dir=None): 34 | 35 | dataset_builder = tfds.builder("patch_camelyon:2.*.*", data_dir=data_dir) 36 | dataset_builder.download_and_prepare() 37 | 38 | # Defines dataset specific train/val/trainval/test splits. 39 | tfds_splits = { 40 | "test": "test", 41 | "train": "train", 42 | "val": "validation", 43 | "trainval": "train+validation", 44 | "train800": "train[:800]", 45 | "val200": "validation[:200]", 46 | "train800val200": "train[:800]+validation[:200]", 47 | } 48 | # Creates a dict with example counts. 49 | num_samples_splits = { 50 | "test": dataset_builder.info.splits["test"].num_examples, 51 | "train": dataset_builder.info.splits["train"].num_examples, 52 | "val": dataset_builder.info.splits["validation"].num_examples, 53 | "train800": 800, 54 | "val200": 200, 55 | "train800val200": 1000, 56 | } 57 | num_samples_splits["trainval"] = ( 58 | num_samples_splits["train"] + num_samples_splits["val"]) 59 | super(PatchCamelyonData, self).__init__( 60 | dataset_builder=dataset_builder, 61 | tfds_splits=tfds_splits, 62 | num_samples_splits=num_samples_splits, 63 | num_preprocessing_threads=400, 64 | shuffle_buffer_size=10000, 65 | # Note: Export only image and label tensors with their original types. 66 | base_preprocess_fn=base.make_get_tensors_fn(["image", "label"]), 67 | num_classes=dataset_builder.info.features["label"].num_classes) 68 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/registry.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Global Registry for the task adaptation framework. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import ast 25 | import functools 26 | 27 | 28 | def partialclass(cls, *base_args, **base_kwargs): 29 | """Builds a subclass with partial application of the given args and keywords. 30 | 31 | Equivalent to functools.partial performance, base_args are preprended to the 32 | positional arguments given during object initialization and base_kwargs are 33 | updated with the kwargs given later. 34 | 35 | Args: 36 | cls: The base class. 37 | *base_args: Positional arguments to be applied to the subclass. 38 | **base_kwargs: Keyword arguments to be applied to the subclass. 39 | 40 | Returns: 41 | A subclass of the input class. 42 | """ 43 | 44 | class _NewClass(cls): 45 | 46 | def __init__(self, *args, **kwargs): 47 | bound_args = base_args + args 48 | bound_kwargs = base_kwargs.copy() 49 | bound_kwargs.update(kwargs) 50 | super(_NewClass, self).__init__(*bound_args, **bound_kwargs) 51 | 52 | return _NewClass 53 | 54 | 55 | def parse_name(string_to_parse): 56 | """Parses input to the registry's lookup function. 57 | 58 | Args: 59 | string_to_parse: can be either an arbitrary name or function call 60 | (optionally with positional and keyword arguments). 61 | e.g. "multiclass", "resnet50_v2(filters_factor=8)". 62 | 63 | Returns: 64 | A tuple of input name and a dctinary with arguments. Examples: 65 | "multiclass" -> ("multiclass", (), {}) 66 | "resnet50_v2(9, filters_factor=4)" -> 67 | ("resnet50_v2", (9,), {"filters_factor": 4}) 68 | """ 69 | expr = ast.parse(string_to_parse, mode="eval").body # pytype: disable=attribute-error 70 | if not isinstance(expr, (ast.Attribute, ast.Call, ast.Name)): 71 | raise ValueError( 72 | "The given string should be a name or a call, but a {} was parsed from " 73 | "the string {!r}".format(type(expr), string_to_parse)) 74 | 75 | # Notes: 76 | # name="some_name" -> type(expr) = ast.Name 77 | # name="module.some_name" -> type(expr) = ast.Attribute 78 | # name="some_name()" -> type(expr) = ast.Call 79 | # name="module.some_name()" -> type(expr) = ast.Call 80 | 81 | if isinstance(expr, ast.Name): 82 | return string_to_parse, {} 83 | elif isinstance(expr, ast.Attribute): 84 | return string_to_parse, {} 85 | 86 | def _get_func_name(expr): 87 | if isinstance(expr, ast.Attribute): 88 | return _get_func_name(expr.value) + "." + expr.attr 89 | elif isinstance(expr, ast.Name): 90 | return expr.id 91 | else: 92 | raise ValueError( 93 | "Type {!r} is not supported in a function name, the string to parse " 94 | "was {!r}".format(type(expr), string_to_parse)) 95 | 96 | def _get_func_args_and_kwargs(call): 97 | args = tuple([ast.literal_eval(arg) for arg in call.args]) 98 | kwargs = { 99 | kwarg.arg: ast.literal_eval(kwarg.value) for kwarg in call.keywords 100 | } 101 | return args, kwargs 102 | 103 | func_name = _get_func_name(expr.func) 104 | func_args, func_kwargs = _get_func_args_and_kwargs(expr) 105 | if func_args: 106 | raise ValueError("Positional arguments are not supported here, but these " 107 | "were found: {!r}".format(func_args)) 108 | 109 | return func_name, func_kwargs 110 | 111 | 112 | class Registry(object): 113 | """Implements global Registry.""" 114 | 115 | _GLOBAL_REGISTRY = {} 116 | 117 | @staticmethod 118 | def global_registry(): 119 | return Registry._GLOBAL_REGISTRY 120 | 121 | @staticmethod 122 | def register(name, item_type): 123 | """Creates a function that registers its input.""" 124 | if item_type not in ["function", "class"]: 125 | raise ValueError("Unknown item type: %s" % item_type) 126 | 127 | def _register(item): 128 | if name in Registry.global_registry(): 129 | raise KeyError( 130 | "The name {!r} was already registered in with type {!r}".format( 131 | name, item_type)) 132 | 133 | Registry.global_registry()[name] = (item, item_type) 134 | return item 135 | 136 | return _register 137 | 138 | @staticmethod 139 | def lookup(lookup_string, kwargs_extra=None): 140 | """Lookup a name in the registry.""" 141 | 142 | name, kwargs = parse_name(lookup_string) 143 | if kwargs_extra: 144 | kwargs.update(kwargs_extra) 145 | item, item_type = Registry.global_registry()[name] 146 | if item_type == "function": 147 | return functools.partial(item, **kwargs) 148 | elif item_type == "class": 149 | return partialclass(item, **kwargs) 150 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/resisc45.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements RESISC-45 data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | TRAIN_SPLIT_PERCENT = 60 28 | VALIDATION_SPLIT_PERCENT = 20 29 | TEST_SPLIT_PERCENT = 20 30 | 31 | 32 | @Registry.register("data.resisc45", "class") 33 | class Resisc45Data(base.ImageTfdsData): 34 | """Provides RESISC-45 dataset. 35 | 36 | RESISC45 dataset is a publicly available benchmark for Remote Sensing Image 37 | Scene Classification (RESISC), created by Northwestern Polytechnical 38 | University (NWPU). This dataset contains 31,500 images, covering 45 scene 39 | classes with 700 images in each class. 40 | 41 | URL: http://www.escience.cn/people/JunweiHan/NWPU-RESISC45.html 42 | """ 43 | 44 | def __init__(self, data_dir=None): 45 | dataset_builder = tfds.builder("resisc45:3.*.*", data_dir=data_dir) 46 | dataset_builder.download_and_prepare() 47 | 48 | # Example counts are retrieved from the tensorflow dataset info. 49 | num_examples = dataset_builder.info.splits["train"].num_examples 50 | train_count = num_examples * TRAIN_SPLIT_PERCENT // 100 51 | val_count = num_examples * VALIDATION_SPLIT_PERCENT // 100 52 | test_count = num_examples * TEST_SPLIT_PERCENT // 100 53 | 54 | tfds_splits = { 55 | "train": 56 | "train[:{}]".format(train_count), 57 | "val": 58 | "train[{}:{}]".format(train_count, train_count + val_count), 59 | "trainval": 60 | "train[:{}]".format(train_count + val_count), 61 | "test": 62 | "train[{}:]".format(train_count + val_count), 63 | "train800": 64 | "train[:800]", 65 | "val200": 66 | "train[{}:{}]".format(train_count, train_count+200), 67 | "train800val200": 68 | "train[:800]+train[{}:{}]".format(train_count, train_count+200), 69 | } 70 | 71 | # Creates a dict with example counts for each split. 72 | num_samples_splits = { 73 | "train": train_count, 74 | "val": val_count, 75 | "trainval": train_count + val_count, 76 | "test": test_count, 77 | "train800": 800, 78 | "val200": 200, 79 | "train800val200": 1000, 80 | } 81 | 82 | super(Resisc45Data, self).__init__( 83 | dataset_builder=dataset_builder, 84 | tfds_splits=tfds_splits, 85 | num_samples_splits=num_samples_splits, 86 | num_preprocessing_threads=400, 87 | shuffle_buffer_size=10000, 88 | # Note: Rename tensors but keep their original types. 89 | base_preprocess_fn=base.make_get_and_cast_tensors_fn({ 90 | "image": ("image", None), 91 | "label": ("label", None), 92 | }), 93 | num_classes=dataset_builder.info.features["label"].num_classes) 94 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/smallnorb.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements the SmallNORB data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow.compat.v1 as tf 24 | import tensorflow_datasets as tfds 25 | 26 | from . import base as base 27 | from .registry import Registry 28 | # This constant specifies the percentage of data that is used to create custom 29 | # val/test splits. Specifically, VAL_SPLIT_PERCENT% of the official testing 30 | # split is used as a new validation split and the rest is used for testing. 31 | VAL_SPLIT_PERCENT = 50 32 | 33 | 34 | @Registry.register("data.smallnorb", "class") 35 | class SmallNORBData(base.ImageTfdsData): 36 | """Provides the SmallNORB data set. 37 | 38 | SmallNORB comes only with a training and test set. Therefore, the validation 39 | set is split out of the original training set, and the remaining examples are 40 | used as the "train" split. The "trainval" split corresponds to the original 41 | training set. 42 | 43 | For additional details and usage, see the base class. 44 | 45 | The data set page is https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/. 46 | """ 47 | 48 | def __init__(self, predicted_attribute, data_dir=None): 49 | dataset_builder = tfds.builder("smallnorb:2.*.*", data_dir=data_dir) 50 | dataset_builder.download_and_prepare() 51 | 52 | if predicted_attribute not in dataset_builder.info.features: 53 | raise ValueError( 54 | "{} is not a valid attribute to predict.".format(predicted_attribute)) 55 | 56 | # Defines dataset specific train/val/trainval/test splits. 57 | tfds_splits = { 58 | "train": "train", 59 | "val": "test[:{}%]".format(VAL_SPLIT_PERCENT), 60 | "trainval": "train+test[:{}%]".format(VAL_SPLIT_PERCENT), 61 | "test": "test[{}%:]".format(VAL_SPLIT_PERCENT), 62 | "train800": "train[:800]", 63 | "val200": "test[:200]", 64 | "train800val200": "train[:800]+test[:200]", 65 | } 66 | 67 | # Creates a dict with example counts for each split. 68 | train_count = dataset_builder.info.splits["train"].num_examples 69 | test_count = dataset_builder.info.splits["test"].num_examples 70 | num_samples_validation = VAL_SPLIT_PERCENT * test_count // 100 71 | num_samples_splits = { 72 | "train": train_count, 73 | "val": num_samples_validation, 74 | "trainval": train_count + num_samples_validation, 75 | "test": test_count - num_samples_validation, 76 | "train800": 800, 77 | "val200": 200, 78 | "train800val200": 1000, 79 | } 80 | 81 | def preprocess_fn(tensors): 82 | # For consistency with other datasets, image needs to have three channels. 83 | image = tf.tile(tensors["image"], [1, 1, 3]) 84 | return dict(image=image, label=tensors[predicted_attribute]) 85 | 86 | info = dataset_builder.info 87 | super(SmallNORBData, self).__init__( 88 | dataset_builder=dataset_builder, 89 | tfds_splits=tfds_splits, 90 | num_samples_splits=num_samples_splits, 91 | num_preprocessing_threads=400, 92 | shuffle_buffer_size=10000, 93 | # We extract the attribute we want to predict in the preprocessing. 94 | base_preprocess_fn=preprocess_fn, 95 | num_classes=info.features[predicted_attribute].num_classes) 96 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/sun397.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements Sun397 data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | CUSTOM_TRAIN_SPLIT_PERCENT = 50 28 | CUSTOM_VALIDATION_SPLIT_PERCENT = 20 29 | CUSTOM_TEST_SPLIT_PERCENT = 30 30 | 31 | 32 | @Registry.register("data.sun397", "class") 33 | class Sun397Data(base.ImageTfdsData): 34 | """Provides Sun397Data data.""" 35 | 36 | def __init__(self, config="tfds", data_dir=None): 37 | 38 | if config == "tfds": 39 | dataset_builder = tfds.builder("sun397/tfds:4.*.*", data_dir=data_dir) 40 | dataset_builder.download_and_prepare() 41 | 42 | tfds_splits = { 43 | "train": "train", 44 | "val": "validation", 45 | "test": "test", 46 | "trainval": "train+validation", 47 | "train800": "train[:800]", 48 | "val200": "validation[:200]", 49 | "train800val200": "train[:800]+validation[:200]", 50 | } 51 | # Creates a dict with example counts. 52 | num_samples_splits = { 53 | "test": dataset_builder.info.splits["test"].num_examples, 54 | "train": dataset_builder.info.splits["train"].num_examples, 55 | "val": dataset_builder.info.splits["validation"].num_examples, 56 | "train800": 800, 57 | "val200": 200, 58 | "train800val200": 1000, 59 | } 60 | num_samples_splits["trainval"] = ( 61 | num_samples_splits["train"] + num_samples_splits["val"]) 62 | else: 63 | 64 | raise ValueError("No supported config %r for Sun397Data." % config) 65 | 66 | super(Sun397Data, self).__init__( 67 | dataset_builder=dataset_builder, 68 | tfds_splits=tfds_splits, 69 | num_samples_splits=num_samples_splits, 70 | num_preprocessing_threads=400, 71 | shuffle_buffer_size=10000, 72 | # Note: Export only image and label tensors with their original types. 73 | base_preprocess_fn=base.make_get_tensors_fn(["image", "label"]), 74 | num_classes=dataset_builder.info.features["label"].num_classes) 75 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/svhn.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements Svhn data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | # This constant specifies the percentage of data that is used to create custom 28 | # train/val splits. Specifically, TRAIN_SPLIT_PERCENT% of the official training 29 | # split is used as a new training split and the rest is used for validation. 30 | TRAIN_SPLIT_PERCENT = 90 31 | 32 | 33 | @Registry.register("data.svhn", "class") 34 | class SvhnData(base.ImageTfdsData): 35 | """Provides SVHN data. 36 | 37 | The Street View House Numbers (SVHN) Dataset is an image digit recognition 38 | dataset of over 600,000 color digit images coming from real world data. 39 | Split size: 40 | - Training set: 73,257 images 41 | - Testing set: 26,032 images 42 | - Extra training set: 531,131 images 43 | Following the common setup on SVHN, we only use the official training and 44 | testing data. Images are cropped to 32x32. 45 | 46 | URL: http://ufldl.stanford.edu/housenumbers/ 47 | """ 48 | 49 | def __init__(self, data_dir=None): 50 | dataset_builder = tfds.builder("svhn_cropped:3.*.*", data_dir=data_dir) 51 | dataset_builder.download_and_prepare() 52 | 53 | # Example counts are retrieved from the tensorflow dataset info. 54 | trainval_count = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples 55 | test_count = dataset_builder.info.splits[tfds.Split.TEST].num_examples 56 | 57 | # Creates a dict with example counts for each split. 58 | num_samples_splits = { 59 | # Calculates the train/val split example count based on percent. 60 | "train": TRAIN_SPLIT_PERCENT * trainval_count // 100, 61 | "val": trainval_count - TRAIN_SPLIT_PERCENT * trainval_count // 100, 62 | "trainval": trainval_count, 63 | "test": test_count, 64 | "train800": 800, 65 | "val200": 200, 66 | "train800val200": 1000, 67 | } 68 | 69 | # Defines dataset specific train/val/trainval/test splits. 70 | # The validation set is split out of the original training set, and the 71 | # remaining examples are used as the "train" split. The "trainval" split 72 | # corresponds to the original training set. 73 | tfds_splits = { 74 | "train": 75 | "train[:{}]".format(num_samples_splits["train"]), 76 | "val": 77 | "train[{}:]".format(num_samples_splits["train"]), 78 | "trainval": 79 | "train", 80 | "test": 81 | "test", 82 | "train800": 83 | "train[:800]", 84 | "val200": 85 | "train[{}:{}]".format(num_samples_splits["train"], 86 | num_samples_splits["train"] + 200), 87 | "train800val200": 88 | "train[:800]+train[{}:{}]".format( 89 | num_samples_splits["train"], num_samples_splits["train"] + 200), 90 | } 91 | 92 | super(SvhnData, self).__init__( 93 | dataset_builder=dataset_builder, 94 | tfds_splits=tfds_splits, 95 | num_samples_splits=num_samples_splits, 96 | num_preprocessing_threads=400, 97 | shuffle_buffer_size=10000, 98 | # Note: Rename tensors but keep their original types. 99 | base_preprocess_fn=base.make_get_and_cast_tensors_fn({ 100 | "image": ("image", None), 101 | "label": ("label", None), 102 | }), 103 | num_classes=dataset_builder.info.features["label"] 104 | .num_classes) 105 | -------------------------------------------------------------------------------- /src/engine/eval/multilabel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | evaluate precision@1, @5 equal to Top1 and Top5 error rate 4 | """ 5 | import numpy as np 6 | from typing import List, Tuple, Dict 7 | from sklearn.metrics import ( 8 | precision_recall_curve, 9 | average_precision_score, 10 | f1_score 11 | ) 12 | 13 | 14 | def get_continuous_ids(probe_labels: List[int]) -> Dict[int, int]: 15 | sorted(probe_labels) 16 | id2continuousid = {} 17 | for idx, p_id in enumerate(probe_labels): 18 | id2continuousid[p_id] = idx 19 | return id2continuousid 20 | 21 | 22 | def multihot(x: List[List[int]], nb_classes: int) -> np.ndarray: 23 | """transform to multihot encoding 24 | 25 | Arguments: 26 | x: list of multi-class integer labels, in the range 27 | [0, nb_classes-1] 28 | nb_classes: number of classes for the multi-hot vector 29 | 30 | Returns: 31 | multihot: multihot vector of type int, (num_samples, nb_classes) 32 | """ 33 | num_samples = len(x) 34 | 35 | multihot = np.zeros((num_samples, nb_classes), dtype=np.int32) 36 | for idx, labs in enumerate(x): 37 | for lab in labs: 38 | multihot[idx, lab] = 1 39 | 40 | return multihot.astype(np.int) 41 | 42 | 43 | def compute_map( 44 | scores: np.ndarray, multihot_targets: np.ndarray 45 | ) -> Tuple[np.ndarray, np.ndarray, float, float]: 46 | """Compute the mean average precision across all class labels. 47 | 48 | Arguments: 49 | scores: matrix of per-class distances, 50 | of size num_samples x nb_classes 51 | multihot_targets: matrix of multi-hot target predictions, 52 | of size num_samples x nb_classes 53 | 54 | Returns: 55 | ap: list of average-precision scores, one for each of 56 | the nb_classes classes. 57 | ar: list of average-recall scores, one for each of 58 | the nb_classes classes. 59 | mAP: the mean average precision score over all average 60 | precisions for all nb_classes classes. 61 | mAR: the mean average recall score over all average 62 | precisions for all nb_classes classes. 63 | """ 64 | nb_classes = scores.shape[1] 65 | 66 | ap = np.zeros((nb_classes,), dtype=np.float32) 67 | ar = np.zeros((nb_classes,), dtype=np.float32) 68 | 69 | for c in range(nb_classes): 70 | y_true = multihot_targets[:, c] 71 | y_scores = scores[:, c] 72 | 73 | # Use interpolated average precision (a la PASCAL 74 | try: 75 | ap[c] = average_precision_score(y_true, y_scores) 76 | except ValueError: 77 | ap[c] = -1 78 | 79 | # Also get the average of the recalls on the raw PR-curve 80 | try: 81 | _, rec, _ = precision_recall_curve(y_true, y_scores) 82 | ar[c] = rec.mean() 83 | except ValueError: 84 | ar[c] = -1 85 | 86 | mAP = ap.mean() 87 | mAR = ar.mean() 88 | 89 | return ap, ar, mAP, mAR 90 | 91 | 92 | def compute_f1( 93 | multihot_targets: np.ndarray, scores: np.ndarray, threshold: float = 0.5 94 | ) -> Tuple[float, float, float]: 95 | # change scores to predict_labels 96 | predict_labels = scores > threshold 97 | predict_labels = predict_labels.astype(np.int) 98 | 99 | # change targets to multihot 100 | f1 = {} 101 | f1["micro"] = f1_score( 102 | y_true=multihot_targets, 103 | y_pred=predict_labels, 104 | average="micro" 105 | ) 106 | f1["samples"] = f1_score( 107 | y_true=multihot_targets, 108 | y_pred=predict_labels, 109 | average="samples" 110 | ) 111 | f1["macro"] = f1_score( 112 | y_true=multihot_targets, 113 | y_pred=predict_labels, 114 | average="macro" 115 | ) 116 | f1["none"] = f1_score( 117 | y_true=multihot_targets, 118 | y_pred=predict_labels, 119 | average=None 120 | ) 121 | return f1["micro"], f1["samples"], f1["macro"], f1["none"] 122 | 123 | 124 | def get_best_f1_scores( 125 | multihot_targets: np.ndarray, scores: np.ndarray, threshold_end: int 126 | ) -> Dict[str, float]: 127 | end = 0.5 128 | end = 0.05 129 | end = threshold_end 130 | thrs = np.linspace( 131 | end, 0.95, int(np.round((0.95 - end) / 0.05)) + 1, endpoint=True 132 | ) 133 | f1_micros = [] 134 | f1_macros = [] 135 | f1_samples = [] 136 | f1_none = [] 137 | for thr in thrs: 138 | _micros, _samples, _macros, _none = compute_f1(multihot_targets, scores, thr) 139 | f1_micros.append(_micros) 140 | f1_samples.append(_samples) 141 | f1_macros.append(_macros) 142 | f1_none.append(_none) 143 | 144 | f1_macros_m = max(f1_macros) 145 | b_thr = np.argmax(f1_macros) 146 | 147 | f1_micros_m = f1_micros[b_thr] 148 | f1_samples_m = f1_samples[b_thr] 149 | f1_none_m = f1_none[b_thr] 150 | f1 = {} 151 | f1["micro"] = f1_micros_m 152 | f1["macro"] = f1_macros_m 153 | f1["samples"] = f1_samples_m 154 | f1["threshold"] = thrs[b_thr] 155 | f1["none"] = f1_none_m 156 | return f1 157 | -------------------------------------------------------------------------------- /src/engine/eval/singlelabel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Functions for computing metrics. all metrics has range of 0-1""" 4 | 5 | import numpy as np 6 | import torch 7 | from sklearn.metrics import ( 8 | accuracy_score, average_precision_score, f1_score, roc_auc_score 9 | ) 10 | 11 | 12 | def accuracy(y_probs, y_true): 13 | # y_prob: (num_images, num_classes) 14 | y_preds = np.argmax(y_probs, axis=1) 15 | accuracy = accuracy_score(y_true, y_preds) 16 | error = 1.0 - accuracy 17 | return accuracy, error 18 | 19 | 20 | def top_n_accuracy(y_probs, truths, n=1): 21 | # y_prob: (num_images, num_classes) 22 | # truth: (num_images, num_classes) multi/one-hot encoding 23 | best_n = np.argsort(y_probs, axis=1)[:, -n:] 24 | if isinstance(truths, np.ndarray) and truths.shape == y_probs.shape: 25 | ts = np.argmax(truths, axis=1) 26 | else: 27 | # a list of GT class idx 28 | ts = truths 29 | 30 | num_input = y_probs.shape[0] 31 | successes = 0 32 | for i in range(num_input): 33 | if ts[i] in best_n[i, :]: 34 | successes += 1 35 | return float(successes) / num_input 36 | 37 | 38 | def compute_acc_auc(y_probs, y_true_ids): 39 | onehot_tgts = np.zeros_like(y_probs) 40 | for idx, t in enumerate(y_true_ids): 41 | onehot_tgts[idx, t] = 1. 42 | 43 | num_classes = y_probs.shape[1] 44 | if num_classes == 2: 45 | top1, _ = accuracy(y_probs, y_true_ids) 46 | # so precision can set all to 2 47 | try: 48 | auc = roc_auc_score(onehot_tgts, y_probs, average='macro') 49 | except ValueError as e: 50 | print(f"value error encountered {e}, set auc sccore to -1.") 51 | auc = -1 52 | return {"top1": top1, "rocauc": auc} 53 | 54 | top1, _ = accuracy(y_probs, y_true_ids) 55 | k = min([5, num_classes]) # if number of labels < 5, use the total class 56 | top5 = top_n_accuracy(y_probs, y_true_ids, k) 57 | return {"top1": top1, f"top{k}": top5} 58 | 59 | 60 | def topks_correct(preds, labels, ks): 61 | """Computes the number of top-k correct predictions for each k.""" 62 | assert preds.size(0) == labels.size( 63 | 0 64 | ), "Batch dim of predictions and labels must match" 65 | # Find the top max_k predictions for each sample 66 | _top_max_k_vals, top_max_k_inds = torch.topk( 67 | preds, max(ks), dim=1, largest=True, sorted=True 68 | ) 69 | # (batch_size, max_k) -> (max_k, batch_size) 70 | top_max_k_inds = top_max_k_inds.t() 71 | # (batch_size, ) -> (max_k, batch_size) 72 | rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds) 73 | # (i, j) = 1 if top i-th prediction for the j-th sample is correct 74 | top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels) 75 | # Compute the number of topk correct predictions for each k 76 | topks_correct = [ 77 | top_max_k_correct[:k, :].reshape(-1).float().sum() for k in ks 78 | ] 79 | return topks_correct 80 | 81 | 82 | def topk_errors(preds, labels, ks): 83 | """Computes the top-k error for each k.""" 84 | if int(labels.min()) < 0: # has ignore 85 | keep_ids = np.where(labels.cpu() >= 0)[0] 86 | preds = preds[keep_ids, :] 87 | labels = labels[keep_ids] 88 | 89 | num_topks_correct = topks_correct(preds, labels, ks) 90 | return [(1.0 - x / preds.size(0)) for x in num_topks_correct] 91 | 92 | 93 | def topk_accuracies(preds, labels, ks): 94 | """Computes the top-k accuracy for each k.""" 95 | num_topks_correct = topks_correct(preds, labels, ks) 96 | return [(x / preds.size(0)) for x in num_topks_correct] 97 | 98 | -------------------------------------------------------------------------------- /src/engine/evaluator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import numpy as np 3 | 4 | from collections import defaultdict 5 | from typing import List, Union 6 | 7 | from .eval import multilabel 8 | from .eval import singlelabel 9 | from ..utils import logging 10 | logger = logging.get_logger("visual_prompt") 11 | 12 | 13 | class Evaluator(): 14 | """ 15 | An evaluator with below logics: 16 | 17 | 1. find which eval module to use. 18 | 2. store the eval results, pretty print it in log file as well. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | ) -> None: 24 | self.results = defaultdict(dict) 25 | self.iteration = -1 26 | self.threshold_end = 0.5 27 | 28 | def update_iteration(self, iteration: int) -> None: 29 | """update iteration info""" 30 | self.iteration = iteration 31 | 32 | def update_result(self, metric: str, value: Union[float, dict]) -> None: 33 | if self.iteration > -1: 34 | key_name = "epoch_" + str(self.iteration) 35 | else: 36 | key_name = "final" 37 | if isinstance(value, float): 38 | self.results[key_name].update({metric: value}) 39 | else: 40 | if metric in self.results[key_name]: 41 | self.results[key_name][metric].update(value) 42 | else: 43 | self.results[key_name].update({metric: value}) 44 | 45 | def classify(self, probs, targets, test_data, multilabel=False): 46 | """ 47 | Evaluate classification result. 48 | Args: 49 | probs: np.ndarray for num_data x num_class, predicted probabilities 50 | targets: np.ndarray for multilabel, list of integers for single label 51 | test_labels: map test image ids to a list of class labels 52 | """ 53 | if not targets: 54 | raise ValueError( 55 | "When evaluating classification, need at least give targets") 56 | 57 | if multilabel: 58 | self._eval_multilabel(probs, targets, test_data) 59 | else: 60 | self._eval_singlelabel(probs, targets, test_data) 61 | 62 | def _eval_singlelabel( 63 | self, 64 | scores: np.ndarray, 65 | targets: List[int], 66 | eval_type: str 67 | ) -> None: 68 | """ 69 | if number of labels > 2: 70 | top1 and topk (5 by default) accuracy 71 | if number of labels == 2: 72 | top1 and rocauc 73 | """ 74 | acc_dict = singlelabel.compute_acc_auc(scores, targets) 75 | 76 | log_results = { 77 | k: np.around(v * 100, decimals=2) for k, v in acc_dict.items() 78 | } 79 | save_results = acc_dict 80 | 81 | self.log_and_update(log_results, save_results, eval_type) 82 | 83 | def _eval_multilabel( 84 | self, 85 | scores: np.ndarray, 86 | targets: np.ndarray, 87 | eval_type: str 88 | ) -> None: 89 | num_labels = scores.shape[-1] 90 | targets = multilabel.multihot(targets, num_labels) 91 | 92 | log_results = {} 93 | ap, ar, mAP, mAR = multilabel.compute_map(scores, targets) 94 | f1_dict = multilabel.get_best_f1_scores( 95 | targets, scores, self.threshold_end) 96 | 97 | log_results["mAP"] = np.around(mAP * 100, decimals=2) 98 | log_results["mAR"] = np.around(mAR * 100, decimals=2) 99 | log_results.update({ 100 | k: np.around(v * 100, decimals=2) for k, v in f1_dict.items()}) 101 | save_results = { 102 | "ap": ap, "ar": ar, "mAP": mAP, "mAR": mAR, "f1": f1_dict 103 | } 104 | self.log_and_update(log_results, save_results, eval_type) 105 | 106 | def log_and_update(self, log_results, save_results, eval_type): 107 | log_str = "" 108 | for k, result in log_results.items(): 109 | if not isinstance(result, np.ndarray): 110 | log_str += f"{k}: {result:.2f}\t" 111 | else: 112 | log_str += f"{k}: {list(result)}\t" 113 | logger.info(f"Classification results with {eval_type}: {log_str}") 114 | # save everything 115 | self.update_result("classification", {eval_type: save_results}) 116 | -------------------------------------------------------------------------------- /src/models/build_h2t_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from easydict import EasyDict 4 | from functools import partial, reduce 5 | import numpy as np 6 | import torch.nn as nn 7 | from ..utils import logging 8 | from .vit_backbones.h2t_vit import Head2ToeVisionTransformer 9 | from .vit_backbones.h2t_vit_mae import Head2ToeVisionTransformerMAE 10 | from .build_vit_backbone import MODEL_ZOO 11 | from .build_model import load_model_to_device 12 | from .mlp import MLP 13 | from timm.models.vision_transformer import _cfg 14 | 15 | 16 | logger = logging.get_logger("visual_prompt") 17 | 18 | 19 | def build_h2t_mae_model(model_type, crop_size, h2t_cfg, model_root, 20 | load_pretrain=True, vis=False, combine_method='concat'): 21 | if combine_method == 'concat': 22 | num_q = 1 if h2t_cfg.POOLING_FEATS or h2t_cfg.WEIGHTED_SUM_FEATS else h2t_cfg.NUM_QUERY_TOKENS 23 | m2featdim = { 24 | 'mae_vitb16': int((768+768*12*num_q)*h2t_cfg.KEEP_FRAC), 25 | } 26 | else: 27 | assert(h2t_cfg.KEEP_FRAC == 1.0) 28 | m2featdim = { 29 | 'mae_vitb16': 768, 30 | } 31 | 32 | model = Head2ToeVisionTransformerMAE( 33 | h2t_cfg=h2t_cfg, drop_path_rate=0.1, 34 | global_pool=True, # default settings for mae-finetune 35 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 36 | mlp_ratio=4, qkv_bias=True, 37 | norm_layer=partial(nn.LayerNorm, eps=1e-6)) 38 | 39 | ckpt = os.path.join(model_root, MODEL_ZOO[model_type]) 40 | checkpoint = torch.load(ckpt, map_location="cpu") 41 | state_dict = checkpoint['model'] 42 | 43 | msg = model.load_state_dict(state_dict, strict=False) 44 | logger.info(f'Loading the checkpoint and get the message: {str(msg)}') 45 | model.head = nn.Identity() 46 | return model, m2featdim[model_type] 47 | 48 | def build_h2t_vit_sup_models(model_type, crop_size, h2t_cfg, model_root, 49 | load_pretrain=True, vis=False, combine_method='concat'): 50 | if combine_method == 'concat': 51 | num_q = 1 if h2t_cfg.POOLING_FEATS or h2t_cfg.WEIGHTED_SUM_FEATS else h2t_cfg.NUM_QUERY_TOKENS 52 | m2featdim = { 53 | 'sup_vitb16_imagenet21k': int((768 + 768*12*num_q)*h2t_cfg.KEEP_FRAC), 54 | 'sup_vitb16_imagenet1k' : int((768 + 768*12*num_q)*h2t_cfg.KEEP_FRAC), 55 | 'sup_vitl16_imagenet21k': int((1024 + 1024*24*num_q)*h2t_cfg.KEEP_FRAC), 56 | 'sup_vith14_imagenet21k': int((1280 + 1280*32*num_q)*h2t_cfg.KEEP_FRAC), 57 | } 58 | else: 59 | assert(h2t_cfg.KEEP_FRAC == 1.0) 60 | m2featdim = { 61 | 'sup_vitb16_imagenet21k': 768, 62 | 'sup_vitb16_imagenet1k' : 768, 63 | 'sup_vitl16_imagenet21k': 1024, 64 | 'sup_vith14_imagenet21k': 1028, 65 | } 66 | assert(h2t_cfg is not None) 67 | model = Head2ToeVisionTransformer( 68 | model_type, h2t_cfg, crop_size, num_classes=-1, vis=vis, combine_method=combine_method) 69 | 70 | if load_pretrain: 71 | model.load_from(np.load(os.path.join( 72 | model_root, MODEL_ZOO[model_type]))) 73 | 74 | return model, m2featdim[model_type] 75 | 76 | 77 | class H2TViT(nn.Module): 78 | 79 | def __init__(self, cfg, load_pretrain=True, vis=False, 80 | combine_method='concat'): 81 | super(H2TViT, self).__init__() 82 | self.cfg = cfg 83 | h2t_cfg = cfg.MODEL.H2T 84 | self.combine_method = combine_method 85 | self.build_backbone( 86 | h2t_cfg, cfg, load_pretrain, vis) 87 | self.side = None 88 | self.setup_head(cfg) 89 | 90 | def build_backbone(self, h2t_cfg, cfg, load_pretrain, vis): 91 | assert(cfg.MODEL.TRANSFER_TYPE == 'h2t-prompt') 92 | self.enc, self.feat_dim = build_h2t_vit_sup_models( 93 | cfg.DATA.FEATURE, cfg.DATA.CROPSIZE, h2t_cfg, cfg.MODEL.MODEL_ROOT, 94 | load_pretrain, vis, combine_method=self.combine_method) 95 | 96 | for k, p in self.enc.named_parameters(): 97 | if 'prompt' in k: 98 | p.requires_grad = True 99 | elif 'combine_selfatten_block' in k: 100 | p.requires_grad = True 101 | elif 'layerwise_mlps' in k: 102 | p.requires_grad = True 103 | elif 'layer_position_embeddings' in k: 104 | p.requires_grad = True 105 | elif 'combine_params' in k: 106 | p.requires_grad = True 107 | else: 108 | p.requires_grad = False 109 | 110 | def setup_head(self, cfg): 111 | self.head = MLP( 112 | input_dim=self.feat_dim, 113 | mlp_dims=[self.feat_dim] * self.cfg.MODEL.MLP_NUM + \ 114 | [cfg.DATA.NUMBER_CLASSES], #noqa 115 | special_bias=True 116 | ) 117 | 118 | def forward(self, x, feat_select_ids=None): 119 | x = self.enc(x, feat_select_ids=feat_select_ids) 120 | x = self.head(x) 121 | return x 122 | 123 | 124 | class H2TSSLViT(H2TViT): 125 | 126 | def __init__(self, cfg, combine_method='concat'): 127 | super(H2TSSLViT, self).__init__( 128 | cfg=cfg, combine_method=combine_method) 129 | 130 | def build_backbone(self, h2t_cfg, cfg, load_pretrain, vis): 131 | assert(cfg.MODEL.TRANSFER_TYPE == 'h2t-prompt') 132 | 133 | if 'mae' in cfg.DATA.FEATURE: 134 | build_fn = build_h2t_mae_model 135 | else: 136 | raise NotImplementedError() 137 | 138 | self.enc, self.feat_dim = build_fn( 139 | cfg.DATA.FEATURE, cfg.DATA.CROPSIZE, h2t_cfg, cfg.MODEL.MODEL_ROOT, 140 | load_pretrain, vis, combine_method=self.combine_method) 141 | 142 | for k, p in self.enc.named_parameters(): 143 | if 'prompt' in k: 144 | p.requires_grad = True 145 | elif 'combine_selfatten_block' in k: 146 | p.requires_grad = True 147 | elif 'layerwise_mlps' in k: 148 | p.requires_grad = True 149 | elif 'layer_position_embeddings' in k: 150 | p.requires_grad = True 151 | elif 'combine_params' in k: 152 | p.requires_grad = True 153 | else: 154 | p.requires_grad = False 155 | 156 | 157 | def build_head2toe_model(cfg, combine_method='concat'): 158 | 159 | if cfg.MODEL.TYPE == 'h2t-vit': 160 | model = H2TViT( 161 | cfg, combine_method=combine_method) 162 | elif cfg.MODEL.TYPE == 'h2t-ssl-vit': 163 | model = H2TSSLViT( 164 | cfg, combine_method=combine_method) 165 | else: 166 | raise NotImplementedError() 167 | 168 | model, device = load_model_to_device(model, cfg) 169 | logger.info(f"Device used for model: {device}") 170 | return model, device 171 | -------------------------------------------------------------------------------- /src/models/build_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Model construction functions. 4 | """ 5 | from tabnanny import verbose 6 | import torch 7 | 8 | from .resnet import ResNet 9 | from .convnext import ConvNeXt 10 | from .vit_models import ViT, Swin, SSLViT 11 | from ..utils import logging 12 | logger = logging.get_logger("visual_prompt") 13 | # Supported model types 14 | _MODEL_TYPES = { 15 | "resnet": ResNet, 16 | "convnext": ConvNeXt, 17 | "vit": ViT, 18 | "swin": Swin, 19 | "ssl-vit": SSLViT, 20 | } 21 | 22 | 23 | def build_model(cfg): 24 | """ 25 | build model here 26 | """ 27 | assert ( 28 | cfg.MODEL.TYPE in _MODEL_TYPES.keys() 29 | ), "Model type '{}' not supported".format(cfg.MODEL.TYPE) 30 | assert ( 31 | cfg.NUM_GPUS <= torch.cuda.device_count() 32 | ), "Cannot use more GPU devices than available" 33 | 34 | # Construct the model 35 | train_type = cfg.MODEL.TYPE 36 | model = _MODEL_TYPES[train_type](cfg) 37 | 38 | log_model_info(model, verbose=cfg.DBG) 39 | model, device = load_model_to_device(model, cfg) 40 | logger.info(f"Device used for model: {device}") 41 | 42 | return model, device 43 | 44 | 45 | def log_model_info(model, verbose=False): 46 | """Logs model info""" 47 | if verbose: 48 | logger.info(f"Classification Model:\n{model}") 49 | model_total_params = sum(p.numel() for p in model.parameters()) 50 | model_grad_params = sum( 51 | p.numel() for p in model.parameters() if p.requires_grad) 52 | logger.info("Total Parameters: {0}\t Gradient Parameters: {1}".format( 53 | model_total_params, model_grad_params)) 54 | logger.info("tuned percent:%.3f"%(model_grad_params/model_total_params*100)) 55 | 56 | 57 | def get_current_device(): 58 | if torch.cuda.is_available(): 59 | # Determine the GPU used by the current process 60 | cur_device = torch.cuda.current_device() 61 | else: 62 | cur_device = torch.device('cpu') 63 | return cur_device 64 | 65 | 66 | def load_model_to_device(model, cfg): 67 | cur_device = get_current_device() 68 | if torch.cuda.is_available(): 69 | # Transfer the model to the current GPU device 70 | model = model.cuda(device=cur_device) 71 | # Use multi-process data parallel model in the multi-gpu setting 72 | if cfg.NUM_GPUS > 1: 73 | # Make model replica operate on the current device 74 | model = torch.nn.parallel.DistributedDataParallel( 75 | module=model, device_ids=[cur_device], output_device=cur_device, 76 | find_unused_parameters=True, 77 | ) 78 | else: 79 | model = model.to(cur_device) 80 | return model, cur_device 81 | -------------------------------------------------------------------------------- /src/models/mlp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Modified from: fbcode/multimo/models/encoders/mlp.py 4 | """ 5 | import math 6 | import torch 7 | 8 | from torch import nn 9 | from typing import List, Type 10 | 11 | from ..utils import logging 12 | logger = logging.get_logger("visual_prompt") 13 | 14 | 15 | class MLP(nn.Module): 16 | def __init__( 17 | self, 18 | input_dim: int, 19 | mlp_dims: List[int], 20 | dropout: float = 0.1, 21 | nonlinearity: Type[nn.Module] = nn.ReLU, 22 | normalization: Type[nn.Module] = nn.BatchNorm1d, # nn.LayerNorm, 23 | special_bias: bool = False, 24 | add_bn_first: bool = False, 25 | ): 26 | super(MLP, self).__init__() 27 | projection_prev_dim = input_dim 28 | projection_modulelist = [] 29 | last_dim = mlp_dims[-1] 30 | mlp_dims = mlp_dims[:-1] 31 | 32 | if add_bn_first: 33 | if normalization is not None: 34 | projection_modulelist.append(normalization(projection_prev_dim)) 35 | if dropout != 0: 36 | projection_modulelist.append(nn.Dropout(dropout)) 37 | 38 | for idx, mlp_dim in enumerate(mlp_dims): 39 | fc_layer = nn.Linear(projection_prev_dim, mlp_dim) 40 | nn.init.kaiming_normal_(fc_layer.weight, a=0, mode='fan_out') 41 | projection_modulelist.append(fc_layer) 42 | projection_modulelist.append(nonlinearity()) 43 | 44 | if normalization is not None: 45 | projection_modulelist.append(normalization(mlp_dim)) 46 | 47 | if dropout != 0: 48 | projection_modulelist.append(nn.Dropout(dropout)) 49 | projection_prev_dim = mlp_dim 50 | 51 | self.projection = nn.Sequential(*projection_modulelist) 52 | self.last_layer = nn.Linear(projection_prev_dim, last_dim) 53 | nn.init.kaiming_normal_(self.last_layer.weight, a=0, mode='fan_out') 54 | if special_bias: 55 | prior_prob = 0.01 56 | bias_value = -math.log((1 - prior_prob) / prior_prob) 57 | torch.nn.init.constant_(self.last_layer.bias, bias_value) 58 | 59 | def forward(self, x: torch.Tensor) -> torch.Tensor: 60 | """ 61 | input_arguments: 62 | @x: torch.FloatTensor 63 | """ 64 | x = self.projection(x) 65 | x = self.last_layer(x) 66 | return x 67 | -------------------------------------------------------------------------------- /src/models/vit_adapter/adapter_block.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | ''' 3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 4 | ''' 5 | import math 6 | import logging 7 | from functools import partial 8 | from collections import OrderedDict 9 | from copy import deepcopy 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ 16 | from timm.models.vision_transformer import Attention 17 | from timm.models.vision_transformer import Block 18 | 19 | from ...utils import logging 20 | logger = logging.get_logger("visual_prompt") 21 | 22 | 23 | class Pfeiffer_Block(Block): 24 | 25 | def __init__(self, adapter_config, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 26 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 27 | 28 | super(Pfeiffer_Block, self).__init__( 29 | dim=dim, 30 | num_heads=num_heads, 31 | mlp_ratio=mlp_ratio, 32 | qkv_bias=qkv_bias, 33 | drop=drop, 34 | attn_drop=attn_drop, 35 | drop_path=drop_path, 36 | act_layer=act_layer, 37 | norm_layer=norm_layer) 38 | 39 | self.adapter_config = adapter_config 40 | 41 | if adapter_config.STYLE == "Pfeiffer": 42 | self.adapter_downsample = nn.Linear( 43 | dim, 44 | dim // adapter_config.REDUCATION_FACTOR 45 | ) 46 | self.adapter_upsample = nn.Linear( 47 | dim // adapter_config.REDUCATION_FACTOR, 48 | dim 49 | ) 50 | self.adapter_act_fn = act_layer() 51 | 52 | nn.init.zeros_(self.adapter_downsample.weight) 53 | nn.init.zeros_(self.adapter_downsample.bias) 54 | 55 | nn.init.zeros_(self.adapter_upsample.weight) 56 | nn.init.zeros_(self.adapter_upsample.bias) 57 | else: 58 | raise ValueError("Other adapter styles are not supported.") 59 | 60 | def forward(self, x): 61 | 62 | if self.adapter_config.STYLE == "Pfeiffer": 63 | # same as reguluar ViT block 64 | h = x 65 | x = self.norm1(x) 66 | x = self.attn(x) 67 | x = self.drop_path(x) 68 | x = x + h 69 | 70 | h = x 71 | x = self.norm2(x) 72 | x = self.mlp(x) 73 | 74 | # start to insert adapter layers... 75 | adpt = self.adapter_downsample(x) 76 | adpt = self.adapter_act_fn(adpt) 77 | adpt = self.adapter_upsample(adpt) 78 | x = adpt + x 79 | # ...end 80 | 81 | x = self.drop_path(x) 82 | x = x + h 83 | 84 | return x 85 | -------------------------------------------------------------------------------- /src/models/vit_adapter/vit_mae.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | borrow from https://github.com/facebookresearch/mae/blob/main/models_vit.py 4 | """ 5 | from functools import partial 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from .adapter_block import Pfeiffer_Block 11 | from ..vit_backbones.vit_mae import VisionTransformer 12 | from timm.models.layers import PatchEmbed 13 | from ...utils import logging 14 | logger = logging.get_logger("visual_prompt") 15 | 16 | 17 | class ADPT_VisionTransformer(VisionTransformer): 18 | """ Vision Transformer with support for global average pooling 19 | """ 20 | def __init__( 21 | self, 22 | adapter_cfg, 23 | img_size=224, 24 | patch_size=16, 25 | in_chans=3, 26 | num_classes=1000, 27 | embed_dim=768, 28 | depth=12, 29 | num_heads=12, 30 | mlp_ratio=4., 31 | qkv_bias=True, 32 | representation_size=None, 33 | distilled=False, 34 | drop_rate=0., 35 | attn_drop_rate=0., 36 | drop_path_rate=0., 37 | embed_layer=PatchEmbed, 38 | norm_layer=None, 39 | act_layer=None, 40 | weight_init='', 41 | **kwargs): 42 | 43 | super(ADPT_VisionTransformer, self).__init__( 44 | img_size=img_size, 45 | patch_size=patch_size, 46 | in_chans=in_chans, 47 | num_classes=num_classes, 48 | embed_dim=embed_dim, 49 | depth=depth, 50 | num_heads=num_heads, 51 | mlp_ratio=mlp_ratio, 52 | qkv_bias=qkv_bias, 53 | representation_size=representation_size, 54 | distilled=distilled, 55 | drop_rate=drop_rate, 56 | attn_drop_rate=attn_drop_rate, 57 | drop_path_rate=drop_path_rate, 58 | embed_layer=embed_layer, 59 | norm_layer=norm_layer, 60 | act_layer=act_layer, 61 | weight_init=weight_init, 62 | **kwargs 63 | ) 64 | 65 | self.adapter_cfg = adapter_cfg 66 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 67 | act_layer = act_layer or nn.GELU 68 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 69 | 70 | if adapter_cfg.STYLE == "Pfeiffer": 71 | self.blocks = nn.Sequential(*[ 72 | Pfeiffer_Block( 73 | adapter_config=adapter_cfg, 74 | dim=embed_dim, 75 | num_heads=num_heads, 76 | mlp_ratio=mlp_ratio, 77 | qkv_bias=qkv_bias, 78 | drop=drop_rate, 79 | attn_drop=attn_drop_rate, 80 | drop_path=dpr[i], 81 | norm_layer=norm_layer, 82 | act_layer=act_layer) for i in range(depth)]) 83 | else: 84 | raise ValueError("Other adapter styles are not supported.") 85 | 86 | 87 | 88 | def build_model(model_type, adapter_cfg): 89 | if "vitb" in model_type: 90 | return vit_base_patch16(adapter_cfg) 91 | elif "vitl" in model_type: 92 | return vit_large_patch16(adapter_cfg) 93 | elif "vith" in model_type: 94 | return vit_huge_patch14(adapter_cfg) 95 | 96 | 97 | def vit_base_patch16(adapter_cfg, **kwargs): 98 | model = ADPT_VisionTransformer( 99 | adapter_cfg, 100 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 101 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 102 | mlp_ratio=4, qkv_bias=True, 103 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 104 | return model 105 | 106 | 107 | def vit_large_patch16(adapter_cfg, **kwargs): 108 | model = ADPT_VisionTransformer( 109 | adapter_cfg, 110 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 111 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 112 | mlp_ratio=4, qkv_bias=True, 113 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 114 | return model 115 | 116 | 117 | def vit_huge_patch14(adapter_cfg, **kwargs): 118 | model = ADPT_VisionTransformer( 119 | adapter_cfg, 120 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 121 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 122 | mlp_ratio=4, qkv_bias=True, 123 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 124 | return model 125 | -------------------------------------------------------------------------------- /src/models/vit_adapter/vit_moco.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | borrow from https://github.com/facebookresearch/moco-v3/blob/main/vits.py 4 | """ 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | from functools import partial, reduce 9 | from operator import mul 10 | 11 | from timm.models.vision_transformer import VisionTransformer, _cfg 12 | from timm.models.layers.helpers import to_2tuple 13 | from timm.models.layers import PatchEmbed 14 | 15 | from .adapter_block import Pfeiffer_Block 16 | from ..vit_backbones.vit_moco import VisionTransformerMoCo 17 | from ...utils import logging 18 | logger = logging.get_logger("visual_prompt") 19 | 20 | 21 | class ADPT_VisionTransformerMoCo(VisionTransformerMoCo): 22 | def __init__( 23 | self, 24 | adapter_cfg, 25 | stop_grad_conv1=False, 26 | img_size=224, 27 | patch_size=16, 28 | in_chans=3, 29 | num_classes=1000, 30 | embed_dim=768, 31 | depth=12, 32 | num_heads=12, 33 | mlp_ratio=4., 34 | qkv_bias=True, 35 | representation_size=None, 36 | distilled=False, 37 | drop_rate=0., 38 | attn_drop_rate=0., 39 | drop_path_rate=0., 40 | embed_layer=PatchEmbed, 41 | norm_layer=None, 42 | act_layer=None, 43 | weight_init='', 44 | **kwargs): 45 | 46 | super(ADPT_VisionTransformerMoCo, self).__init__( 47 | stop_grad_conv1=stop_grad_conv1, 48 | img_size=img_size, 49 | patch_size=patch_size, 50 | in_chans=in_chans, 51 | num_classes=num_classes, 52 | embed_dim=embed_dim, 53 | depth=depth, 54 | num_heads=num_heads, 55 | mlp_ratio=mlp_ratio, 56 | qkv_bias=qkv_bias, 57 | representation_size=representation_size, 58 | distilled=distilled, 59 | drop_rate=drop_rate, 60 | attn_drop_rate=attn_drop_rate, 61 | drop_path_rate=drop_path_rate, 62 | embed_layer=embed_layer, 63 | norm_layer=norm_layer, 64 | act_layer=act_layer, 65 | weight_init=weight_init, 66 | **kwargs 67 | ) 68 | 69 | self.adapter_cfg = adapter_cfg 70 | 71 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 72 | act_layer = act_layer or nn.GELU 73 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 74 | 75 | if adapter_cfg.STYLE == "Pfeiffer": 76 | self.blocks = nn.Sequential(*[ 77 | Pfeiffer_Block( 78 | adapter_config=adapter_cfg, 79 | dim=embed_dim, 80 | num_heads=num_heads, 81 | mlp_ratio=mlp_ratio, 82 | qkv_bias=qkv_bias, 83 | drop=drop_rate, 84 | attn_drop=attn_drop_rate, 85 | drop_path=dpr[i], 86 | norm_layer=norm_layer, 87 | act_layer=act_layer) for i in range(depth)]) 88 | else: 89 | raise ValueError("Other adapter styles are not supported.") 90 | 91 | 92 | 93 | def vit_base(adapter_cfg, **kwargs): 94 | model = ADPT_VisionTransformerMoCo( 95 | adapter_cfg, 96 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 97 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 98 | model.default_cfg = _cfg() 99 | return model 100 | -------------------------------------------------------------------------------- /src/models/vit_backbones/h2t_vit_mae.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial, reduce 3 | from operator import mul 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from .timm_h2t_vit import TimmHead2ToeVisionTransformer 9 | 10 | 11 | class Head2ToeVisionTransformerMAE(TimmHead2ToeVisionTransformer): 12 | 13 | def __init__(self, h2t_cfg, global_pool=True, **kwargs): # Note that VPT use global-pool version of MAE pre-trained vit 14 | super().__init__(**kwargs) 15 | self.global_pool = global_pool 16 | if self.global_pool: 17 | norm_layer = kwargs['norm_layer'] 18 | embed_dim = kwargs['embed_dim'] 19 | self.fc_norm = norm_layer(embed_dim) 20 | del self.norm # Remove the original norm 21 | 22 | self.h2t_cfg = h2t_cfg 23 | self.num_query_tokens = h2t_cfg.NUM_QUERY_TOKENS 24 | self.prompt_dropout = nn.Dropout(h2t_cfg.DROPOUT) 25 | self.norm_feats = h2t_cfg.NORMALIZE_FEATS 26 | 27 | # Initiate query prompts 28 | if self.num_query_tokens > 0: 29 | patch_size = self.patch_embed.patch_size 30 | self.query_prompt_embeddings = nn.Parameter(torch.zeros( 31 | len(self.blocks), self.num_query_tokens, self.embed_dim)) 32 | 33 | prompt_dim = self.embed_dim 34 | val = math.sqrt(6./float(3*reduce(mul, patch_size, 1) + prompt_dim)) # noqa 35 | # xavier_uniform initialization 36 | nn.init.uniform_(self.query_prompt_embeddings.data, -val, val) 37 | else: 38 | self.register_parameter('query_prompt_embeddings', None) 39 | 40 | def train(self, mode=True): 41 | if mode: 42 | for module in self.children(): 43 | module.eval() 44 | self.prompt_dropout.train() 45 | else: 46 | for module in self.children(): 47 | module.train(mode) 48 | self.prompt_dropout.eval() 49 | 50 | def embeddings(self, x): 51 | B = x.shape[0] 52 | x = self.patch_embed(x) 53 | 54 | cls_token = self.cls_token.expand(B, -1, -1) 55 | x = torch.cat((cls_token, x), dim=1) 56 | x = x + self.pos_embed 57 | x = self.pos_drop(x) 58 | return x 59 | 60 | def forward_features(self, x): 61 | x = self.embeddings(x) 62 | 63 | B = x.shape[0] 64 | query_outputs = [] 65 | for layer_idx, layer_block in enumerate(self.blocks): 66 | 67 | if self.query_prompt_embeddings is not None: 68 | q_states = self.prompt_dropout( 69 | self.query_prompt_embeddings[layer_idx].expand( 70 | B, -1, -1)) 71 | x = torch.cat([q_states, x], dim=1) 72 | 73 | x = layer_block(x, query_prompt_len=self.num_query_tokens) 74 | 75 | if self.query_prompt_embeddings is not None: 76 | query_outputs.append(x[:, :self.num_query_tokens, :]) 77 | x = x[:, self.num_query_tokens:, :] 78 | 79 | # MAE by default use global pool instead of CLS 80 | x = x[:, 1:, :].mean(dim=1) 81 | x = self.fc_norm(x) 82 | return x, query_outputs 83 | 84 | def forward(self, x, feat_select_ids=None): 85 | x, query_outputs = self.forward_features(x) 86 | B = x.shape[0] 87 | if self.head_dist is not None: 88 | raise NotImplementedError() 89 | else: 90 | included_features = [x] + [q.view(B, -1) for q in query_outputs] 91 | if self.norm_feats: 92 | included_features = [F.normalize(x) for x in included_features] 93 | feats = torch.cat(included_features, dim=1) 94 | 95 | if feat_select_ids is None: 96 | logits = self.head(feats) 97 | else: 98 | logits = self.head(feats[:, feat_select_ids]) 99 | 100 | return logits 101 | -------------------------------------------------------------------------------- /src/models/vit_backbones/vit_mae.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | """ 4 | borrowed from https://github.com/facebookresearch/mae/blob/main/models_vit.py 5 | """ 6 | from functools import partial 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | import timm.models.vision_transformer 12 | 13 | 14 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 15 | """ Vision Transformer with support for global average pooling 16 | """ 17 | def __init__(self, global_pool=False, **kwargs): 18 | super(VisionTransformer, self).__init__(**kwargs) 19 | 20 | self.global_pool = global_pool 21 | if self.global_pool: 22 | norm_layer = kwargs['norm_layer'] 23 | embed_dim = kwargs['embed_dim'] 24 | self.fc_norm = norm_layer(embed_dim) 25 | 26 | del self.norm # remove the original norm 27 | 28 | def forward_features(self, x): 29 | B = x.shape[0] 30 | x = self.patch_embed(x) 31 | 32 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 33 | x = torch.cat((cls_tokens, x), dim=1) 34 | x = x + self.pos_embed 35 | x = self.pos_drop(x) 36 | 37 | for blk in self.blocks: 38 | x = blk(x) 39 | 40 | if self.global_pool: 41 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 42 | outcome = self.fc_norm(x) 43 | else: 44 | x = self.norm(x) 45 | outcome = x[:, 0] 46 | 47 | return outcome 48 | 49 | 50 | def build_model(model_type): 51 | if "vitb" in model_type: 52 | return vit_base_patch16() 53 | elif "vitl" in model_type: 54 | return vit_large_patch16() 55 | elif "vith" in model_type: 56 | return vit_huge_patch14() 57 | 58 | 59 | def vit_base_patch16(**kwargs): 60 | model = VisionTransformer( 61 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 62 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 63 | mlp_ratio=4, qkv_bias=True, 64 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 65 | return model 66 | 67 | 68 | def vit_large_patch16(**kwargs): 69 | model = VisionTransformer( 70 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 71 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 72 | mlp_ratio=4, qkv_bias=True, 73 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 74 | return model 75 | 76 | 77 | def vit_huge_patch14(**kwargs): 78 | model = VisionTransformer( 79 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 80 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 81 | mlp_ratio=4, qkv_bias=True, 82 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 83 | return model 84 | -------------------------------------------------------------------------------- /src/models/vit_backbones/vit_moco.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | """ 4 | borrowed from https://github.com/facebookresearch/moco-v3/blob/main/vits.py 5 | """ 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | from functools import partial, reduce 10 | from operator import mul 11 | 12 | from timm.models.vision_transformer import VisionTransformer, _cfg 13 | from timm.models.layers.helpers import to_2tuple 14 | from timm.models.layers import PatchEmbed 15 | 16 | __all__ = [ 17 | 'vit_small', 18 | 'vit_base', 19 | 'vit_conv_small', 20 | 'vit_conv_base', 21 | ] 22 | 23 | 24 | class VisionTransformerMoCo(VisionTransformer): 25 | def __init__(self, stop_grad_conv1=False, **kwargs): 26 | super().__init__(**kwargs) 27 | # Use fixed 2D sin-cos position embedding 28 | self.build_2d_sincos_position_embedding() 29 | 30 | # weight initialization 31 | for name, m in self.named_modules(): 32 | if isinstance(m, nn.Linear): 33 | if 'qkv' in name: 34 | # treat the weights of Q, K, V separately 35 | val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1])) 36 | nn.init.uniform_(m.weight, -val, val) 37 | else: 38 | nn.init.xavier_uniform_(m.weight) 39 | nn.init.zeros_(m.bias) 40 | nn.init.normal_(self.cls_token, std=1e-6) 41 | 42 | if isinstance(self.patch_embed, PatchEmbed): 43 | # xavier_uniform initialization 44 | val = math.sqrt(6. / float(3 * reduce(mul, self.patch_embed.patch_size, 1) + self.embed_dim)) 45 | nn.init.uniform_(self.patch_embed.proj.weight, -val, val) 46 | nn.init.zeros_(self.patch_embed.proj.bias) 47 | 48 | if stop_grad_conv1: 49 | self.patch_embed.proj.weight.requires_grad = False 50 | self.patch_embed.proj.bias.requires_grad = False 51 | 52 | def build_2d_sincos_position_embedding(self, temperature=10000.): 53 | h, w = self.patch_embed.grid_size 54 | grid_w = torch.arange(w, dtype=torch.float32) 55 | grid_h = torch.arange(h, dtype=torch.float32) 56 | grid_w, grid_h = torch.meshgrid(grid_w, grid_h) 57 | assert self.embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' 58 | pos_dim = self.embed_dim // 4 59 | omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim 60 | omega = 1. / (temperature**omega) 61 | out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) 62 | out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) 63 | pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :] 64 | 65 | assert self.num_tokens == 1, 'Assuming one and only one token, [cls]' 66 | pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32) 67 | self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1)) 68 | self.pos_embed.requires_grad = False 69 | 70 | 71 | class ConvStem(nn.Module): 72 | """ 73 | ConvStem, from Early Convolutions Help Transformers See Better, Tete et al. https://arxiv.org/abs/2106.14881 74 | """ 75 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 76 | super().__init__() 77 | 78 | assert patch_size == 16, 'ConvStem only supports patch size of 16' 79 | assert embed_dim % 8 == 0, 'Embed dimension must be divisible by 8 for ConvStem' 80 | 81 | img_size = to_2tuple(img_size) 82 | patch_size = to_2tuple(patch_size) 83 | self.img_size = img_size 84 | self.patch_size = patch_size 85 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 86 | self.num_patches = self.grid_size[0] * self.grid_size[1] 87 | self.flatten = flatten 88 | 89 | # build stem, similar to the design in https://arxiv.org/abs/2106.14881 90 | stem = [] 91 | input_dim, output_dim = 3, embed_dim // 8 92 | for l in range(4): 93 | stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False)) 94 | stem.append(nn.BatchNorm2d(output_dim)) 95 | stem.append(nn.ReLU(inplace=True)) 96 | input_dim = output_dim 97 | output_dim *= 2 98 | stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1)) 99 | self.proj = nn.Sequential(*stem) 100 | 101 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 102 | 103 | def forward(self, x): 104 | B, C, H, W = x.shape 105 | assert H == self.img_size[0] and W == self.img_size[1], \ 106 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 107 | x = self.proj(x) 108 | if self.flatten: 109 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 110 | x = self.norm(x) 111 | return x 112 | 113 | 114 | def vit_small(**kwargs): 115 | model = VisionTransformerMoCo( 116 | patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 117 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 118 | model.default_cfg = _cfg() 119 | return model 120 | 121 | def vit_base(**kwargs): 122 | model = VisionTransformerMoCo( 123 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 124 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 125 | model.default_cfg = _cfg() 126 | return model 127 | 128 | def vit_conv_small(**kwargs): 129 | # minus one ViT block 130 | model = VisionTransformerMoCo( 131 | patch_size=16, embed_dim=384, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True, 132 | norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs) 133 | model.default_cfg = _cfg() 134 | return model 135 | 136 | def vit_conv_base(**kwargs): 137 | # minus one ViT block 138 | model = VisionTransformerMoCo( 139 | patch_size=16, embed_dim=768, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True, 140 | norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs) 141 | model.default_cfg = _cfg() 142 | return model 143 | -------------------------------------------------------------------------------- /src/models/vit_prompt/vit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | vit with prompt: a clean version with the default settings of VPT 4 | """ 5 | import math 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torchvision as tv 10 | 11 | from functools import reduce 12 | from operator import mul 13 | from torch.nn.modules.utils import _pair 14 | from torch.nn import Conv2d, Dropout 15 | from scipy import ndimage 16 | 17 | from ..vit_backbones.vit import CONFIGS, Transformer, VisionTransformer, np2th 18 | from ...utils import logging 19 | 20 | logger = logging.get_logger("visual_prompt") 21 | 22 | 23 | class PromptedTransformer(Transformer): 24 | def __init__(self, prompt_config, config, img_size, vis): 25 | assert prompt_config.LOCATION == "prepend" 26 | assert prompt_config.INITIATION == "random" 27 | assert prompt_config.NUM_DEEP_LAYERS is None 28 | assert not prompt_config.DEEP_SHARED 29 | super(PromptedTransformer, self).__init__( 30 | config, img_size, vis) 31 | 32 | self.prompt_config = prompt_config 33 | self.vit_config = config 34 | 35 | img_size = _pair(img_size) 36 | patch_size = _pair(config.patches["size"]) 37 | 38 | num_tokens = self.prompt_config.NUM_TOKENS 39 | self.num_tokens = num_tokens # number of prompted tokens 40 | 41 | self.prompt_dropout = Dropout(self.prompt_config.DROPOUT) 42 | 43 | # if project the prompt embeddings 44 | if self.prompt_config.PROJECT > -1: 45 | # only for prepend / add 46 | prompt_dim = self.prompt_config.PROJECT 47 | self.prompt_proj = nn.Linear( 48 | prompt_dim, config.hidden_size) 49 | nn.init.kaiming_normal_( 50 | self.prompt_proj.weight, a=0, mode='fan_out') 51 | else: 52 | prompt_dim = config.hidden_size 53 | self.prompt_proj = nn.Identity() 54 | 55 | # initiate prompt: 56 | if self.prompt_config.INITIATION == "random": 57 | val = math.sqrt(6. / float(3 * reduce(mul, patch_size, 1) + prompt_dim)) # noqa 58 | 59 | self.prompt_embeddings = nn.Parameter(torch.zeros( 60 | 1, num_tokens, prompt_dim)) 61 | # xavier_uniform initialization 62 | nn.init.uniform_(self.prompt_embeddings.data, -val, val) 63 | 64 | if self.prompt_config.DEEP: # noqa 65 | 66 | total_d_layer = config.transformer["num_layers"]-1 67 | self.deep_prompt_embeddings = nn.Parameter(torch.zeros( 68 | total_d_layer, num_tokens, prompt_dim)) 69 | # xavier_uniform initialization 70 | nn.init.uniform_(self.deep_prompt_embeddings.data, -val, val) 71 | 72 | else: 73 | raise ValueError("Other initiation scheme is not supported") 74 | 75 | def incorporate_prompt(self, x): 76 | # combine prompt embeddings with image-patch embeddings 77 | B = x.shape[0] 78 | # after CLS token, all before image patches 79 | x = self.embeddings(x) # (batch_size, 1 + n_patches, hidden_dim) 80 | x = torch.cat(( 81 | x[:, :1, :], 82 | self.prompt_dropout(self.prompt_proj(self.prompt_embeddings).expand(B, -1, -1)), 83 | x[:, 1:, :] 84 | ), dim=1) 85 | # (batch_size, cls_token + n_prompt + n_patches, hidden_dim) 86 | 87 | return x 88 | 89 | def train(self, mode=True): 90 | # set train status for this class: disable all but the prompt-related modules 91 | if mode: 92 | # training: 93 | self.encoder.eval() 94 | self.embeddings.eval() 95 | self.prompt_proj.train() 96 | self.prompt_dropout.train() 97 | else: 98 | # eval: 99 | for module in self.children(): 100 | module.train(mode) 101 | 102 | def forward_deep_prompt(self, embedding_output): 103 | attn_weights = [] 104 | hidden_states = None 105 | weights = None 106 | B = embedding_output.shape[0] 107 | num_layers = self.vit_config.transformer["num_layers"] 108 | 109 | for i in range(num_layers): 110 | if i == 0: 111 | hidden_states, weights = self.encoder.layer[i](embedding_output) 112 | else: 113 | if i <= self.deep_prompt_embeddings.shape[0]: 114 | deep_prompt_emb = self.prompt_dropout(self.prompt_proj( 115 | self.deep_prompt_embeddings[i-1]).expand(B, -1, -1)) 116 | 117 | hidden_states = torch.cat(( 118 | hidden_states[:, :1, :], 119 | deep_prompt_emb, 120 | hidden_states[:, (1+self.num_tokens):, :] 121 | ), dim=1) 122 | 123 | 124 | hidden_states, weights = self.encoder.layer[i](hidden_states) 125 | 126 | if self.encoder.vis: 127 | attn_weights.append(weights) 128 | 129 | encoded = self.encoder.encoder_norm(hidden_states) 130 | return encoded, attn_weights 131 | 132 | def forward(self, x): 133 | # this is the default version: 134 | embedding_output = self.incorporate_prompt(x) 135 | 136 | if self.prompt_config.DEEP: 137 | encoded, attn_weights = self.forward_deep_prompt( 138 | embedding_output) 139 | else: 140 | encoded, attn_weights = self.encoder(embedding_output) 141 | 142 | return encoded, attn_weights 143 | 144 | 145 | class PromptedVisionTransformer(VisionTransformer): 146 | def __init__( 147 | self, prompt_cfg, model_type, 148 | img_size=224, num_classes=21843, vis=False 149 | ): 150 | assert prompt_cfg.VIT_POOL_TYPE == "original" 151 | super(PromptedVisionTransformer, self).__init__( 152 | model_type, img_size, num_classes, vis) 153 | if prompt_cfg is None: 154 | raise ValueError("prompt_cfg cannot be None if using PromptedVisionTransformer") 155 | self.prompt_cfg = prompt_cfg 156 | vit_cfg = CONFIGS[model_type] 157 | self.transformer = PromptedTransformer( 158 | prompt_cfg, vit_cfg, img_size, vis) 159 | 160 | def forward(self, x, vis=False): 161 | x, attn_weights = self.transformer(x) 162 | 163 | x = x[:, 0] 164 | 165 | logits = self.head(x) 166 | 167 | if not vis: 168 | return logits 169 | return logits, attn_weights 170 | -------------------------------------------------------------------------------- /src/models/vit_prompt/vit_mae.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | vit-moco-v3 with prompt 4 | """ 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | import torchvision as tv 9 | 10 | from functools import partial, reduce 11 | from operator import mul 12 | from torch.nn import Conv2d, Dropout 13 | from timm.models.vision_transformer import _cfg 14 | 15 | from ..vit_backbones.vit_mae import VisionTransformer 16 | from ...utils import logging 17 | logger = logging.get_logger("visual_prompt") 18 | 19 | 20 | class PromptedVisionTransformer(VisionTransformer): 21 | def __init__(self, prompt_config, **kwargs): 22 | super().__init__(**kwargs) 23 | self.prompt_config = prompt_config 24 | if self.prompt_config.DEEP and self.prompt_config.LOCATION not in ["prepend", ]: 25 | raise ValueError("Deep-{} is not supported".format(self.prompt_config.LOCATION)) 26 | 27 | num_tokens = self.prompt_config.NUM_TOKENS 28 | 29 | self.num_tokens = num_tokens 30 | self.prompt_dropout = Dropout(self.prompt_config.DROPOUT) 31 | 32 | # initiate prompt: 33 | if self.prompt_config.INITIATION == "random": 34 | val = math.sqrt(6. / float(3 * reduce(mul, self.patch_embed.patch_size, 1) + self.embed_dim)) # noqa 35 | 36 | self.prompt_embeddings = nn.Parameter(torch.zeros( 37 | 1, num_tokens, self.embed_dim)) 38 | # xavier_uniform initialization 39 | nn.init.uniform_(self.prompt_embeddings.data, -val, val) 40 | 41 | if self.prompt_config.DEEP: 42 | self.deep_prompt_embeddings = nn.Parameter(torch.zeros( 43 | len(self.blocks) - 1, 44 | num_tokens, self.embed_dim 45 | )) 46 | # xavier_uniform initialization 47 | nn.init.uniform_( 48 | self.deep_prompt_embeddings.data, -val, val) 49 | 50 | else: 51 | raise ValueError("Other initiation scheme is not supported") 52 | 53 | def incorporate_prompt(self, x): 54 | # combine prompt embeddings with image-patch embeddings 55 | B = x.shape[0] 56 | if self.prompt_config.LOCATION == "prepend": 57 | # after CLS token, all before image patches 58 | x = self.embeddings(x) # (batch_size, 1 + n_patches, hidden_dim) 59 | x = torch.cat(( 60 | x[:, :1, :], 61 | self.prompt_dropout( 62 | self.prompt_embeddings.expand(B, -1, -1)), 63 | x[:, 1:, :] 64 | ), dim=1) 65 | # (batch_size, cls_token + n_prompt + n_patches, hidden_dim) 66 | 67 | else: 68 | raise ValueError("Other prompt locations are not supported") 69 | return x 70 | 71 | def embeddings(self, x): 72 | B = x.shape[0] 73 | x = self.patch_embed(x) 74 | 75 | cls_tokens = self.cls_token.expand(B, -1, -1) 76 | x = torch.cat((cls_tokens, x), dim=1) 77 | x = x + self.pos_embed 78 | x = self.pos_drop(x) 79 | return x 80 | 81 | def train(self, mode=True): 82 | # set train status for this class: disable all but the prompt-related modules 83 | if mode: 84 | # training: 85 | self.blocks.eval() 86 | self.patch_embed.eval() 87 | self.pos_drop.eval() 88 | self.prompt_dropout.train() 89 | else: 90 | # eval: 91 | for module in self.children(): 92 | module.train(mode) 93 | 94 | def forward_features(self, x): 95 | x = self.incorporate_prompt(x) 96 | 97 | if self.prompt_config.DEEP: 98 | B = x.shape[0] 99 | num_layers = len(self.blocks) 100 | 101 | for i in range(num_layers): 102 | if i == 0: 103 | x = self.blocks[i](x) 104 | else: 105 | # prepend 106 | x = torch.cat(( 107 | x[:, :1, :], 108 | self.prompt_dropout( 109 | self.deep_prompt_embeddings[i-1].expand(B, -1, -1) 110 | ), 111 | x[:, (1 + self.num_tokens):, :] 112 | ), dim=1) 113 | x = self.blocks[i](x) 114 | else: 115 | for blk in self.blocks: 116 | x = blk(x) 117 | 118 | if self.prompt_config.VIT_POOL_TYPE == "imgprompt_pool": 119 | assert self.prompt_config.LOCATION == "prepend" 120 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 121 | outcome = self.fc_norm(x) 122 | elif self.prompt_config.VIT_POOL_TYPE == "original": 123 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 124 | outcome = self.fc_norm(x) 125 | elif self.prompt_config.VIT_POOL_TYPE == "img_pool": 126 | assert self.prompt_config.LOCATION == "prepend" 127 | x = x[:, self.num_tokens+1:, :].mean(dim=1) 128 | outcome = self.fc_norm(x) 129 | elif self.prompt_config.VIT_POOL_TYPE == "prompt_pool": 130 | assert self.prompt_config.LOCATION == "prepend" 131 | x = x[:, 1:self.num_tokens+1, :].mean(dim=1) 132 | outcome = self.fc_norm(x) 133 | else: 134 | raise ValueError("pooling type for output is not supported") 135 | 136 | return outcome 137 | 138 | 139 | def build_model(model_type, prompt_cfg): 140 | if "vitb" in model_type: 141 | return vit_base_patch16(prompt_cfg) 142 | elif "vitl" in model_type: 143 | return vit_large_patch16(prompt_cfg) 144 | elif "vith" in model_type: 145 | return vit_huge_patch14(prompt_cfg) 146 | 147 | 148 | def vit_base_patch16(prompt_cfg, **kwargs): 149 | model = PromptedVisionTransformer( 150 | prompt_cfg, 151 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 152 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 153 | mlp_ratio=4, qkv_bias=True, 154 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 155 | return model 156 | 157 | 158 | def vit_large_patch16(prompt_cfg, **kwargs): 159 | model = PromptedVisionTransformer( 160 | prompt_cfg, 161 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 162 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 163 | mlp_ratio=4, qkv_bias=True, 164 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 165 | return model 166 | 167 | 168 | def vit_huge_patch14(prompt_cfg, **kwargs): 169 | model = PromptedVisionTransformer( 170 | prompt_cfg, 171 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 172 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 173 | mlp_ratio=4, qkv_bias=True, 174 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 175 | return model 176 | 177 | 178 | -------------------------------------------------------------------------------- /src/models/vit_prompt/vit_moco.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | vit-moco-v3 with prompt 4 | """ 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | import torchvision as tv 9 | 10 | from functools import partial, reduce 11 | from operator import mul 12 | from torch.nn import Conv2d, Dropout 13 | from timm.models.vision_transformer import _cfg 14 | 15 | from ..vit_backbones.vit_moco import VisionTransformerMoCo 16 | from ...utils import logging 17 | logger = logging.get_logger("visual_prompt") 18 | 19 | 20 | class PromptedVisionTransformerMoCo(VisionTransformerMoCo): 21 | def __init__(self, prompt_config, **kwargs): 22 | super().__init__(**kwargs) 23 | self.prompt_config = prompt_config 24 | 25 | if self.prompt_config.DEEP and self.prompt_config.LOCATION not in ["prepend", ]: 26 | raise ValueError("Deep-{} is not supported".format(self.prompt_config.LOCATION)) 27 | 28 | num_tokens = self.prompt_config.NUM_TOKENS 29 | 30 | self.num_tokens = num_tokens 31 | self.prompt_dropout = Dropout(self.prompt_config.DROPOUT) 32 | 33 | # initiate prompt: 34 | if self.prompt_config.INITIATION == "random": 35 | val = math.sqrt(6. / float(3 * reduce(mul, self.patch_embed.patch_size, 1) + self.embed_dim)) # noqa 36 | 37 | self.prompt_embeddings = nn.Parameter(torch.zeros( 38 | 1, num_tokens, self.embed_dim)) 39 | # xavier_uniform initialization 40 | nn.init.uniform_(self.prompt_embeddings.data, -val, val) 41 | if self.prompt_config.DEEP: 42 | self.deep_prompt_embeddings = nn.Parameter(torch.zeros( 43 | len(self.blocks) - 1, 44 | num_tokens, self.embed_dim 45 | )) 46 | # xavier_uniform initialization 47 | nn.init.uniform_( 48 | self.deep_prompt_embeddings.data, -val, val) 49 | 50 | else: 51 | raise ValueError("Other initiation scheme is not supported") 52 | 53 | def incorporate_prompt(self, x): 54 | # combine prompt embeddings with image-patch embeddings 55 | B = x.shape[0] 56 | if self.prompt_config.LOCATION == "prepend": 57 | # after CLS token, all before image patches 58 | x = self.embeddings(x) # (batch_size, 1 + n_patches, hidden_dim) 59 | x = torch.cat(( 60 | x[:, :1, :], 61 | self.prompt_dropout( 62 | self.prompt_embeddings.expand(B, -1, -1)), 63 | x[:, 1:, :] 64 | ), dim=1) 65 | # (batch_size, cls_token + n_prompt + n_patches, hidden_dim) 66 | else: 67 | raise ValueError("Other prompt locations are not supported") 68 | 69 | return x 70 | 71 | def embeddings(self, x): 72 | x = self.patch_embed(x) 73 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) 74 | if self.dist_token is None: 75 | x = torch.cat((cls_token, x), dim=1) 76 | else: 77 | x = torch.cat(( 78 | cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), 79 | dim=1) 80 | x = self.pos_drop(x + self.pos_embed) 81 | return x 82 | 83 | def train(self, mode=True): 84 | # set train status for this class: disable all but the prompt-related modules 85 | if mode: 86 | # training: 87 | self.blocks.eval() 88 | self.patch_embed.eval() 89 | self.pos_drop.eval() 90 | self.prompt_dropout.train() 91 | else: 92 | # eval: 93 | for module in self.children(): 94 | module.train(mode) 95 | 96 | def forward_features(self, x): 97 | x = self.incorporate_prompt(x) 98 | 99 | # deep 100 | if self.prompt_config.DEEP: 101 | B = x.shape[0] 102 | num_layers = len(self.blocks) 103 | 104 | for i in range(num_layers): 105 | if i == 0: 106 | x = self.blocks[i](x) 107 | else: 108 | # prepend 109 | x = torch.cat(( 110 | x[:, :1, :], 111 | self.prompt_dropout( 112 | self.deep_prompt_embeddings[i-1].expand(B, -1, -1) 113 | ), 114 | x[:, (1 + self.num_tokens):, :] 115 | ), dim=1) 116 | x = self.blocks[i](x) 117 | else: 118 | # not deep: 119 | x = self.blocks(x) 120 | 121 | x = self.norm(x) 122 | if self.dist_token is None: 123 | return self.pre_logits(x[:, 0]) 124 | else: 125 | return x[:, 0], x[:, 1] 126 | 127 | 128 | def vit_base(prompt_cfg, **kwargs): 129 | model = PromptedVisionTransformerMoCo( 130 | prompt_cfg, 131 | patch_size=16, embed_dim=768, depth=12, 132 | num_heads=12, mlp_ratio=4, qkv_bias=True, 133 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 134 | model.default_cfg = _cfg() 135 | return model 136 | 137 | -------------------------------------------------------------------------------- /src/solver/losses.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from typing import Optional 6 | 7 | from ..utils import logging 8 | logger = logging.get_logger("visual_prompt") 9 | 10 | 11 | class SigmoidLoss(nn.Module): 12 | def __init__(self, cfg=None): 13 | super(SigmoidLoss, self).__init__() 14 | 15 | def is_single(self): 16 | return True 17 | 18 | def is_local(self): 19 | return False 20 | 21 | def multi_hot(self, labels: torch.Tensor, nb_classes: int) -> torch.Tensor: 22 | labels = labels.unsqueeze(1) # (batch_size, 1) 23 | target = torch.zeros( 24 | labels.size(0), nb_classes, device=labels.device 25 | ).scatter_(1, labels, 1.) 26 | # (batch_size, num_classes) 27 | return target 28 | 29 | def loss( 30 | self, logits, targets, per_cls_weights, 31 | multihot_targets: Optional[bool] = False 32 | ): 33 | # targets: 1d-tensor of integer 34 | # Only support single label at this moment 35 | # if len(targets.shape) != 2: 36 | num_classes = logits.shape[1] 37 | targets = self.multi_hot(targets, num_classes) 38 | 39 | loss = F.binary_cross_entropy_with_logits( 40 | logits, targets, reduction="none") 41 | # logger.info(f"loss shape: {loss.shape}") 42 | weight = torch.tensor( 43 | per_cls_weights, device=logits.device 44 | ).unsqueeze(0) 45 | # logger.info(f"weight shape: {weight.shape}") 46 | loss = torch.mul(loss.to(torch.float32), weight.to(torch.float32)) 47 | return torch.sum(loss) / targets.shape[0] 48 | 49 | def forward( 50 | self, pred_logits, targets, per_cls_weights, multihot_targets=False 51 | ): 52 | loss = self.loss( 53 | pred_logits, targets, per_cls_weights, multihot_targets) 54 | return loss 55 | 56 | 57 | class SoftmaxLoss(SigmoidLoss): 58 | def __init__(self, cfg=None): 59 | super(SoftmaxLoss, self).__init__() 60 | 61 | def loss(self, logits, targets, per_cls_weights, kwargs): 62 | weight = torch.tensor( 63 | per_cls_weights, device=logits.device 64 | ) 65 | loss = F.cross_entropy(logits, targets, weight, reduction="none") 66 | 67 | return torch.sum(loss) / targets.shape[0] 68 | 69 | 70 | LOSS = { 71 | "softmax": SoftmaxLoss, 72 | } 73 | 74 | 75 | def build_loss(cfg): 76 | loss_name = cfg.SOLVER.LOSS 77 | assert loss_name in LOSS, \ 78 | f'loss name {loss_name} is not supported' 79 | loss_fn = LOSS[loss_name] 80 | if not loss_fn: 81 | return None 82 | else: 83 | return loss_fn(cfg) 84 | -------------------------------------------------------------------------------- /src/solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import math 3 | 4 | import torch.optim as optim 5 | from fvcore.common.config import CfgNode 6 | from torch.optim.lr_scheduler import LambdaLR 7 | 8 | 9 | def make_scheduler( 10 | optimizer: optim.Optimizer, train_params: CfgNode 11 | ) -> LambdaLR: 12 | warmup = train_params.WARMUP_EPOCH 13 | total_iters = train_params.TOTAL_EPOCH 14 | 15 | if train_params.SCHEDULER == "cosine": 16 | scheduler = WarmupCosineSchedule( 17 | optimizer, 18 | warmup_steps=warmup, 19 | t_total=total_iters 20 | ) 21 | elif train_params.SCHEDULER == "cosine_hardrestart": 22 | scheduler = WarmupCosineWithHardRestartsSchedule( 23 | optimizer, 24 | warmup_steps=warmup, 25 | t_total=total_iters 26 | ) 27 | 28 | elif train_params.SCHEDULER == "plateau": 29 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 30 | optimizer, 31 | "max", 32 | patience=5, 33 | verbose=True, 34 | factor=train_params.LR_DECAY_FACTOR, 35 | ) 36 | else: 37 | scheduler = None 38 | return scheduler 39 | 40 | 41 | class WarmupCosineSchedule(LambdaLR): 42 | """ Linear warmup and then cosine decay. 43 | Linearly increases learning rate from 0 to 1 over `warmup_steps`. 44 | Decreases learning rate from 1. to 0. over remaining 45 | `t_total - warmup_steps` steps following a cosine curve. 46 | If `cycles` (default=0.5) is different from default, learning rate 47 | follows cosine function after warmup. 48 | """ 49 | def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1): 50 | self.warmup_steps = warmup_steps 51 | self.t_total = t_total 52 | self.cycles = cycles 53 | super(WarmupCosineSchedule, self).__init__( 54 | optimizer, self.lr_lambda, last_epoch=last_epoch) 55 | 56 | def lr_lambda(self, step): 57 | if step < self.warmup_steps: 58 | return float(step) / float(max(1.0, self.warmup_steps)) 59 | # progress after warmup 60 | progress = float(step - self.warmup_steps) / float(max( 61 | 1, self.t_total - self.warmup_steps)) 62 | return max( 63 | 0.0, 64 | 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)) 65 | ) 66 | 67 | 68 | class WarmupCosineWithHardRestartsSchedule(LambdaLR): 69 | """ Linear warmup and then cosine cycles with hard restarts. 70 | Linearly increases learning rate from 0 to 1 over `warmup_steps`. 71 | If `cycles` (default=1.) is different from default, learning rate 72 | follows `cycles` times a cosine decaying learning rate 73 | (with hard restarts). 74 | """ 75 | def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1): 76 | self.warmup_steps = warmup_steps 77 | self.t_total = t_total 78 | self.cycles = cycles 79 | super(WarmupCosineWithHardRestartsSchedule, self).__init__( 80 | optimizer, self.lr_lambda, last_epoch=last_epoch) 81 | 82 | def lr_lambda(self, step): 83 | if step < self.warmup_steps: 84 | return float(step) / float(max(1, self.warmup_steps)) 85 | # progress after warmup 86 | progress = float(step - self.warmup_steps) / float( 87 | max(1, self.t_total - self.warmup_steps)) 88 | if progress >= 1.0: 89 | return 0.0 90 | return max( 91 | 0.0, 92 | 0.5 * (1. + math.cos( 93 | math.pi * ((float(self.cycles) * progress) % 1.0))) 94 | ) 95 | -------------------------------------------------------------------------------- /src/utils/distributed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Distributed helpers.""" 4 | 5 | import torch 6 | import torch.distributed as dist 7 | _LOCAL_PROCESS_GROUP = None 8 | 9 | 10 | def get_world_size() -> int: 11 | if not dist.is_available(): 12 | return 1 13 | if not dist.is_initialized(): 14 | return 1 15 | return dist.get_world_size() 16 | 17 | 18 | def get_rank() -> int: 19 | if not dist.is_available(): 20 | return 0 21 | if not dist.is_initialized(): 22 | return 0 23 | return dist.get_rank() 24 | 25 | 26 | def is_master_process(num_gpus=8): 27 | """ 28 | Determines if the current process is the master process. 29 | """ 30 | if torch.distributed.is_initialized(): 31 | return dist.get_rank() % num_gpus == 0 32 | else: 33 | return True 34 | 35 | 36 | def run( 37 | local_rank, 38 | num_proc, 39 | func, 40 | init_method, 41 | shard_id, 42 | num_shards, 43 | backend, 44 | cfg, 45 | args, 46 | ): 47 | """ 48 | Runs a function from a child process. 49 | Args: 50 | local_rank (int): rank of the current process on the current machine. 51 | num_proc (int): number of processes per machine. 52 | func (function): function to execute on each of the process. 53 | init_method (string): method to initialize the distributed training. 54 | TCP initialization: equiring a network address reachable from all 55 | processes followed by the port. 56 | Shared file-system initialization: makes use of a file system that 57 | is shared and visible from all machines. The URL should start with 58 | file:// and contain a path to a non-existent file on a shared file 59 | system. 60 | shard_id (int): the rank of the current machine. 61 | num_shards (int): number of overall machines for the distributed 62 | training job. 63 | backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are 64 | supports, each with different capabilities. Details can be found 65 | here: 66 | https://pytorch.org/docs/stable/distributed.html 67 | cfg (CfgNode): configs. Details can be found in 68 | loco/config/defaults.py 69 | """ 70 | # Initialize the process group. 71 | # shard_id = get_rank() 72 | world_size = num_proc * num_shards 73 | rank = shard_id * num_proc + local_rank 74 | 75 | try: 76 | torch.distributed.init_process_group( 77 | backend=backend, 78 | init_method=init_method, 79 | world_size=world_size, 80 | rank=rank, 81 | ) 82 | except Exception as e: 83 | raise e 84 | 85 | torch.cuda.set_device(local_rank) 86 | func(cfg, args) 87 | 88 | 89 | def destroy_process_group(): 90 | """Destroys the default process group.""" 91 | torch.distributed.destroy_process_group() 92 | 93 | 94 | def scaled_all_reduce(cfg, tensors): 95 | """Performs the scaled all_reduce operation on the provided tensors. 96 | 97 | The input tensors are modified in-place. Currently supports only the sum 98 | reduction operator. The reduced values are scaled by the inverse size of 99 | the process group (equivalent to cfg.NUM_GPUS). 100 | """ 101 | # Queue the reductions 102 | reductions = [] 103 | for tensor in tensors: 104 | reduction = torch.distributed.all_reduce(tensor, async_op=True) 105 | reductions.append(reduction) 106 | # Wait for reductions to finish 107 | for reduction in reductions: 108 | reduction.wait() 109 | # Scale the results 110 | for tensor in tensors: 111 | tensor.mul_(1.0 / cfg.NUM_GPUS / cfg.NUM_SHARDS) 112 | return tensors 113 | 114 | 115 | def cat_all_gather(tensors): 116 | """Performs the concatenated all_gather operation on the provided tensors. 117 | """ 118 | tensors_gather = [ 119 | torch.ones_like(tensors) 120 | for _ in range(torch.distributed.get_world_size()) 121 | ] 122 | torch.distributed.all_gather(tensors_gather, tensors, async_op=False) 123 | 124 | output = torch.cat(tensors_gather, dim=0) 125 | return output 126 | 127 | 128 | def local_cat_all_gather(tensors): 129 | """Performs the concatenated all_gather operation on the provided tensors. 130 | """ 131 | tensors_gather = [ 132 | torch.ones_like(tensors) 133 | for _ in range(get_local_size()) 134 | ] 135 | torch.distributed.all_gather( 136 | tensors_gather, 137 | tensors, 138 | async_op=False, 139 | group=_LOCAL_PROCESS_GROUP, 140 | ) 141 | output = torch.cat(tensors_gather, dim=0) 142 | return output 143 | 144 | 145 | def get_local_size(): 146 | """ 147 | Returns: 148 | The size of the per-machine process group, 149 | i.e. the number of processes per machine. 150 | """ 151 | if not dist.is_available(): 152 | return 1 153 | if not dist.is_initialized(): 154 | return 1 155 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 156 | 157 | 158 | def get_local_rank(): 159 | """ 160 | Returns: 161 | The rank of the current process within the local (per-machine) process group. 162 | """ 163 | if not dist.is_available(): 164 | return 0 165 | if not dist.is_initialized(): 166 | return 0 167 | assert _LOCAL_PROCESS_GROUP is not None 168 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 169 | -------------------------------------------------------------------------------- /src/utils/file_io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Project specific pathmanagers for a project as recommended by Detectron2 5 | """ 6 | from iopath.common.file_io import PathManager as PathManagerBase 7 | from iopath.common.file_io import HTTPURLHandler 8 | 9 | 10 | PathManager = PathManagerBase() 11 | PathManager.register_handler(HTTPURLHandler()) 12 | -------------------------------------------------------------------------------- /src/utils/io_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | a bunch of helper functions for read and write data 4 | """ 5 | import os 6 | import json 7 | import numpy as np 8 | import time 9 | import pandas as pd 10 | 11 | from typing import List, Union 12 | from PIL import Image, ImageFile 13 | Image.MAX_IMAGE_PIXELS = None 14 | 15 | 16 | def save_or_append_df(out_path, df): 17 | if os.path.exists(out_path): 18 | previous_df = pd.read_pickle(out_path) 19 | df = pd.concat([previous_df, df], ignore_index=True) 20 | df.to_pickle(out_path) 21 | print(f"Saved output at {out_path}") 22 | 23 | 24 | class JSONEncoder(json.JSONEncoder): 25 | def default(self, obj): 26 | if isinstance(obj, np.ndarray): 27 | return obj.tolist() 28 | elif isinstance(obj, bytes): 29 | return str(obj, encoding='utf-8') 30 | elif isinstance(obj, np.integer): 31 | return int(obj) 32 | elif isinstance(obj, np.floating): 33 | return float(obj) 34 | elif isinstance(obj, np.ndarray): 35 | return obj.tolist() 36 | else: 37 | # return super(MyEncoder, self).default(obj) 38 | 39 | raise TypeError( 40 | "Unserializable object {} of type {}".format(obj, type(obj)) 41 | ) 42 | 43 | 44 | def write_json(data: Union[list, dict], outfile: str) -> None: 45 | json_dir, _ = os.path.split(outfile) 46 | if json_dir and not os.path.exists(json_dir): 47 | os.makedirs(json_dir) 48 | 49 | with open(outfile, 'w') as f: 50 | json.dump(data, f, cls=JSONEncoder, ensure_ascii=False, indent=2) 51 | 52 | 53 | def read_json(filename: str) -> Union[list, dict]: 54 | """read json files""" 55 | with open(filename, "rb") as fin: 56 | data = json.load(fin, encoding="utf-8") 57 | return data 58 | 59 | 60 | def pil_loader(path: str) -> Image.Image: 61 | """load an image from path, and suppress warning""" 62 | # to avoid crashing for truncated (corrupted images) 63 | ImageFile.LOAD_TRUNCATED_IMAGES = True 64 | # open path as file to avoid ResourceWarning 65 | # (https://github.com/python-pillow/Pillow/issues/835) 66 | with open(path, 'rb') as f: 67 | img = Image.open(f) 68 | return img.convert('RGB') 69 | -------------------------------------------------------------------------------- /src/utils/logging.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Logging.""" 4 | 5 | import builtins 6 | import decimal 7 | import functools 8 | import logging 9 | import simplejson 10 | import sys 11 | import os 12 | from termcolor import colored 13 | 14 | from .distributed import is_master_process 15 | from .file_io import PathManager 16 | 17 | # Show filename and line number in logs 18 | _FORMAT = "[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s" 19 | 20 | 21 | def _suppress_print(): 22 | """Suppresses printing from the current process.""" 23 | 24 | def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False): 25 | pass 26 | 27 | builtins.print = print_pass 28 | 29 | 30 | # cache the opened file object, so that different calls to `setup_logger` 31 | # with the same file name can safely write to the same file. 32 | @functools.lru_cache(maxsize=None) 33 | def _cached_log_stream(filename): 34 | return PathManager.open(filename, "a") 35 | 36 | 37 | @functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers # noqa 38 | def setup_logging( 39 | num_gpu, num_shards, output="", name="visual_prompt", color=True): 40 | """Sets up the logging.""" 41 | # Enable logging only for the master process 42 | if is_master_process(num_gpu): 43 | # Clear the root logger to prevent any existing logging config 44 | # (e.g. set by another module) from messing with our setup 45 | logging.root.handlers = [] 46 | # Configure logging 47 | logging.basicConfig( 48 | level=logging.INFO, format=_FORMAT, stream=sys.stdout 49 | ) 50 | else: 51 | _suppress_print() 52 | 53 | if name is None: 54 | name = __name__ 55 | logger = logging.getLogger(name) 56 | # remove any lingering handler 57 | logger.handlers.clear() 58 | 59 | logger.setLevel(logging.INFO) 60 | logger.propagate = False 61 | 62 | plain_formatter = logging.Formatter( 63 | "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s", 64 | datefmt="%m/%d %H:%M:%S", 65 | ) 66 | if color: 67 | formatter = _ColorfulFormatter( 68 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 69 | datefmt="%m/%d %H:%M:%S", 70 | root_name=name, 71 | abbrev_name=str(name), 72 | ) 73 | else: 74 | formatter = plain_formatter 75 | 76 | if is_master_process(num_gpu): 77 | ch = logging.StreamHandler(stream=sys.stdout) 78 | ch.setLevel(logging.DEBUG) 79 | ch.setFormatter(formatter) 80 | logger.addHandler(ch) 81 | 82 | if is_master_process(num_gpu * num_shards): 83 | if len(output) > 0: 84 | if output.endswith(".txt") or output.endswith(".log"): 85 | filename = output 86 | else: 87 | filename = os.path.join(output, "logs.txt") 88 | 89 | PathManager.mkdirs(os.path.dirname(filename)) 90 | 91 | fh = logging.StreamHandler(_cached_log_stream(filename)) 92 | fh.setLevel(logging.DEBUG) 93 | fh.setFormatter(plain_formatter) 94 | logger.addHandler(fh) 95 | return logger 96 | 97 | 98 | def setup_single_logging(name, output=""): 99 | """Sets up the logging.""" 100 | # Enable logging only for the master process 101 | # Clear the root logger to prevent any existing logging config 102 | # (e.g. set by another module) from messing with our setup 103 | logging.root.handlers = [] 104 | # Configure logging 105 | logging.basicConfig( 106 | level=logging.INFO, format=_FORMAT, stream=sys.stdout 107 | ) 108 | 109 | if len(name) == 0: 110 | name = __name__ 111 | logger = logging.getLogger(name) 112 | logger.setLevel(logging.INFO) 113 | logger.propagate = False 114 | 115 | plain_formatter = logging.Formatter( 116 | "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s", 117 | datefmt="%m/%d %H:%M:%S", 118 | ) 119 | formatter = _ColorfulFormatter( 120 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 121 | datefmt="%m/%d %H:%M:%S", 122 | root_name=name, 123 | abbrev_name=str(name), 124 | ) 125 | 126 | ch = logging.StreamHandler(stream=sys.stdout) 127 | ch.setLevel(logging.DEBUG) 128 | ch.setFormatter(formatter) 129 | logger.addHandler(ch) 130 | 131 | if len(output) > 0: 132 | if output.endswith(".txt") or output.endswith(".log"): 133 | filename = output 134 | else: 135 | filename = os.path.join(output, "logs.txt") 136 | 137 | PathManager.mkdirs(os.path.dirname(filename)) 138 | 139 | fh = logging.StreamHandler(_cached_log_stream(filename)) 140 | fh.setLevel(logging.DEBUG) 141 | fh.setFormatter(plain_formatter) 142 | logger.addHandler(fh) 143 | 144 | return logger 145 | 146 | 147 | def get_logger(name): 148 | """Retrieves the logger.""" 149 | return logging.getLogger(name) 150 | 151 | 152 | def log_json_stats(stats, sort_keys=True): 153 | """Logs json stats.""" 154 | # It seems that in Python >= 3.6 json.encoder.FLOAT_REPR has no effect 155 | # Use decimal+string as a workaround for having fixed length values in logs 156 | logger = get_logger(__name__) 157 | stats = { 158 | k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v 159 | for k, v in stats.items() 160 | } 161 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True) 162 | if stats["_type"] == "test_epoch" or stats["_type"] == "train_epoch": 163 | logger.info("json_stats: {:s}".format(json_stats)) 164 | else: 165 | logger.info("{:s}".format(json_stats)) 166 | 167 | 168 | class _ColorfulFormatter(logging.Formatter): 169 | # from detectron2 170 | def __init__(self, *args, **kwargs): 171 | self._root_name = kwargs.pop("root_name") + "." 172 | self._abbrev_name = kwargs.pop("abbrev_name", "") 173 | if len(self._abbrev_name): 174 | self._abbrev_name = self._abbrev_name + "." 175 | super(_ColorfulFormatter, self).__init__(*args, **kwargs) 176 | 177 | def formatMessage(self, record: logging.LogRecord) -> str: 178 | record.name = record.name.replace(self._root_name, self._abbrev_name) 179 | log = super(_ColorfulFormatter, self).formatMessage(record) 180 | if record.levelno == logging.WARNING: 181 | prefix = colored("WARNING", "red", attrs=["blink"]) 182 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 183 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 184 | else: 185 | return log 186 | return prefix + " " + log 187 | -------------------------------------------------------------------------------- /src/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import torch 3 | 4 | 5 | def gpu_mem_usage(): 6 | """Computes the GPU memory usage for the current device (GB).""" 7 | if not torch.cuda.is_available(): 8 | return 0 9 | # Number of bytes in a megabyte 10 | _B_IN_GB = 1024 * 1024 * 1024 11 | 12 | mem_usage_bytes = torch.cuda.max_memory_allocated() 13 | return mem_usage_bytes / _B_IN_GB 14 | 15 | 16 | class AverageMeter(object): 17 | """Computes and stores the average and current value""" 18 | def __init__(self, name, fmt=':f'): 19 | self.name = name 20 | self.fmt = fmt 21 | self.reset() 22 | 23 | def reset(self): 24 | self.val = 0 25 | self.avg = 0 26 | self.sum = 0 27 | self.count = 0 28 | 29 | def update(self, val, n=1): 30 | self.val = val 31 | self.sum += val * n 32 | self.count += n 33 | self.avg = self.sum / self.count 34 | 35 | def __str__(self): 36 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 37 | return fmtstr.format(**self.__dict__) 38 | -------------------------------------------------------------------------------- /vtab_data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | --------------------------------------------------------------------------------