├── figure └── demo.gif ├── configs ├── linear │ ├── cub.yaml │ ├── dogs.yaml │ ├── flowers.yaml │ ├── cars.yaml │ └── nabirds.yaml ├── finetune │ ├── nabirds.yaml │ ├── cub.yaml │ ├── cars.yaml │ ├── dogs.yaml │ ├── flowers.yaml │ └── multitask.yaml ├── base-linear.yaml └── base-finetune.yaml ├── src ├── utils │ ├── file_io.py │ ├── train_utils.py │ ├── build_pruner.py │ ├── io_utils.py │ ├── pruner.py │ ├── distributed.py │ └── logging.py ├── data │ ├── vtab_datasets │ │ ├── __init__.py │ │ ├── patch_camelyon.py │ │ ├── sun397.py │ │ ├── dtd.py │ │ ├── dmlab.py │ │ ├── caltech.py │ │ ├── resisc45.py │ │ ├── oxford_iiit_pet.py │ │ ├── eurosat.py │ │ ├── cifar.py │ │ ├── smallnorb.py │ │ ├── oxford_flowers102.py │ │ ├── svhn.py │ │ ├── clevr.py │ │ ├── dsprites.py │ │ ├── registry.py │ │ └── kitti.py │ ├── transforms.py │ ├── loader.py │ └── datasets │ │ ├── tf_dataset.py │ │ ├── pt_dataset.py │ │ └── json_dataset.py ├── configs │ ├── config_node.py │ ├── vit_configs.py │ └── config.py ├── models │ ├── mlp.py │ ├── build_model.py │ ├── vit_adapter │ │ ├── adapter_block.py │ │ ├── vit_moco.py │ │ ├── vit_mae.py │ │ └── swin_adapter.py │ └── vit_backbones │ │ ├── vit_mae.py │ │ └── vit_moco.py ├── engine │ ├── eval │ │ ├── singlelabel.py │ │ └── multilabel.py │ └── evaluator.py └── solver │ ├── losses.py │ └── lr_scheduler.py ├── env_setup.sh ├── scripts └── mosa │ ├── swin │ └── FGVC │ │ ├── cars.sh │ │ ├── cub.sh │ │ ├── dogs.sh │ │ ├── nabirds.sh │ │ └── flowers.sh │ └── vit │ ├── FGVC │ ├── cars.sh │ ├── cub.sh │ ├── dogs.sh │ ├── nabirds.sh │ ├── flowers.sh │ └── multi.sh │ ├── GICD │ ├── food101.sh │ ├── aircraft.sh │ └── cifar100.sh │ └── VTAB │ ├── vtab_specialized.sh │ ├── vtab_natural.sh │ ├── vtab_structured.sh │ └── vtab.sh ├── README.md ├── .gitignore ├── launch.py └── VTAB_SETUP.md /figure/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Theia-4869/MoSA/HEAD/figure/demo.gif -------------------------------------------------------------------------------- /configs/linear/cub.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-linear.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "CUB" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 200 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.1 12 | WEIGHT_DECAY: 0.01 -------------------------------------------------------------------------------- /configs/linear/dogs.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-linear.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "StanfordDogs" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 120 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.001 12 | WEIGHT_DECAY: 0.0001 -------------------------------------------------------------------------------- /configs/linear/flowers.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-linear.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "OxfordFlowers" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 102 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.001 12 | WEIGHT_DECAY: 0.0001 -------------------------------------------------------------------------------- /configs/finetune/nabirds.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-finetune.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "nabirds" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 555 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.00375 12 | WEIGHT_DECAY: 0.01 13 | -------------------------------------------------------------------------------- /configs/linear/cars.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-linear.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "StanfordCars" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 196 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.001 12 | WEIGHT_DECAY: 0.0001 13 | -------------------------------------------------------------------------------- /configs/linear/nabirds.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-linear.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "nabirds" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 555 7 | MULTILABEL: False 8 | MODEL: 9 | TYPE: "vit" 10 | SOLVER: 11 | BASE_LR: 0.1 12 | WEIGHT_DECAY: 0.01 13 | LOG_EVERY_N: 10 -------------------------------------------------------------------------------- /configs/finetune/cub.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-finetune.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "CUB" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 200 7 | MULTILABEL: False 8 | FEATURE: "imagenet_supervised" # need to tune 9 | MODEL: 10 | TYPE: "vit" 11 | SOLVER: 12 | BASE_LR: 0.00375 13 | WEIGHT_DECAY: 0.01 14 | -------------------------------------------------------------------------------- /configs/finetune/cars.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-finetune.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "StanfordCars" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 196 7 | MULTILABEL: False 8 | FEATURE: "imagenet_supervised" # need to tune 9 | MODEL: 10 | TYPE: "vit" 11 | SOLVER: 12 | BASE_LR: 0.0375 13 | WEIGHT_DECAY: 0.001 14 | -------------------------------------------------------------------------------- /configs/finetune/dogs.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-finetune.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "StanfordDogs" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 120 7 | MULTILABEL: False 8 | FEATURE: "imagenet_supervised" # need to tune 9 | MODEL: 10 | TYPE: "vit" 11 | SOLVER: 12 | BASE_LR: 0.00375 13 | WEIGHT_DECAY: 0.001 14 | -------------------------------------------------------------------------------- /configs/finetune/flowers.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-finetune.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "OxfordFlowers" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 102 7 | MULTILABEL: False 8 | FEATURE: "imagenet_supervised" # need to tune 9 | MODEL: 10 | TYPE: "vit" 11 | SOLVER: 12 | BASE_LR: 0.001 13 | WEIGHT_DECAY: 0.0001 14 | -------------------------------------------------------------------------------- /src/utils/file_io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Project specific pathmanagers for a project as recommended by Detectron2 5 | """ 6 | from iopath.common.file_io import PathManager as PathManagerBase 7 | from iopath.common.file_io import HTTPURLHandler 8 | 9 | 10 | PathManager = PathManagerBase() 11 | PathManager.register_handler(HTTPURLHandler()) 12 | -------------------------------------------------------------------------------- /configs/finetune/multitask.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../base-finetune.yaml" 2 | RUN_N_TIMES: 1 3 | DATA: 4 | NAME: "multi" 5 | DATAPATH: "" #TODO: need to specify here 6 | NUMBER_CLASSES: 1173 # 200+55+102+120+196 7 | MULTILABEL: False 8 | FEATURE: "imagenet_supervised" # need to tune 9 | MODEL: 10 | TYPE: "vit" 11 | SOLVER: 12 | BASE_LR: 0.00375 13 | WEIGHT_DECAY: 0.01 14 | -------------------------------------------------------------------------------- /configs/base-linear.yaml: -------------------------------------------------------------------------------- 1 | NUM_GPUS: 1 2 | NUM_SHARDS: 1 3 | OUTPUT_DIR: "" 4 | RUN_N_TIMES: 1 5 | MODEL: 6 | TRANSFER_TYPE: "linear" 7 | TYPE: "vit" 8 | LINEAR: 9 | MLP_SIZES: [] 10 | SOLVER: 11 | SCHEDULER: "cosine" 12 | PATIENCE: 300 13 | LOSS: "softmax" 14 | OPTIMIZER: "sgd" 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 0.0001 17 | LOG_EVERY_N: 100 18 | WARMUP_EPOCH: 10 19 | TOTAL_EPOCH: 100 20 | DATA: 21 | NAME: "" 22 | NUMBER_CLASSES: -1 23 | DATAPATH: "" 24 | FEATURE: "sup_vitb16_224" 25 | BATCH_SIZE: 1024 -------------------------------------------------------------------------------- /configs/base-finetune.yaml: -------------------------------------------------------------------------------- 1 | NUM_GPUS: 1 2 | NUM_SHARDS: 1 3 | OUTPUT_DIR: "" 4 | RUN_N_TIMES: 1 5 | MODEL: 6 | TRANSFER_TYPE: "end2end" 7 | TYPE: "vit" 8 | LINEAR: 9 | MLP_SIZES: [] 10 | SOLVER: 11 | SCHEDULER: "cosine" 12 | WARMUP_EPOCH: 10 13 | LOSS: "softmax" 14 | OPTIMIZER: "adamw" 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 0.0001 17 | LOG_EVERY_N: 100 18 | TOTAL_EPOCH: 100 19 | PATIENCE: 300 20 | DATA: 21 | NAME: "" 22 | NUMBER_CLASSES: -1 23 | DATAPATH: "" 24 | FEATURE: "sup_vitb16_224" 25 | BATCH_SIZE: 384 26 | NUM_WORKERS: 8 -------------------------------------------------------------------------------- /src/data/vtab_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | -------------------------------------------------------------------------------- /src/configs/config_node.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Config system (based on Detectron's).""" 4 | 5 | from fvcore.common.config import CfgNode as _CfgNode 6 | from ..utils.file_io import PathManager 7 | 8 | 9 | class CfgNode(_CfgNode): 10 | """ 11 | The same as `fvcore.common.config.CfgNode`, but different in: 12 | 13 | support manifold path 14 | """ 15 | 16 | @classmethod 17 | def _open_cfg(cls, filename): 18 | return PathManager.open(filename, "r") 19 | 20 | def dump(self, *args, **kwargs): 21 | """ 22 | Returns: 23 | str: a yaml string representation of the config 24 | """ 25 | # to make it show up in docs 26 | return super().dump(*args, **kwargs) 27 | -------------------------------------------------------------------------------- /env_setup.sh: -------------------------------------------------------------------------------- 1 | conda create -n mosa python=3.7 2 | conda activate mosa 3 | 4 | pip install -q tensorflow 5 | # specifying tfds versions is important to reproduce our results 6 | pip install tfds-nightly==4.4.0.dev202201080107 7 | pip install opencv-python 8 | pip install tensorflow-addons 9 | pip install mock 10 | 11 | 12 | conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch 13 | 14 | python -m pip install detectron2 -f \ 15 | https://dl.fbaipublicfiles.com/detectron2/wheels/cu110/torch1.7/index.html 16 | pip install opencv-python 17 | 18 | conda install tqdm pandas matplotlib seaborn scikit-learn scipy simplejson termcolor 19 | conda install -c iopath iopath 20 | 21 | 22 | # for transformers 23 | pip install timm==0.4.12 24 | pip install ml-collections 25 | 26 | # Optional: for slurm jobs 27 | pip install submitit -U 28 | pip install slurm_gpustat -------------------------------------------------------------------------------- /scripts/mosa/swin/FGVC/cars.sh: -------------------------------------------------------------------------------- 1 | # launch final training for FGVC-cars. 2 | 3 | gpu_id=${1} 4 | data_path=/your/path/to/FGVC/StanfordCars 5 | model_root=checkpoints 6 | output_dir=${2} 7 | 8 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \ 9 | --config-file configs/finetune/cars.yaml \ 10 | --sparse-train \ 11 | DATA.BATCH_SIZE "128" \ 12 | DATA.FEATURE "swinb_imagenet22k_224" \ 13 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \ 14 | MODEL.ADAPTER.EXPERT_NUM "2" \ 15 | MODEL.ADAPTER.MOE "True" \ 16 | MODEL.ADAPTER.MERGE "add" \ 17 | MODEL.ADAPTER.SHARE "up" \ 18 | MODEL.ADAPTER.ADDITIONAL "True" \ 19 | MODEL.ADAPTER.DEEPREG "True" \ 20 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \ 21 | MODEL.ADAPTER.REG_WEIGHT "1.0" \ 22 | MODEL.TRANSFER_TYPE "mosa" \ 23 | MODEL.TYPE "swin" \ 24 | SEED "3407" \ 25 | SOLVER.BASE_LR "0.005" \ 26 | SOLVER.WEIGHT_DECAY "0.01" \ 27 | SOLVER.WARMUP_EPOCH "10" \ 28 | DATA.DATAPATH "${data_path}" \ 29 | GPU_ID "${gpu_id}" \ 30 | MODEL.MODEL_ROOT "${model_root}" \ 31 | OUTPUT_DIR "${output_dir}" 32 | -------------------------------------------------------------------------------- /scripts/mosa/swin/FGVC/cub.sh: -------------------------------------------------------------------------------- 1 | # launch final training for FGVC-cub. 2 | 3 | gpu_id=${1} 4 | data_path=/your/path/to/FGVC/CUB_200_2011 5 | model_root=checkpoints 6 | output_dir=${2} 7 | 8 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \ 9 | --config-file configs/finetune/cub.yaml \ 10 | --sparse-train \ 11 | DATA.BATCH_SIZE "128" \ 12 | DATA.FEATURE "swinb_imagenet22k_224" \ 13 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \ 14 | MODEL.ADAPTER.EXPERT_NUM "4" \ 15 | MODEL.ADAPTER.MOE "True" \ 16 | MODEL.ADAPTER.MERGE "add" \ 17 | MODEL.ADAPTER.SHARE "down" \ 18 | MODEL.ADAPTER.ADDITIONAL "True" \ 19 | MODEL.ADAPTER.DEEPREG "False" \ 20 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \ 21 | MODEL.ADAPTER.REG_WEIGHT "1.0" \ 22 | MODEL.TRANSFER_TYPE "mosa" \ 23 | MODEL.TYPE "swin" \ 24 | SEED "3407" \ 25 | SOLVER.BASE_LR "0.001" \ 26 | SOLVER.WEIGHT_DECAY "0.01" \ 27 | SOLVER.WARMUP_EPOCH "10" \ 28 | DATA.DATAPATH "${data_path}" \ 29 | GPU_ID "${gpu_id}" \ 30 | MODEL.MODEL_ROOT "${model_root}" \ 31 | OUTPUT_DIR "${output_dir}" 32 | -------------------------------------------------------------------------------- /scripts/mosa/swin/FGVC/dogs.sh: -------------------------------------------------------------------------------- 1 | # launch final training for FGVC-dogs. 2 | 3 | gpu_id=${1} 4 | data_path=/your/path/to/FGVC/StanfordDogs 5 | model_root=checkpoints 6 | output_dir=${2} 7 | 8 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \ 9 | --config-file configs/finetune/dogs.yaml \ 10 | --sparse-train \ 11 | DATA.BATCH_SIZE "128" \ 12 | DATA.FEATURE "swinb_imagenet22k_224" \ 13 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \ 14 | MODEL.ADAPTER.EXPERT_NUM "4" \ 15 | MODEL.ADAPTER.MOE "True" \ 16 | MODEL.ADAPTER.MERGE "add" \ 17 | MODEL.ADAPTER.SHARE "down" \ 18 | MODEL.ADAPTER.ADDITIONAL "True" \ 19 | MODEL.ADAPTER.DEEPREG "False" \ 20 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \ 21 | MODEL.ADAPTER.REG_WEIGHT "1.0" \ 22 | MODEL.TRANSFER_TYPE "mosa" \ 23 | MODEL.TYPE "swin" \ 24 | SEED "3407" \ 25 | SOLVER.BASE_LR "0.0005" \ 26 | SOLVER.WEIGHT_DECAY "0.01" \ 27 | SOLVER.WARMUP_EPOCH "10" \ 28 | DATA.DATAPATH "${data_path}" \ 29 | GPU_ID "${gpu_id}" \ 30 | MODEL.MODEL_ROOT "${model_root}" \ 31 | OUTPUT_DIR "${output_dir}" 32 | -------------------------------------------------------------------------------- /scripts/mosa/swin/FGVC/nabirds.sh: -------------------------------------------------------------------------------- 1 | # launch final training for FGVC-nabirds. 2 | 3 | gpu_id=${1} 4 | data_path=/your/path/to/FGVC/nabirds 5 | model_root=checkpoints 6 | output_dir=${2} 7 | 8 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \ 9 | --config-file configs/finetune/nabirds.yaml \ 10 | --sparse-train \ 11 | DATA.BATCH_SIZE "128" \ 12 | DATA.FEATURE "swinb_imagenet22k_224" \ 13 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \ 14 | MODEL.ADAPTER.EXPERT_NUM "4" \ 15 | MODEL.ADAPTER.MOE "True" \ 16 | MODEL.ADAPTER.MERGE "add" \ 17 | MODEL.ADAPTER.SHARE "down" \ 18 | MODEL.ADAPTER.ADDITIONAL "True" \ 19 | MODEL.ADAPTER.DEEPREG "False" \ 20 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \ 21 | MODEL.ADAPTER.REG_WEIGHT "1.0" \ 22 | MODEL.TRANSFER_TYPE "mosa" \ 23 | MODEL.TYPE "swin" \ 24 | SEED "3407" \ 25 | SOLVER.BASE_LR "0.0005" \ 26 | SOLVER.WEIGHT_DECAY "0.01" \ 27 | SOLVER.WARMUP_EPOCH "10" \ 28 | DATA.DATAPATH "${data_path}" \ 29 | GPU_ID "${gpu_id}" \ 30 | MODEL.MODEL_ROOT "${model_root}" \ 31 | OUTPUT_DIR "${output_dir}" 32 | -------------------------------------------------------------------------------- /scripts/mosa/swin/FGVC/flowers.sh: -------------------------------------------------------------------------------- 1 | # launch final training for FGVC-flowers. 2 | 3 | gpu_id=${1} 4 | data_path=/your/path/to/FGVC/OxfordFlowers 5 | model_root=checkpoints 6 | output_dir=${2} 7 | 8 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \ 9 | --config-file configs/finetune/flowers.yaml \ 10 | --sparse-train \ 11 | DATA.BATCH_SIZE "128" \ 12 | DATA.FEATURE "swinb_imagenet22k_224" \ 13 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \ 14 | MODEL.ADAPTER.EXPERT_NUM "4" \ 15 | MODEL.ADAPTER.MOE "True" \ 16 | MODEL.ADAPTER.MERGE "add" \ 17 | MODEL.ADAPTER.SHARE "down" \ 18 | MODEL.ADAPTER.ADDITIONAL "True" \ 19 | MODEL.ADAPTER.DEEPREG "False" \ 20 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \ 21 | MODEL.ADAPTER.REG_WEIGHT "1.0" \ 22 | MODEL.TRANSFER_TYPE "mosa" \ 23 | MODEL.TYPE "swin" \ 24 | SEED "3407" \ 25 | SOLVER.BASE_LR "0.005" \ 26 | SOLVER.WEIGHT_DECAY "0.01" \ 27 | SOLVER.WARMUP_EPOCH "10" \ 28 | DATA.DATAPATH "${data_path}" \ 29 | GPU_ID "${gpu_id}" \ 30 | MODEL.MODEL_ROOT "${model_root}" \ 31 | OUTPUT_DIR "${output_dir}" 32 | -------------------------------------------------------------------------------- /src/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import torch 3 | 4 | 5 | def gpu_mem_usage(): 6 | """Computes the GPU memory usage for the current device (GB).""" 7 | if not torch.cuda.is_available(): 8 | return 0 9 | # Number of bytes in a megabyte 10 | _B_IN_GB = 1024 * 1024 * 1024 11 | 12 | mem_usage_bytes = torch.cuda.max_memory_allocated() 13 | return mem_usage_bytes / _B_IN_GB 14 | 15 | 16 | class AverageMeter(object): 17 | """Computes and stores the average and current value""" 18 | def __init__(self, name, fmt=':f'): 19 | self.name = name 20 | self.fmt = fmt 21 | self.reset() 22 | 23 | def reset(self): 24 | self.val = 0 25 | self.avg = 0 26 | self.sum = 0 27 | self.count = 0 28 | 29 | def update(self, val, n=1): 30 | self.val = val 31 | self.sum += val * n 32 | self.count += n 33 | self.avg = self.sum / self.count 34 | 35 | def __str__(self): 36 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 37 | return fmtstr.format(**self.__dict__) 38 | -------------------------------------------------------------------------------- /scripts/mosa/vit/FGVC/cars.sh: -------------------------------------------------------------------------------- 1 | # launch final training for FGVC-cars. 2 | 3 | gpu_id=${1} 4 | data_path=/your/path/to/FGVC/StanfordCars 5 | model_root=checkpoints 6 | style=${2} 7 | output_dir=${3} 8 | 9 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \ 10 | --config-file configs/finetune/cars.yaml \ 11 | --sparse-train \ 12 | DATA.BATCH_SIZE "128" \ 13 | DATA.FEATURE "sup_vitb16_imagenet21k" \ 14 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \ 15 | MODEL.ADAPTER.STYLE "${style}" \ 16 | MODEL.ADAPTER.EXPERT_NUM "2" \ 17 | MODEL.ADAPTER.MOE "True" \ 18 | MODEL.ADAPTER.MERGE "add" \ 19 | MODEL.ADAPTER.SHARE "up" \ 20 | MODEL.ADAPTER.ADDITIONAL "True" \ 21 | MODEL.ADAPTER.DEEPREG "True" \ 22 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \ 23 | MODEL.ADAPTER.REG_WEIGHT "1.0" \ 24 | MODEL.TRANSFER_TYPE "mosa" \ 25 | MODEL.TYPE "vit" \ 26 | SEED "3407" \ 27 | SOLVER.BASE_LR "0.005" \ 28 | SOLVER.WEIGHT_DECAY "0.01" \ 29 | SOLVER.WARMUP_EPOCH "10" \ 30 | DATA.DATAPATH "${data_path}" \ 31 | GPU_ID "${gpu_id}" \ 32 | MODEL.MODEL_ROOT "${model_root}" \ 33 | OUTPUT_DIR "${output_dir}" 34 | -------------------------------------------------------------------------------- /scripts/mosa/vit/FGVC/cub.sh: -------------------------------------------------------------------------------- 1 | # launch final training for FGVC-cub. 2 | 3 | gpu_id=${1} 4 | data_path=/your/path/to/FGVC/CUB_200_2011 5 | model_root=checkpoints 6 | style=${2} 7 | output_dir=${3} 8 | 9 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \ 10 | --config-file configs/finetune/cub.yaml \ 11 | --sparse-train \ 12 | DATA.BATCH_SIZE "128" \ 13 | DATA.FEATURE "sup_vitb16_imagenet21k" \ 14 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \ 15 | MODEL.ADAPTER.STYLE "${style}" \ 16 | MODEL.ADAPTER.EXPERT_NUM "4" \ 17 | MODEL.ADAPTER.MOE "True" \ 18 | MODEL.ADAPTER.MERGE "add" \ 19 | MODEL.ADAPTER.SHARE "down" \ 20 | MODEL.ADAPTER.ADDITIONAL "True" \ 21 | MODEL.ADAPTER.DEEPREG "False" \ 22 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \ 23 | MODEL.ADAPTER.REG_WEIGHT "1.0" \ 24 | MODEL.TRANSFER_TYPE "mosa" \ 25 | MODEL.TYPE "vit" \ 26 | SEED "3407" \ 27 | SOLVER.BASE_LR "0.001" \ 28 | SOLVER.WEIGHT_DECAY "0.01" \ 29 | SOLVER.WARMUP_EPOCH "10" \ 30 | DATA.DATAPATH "${data_path}" \ 31 | GPU_ID "${gpu_id}" \ 32 | MODEL.MODEL_ROOT "${model_root}" \ 33 | OUTPUT_DIR "${output_dir}" 34 | -------------------------------------------------------------------------------- /scripts/mosa/vit/FGVC/dogs.sh: -------------------------------------------------------------------------------- 1 | # launch final training for FGVC-dogs. 2 | 3 | gpu_id=${1} 4 | data_path=/your/path/to/FGVC/StanfordDogs 5 | model_root=checkpoints 6 | style=${2} 7 | output_dir=${3} 8 | 9 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \ 10 | --config-file configs/finetune/dogs.yaml \ 11 | --sparse-train \ 12 | DATA.BATCH_SIZE "128" \ 13 | DATA.FEATURE "sup_vitb16_imagenet21k" \ 14 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \ 15 | MODEL.ADAPTER.STYLE "${style}" \ 16 | MODEL.ADAPTER.EXPERT_NUM "4" \ 17 | MODEL.ADAPTER.MOE "True" \ 18 | MODEL.ADAPTER.MERGE "add" \ 19 | MODEL.ADAPTER.SHARE "down" \ 20 | MODEL.ADAPTER.ADDITIONAL "True" \ 21 | MODEL.ADAPTER.DEEPREG "False" \ 22 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \ 23 | MODEL.ADAPTER.REG_WEIGHT "1.0" \ 24 | MODEL.TRANSFER_TYPE "mosa" \ 25 | MODEL.TYPE "vit" \ 26 | SEED "3407" \ 27 | SOLVER.BASE_LR "0.0005" \ 28 | SOLVER.WEIGHT_DECAY "0.01" \ 29 | SOLVER.WARMUP_EPOCH "10" \ 30 | DATA.DATAPATH "${data_path}" \ 31 | GPU_ID "${gpu_id}" \ 32 | MODEL.MODEL_ROOT "${model_root}" \ 33 | OUTPUT_DIR "${output_dir}" 34 | -------------------------------------------------------------------------------- /scripts/mosa/vit/FGVC/nabirds.sh: -------------------------------------------------------------------------------- 1 | # launch final training for FGVC-nabirds. 2 | 3 | gpu_id=${1} 4 | data_path=/your/path/to/FGVC/nabirds 5 | model_root=checkpoints 6 | style=${2} 7 | output_dir=${3} 8 | 9 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \ 10 | --config-file configs/finetune/nabirds.yaml \ 11 | --sparse-train \ 12 | DATA.BATCH_SIZE "128" \ 13 | DATA.FEATURE "sup_vitb16_imagenet21k" \ 14 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \ 15 | MODEL.ADAPTER.STYLE "${style}" \ 16 | MODEL.ADAPTER.EXPERT_NUM "4" \ 17 | MODEL.ADAPTER.MOE "True" \ 18 | MODEL.ADAPTER.MERGE "add" \ 19 | MODEL.ADAPTER.SHARE "down" \ 20 | MODEL.ADAPTER.ADDITIONAL "True" \ 21 | MODEL.ADAPTER.DEEPREG "False" \ 22 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \ 23 | MODEL.ADAPTER.REG_WEIGHT "1.0" \ 24 | MODEL.TRANSFER_TYPE "mosa" \ 25 | MODEL.TYPE "vit" \ 26 | SEED "3407" \ 27 | SOLVER.BASE_LR "0.0005" \ 28 | SOLVER.WEIGHT_DECAY "0.01" \ 29 | SOLVER.WARMUP_EPOCH "10" \ 30 | DATA.DATAPATH "${data_path}" \ 31 | GPU_ID "${gpu_id}" \ 32 | MODEL.MODEL_ROOT "${model_root}" \ 33 | OUTPUT_DIR "${output_dir}" 34 | -------------------------------------------------------------------------------- /scripts/mosa/vit/FGVC/flowers.sh: -------------------------------------------------------------------------------- 1 | # launch final training for FGVC-flowers. 2 | 3 | gpu_id=${1} 4 | data_path=/your/path/to/FGVC/OxfordFlowers 5 | model_root=checkpoints 6 | style=${2} 7 | output_dir=${3} 8 | 9 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \ 10 | --config-file configs/finetune/flowers.yaml \ 11 | --sparse-train \ 12 | DATA.BATCH_SIZE "128" \ 13 | DATA.FEATURE "sup_vitb16_imagenet21k" \ 14 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \ 15 | MODEL.ADAPTER.STYLE "${style}" \ 16 | MODEL.ADAPTER.EXPERT_NUM "4" \ 17 | MODEL.ADAPTER.MOE "True" \ 18 | MODEL.ADAPTER.MERGE "add" \ 19 | MODEL.ADAPTER.SHARE "down" \ 20 | MODEL.ADAPTER.ADDITIONAL "True" \ 21 | MODEL.ADAPTER.DEEPREG "False" \ 22 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \ 23 | MODEL.ADAPTER.REG_WEIGHT "1.0" \ 24 | MODEL.TRANSFER_TYPE "mosa" \ 25 | MODEL.TYPE "vit" \ 26 | SEED "3407" \ 27 | SOLVER.BASE_LR "0.005" \ 28 | SOLVER.WEIGHT_DECAY "0.01" \ 29 | SOLVER.WARMUP_EPOCH "10" \ 30 | DATA.DATAPATH "${data_path}" \ 31 | GPU_ID "${gpu_id}" \ 32 | MODEL.MODEL_ROOT "${model_root}" \ 33 | OUTPUT_DIR "${output_dir}" 34 | -------------------------------------------------------------------------------- /scripts/mosa/vit/GICD/food101.sh: -------------------------------------------------------------------------------- 1 | # launch final training for GICD-food101. 2 | 3 | gpu_id=${1} 4 | data_path=/your/path/to/GICD/Food101 5 | model_root=checkpoints 6 | output_dir=${2} 7 | 8 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \ 9 | --config-file configs/finetune/cub.yaml \ 10 | --sparse-train \ 11 | DATA.BATCH_SIZE "128" \ 12 | DATA.FEATURE "sup_vitb16_imagenet21k" \ 13 | DATA.NAME "food101" \ 14 | DATA.NUMBER_CLASSES "101" \ 15 | DATA.NO_TEST "True" \ 16 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \ 17 | MODEL.ADAPTER.EXPERT_NUM "4" \ 18 | MODEL.ADAPTER.MOE "True" \ 19 | MODEL.ADAPTER.MERGE "add" \ 20 | MODEL.ADAPTER.SHARE "down" \ 21 | MODEL.ADAPTER.ADDITIONAL "True" \ 22 | MODEL.ADAPTER.DEEPREG "False" \ 23 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \ 24 | MODEL.ADAPTER.REG_WEIGHT "1.0" \ 25 | MODEL.TRANSFER_TYPE "mosa" \ 26 | MODEL.TYPE "vit" \ 27 | SEED "3407" \ 28 | SOLVER.BASE_LR "0.0005" \ 29 | SOLVER.WEIGHT_DECAY "0.01" \ 30 | SOLVER.WARMUP_EPOCH "10" \ 31 | DATA.DATAPATH "${data_path}" \ 32 | GPU_ID "${gpu_id}" \ 33 | MODEL.MODEL_ROOT "${model_root}" \ 34 | OUTPUT_DIR "${output_dir}" 35 | -------------------------------------------------------------------------------- /scripts/mosa/vit/GICD/aircraft.sh: -------------------------------------------------------------------------------- 1 | # launch final training for GICD-aricraft. 2 | 3 | gpu_id=${1} 4 | data_path=/your/path/to/GICD/FGVCAircraft 5 | model_root=checkpoints 6 | output_dir=${2} 7 | 8 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \ 9 | --config-file configs/finetune/cub.yaml \ 10 | --sparse-train \ 11 | DATA.BATCH_SIZE "128" \ 12 | DATA.FEATURE "sup_vitb16_imagenet21k" \ 13 | DATA.NAME "aircraft" \ 14 | DATA.NUMBER_CLASSES "100" \ 15 | DATA.NO_TEST "True" \ 16 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \ 17 | MODEL.ADAPTER.EXPERT_NUM "4" \ 18 | MODEL.ADAPTER.MOE "True" \ 19 | MODEL.ADAPTER.MERGE "add" \ 20 | MODEL.ADAPTER.SHARE "down" \ 21 | MODEL.ADAPTER.ADDITIONAL "True" \ 22 | MODEL.ADAPTER.DEEPREG "False" \ 23 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \ 24 | MODEL.ADAPTER.REG_WEIGHT "1.0" \ 25 | MODEL.TRANSFER_TYPE "mosa" \ 26 | MODEL.TYPE "vit" \ 27 | SEED "3407" \ 28 | SOLVER.BASE_LR "0.005" \ 29 | SOLVER.WEIGHT_DECAY "0.01" \ 30 | SOLVER.WARMUP_EPOCH "10" \ 31 | DATA.DATAPATH "${data_path}" \ 32 | GPU_ID "${gpu_id}" \ 33 | MODEL.MODEL_ROOT "${model_root}" \ 34 | OUTPUT_DIR "${output_dir}" 35 | -------------------------------------------------------------------------------- /scripts/mosa/vit/GICD/cifar100.sh: -------------------------------------------------------------------------------- 1 | # launch final training for GICD-cifar100. 2 | 3 | gpu_id=${1} 4 | data_path=/your/path/to/GICD/CIFAR100 5 | model_root=checkpoints 6 | output_dir=${2} 7 | 8 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \ 9 | --config-file configs/finetune/cub.yaml \ 10 | --sparse-train \ 11 | DATA.BATCH_SIZE "128" \ 12 | DATA.FEATURE "sup_vitb16_imagenet21k" \ 13 | DATA.NAME "cifar100" \ 14 | DATA.NUMBER_CLASSES "100" \ 15 | DATA.NO_TEST "True" \ 16 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \ 17 | MODEL.ADAPTER.EXPERT_NUM "4" \ 18 | MODEL.ADAPTER.MOE "True" \ 19 | MODEL.ADAPTER.MERGE "add" \ 20 | MODEL.ADAPTER.SHARE "down" \ 21 | MODEL.ADAPTER.ADDITIONAL "True" \ 22 | MODEL.ADAPTER.DEEPREG "False" \ 23 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \ 24 | MODEL.ADAPTER.REG_WEIGHT "1.0" \ 25 | MODEL.TRANSFER_TYPE "mosa" \ 26 | MODEL.TYPE "vit" \ 27 | SEED "3407" \ 28 | SOLVER.BASE_LR "0.0005" \ 29 | SOLVER.WEIGHT_DECAY "0.01" \ 30 | SOLVER.WARMUP_EPOCH "10" \ 31 | DATA.DATAPATH "${data_path}" \ 32 | GPU_ID "${gpu_id}" \ 33 | MODEL.MODEL_ROOT "${model_root}" \ 34 | OUTPUT_DIR "${output_dir}" 35 | -------------------------------------------------------------------------------- /scripts/mosa/vit/FGVC/multi.sh: -------------------------------------------------------------------------------- 1 | # launch final training for FGVC-multitask. The hyperparameters are the same from our paper. 2 | 3 | gpu_id=${1} 4 | data_path=/data/dataset/ 5 | model_root=checkpoints 6 | style=${2} 7 | output_dir=${3} 8 | 9 | # rm -rf ${output_dir} 10 | 11 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \ 12 | --config-file configs/finetune/multitask.yaml \ 13 | --sparse-train \ 14 | DATA.BATCH_SIZE "128" \ 15 | DATA.FEATURE "sup_vitb16_imagenet21k" \ 16 | DATA.NO_TEST "False" \ 17 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \ 18 | MODEL.ADAPTER.STYLE "${style}" \ 19 | MODEL.ADAPTER.EXPERT_NUM "2" \ 20 | MODEL.ADAPTER.MOE "True" \ 21 | MODEL.ADAPTER.MERGE "add" \ 22 | MODEL.ADAPTER.SHARE "none" \ 23 | MODEL.ADAPTER.ADDITIONAL "True" \ 24 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \ 25 | MODEL.ADAPTER.REG_WEIGHT "1.0" \ 26 | MODEL.TRANSFER_TYPE "mosa" \ 27 | MODEL.TYPE "vit" \ 28 | SEED "3407" \ 29 | SOLVER.BASE_LR "0.0005" \ 30 | SOLVER.WEIGHT_DECAY "0.01" \ 31 | SOLVER.WARMUP_EPOCH "10" \ 32 | DATA.DATAPATH "${data_path}" \ 33 | GPU_ID "${gpu_id}" \ 34 | MODEL.MODEL_ROOT "${model_root}" \ 35 | OUTPUT_DIR "${output_dir}" 36 | -------------------------------------------------------------------------------- /src/utils/build_pruner.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Pruner construction functions. 4 | """ 5 | 6 | from .pruner import Rand 7 | from . import logging 8 | logger = logging.get_logger("MOSA") 9 | # Supported pruner types 10 | _PRUNER_TYPES = { 11 | "random": Rand, 12 | } 13 | 14 | 15 | def build_pruner(cfg): 16 | """ 17 | build pruner here 18 | """ 19 | assert ( 20 | cfg.PRUNER.TYPE in _PRUNER_TYPES.keys() 21 | ), "Model type '{}' not supported".format(cfg.PRUNER.TYPE) 22 | 23 | # Construct the pruner 24 | prune_type = cfg.PRUNER.TYPE 25 | pruner = _PRUNER_TYPES[prune_type](cfg) 26 | 27 | return pruner 28 | 29 | 30 | def log_pruned_model_info(model, verbose=False): 31 | """Logs pruned model info""" 32 | if verbose: 33 | logger.info(f"Classification Model:\n{model}") 34 | model_total_params = sum(p.numel() for p in model.parameters()) 35 | model_grad_params = sum(int(p.mask.sum()) if hasattr(p, 'mask') else p.numel() for p in model.parameters() if p.requires_grad) 36 | logger.info("Total Parameters: {0}\t Gradient Parameters: {1}".format( 37 | model_total_params, model_grad_params)) 38 | logger.info("tuned percent:%.3f"%(model_grad_params/model_total_params*100)) 39 | -------------------------------------------------------------------------------- /src/data/transforms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Image transformations.""" 4 | import torchvision as tv 5 | 6 | 7 | def get_transforms(split, size): 8 | normalize = tv.transforms.Normalize( 9 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 10 | ) 11 | if size == 448: 12 | resize_dim = 512 13 | crop_dim = 448 14 | elif size == 384: 15 | resize_dim = 438 16 | crop_dim = 384 17 | elif size == 224: 18 | resize_dim = 256 19 | crop_dim = 224 20 | if split == "train": 21 | transform = tv.transforms.Compose( 22 | [ 23 | tv.transforms.Resize(resize_dim), 24 | tv.transforms.RandomCrop(crop_dim), 25 | tv.transforms.RandomHorizontalFlip(0.5), 26 | # tv.transforms.RandomResizedCrop(224, scale=(0.2, 1.0)), 27 | # tv.transforms.RandomHorizontalFlip(), 28 | tv.transforms.ToTensor(), 29 | normalize, 30 | ] 31 | ) 32 | else: 33 | transform = tv.transforms.Compose( 34 | [ 35 | tv.transforms.Resize(resize_dim), 36 | tv.transforms.CenterCrop(crop_dim), 37 | tv.transforms.ToTensor(), 38 | normalize, 39 | ] 40 | ) 41 | return transform 42 | -------------------------------------------------------------------------------- /scripts/mosa/vit/VTAB/vtab_specialized.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_id=${1} 4 | data_path=/your/path/to/VTAB 5 | model_root=checkpoints 6 | dataset=("vtab-patch_camelyon" "vtab-eurosat" "vtab-resisc45" 'vtab-diabetic_retinopathy(config="btgraham-300")') 7 | number_classes=(2 10 45 5) 8 | output_dir=${2} 9 | 10 | for idx in {0..3}; do 11 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \ 12 | --config-file configs/finetune/cub.yaml \ 13 | --sparse-train \ 14 | DATA.BATCH_SIZE "128" \ 15 | DATA.FEATURE "sup_vitb16_imagenet21k" \ 16 | DATA.NAME ${dataset[idx]} \ 17 | DATA.NUMBER_CLASSES "${number_classes[idx]}" \ 18 | MODEL.ADAPTER.BOTTLENECK_SIZE "16" \ 19 | MODEL.ADAPTER.EXPERT_NUM "4" \ 20 | MODEL.ADAPTER.MOE "True" \ 21 | MODEL.ADAPTER.MERGE "add" \ 22 | MODEL.ADAPTER.SHARE "down" \ 23 | MODEL.ADAPTER.ADDITIONAL "True" \ 24 | MODEL.ADAPTER.DEEPREG "False" \ 25 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \ 26 | MODEL.ADAPTER.REG_WEIGHT "1.0" \ 27 | MODEL.TRANSFER_TYPE "mosa" \ 28 | MODEL.TYPE "vit" \ 29 | SEED "3407" \ 30 | SOLVER.BASE_LR "0.005" \ 31 | SOLVER.WEIGHT_DECAY "0.01" \ 32 | SOLVER.WARMUP_EPOCH "10" \ 33 | DATA.DATAPATH "${data_path}" \ 34 | GPU_ID "${gpu_id}" \ 35 | MODEL.MODEL_ROOT "${model_root}" \ 36 | OUTPUT_DIR "${output_dir}" 37 | done -------------------------------------------------------------------------------- /scripts/mosa/vit/VTAB/vtab_natural.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_id=${1} 4 | data_path=/your/path/to/VTAB 5 | model_root=checkpoints 6 | dataset=("vtab-cifar(num_classes=100)" "vtab-caltech101" "vtab-dtd" "vtab-oxford_flowers102" "vtab-oxford_iiit_pet" "vtab-svhn" "vtab-sun397") 7 | number_classes=(100 102 47 102 37 10 397) 8 | output_dir=${2} 9 | 10 | for idx in {0..6}; do 11 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \ 12 | --config-file configs/finetune/cub.yaml \ 13 | --sparse-train \ 14 | DATA.BATCH_SIZE "128" \ 15 | DATA.FEATURE "sup_vitb16_imagenet21k" \ 16 | DATA.NAME ${dataset[idx]} \ 17 | DATA.NUMBER_CLASSES "${number_classes[idx]}" \ 18 | MODEL.ADAPTER.BOTTLENECK_SIZE "16" \ 19 | MODEL.ADAPTER.EXPERT_NUM "4" \ 20 | MODEL.ADAPTER.MOE "True" \ 21 | MODEL.ADAPTER.MERGE "add" \ 22 | MODEL.ADAPTER.SHARE "down" \ 23 | MODEL.ADAPTER.ADDITIONAL "True" \ 24 | MODEL.ADAPTER.DEEPREG "False" \ 25 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \ 26 | MODEL.ADAPTER.REG_WEIGHT "1.0" \ 27 | MODEL.TRANSFER_TYPE "mosa" \ 28 | MODEL.TYPE "vit" \ 29 | SEED "3407" \ 30 | SOLVER.BASE_LR "0.005" \ 31 | SOLVER.WEIGHT_DECAY "0.01" \ 32 | SOLVER.WARMUP_EPOCH "10" \ 33 | DATA.DATAPATH "${data_path}" \ 34 | GPU_ID "${gpu_id}" \ 35 | MODEL.MODEL_ROOT "${model_root}" \ 36 | OUTPUT_DIR "${output_dir}" 37 | done -------------------------------------------------------------------------------- /scripts/mosa/vit/VTAB/vtab_structured.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_id=${1} 4 | data_path=/your/path/to/VTAB 5 | model_root=checkpoints 6 | dataset=('vtab-clevr(task="count_all")' 'vtab-clevr(task="closest_object_distance")' "vtab-dmlab" 'vtab-kitti(task="closest_vehicle_distance")' 'vtab-dsprites(predicted_attribute="label_x_position",num_classes=16)' 'vtab-dsprites(predicted_attribute="label_orientation",num_classes=16)' 'vtab-smallnorb(predicted_attribute="label_azimuth")' 'vtab-smallnorb(predicted_attribute="label_elevation")') 7 | number_classes=(8 6 6 4 16 16 18 9) 8 | output_dir=${2} 9 | 10 | for idx in {0..7}; do 11 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \ 12 | --config-file configs/finetune/cub.yaml \ 13 | --sparse-train \ 14 | DATA.BATCH_SIZE "128" \ 15 | DATA.FEATURE "sup_vitb16_imagenet21k" \ 16 | DATA.NAME ${dataset[idx]} \ 17 | DATA.NUMBER_CLASSES "${number_classes[idx]}" \ 18 | MODEL.ADAPTER.BOTTLENECK_SIZE "16" \ 19 | MODEL.ADAPTER.EXPERT_NUM "4" \ 20 | MODEL.ADAPTER.MOE "True" \ 21 | MODEL.ADAPTER.MERGE "add" \ 22 | MODEL.ADAPTER.SHARE "down" \ 23 | MODEL.ADAPTER.ADDITIONAL "True" \ 24 | MODEL.ADAPTER.DEEPREG "False" \ 25 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \ 26 | MODEL.ADAPTER.REG_WEIGHT "1.0" \ 27 | MODEL.TRANSFER_TYPE "mosa" \ 28 | MODEL.TYPE "vit" \ 29 | SEED "3407" \ 30 | SOLVER.BASE_LR "0.005" \ 31 | SOLVER.WEIGHT_DECAY "0.01" \ 32 | SOLVER.WARMUP_EPOCH "10" \ 33 | DATA.DATAPATH "${data_path}" \ 34 | GPU_ID "${gpu_id}" \ 35 | MODEL.MODEL_ROOT "${model_root}" \ 36 | OUTPUT_DIR "${output_dir}" 37 | done -------------------------------------------------------------------------------- /scripts/mosa/vit/VTAB/vtab.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_id=${1} 4 | data_path=/your/path/to/VTAB 5 | model_root=checkpoints 6 | dataset=("vtab-cifar(num_classes=100)" "vtab-caltech101" "vtab-dtd" "vtab-oxford_flowers102" "vtab-oxford_iiit_pet" "vtab-svhn" "vtab-sun397" "vtab-patch_camelyon" "vtab-eurosat" "vtab-resisc45" 'vtab-diabetic_retinopathy(config="btgraham-300")' 'vtab-clevr(task="count_all")' 'vtab-clevr(task="closest_object_distance")' "vtab-dmlab" 'vtab-kitti(task="closest_vehicle_distance")' 'vtab-dsprites(predicted_attribute="label_x_position",num_classes=16)' 'vtab-dsprites(predicted_attribute="label_orientation",num_classes=16)' 'vtab-smallnorb(predicted_attribute="label_azimuth")' 'vtab-smallnorb(predicted_attribute="label_elevation")') 7 | number_classes=(100 102 47 102 37 10 397 2 10 45 5 8 6 6 4 16 16 18 9) 8 | output_dir=${2} 9 | 10 | for idx in {0..18}; do 11 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \ 12 | --config-file configs/finetune/cub.yaml \ 13 | --sparse-train \ 14 | DATA.BATCH_SIZE "128" \ 15 | DATA.FEATURE "sup_vitb16_imagenet21k" \ 16 | DATA.NAME ${dataset[idx]} \ 17 | DATA.NUMBER_CLASSES "${number_classes[idx]}" \ 18 | MODEL.ADAPTER.BOTTLENECK_SIZE "16" \ 19 | MODEL.ADAPTER.EXPERT_NUM "4" \ 20 | MODEL.ADAPTER.MOE "True" \ 21 | MODEL.ADAPTER.MERGE "add" \ 22 | MODEL.ADAPTER.SHARE "down" \ 23 | MODEL.ADAPTER.ADDITIONAL "True" \ 24 | MODEL.ADAPTER.DEEPREG "False" \ 25 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \ 26 | MODEL.ADAPTER.REG_WEIGHT "1.0" \ 27 | MODEL.TRANSFER_TYPE "mosa" \ 28 | MODEL.TYPE "vit" \ 29 | SEED "3407" \ 30 | SOLVER.BASE_LR "0.005" \ 31 | SOLVER.WEIGHT_DECAY "0.01" \ 32 | SOLVER.WARMUP_EPOCH "10" \ 33 | DATA.DATAPATH "${data_path}" \ 34 | GPU_ID "${gpu_id}" \ 35 | MODEL.MODEL_ROOT "${model_root}" \ 36 | OUTPUT_DIR "${output_dir}" 37 | done -------------------------------------------------------------------------------- /src/utils/io_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | a bunch of helper functions for read and write data 4 | """ 5 | import os 6 | import json 7 | import numpy as np 8 | import time 9 | import pandas as pd 10 | 11 | from typing import List, Union 12 | from PIL import Image, ImageFile 13 | Image.MAX_IMAGE_PIXELS = None 14 | 15 | 16 | def save_or_append_df(out_path, df): 17 | if os.path.exists(out_path): 18 | previous_df = pd.read_pickle(out_path) 19 | df = pd.concat([previous_df, df], ignore_index=True) 20 | df.to_pickle(out_path) 21 | print(f"Saved output at {out_path}") 22 | 23 | 24 | class JSONEncoder(json.JSONEncoder): 25 | def default(self, obj): 26 | if isinstance(obj, np.ndarray): 27 | return obj.tolist() 28 | elif isinstance(obj, bytes): 29 | return str(obj, encoding='utf-8') 30 | elif isinstance(obj, np.integer): 31 | return int(obj) 32 | elif isinstance(obj, np.floating): 33 | return float(obj) 34 | elif isinstance(obj, np.ndarray): 35 | return obj.tolist() 36 | else: 37 | # return super(MyEncoder, self).default(obj) 38 | 39 | raise TypeError( 40 | "Unserializable object {} of type {}".format(obj, type(obj)) 41 | ) 42 | 43 | 44 | def write_json(data: Union[list, dict], outfile: str) -> None: 45 | json_dir, _ = os.path.split(outfile) 46 | if json_dir and not os.path.exists(json_dir): 47 | os.makedirs(json_dir) 48 | 49 | with open(outfile, 'w') as f: 50 | json.dump(data, f, cls=JSONEncoder, ensure_ascii=False, indent=2) 51 | 52 | 53 | def read_json(filename: str) -> Union[list, dict]: 54 | """read json files""" 55 | with open(filename, "rb") as fin: 56 | data = json.load(fin) 57 | return data 58 | 59 | 60 | def pil_loader(path: str) -> Image.Image: 61 | """load an image from path, and suppress warning""" 62 | # to avoid crashing for truncated (corrupted images) 63 | ImageFile.LOAD_TRUNCATED_IMAGES = True 64 | # open path as file to avoid ResourceWarning 65 | # (https://github.com/python-pillow/Pillow/issues/835) 66 | with open(path, 'rb') as f: 67 | img = Image.open(f) 68 | return img.convert('RGB') 69 | -------------------------------------------------------------------------------- /src/models/mlp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Modified from: fbcode/multimo/models/encoders/mlp.py 4 | """ 5 | import math 6 | import torch 7 | 8 | from torch import nn 9 | from typing import List, Type 10 | 11 | from ..utils import logging 12 | logger = logging.get_logger("MOSA") 13 | 14 | 15 | class MLP(nn.Module): 16 | def __init__( 17 | self, 18 | input_dim: int, 19 | mlp_dims: List[int], 20 | dropout: float = 0.1, 21 | nonlinearity: Type[nn.Module] = nn.ReLU, 22 | normalization: Type[nn.Module] = nn.BatchNorm1d, # nn.LayerNorm, 23 | special_bias: bool = False, 24 | add_bn_first: bool = False, 25 | ): 26 | super(MLP, self).__init__() 27 | projection_prev_dim = input_dim 28 | projection_modulelist = [] 29 | last_dim = mlp_dims[-1] 30 | mlp_dims = mlp_dims[:-1] 31 | 32 | if add_bn_first: 33 | if normalization is not None: 34 | projection_modulelist.append(normalization(projection_prev_dim)) 35 | if dropout != 0: 36 | projection_modulelist.append(nn.Dropout(dropout)) 37 | 38 | for idx, mlp_dim in enumerate(mlp_dims): 39 | fc_layer = nn.Linear(projection_prev_dim, mlp_dim) 40 | nn.init.kaiming_normal_(fc_layer.weight, a=0, mode='fan_out') 41 | projection_modulelist.append(fc_layer) 42 | projection_modulelist.append(nonlinearity()) 43 | 44 | if normalization is not None: 45 | projection_modulelist.append(normalization(mlp_dim)) 46 | 47 | if dropout != 0: 48 | projection_modulelist.append(nn.Dropout(dropout)) 49 | projection_prev_dim = mlp_dim 50 | 51 | self.projection = nn.Sequential(*projection_modulelist) 52 | self.last_layer = nn.Linear(projection_prev_dim, last_dim) 53 | nn.init.kaiming_normal_(self.last_layer.weight, a=0, mode='fan_out') 54 | if special_bias: 55 | prior_prob = 0.01 56 | bias_value = -math.log((1 - prior_prob) / prior_prob) 57 | torch.nn.init.constant_(self.last_layer.bias, bias_value) 58 | 59 | def forward(self, x: torch.Tensor) -> torch.Tensor: 60 | """ 61 | input_arguments: 62 | @x: torch.FloatTensor 63 | """ 64 | x = self.projection(x) 65 | x = self.last_layer(x) 66 | return x 67 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mixture of Sparse Adapters 2 | 3 | This repository contains the official PyTorch implementation for MoSA. 4 | 5 | ![MoSA_demo](figure/demo.gif) 6 | 7 | ## Environment setup 8 | 9 | See `env_setup.sh` 10 | 11 | ## Datasets preperation 12 | 13 | - Fine-Grained Visual Classification (FGVC): The datasets can be downloaded following the official links. We split the training data if the public validation set is not available. The splitted dataset can be found here: [Dropbox](https://cornell.box.com/v/vptfgvcsplits), [Google Drive](https://drive.google.com/drive/folders/1mnvxTkYxmOr2W9QjcgS64UBpoJ4UmKaM?usp=sharing). 14 | 15 | - [CUB200 2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) 16 | 17 | - [NABirds](http://info.allaboutbirds.org/nabirds/) 18 | 19 | - [Oxford Flowers](https://www.robots.ox.ac.uk/~vgg/data/flowers/) 20 | 21 | - [Stanford Dogs](http://vision.stanford.edu/aditya86/ImageNetDogs/main.html) 22 | 23 | - [Stanford Cars](https://ai.stanford.edu/~jkrause/cars/car_dataset.html) 24 | 25 | - Visual Task Adaptation Benchmark (VTAB): See [`VTAB_SETUP.md`](https://github.com/KMnP/vpt/blob/main/VTAB_SETUP.md) for detailed instructions and tips. 26 | 27 | - General Image Classification Datasets (GICD): The datasets will be automatically downloaded when you run an experiment using them via `MoSA`. 28 | 29 | ## Pre-trained weights preperation 30 | 31 | Download and place the pre-trained Transformer-based backbones to `MODEL.MODEL_ROOT`. 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 |
Pre-trained BackboneLinkmd5sum
ViT-B/16linkd9715d
ViT-L/16link8f39ce
Swin-Blinkbf9cc1
53 | 54 | ## Training 55 | 56 | To fine-tune a pre-trained ViT model via MoSA on FGVC-cub, you can run: 57 | 58 | ```bash 59 | bash scripts/mosa/vit/FGVC/cub.sh 60 | ``` 61 | 62 | ## License 63 | 64 | The majority of MoSA is licensed under the CC-BY-NC 4.0 license (see [LICENSE](https://github.com/KMnP/vpt/blob/main/LICENSE) for details). 65 | -------------------------------------------------------------------------------- /src/models/build_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Model construction functions. 4 | """ 5 | from tabnanny import verbose 6 | import torch 7 | 8 | from .resnet import ResNet 9 | from .convnext import ConvNeXt 10 | from .vit_models import ViT, Swin, SSLViT 11 | from ..utils import logging 12 | logger = logging.get_logger("MOSA") 13 | # Supported model types 14 | _MODEL_TYPES = { 15 | "resnet": ResNet, 16 | "convnext": ConvNeXt, 17 | "vit": ViT, 18 | "swin": Swin, 19 | "ssl-vit": SSLViT, 20 | } 21 | 22 | 23 | def build_model(cfg): 24 | """ 25 | build model here 26 | """ 27 | assert ( 28 | cfg.MODEL.TYPE in _MODEL_TYPES.keys() 29 | ), "Model type '{}' not supported".format(cfg.MODEL.TYPE) 30 | assert ( 31 | cfg.NUM_GPUS <= torch.cuda.device_count() 32 | ), "Cannot use more GPU devices than available" 33 | 34 | # Construct the model 35 | train_type = cfg.MODEL.TYPE 36 | model = _MODEL_TYPES[train_type](cfg) 37 | 38 | log_model_info(model, verbose=cfg.DBG) 39 | model, device = load_model_to_device(model, cfg) 40 | logger.info(f"Device used for model: {device}") 41 | 42 | return model, device 43 | 44 | 45 | def log_model_info(model, verbose=False): 46 | """Logs model info""" 47 | if verbose: 48 | logger.info(f"Classification Model:\n{model}") 49 | model_total_params = sum(p.numel() for p in model.parameters()) 50 | model_grad_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 51 | logger.info("Total Parameters: {0}\t Gradient Parameters: {1}".format( 52 | model_total_params, model_grad_params)) 53 | logger.info("tuned percent:%.3f"%(model_grad_params/model_total_params*100)) 54 | 55 | 56 | def get_current_device(): 57 | if torch.cuda.is_available(): 58 | # Determine the GPU used by the current process 59 | cur_device = torch.cuda.current_device() 60 | else: 61 | cur_device = torch.device('cpu') 62 | return cur_device 63 | 64 | 65 | def load_model_to_device(model, cfg): 66 | cur_device = get_current_device() 67 | if torch.cuda.is_available(): 68 | # Transfer the model to the current GPU device 69 | model = model.cuda(device=cur_device) 70 | # Use multi-process data parallel model in the multi-gpu setting 71 | if cfg.NUM_GPUS > 1: 72 | # Make model replica operate on the current device 73 | model = torch.nn.parallel.DistributedDataParallel( 74 | module=model, device_ids=[cur_device], output_device=cur_device, 75 | find_unused_parameters=True, 76 | ) 77 | else: 78 | model = model.to(cur_device) 79 | return model, cur_device 80 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | GIT_TOKENS.md 2 | checkpoints/ 3 | hyperparameters/ 4 | outputs/ 5 | outputs_appendix/ 6 | scripts/ 7 | *.npy 8 | *.zip 9 | 10 | # General 11 | .DS_Store? 12 | .DS_Store 13 | .AppleDouble 14 | .LSOverride 15 | 16 | # Byte-compiled / optimized / DLL files 17 | __pycache__/ 18 | *.py[cod] 19 | *$py.class 20 | 21 | # C extensions 22 | *.so 23 | 24 | # Distribution / packaging 25 | .Python 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | downloads/ 30 | eggs/ 31 | .eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | wheels/ 38 | pip-wheel-metadata/ 39 | share/python-wheels/ 40 | *.egg-info/ 41 | .installed.cfg 42 | *.egg 43 | MANIFEST 44 | 45 | # PyInstaller 46 | # Usually these files are written by a python script from a template 47 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 48 | *.manifest 49 | *.spec 50 | 51 | # Installer logs 52 | pip-log.txt 53 | pip-delete-this-directory.txt 54 | 55 | # Unit test / coverage reports 56 | htmlcov/ 57 | .tox/ 58 | .nox/ 59 | .coverage 60 | .coverage.* 61 | .cache 62 | nosetests.xml 63 | coverage.xml 64 | *.cover 65 | *.py,cover 66 | .hypothesis/ 67 | .pytest_cache/ 68 | 69 | # Translations 70 | *.mo 71 | *.pot 72 | 73 | # Django stuff: 74 | *.log 75 | local_settings.py 76 | db.sqlite3 77 | db.sqlite3-journal 78 | 79 | # Flask stuff: 80 | instance/ 81 | .webassets-cache 82 | 83 | # Scrapy stuff: 84 | .scrapy 85 | 86 | # Sphinx documentation 87 | docs/_build/ 88 | 89 | # PyBuilder 90 | target/ 91 | 92 | # Jupyter Notebook 93 | .ipynb_checkpoints 94 | 95 | # IPython 96 | profile_default/ 97 | ipython_config.py 98 | 99 | # pyenv 100 | .python-version 101 | 102 | # pipenv 103 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 104 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 105 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 106 | # install all needed dependencies. 107 | #Pipfile.lock 108 | 109 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 110 | __pypackages__/ 111 | 112 | # Celery stuff 113 | celerybeat-schedule 114 | celerybeat.pid 115 | 116 | # SageMath parsed files 117 | *.sage.py 118 | 119 | # Environments 120 | .env 121 | .venv 122 | env/ 123 | venv/ 124 | ENV/ 125 | env.bak/ 126 | venv.bak/ 127 | 128 | # Spyder project settings 129 | .spyderproject 130 | .spyproject 131 | 132 | # Rope project settings 133 | .ropeproject 134 | 135 | # mkdocs documentation 136 | /site 137 | 138 | # mypy 139 | .mypy_cache/ 140 | .dmypy.json 141 | dmypy.json 142 | 143 | # Pyre type checker 144 | .pyre/ 145 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/patch_camelyon.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements PatchCamelyon data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | 28 | 29 | @Registry.register("data.patch_camelyon", "class") 30 | class PatchCamelyonData(base.ImageTfdsData): 31 | """Provides PatchCamelyon data.""" 32 | 33 | def __init__(self, data_dir=None): 34 | 35 | dataset_builder = tfds.builder("patch_camelyon:2.*.*", data_dir=data_dir) 36 | dataset_builder.download_and_prepare() 37 | 38 | # Defines dataset specific train/val/trainval/test splits. 39 | tfds_splits = { 40 | "test": "test", 41 | "train": "train", 42 | "val": "validation", 43 | "trainval": "train+validation", 44 | "train800": "train[:800]", 45 | "val200": "validation[:200]", 46 | "train800val200": "train[:800]+validation[:200]", 47 | } 48 | # Creates a dict with example counts. 49 | num_samples_splits = { 50 | "test": dataset_builder.info.splits["test"].num_examples, 51 | "train": dataset_builder.info.splits["train"].num_examples, 52 | "val": dataset_builder.info.splits["validation"].num_examples, 53 | "train800": 800, 54 | "val200": 200, 55 | "train800val200": 1000, 56 | } 57 | num_samples_splits["trainval"] = ( 58 | num_samples_splits["train"] + num_samples_splits["val"]) 59 | super(PatchCamelyonData, self).__init__( 60 | dataset_builder=dataset_builder, 61 | tfds_splits=tfds_splits, 62 | num_samples_splits=num_samples_splits, 63 | num_preprocessing_threads=400, 64 | shuffle_buffer_size=10000, 65 | # Note: Export only image and label tensors with their original types. 66 | base_preprocess_fn=base.make_get_tensors_fn(["image", "label"]), 67 | num_classes=dataset_builder.info.features["label"].num_classes) 68 | -------------------------------------------------------------------------------- /src/models/vit_adapter/adapter_block.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | ''' 3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 4 | ''' 5 | import math 6 | import logging 7 | from functools import partial 8 | from collections import OrderedDict 9 | from copy import deepcopy 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ 16 | from timm.models.vision_transformer import Attention 17 | from timm.models.vision_transformer import Block 18 | 19 | from ...utils import logging 20 | logger = logging.get_logger("MOSA") 21 | 22 | 23 | class Pfeiffer_Block(Block): 24 | 25 | def __init__(self, adapter_config, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 26 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 27 | 28 | super(Pfeiffer_Block, self).__init__( 29 | dim=dim, 30 | num_heads=num_heads, 31 | mlp_ratio=mlp_ratio, 32 | qkv_bias=qkv_bias, 33 | drop=drop, 34 | attn_drop=attn_drop, 35 | drop_path=drop_path, 36 | act_layer=act_layer, 37 | norm_layer=norm_layer) 38 | 39 | self.adapter_config = adapter_config 40 | 41 | if adapter_config.STYLE == "Pfeiffer": 42 | self.adapter_downsample = nn.Linear( 43 | dim, 44 | dim // adapter_config.REDUCATION_FACTOR 45 | ) 46 | self.adapter_upsample = nn.Linear( 47 | dim // adapter_config.REDUCATION_FACTOR, 48 | dim 49 | ) 50 | self.adapter_act_fn = act_layer() 51 | 52 | nn.init.zeros_(self.adapter_downsample.weight) 53 | nn.init.zeros_(self.adapter_downsample.bias) 54 | 55 | nn.init.zeros_(self.adapter_upsample.weight) 56 | nn.init.zeros_(self.adapter_upsample.bias) 57 | else: 58 | raise ValueError("Other adapter styles are not supported.") 59 | 60 | def forward(self, x): 61 | 62 | if self.adapter_config.STYLE == "Pfeiffer": 63 | # same as reguluar ViT block 64 | h = x 65 | x = self.norm1(x) 66 | x = self.attn(x) 67 | x = self.drop_path(x) 68 | x = x + h 69 | 70 | h = x 71 | x = self.norm2(x) 72 | x = self.mlp(x) 73 | 74 | # start to insert adapter layers... 75 | adpt = self.adapter_downsample(x) 76 | adpt = self.adapter_act_fn(adpt) 77 | adpt = self.adapter_upsample(adpt) 78 | x = adpt + x 79 | # ...end 80 | 81 | x = self.drop_path(x) 82 | x = x + h 83 | 84 | return x 85 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/sun397.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements Sun397 data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | CUSTOM_TRAIN_SPLIT_PERCENT = 50 28 | CUSTOM_VALIDATION_SPLIT_PERCENT = 20 29 | CUSTOM_TEST_SPLIT_PERCENT = 30 30 | 31 | 32 | @Registry.register("data.sun397", "class") 33 | class Sun397Data(base.ImageTfdsData): 34 | """Provides Sun397Data data.""" 35 | 36 | def __init__(self, config="tfds", data_dir=None): 37 | 38 | if config == "tfds": 39 | dataset_builder = tfds.builder("sun397/tfds:4.*.*", data_dir=data_dir) 40 | dataset_builder.download_and_prepare() 41 | 42 | tfds_splits = { 43 | "train": "train", 44 | "val": "validation", 45 | "test": "test", 46 | "trainval": "train+validation", 47 | "train800": "train[:800]", 48 | "val200": "validation[:200]", 49 | "train800val200": "train[:800]+validation[:200]", 50 | } 51 | # Creates a dict with example counts. 52 | num_samples_splits = { 53 | "test": dataset_builder.info.splits["test"].num_examples, 54 | "train": dataset_builder.info.splits["train"].num_examples, 55 | "val": dataset_builder.info.splits["validation"].num_examples, 56 | "train800": 800, 57 | "val200": 200, 58 | "train800val200": 1000, 59 | } 60 | num_samples_splits["trainval"] = ( 61 | num_samples_splits["train"] + num_samples_splits["val"]) 62 | else: 63 | 64 | raise ValueError("No supported config %r for Sun397Data." % config) 65 | 66 | super(Sun397Data, self).__init__( 67 | dataset_builder=dataset_builder, 68 | tfds_splits=tfds_splits, 69 | num_samples_splits=num_samples_splits, 70 | num_preprocessing_threads=400, 71 | shuffle_buffer_size=10000, 72 | # Note: Export only image and label tensors with their original types. 73 | base_preprocess_fn=base.make_get_tensors_fn(["image", "label"]), 74 | num_classes=dataset_builder.info.features["label"].num_classes) 75 | -------------------------------------------------------------------------------- /src/models/vit_backbones/vit_mae.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | """ 4 | borrowed from https://github.com/facebookresearch/mae/blob/main/models_vit.py 5 | """ 6 | from functools import partial 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | import timm.models.vision_transformer 12 | 13 | 14 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 15 | """ Vision Transformer with support for global average pooling 16 | """ 17 | def __init__(self, global_pool=False, **kwargs): 18 | super(VisionTransformer, self).__init__(**kwargs) 19 | 20 | self.global_pool = global_pool 21 | if self.global_pool: 22 | norm_layer = kwargs['norm_layer'] 23 | embed_dim = kwargs['embed_dim'] 24 | self.fc_norm = norm_layer(embed_dim) 25 | 26 | del self.norm # remove the original norm 27 | 28 | def forward_features(self, x): 29 | B = x.shape[0] 30 | x = self.patch_embed(x) 31 | 32 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 33 | x = torch.cat((cls_tokens, x), dim=1) 34 | x = x + self.pos_embed 35 | x = self.pos_drop(x) 36 | 37 | for blk in self.blocks: 38 | x = blk(x) 39 | 40 | if self.global_pool: 41 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 42 | outcome = self.fc_norm(x) 43 | else: 44 | x = self.norm(x) 45 | outcome = x[:, 0] 46 | 47 | return outcome 48 | 49 | 50 | def build_model(model_type): 51 | if "vitb" in model_type: 52 | return vit_base_patch16() 53 | elif "vitl" in model_type: 54 | return vit_large_patch16() 55 | elif "vith" in model_type: 56 | return vit_huge_patch14() 57 | 58 | 59 | def vit_base_patch16(**kwargs): 60 | model = VisionTransformer( 61 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 62 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 63 | mlp_ratio=4, qkv_bias=True, 64 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 65 | return model 66 | 67 | 68 | def vit_large_patch16(**kwargs): 69 | model = VisionTransformer( 70 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 71 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 72 | mlp_ratio=4, qkv_bias=True, 73 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 74 | return model 75 | 76 | 77 | def vit_huge_patch14(**kwargs): 78 | model = VisionTransformer( 79 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 80 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 81 | mlp_ratio=4, qkv_bias=True, 82 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 83 | return model 84 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/dtd.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements the Describable Textures Dataset (DTD) data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | 28 | 29 | @Registry.register("data.dtd", "class") 30 | class DTDData(base.ImageTfdsData): 31 | """Provides Describable Textures Dataset (DTD) data. 32 | 33 | As of version 1.0.0, the train/val/test splits correspond to those of the 34 | 1st fold of the official cross-validation partition. 35 | 36 | For additional details and usage, see the base class. 37 | """ 38 | 39 | def __init__(self, data_dir=None): 40 | 41 | dataset_builder = tfds.builder("dtd:3.*.*", data_dir=data_dir) 42 | dataset_builder.download_and_prepare() 43 | 44 | # Defines dataset specific train/val/trainval/test splits. 45 | tfds_splits = { 46 | "train": "train", 47 | "val": "validation", 48 | "trainval": "train+validation", 49 | "test": "test", 50 | "train800": "train[:800]", 51 | "val200": "validation[:200]", 52 | "train800val200": "train[:800]+validation[:200]", 53 | } 54 | 55 | # Creates a dict with example counts for each split. 56 | train_count = dataset_builder.info.splits["train"].num_examples 57 | val_count = dataset_builder.info.splits["validation"].num_examples 58 | test_count = dataset_builder.info.splits["test"].num_examples 59 | num_samples_splits = { 60 | "train": train_count, 61 | "val": val_count, 62 | "trainval": train_count + val_count, 63 | "test": test_count, 64 | "train800": 800, 65 | "val200": 200, 66 | "train800val200": 1000, 67 | } 68 | 69 | super(DTDData, self).__init__( 70 | dataset_builder=dataset_builder, 71 | tfds_splits=tfds_splits, 72 | num_samples_splits=num_samples_splits, 73 | num_preprocessing_threads=400, 74 | shuffle_buffer_size=10000, 75 | # Note: Export only image and label tensors with their original types. 76 | base_preprocess_fn=base.make_get_tensors_fn(["image", "label"]), 77 | num_classes=dataset_builder.info.features["label"].num_classes) 78 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/dmlab.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements Dmlab data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | 28 | 29 | @Registry.register("data.dmlab", "class") 30 | class DmlabData(base.ImageTfdsData): 31 | """Dmlab dataset. 32 | 33 | The Dmlab dataset contains frames observed by the agent acting in the 34 | DMLab environment, which are annotated by the distance between 35 | the agent and various objects present in the environment. The goal is to 36 | is to evaluate the ability of a visual model to reason about distances 37 | from the visual input in 3D environments. The Dmlab dataset consists of 38 | 360x480 color images in 6 classes. The classes are 39 | {close, far, very far} x {positive reward, negative reward} 40 | respectively. 41 | """ 42 | 43 | def __init__(self, data_dir=None): 44 | 45 | dataset_builder = tfds.builder("dmlab:2.0.1", data_dir=data_dir) 46 | 47 | tfds_splits = { 48 | "train": "train", 49 | "val": "validation", 50 | "trainval": "train+validation", 51 | "test": "test", 52 | "train800": "train[:800]", 53 | "val200": "validation[:200]", 54 | "train800val200": "train[:800]+validation[:200]", 55 | } 56 | 57 | # Example counts are retrieved from the tensorflow dataset info. 58 | train_count = dataset_builder.info.splits["train"].num_examples 59 | val_count = dataset_builder.info.splits["validation"].num_examples 60 | test_count = dataset_builder.info.splits["test"].num_examples 61 | 62 | # Creates a dict with example counts for each split. 63 | num_samples_splits = { 64 | "train": train_count, 65 | "val": val_count, 66 | "trainval": train_count + val_count, 67 | "test": test_count, 68 | "train800": 800, 69 | "val200": 200, 70 | "train800val200": 1000, 71 | } 72 | 73 | super(DmlabData, self).__init__( 74 | dataset_builder=dataset_builder, 75 | tfds_splits=tfds_splits, 76 | num_samples_splits=num_samples_splits, 77 | num_preprocessing_threads=400, 78 | shuffle_buffer_size=10000, 79 | base_preprocess_fn=base.make_get_and_cast_tensors_fn({ 80 | "image": ("image", None), 81 | "label": ("label", None), 82 | }), 83 | num_classes=dataset_builder.info.features["label"].num_classes, 84 | image_key="image") 85 | -------------------------------------------------------------------------------- /src/engine/eval/singlelabel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Functions for computing metrics. all metrics has range of 0-1""" 4 | 5 | import numpy as np 6 | import torch 7 | from sklearn.metrics import ( 8 | accuracy_score, average_precision_score, f1_score, roc_auc_score 9 | ) 10 | 11 | 12 | def accuracy(y_probs, y_true): 13 | # y_prob: (num_images, num_classes) 14 | y_preds = np.argmax(y_probs, axis=1) 15 | accuracy = accuracy_score(y_true, y_preds) 16 | error = 1.0 - accuracy 17 | return accuracy, error 18 | 19 | 20 | def top_n_accuracy(y_probs, truths, n=1): 21 | # y_prob: (num_images, num_classes) 22 | # truth: (num_images, num_classes) multi/one-hot encoding 23 | best_n = np.argsort(y_probs, axis=1)[:, -n:] 24 | if isinstance(truths, np.ndarray) and truths.shape == y_probs.shape: 25 | ts = np.argmax(truths, axis=1) 26 | else: 27 | # a list of GT class idx 28 | ts = truths 29 | 30 | num_input = y_probs.shape[0] 31 | successes = 0 32 | for i in range(num_input): 33 | if ts[i] in best_n[i, :]: 34 | successes += 1 35 | return float(successes) / num_input 36 | 37 | 38 | def compute_acc_auc(y_probs, y_true_ids): 39 | onehot_tgts = np.zeros_like(y_probs) 40 | for idx, t in enumerate(y_true_ids): 41 | onehot_tgts[idx, t] = 1. 42 | 43 | num_classes = y_probs.shape[1] 44 | if num_classes == 2: 45 | top1, _ = accuracy(y_probs, y_true_ids) 46 | # so precision can set all to 2 47 | try: 48 | auc = roc_auc_score(onehot_tgts, y_probs, average='macro') 49 | except ValueError as e: 50 | print(f"value error encountered {e}, set auc sccore to -1.") 51 | auc = -1 52 | return {"top1": top1, "rocauc": auc} 53 | 54 | top1, _ = accuracy(y_probs, y_true_ids) 55 | k = min([5, num_classes]) # if number of labels < 5, use the total class 56 | top5 = top_n_accuracy(y_probs, y_true_ids, k) 57 | return {"top1": top1, f"top{k}": top5} 58 | 59 | 60 | def topks_correct(preds, labels, ks): 61 | """Computes the number of top-k correct predictions for each k.""" 62 | assert preds.size(0) == labels.size( 63 | 0 64 | ), "Batch dim of predictions and labels must match" 65 | # Find the top max_k predictions for each sample 66 | _top_max_k_vals, top_max_k_inds = torch.topk( 67 | preds, max(ks), dim=1, largest=True, sorted=True 68 | ) 69 | # (batch_size, max_k) -> (max_k, batch_size) 70 | top_max_k_inds = top_max_k_inds.t() 71 | # (batch_size, ) -> (max_k, batch_size) 72 | rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds) 73 | # (i, j) = 1 if top i-th prediction for the j-th sample is correct 74 | top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels) 75 | # Compute the number of topk correct predictions for each k 76 | topks_correct = [ 77 | top_max_k_correct[:k, :].reshape(-1).float().sum() for k in ks 78 | ] 79 | return topks_correct 80 | 81 | 82 | def topk_errors(preds, labels, ks): 83 | """Computes the top-k error for each k.""" 84 | if int(labels.min()) < 0: # has ignore 85 | keep_ids = np.where(labels.cpu() >= 0)[0] 86 | preds = preds[keep_ids, :] 87 | labels = labels[keep_ids] 88 | 89 | num_topks_correct = topks_correct(preds, labels, ks) 90 | return [(1.0 - x / preds.size(0)) for x in num_topks_correct] 91 | 92 | 93 | def topk_accuracies(preds, labels, ks): 94 | """Computes the top-k accuracy for each k.""" 95 | num_topks_correct = topks_correct(preds, labels, ks) 96 | return [(x / preds.size(0)) for x in num_topks_correct] 97 | 98 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/caltech.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Imports the Caltech images dataset.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | from . import base as base 23 | from .registry import Registry 24 | import tensorflow_datasets as tfds 25 | 26 | 27 | # Percentage of the original training set retained for training, the rest is 28 | # used as a validation set. 29 | _TRAIN_SPLIT_PERCENT = 90 30 | 31 | 32 | @Registry.register("data.caltech101", "class") 33 | class Caltech101(base.ImageTfdsData): 34 | """Provides the Caltech101 dataset. 35 | 36 | See the base class for additional details on the class. 37 | 38 | See TFDS dataset for details on the dataset: 39 | third_party/py/tensorflow_datasets/image/caltech.py 40 | 41 | The original (TFDS) dataset contains only a train and test split. We randomly 42 | sample _TRAIN_SPLIT_PERCENT% of the train split for our "train" set. The 43 | remainder of the TFDS train split becomes our "val" set. The full TFDS train 44 | split is called "trainval". The TFDS test split is used as our test set. 45 | 46 | Note that, in the TFDS dataset, the training split is class-balanced, but not 47 | the test split. Therefore, a significant difference between performance on the 48 | "val" and "test" sets should be expected. 49 | """ 50 | 51 | def __init__(self, data_dir=None): 52 | dataset_builder = tfds.builder("caltech101:3.*.*", data_dir=data_dir) 53 | dataset_builder.download_and_prepare() 54 | 55 | # Creates a dict with example counts for each split. 56 | trainval_count = dataset_builder.info.splits["train"].num_examples 57 | train_count = (_TRAIN_SPLIT_PERCENT * trainval_count) // 100 58 | test_count = dataset_builder.info.splits["test"].num_examples 59 | num_samples_splits = dict( 60 | train=train_count, 61 | val=trainval_count - train_count, 62 | trainval=trainval_count, 63 | test=test_count, 64 | train800=800, 65 | val200=200, 66 | train800val200=1000) 67 | 68 | # Defines dataset specific train/val/trainval/test splits. 69 | tfds_splits = { 70 | "train": "train[:{}]".format(train_count), 71 | "val": "train[{}:]".format(train_count), 72 | "trainval": "train", 73 | "test": "test", 74 | "train800": "train[:800]", 75 | "val200": "train[{}:{}]".format(train_count, train_count+200), 76 | "train800val200": ( 77 | "train[:800]+train[{}:{}]".format(train_count, train_count+200)), 78 | } 79 | 80 | super(Caltech101, self).__init__( 81 | dataset_builder=dataset_builder, 82 | tfds_splits=tfds_splits, 83 | num_samples_splits=num_samples_splits, 84 | num_preprocessing_threads=400, 85 | shuffle_buffer_size=3000, 86 | base_preprocess_fn=base.make_get_tensors_fn(("image", "label")), 87 | num_classes=dataset_builder.info.features["label"].num_classes) 88 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/resisc45.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements RESISC-45 data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | TRAIN_SPLIT_PERCENT = 60 28 | VALIDATION_SPLIT_PERCENT = 20 29 | TEST_SPLIT_PERCENT = 20 30 | 31 | 32 | @Registry.register("data.resisc45", "class") 33 | class Resisc45Data(base.ImageTfdsData): 34 | """Provides RESISC-45 dataset. 35 | 36 | RESISC45 dataset is a publicly available benchmark for Remote Sensing Image 37 | Scene Classification (RESISC), created by Northwestern Polytechnical 38 | University (NWPU). This dataset contains 31,500 images, covering 45 scene 39 | classes with 700 images in each class. 40 | 41 | URL: http://www.escience.cn/people/JunweiHan/NWPU-RESISC45.html 42 | """ 43 | 44 | def __init__(self, data_dir=None): 45 | dataset_builder = tfds.builder("resisc45:3.*.*", data_dir=data_dir) 46 | dataset_builder.download_and_prepare() 47 | 48 | # Example counts are retrieved from the tensorflow dataset info. 49 | num_examples = dataset_builder.info.splits["train"].num_examples 50 | train_count = num_examples * TRAIN_SPLIT_PERCENT // 100 51 | val_count = num_examples * VALIDATION_SPLIT_PERCENT // 100 52 | test_count = num_examples * TEST_SPLIT_PERCENT // 100 53 | 54 | tfds_splits = { 55 | "train": 56 | "train[:{}]".format(train_count), 57 | "val": 58 | "train[{}:{}]".format(train_count, train_count + val_count), 59 | "trainval": 60 | "train[:{}]".format(train_count + val_count), 61 | "test": 62 | "train[{}:]".format(train_count + val_count), 63 | "train800": 64 | "train[:800]", 65 | "val200": 66 | "train[{}:{}]".format(train_count, train_count+200), 67 | "train800val200": 68 | "train[:800]+train[{}:{}]".format(train_count, train_count+200), 69 | } 70 | 71 | # Creates a dict with example counts for each split. 72 | num_samples_splits = { 73 | "train": train_count, 74 | "val": val_count, 75 | "trainval": train_count + val_count, 76 | "test": test_count, 77 | "train800": 800, 78 | "val200": 200, 79 | "train800val200": 1000, 80 | } 81 | 82 | super(Resisc45Data, self).__init__( 83 | dataset_builder=dataset_builder, 84 | tfds_splits=tfds_splits, 85 | num_samples_splits=num_samples_splits, 86 | num_preprocessing_threads=400, 87 | shuffle_buffer_size=10000, 88 | # Note: Rename tensors but keep their original types. 89 | base_preprocess_fn=base.make_get_and_cast_tensors_fn({ 90 | "image": ("image", None), 91 | "label": ("label", None), 92 | }), 93 | num_classes=dataset_builder.info.features["label"].num_classes) 94 | -------------------------------------------------------------------------------- /src/solver/losses.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from typing import Optional 6 | 7 | from ..utils import logging 8 | logger = logging.get_logger("MOSA") 9 | 10 | 11 | class SigmoidLoss(nn.Module): 12 | def __init__(self, cfg=None): 13 | super(SigmoidLoss, self).__init__() 14 | 15 | def is_single(self): 16 | return True 17 | 18 | def is_local(self): 19 | return False 20 | 21 | def multi_hot(self, labels: torch.Tensor, nb_classes: int) -> torch.Tensor: 22 | labels = labels.unsqueeze(1) # (batch_size, 1) 23 | target = torch.zeros( 24 | labels.size(0), nb_classes, device=labels.device 25 | ).scatter_(1, labels, 1.) 26 | # (batch_size, num_classes) 27 | return target 28 | 29 | def loss( 30 | self, logits, targets, per_cls_weights, 31 | multihot_targets: Optional[bool] = False 32 | ): 33 | # targets: 1d-tensor of integer 34 | # Only support single label at this moment 35 | # if len(targets.shape) != 2: 36 | num_classes = logits.shape[1] 37 | targets = self.multi_hot(targets, num_classes) 38 | 39 | loss = F.binary_cross_entropy_with_logits( 40 | logits, targets, reduction="none") 41 | # logger.info(f"loss shape: {loss.shape}") 42 | weight = torch.tensor( 43 | per_cls_weights, device=logits.device 44 | ).unsqueeze(0) 45 | # logger.info(f"weight shape: {weight.shape}") 46 | loss = torch.mul(loss.to(torch.float32), weight.to(torch.float32)) 47 | return torch.sum(loss) / targets.shape[0] 48 | 49 | def forward( 50 | self, pred_logits, targets, per_cls_weights, multihot_targets=False 51 | ): 52 | loss = self.loss( 53 | pred_logits, targets, per_cls_weights, multihot_targets) 54 | return loss 55 | 56 | 57 | class SoftmaxLoss(SigmoidLoss): 58 | def __init__(self, cfg=None): 59 | super(SoftmaxLoss, self).__init__() 60 | 61 | def loss(self, logits, targets, per_cls_weights, kwargs): 62 | weight = torch.tensor( 63 | per_cls_weights, device=logits.device 64 | ) 65 | loss = F.cross_entropy(logits, targets, weight, reduction="none") 66 | 67 | return torch.sum(loss) / targets.shape[0] 68 | 69 | 70 | def symmetric_KL_loss(input, target, reduction='batchmean'): 71 | """ symmetric KL-divergence 1/2*(KL(p||q)+KL(q||p)) """ 72 | 73 | input = input.float() 74 | target = target.float() 75 | loss = F.kl_div(F.log_softmax(input, dim=-1, dtype=torch.float32), 76 | F.softmax(target.detach(), dim=-1, dtype=torch.float32), reduction=reduction) + \ 77 | F.kl_div(F.log_softmax(target, dim=-1, dtype=torch.float32), 78 | F.softmax(input.detach(), dim=-1, dtype=torch.float32), reduction=reduction) 79 | return 0.5 * loss.sum() 80 | 81 | 82 | def deepreg_MSE_loss(input, target, reduction='mean'): 83 | """ deep regulerization MSE loss """ 84 | 85 | loss = 0 86 | for i in range(6): 87 | loss += F.mse_loss(input[i], target[i], reduction=reduction) 88 | return loss.sum() / len(input) * 0.01 89 | 90 | 91 | LOSS = { 92 | "softmax": SoftmaxLoss, 93 | } 94 | 95 | 96 | def build_loss(cfg): 97 | loss_name = cfg.SOLVER.LOSS 98 | assert loss_name in LOSS, \ 99 | f'loss name {loss_name} is not supported' 100 | loss_fn = LOSS[loss_name] 101 | if not loss_fn: 102 | return None 103 | else: 104 | return loss_fn(cfg) 105 | -------------------------------------------------------------------------------- /src/configs/vit_configs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Copyright (c) Meta Platforms, Inc. All Rights Reserved 4 | https://github.com/jeonsworld/ViT-pytorch/blob/main/models/configs.py 5 | """ 6 | import ml_collections 7 | 8 | 9 | def get_testing(): 10 | """Returns a minimal configuration for testing.""" 11 | config = ml_collections.ConfigDict() 12 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 13 | config.hidden_size = 1 14 | config.transformer = ml_collections.ConfigDict() 15 | config.transformer.mlp_dim = 1 16 | config.transformer.num_heads = 1 17 | config.transformer.num_layers = 1 18 | config.transformer.attention_dropout_rate = 0.0 19 | config.transformer.dropout_rate = 0.1 20 | config.classifier = 'token' 21 | config.representation_size = None 22 | return config 23 | 24 | 25 | def get_b16_config(): 26 | """Returns the ViT-B/16 configuration.""" 27 | config = ml_collections.ConfigDict() 28 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 29 | config.hidden_size = 768 30 | config.transformer = ml_collections.ConfigDict() 31 | config.transformer.mlp_dim = 3072 32 | config.transformer.num_heads = 12 33 | config.transformer.num_layers = 12 34 | config.transformer.attention_dropout_rate = 0.0 35 | config.transformer.dropout_rate = 0.1 36 | config.classifier = 'token' 37 | config.representation_size = None 38 | return config 39 | 40 | 41 | def get_r50_b16_config(): 42 | """Returns the Resnet50 + ViT-B/16 configuration.""" 43 | config = get_b16_config() 44 | del config.patches.size 45 | config.patches.grid = (14, 14) 46 | config.resnet = ml_collections.ConfigDict() 47 | config.resnet.num_layers = (3, 4, 9) 48 | config.resnet.width_factor = 1 49 | return config 50 | 51 | 52 | def get_b32_config(): 53 | """Returns the ViT-B/32 configuration.""" 54 | config = get_b16_config() 55 | config.patches.size = (32, 32) 56 | return config 57 | 58 | 59 | def get_b8_config(): 60 | """Returns the ViT-B/32 configuration.""" 61 | config = get_b16_config() 62 | config.patches.size = (8, 8) 63 | return config 64 | 65 | 66 | def get_l16_config(): 67 | """Returns the ViT-L/16 configuration.""" 68 | config = ml_collections.ConfigDict() 69 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 70 | config.hidden_size = 1024 71 | config.transformer = ml_collections.ConfigDict() 72 | config.transformer.mlp_dim = 4096 73 | config.transformer.num_heads = 16 74 | config.transformer.num_layers = 24 75 | config.transformer.attention_dropout_rate = 0.0 76 | config.transformer.dropout_rate = 0.1 77 | config.classifier = 'token' 78 | config.representation_size = None 79 | return config 80 | 81 | 82 | def get_l32_config(): 83 | """Returns the ViT-L/32 configuration.""" 84 | config = get_l16_config() 85 | config.patches.size = (32, 32) 86 | return config 87 | 88 | 89 | def get_h14_config(): 90 | """Returns the ViT-L/16 configuration.""" 91 | config = ml_collections.ConfigDict() 92 | config.patches = ml_collections.ConfigDict({'size': (14, 14)}) 93 | config.hidden_size = 1280 94 | config.transformer = ml_collections.ConfigDict() 95 | config.transformer.mlp_dim = 5120 96 | config.transformer.num_heads = 16 97 | config.transformer.num_layers = 32 98 | config.transformer.attention_dropout_rate = 0.0 99 | config.transformer.dropout_rate = 0.1 100 | config.classifier = 'token' 101 | config.representation_size = None 102 | return config 103 | -------------------------------------------------------------------------------- /src/models/vit_adapter/vit_moco.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | borrow from https://github.com/facebookresearch/moco-v3/blob/main/vits.py 4 | """ 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | from functools import partial, reduce 9 | from operator import mul 10 | 11 | from timm.models.vision_transformer import VisionTransformer, _cfg 12 | # from timm.layers.helpers import to_2tuple 13 | from timm.models.layers.helpers import to_2tuple 14 | from timm.models.layers import PatchEmbed 15 | 16 | from .adapter_block import Pfeiffer_Block 17 | from ..vit_backbones.vit_moco import VisionTransformerMoCo 18 | from ...utils import logging 19 | logger = logging.get_logger("MOSA") 20 | 21 | 22 | class ADPT_VisionTransformerMoCo(VisionTransformerMoCo): 23 | def __init__( 24 | self, 25 | adapter_cfg, 26 | stop_grad_conv1=False, 27 | img_size=224, 28 | patch_size=16, 29 | in_chans=3, 30 | num_classes=1000, 31 | embed_dim=768, 32 | depth=12, 33 | num_heads=12, 34 | mlp_ratio=4., 35 | qkv_bias=True, 36 | representation_size=None, 37 | distilled=False, 38 | drop_rate=0., 39 | attn_drop_rate=0., 40 | drop_path_rate=0., 41 | embed_layer=PatchEmbed, 42 | norm_layer=None, 43 | act_layer=None, 44 | weight_init='', 45 | **kwargs): 46 | 47 | super(ADPT_VisionTransformerMoCo, self).__init__( 48 | stop_grad_conv1=stop_grad_conv1, 49 | img_size=img_size, 50 | patch_size=patch_size, 51 | in_chans=in_chans, 52 | num_classes=num_classes, 53 | embed_dim=embed_dim, 54 | depth=depth, 55 | num_heads=num_heads, 56 | mlp_ratio=mlp_ratio, 57 | qkv_bias=qkv_bias, 58 | representation_size=representation_size, 59 | distilled=distilled, 60 | drop_rate=drop_rate, 61 | attn_drop_rate=attn_drop_rate, 62 | drop_path_rate=drop_path_rate, 63 | embed_layer=embed_layer, 64 | norm_layer=norm_layer, 65 | act_layer=act_layer, 66 | weight_init=weight_init, 67 | **kwargs 68 | ) 69 | 70 | self.adapter_cfg = adapter_cfg 71 | 72 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 73 | act_layer = act_layer or nn.GELU 74 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 75 | 76 | if adapter_cfg.STYLE == "Pfeiffer": 77 | self.blocks = nn.Sequential(*[ 78 | Pfeiffer_Block( 79 | adapter_config=adapter_cfg, 80 | dim=embed_dim, 81 | num_heads=num_heads, 82 | mlp_ratio=mlp_ratio, 83 | qkv_bias=qkv_bias, 84 | drop=drop_rate, 85 | attn_drop=attn_drop_rate, 86 | drop_path=dpr[i], 87 | norm_layer=norm_layer, 88 | act_layer=act_layer) for i in range(depth)]) 89 | else: 90 | raise ValueError("Other adapter styles are not supported.") 91 | 92 | 93 | 94 | def vit_base(adapter_cfg, **kwargs): 95 | model = ADPT_VisionTransformerMoCo( 96 | adapter_cfg, 97 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 98 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 99 | model.default_cfg = _cfg() 100 | return model 101 | -------------------------------------------------------------------------------- /src/solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import math 3 | 4 | import torch.optim as optim 5 | from fvcore.common.config import CfgNode 6 | from torch.optim.lr_scheduler import LambdaLR 7 | 8 | 9 | def make_scheduler( 10 | optimizer: optim.Optimizer, train_params: CfgNode 11 | ) -> LambdaLR: 12 | warmup = train_params.WARMUP_EPOCH 13 | total_iters = train_params.TOTAL_EPOCH 14 | 15 | if train_params.SCHEDULER == "cosine": 16 | scheduler = WarmupCosineSchedule( 17 | optimizer, 18 | warmup_steps=warmup, 19 | t_total=total_iters 20 | ) 21 | elif train_params.SCHEDULER == "cosine_hardrestart": 22 | scheduler = WarmupCosineWithHardRestartsSchedule( 23 | optimizer, 24 | warmup_steps=warmup, 25 | t_total=total_iters 26 | ) 27 | 28 | elif train_params.SCHEDULER == "plateau": 29 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 30 | optimizer, 31 | "max", 32 | patience=5, 33 | verbose=True, 34 | factor=train_params.LR_DECAY_FACTOR, 35 | ) 36 | else: 37 | scheduler = None 38 | return scheduler 39 | 40 | 41 | class WarmupCosineSchedule(LambdaLR): 42 | """ Linear warmup and then cosine decay. 43 | Linearly increases learning rate from 0 to 1 over `warmup_steps`. 44 | Decreases learning rate from 1. to 0. over remaining 45 | `t_total - warmup_steps` steps following a cosine curve. 46 | If `cycles` (default=0.5) is different from default, learning rate 47 | follows cosine function after warmup. 48 | """ 49 | def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1): 50 | self.warmup_steps = warmup_steps 51 | self.t_total = t_total 52 | self.cycles = cycles 53 | super(WarmupCosineSchedule, self).__init__( 54 | optimizer, self.lr_lambda, last_epoch=last_epoch) 55 | 56 | def lr_lambda(self, step): 57 | if step < self.warmup_steps: 58 | return float(step) / float(max(1.0, self.warmup_steps)) 59 | # progress after warmup 60 | progress = float(step - self.warmup_steps) / float(max( 61 | 1, self.t_total - self.warmup_steps)) 62 | return max( 63 | 0.0, 64 | 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)) 65 | ) 66 | 67 | 68 | class WarmupCosineWithHardRestartsSchedule(LambdaLR): 69 | """ Linear warmup and then cosine cycles with hard restarts. 70 | Linearly increases learning rate from 0 to 1 over `warmup_steps`. 71 | If `cycles` (default=1.) is different from default, learning rate 72 | follows `cycles` times a cosine decaying learning rate 73 | (with hard restarts). 74 | """ 75 | def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1): 76 | self.warmup_steps = warmup_steps 77 | self.t_total = t_total 78 | self.cycles = cycles 79 | super(WarmupCosineWithHardRestartsSchedule, self).__init__( 80 | optimizer, self.lr_lambda, last_epoch=last_epoch) 81 | 82 | def lr_lambda(self, step): 83 | if step < self.warmup_steps: 84 | return float(step) / float(max(1, self.warmup_steps)) 85 | # progress after warmup 86 | progress = float(step - self.warmup_steps) / float( 87 | max(1, self.t_total - self.warmup_steps)) 88 | if progress >= 1.0: 89 | return 0.0 90 | return max( 91 | 0.0, 92 | 0.5 * (1. + math.cos( 93 | math.pi * ((float(self.cycles) * progress) % 1.0))) 94 | ) 95 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/oxford_iiit_pet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements OxfordIIITPet data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | # This constant specifies the percentage of data that is used to create custom 28 | # train/val splits. Specifically, TRAIN_SPLIT_PERCENT% of the official training 29 | # split is used as a new training split and the rest is used for validation. 30 | TRAIN_SPLIT_PERCENT = 80 31 | 32 | 33 | @Registry.register("data.oxford_iiit_pet", "class") 34 | class OxfordIIITPetData(base.ImageTfdsData): 35 | """Provides OxfordIIITPet data. 36 | 37 | The OxfordIIITPet dataset comes only with a training and test set. 38 | Therefore, the validation set is split out of the original training set, and 39 | the remaining examples are used as the "train" split. The "trainval" split 40 | corresponds to the original training set. 41 | 42 | For additional details and usage, see the base class. 43 | """ 44 | 45 | def __init__(self, data_dir=None, train_split_percent=None): 46 | 47 | dataset_builder = tfds.builder("oxford_iiit_pet:3.*.*", data_dir=data_dir) 48 | dataset_builder.download_and_prepare() 49 | train_split_percent = train_split_percent or TRAIN_SPLIT_PERCENT 50 | 51 | # Creates a dict with example counts for each split. 52 | trainval_count = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples 53 | test_count = dataset_builder.info.splits[tfds.Split.TEST].num_examples 54 | num_samples_splits = { 55 | "train": (train_split_percent * trainval_count) // 100, 56 | "val": trainval_count - (train_split_percent * trainval_count) // 100, 57 | "trainval": trainval_count, 58 | "test": test_count, 59 | "train800": 800, 60 | "val200": 200, 61 | "train800val200": 1000, 62 | } 63 | 64 | # Defines dataset specific train/val/trainval/test splits. 65 | tfds_splits = { 66 | "train": "train[:{}]".format(num_samples_splits["train"]), 67 | "val": "train[{}:]".format(num_samples_splits["train"]), 68 | "trainval": tfds.Split.TRAIN, 69 | "test": tfds.Split.TEST, 70 | "train800": "train[:800]", 71 | "val200": "train[{}:{}]".format( 72 | num_samples_splits["train"], num_samples_splits["train"]+200), 73 | "train800val200": "train[:800]+train[{}:{}]".format( 74 | num_samples_splits["train"], num_samples_splits["train"]+200), 75 | } 76 | 77 | super(OxfordIIITPetData, self).__init__( 78 | dataset_builder=dataset_builder, 79 | tfds_splits=tfds_splits, 80 | num_samples_splits=num_samples_splits, 81 | num_preprocessing_threads=400, 82 | shuffle_buffer_size=10000, 83 | # Note: Export only image and label tensors with their original types. 84 | base_preprocess_fn=base.make_get_tensors_fn(["image", "label"]), 85 | num_classes=dataset_builder.info.features["label"].num_classes) 86 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements EurosatData class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | 28 | TRAIN_SPLIT_PERCENT = 60 29 | VALIDATION_SPLIT_PERCENT = 20 30 | TEST_SPLIT_PERCENT = 20 31 | 32 | 33 | @Registry.register("data.eurosat", "class") 34 | class EurosatData(base.ImageTfdsData): 35 | """Provides EuroSat dataset. 36 | 37 | EuroSAT dataset is based on Sentinel-2 satellite images covering 13 spectral 38 | bands and consisting of 10 classes with 27000 labeled and 39 | geo-referenced samples. 40 | 41 | URL: https://github.com/phelber/eurosat 42 | """ 43 | 44 | def __init__(self, subset="rgb", data_key="image", data_dir=None): 45 | dataset_name = "eurosat/{}:2.*.*".format(subset) 46 | dataset_builder = tfds.builder(dataset_name, data_dir=data_dir) 47 | dataset_builder.download_and_prepare() 48 | 49 | # Example counts are retrieved from the tensorflow dataset info. 50 | num_examples = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples 51 | train_count = num_examples * TRAIN_SPLIT_PERCENT // 100 52 | val_count = num_examples * VALIDATION_SPLIT_PERCENT // 100 53 | test_count = num_examples * TEST_SPLIT_PERCENT // 100 54 | 55 | tfds_splits = { 56 | "train": 57 | "train[:{}]".format(train_count), 58 | "val": 59 | "train[{}:{}]".format(train_count, train_count+val_count), 60 | "trainval": 61 | "train[:{}]".format(train_count+val_count), 62 | "test": 63 | "train[{}:]".format(train_count+val_count), 64 | "train800": 65 | "train[:800]", 66 | "val200": 67 | "train[{}:{}]".format(train_count, train_count+200), 68 | "train800val200": 69 | "train[:800]+train[{}:{}]".format(train_count, train_count+200), 70 | } 71 | 72 | # Creates a dict with example counts for each split. 73 | num_samples_splits = { 74 | "train": train_count, 75 | "val": val_count, 76 | "trainval": train_count + val_count, 77 | "test": test_count, 78 | "train800": 800, 79 | "val200": 200, 80 | "train800val200": 1000, 81 | } 82 | 83 | num_channels = 3 84 | if data_key == "sentinel2": 85 | num_channels = 13 86 | 87 | super(EurosatData, self).__init__( 88 | dataset_builder=dataset_builder, 89 | tfds_splits=tfds_splits, 90 | num_samples_splits=num_samples_splits, 91 | num_preprocessing_threads=100, 92 | shuffle_buffer_size=10000, 93 | base_preprocess_fn=base.make_get_and_cast_tensors_fn({ 94 | data_key: ("image", None), 95 | "label": ("label", None), 96 | }), 97 | image_key=data_key, 98 | num_channels=num_channels, 99 | num_classes=dataset_builder.info.features["label"].num_classes) 100 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/cifar.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements Cifar data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | from . import base as base 23 | from .registry import Registry 24 | import tensorflow_datasets as tfds 25 | 26 | # This constant specifies the percentage of data that is used to create custom 27 | # train/val splits. Specifically, TRAIN_SPLIT_PERCENT% of the official training 28 | # split is used as a new training split and the rest is used for validation. 29 | TRAIN_SPLIT_PERCENT = 90 30 | 31 | 32 | @Registry.register("data.cifar", "class") 33 | class CifarData(base.ImageTfdsData): 34 | """Provides Cifar10 or Cifar100 data. 35 | 36 | Cifar comes only with a training and test set. Therefore, the validation set 37 | is split out of the original training set, and the remaining examples are used 38 | as the "train" split. The "trainval" split corresponds to the original 39 | training set. 40 | 41 | For additional details and usage, see the base class. 42 | """ 43 | 44 | def __init__(self, num_classes=10, data_dir=None, train_split_percent=None): 45 | 46 | if num_classes == 10: 47 | dataset_builder = tfds.builder("cifar10:3.*.*", data_dir=data_dir) 48 | elif num_classes == 100: 49 | dataset_builder = tfds.builder("cifar100:3.*.*", data_dir=data_dir) 50 | else: 51 | raise ValueError( 52 | "Number of classes must be 10 or 100, got {}".format(num_classes)) 53 | 54 | dataset_builder.download_and_prepare() 55 | 56 | train_split_percent = train_split_percent or TRAIN_SPLIT_PERCENT 57 | 58 | # Creates a dict with example counts for each split. 59 | trainval_count = dataset_builder.info.splits["train"].num_examples 60 | test_count = dataset_builder.info.splits["test"].num_examples 61 | num_samples_splits = { 62 | "train": (train_split_percent * trainval_count) // 100, 63 | "val": trainval_count - (train_split_percent * trainval_count) // 100, 64 | "trainval": trainval_count, 65 | "test": test_count, 66 | "train800": 800, 67 | "val200": 200, 68 | "train800val200": 1000, 69 | } 70 | 71 | # Defines dataset specific train/val/trainval/test splits. 72 | tfds_splits = { 73 | "train": "train[:{}]".format(num_samples_splits["train"]), 74 | "val": "train[{}:]".format(num_samples_splits["train"]), 75 | "trainval": "train", 76 | "test": "test", 77 | "train800": "train[:800]", 78 | "val200": "train[{}:{}]".format( 79 | num_samples_splits["train"], num_samples_splits["train"]+200), 80 | "train800val200": "train[:800]+train[{}:{}]".format( 81 | num_samples_splits["train"], num_samples_splits["train"]+200), 82 | } 83 | 84 | super(CifarData, self).__init__( 85 | dataset_builder=dataset_builder, 86 | tfds_splits=tfds_splits, 87 | num_samples_splits=num_samples_splits, 88 | num_preprocessing_threads=400, 89 | shuffle_buffer_size=10000, 90 | # Note: Export only image and label tensors with their original types. 91 | base_preprocess_fn=base.make_get_tensors_fn(["image", "label", "id"]), 92 | num_classes=dataset_builder.info.features["label"].num_classes) 93 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/smallnorb.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements the SmallNORB data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow.compat.v1 as tf 24 | import tensorflow_datasets as tfds 25 | 26 | from . import base as base 27 | from .registry import Registry 28 | # This constant specifies the percentage of data that is used to create custom 29 | # val/test splits. Specifically, VAL_SPLIT_PERCENT% of the official testing 30 | # split is used as a new validation split and the rest is used for testing. 31 | VAL_SPLIT_PERCENT = 50 32 | 33 | 34 | @Registry.register("data.smallnorb", "class") 35 | class SmallNORBData(base.ImageTfdsData): 36 | """Provides the SmallNORB data set. 37 | 38 | SmallNORB comes only with a training and test set. Therefore, the validation 39 | set is split out of the original training set, and the remaining examples are 40 | used as the "train" split. The "trainval" split corresponds to the original 41 | training set. 42 | 43 | For additional details and usage, see the base class. 44 | 45 | The data set page is https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/. 46 | """ 47 | 48 | def __init__(self, predicted_attribute, data_dir=None): 49 | dataset_builder = tfds.builder("smallnorb:2.*.*", data_dir=data_dir) 50 | dataset_builder.download_and_prepare() 51 | 52 | if predicted_attribute not in dataset_builder.info.features: 53 | raise ValueError( 54 | "{} is not a valid attribute to predict.".format(predicted_attribute)) 55 | 56 | # Defines dataset specific train/val/trainval/test splits. 57 | tfds_splits = { 58 | "train": "train", 59 | "val": "test[:{}%]".format(VAL_SPLIT_PERCENT), 60 | "trainval": "train+test[:{}%]".format(VAL_SPLIT_PERCENT), 61 | "test": "test[{}%:]".format(VAL_SPLIT_PERCENT), 62 | "train800": "train[:800]", 63 | "val200": "test[:200]", 64 | "train800val200": "train[:800]+test[:200]", 65 | } 66 | 67 | # Creates a dict with example counts for each split. 68 | train_count = dataset_builder.info.splits["train"].num_examples 69 | test_count = dataset_builder.info.splits["test"].num_examples 70 | num_samples_validation = VAL_SPLIT_PERCENT * test_count // 100 71 | num_samples_splits = { 72 | "train": train_count, 73 | "val": num_samples_validation, 74 | "trainval": train_count + num_samples_validation, 75 | "test": test_count - num_samples_validation, 76 | "train800": 800, 77 | "val200": 200, 78 | "train800val200": 1000, 79 | } 80 | 81 | def preprocess_fn(tensors): 82 | # For consistency with other datasets, image needs to have three channels. 83 | image = tf.tile(tensors["image"], [1, 1, 3]) 84 | return dict(image=image, label=tensors[predicted_attribute]) 85 | 86 | info = dataset_builder.info 87 | super(SmallNORBData, self).__init__( 88 | dataset_builder=dataset_builder, 89 | tfds_splits=tfds_splits, 90 | num_samples_splits=num_samples_splits, 91 | num_preprocessing_threads=400, 92 | shuffle_buffer_size=10000, 93 | # We extract the attribute we want to predict in the preprocessing. 94 | base_preprocess_fn=preprocess_fn, 95 | num_classes=info.features[predicted_attribute].num_classes) 96 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/oxford_flowers102.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements oxford flowers 102 data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | 28 | 29 | @Registry.register("data.oxford_flowers102", "class") 30 | class OxfordFlowers102Data(base.ImageTfdsData): 31 | """Provides Oxford 102 categories flowers dataset. 32 | 33 | See corresponding tfds dataset for details. 34 | 35 | URL: https://www.robots.ox.ac.uk/~vgg/data/flowers/102/ 36 | """ 37 | 38 | def __init__(self, data_dir=None, train_split_percent=None): 39 | dataset_builder = tfds.builder("oxford_flowers102:2.*.*", data_dir=data_dir) 40 | dataset_builder.download_and_prepare() 41 | 42 | # Example counts are retrieved from the tensorflow dataset info. 43 | train_count = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples 44 | val_count = dataset_builder.info.splits[tfds.Split.VALIDATION].num_examples 45 | test_count = dataset_builder.info.splits[tfds.Split.TEST].num_examples 46 | 47 | if train_split_percent: 48 | tfds_splits = { 49 | "train": "train[:{s}%]+validation[:{s}%]".format( 50 | s=train_split_percent), 51 | "val": "train[-{s}%:]+validation[-{s}%:]".format( 52 | s=train_split_percent), 53 | "trainval": "train+validation", 54 | "test": "test", 55 | "train800": "train[:800]", 56 | "val200": "validation[:200]", 57 | "train800val200": "train[:800]+validation[:200]", 58 | } 59 | num_samples_splits = { 60 | "train": (((train_count + val_count) // 100) 61 | * train_split_percent), 62 | "val": (((train_count + val_count) // 100) * 63 | (100 - train_split_percent)), 64 | "trainval": train_count + val_count, 65 | "test": test_count, 66 | "train800": 800, 67 | "val200": 200, 68 | "train800val200": 1000, 69 | } 70 | else: 71 | tfds_splits = { 72 | "train": "train", 73 | "val": "validation", 74 | "trainval": "train+validation", 75 | "test": "test", 76 | "train800": "train[:800]", 77 | "val200": "validation[:200]", 78 | "train800val200": "train[:800]+validation[:200]", 79 | } 80 | num_samples_splits = { 81 | "train": train_count, 82 | "val": val_count, 83 | "trainval": train_count + val_count, 84 | "test": test_count, 85 | "train800": 800, 86 | "val200": 200, 87 | "train800val200": 1000, 88 | } 89 | 90 | super(OxfordFlowers102Data, self).__init__( 91 | dataset_builder=dataset_builder, 92 | tfds_splits=tfds_splits, 93 | num_samples_splits=num_samples_splits, 94 | num_preprocessing_threads=400, 95 | shuffle_buffer_size=10000, 96 | # Note: Rename tensors but keep their original types. 97 | base_preprocess_fn=base.make_get_and_cast_tensors_fn({ 98 | "image": ("image", None), 99 | "label": ("label", None), 100 | }), 101 | num_classes=dataset_builder.info.features["label"] 102 | .num_classes) 103 | -------------------------------------------------------------------------------- /launch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | launch helper functions 4 | """ 5 | import argparse 6 | import os 7 | import sys 8 | import pprint 9 | import PIL 10 | from collections import defaultdict 11 | from tabulate import tabulate 12 | from typing import Tuple 13 | 14 | import torch 15 | from src.utils.file_io import PathManager 16 | from src.utils import logging 17 | from src.utils.distributed import get_rank, get_world_size 18 | 19 | 20 | def collect_torch_env() -> str: 21 | try: 22 | import torch.__config__ 23 | 24 | return torch.__config__.show() 25 | except ImportError: 26 | # compatible with older versions of pytorch 27 | from torch.utils.collect_env import get_pretty_env_info 28 | 29 | return get_pretty_env_info() 30 | 31 | 32 | def get_env_module() -> Tuple[str]: 33 | var_name = "ENV_MODULE" 34 | return var_name, os.environ.get(var_name, "") 35 | 36 | 37 | def collect_env_info() -> str: 38 | data = [] 39 | data.append(("Python", sys.version.replace("\n", ""))) 40 | data.append(get_env_module()) 41 | data.append(("PyTorch", torch.__version__)) 42 | data.append(("PyTorch Debug Build", torch.version.debug)) 43 | 44 | has_cuda = torch.cuda.is_available() 45 | data.append(("CUDA available", has_cuda)) 46 | if has_cuda: 47 | data.append(("CUDA ID", os.environ["CUDA_VISIBLE_DEVICES"])) 48 | devices = defaultdict(list) 49 | for k in range(torch.cuda.device_count()): 50 | devices[torch.cuda.get_device_name(k)].append(str(k)) 51 | for name, devids in devices.items(): 52 | data.append(("GPU " + ",".join(devids), name)) 53 | data.append(("Pillow", PIL.__version__)) 54 | 55 | try: 56 | import cv2 57 | 58 | data.append(("cv2", cv2.__version__)) 59 | except ImportError: 60 | pass 61 | env_str = tabulate(data) + "\n" 62 | env_str += collect_torch_env() 63 | return env_str 64 | 65 | 66 | def default_argument_parser(): 67 | """ 68 | create a simple parser to wrap around config file 69 | """ 70 | parser = argparse.ArgumentParser(description="visual-prompt") 71 | parser.add_argument( 72 | "--config-file", default="", metavar="FILE", help="path to config file") 73 | parser.add_argument( 74 | "--train-type", default="", help="training types") 75 | parser.add_argument( 76 | "--sparse-train", default=False, action="store_true", help="sparse training") 77 | parser.add_argument( 78 | "--use-wandb", action="store_true", help="use wandb to log") 79 | parser.add_argument( 80 | "opts", 81 | help="Modify config options using the command-line", 82 | default=None, 83 | nargs=argparse.REMAINDER, 84 | ) 85 | 86 | return parser 87 | 88 | 89 | def logging_train_setup(args, cfg) -> None: 90 | output_dir = cfg.OUTPUT_DIR 91 | if output_dir: 92 | PathManager.mkdirs(output_dir) 93 | 94 | logger = logging.setup_logging( 95 | cfg.NUM_GPUS, get_world_size(), output_dir, name="MOSA") 96 | 97 | # Log basic information about environment, cmdline arguments, and config 98 | rank = get_rank() 99 | logger.info( 100 | f"Rank of current process: {rank}. World size: {get_world_size()}") 101 | logger.info("Environment info:\n" + collect_env_info()) 102 | 103 | logger.info("Command line arguments: " + str(args)) 104 | if hasattr(args, "config_file") and args.config_file != "": 105 | logger.info( 106 | "Contents of args.config_file={}:\n{}".format( 107 | args.config_file, 108 | PathManager.open(args.config_file, "r").read() 109 | ) 110 | ) 111 | # Show the config 112 | logger.info("Training with config:") 113 | logger.info(pprint.pformat(cfg)) 114 | # cudnn benchmark has large overhead. 115 | # It shouldn't be used considering the small size of typical val set. 116 | if not (hasattr(args, "eval_only") and args.eval_only): 117 | torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK 118 | -------------------------------------------------------------------------------- /src/engine/evaluator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import numpy as np 3 | 4 | from collections import defaultdict 5 | from typing import List, Union 6 | 7 | from .eval import multilabel 8 | from .eval import singlelabel 9 | from ..utils import logging 10 | logger = logging.get_logger("MOSA") 11 | 12 | 13 | class Evaluator(): 14 | """ 15 | An evaluator with below logics: 16 | 17 | 1. find which eval module to use. 18 | 2. store the eval results, pretty print it in log file as well. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | ) -> None: 24 | self.results = defaultdict(dict) 25 | self.iteration = -1 26 | self.threshold_end = 0.5 27 | 28 | def update_iteration(self, iteration: int) -> None: 29 | """update iteration info""" 30 | self.iteration = iteration 31 | 32 | def update_result(self, metric: str, value: Union[float, dict]) -> None: 33 | if self.iteration > -1: 34 | key_name = "epoch_" + str(self.iteration) 35 | else: 36 | key_name = "final" 37 | if isinstance(value, float): 38 | self.results[key_name].update({metric: value}) 39 | else: 40 | if metric in self.results[key_name]: 41 | self.results[key_name][metric].update(value) 42 | else: 43 | self.results[key_name].update({metric: value}) 44 | 45 | def classify(self, probs, targets, test_data, multilabel=False): 46 | """ 47 | Evaluate classification result. 48 | Args: 49 | probs: np.ndarray for num_data x num_class, predicted probabilities 50 | targets: np.ndarray for multilabel, list of integers for single label 51 | test_labels: map test image ids to a list of class labels 52 | """ 53 | if not targets: 54 | raise ValueError( 55 | "When evaluating classification, need at least give targets") 56 | 57 | if multilabel: 58 | self._eval_multilabel(probs, targets, test_data) 59 | else: 60 | self._eval_singlelabel(probs, targets, test_data) 61 | 62 | def _eval_singlelabel( 63 | self, 64 | scores: np.ndarray, 65 | targets: List[int], 66 | eval_type: str 67 | ) -> None: 68 | """ 69 | if number of labels > 2: 70 | top1 and topk (5 by default) accuracy 71 | if number of labels == 2: 72 | top1 and rocauc 73 | """ 74 | acc_dict = singlelabel.compute_acc_auc(scores, targets) 75 | 76 | log_results = { 77 | k: np.around(v * 100, decimals=2) for k, v in acc_dict.items() 78 | } 79 | save_results = acc_dict 80 | 81 | self.log_and_update(log_results, save_results, eval_type) 82 | 83 | def _eval_multilabel( 84 | self, 85 | scores: np.ndarray, 86 | targets: np.ndarray, 87 | eval_type: str 88 | ) -> None: 89 | num_labels = scores.shape[-1] 90 | targets = multilabel.multihot(targets, num_labels) 91 | 92 | log_results = {} 93 | ap, ar, mAP, mAR = multilabel.compute_map(scores, targets) 94 | f1_dict = multilabel.get_best_f1_scores( 95 | targets, scores, self.threshold_end) 96 | 97 | log_results["mAP"] = np.around(mAP * 100, decimals=2) 98 | log_results["mAR"] = np.around(mAR * 100, decimals=2) 99 | log_results.update({ 100 | k: np.around(v * 100, decimals=2) for k, v in f1_dict.items()}) 101 | save_results = { 102 | "ap": ap, "ar": ar, "mAP": mAP, "mAR": mAR, "f1": f1_dict 103 | } 104 | self.log_and_update(log_results, save_results, eval_type) 105 | 106 | def log_and_update(self, log_results, save_results, eval_type): 107 | log_str = "" 108 | for k, result in log_results.items(): 109 | if not isinstance(result, np.ndarray): 110 | log_str += f"{k}: {result:.2f}\t" 111 | else: 112 | log_str += f"{k}: {list(result)}\t" 113 | logger.info(f"Classification results with {eval_type}: {log_str}") 114 | # save everything 115 | self.update_result("classification", {eval_type: save_results}) 116 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/svhn.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements Svhn data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow_datasets as tfds 24 | 25 | from . import base as base 26 | from .registry import Registry 27 | # This constant specifies the percentage of data that is used to create custom 28 | # train/val splits. Specifically, TRAIN_SPLIT_PERCENT% of the official training 29 | # split is used as a new training split and the rest is used for validation. 30 | TRAIN_SPLIT_PERCENT = 90 31 | 32 | 33 | @Registry.register("data.svhn", "class") 34 | class SvhnData(base.ImageTfdsData): 35 | """Provides SVHN data. 36 | 37 | The Street View House Numbers (SVHN) Dataset is an image digit recognition 38 | dataset of over 600,000 color digit images coming from real world data. 39 | Split size: 40 | - Training set: 73,257 images 41 | - Testing set: 26,032 images 42 | - Extra training set: 531,131 images 43 | Following the common setup on SVHN, we only use the official training and 44 | testing data. Images are cropped to 32x32. 45 | 46 | URL: http://ufldl.stanford.edu/housenumbers/ 47 | """ 48 | 49 | def __init__(self, data_dir=None): 50 | dataset_builder = tfds.builder("svhn_cropped:3.*.*", data_dir=data_dir) 51 | dataset_builder.download_and_prepare() 52 | 53 | # Example counts are retrieved from the tensorflow dataset info. 54 | trainval_count = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples 55 | test_count = dataset_builder.info.splits[tfds.Split.TEST].num_examples 56 | 57 | # Creates a dict with example counts for each split. 58 | num_samples_splits = { 59 | # Calculates the train/val split example count based on percent. 60 | "train": TRAIN_SPLIT_PERCENT * trainval_count // 100, 61 | "val": trainval_count - TRAIN_SPLIT_PERCENT * trainval_count // 100, 62 | "trainval": trainval_count, 63 | "test": test_count, 64 | "train800": 800, 65 | "val200": 200, 66 | "train800val200": 1000, 67 | } 68 | 69 | # Defines dataset specific train/val/trainval/test splits. 70 | # The validation set is split out of the original training set, and the 71 | # remaining examples are used as the "train" split. The "trainval" split 72 | # corresponds to the original training set. 73 | tfds_splits = { 74 | "train": 75 | "train[:{}]".format(num_samples_splits["train"]), 76 | "val": 77 | "train[{}:]".format(num_samples_splits["train"]), 78 | "trainval": 79 | "train", 80 | "test": 81 | "test", 82 | "train800": 83 | "train[:800]", 84 | "val200": 85 | "train[{}:{}]".format(num_samples_splits["train"], 86 | num_samples_splits["train"] + 200), 87 | "train800val200": 88 | "train[:800]+train[{}:{}]".format( 89 | num_samples_splits["train"], num_samples_splits["train"] + 200), 90 | } 91 | 92 | super(SvhnData, self).__init__( 93 | dataset_builder=dataset_builder, 94 | tfds_splits=tfds_splits, 95 | num_samples_splits=num_samples_splits, 96 | num_preprocessing_threads=400, 97 | shuffle_buffer_size=10000, 98 | # Note: Rename tensors but keep their original types. 99 | base_preprocess_fn=base.make_get_and_cast_tensors_fn({ 100 | "image": ("image", None), 101 | "label": ("label", None), 102 | }), 103 | num_classes=dataset_builder.info.features["label"] 104 | .num_classes) 105 | -------------------------------------------------------------------------------- /src/data/loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Data loader.""" 4 | import torch 5 | from torch.utils.data.distributed import DistributedSampler 6 | from torch.utils.data.sampler import RandomSampler 7 | 8 | from ..utils import logging 9 | from .datasets.json_dataset import ( 10 | CUB200Dataset, CarsDataset, DogsDataset, FlowersDataset, NabirdsDataset 11 | ) 12 | from .datasets.pt_dataset import ( 13 | CIFAR100Dataset, AircraftDataset, Food101Dataset 14 | ) 15 | 16 | logger = logging.get_logger("MOSA") 17 | _TORCH_BASIC_DS = { 18 | "cifar100": CIFAR100Dataset, 19 | 'aircraft': AircraftDataset, 20 | "food101": Food101Dataset, 21 | } 22 | _DATASET_CATALOG = { 23 | "CUB": CUB200Dataset, 24 | "OxfordFlowers": FlowersDataset, 25 | "StanfordCars": CarsDataset, 26 | "StanfordDogs": DogsDataset, 27 | "nabirds": NabirdsDataset, 28 | } 29 | 30 | 31 | def _construct_loader(cfg, split, batch_size, shuffle, drop_last): 32 | """Constructs the data loader for the given dataset.""" 33 | dataset_name = cfg.DATA.NAME 34 | 35 | # Construct the dataset 36 | if dataset_name.startswith("vtab-"): 37 | # import the tensorflow here only if needed 38 | from .datasets.tf_dataset import TFDataset 39 | dataset = TFDataset(cfg, split) 40 | # from .datasets.vtab_dataset import VTABDataset 41 | # dataset = VTABDataset(cfg, split) 42 | 43 | elif dataset_name == "multi": 44 | from .datasets.json_dataset import MultiDataset 45 | dataset = MultiDataset(cfg, split) 46 | 47 | else: 48 | if dataset_name in _TORCH_BASIC_DS.keys(): 49 | dataset = _TORCH_BASIC_DS[dataset_name](cfg, split) 50 | elif dataset_name in _DATASET_CATALOG.keys(): 51 | dataset = _DATASET_CATALOG[dataset_name](cfg, split) 52 | else: 53 | raise ValueError("Dataset '{}' not supported".format(dataset_name)) 54 | 55 | # Create a sampler for multi-process training 56 | sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None 57 | # Create a loader 58 | loader = torch.utils.data.DataLoader( 59 | dataset, 60 | batch_size=batch_size, 61 | shuffle=(False if sampler else shuffle), 62 | sampler=sampler, 63 | num_workers=cfg.DATA.NUM_WORKERS, 64 | pin_memory=cfg.DATA.PIN_MEMORY, 65 | drop_last=drop_last, 66 | ) 67 | return loader 68 | 69 | 70 | def construct_train_loader(cfg): 71 | """Train loader wrapper.""" 72 | if cfg.NUM_GPUS > 1: 73 | drop_last = True 74 | else: 75 | drop_last = False 76 | return _construct_loader( 77 | cfg=cfg, 78 | split="train", 79 | batch_size=int(cfg.DATA.BATCH_SIZE / cfg.NUM_GPUS), 80 | shuffle=True, 81 | drop_last=drop_last, 82 | ) 83 | 84 | 85 | def construct_trainval_loader(cfg): 86 | """Train loader wrapper.""" 87 | if cfg.NUM_GPUS > 1: 88 | drop_last = True 89 | else: 90 | drop_last = False 91 | return _construct_loader( 92 | cfg=cfg, 93 | split="trainval", 94 | batch_size=int(cfg.DATA.BATCH_SIZE / cfg.NUM_GPUS), 95 | shuffle=True, 96 | drop_last=drop_last, 97 | ) 98 | 99 | 100 | def construct_test_loader(cfg): 101 | """Test loader wrapper.""" 102 | return _construct_loader( 103 | cfg=cfg, 104 | split="test", 105 | batch_size=int(cfg.DATA.BATCH_SIZE / cfg.NUM_GPUS), 106 | shuffle=False, 107 | drop_last=False, 108 | ) 109 | 110 | 111 | def construct_val_loader(cfg, batch_size=None): 112 | if batch_size is None: 113 | bs = int(cfg.DATA.BATCH_SIZE / cfg.NUM_GPUS) 114 | else: 115 | bs = batch_size 116 | """Validation loader wrapper.""" 117 | return _construct_loader( 118 | cfg=cfg, 119 | split="val", 120 | batch_size=bs, 121 | shuffle=False, 122 | drop_last=False, 123 | ) 124 | 125 | 126 | def shuffle(loader, cur_epoch): 127 | """"Shuffles the data.""" 128 | assert isinstance( 129 | loader.sampler, (RandomSampler, DistributedSampler) 130 | ), "Sampler type '{}' not supported".format(type(loader.sampler)) 131 | # RandomSampler handles shuffling automatically 132 | if isinstance(loader.sampler, DistributedSampler): 133 | # DistributedSampler shuffles data based on epoch 134 | loader.sampler.set_epoch(cur_epoch) 135 | -------------------------------------------------------------------------------- /src/models/vit_adapter/vit_mae.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | borrow from https://github.com/facebookresearch/mae/blob/main/models_vit.py 4 | """ 5 | from functools import partial 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from .adapter_block import Pfeiffer_Block 11 | from ..vit_backbones.vit_mae import VisionTransformer 12 | from timm.models.layers import PatchEmbed 13 | from ...utils import logging 14 | logger = logging.get_logger("MOSA") 15 | 16 | 17 | class ADPT_VisionTransformer(VisionTransformer): 18 | """ Vision Transformer with support for global average pooling 19 | """ 20 | def __init__( 21 | self, 22 | adapter_cfg, 23 | img_size=224, 24 | patch_size=16, 25 | in_chans=3, 26 | num_classes=1000, 27 | embed_dim=768, 28 | depth=12, 29 | num_heads=12, 30 | mlp_ratio=4., 31 | qkv_bias=True, 32 | representation_size=None, 33 | distilled=False, 34 | drop_rate=0., 35 | attn_drop_rate=0., 36 | drop_path_rate=0., 37 | embed_layer=PatchEmbed, 38 | norm_layer=None, 39 | act_layer=None, 40 | weight_init='', 41 | **kwargs): 42 | 43 | super(ADPT_VisionTransformer, self).__init__( 44 | img_size=img_size, 45 | patch_size=patch_size, 46 | in_chans=in_chans, 47 | num_classes=num_classes, 48 | embed_dim=embed_dim, 49 | depth=depth, 50 | num_heads=num_heads, 51 | mlp_ratio=mlp_ratio, 52 | qkv_bias=qkv_bias, 53 | representation_size=representation_size, 54 | distilled=distilled, 55 | drop_rate=drop_rate, 56 | attn_drop_rate=attn_drop_rate, 57 | drop_path_rate=drop_path_rate, 58 | embed_layer=embed_layer, 59 | norm_layer=norm_layer, 60 | act_layer=act_layer, 61 | weight_init=weight_init, 62 | **kwargs 63 | ) 64 | 65 | self.adapter_cfg = adapter_cfg 66 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 67 | act_layer = act_layer or nn.GELU 68 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 69 | 70 | if adapter_cfg.STYLE == "Pfeiffer": 71 | self.blocks = nn.Sequential(*[ 72 | Pfeiffer_Block( 73 | adapter_config=adapter_cfg, 74 | dim=embed_dim, 75 | num_heads=num_heads, 76 | mlp_ratio=mlp_ratio, 77 | qkv_bias=qkv_bias, 78 | drop=drop_rate, 79 | attn_drop=attn_drop_rate, 80 | drop_path=dpr[i], 81 | norm_layer=norm_layer, 82 | act_layer=act_layer) for i in range(depth)]) 83 | else: 84 | raise ValueError("Other adapter styles are not supported.") 85 | 86 | 87 | 88 | def build_model(model_type, adapter_cfg): 89 | if "vitb" in model_type: 90 | return vit_base_patch16(adapter_cfg) 91 | elif "vitl" in model_type: 92 | return vit_large_patch16(adapter_cfg) 93 | elif "vith" in model_type: 94 | return vit_huge_patch14(adapter_cfg) 95 | 96 | 97 | def vit_base_patch16(adapter_cfg, **kwargs): 98 | model = ADPT_VisionTransformer( 99 | adapter_cfg, 100 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 101 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 102 | mlp_ratio=4, qkv_bias=True, 103 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 104 | return model 105 | 106 | 107 | def vit_large_patch16(adapter_cfg, **kwargs): 108 | model = ADPT_VisionTransformer( 109 | adapter_cfg, 110 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 111 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 112 | mlp_ratio=4, qkv_bias=True, 113 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 114 | return model 115 | 116 | 117 | def vit_huge_patch14(adapter_cfg, **kwargs): 118 | model = ADPT_VisionTransformer( 119 | adapter_cfg, 120 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune 121 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 122 | mlp_ratio=4, qkv_bias=True, 123 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 124 | return model 125 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/clevr.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements CLEVR data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import numpy as np 24 | import tensorflow.compat.v1 as tf 25 | import tensorflow_datasets as tfds 26 | 27 | from . import base as base 28 | from .registry import Registry 29 | 30 | TRAIN_SPLIT_PERCENT = 90 31 | 32 | 33 | def _count_preprocess_fn(x): 34 | return {"image": x["image"], 35 | "label": tf.size(x["objects"]["size"]) - 3} 36 | 37 | 38 | def _count_cylinders_preprocess_fn(x): 39 | # Class distribution: 40 | 41 | num_cylinders = tf.reduce_sum( 42 | tf.cast(tf.equal(x["objects"]["shape"], 2), tf.int32)) 43 | return {"image": x["image"], "label": num_cylinders} 44 | 45 | 46 | def _closest_object_preprocess_fn(x): 47 | dist = tf.reduce_min(x["objects"]["pixel_coords"][:, 2]) 48 | # These thresholds are uniformly spaced and result in more or less balanced 49 | # distribution of classes, see the resulting histogram: 50 | 51 | thrs = np.array([0.0, 8.0, 8.5, 9.0, 9.5, 10.0, 100.0]) 52 | label = tf.reduce_max(tf.where((thrs - dist) < 0)) 53 | return {"image": x["image"], 54 | "label": label} 55 | 56 | 57 | _TASK_DICT = { 58 | "count_all": { 59 | "preprocess_fn": _count_preprocess_fn, 60 | "num_classes": 8 61 | }, 62 | "count_cylinders": { 63 | "preprocess_fn": _count_cylinders_preprocess_fn, 64 | "num_classes": 11 65 | }, 66 | "closest_object_distance": { 67 | "preprocess_fn": _closest_object_preprocess_fn, 68 | "num_classes": 6 69 | }, 70 | } 71 | 72 | 73 | @Registry.register("data.clevr", "class") 74 | class CLEVRData(base.ImageTfdsData): 75 | """Provides CLEVR dataset. 76 | 77 | Currently, two tasks are supported: 78 | 1. Predict number of objects. 79 | 2. Predict distnace to the closest object. 80 | """ 81 | 82 | def __init__(self, task, data_dir=None): 83 | 84 | if task not in _TASK_DICT: 85 | raise ValueError("Unknown task: %s" % task) 86 | 87 | dataset_builder = tfds.builder("clevr:3.*.*", data_dir=data_dir) 88 | dataset_builder.download_and_prepare() 89 | 90 | # Creates a dict with example counts for each split. 91 | trainval_count = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples 92 | test_count = dataset_builder.info.splits[tfds.Split.TEST].num_examples 93 | num_samples_splits = { 94 | "train": (TRAIN_SPLIT_PERCENT * trainval_count) // 100, 95 | "val": trainval_count - (TRAIN_SPLIT_PERCENT * trainval_count) // 100, 96 | "trainval": trainval_count, 97 | "test": test_count, 98 | "train800": 800, 99 | "val200": 200, 100 | "train800val200": 1000, 101 | } 102 | 103 | # Defines dataset specific train/val/trainval/test splits. 104 | tfds_splits = { 105 | "train": "train[:{}]".format(num_samples_splits["train"]), 106 | "val": "train[{}:]".format(num_samples_splits["train"]), 107 | "trainval": "train", 108 | "test": "validation", 109 | "train800": "train[:800]", 110 | "val200": "train[{}:{}]".format( 111 | num_samples_splits["train"], num_samples_splits["train"]+200), 112 | "train800val200": "train[:800]+train[{}:{}]".format( 113 | num_samples_splits["train"], num_samples_splits["train"]+200), 114 | } 115 | 116 | task = _TASK_DICT[task] 117 | base_preprocess_fn = task["preprocess_fn"] 118 | 119 | super(CLEVRData, self).__init__( 120 | dataset_builder=dataset_builder, 121 | tfds_splits=tfds_splits, 122 | num_samples_splits=num_samples_splits, 123 | num_preprocessing_threads=400, 124 | shuffle_buffer_size=10000, 125 | # Note: Export only image and label tensors with their original types. 126 | base_preprocess_fn=base_preprocess_fn, 127 | num_classes=task["num_classes"]) 128 | -------------------------------------------------------------------------------- /src/utils/pruner.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Pruner class 4 | """ 5 | import torch 6 | 7 | 8 | class Pruner: 9 | def __init__(self, cfg): 10 | self.cfg = cfg 11 | if cfg.MODEL.TRANSFER_TYPE == "adapter" or cfg.MODEL.TRANSFER_TYPE == "mosa": 12 | self.num = cfg.MODEL.ADAPTER.EXPERT_NUM 13 | elif cfg.MODEL.TRANSFER_TYPE == "lora" or cfg.MODEL.TRANSFER_TYPE == "mosl": 14 | self.num = cfg.MODEL.LORA.EXPERT_NUM 15 | self.sparsity = 1 / self.num 16 | 17 | def score(self, param): 18 | raise NotImplementedError 19 | 20 | def prune(self, score): 21 | k = int((1.0 - self.sparsity) * score.numel()) 22 | threshold, _ = torch.kthvalue(torch.flatten(score), k) 23 | 24 | zero = torch.LongTensor([0]).to(score.device) 25 | one = torch.LongTensor([1]).to(score.device) 26 | return torch.where(score <= threshold, zero, one) 27 | 28 | # def divide(self, score): 29 | # masks = [] 30 | # zero = torch.LongTensor([0]).to(score.device) 31 | # one = torch.LongTensor([1]).to(score.device) 32 | 33 | # for i in range(self.num): 34 | # mask = torch.ones_like(score, dtype=torch.long) 35 | 36 | # lower_k = int((self.sparsity * i) * score.numel()) + 1 37 | # lower_threshold, _ = torch.kthvalue(torch.flatten(score), lower_k) 38 | # mask = torch.where(score < lower_threshold, zero, mask) 39 | 40 | # upper_k = int((self.sparsity * (i + 1)) * score.numel()) 41 | # upper_threshold, _ = torch.kthvalue(torch.flatten(score), upper_k) 42 | # mask = torch.where(score > upper_threshold, zero, mask) 43 | 44 | # masks.append(mask) 45 | 46 | # return masks 47 | 48 | def divide(self, score, mode="all"): 49 | masks = [] 50 | zero = torch.LongTensor([0]).to(score.device) 51 | one = torch.LongTensor([1]).to(score.device) 52 | 53 | for i in range(self.num): 54 | mask = torch.ones_like(score, dtype=torch.long) 55 | if self.cfg.MODEL.ADAPTER.BIAS and len(score.shape) == 1: 56 | masks.append(mask) 57 | continue 58 | 59 | if mode == "all": 60 | lower_k = int((self.sparsity * i) * score.numel()) + 1 61 | lower_threshold, _ = torch.kthvalue(torch.flatten(score), lower_k) 62 | mask = torch.where(score < lower_threshold, zero, mask) 63 | 64 | upper_k = int((self.sparsity * (i + 1)) * score.numel()) 65 | upper_threshold, _ = torch.kthvalue(torch.flatten(score), upper_k) 66 | mask = torch.where(score > upper_threshold, zero, mask) 67 | 68 | elif mode == "row": 69 | r, c = score.shape 70 | flag = torch.zeros_like(score, dtype=torch.long) 71 | flag += torch.arange(r, device=score.device).view(r, 1) 72 | 73 | for j in range(r): 74 | lower_k = int((self.sparsity * i) * score[j].numel()) + 1 75 | lower_threshold, _ = torch.kthvalue(torch.flatten(score[j]), lower_k) 76 | mask = torch.where((score < lower_threshold) & (flag == j), zero, mask) 77 | 78 | upper_k = int((self.sparsity * (i + 1)) * score[j].numel()) 79 | upper_threshold, _ = torch.kthvalue(torch.flatten(score[j]), upper_k) 80 | mask = torch.where((score > upper_threshold) & (flag == j), zero, mask) 81 | 82 | elif mode == "column": 83 | r, c = score.shape 84 | flag = torch.zeros_like(score, dtype=torch.long) 85 | flag += torch.arange(c, device=score.device) 86 | 87 | for j in range(c): 88 | lower_k = int((self.sparsity * i) * score[:, j].numel()) + 1 89 | lower_threshold, _ = torch.kthvalue(torch.flatten(score[:, j]), lower_k) 90 | mask = torch.where((score < lower_threshold) & (flag == j), zero, mask) 91 | 92 | upper_k = int((self.sparsity * (i + 1)) * score[:, j].numel()) 93 | upper_threshold, _ = torch.kthvalue(torch.flatten(score[:, j]), upper_k) 94 | mask = torch.where((score > upper_threshold) & (flag == j), zero, mask) 95 | 96 | masks.append(mask) 97 | 98 | return masks 99 | 100 | class Rand(Pruner): 101 | def __init__(self, cfg): 102 | super(Rand, self).__init__(cfg) 103 | 104 | def score(self, param): 105 | return torch.rand_like(param) 106 | 107 | -------------------------------------------------------------------------------- /src/configs/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Config system (based on Detectron's).""" 4 | 5 | from .config_node import CfgNode 6 | 7 | 8 | # Global config object 9 | _C = CfgNode() 10 | # Example usage: 11 | # from configs.config import cfg 12 | 13 | _C.DBG = False 14 | _C.GPU_ID = 0 15 | _C.OUTPUT_DIR = "./output" 16 | _C.RUN_N_TIMES = 5 17 | # Perform benchmarking to select the fastest CUDNN algorithms to use 18 | # Note that this may increase the memory usage and will likely not result 19 | # in overall speedups when variable size inputs are used (e.g. COCO training) 20 | _C.CUDNN_BENCHMARK = False 21 | 22 | # Number of GPUs to use (applies to both training and testing) 23 | _C.NUM_GPUS = 1 24 | _C.NUM_SHARDS = 1 25 | 26 | # Note that non-determinism may still be present due to non-deterministic 27 | # operator implementations in GPU operator libraries 28 | _C.SEED = None 29 | 30 | # ---------------------------------------------------------------------- 31 | # Model options 32 | # ---------------------------------------------------------------------- 33 | _C.MODEL = CfgNode() 34 | _C.MODEL.TRANSFER_TYPE = "linear" # one of linear, end2end, prompt, adapter, side, partial-1, tinytl-bias 35 | _C.MODEL.WEIGHT_PATH = "" # if resume from some checkpoint file 36 | _C.MODEL.SAVE_CKPT = False 37 | 38 | _C.MODEL.MODEL_ROOT = "" # root folder for pretrained model weights 39 | 40 | _C.MODEL.TYPE = "vit" 41 | _C.MODEL.MLP_NUM = 0 42 | 43 | _C.MODEL.LINEAR = CfgNode() 44 | _C.MODEL.LINEAR.MLP_SIZES = [] 45 | _C.MODEL.LINEAR.DROPOUT = 0.1 46 | 47 | # ---------------------------------------------------------------------- 48 | # Pruner options 49 | # ---------------------------------------------------------------------- 50 | _C.PRUNER = CfgNode() 51 | 52 | _C.PRUNER.TYPE = "random" 53 | _C.PRUNER.NUM = 4 54 | 55 | # ---------------------------------------------------------------------- 56 | # Adapter options 57 | # ---------------------------------------------------------------------- 58 | _C.MODEL.ADAPTER = CfgNode() 59 | _C.MODEL.ADAPTER.BOTTLENECK_SIZE = 64 60 | _C.MODEL.ADAPTER.STYLE = "AdaptFormer" 61 | _C.MODEL.ADAPTER.SCALAR = 0.1 62 | _C.MODEL.ADAPTER.DROPOUT = 0.1 63 | _C.MODEL.ADAPTER.MOE = False 64 | _C.MODEL.ADAPTER.BIAS = True 65 | _C.MODEL.ADAPTER.EXPERT_NUM = 4 66 | _C.MODEL.ADAPTER.MERGE = "add" 67 | _C.MODEL.ADAPTER.SHARE = None 68 | _C.MODEL.ADAPTER.ADDITIONAL = False 69 | _C.MODEL.ADAPTER.DEEPREG = False 70 | _C.MODEL.ADAPTER.ADD_WEIGHT = 0.0 71 | _C.MODEL.ADAPTER.REG_WEIGHT = 1.0 72 | 73 | # ---------------------------------------------------------------------- 74 | # LoRA options 75 | # ---------------------------------------------------------------------- 76 | _C.MODEL.LORA = CfgNode() 77 | _C.MODEL.LORA.RANK = 16 78 | _C.MODEL.LORA.MODE = "qv" 79 | _C.MODEL.LORA.SCALAR = 1.0 80 | _C.MODEL.LORA.DROPOUT = 0.0 81 | _C.MODEL.LORA.MOE = False 82 | _C.MODEL.LORA.BIAS = False 83 | _C.MODEL.LORA.EXPERT_NUM = 4 84 | _C.MODEL.LORA.MERGE = "add" 85 | _C.MODEL.LORA.SHARE = None 86 | _C.MODEL.LORA.ADDITIONAL = False 87 | _C.MODEL.LORA.DEEPREG = False 88 | _C.MODEL.LORA.ADD_WEIGHT = 0.0 89 | _C.MODEL.LORA.REG_WEIGHT = 1.0 90 | 91 | # ---------------------------------------------------------------------- 92 | # Solver options 93 | # ---------------------------------------------------------------------- 94 | _C.SOLVER = CfgNode() 95 | _C.SOLVER.LOSS = "softmax" 96 | _C.SOLVER.LOSS_ALPHA = 0.01 97 | 98 | _C.SOLVER.OPTIMIZER = "sgd" # or "adamw" 99 | _C.SOLVER.MOMENTUM = 0.9 100 | _C.SOLVER.WEIGHT_DECAY = 0.0001 101 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0 102 | 103 | _C.SOLVER.PATIENCE = 300 104 | 105 | 106 | _C.SOLVER.SCHEDULER = "cosine" 107 | 108 | _C.SOLVER.BASE_LR = 0.01 109 | _C.SOLVER.BIAS_MULTIPLIER = 1. # for prompt + bias 110 | 111 | _C.SOLVER.WARMUP_EPOCH = 5 112 | _C.SOLVER.MERGE_EPOCH = 5 113 | _C.SOLVER.TOTAL_EPOCH = 30 114 | _C.SOLVER.LOG_EVERY_N = 1000 115 | 116 | 117 | _C.SOLVER.DBG_TRAINABLE = False # if True, will print the name of trainable params 118 | 119 | # ---------------------------------------------------------------------- 120 | # Dataset options 121 | # ---------------------------------------------------------------------- 122 | _C.DATA = CfgNode() 123 | 124 | _C.DATA.NAME = "" 125 | _C.DATA.DATAPATH = "" 126 | _C.DATA.FEATURE = "" # e.g. inat2021_supervised 127 | 128 | _C.DATA.PERCENTAGE = 1.0 129 | _C.DATA.NUMBER_CLASSES = -1 130 | _C.DATA.MULTILABEL = False 131 | _C.DATA.CLASS_WEIGHTS_TYPE = "none" 132 | 133 | _C.DATA.CROPSIZE = 224 # or 384 134 | 135 | _C.DATA.NO_TEST = False 136 | _C.DATA.BATCH_SIZE = 32 137 | # Number of data loader workers per training process 138 | _C.DATA.NUM_WORKERS = 4 139 | # Load data to pinned host memory 140 | _C.DATA.PIN_MEMORY = True 141 | 142 | _C.DIST_BACKEND = "nccl" 143 | _C.DIST_INIT_PATH = "env://" 144 | _C.DIST_INIT_FILE = "" 145 | 146 | 147 | def get_cfg(): 148 | """ 149 | Get a copy of the default config. 150 | """ 151 | return _C.clone() 152 | -------------------------------------------------------------------------------- /VTAB_SETUP.md: -------------------------------------------------------------------------------- 1 | # VTAB Preperation 2 | 3 | ## Download and prepare 4 | 5 | It is recommended to download the data before the experiments, to avoid duplicated effort if submitting experiments for multiple tuning protocols. Here are the collective command to set up the vtab data. 6 | 7 | ```python 8 | import tensorflow_datasets as tfds 9 | data_dir = "" # TODO: setup the data_dir to put the the data to, the DATA.DATAPATH value in config 10 | 11 | # caltech101 12 | dataset_builder = tfds.builder("caltech101:3.*.*", data_dir=data_dir) 13 | dataset_builder.download_and_prepare() 14 | 15 | # cifar100 16 | dataset_builder = tfds.builder("cifar100:3.*.*", data_dir=data_dir) 17 | dataset_builder.download_and_prepare() 18 | 19 | # clevr 20 | dataset_builder = tfds.builder("clevr:3.*.*", data_dir=data_dir) 21 | dataset_builder.download_and_prepare() 22 | 23 | # dmlab 24 | dataset_builder = tfds.builder("dmlab:2.0.1", data_dir=data_dir) 25 | dataset_builder.download_and_prepare() 26 | 27 | # dsprites 28 | dataset_builder = tfds.builder("dsprites:2.*.*", data_dir=data_dir) 29 | dataset_builder.download_and_prepare() 30 | 31 | # dtd 32 | dataset_builder = tfds.builder("dtd:3.*.*", data_dir=data_dir) 33 | dataset_builder.download_and_prepare() 34 | 35 | # eurosat 36 | subset="rgb" 37 | dataset_name = "eurosat/{}:2.*.*".format(subset) 38 | dataset_builder = tfds.builder(dataset_name, data_dir=data_dir) 39 | dataset_builder.download_and_prepare() 40 | 41 | # oxford_flowers102 42 | dataset_builder = tfds.builder("oxford_flowers102:2.*.*", data_dir=data_dir) 43 | dataset_builder.download_and_prepare() 44 | 45 | # oxford_iiit_pet 46 | dataset_builder = tfds.builder("oxford_iiit_pet:3.*.*", data_dir=data_dir) 47 | dataset_builder.download_and_prepare() 48 | 49 | # patch_camelyon 50 | dataset_builder = tfds.builder("patch_camelyon:2.*.*", data_dir=data_dir) 51 | dataset_builder.download_and_prepare() 52 | 53 | # smallnorb 54 | dataset_builder = tfds.builder("smallnorb:2.*.*", data_dir=data_dir) 55 | dataset_builder.download_and_prepare() 56 | 57 | # svhn 58 | dataset_builder = tfds.builder("svhn_cropped:3.*.*", data_dir=data_dir) 59 | dataset_builder.download_and_prepare() 60 | ``` 61 | 62 | There are 4 datasets need special care: 63 | 64 | ```python 65 | # sun397 --> need cv2 66 | # cannot load one image, similar to issue here: https://github.com/tensorflow/datasets/issues/2889 67 | # "Image /t/track/outdoor/sun_aophkoiosslinihb.jpg could not be decoded by Tensorflow."" 68 | # sol: modify the file: "/fsx/menglin/conda/envs/prompt_tf/lib/python3.7/site-packages/tensorflow_datasets/image_classification/sun.py" to ignore those images 69 | dataset_builder = tfds.builder("sun397/tfds:4.*.*", data_dir=data_dir) 70 | dataset_builder.download_and_prepare() 71 | 72 | # kitti version is wrong from vtab repo, try 3.2.0 (https://github.com/google-research/task_adaptation/issues/18) 73 | dataset_builder = tfds.builder("kitti:3.2.0", data_dir=data_dir) 74 | dataset_builder.download_and_prepare() 75 | 76 | 77 | # diabetic_retinopathy 78 | """ 79 | Download this dataset from Kaggle. 80 | https://www.kaggle.com/c/diabetic-retinopathy-detection/data 81 | After downloading, 82 | - unpack the test.zip file into /manual_dir/. 83 | - unpack the sample.zip to sample/. 84 | - unpack the sampleSubmissions.csv and trainLabels.csv. 85 | 86 | # ==== important! ==== 87 | # 1. make sure to check that there are 5 train.zip files instead of 4 (somehow if you chose to download all from kaggle, the train.zip.005 file is missing) 88 | # 2. if unzip train.zip ran into issues, try to use jar xvf train.zip to handle huge zip file 89 | cat test.zip.* > test.zip 90 | cat train.zip.* > train.zip 91 | """ 92 | 93 | config_and_version = "btgraham-300" + ":3.*.*" 94 | dataset_builder = tfds.builder("diabetic_retinopathy_detection/{}".format(config_and_version), data_dir=data_dir) 95 | dataset_builder.download_and_prepare() 96 | 97 | 98 | # resisc45 99 | """ 100 | download/extract dataset artifacts manually: 101 | Dataset can be downloaded from OneDrive: https://1drv.ms/u/s!AmgKYzARBl5ca3HNaHIlzp_IXjs 102 | After downloading the rar file, please extract it to the manual_dir. 103 | """ 104 | 105 | dataset_builder = tfds.builder("resisc45:3.*.*", data_dir=data_dir) 106 | dataset_builder.download_and_prepare() 107 | ``` 108 | 109 | ## Notes 110 | 111 | ### TFDS version 112 | 113 | Note that the experimental results may be different with different API and/or dataset generation code versions. See more from [tfds documentation](https://www.tensorflow.org/datasets/datasets_versioning). Here are what we used for MoSA: 114 | 115 | ```bash 116 | tfds: 4.4.0+nightly 117 | 118 | # Natural: 119 | cifar100: 3.0.2 120 | caltech101: 3.0.1 121 | dtd: 3.0.1 122 | oxford_flowers102: 2.1.1 123 | oxford_iiit_pet: 3.2.0 124 | svhn_cropped: 3.0.0 125 | sun397: 4.0.0 126 | 127 | # Specialized: 128 | patch_camelyon: 2.0.0 129 | eurosat: 2.0.0 130 | resisc45: 3.0.0 131 | diabetic_retinopathy_detection: 3.0.0 132 | 133 | 134 | # Structured 135 | clevr: 3.1.0 136 | dmlab: 2.0.1 137 | kitti: 3.2.0 138 | dsprites: 2.0.0 139 | smallnorb: 2.0.0 140 | ``` 141 | -------------------------------------------------------------------------------- /src/engine/eval/multilabel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | evaluate precision@1, @5 equal to Top1 and Top5 error rate 4 | """ 5 | import numpy as np 6 | from typing import List, Tuple, Dict 7 | from sklearn.metrics import ( 8 | precision_recall_curve, 9 | average_precision_score, 10 | f1_score 11 | ) 12 | 13 | 14 | def get_continuous_ids(probe_labels: List[int]) -> Dict[int, int]: 15 | sorted(probe_labels) 16 | id2continuousid = {} 17 | for idx, p_id in enumerate(probe_labels): 18 | id2continuousid[p_id] = idx 19 | return id2continuousid 20 | 21 | 22 | def multihot(x: List[List[int]], nb_classes: int) -> np.ndarray: 23 | """transform to multihot encoding 24 | 25 | Arguments: 26 | x: list of multi-class integer labels, in the range 27 | [0, nb_classes-1] 28 | nb_classes: number of classes for the multi-hot vector 29 | 30 | Returns: 31 | multihot: multihot vector of type int, (num_samples, nb_classes) 32 | """ 33 | num_samples = len(x) 34 | 35 | multihot = np.zeros((num_samples, nb_classes), dtype=np.int32) 36 | for idx, labs in enumerate(x): 37 | for lab in labs: 38 | multihot[idx, lab] = 1 39 | 40 | return multihot.astype(np.int) 41 | 42 | 43 | def compute_map( 44 | scores: np.ndarray, multihot_targets: np.ndarray 45 | ) -> Tuple[np.ndarray, np.ndarray, float, float]: 46 | """Compute the mean average precision across all class labels. 47 | 48 | Arguments: 49 | scores: matrix of per-class distances, 50 | of size num_samples x nb_classes 51 | multihot_targets: matrix of multi-hot target predictions, 52 | of size num_samples x nb_classes 53 | 54 | Returns: 55 | ap: list of average-precision scores, one for each of 56 | the nb_classes classes. 57 | ar: list of average-recall scores, one for each of 58 | the nb_classes classes. 59 | mAP: the mean average precision score over all average 60 | precisions for all nb_classes classes. 61 | mAR: the mean average recall score over all average 62 | precisions for all nb_classes classes. 63 | """ 64 | nb_classes = scores.shape[1] 65 | 66 | ap = np.zeros((nb_classes,), dtype=np.float32) 67 | ar = np.zeros((nb_classes,), dtype=np.float32) 68 | 69 | for c in range(nb_classes): 70 | y_true = multihot_targets[:, c] 71 | y_scores = scores[:, c] 72 | 73 | # Use interpolated average precision (a la PASCAL 74 | try: 75 | ap[c] = average_precision_score(y_true, y_scores) 76 | except ValueError: 77 | ap[c] = -1 78 | 79 | # Also get the average of the recalls on the raw PR-curve 80 | try: 81 | _, rec, _ = precision_recall_curve(y_true, y_scores) 82 | ar[c] = rec.mean() 83 | except ValueError: 84 | ar[c] = -1 85 | 86 | mAP = ap.mean() 87 | mAR = ar.mean() 88 | 89 | return ap, ar, mAP, mAR 90 | 91 | 92 | def compute_f1( 93 | multihot_targets: np.ndarray, scores: np.ndarray, threshold: float = 0.5 94 | ) -> Tuple[float, float, float]: 95 | # change scores to predict_labels 96 | predict_labels = scores > threshold 97 | predict_labels = predict_labels.astype(np.int) 98 | 99 | # change targets to multihot 100 | f1 = {} 101 | f1["micro"] = f1_score( 102 | y_true=multihot_targets, 103 | y_pred=predict_labels, 104 | average="micro" 105 | ) 106 | f1["samples"] = f1_score( 107 | y_true=multihot_targets, 108 | y_pred=predict_labels, 109 | average="samples" 110 | ) 111 | f1["macro"] = f1_score( 112 | y_true=multihot_targets, 113 | y_pred=predict_labels, 114 | average="macro" 115 | ) 116 | f1["none"] = f1_score( 117 | y_true=multihot_targets, 118 | y_pred=predict_labels, 119 | average=None 120 | ) 121 | return f1["micro"], f1["samples"], f1["macro"], f1["none"] 122 | 123 | 124 | def get_best_f1_scores( 125 | multihot_targets: np.ndarray, scores: np.ndarray, threshold_end: int 126 | ) -> Dict[str, float]: 127 | end = 0.5 128 | end = 0.05 129 | end = threshold_end 130 | thrs = np.linspace( 131 | end, 0.95, int(np.round((0.95 - end) / 0.05)) + 1, endpoint=True 132 | ) 133 | f1_micros = [] 134 | f1_macros = [] 135 | f1_samples = [] 136 | f1_none = [] 137 | for thr in thrs: 138 | _micros, _samples, _macros, _none = compute_f1(multihot_targets, scores, thr) 139 | f1_micros.append(_micros) 140 | f1_samples.append(_samples) 141 | f1_macros.append(_macros) 142 | f1_none.append(_none) 143 | 144 | f1_macros_m = max(f1_macros) 145 | b_thr = np.argmax(f1_macros) 146 | 147 | f1_micros_m = f1_micros[b_thr] 148 | f1_samples_m = f1_samples[b_thr] 149 | f1_none_m = f1_none[b_thr] 150 | f1 = {} 151 | f1["micro"] = f1_micros_m 152 | f1["macro"] = f1_macros_m 153 | f1["samples"] = f1_samples_m 154 | f1["threshold"] = thrs[b_thr] 155 | f1["none"] = f1_none_m 156 | return f1 157 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/dsprites.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements the DSprites data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow.compat.v1 as tf 24 | import tensorflow_datasets as tfds 25 | 26 | from . import base as base 27 | from .registry import Registry 28 | 29 | 30 | # These constants specify the percentage of data that is used to create custom 31 | # train/val splits. Specifically, TRAIN_SPLIT_PERCENT% of the data set is used 32 | # as a new training split and VAL_SPLIT_PERCENT% is used for validation. 33 | # The rest is used for testing. 34 | TRAIN_SPLIT_PERCENT = 80 35 | VAL_SPLIT_PERCENT = 10 36 | 37 | 38 | @Registry.register("data.dsprites", "class") 39 | class DSpritesData(base.ImageTfdsData): 40 | """Provides the DSprites data set. 41 | 42 | DSprites only comes with a training set. Therefore, the training, validation, 43 | and test set are split out of the original training set. 44 | 45 | For additional details and usage, see the base class. 46 | 47 | The data set page is https://github.com/deepmind/dsprites-dataset/. 48 | """ 49 | 50 | def __init__(self, predicted_attribute, num_classes=None, data_dir=None): 51 | dataset_builder = tfds.builder("dsprites:2.*.*", data_dir=data_dir) 52 | dataset_builder.download_and_prepare() 53 | info = dataset_builder.info 54 | 55 | if predicted_attribute not in dataset_builder.info.features: 56 | raise ValueError( 57 | "{} is not a valid attribute to predict.".format(predicted_attribute)) 58 | 59 | # If num_classes is set, we group together nearby integer values to arrive 60 | # at the desired number of classes. This is useful for example for grouping 61 | # together different spatial positions. 62 | num_original_classes = info.features[predicted_attribute].num_classes 63 | if num_classes is None: 64 | num_classes = num_original_classes 65 | if not isinstance(num_classes, int) or num_classes <= 1 or ( 66 | num_classes > num_original_classes): 67 | raise ValueError( 68 | "The number of classes should be None or in [2, ..., num_classes].") 69 | class_division_factor = float(num_original_classes) / num_classes 70 | 71 | # Creates a dict with example counts for each split. 72 | num_total = dataset_builder.info.splits["train"].num_examples 73 | num_samples_train = TRAIN_SPLIT_PERCENT * num_total // 100 74 | num_samples_val = VAL_SPLIT_PERCENT * num_total // 100 75 | num_samples_splits = { 76 | "train": num_samples_train, 77 | "val": num_samples_val, 78 | "trainval": num_samples_val + num_samples_train, 79 | "test": num_total - num_samples_val - num_samples_train, 80 | "train800": 800, 81 | "val200": 200, 82 | "train800val200": 1000, 83 | } 84 | 85 | # Defines dataset specific train/val/trainval/test splits. 86 | tfds_splits = { 87 | "train": "train[:{}]".format(num_samples_splits["train"]), 88 | "val": "train[{}:{}]".format(num_samples_splits["train"], 89 | num_samples_splits["trainval"]), 90 | "trainval": "train[:{}]".format(num_samples_splits["trainval"]), 91 | "test": "train[{}:]".format(num_samples_splits["trainval"]), 92 | "train800": "train[:800]", 93 | "val200": "train[{}:{}]".format(num_samples_splits["train"], 94 | num_samples_splits["train"]+200), 95 | "train800val200": "train[:800]+train[{}:{}]".format( 96 | num_samples_splits["train"], num_samples_splits["train"]+200), 97 | } 98 | 99 | def preprocess_fn(tensors): 100 | # For consistency with other datasets, image needs to have three channels 101 | # and be in [0, 255). 102 | images = tf.tile(tensors["image"], [1, 1, 3]) * 255 103 | label = tf.cast( 104 | tf.math.floordiv( 105 | tf.cast(tensors[predicted_attribute], tf.float32), 106 | class_division_factor), info.features[predicted_attribute].dtype) 107 | return dict(image=images, label=label) 108 | 109 | super(DSpritesData, self).__init__( 110 | dataset_builder=dataset_builder, 111 | tfds_splits=tfds_splits, 112 | num_samples_splits=num_samples_splits, 113 | num_preprocessing_threads=400, 114 | shuffle_buffer_size=10000, 115 | # We extract the attribute we want to predict in the preprocessing. 116 | base_preprocess_fn=preprocess_fn, 117 | num_classes=num_classes) 118 | -------------------------------------------------------------------------------- /src/models/vit_adapter/swin_adapter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | vit with adapter 4 | """ 5 | import copy 6 | import math 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | 11 | from ..vit_backbones.swin_transformer import Mlp, SwinTransformerBlock, BasicLayer, PatchMerging, SwinTransformer 12 | from ...utils import logging 13 | logger = logging.get_logger("MOSA") 14 | 15 | 16 | class AdaptedMlp(Mlp): 17 | def __init__( 18 | self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., 19 | adapter_config=None, adapter_scalar=1.0, dropout=0.0 20 | ): 21 | super(AdaptedMlp, self).__init__(in_features, hidden_features, out_features, act_layer, drop) 22 | self.adapter_config = adapter_config 23 | 24 | if adapter_scalar is None: 25 | self.adapter_scale = nn.Parameter(torch.ones(1)) 26 | else: 27 | self.adapter_scale = adapter_scalar 28 | self.dropout = dropout 29 | 30 | out_features = out_features or in_features 31 | self.adapter_down = nn.Linear( 32 | in_features, 33 | adapter_config.BOTTLENECK_SIZE 34 | ) 35 | self.adapter_up = nn.Linear( 36 | adapter_config.BOTTLENECK_SIZE, 37 | out_features 38 | ) 39 | self.adapter_act_fn = nn.ReLU() 40 | 41 | nn.init.kaiming_uniform_(self.adapter_down.weight, a=math.sqrt(5)) 42 | nn.init.zeros_(self.adapter_down.bias) 43 | 44 | nn.init.zeros_(self.adapter_up.weight) 45 | nn.init.zeros_(self.adapter_up.bias) 46 | 47 | def forward(self, x): 48 | # same as reguluar Mlp block 49 | 50 | h = x 51 | x = self.fc1(x) 52 | x = self.act(x) 53 | x = self.drop(x) 54 | x = self.fc2(x) 55 | x = self.drop(x) 56 | 57 | # start to insert adapter layers... 58 | adpt = self.adapter_down(h) 59 | adpt = self.adapter_act_fn(adpt) 60 | adpt = nn.functional.dropout(adpt, p=self.dropout, training=self.training) 61 | adpt = self.adapter_up(adpt) 62 | 63 | x = adpt * self.adapter_scale + x 64 | # ...end 65 | 66 | return x 67 | 68 | 69 | class AdaptedSwinTransformerBlock(SwinTransformerBlock): 70 | def __init__( 71 | self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 72 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 73 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, adapter_config=None 74 | ): 75 | super(AdaptedSwinTransformerBlock, self).__init__( 76 | dim, input_resolution, num_heads, window_size, 77 | shift_size, mlp_ratio, qkv_bias, qk_scale, drop, 78 | attn_drop, drop_path, act_layer, norm_layer 79 | ) 80 | self.adapter_config = adapter_config 81 | mlp_hidden_dim = int(dim * mlp_ratio) 82 | self.mlp = AdaptedMlp( 83 | in_features=dim, hidden_features=mlp_hidden_dim, 84 | act_layer=act_layer, drop=drop, adapter_config=adapter_config 85 | ) 86 | 87 | 88 | class AdaptedSwinTransformer(SwinTransformer): 89 | def __init__( 90 | self, adapter_config, img_size=224, patch_size=4, in_chans=3, 91 | num_classes=1000, embed_dim=96, depths=[2, 2, 6, 2], 92 | num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4., qkv_bias=True, 93 | qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 94 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 95 | use_checkpoint=False, **kwargs 96 | ): 97 | super(AdaptedSwinTransformer, self).__init__( 98 | img_size, patch_size, in_chans, num_classes, embed_dim, depths, 99 | num_heads, window_size, mlp_ratio, qkv_bias, qk_scale, drop_rate, 100 | attn_drop_rate, drop_path_rate, norm_layer, ape, patch_norm, 101 | use_checkpoint, **kwargs 102 | ) 103 | self.adapter_config = adapter_config 104 | 105 | # stochastic depth 106 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 107 | 108 | # build layers 109 | self.layers = nn.ModuleList() 110 | for i_layer in range(self.num_layers): 111 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 112 | input_resolution=( 113 | self.patches_resolution[0] // (2 ** i_layer), 114 | self.patches_resolution[1] // (2 ** i_layer)), 115 | depth=depths[i_layer], 116 | num_heads=num_heads[i_layer], 117 | window_size=window_size, 118 | mlp_ratio=self.mlp_ratio, 119 | qkv_bias=qkv_bias, qk_scale=qk_scale, 120 | drop=drop_rate, attn_drop=attn_drop_rate, 121 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 122 | norm_layer=norm_layer, 123 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 124 | use_checkpoint=use_checkpoint, 125 | block_module=AdaptedSwinTransformerBlock, 126 | adapter_config=adapter_config 127 | ) 128 | self.layers.append(layer) 129 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/registry.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Global Registry for the task adaptation framework. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import ast 25 | import functools 26 | 27 | 28 | def partialclass(cls, *base_args, **base_kwargs): 29 | """Builds a subclass with partial application of the given args and keywords. 30 | 31 | Equivalent to functools.partial performance, base_args are preprended to the 32 | positional arguments given during object initialization and base_kwargs are 33 | updated with the kwargs given later. 34 | 35 | Args: 36 | cls: The base class. 37 | *base_args: Positional arguments to be applied to the subclass. 38 | **base_kwargs: Keyword arguments to be applied to the subclass. 39 | 40 | Returns: 41 | A subclass of the input class. 42 | """ 43 | 44 | class _NewClass(cls): 45 | 46 | def __init__(self, *args, **kwargs): 47 | bound_args = base_args + args 48 | bound_kwargs = base_kwargs.copy() 49 | bound_kwargs.update(kwargs) 50 | super(_NewClass, self).__init__(*bound_args, **bound_kwargs) 51 | 52 | return _NewClass 53 | 54 | 55 | def parse_name(string_to_parse): 56 | """Parses input to the registry's lookup function. 57 | 58 | Args: 59 | string_to_parse: can be either an arbitrary name or function call 60 | (optionally with positional and keyword arguments). 61 | e.g. "multiclass", "resnet50_v2(filters_factor=8)". 62 | 63 | Returns: 64 | A tuple of input name and a dctinary with arguments. Examples: 65 | "multiclass" -> ("multiclass", (), {}) 66 | "resnet50_v2(9, filters_factor=4)" -> 67 | ("resnet50_v2", (9,), {"filters_factor": 4}) 68 | """ 69 | expr = ast.parse(string_to_parse, mode="eval").body # pytype: disable=attribute-error 70 | if not isinstance(expr, (ast.Attribute, ast.Call, ast.Name)): 71 | raise ValueError( 72 | "The given string should be a name or a call, but a {} was parsed from " 73 | "the string {!r}".format(type(expr), string_to_parse)) 74 | 75 | # Notes: 76 | # name="some_name" -> type(expr) = ast.Name 77 | # name="module.some_name" -> type(expr) = ast.Attribute 78 | # name="some_name()" -> type(expr) = ast.Call 79 | # name="module.some_name()" -> type(expr) = ast.Call 80 | 81 | if isinstance(expr, ast.Name): 82 | return string_to_parse, {} 83 | elif isinstance(expr, ast.Attribute): 84 | return string_to_parse, {} 85 | 86 | def _get_func_name(expr): 87 | if isinstance(expr, ast.Attribute): 88 | return _get_func_name(expr.value) + "." + expr.attr 89 | elif isinstance(expr, ast.Name): 90 | return expr.id 91 | else: 92 | raise ValueError( 93 | "Type {!r} is not supported in a function name, the string to parse " 94 | "was {!r}".format(type(expr), string_to_parse)) 95 | 96 | def _get_func_args_and_kwargs(call): 97 | args = tuple([ast.literal_eval(arg) for arg in call.args]) 98 | kwargs = { 99 | kwarg.arg: ast.literal_eval(kwarg.value) for kwarg in call.keywords 100 | } 101 | return args, kwargs 102 | 103 | func_name = _get_func_name(expr.func) 104 | func_args, func_kwargs = _get_func_args_and_kwargs(expr) 105 | if func_args: 106 | raise ValueError("Positional arguments are not supported here, but these " 107 | "were found: {!r}".format(func_args)) 108 | 109 | return func_name, func_kwargs 110 | 111 | 112 | class Registry(object): 113 | """Implements global Registry.""" 114 | 115 | _GLOBAL_REGISTRY = {} 116 | 117 | @staticmethod 118 | def global_registry(): 119 | return Registry._GLOBAL_REGISTRY 120 | 121 | @staticmethod 122 | def register(name, item_type): 123 | """Creates a function that registers its input.""" 124 | if item_type not in ["function", "class"]: 125 | raise ValueError("Unknown item type: %s" % item_type) 126 | 127 | def _register(item): 128 | if name in Registry.global_registry(): 129 | raise KeyError( 130 | "The name {!r} was already registered in with type {!r}".format( 131 | name, item_type)) 132 | 133 | Registry.global_registry()[name] = (item, item_type) 134 | return item 135 | 136 | return _register 137 | 138 | @staticmethod 139 | def lookup(lookup_string, kwargs_extra=None): 140 | """Lookup a name in the registry.""" 141 | 142 | name, kwargs = parse_name(lookup_string) 143 | if kwargs_extra: 144 | kwargs.update(kwargs_extra) 145 | item, item_type = Registry.global_registry()[name] 146 | if item_type == "function": 147 | return functools.partial(item, **kwargs) 148 | elif item_type == "class": 149 | return partialclass(item, **kwargs) 150 | -------------------------------------------------------------------------------- /src/utils/distributed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Distributed helpers.""" 4 | 5 | import torch 6 | import torch.distributed as dist 7 | _LOCAL_PROCESS_GROUP = None 8 | 9 | 10 | def get_world_size() -> int: 11 | if not dist.is_available(): 12 | return 1 13 | if not dist.is_initialized(): 14 | return 1 15 | return dist.get_world_size() 16 | 17 | 18 | def get_rank() -> int: 19 | if not dist.is_available(): 20 | return 0 21 | if not dist.is_initialized(): 22 | return 0 23 | return dist.get_rank() 24 | 25 | 26 | def is_master_process(num_gpus=8): 27 | """ 28 | Determines if the current process is the master process. 29 | """ 30 | if torch.distributed.is_initialized(): 31 | return dist.get_rank() % num_gpus == 0 32 | else: 33 | return True 34 | 35 | 36 | def run( 37 | local_rank, 38 | num_proc, 39 | func, 40 | init_method, 41 | shard_id, 42 | num_shards, 43 | backend, 44 | cfg, 45 | args, 46 | ): 47 | """ 48 | Runs a function from a child process. 49 | Args: 50 | local_rank (int): rank of the current process on the current machine. 51 | num_proc (int): number of processes per machine. 52 | func (function): function to execute on each of the process. 53 | init_method (string): method to initialize the distributed training. 54 | TCP initialization: equiring a network address reachable from all 55 | processes followed by the port. 56 | Shared file-system initialization: makes use of a file system that 57 | is shared and visible from all machines. The URL should start with 58 | file:// and contain a path to a non-existent file on a shared file 59 | system. 60 | shard_id (int): the rank of the current machine. 61 | num_shards (int): number of overall machines for the distributed 62 | training job. 63 | backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are 64 | supports, each with different capabilities. Details can be found 65 | here: 66 | https://pytorch.org/docs/stable/distributed.html 67 | cfg (CfgNode): configs. Details can be found in 68 | loco/config/defaults.py 69 | """ 70 | # Initialize the process group. 71 | # shard_id = get_rank() 72 | world_size = num_proc * num_shards 73 | rank = shard_id * num_proc + local_rank 74 | 75 | try: 76 | torch.distributed.init_process_group( 77 | backend=backend, 78 | init_method=init_method, 79 | world_size=world_size, 80 | rank=rank, 81 | ) 82 | except Exception as e: 83 | raise e 84 | 85 | torch.cuda.set_device(local_rank) 86 | func(cfg, args) 87 | 88 | 89 | def destroy_process_group(): 90 | """Destroys the default process group.""" 91 | torch.distributed.destroy_process_group() 92 | 93 | 94 | def scaled_all_reduce(cfg, tensors): 95 | """Performs the scaled all_reduce operation on the provided tensors. 96 | 97 | The input tensors are modified in-place. Currently supports only the sum 98 | reduction operator. The reduced values are scaled by the inverse size of 99 | the process group (equivalent to cfg.NUM_GPUS). 100 | """ 101 | # Queue the reductions 102 | reductions = [] 103 | for tensor in tensors: 104 | reduction = torch.distributed.all_reduce(tensor, async_op=True) 105 | reductions.append(reduction) 106 | # Wait for reductions to finish 107 | for reduction in reductions: 108 | reduction.wait() 109 | # Scale the results 110 | for tensor in tensors: 111 | tensor.mul_(1.0 / cfg.NUM_GPUS / cfg.NUM_SHARDS) 112 | return tensors 113 | 114 | 115 | def cat_all_gather(tensors): 116 | """Performs the concatenated all_gather operation on the provided tensors. 117 | """ 118 | tensors_gather = [ 119 | torch.ones_like(tensors) 120 | for _ in range(torch.distributed.get_world_size()) 121 | ] 122 | torch.distributed.all_gather(tensors_gather, tensors, async_op=False) 123 | 124 | output = torch.cat(tensors_gather, dim=0) 125 | return output 126 | 127 | 128 | def local_cat_all_gather(tensors): 129 | """Performs the concatenated all_gather operation on the provided tensors. 130 | """ 131 | tensors_gather = [ 132 | torch.ones_like(tensors) 133 | for _ in range(get_local_size()) 134 | ] 135 | torch.distributed.all_gather( 136 | tensors_gather, 137 | tensors, 138 | async_op=False, 139 | group=_LOCAL_PROCESS_GROUP, 140 | ) 141 | output = torch.cat(tensors_gather, dim=0) 142 | return output 143 | 144 | 145 | def get_local_size(): 146 | """ 147 | Returns: 148 | The size of the per-machine process group, 149 | i.e. the number of processes per machine. 150 | """ 151 | if not dist.is_available(): 152 | return 1 153 | if not dist.is_initialized(): 154 | return 1 155 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 156 | 157 | 158 | def get_local_rank(): 159 | """ 160 | Returns: 161 | The rank of the current process within the local (per-machine) process group. 162 | """ 163 | if not dist.is_available(): 164 | return 0 165 | if not dist.is_initialized(): 166 | return 0 167 | assert _LOCAL_PROCESS_GROUP is not None 168 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 169 | -------------------------------------------------------------------------------- /src/models/vit_backbones/vit_moco.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | """ 4 | borrowed from https://github.com/facebookresearch/moco-v3/blob/main/vits.py 5 | """ 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | from functools import partial, reduce 10 | from operator import mul 11 | 12 | from timm.models.vision_transformer import VisionTransformer, _cfg 13 | from timm.models.layers.helpers import to_2tuple 14 | from timm.models.layers import PatchEmbed 15 | 16 | __all__ = [ 17 | 'vit_small', 18 | 'vit_base', 19 | 'vit_conv_small', 20 | 'vit_conv_base', 21 | ] 22 | 23 | 24 | class VisionTransformerMoCo(VisionTransformer): 25 | def __init__(self, stop_grad_conv1=False, **kwargs): 26 | super().__init__(**kwargs) 27 | # Use fixed 2D sin-cos position embedding 28 | self.build_2d_sincos_position_embedding() 29 | 30 | # weight initialization 31 | for name, m in self.named_modules(): 32 | if isinstance(m, nn.Linear): 33 | if 'qkv' in name: 34 | # treat the weights of Q, K, V separately 35 | val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1])) 36 | nn.init.uniform_(m.weight, -val, val) 37 | else: 38 | nn.init.xavier_uniform_(m.weight) 39 | nn.init.zeros_(m.bias) 40 | nn.init.normal_(self.cls_token, std=1e-6) 41 | 42 | if isinstance(self.patch_embed, PatchEmbed): 43 | # xavier_uniform initialization 44 | val = math.sqrt(6. / float(3 * reduce(mul, self.patch_embed.patch_size, 1) + self.embed_dim)) 45 | nn.init.uniform_(self.patch_embed.proj.weight, -val, val) 46 | nn.init.zeros_(self.patch_embed.proj.bias) 47 | 48 | if stop_grad_conv1: 49 | self.patch_embed.proj.weight.requires_grad = False 50 | self.patch_embed.proj.bias.requires_grad = False 51 | 52 | def build_2d_sincos_position_embedding(self, temperature=10000.): 53 | h, w = self.patch_embed.grid_size 54 | grid_w = torch.arange(w, dtype=torch.float32) 55 | grid_h = torch.arange(h, dtype=torch.float32) 56 | grid_w, grid_h = torch.meshgrid(grid_w, grid_h) 57 | assert self.embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' 58 | pos_dim = self.embed_dim // 4 59 | omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim 60 | omega = 1. / (temperature**omega) 61 | out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) 62 | out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) 63 | pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :] 64 | 65 | assert self.num_tokens == 1, 'Assuming one and only one token, [cls]' 66 | pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32) 67 | self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1)) 68 | self.pos_embed.requires_grad = False 69 | 70 | 71 | class ConvStem(nn.Module): 72 | """ 73 | ConvStem, from Early Convolutions Help Transformers See Better, Tete et al. https://arxiv.org/abs/2106.14881 74 | """ 75 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 76 | super().__init__() 77 | 78 | assert patch_size == 16, 'ConvStem only supports patch size of 16' 79 | assert embed_dim % 8 == 0, 'Embed dimension must be divisible by 8 for ConvStem' 80 | 81 | img_size = to_2tuple(img_size) 82 | patch_size = to_2tuple(patch_size) 83 | self.img_size = img_size 84 | self.patch_size = patch_size 85 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 86 | self.num_patches = self.grid_size[0] * self.grid_size[1] 87 | self.flatten = flatten 88 | 89 | # build stem, similar to the design in https://arxiv.org/abs/2106.14881 90 | stem = [] 91 | input_dim, output_dim = 3, embed_dim // 8 92 | for l in range(4): 93 | stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False)) 94 | stem.append(nn.BatchNorm2d(output_dim)) 95 | stem.append(nn.ReLU(inplace=True)) 96 | input_dim = output_dim 97 | output_dim *= 2 98 | stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1)) 99 | self.proj = nn.Sequential(*stem) 100 | 101 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 102 | 103 | def forward(self, x): 104 | B, C, H, W = x.shape 105 | assert H == self.img_size[0] and W == self.img_size[1], \ 106 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 107 | x = self.proj(x) 108 | if self.flatten: 109 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 110 | x = self.norm(x) 111 | return x 112 | 113 | 114 | def vit_small(**kwargs): 115 | model = VisionTransformerMoCo( 116 | patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 117 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 118 | model.default_cfg = _cfg() 119 | return model 120 | 121 | def vit_base(**kwargs): 122 | model = VisionTransformerMoCo( 123 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 124 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 125 | model.default_cfg = _cfg() 126 | return model 127 | 128 | def vit_conv_small(**kwargs): 129 | # minus one ViT block 130 | model = VisionTransformerMoCo( 131 | patch_size=16, embed_dim=384, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True, 132 | norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs) 133 | model.default_cfg = _cfg() 134 | return model 135 | 136 | def vit_conv_base(**kwargs): 137 | # minus one ViT block 138 | model = VisionTransformerMoCo( 139 | patch_size=16, embed_dim=768, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True, 140 | norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs) 141 | model.default_cfg = _cfg() 142 | return model 143 | -------------------------------------------------------------------------------- /src/utils/logging.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Logging.""" 4 | 5 | import builtins 6 | import decimal 7 | import functools 8 | import logging 9 | import simplejson 10 | import sys 11 | import os 12 | from termcolor import colored 13 | 14 | from .distributed import is_master_process 15 | from .file_io import PathManager 16 | 17 | # Show filename and line number in logs 18 | _FORMAT = "[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s" 19 | 20 | 21 | def _suppress_print(): 22 | """Suppresses printing from the current process.""" 23 | 24 | def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False): 25 | pass 26 | 27 | builtins.print = print_pass 28 | 29 | 30 | # cache the opened file object, so that different calls to `setup_logger` 31 | # with the same file name can safely write to the same file. 32 | @functools.lru_cache(maxsize=None) 33 | def _cached_log_stream(filename): 34 | return PathManager.open(filename, "a") 35 | 36 | 37 | @functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers # noqa 38 | def setup_logging( 39 | num_gpu, num_shards, output="", name="MOSA", color=True): 40 | """Sets up the logging.""" 41 | # Enable logging only for the master process 42 | if is_master_process(num_gpu): 43 | # Clear the root logger to prevent any existing logging config 44 | # (e.g. set by another module) from messing with our setup 45 | logging.root.handlers = [] 46 | # Configure logging 47 | logging.basicConfig( 48 | level=logging.INFO, format=_FORMAT, stream=sys.stdout 49 | ) 50 | else: 51 | _suppress_print() 52 | 53 | if name is None: 54 | name = __name__ 55 | logger = logging.getLogger(name) 56 | # remove any lingering handler 57 | logger.handlers.clear() 58 | 59 | logger.setLevel(logging.INFO) 60 | logger.propagate = False 61 | 62 | plain_formatter = logging.Formatter( 63 | "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s", 64 | datefmt="%m/%d %H:%M:%S", 65 | ) 66 | if color: 67 | formatter = _ColorfulFormatter( 68 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 69 | datefmt="%m/%d %H:%M:%S", 70 | root_name=name, 71 | abbrev_name=str(name), 72 | ) 73 | else: 74 | formatter = plain_formatter 75 | 76 | if is_master_process(num_gpu): 77 | ch = logging.StreamHandler(stream=sys.stdout) 78 | ch.setLevel(logging.DEBUG) 79 | ch.setFormatter(formatter) 80 | logger.addHandler(ch) 81 | 82 | if is_master_process(num_gpu * num_shards): 83 | if len(output) > 0: 84 | if output.endswith(".txt") or output.endswith(".log"): 85 | filename = output 86 | else: 87 | filename = os.path.join(output, "logs.txt") 88 | 89 | PathManager.mkdirs(os.path.dirname(filename)) 90 | 91 | fh = logging.StreamHandler(_cached_log_stream(filename)) 92 | fh.setLevel(logging.DEBUG) 93 | fh.setFormatter(plain_formatter) 94 | logger.addHandler(fh) 95 | return logger 96 | 97 | 98 | def setup_single_logging(name, output=""): 99 | """Sets up the logging.""" 100 | # Enable logging only for the master process 101 | # Clear the root logger to prevent any existing logging config 102 | # (e.g. set by another module) from messing with our setup 103 | logging.root.handlers = [] 104 | # Configure logging 105 | logging.basicConfig( 106 | level=logging.INFO, format=_FORMAT, stream=sys.stdout 107 | ) 108 | 109 | if len(name) == 0: 110 | name = __name__ 111 | logger = logging.getLogger(name) 112 | logger.setLevel(logging.INFO) 113 | logger.propagate = False 114 | 115 | plain_formatter = logging.Formatter( 116 | "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s", 117 | datefmt="%m/%d %H:%M:%S", 118 | ) 119 | formatter = _ColorfulFormatter( 120 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 121 | datefmt="%m/%d %H:%M:%S", 122 | root_name=name, 123 | abbrev_name=str(name), 124 | ) 125 | 126 | ch = logging.StreamHandler(stream=sys.stdout) 127 | ch.setLevel(logging.DEBUG) 128 | ch.setFormatter(formatter) 129 | logger.addHandler(ch) 130 | 131 | if len(output) > 0: 132 | if output.endswith(".txt") or output.endswith(".log"): 133 | filename = output 134 | else: 135 | filename = os.path.join(output, "logs.txt") 136 | 137 | PathManager.mkdirs(os.path.dirname(filename)) 138 | 139 | fh = logging.StreamHandler(_cached_log_stream(filename)) 140 | fh.setLevel(logging.DEBUG) 141 | fh.setFormatter(plain_formatter) 142 | logger.addHandler(fh) 143 | 144 | return logger 145 | 146 | 147 | def get_logger(name): 148 | """Retrieves the logger.""" 149 | return logging.getLogger(name) 150 | 151 | 152 | def log_json_stats(stats, sort_keys=True): 153 | """Logs json stats.""" 154 | # It seems that in Python >= 3.6 json.encoder.FLOAT_REPR has no effect 155 | # Use decimal+string as a workaround for having fixed length values in logs 156 | logger = get_logger(__name__) 157 | stats = { 158 | k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v 159 | for k, v in stats.items() 160 | } 161 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True) 162 | if stats["_type"] == "test_epoch" or stats["_type"] == "train_epoch": 163 | logger.info("json_stats: {:s}".format(json_stats)) 164 | else: 165 | logger.info("{:s}".format(json_stats)) 166 | 167 | 168 | class _ColorfulFormatter(logging.Formatter): 169 | # from detectron2 170 | def __init__(self, *args, **kwargs): 171 | self._root_name = kwargs.pop("root_name") + "." 172 | self._abbrev_name = kwargs.pop("abbrev_name", "") 173 | if len(self._abbrev_name): 174 | self._abbrev_name = self._abbrev_name + "." 175 | super(_ColorfulFormatter, self).__init__(*args, **kwargs) 176 | 177 | def formatMessage(self, record: logging.LogRecord) -> str: 178 | record.name = record.name.replace(self._root_name, self._abbrev_name) 179 | log = super(_ColorfulFormatter, self).formatMessage(record) 180 | if record.levelno == logging.WARNING: 181 | prefix = colored("WARNING", "red", attrs=["blink"]) 182 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 183 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 184 | else: 185 | return log 186 | return prefix + " " + log 187 | -------------------------------------------------------------------------------- /src/data/datasets/tf_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """a dataset that handles output of tf.data: support datasets from VTAB""" 4 | import functools 5 | import tensorflow.compat.v1 as tf 6 | import torch 7 | import torch.utils.data 8 | import numpy as np 9 | 10 | from collections import Counter 11 | from torch import Tensor 12 | 13 | from ..vtab_datasets import base 14 | # pylint: disable=unused-import 15 | from ..vtab_datasets import caltech 16 | from ..vtab_datasets import cifar 17 | from ..vtab_datasets import clevr 18 | from ..vtab_datasets import diabetic_retinopathy 19 | from ..vtab_datasets import dmlab 20 | from ..vtab_datasets import dsprites 21 | from ..vtab_datasets import dtd 22 | from ..vtab_datasets import eurosat 23 | from ..vtab_datasets import kitti 24 | from ..vtab_datasets import oxford_flowers102 25 | from ..vtab_datasets import oxford_iiit_pet 26 | from ..vtab_datasets import patch_camelyon 27 | from ..vtab_datasets import resisc45 28 | from ..vtab_datasets import smallnorb 29 | from ..vtab_datasets import sun397 30 | from ..vtab_datasets import svhn 31 | from ..vtab_datasets.registry import Registry 32 | 33 | from ...utils import logging 34 | logger = logging.get_logger("MOSA") 35 | tf.config.experimental.set_visible_devices([], 'GPU') # set tensorflow to not use gpu # noqa 36 | DATASETS = [ 37 | 'caltech101', 38 | 'cifar(num_classes=100)', 39 | 'dtd', 40 | 'oxford_flowers102', 41 | 'oxford_iiit_pet', 42 | 'patch_camelyon', 43 | 'sun397', 44 | 'svhn', 45 | 'resisc45', 46 | 'eurosat', 47 | 'dmlab', 48 | 'kitti(task="closest_vehicle_distance")', 49 | 'smallnorb(predicted_attribute="label_azimuth")', 50 | 'smallnorb(predicted_attribute="label_elevation")', 51 | 'dsprites(predicted_attribute="label_x_position",num_classes=16)', 52 | 'dsprites(predicted_attribute="label_orientation",num_classes=16)', 53 | 'clevr(task="closest_object_distance")', 54 | 'clevr(task="count_all")', 55 | 'diabetic_retinopathy(config="btgraham-300")' 56 | ] 57 | 58 | 59 | class TFDataset(torch.utils.data.Dataset): 60 | def __init__(self, cfg, split): 61 | assert split in { 62 | "train", 63 | "val", 64 | "test", 65 | "trainval" 66 | }, "Split '{}' not supported for {} dataset".format( 67 | split, cfg.DATA.NAME) 68 | logger.info("Constructing {} dataset {}...".format( 69 | cfg.DATA.NAME, split)) 70 | 71 | self.cfg = cfg 72 | self._split = split 73 | self.name = cfg.DATA.NAME 74 | 75 | self.img_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) 76 | self.img_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) 77 | 78 | self.get_data(cfg, split) 79 | 80 | def get_data(self, cfg, split): 81 | tf_data = build_tf_dataset(cfg, split) 82 | data_list = list(tf_data) # a list of tuples 83 | 84 | self._image_tensor_list = [t[0].numpy().squeeze() for t in data_list] 85 | self._targets = [int(t[1].numpy()[0]) for t in data_list] 86 | self._class_ids = sorted(list(set(self._targets))) 87 | 88 | logger.info("Number of images: {}".format(len(self._image_tensor_list))) 89 | logger.info("Number of classes: {} / {}".format( 90 | len(self._class_ids), self.get_class_num())) 91 | 92 | del data_list 93 | del tf_data 94 | 95 | def get_info(self): 96 | num_imgs = len(self._image_tensor_list) 97 | return num_imgs, self.get_class_num() 98 | 99 | def get_class_num(self): 100 | return self.cfg.DATA.NUMBER_CLASSES 101 | 102 | def get_class_weights(self, weight_type): 103 | """get a list of class weight, return a list float""" 104 | if "train" not in self._split: 105 | raise ValueError( 106 | "only getting training class distribution, " + \ 107 | "got split {} instead".format(self._split) 108 | ) 109 | 110 | cls_num = self.get_class_num() 111 | if weight_type == "none": 112 | return [1.0] * cls_num 113 | 114 | id2counts = Counter(self._class_ids) 115 | assert len(id2counts) == cls_num 116 | num_per_cls = np.array([id2counts[i] for i in self._class_ids]) 117 | 118 | if weight_type == 'inv': 119 | mu = -1.0 120 | elif weight_type == 'inv_sqrt': 121 | mu = -0.5 122 | weight_list = num_per_cls ** mu 123 | weight_list = np.divide( 124 | weight_list, np.linalg.norm(weight_list, 1)) * cls_num 125 | return weight_list.tolist() 126 | 127 | def __getitem__(self, index): 128 | # Load the image 129 | label = self._targets[index] 130 | im = to_torch_imgs( 131 | self._image_tensor_list[index], self.img_mean, self.img_std) 132 | 133 | if self._split == "train": 134 | index = index 135 | else: 136 | index = f"{self._split}{index}" 137 | sample = { 138 | "image": im, 139 | "label": label, 140 | # "id": index 141 | } 142 | return sample 143 | 144 | def __len__(self): 145 | return len(self._targets) 146 | 147 | 148 | def preprocess_fn(data, size=224, input_range=(0.0, 1.0)): 149 | image = data["image"] 150 | image = tf.image.resize(image, [size, size]) 151 | 152 | image = tf.cast(image, tf.float32) / 255.0 153 | image = image * (input_range[1] - input_range[0]) + input_range[0] 154 | 155 | data["image"] = image 156 | return data 157 | 158 | 159 | def build_tf_dataset(cfg, mode): 160 | """ 161 | Builds a tf data instance, then transform to a list of tensors and labels 162 | """ 163 | 164 | if mode not in ["train", "val", "test", "trainval"]: 165 | raise ValueError("The input pipeline supports `train`, `val`, `test`." 166 | "Provided mode is {}".format(mode)) 167 | 168 | vtab_dataname = cfg.DATA.NAME.split("vtab-")[-1] 169 | data_dir = cfg.DATA.DATAPATH 170 | if vtab_dataname in DATASETS: 171 | data_cls = Registry.lookup("data." + vtab_dataname) 172 | vtab_tf_dataloader = data_cls(data_dir=data_dir) 173 | else: 174 | raise ValueError("Unknown type for \"dataset\" field: {}".format( 175 | type(vtab_dataname))) 176 | 177 | split_name_dict = { 178 | "dataset_train_split_name": "train800", 179 | "dataset_val_split_name": "val200", 180 | "dataset_trainval_split_name": "train800val200", 181 | "dataset_test_split_name": "test", 182 | } 183 | 184 | def _dict_to_tuple(batch): 185 | return batch['image'], batch['label'] 186 | 187 | return vtab_tf_dataloader.get_tf_data( 188 | batch_size=1, # data_params["batch_size"], 189 | drop_remainder=False, 190 | split_name=split_name_dict[f"dataset_{mode}_split_name"], 191 | preprocess_fn=functools.partial( 192 | preprocess_fn, 193 | input_range=(0.0, 1.0), 194 | size=cfg.DATA.CROPSIZE, 195 | ), 196 | for_eval=mode != "train", # handles shuffling 197 | shuffle_buffer_size=1000, 198 | prefetch=1, 199 | train_examples=None, 200 | epochs=1 # setting epochs to 1 make sure it returns one copy of the dataset 201 | ).map(_dict_to_tuple) # return a PrefetchDataset object. (which does not have much documentation to go on) 202 | 203 | 204 | def to_torch_imgs(img: np.ndarray, mean: Tensor, std: Tensor) -> Tensor: 205 | t_img: Tensor = torch.from_numpy(np.transpose(img, (2, 0, 1))) 206 | t_img -= mean 207 | t_img /= std 208 | 209 | return t_img 210 | -------------------------------------------------------------------------------- /src/data/datasets/pt_dataset.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | import numpy as np 3 | from torchvision.datasets import CIFAR100, FGVCAircraft, Food101 4 | import torchvision as tv 5 | 6 | from ..transforms import get_transforms 7 | from ...utils import logging 8 | logger = logging.get_logger("MOSA") 9 | 10 | import ssl 11 | ssl._create_default_https_context = ssl._create_unverified_context 12 | 13 | 14 | class CIFAR100Dataset(CIFAR100): 15 | def __init__(self, cfg, split): 16 | root = cfg.DATA.DATAPATH 17 | train = True if split == "train" else False 18 | transform = get_transforms(split, cfg.DATA.CROPSIZE) 19 | super(CIFAR100Dataset, self).__init__(root=root, train=train, transform=transform) 20 | self.raw_ds = CIFAR100(root=root, train=train) 21 | logger.info("Constructing {} dataset {}...".format( 22 | cfg.DATA.NAME, split)) 23 | 24 | self.cfg = cfg 25 | self._split = split 26 | self.name = cfg.DATA.NAME 27 | self._construct_imdb() 28 | 29 | def _construct_imdb(self): 30 | logger.info("Number of images: {}".format(len(self.data))) 31 | logger.info("Number of classes: {}".format(len(self.class_to_idx))) 32 | 33 | def get_info(self): 34 | num_imgs = len(self.data) 35 | return num_imgs, self.get_class_num() 36 | 37 | def get_class_num(self): 38 | return self.cfg.DATA.NUMBER_CLASSES 39 | # return len(self._class_ids) 40 | 41 | def get_class_weights(self, weight_type): 42 | """get a list of class weight, return a list float""" 43 | if "train" not in self._split: 44 | raise ValueError( 45 | "only getting training class distribution, " + \ 46 | "got split {} instead".format(self._split) 47 | ) 48 | 49 | cls_num = self.get_class_num() 50 | if weight_type == "none": 51 | return [1.0] * cls_num 52 | 53 | id2counts = Counter(self.classes) 54 | assert len(id2counts) == cls_num 55 | num_per_cls = np.array([id2counts[i] for i in self.classes]) 56 | 57 | if weight_type == 'inv': 58 | mu = -1.0 59 | elif weight_type == 'inv_sqrt': 60 | mu = -0.5 61 | weight_list = num_per_cls ** mu 62 | weight_list = np.divide( 63 | weight_list, np.linalg.norm(weight_list, 1)) * cls_num 64 | return weight_list.tolist() 65 | 66 | def __getitem__(self, index): 67 | img, target = super(CIFAR100Dataset, self).__getitem__(index) 68 | raw_transform = tv.transforms.Compose( 69 | [ 70 | tv.transforms.Resize(256), 71 | tv.transforms.CenterCrop(224), 72 | tv.transforms.ToTensor(), 73 | ] 74 | ) 75 | raw_img, test_target = self.raw_ds.__getitem__(index) 76 | assert(test_target == target) 77 | 78 | sample = { 79 | "image": img, 80 | "label": target, 81 | "raw": raw_transform(raw_img) 82 | } 83 | return sample 84 | 85 | 86 | class AircraftDataset(FGVCAircraft): 87 | def __init__(self, cfg, split): 88 | root = cfg.DATA.DATAPATH 89 | if split == "train": 90 | split = "trainval" 91 | if split == "val": 92 | split = "test" 93 | transform = get_transforms(split, cfg.DATA.CROPSIZE) 94 | super(AircraftDataset, self).__init__(root=root, split=split, transform=transform) 95 | logger.info("Constructing {} dataset {}...".format( 96 | cfg.DATA.NAME, split)) 97 | 98 | self.cfg = cfg 99 | self._split = split 100 | self.name = cfg.DATA.NAME 101 | self._construct_imdb() 102 | 103 | def _construct_imdb(self): 104 | logger.info("Number of images: {}".format(len(self._image_files))) 105 | logger.info("Number of classes: {}".format(len(self.class_to_idx))) 106 | 107 | def get_info(self): 108 | num_imgs = len(self._image_files) 109 | return num_imgs, self.get_class_num() 110 | 111 | def get_class_num(self): 112 | return self.cfg.DATA.NUMBER_CLASSES 113 | # return len(self._class_ids) 114 | 115 | def get_class_weights(self, weight_type): 116 | """get a list of class weight, return a list float""" 117 | if "train" not in self._split: 118 | raise ValueError( 119 | "only getting training class distribution, " + \ 120 | "got split {} instead".format(self._split) 121 | ) 122 | 123 | cls_num = self.get_class_num() 124 | if weight_type == "none": 125 | return [1.0] * cls_num 126 | 127 | id2counts = Counter(self._labels) 128 | assert len(id2counts) == cls_num 129 | num_per_cls = np.array([id2counts[i] for i in self._labels]) 130 | 131 | if weight_type == 'inv': 132 | mu = -1.0 133 | elif weight_type == 'inv_sqrt': 134 | mu = -0.5 135 | weight_list = num_per_cls ** mu 136 | weight_list = np.divide( 137 | weight_list, np.linalg.norm(weight_list, 1)) * cls_num 138 | return weight_list.tolist() 139 | 140 | def __getitem__(self, index): 141 | img, target = super(AircraftDataset, self).__getitem__(index) 142 | sample = { 143 | "image": img, 144 | "label": target, 145 | } 146 | return sample 147 | 148 | 149 | class Food101Dataset(Food101): 150 | def __init__(self, cfg, split): 151 | root = cfg.DATA.DATAPATH 152 | if split == "val": 153 | split = "test" 154 | transform = get_transforms(split, cfg.DATA.CROPSIZE) 155 | super(Food101Dataset, self).__init__(root=root, split=split, transform=transform) 156 | logger.info("Constructing {} dataset {}...".format( 157 | cfg.DATA.NAME, split)) 158 | 159 | self.cfg = cfg 160 | self._split = split 161 | self.name = cfg.DATA.NAME 162 | self._construct_imdb() 163 | 164 | def _construct_imdb(self): 165 | logger.info("Number of images: {}".format(len(self._image_files))) 166 | logger.info("Number of classes: {}".format(len(self.class_to_idx))) 167 | 168 | def get_info(self): 169 | num_imgs = len(self._image_files) 170 | return num_imgs, self.get_class_num() 171 | 172 | def get_class_num(self): 173 | return self.cfg.DATA.NUMBER_CLASSES 174 | # return len(self._class_ids) 175 | 176 | def get_class_weights(self, weight_type): 177 | """get a list of class weight, return a list float""" 178 | if "train" not in self._split: 179 | raise ValueError( 180 | "only getting training class distribution, " + \ 181 | "got split {} instead".format(self._split) 182 | ) 183 | 184 | cls_num = self.get_class_num() 185 | if weight_type == "none": 186 | return [1.0] * cls_num 187 | 188 | id2counts = Counter(self._labels) 189 | assert len(id2counts) == cls_num 190 | num_per_cls = np.array([id2counts[i] for i in self._labels]) 191 | 192 | if weight_type == 'inv': 193 | mu = -1.0 194 | elif weight_type == 'inv_sqrt': 195 | mu = -0.5 196 | weight_list = num_per_cls ** mu 197 | weight_list = np.divide( 198 | weight_list, np.linalg.norm(weight_list, 1)) * cls_num 199 | return weight_list.tolist() 200 | 201 | def __getitem__(self, index): 202 | img, target = super(Food101Dataset, self).__getitem__(index) 203 | sample = { 204 | "image": img, 205 | "label": target, 206 | } 207 | return sample 208 | -------------------------------------------------------------------------------- /src/data/vtab_datasets/kitti.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Implements Kitti data class.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import numpy as np 24 | 25 | import tensorflow.compat.v1 as tf 26 | import tensorflow_datasets as tfds 27 | 28 | from . import base as base 29 | from .registry import Registry 30 | 31 | 32 | def _count_all_pp(x): 33 | """Count all objects.""" 34 | # Count distribution (thresholded at 15): 35 | 36 | label = tf.math.minimum(tf.size(x["objects"]["type"]) - 1, 8) 37 | return {"image": x["image"], "label": label} 38 | 39 | 40 | def _count_vehicles_pp(x): 41 | """Counting vehicles.""" 42 | # Label distribution: 43 | 44 | vehicles = tf.where(x["objects"]["type"] < 3) # Car, Van, Truck. 45 | # Cap at 3. 46 | label = tf.math.minimum(tf.size(vehicles), 3) 47 | return {"image": x["image"], "label": label} 48 | 49 | 50 | def _count_left_pp(x): 51 | """Count objects on the left hand side of the camera.""" 52 | # Count distribution (thresholded at 15): 53 | 54 | # Location feature contains (x, y, z) in meters w.r.t. the camera. 55 | objects_on_left = tf.where(x["objects"]["location"][:, 0] < 0) 56 | label = tf.math.minimum(tf.size(objects_on_left), 8) 57 | return {"image": x["image"], "label": label} 58 | 59 | 60 | def _count_far_pp(x): 61 | """Counts objects far from the camera.""" 62 | # Threshold removes ~half of the objects. 63 | # Count distribution (thresholded at 15): 64 | 65 | # Location feature contains (x, y, z) in meters w.r.t. the camera. 66 | distant_objects = tf.where(x["objects"]["location"][:, 2] >= 25) 67 | label = tf.math.minimum(tf.size(distant_objects), 8) 68 | return {"image": x["image"], "label": label} 69 | 70 | 71 | def _count_near_pp(x): 72 | """Counts objects close to the camera.""" 73 | # Threshold removes ~half of the objects. 74 | # Count distribution: 75 | 76 | # Location feature contains (x, y, z) in meters w.r.t. the camera. 77 | close_objects = tf.where(x["objects"]["location"][:, 2] < 25) 78 | label = tf.math.minimum(tf.size(close_objects), 8) 79 | return {"image": x["image"], "label": label} 80 | 81 | 82 | def _closest_object_distance_pp(x): 83 | """Predict the distance to the closest object.""" 84 | # Label distribution: 85 | 86 | # Location feature contains (x, y, z) in meters w.r.t. the camera. 87 | dist = tf.reduce_min(x["objects"]["location"][:, 2]) 88 | thrs = np.array([-100, 5.6, 8.4, 13.4, 23.4]) 89 | label = tf.reduce_max(tf.where((thrs - dist) < 0)) 90 | return {"image": x["image"], "label": label} 91 | 92 | 93 | def _closest_vehicle_distance_pp(x): 94 | """Predict the distance to the closest vehicle.""" 95 | # Label distribution: 96 | 97 | # Location feature contains (x, y, z) in meters w.r.t. the camera. 98 | vehicles = tf.where(x["objects"]["type"] < 3) # Car, Van, Truck. 99 | vehicle_z = tf.gather(params=x["objects"]["location"][:, 2], indices=vehicles) 100 | vehicle_z = tf.concat([vehicle_z, tf.constant([[1000.0]])], axis=0) 101 | dist = tf.reduce_min(vehicle_z) 102 | # Results in a uniform distribution over three distances, plus one class for 103 | # "no vehicle". 104 | thrs = np.array([-100.0, 8.0, 20.0, 999.0]) 105 | label = tf.reduce_max(tf.where((thrs - dist) < 0)) 106 | return {"image": x["image"], "label": label} 107 | 108 | 109 | def _closest_object_x_location_pp(x): 110 | """Predict the absolute x position of the closest object.""" 111 | # Label distribution: 112 | 113 | # Location feature contains (x, y, z) in meters w.r.t. the camera. 114 | idx = tf.math.argmin(x["objects"]["location"][:, 2]) 115 | xloc = x["objects"]["location"][idx, 0] 116 | thrs = np.array([-100, -6.4, -3.5, 0.0, 3.3, 23.9]) 117 | label = tf.reduce_max(tf.where((thrs - xloc) < 0)) 118 | return {"image": x["image"], "label": label} 119 | 120 | 121 | _TASK_DICT = { 122 | "count_all": { 123 | "preprocess_fn": _count_all_pp, 124 | "num_classes": 16, 125 | }, 126 | "count_left": { 127 | "preprocess_fn": _count_left_pp, 128 | "num_classes": 16, 129 | }, 130 | "count_far": { 131 | "preprocess_fn": _count_far_pp, 132 | "num_classes": 16, 133 | }, 134 | "count_near": { 135 | "preprocess_fn": _count_near_pp, 136 | "num_classes": 16, 137 | }, 138 | "closest_object_distance": { 139 | "preprocess_fn": _closest_object_distance_pp, 140 | "num_classes": 5, 141 | }, 142 | "closest_object_x_location": { 143 | "preprocess_fn": _closest_object_x_location_pp, 144 | "num_classes": 5, 145 | }, 146 | "count_vehicles": { 147 | "preprocess_fn": _count_vehicles_pp, 148 | "num_classes": 4, 149 | }, 150 | "closest_vehicle_distance": { 151 | "preprocess_fn": _closest_vehicle_distance_pp, 152 | "num_classes": 4, 153 | }, 154 | } 155 | 156 | 157 | @Registry.register("data.kitti", "class") 158 | class KittiData(base.ImageTfdsData): 159 | """Provides Kitti dataset. 160 | 161 | Six tasks are supported: 162 | 1. Count the number of objects. 163 | 2. Count the number of objects on the left hand side of the camera. 164 | 3. Count the number of objects in the foreground. 165 | 4. Count the number of objects in the background. 166 | 5. Predict the distance of the closest object. 167 | 6. Predict the x-location (w.r.t. the camera) of the closest object. 168 | """ 169 | 170 | def __init__(self, task, data_dir=None): 171 | 172 | if task not in _TASK_DICT: 173 | raise ValueError("Unknown task: %s" % task) 174 | 175 | dataset_builder = tfds.builder("kitti:3.2.0", data_dir=data_dir) 176 | dataset_builder.download_and_prepare() 177 | 178 | tfds_splits = { 179 | "train": "train", 180 | "val": "validation", 181 | "trainval": "train+validation", 182 | "test": "test", 183 | "train800": "train[:800]", 184 | "val200": "validation[:200]", 185 | "train800val200": "train[:800]+validation[:200]", 186 | } 187 | 188 | # Example counts are retrieved from the tensorflow dataset info. 189 | train_count = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples 190 | val_count = dataset_builder.info.splits[tfds.Split.VALIDATION].num_examples 191 | test_count = dataset_builder.info.splits[tfds.Split.TEST].num_examples 192 | # Creates a dict with example counts for each split. 193 | num_samples_splits = { 194 | "train": train_count, 195 | "val": val_count, 196 | "trainval": train_count + val_count, 197 | "test": test_count, 198 | "train800": 800, 199 | "val200": 200, 200 | "train800val200": 1000, 201 | } 202 | 203 | task = _TASK_DICT[task] 204 | base_preprocess_fn = task["preprocess_fn"] 205 | super(KittiData, self).__init__( 206 | dataset_builder=dataset_builder, 207 | tfds_splits=tfds_splits, 208 | num_samples_splits=num_samples_splits, 209 | num_preprocessing_threads=400, 210 | shuffle_buffer_size=10000, 211 | base_preprocess_fn=base_preprocess_fn, 212 | num_classes=task["num_classes"]) 213 | -------------------------------------------------------------------------------- /src/data/datasets/json_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """JSON dataset: support CUB, NABrids, Flower, Dogs and Cars""" 4 | 5 | import os 6 | import torch 7 | import torch.utils.data 8 | import torchvision as tv 9 | import numpy as np 10 | from collections import Counter 11 | 12 | from ..transforms import get_transforms 13 | from ...utils import logging 14 | from ...utils.io_utils import read_json 15 | logger = logging.get_logger("MOSA") 16 | 17 | add_dict = { 18 | "FGVC-cars": "", 19 | "FGVC-cub": 'images', 20 | "FGVC-dogs" : "Images", 21 | "FGVC-flowers": "", 22 | "FGVC-nabirds": "images" 23 | } 24 | 25 | class JSONDataset(torch.utils.data.Dataset): 26 | def __init__(self, cfg, split): 27 | assert split in { 28 | "train", 29 | "val", 30 | "test", 31 | }, "Split '{}' not supported for {} dataset".format( 32 | split, cfg.DATA.NAME) 33 | logger.info("Constructing {} dataset {}...".format( 34 | cfg.DATA.NAME, split)) 35 | 36 | self.cfg = cfg 37 | self._split = split 38 | self.name = cfg.DATA.NAME 39 | self.data_dir = cfg.DATA.DATAPATH 40 | self.data_percentage = cfg.DATA.PERCENTAGE 41 | self._construct_imdb(cfg) 42 | self.transform = get_transforms(split, cfg.DATA.CROPSIZE) 43 | 44 | def get_anno(self): 45 | anno_path = os.path.join(self.data_dir, "{}.json".format(self._split)) 46 | if "train" in self._split: 47 | if self.data_percentage < 1.0: 48 | anno_path = os.path.join( 49 | self.data_dir, 50 | "{}_{}.json".format(self._split, self.data_percentage) 51 | ) 52 | assert os.path.exists(anno_path), "{} dir not found".format(anno_path) 53 | 54 | return read_json(anno_path) 55 | 56 | def get_imagedir(self): 57 | raise NotImplementedError() 58 | 59 | def _construct_imdb(self, cfg): 60 | """Constructs the imdb.""" 61 | if cfg.DATA.NAME == "multi": 62 | class_nbr = {"FGVC-cars": 196, "FGVC-cub": 200, "FGVC-dogs": 120, "FGVC-flowers": 102, "FGVC-nabirds": 555} 63 | label_counter = 0 64 | self._imdb = [] 65 | self._class_ids = [] 66 | for name in class_nbr.keys(): 67 | class_dir = self.get_imagedir() + name 68 | if add_dict[name] != "": 69 | img_dir = os.path.join(class_dir, add_dict[name]) 70 | else: 71 | img_dir = os.path.join(class_dir) 72 | anno_path = os.path.join(class_dir, "{}.json".format(self._split)) 73 | anno = read_json(anno_path) 74 | class_ids_list = sorted(list(set(anno.values()))) 75 | class_id_cont_id_dict = {v: i for i, v in enumerate(class_ids_list)} 76 | # label_count = len(list(set(anno.values()))) 77 | label_count = class_nbr[name] 78 | for img_name, cls_id in anno.items(): 79 | cont_id = class_id_cont_id_dict[cls_id] 80 | im_path = os.path.join(img_dir, img_name) 81 | self._imdb.append({"im_path": im_path, "class": cont_id+label_counter}) 82 | label_counter += label_count 83 | self._class_ids += class_ids_list 84 | #print(name, label_count) 85 | #print(cfg.DATA.NUMBER_CLASSES, label_counter) 86 | assert(cfg.DATA.NUMBER_CLASSES == label_counter) 87 | 88 | else: 89 | img_dir = self.get_imagedir() 90 | assert os.path.exists(img_dir), "{} dir not found".format(img_dir) 91 | 92 | anno = self.get_anno() 93 | # Map class ids to contiguous ids 94 | self._class_ids = sorted(list(set(anno.values()))) 95 | self._class_id_cont_id = {v: i for i, v in enumerate(self._class_ids)} 96 | 97 | # Construct the image db 98 | self._imdb = [] 99 | for img_name, cls_id in anno.items(): 100 | cont_id = self._class_id_cont_id[cls_id] 101 | im_path = os.path.join(img_dir, img_name) 102 | self._imdb.append({"im_path": im_path, "class": cont_id}) 103 | 104 | logger.info("Number of images: {}".format(len(self._imdb))) 105 | logger.info("Number of classes: {} / {}".format( 106 | len(self._class_ids), self.get_class_num())) 107 | 108 | def get_info(self): 109 | num_imgs = len(self._imdb) 110 | return num_imgs, self.get_class_num() 111 | 112 | def get_class_num(self): 113 | return self.cfg.DATA.NUMBER_CLASSES 114 | 115 | def get_class_weights(self, weight_type): 116 | """get a list of class weight, return a list float""" 117 | if "train" not in self._split: 118 | raise ValueError( 119 | "only getting training class distribution, " + \ 120 | "got split {} instead".format(self._split) 121 | ) 122 | 123 | cls_num = self.get_class_num() 124 | if weight_type == "none": 125 | return [1.0] * cls_num 126 | 127 | id2counts = Counter(self._class_ids) 128 | assert len(id2counts) == cls_num 129 | num_per_cls = np.array([id2counts[i] for i in self._class_ids]) 130 | 131 | if weight_type == 'inv': 132 | mu = -1.0 133 | elif weight_type == 'inv_sqrt': 134 | mu = -0.5 135 | weight_list = num_per_cls ** mu 136 | weight_list = np.divide( 137 | weight_list, np.linalg.norm(weight_list, 1)) * cls_num 138 | return weight_list.tolist() 139 | 140 | def __getitem__(self, index): 141 | # Load the image 142 | im = tv.datasets.folder.default_loader(self._imdb[index]["im_path"]) 143 | label = self._imdb[index]["class"] 144 | im = self.transform(im) 145 | if self._split == "train": 146 | index = index 147 | else: 148 | index = f"{self._split}{index}" 149 | 150 | sample = { 151 | "image": im, 152 | "label": label, 153 | # "id": index 154 | } 155 | return sample 156 | 157 | def __len__(self): 158 | return len(self._imdb) 159 | 160 | 161 | class CUB200Dataset(JSONDataset): 162 | """CUB_200 dataset.""" 163 | 164 | def __init__(self, cfg, split): 165 | super(CUB200Dataset, self).__init__(cfg, split) 166 | 167 | def get_imagedir(self): 168 | return os.path.join(self.data_dir, "images") 169 | 170 | 171 | class CarsDataset(JSONDataset): 172 | """stanford-cars dataset.""" 173 | 174 | def __init__(self, cfg, split): 175 | super(CarsDataset, self).__init__(cfg, split) 176 | 177 | def get_imagedir(self): 178 | return self.data_dir 179 | 180 | 181 | class DogsDataset(JSONDataset): 182 | """stanford-dogs dataset.""" 183 | 184 | def __init__(self, cfg, split): 185 | super(DogsDataset, self).__init__(cfg, split) 186 | 187 | def get_imagedir(self): 188 | return os.path.join(self.data_dir, "Images") 189 | 190 | 191 | class FlowersDataset(JSONDataset): 192 | """flowers dataset.""" 193 | 194 | def __init__(self, cfg, split): 195 | super(FlowersDataset, self).__init__(cfg, split) 196 | 197 | def get_imagedir(self): 198 | return self.data_dir 199 | 200 | 201 | class NabirdsDataset(JSONDataset): 202 | """Nabirds dataset.""" 203 | 204 | def __init__(self, cfg, split): 205 | super(NabirdsDataset, self).__init__(cfg, split) 206 | 207 | def get_imagedir(self): 208 | return os.path.join(self.data_dir, "images") 209 | 210 | 211 | class MultiDataset(JSONDataset): 212 | """Multitask dataset.""" 213 | def __init__(self, cfg, split): 214 | super(MultiDataset, self).__init__(cfg, split) 215 | 216 | def get_imagedir(self): 217 | return self.data_dir --------------------------------------------------------------------------------