├── figure
└── demo.gif
├── configs
├── linear
│ ├── cub.yaml
│ ├── dogs.yaml
│ ├── flowers.yaml
│ ├── cars.yaml
│ └── nabirds.yaml
├── finetune
│ ├── nabirds.yaml
│ ├── cub.yaml
│ ├── cars.yaml
│ ├── dogs.yaml
│ ├── flowers.yaml
│ └── multitask.yaml
├── base-linear.yaml
└── base-finetune.yaml
├── src
├── utils
│ ├── file_io.py
│ ├── train_utils.py
│ ├── build_pruner.py
│ ├── io_utils.py
│ ├── pruner.py
│ ├── distributed.py
│ └── logging.py
├── data
│ ├── vtab_datasets
│ │ ├── __init__.py
│ │ ├── patch_camelyon.py
│ │ ├── sun397.py
│ │ ├── dtd.py
│ │ ├── dmlab.py
│ │ ├── caltech.py
│ │ ├── resisc45.py
│ │ ├── oxford_iiit_pet.py
│ │ ├── eurosat.py
│ │ ├── cifar.py
│ │ ├── smallnorb.py
│ │ ├── oxford_flowers102.py
│ │ ├── svhn.py
│ │ ├── clevr.py
│ │ ├── dsprites.py
│ │ ├── registry.py
│ │ └── kitti.py
│ ├── transforms.py
│ ├── loader.py
│ └── datasets
│ │ ├── tf_dataset.py
│ │ ├── pt_dataset.py
│ │ └── json_dataset.py
├── configs
│ ├── config_node.py
│ ├── vit_configs.py
│ └── config.py
├── models
│ ├── mlp.py
│ ├── build_model.py
│ ├── vit_adapter
│ │ ├── adapter_block.py
│ │ ├── vit_moco.py
│ │ ├── vit_mae.py
│ │ └── swin_adapter.py
│ └── vit_backbones
│ │ ├── vit_mae.py
│ │ └── vit_moco.py
├── engine
│ ├── eval
│ │ ├── singlelabel.py
│ │ └── multilabel.py
│ └── evaluator.py
└── solver
│ ├── losses.py
│ └── lr_scheduler.py
├── env_setup.sh
├── scripts
└── mosa
│ ├── swin
│ └── FGVC
│ │ ├── cars.sh
│ │ ├── cub.sh
│ │ ├── dogs.sh
│ │ ├── nabirds.sh
│ │ └── flowers.sh
│ └── vit
│ ├── FGVC
│ ├── cars.sh
│ ├── cub.sh
│ ├── dogs.sh
│ ├── nabirds.sh
│ ├── flowers.sh
│ └── multi.sh
│ ├── GICD
│ ├── food101.sh
│ ├── aircraft.sh
│ └── cifar100.sh
│ └── VTAB
│ ├── vtab_specialized.sh
│ ├── vtab_natural.sh
│ ├── vtab_structured.sh
│ └── vtab.sh
├── README.md
├── .gitignore
├── launch.py
└── VTAB_SETUP.md
/figure/demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Theia-4869/MoSA/HEAD/figure/demo.gif
--------------------------------------------------------------------------------
/configs/linear/cub.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../base-linear.yaml"
2 | RUN_N_TIMES: 1
3 | DATA:
4 | NAME: "CUB"
5 | DATAPATH: "" #TODO: need to specify here
6 | NUMBER_CLASSES: 200
7 | MULTILABEL: False
8 | MODEL:
9 | TYPE: "vit"
10 | SOLVER:
11 | BASE_LR: 0.1
12 | WEIGHT_DECAY: 0.01
--------------------------------------------------------------------------------
/configs/linear/dogs.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../base-linear.yaml"
2 | RUN_N_TIMES: 1
3 | DATA:
4 | NAME: "StanfordDogs"
5 | DATAPATH: "" #TODO: need to specify here
6 | NUMBER_CLASSES: 120
7 | MULTILABEL: False
8 | MODEL:
9 | TYPE: "vit"
10 | SOLVER:
11 | BASE_LR: 0.001
12 | WEIGHT_DECAY: 0.0001
--------------------------------------------------------------------------------
/configs/linear/flowers.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../base-linear.yaml"
2 | RUN_N_TIMES: 1
3 | DATA:
4 | NAME: "OxfordFlowers"
5 | DATAPATH: "" #TODO: need to specify here
6 | NUMBER_CLASSES: 102
7 | MULTILABEL: False
8 | MODEL:
9 | TYPE: "vit"
10 | SOLVER:
11 | BASE_LR: 0.001
12 | WEIGHT_DECAY: 0.0001
--------------------------------------------------------------------------------
/configs/finetune/nabirds.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../base-finetune.yaml"
2 | RUN_N_TIMES: 1
3 | DATA:
4 | NAME: "nabirds"
5 | DATAPATH: "" #TODO: need to specify here
6 | NUMBER_CLASSES: 555
7 | MULTILABEL: False
8 | MODEL:
9 | TYPE: "vit"
10 | SOLVER:
11 | BASE_LR: 0.00375
12 | WEIGHT_DECAY: 0.01
13 |
--------------------------------------------------------------------------------
/configs/linear/cars.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../base-linear.yaml"
2 | RUN_N_TIMES: 1
3 | DATA:
4 | NAME: "StanfordCars"
5 | DATAPATH: "" #TODO: need to specify here
6 | NUMBER_CLASSES: 196
7 | MULTILABEL: False
8 | MODEL:
9 | TYPE: "vit"
10 | SOLVER:
11 | BASE_LR: 0.001
12 | WEIGHT_DECAY: 0.0001
13 |
--------------------------------------------------------------------------------
/configs/linear/nabirds.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../base-linear.yaml"
2 | RUN_N_TIMES: 1
3 | DATA:
4 | NAME: "nabirds"
5 | DATAPATH: "" #TODO: need to specify here
6 | NUMBER_CLASSES: 555
7 | MULTILABEL: False
8 | MODEL:
9 | TYPE: "vit"
10 | SOLVER:
11 | BASE_LR: 0.1
12 | WEIGHT_DECAY: 0.01
13 | LOG_EVERY_N: 10
--------------------------------------------------------------------------------
/configs/finetune/cub.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../base-finetune.yaml"
2 | RUN_N_TIMES: 1
3 | DATA:
4 | NAME: "CUB"
5 | DATAPATH: "" #TODO: need to specify here
6 | NUMBER_CLASSES: 200
7 | MULTILABEL: False
8 | FEATURE: "imagenet_supervised" # need to tune
9 | MODEL:
10 | TYPE: "vit"
11 | SOLVER:
12 | BASE_LR: 0.00375
13 | WEIGHT_DECAY: 0.01
14 |
--------------------------------------------------------------------------------
/configs/finetune/cars.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../base-finetune.yaml"
2 | RUN_N_TIMES: 1
3 | DATA:
4 | NAME: "StanfordCars"
5 | DATAPATH: "" #TODO: need to specify here
6 | NUMBER_CLASSES: 196
7 | MULTILABEL: False
8 | FEATURE: "imagenet_supervised" # need to tune
9 | MODEL:
10 | TYPE: "vit"
11 | SOLVER:
12 | BASE_LR: 0.0375
13 | WEIGHT_DECAY: 0.001
14 |
--------------------------------------------------------------------------------
/configs/finetune/dogs.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../base-finetune.yaml"
2 | RUN_N_TIMES: 1
3 | DATA:
4 | NAME: "StanfordDogs"
5 | DATAPATH: "" #TODO: need to specify here
6 | NUMBER_CLASSES: 120
7 | MULTILABEL: False
8 | FEATURE: "imagenet_supervised" # need to tune
9 | MODEL:
10 | TYPE: "vit"
11 | SOLVER:
12 | BASE_LR: 0.00375
13 | WEIGHT_DECAY: 0.001
14 |
--------------------------------------------------------------------------------
/configs/finetune/flowers.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../base-finetune.yaml"
2 | RUN_N_TIMES: 1
3 | DATA:
4 | NAME: "OxfordFlowers"
5 | DATAPATH: "" #TODO: need to specify here
6 | NUMBER_CLASSES: 102
7 | MULTILABEL: False
8 | FEATURE: "imagenet_supervised" # need to tune
9 | MODEL:
10 | TYPE: "vit"
11 | SOLVER:
12 | BASE_LR: 0.001
13 | WEIGHT_DECAY: 0.0001
14 |
--------------------------------------------------------------------------------
/src/utils/file_io.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | """
4 | Project specific pathmanagers for a project as recommended by Detectron2
5 | """
6 | from iopath.common.file_io import PathManager as PathManagerBase
7 | from iopath.common.file_io import HTTPURLHandler
8 |
9 |
10 | PathManager = PathManagerBase()
11 | PathManager.register_handler(HTTPURLHandler())
12 |
--------------------------------------------------------------------------------
/configs/finetune/multitask.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../base-finetune.yaml"
2 | RUN_N_TIMES: 1
3 | DATA:
4 | NAME: "multi"
5 | DATAPATH: "" #TODO: need to specify here
6 | NUMBER_CLASSES: 1173 # 200+55+102+120+196
7 | MULTILABEL: False
8 | FEATURE: "imagenet_supervised" # need to tune
9 | MODEL:
10 | TYPE: "vit"
11 | SOLVER:
12 | BASE_LR: 0.00375
13 | WEIGHT_DECAY: 0.01
14 |
--------------------------------------------------------------------------------
/configs/base-linear.yaml:
--------------------------------------------------------------------------------
1 | NUM_GPUS: 1
2 | NUM_SHARDS: 1
3 | OUTPUT_DIR: ""
4 | RUN_N_TIMES: 1
5 | MODEL:
6 | TRANSFER_TYPE: "linear"
7 | TYPE: "vit"
8 | LINEAR:
9 | MLP_SIZES: []
10 | SOLVER:
11 | SCHEDULER: "cosine"
12 | PATIENCE: 300
13 | LOSS: "softmax"
14 | OPTIMIZER: "sgd"
15 | MOMENTUM: 0.9
16 | WEIGHT_DECAY: 0.0001
17 | LOG_EVERY_N: 100
18 | WARMUP_EPOCH: 10
19 | TOTAL_EPOCH: 100
20 | DATA:
21 | NAME: ""
22 | NUMBER_CLASSES: -1
23 | DATAPATH: ""
24 | FEATURE: "sup_vitb16_224"
25 | BATCH_SIZE: 1024
--------------------------------------------------------------------------------
/configs/base-finetune.yaml:
--------------------------------------------------------------------------------
1 | NUM_GPUS: 1
2 | NUM_SHARDS: 1
3 | OUTPUT_DIR: ""
4 | RUN_N_TIMES: 1
5 | MODEL:
6 | TRANSFER_TYPE: "end2end"
7 | TYPE: "vit"
8 | LINEAR:
9 | MLP_SIZES: []
10 | SOLVER:
11 | SCHEDULER: "cosine"
12 | WARMUP_EPOCH: 10
13 | LOSS: "softmax"
14 | OPTIMIZER: "adamw"
15 | MOMENTUM: 0.9
16 | WEIGHT_DECAY: 0.0001
17 | LOG_EVERY_N: 100
18 | TOTAL_EPOCH: 100
19 | PATIENCE: 300
20 | DATA:
21 | NAME: ""
22 | NUMBER_CLASSES: -1
23 | DATAPATH: ""
24 | FEATURE: "sup_vitb16_224"
25 | BATCH_SIZE: 384
26 | NUM_WORKERS: 8
--------------------------------------------------------------------------------
/src/data/vtab_datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3 | # Copyright 2019 Google LLC.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 |
--------------------------------------------------------------------------------
/src/configs/config_node.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | """Config system (based on Detectron's)."""
4 |
5 | from fvcore.common.config import CfgNode as _CfgNode
6 | from ..utils.file_io import PathManager
7 |
8 |
9 | class CfgNode(_CfgNode):
10 | """
11 | The same as `fvcore.common.config.CfgNode`, but different in:
12 |
13 | support manifold path
14 | """
15 |
16 | @classmethod
17 | def _open_cfg(cls, filename):
18 | return PathManager.open(filename, "r")
19 |
20 | def dump(self, *args, **kwargs):
21 | """
22 | Returns:
23 | str: a yaml string representation of the config
24 | """
25 | # to make it show up in docs
26 | return super().dump(*args, **kwargs)
27 |
--------------------------------------------------------------------------------
/env_setup.sh:
--------------------------------------------------------------------------------
1 | conda create -n mosa python=3.7
2 | conda activate mosa
3 |
4 | pip install -q tensorflow
5 | # specifying tfds versions is important to reproduce our results
6 | pip install tfds-nightly==4.4.0.dev202201080107
7 | pip install opencv-python
8 | pip install tensorflow-addons
9 | pip install mock
10 |
11 |
12 | conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
13 |
14 | python -m pip install detectron2 -f \
15 | https://dl.fbaipublicfiles.com/detectron2/wheels/cu110/torch1.7/index.html
16 | pip install opencv-python
17 |
18 | conda install tqdm pandas matplotlib seaborn scikit-learn scipy simplejson termcolor
19 | conda install -c iopath iopath
20 |
21 |
22 | # for transformers
23 | pip install timm==0.4.12
24 | pip install ml-collections
25 |
26 | # Optional: for slurm jobs
27 | pip install submitit -U
28 | pip install slurm_gpustat
--------------------------------------------------------------------------------
/scripts/mosa/swin/FGVC/cars.sh:
--------------------------------------------------------------------------------
1 | # launch final training for FGVC-cars.
2 |
3 | gpu_id=${1}
4 | data_path=/your/path/to/FGVC/StanfordCars
5 | model_root=checkpoints
6 | output_dir=${2}
7 |
8 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \
9 | --config-file configs/finetune/cars.yaml \
10 | --sparse-train \
11 | DATA.BATCH_SIZE "128" \
12 | DATA.FEATURE "swinb_imagenet22k_224" \
13 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \
14 | MODEL.ADAPTER.EXPERT_NUM "2" \
15 | MODEL.ADAPTER.MOE "True" \
16 | MODEL.ADAPTER.MERGE "add" \
17 | MODEL.ADAPTER.SHARE "up" \
18 | MODEL.ADAPTER.ADDITIONAL "True" \
19 | MODEL.ADAPTER.DEEPREG "True" \
20 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \
21 | MODEL.ADAPTER.REG_WEIGHT "1.0" \
22 | MODEL.TRANSFER_TYPE "mosa" \
23 | MODEL.TYPE "swin" \
24 | SEED "3407" \
25 | SOLVER.BASE_LR "0.005" \
26 | SOLVER.WEIGHT_DECAY "0.01" \
27 | SOLVER.WARMUP_EPOCH "10" \
28 | DATA.DATAPATH "${data_path}" \
29 | GPU_ID "${gpu_id}" \
30 | MODEL.MODEL_ROOT "${model_root}" \
31 | OUTPUT_DIR "${output_dir}"
32 |
--------------------------------------------------------------------------------
/scripts/mosa/swin/FGVC/cub.sh:
--------------------------------------------------------------------------------
1 | # launch final training for FGVC-cub.
2 |
3 | gpu_id=${1}
4 | data_path=/your/path/to/FGVC/CUB_200_2011
5 | model_root=checkpoints
6 | output_dir=${2}
7 |
8 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \
9 | --config-file configs/finetune/cub.yaml \
10 | --sparse-train \
11 | DATA.BATCH_SIZE "128" \
12 | DATA.FEATURE "swinb_imagenet22k_224" \
13 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \
14 | MODEL.ADAPTER.EXPERT_NUM "4" \
15 | MODEL.ADAPTER.MOE "True" \
16 | MODEL.ADAPTER.MERGE "add" \
17 | MODEL.ADAPTER.SHARE "down" \
18 | MODEL.ADAPTER.ADDITIONAL "True" \
19 | MODEL.ADAPTER.DEEPREG "False" \
20 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \
21 | MODEL.ADAPTER.REG_WEIGHT "1.0" \
22 | MODEL.TRANSFER_TYPE "mosa" \
23 | MODEL.TYPE "swin" \
24 | SEED "3407" \
25 | SOLVER.BASE_LR "0.001" \
26 | SOLVER.WEIGHT_DECAY "0.01" \
27 | SOLVER.WARMUP_EPOCH "10" \
28 | DATA.DATAPATH "${data_path}" \
29 | GPU_ID "${gpu_id}" \
30 | MODEL.MODEL_ROOT "${model_root}" \
31 | OUTPUT_DIR "${output_dir}"
32 |
--------------------------------------------------------------------------------
/scripts/mosa/swin/FGVC/dogs.sh:
--------------------------------------------------------------------------------
1 | # launch final training for FGVC-dogs.
2 |
3 | gpu_id=${1}
4 | data_path=/your/path/to/FGVC/StanfordDogs
5 | model_root=checkpoints
6 | output_dir=${2}
7 |
8 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \
9 | --config-file configs/finetune/dogs.yaml \
10 | --sparse-train \
11 | DATA.BATCH_SIZE "128" \
12 | DATA.FEATURE "swinb_imagenet22k_224" \
13 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \
14 | MODEL.ADAPTER.EXPERT_NUM "4" \
15 | MODEL.ADAPTER.MOE "True" \
16 | MODEL.ADAPTER.MERGE "add" \
17 | MODEL.ADAPTER.SHARE "down" \
18 | MODEL.ADAPTER.ADDITIONAL "True" \
19 | MODEL.ADAPTER.DEEPREG "False" \
20 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \
21 | MODEL.ADAPTER.REG_WEIGHT "1.0" \
22 | MODEL.TRANSFER_TYPE "mosa" \
23 | MODEL.TYPE "swin" \
24 | SEED "3407" \
25 | SOLVER.BASE_LR "0.0005" \
26 | SOLVER.WEIGHT_DECAY "0.01" \
27 | SOLVER.WARMUP_EPOCH "10" \
28 | DATA.DATAPATH "${data_path}" \
29 | GPU_ID "${gpu_id}" \
30 | MODEL.MODEL_ROOT "${model_root}" \
31 | OUTPUT_DIR "${output_dir}"
32 |
--------------------------------------------------------------------------------
/scripts/mosa/swin/FGVC/nabirds.sh:
--------------------------------------------------------------------------------
1 | # launch final training for FGVC-nabirds.
2 |
3 | gpu_id=${1}
4 | data_path=/your/path/to/FGVC/nabirds
5 | model_root=checkpoints
6 | output_dir=${2}
7 |
8 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \
9 | --config-file configs/finetune/nabirds.yaml \
10 | --sparse-train \
11 | DATA.BATCH_SIZE "128" \
12 | DATA.FEATURE "swinb_imagenet22k_224" \
13 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \
14 | MODEL.ADAPTER.EXPERT_NUM "4" \
15 | MODEL.ADAPTER.MOE "True" \
16 | MODEL.ADAPTER.MERGE "add" \
17 | MODEL.ADAPTER.SHARE "down" \
18 | MODEL.ADAPTER.ADDITIONAL "True" \
19 | MODEL.ADAPTER.DEEPREG "False" \
20 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \
21 | MODEL.ADAPTER.REG_WEIGHT "1.0" \
22 | MODEL.TRANSFER_TYPE "mosa" \
23 | MODEL.TYPE "swin" \
24 | SEED "3407" \
25 | SOLVER.BASE_LR "0.0005" \
26 | SOLVER.WEIGHT_DECAY "0.01" \
27 | SOLVER.WARMUP_EPOCH "10" \
28 | DATA.DATAPATH "${data_path}" \
29 | GPU_ID "${gpu_id}" \
30 | MODEL.MODEL_ROOT "${model_root}" \
31 | OUTPUT_DIR "${output_dir}"
32 |
--------------------------------------------------------------------------------
/scripts/mosa/swin/FGVC/flowers.sh:
--------------------------------------------------------------------------------
1 | # launch final training for FGVC-flowers.
2 |
3 | gpu_id=${1}
4 | data_path=/your/path/to/FGVC/OxfordFlowers
5 | model_root=checkpoints
6 | output_dir=${2}
7 |
8 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \
9 | --config-file configs/finetune/flowers.yaml \
10 | --sparse-train \
11 | DATA.BATCH_SIZE "128" \
12 | DATA.FEATURE "swinb_imagenet22k_224" \
13 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \
14 | MODEL.ADAPTER.EXPERT_NUM "4" \
15 | MODEL.ADAPTER.MOE "True" \
16 | MODEL.ADAPTER.MERGE "add" \
17 | MODEL.ADAPTER.SHARE "down" \
18 | MODEL.ADAPTER.ADDITIONAL "True" \
19 | MODEL.ADAPTER.DEEPREG "False" \
20 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \
21 | MODEL.ADAPTER.REG_WEIGHT "1.0" \
22 | MODEL.TRANSFER_TYPE "mosa" \
23 | MODEL.TYPE "swin" \
24 | SEED "3407" \
25 | SOLVER.BASE_LR "0.005" \
26 | SOLVER.WEIGHT_DECAY "0.01" \
27 | SOLVER.WARMUP_EPOCH "10" \
28 | DATA.DATAPATH "${data_path}" \
29 | GPU_ID "${gpu_id}" \
30 | MODEL.MODEL_ROOT "${model_root}" \
31 | OUTPUT_DIR "${output_dir}"
32 |
--------------------------------------------------------------------------------
/src/utils/train_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import torch
3 |
4 |
5 | def gpu_mem_usage():
6 | """Computes the GPU memory usage for the current device (GB)."""
7 | if not torch.cuda.is_available():
8 | return 0
9 | # Number of bytes in a megabyte
10 | _B_IN_GB = 1024 * 1024 * 1024
11 |
12 | mem_usage_bytes = torch.cuda.max_memory_allocated()
13 | return mem_usage_bytes / _B_IN_GB
14 |
15 |
16 | class AverageMeter(object):
17 | """Computes and stores the average and current value"""
18 | def __init__(self, name, fmt=':f'):
19 | self.name = name
20 | self.fmt = fmt
21 | self.reset()
22 |
23 | def reset(self):
24 | self.val = 0
25 | self.avg = 0
26 | self.sum = 0
27 | self.count = 0
28 |
29 | def update(self, val, n=1):
30 | self.val = val
31 | self.sum += val * n
32 | self.count += n
33 | self.avg = self.sum / self.count
34 |
35 | def __str__(self):
36 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
37 | return fmtstr.format(**self.__dict__)
38 |
--------------------------------------------------------------------------------
/scripts/mosa/vit/FGVC/cars.sh:
--------------------------------------------------------------------------------
1 | # launch final training for FGVC-cars.
2 |
3 | gpu_id=${1}
4 | data_path=/your/path/to/FGVC/StanfordCars
5 | model_root=checkpoints
6 | style=${2}
7 | output_dir=${3}
8 |
9 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \
10 | --config-file configs/finetune/cars.yaml \
11 | --sparse-train \
12 | DATA.BATCH_SIZE "128" \
13 | DATA.FEATURE "sup_vitb16_imagenet21k" \
14 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \
15 | MODEL.ADAPTER.STYLE "${style}" \
16 | MODEL.ADAPTER.EXPERT_NUM "2" \
17 | MODEL.ADAPTER.MOE "True" \
18 | MODEL.ADAPTER.MERGE "add" \
19 | MODEL.ADAPTER.SHARE "up" \
20 | MODEL.ADAPTER.ADDITIONAL "True" \
21 | MODEL.ADAPTER.DEEPREG "True" \
22 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \
23 | MODEL.ADAPTER.REG_WEIGHT "1.0" \
24 | MODEL.TRANSFER_TYPE "mosa" \
25 | MODEL.TYPE "vit" \
26 | SEED "3407" \
27 | SOLVER.BASE_LR "0.005" \
28 | SOLVER.WEIGHT_DECAY "0.01" \
29 | SOLVER.WARMUP_EPOCH "10" \
30 | DATA.DATAPATH "${data_path}" \
31 | GPU_ID "${gpu_id}" \
32 | MODEL.MODEL_ROOT "${model_root}" \
33 | OUTPUT_DIR "${output_dir}"
34 |
--------------------------------------------------------------------------------
/scripts/mosa/vit/FGVC/cub.sh:
--------------------------------------------------------------------------------
1 | # launch final training for FGVC-cub.
2 |
3 | gpu_id=${1}
4 | data_path=/your/path/to/FGVC/CUB_200_2011
5 | model_root=checkpoints
6 | style=${2}
7 | output_dir=${3}
8 |
9 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \
10 | --config-file configs/finetune/cub.yaml \
11 | --sparse-train \
12 | DATA.BATCH_SIZE "128" \
13 | DATA.FEATURE "sup_vitb16_imagenet21k" \
14 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \
15 | MODEL.ADAPTER.STYLE "${style}" \
16 | MODEL.ADAPTER.EXPERT_NUM "4" \
17 | MODEL.ADAPTER.MOE "True" \
18 | MODEL.ADAPTER.MERGE "add" \
19 | MODEL.ADAPTER.SHARE "down" \
20 | MODEL.ADAPTER.ADDITIONAL "True" \
21 | MODEL.ADAPTER.DEEPREG "False" \
22 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \
23 | MODEL.ADAPTER.REG_WEIGHT "1.0" \
24 | MODEL.TRANSFER_TYPE "mosa" \
25 | MODEL.TYPE "vit" \
26 | SEED "3407" \
27 | SOLVER.BASE_LR "0.001" \
28 | SOLVER.WEIGHT_DECAY "0.01" \
29 | SOLVER.WARMUP_EPOCH "10" \
30 | DATA.DATAPATH "${data_path}" \
31 | GPU_ID "${gpu_id}" \
32 | MODEL.MODEL_ROOT "${model_root}" \
33 | OUTPUT_DIR "${output_dir}"
34 |
--------------------------------------------------------------------------------
/scripts/mosa/vit/FGVC/dogs.sh:
--------------------------------------------------------------------------------
1 | # launch final training for FGVC-dogs.
2 |
3 | gpu_id=${1}
4 | data_path=/your/path/to/FGVC/StanfordDogs
5 | model_root=checkpoints
6 | style=${2}
7 | output_dir=${3}
8 |
9 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \
10 | --config-file configs/finetune/dogs.yaml \
11 | --sparse-train \
12 | DATA.BATCH_SIZE "128" \
13 | DATA.FEATURE "sup_vitb16_imagenet21k" \
14 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \
15 | MODEL.ADAPTER.STYLE "${style}" \
16 | MODEL.ADAPTER.EXPERT_NUM "4" \
17 | MODEL.ADAPTER.MOE "True" \
18 | MODEL.ADAPTER.MERGE "add" \
19 | MODEL.ADAPTER.SHARE "down" \
20 | MODEL.ADAPTER.ADDITIONAL "True" \
21 | MODEL.ADAPTER.DEEPREG "False" \
22 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \
23 | MODEL.ADAPTER.REG_WEIGHT "1.0" \
24 | MODEL.TRANSFER_TYPE "mosa" \
25 | MODEL.TYPE "vit" \
26 | SEED "3407" \
27 | SOLVER.BASE_LR "0.0005" \
28 | SOLVER.WEIGHT_DECAY "0.01" \
29 | SOLVER.WARMUP_EPOCH "10" \
30 | DATA.DATAPATH "${data_path}" \
31 | GPU_ID "${gpu_id}" \
32 | MODEL.MODEL_ROOT "${model_root}" \
33 | OUTPUT_DIR "${output_dir}"
34 |
--------------------------------------------------------------------------------
/scripts/mosa/vit/FGVC/nabirds.sh:
--------------------------------------------------------------------------------
1 | # launch final training for FGVC-nabirds.
2 |
3 | gpu_id=${1}
4 | data_path=/your/path/to/FGVC/nabirds
5 | model_root=checkpoints
6 | style=${2}
7 | output_dir=${3}
8 |
9 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \
10 | --config-file configs/finetune/nabirds.yaml \
11 | --sparse-train \
12 | DATA.BATCH_SIZE "128" \
13 | DATA.FEATURE "sup_vitb16_imagenet21k" \
14 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \
15 | MODEL.ADAPTER.STYLE "${style}" \
16 | MODEL.ADAPTER.EXPERT_NUM "4" \
17 | MODEL.ADAPTER.MOE "True" \
18 | MODEL.ADAPTER.MERGE "add" \
19 | MODEL.ADAPTER.SHARE "down" \
20 | MODEL.ADAPTER.ADDITIONAL "True" \
21 | MODEL.ADAPTER.DEEPREG "False" \
22 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \
23 | MODEL.ADAPTER.REG_WEIGHT "1.0" \
24 | MODEL.TRANSFER_TYPE "mosa" \
25 | MODEL.TYPE "vit" \
26 | SEED "3407" \
27 | SOLVER.BASE_LR "0.0005" \
28 | SOLVER.WEIGHT_DECAY "0.01" \
29 | SOLVER.WARMUP_EPOCH "10" \
30 | DATA.DATAPATH "${data_path}" \
31 | GPU_ID "${gpu_id}" \
32 | MODEL.MODEL_ROOT "${model_root}" \
33 | OUTPUT_DIR "${output_dir}"
34 |
--------------------------------------------------------------------------------
/scripts/mosa/vit/FGVC/flowers.sh:
--------------------------------------------------------------------------------
1 | # launch final training for FGVC-flowers.
2 |
3 | gpu_id=${1}
4 | data_path=/your/path/to/FGVC/OxfordFlowers
5 | model_root=checkpoints
6 | style=${2}
7 | output_dir=${3}
8 |
9 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \
10 | --config-file configs/finetune/flowers.yaml \
11 | --sparse-train \
12 | DATA.BATCH_SIZE "128" \
13 | DATA.FEATURE "sup_vitb16_imagenet21k" \
14 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \
15 | MODEL.ADAPTER.STYLE "${style}" \
16 | MODEL.ADAPTER.EXPERT_NUM "4" \
17 | MODEL.ADAPTER.MOE "True" \
18 | MODEL.ADAPTER.MERGE "add" \
19 | MODEL.ADAPTER.SHARE "down" \
20 | MODEL.ADAPTER.ADDITIONAL "True" \
21 | MODEL.ADAPTER.DEEPREG "False" \
22 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \
23 | MODEL.ADAPTER.REG_WEIGHT "1.0" \
24 | MODEL.TRANSFER_TYPE "mosa" \
25 | MODEL.TYPE "vit" \
26 | SEED "3407" \
27 | SOLVER.BASE_LR "0.005" \
28 | SOLVER.WEIGHT_DECAY "0.01" \
29 | SOLVER.WARMUP_EPOCH "10" \
30 | DATA.DATAPATH "${data_path}" \
31 | GPU_ID "${gpu_id}" \
32 | MODEL.MODEL_ROOT "${model_root}" \
33 | OUTPUT_DIR "${output_dir}"
34 |
--------------------------------------------------------------------------------
/scripts/mosa/vit/GICD/food101.sh:
--------------------------------------------------------------------------------
1 | # launch final training for GICD-food101.
2 |
3 | gpu_id=${1}
4 | data_path=/your/path/to/GICD/Food101
5 | model_root=checkpoints
6 | output_dir=${2}
7 |
8 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \
9 | --config-file configs/finetune/cub.yaml \
10 | --sparse-train \
11 | DATA.BATCH_SIZE "128" \
12 | DATA.FEATURE "sup_vitb16_imagenet21k" \
13 | DATA.NAME "food101" \
14 | DATA.NUMBER_CLASSES "101" \
15 | DATA.NO_TEST "True" \
16 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \
17 | MODEL.ADAPTER.EXPERT_NUM "4" \
18 | MODEL.ADAPTER.MOE "True" \
19 | MODEL.ADAPTER.MERGE "add" \
20 | MODEL.ADAPTER.SHARE "down" \
21 | MODEL.ADAPTER.ADDITIONAL "True" \
22 | MODEL.ADAPTER.DEEPREG "False" \
23 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \
24 | MODEL.ADAPTER.REG_WEIGHT "1.0" \
25 | MODEL.TRANSFER_TYPE "mosa" \
26 | MODEL.TYPE "vit" \
27 | SEED "3407" \
28 | SOLVER.BASE_LR "0.0005" \
29 | SOLVER.WEIGHT_DECAY "0.01" \
30 | SOLVER.WARMUP_EPOCH "10" \
31 | DATA.DATAPATH "${data_path}" \
32 | GPU_ID "${gpu_id}" \
33 | MODEL.MODEL_ROOT "${model_root}" \
34 | OUTPUT_DIR "${output_dir}"
35 |
--------------------------------------------------------------------------------
/scripts/mosa/vit/GICD/aircraft.sh:
--------------------------------------------------------------------------------
1 | # launch final training for GICD-aricraft.
2 |
3 | gpu_id=${1}
4 | data_path=/your/path/to/GICD/FGVCAircraft
5 | model_root=checkpoints
6 | output_dir=${2}
7 |
8 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \
9 | --config-file configs/finetune/cub.yaml \
10 | --sparse-train \
11 | DATA.BATCH_SIZE "128" \
12 | DATA.FEATURE "sup_vitb16_imagenet21k" \
13 | DATA.NAME "aircraft" \
14 | DATA.NUMBER_CLASSES "100" \
15 | DATA.NO_TEST "True" \
16 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \
17 | MODEL.ADAPTER.EXPERT_NUM "4" \
18 | MODEL.ADAPTER.MOE "True" \
19 | MODEL.ADAPTER.MERGE "add" \
20 | MODEL.ADAPTER.SHARE "down" \
21 | MODEL.ADAPTER.ADDITIONAL "True" \
22 | MODEL.ADAPTER.DEEPREG "False" \
23 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \
24 | MODEL.ADAPTER.REG_WEIGHT "1.0" \
25 | MODEL.TRANSFER_TYPE "mosa" \
26 | MODEL.TYPE "vit" \
27 | SEED "3407" \
28 | SOLVER.BASE_LR "0.005" \
29 | SOLVER.WEIGHT_DECAY "0.01" \
30 | SOLVER.WARMUP_EPOCH "10" \
31 | DATA.DATAPATH "${data_path}" \
32 | GPU_ID "${gpu_id}" \
33 | MODEL.MODEL_ROOT "${model_root}" \
34 | OUTPUT_DIR "${output_dir}"
35 |
--------------------------------------------------------------------------------
/scripts/mosa/vit/GICD/cifar100.sh:
--------------------------------------------------------------------------------
1 | # launch final training for GICD-cifar100.
2 |
3 | gpu_id=${1}
4 | data_path=/your/path/to/GICD/CIFAR100
5 | model_root=checkpoints
6 | output_dir=${2}
7 |
8 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \
9 | --config-file configs/finetune/cub.yaml \
10 | --sparse-train \
11 | DATA.BATCH_SIZE "128" \
12 | DATA.FEATURE "sup_vitb16_imagenet21k" \
13 | DATA.NAME "cifar100" \
14 | DATA.NUMBER_CLASSES "100" \
15 | DATA.NO_TEST "True" \
16 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \
17 | MODEL.ADAPTER.EXPERT_NUM "4" \
18 | MODEL.ADAPTER.MOE "True" \
19 | MODEL.ADAPTER.MERGE "add" \
20 | MODEL.ADAPTER.SHARE "down" \
21 | MODEL.ADAPTER.ADDITIONAL "True" \
22 | MODEL.ADAPTER.DEEPREG "False" \
23 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \
24 | MODEL.ADAPTER.REG_WEIGHT "1.0" \
25 | MODEL.TRANSFER_TYPE "mosa" \
26 | MODEL.TYPE "vit" \
27 | SEED "3407" \
28 | SOLVER.BASE_LR "0.0005" \
29 | SOLVER.WEIGHT_DECAY "0.01" \
30 | SOLVER.WARMUP_EPOCH "10" \
31 | DATA.DATAPATH "${data_path}" \
32 | GPU_ID "${gpu_id}" \
33 | MODEL.MODEL_ROOT "${model_root}" \
34 | OUTPUT_DIR "${output_dir}"
35 |
--------------------------------------------------------------------------------
/scripts/mosa/vit/FGVC/multi.sh:
--------------------------------------------------------------------------------
1 | # launch final training for FGVC-multitask. The hyperparameters are the same from our paper.
2 |
3 | gpu_id=${1}
4 | data_path=/data/dataset/
5 | model_root=checkpoints
6 | style=${2}
7 | output_dir=${3}
8 |
9 | # rm -rf ${output_dir}
10 |
11 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \
12 | --config-file configs/finetune/multitask.yaml \
13 | --sparse-train \
14 | DATA.BATCH_SIZE "128" \
15 | DATA.FEATURE "sup_vitb16_imagenet21k" \
16 | DATA.NO_TEST "False" \
17 | MODEL.ADAPTER.BOTTLENECK_SIZE "64" \
18 | MODEL.ADAPTER.STYLE "${style}" \
19 | MODEL.ADAPTER.EXPERT_NUM "2" \
20 | MODEL.ADAPTER.MOE "True" \
21 | MODEL.ADAPTER.MERGE "add" \
22 | MODEL.ADAPTER.SHARE "none" \
23 | MODEL.ADAPTER.ADDITIONAL "True" \
24 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \
25 | MODEL.ADAPTER.REG_WEIGHT "1.0" \
26 | MODEL.TRANSFER_TYPE "mosa" \
27 | MODEL.TYPE "vit" \
28 | SEED "3407" \
29 | SOLVER.BASE_LR "0.0005" \
30 | SOLVER.WEIGHT_DECAY "0.01" \
31 | SOLVER.WARMUP_EPOCH "10" \
32 | DATA.DATAPATH "${data_path}" \
33 | GPU_ID "${gpu_id}" \
34 | MODEL.MODEL_ROOT "${model_root}" \
35 | OUTPUT_DIR "${output_dir}"
36 |
--------------------------------------------------------------------------------
/src/utils/build_pruner.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Pruner construction functions.
4 | """
5 |
6 | from .pruner import Rand
7 | from . import logging
8 | logger = logging.get_logger("MOSA")
9 | # Supported pruner types
10 | _PRUNER_TYPES = {
11 | "random": Rand,
12 | }
13 |
14 |
15 | def build_pruner(cfg):
16 | """
17 | build pruner here
18 | """
19 | assert (
20 | cfg.PRUNER.TYPE in _PRUNER_TYPES.keys()
21 | ), "Model type '{}' not supported".format(cfg.PRUNER.TYPE)
22 |
23 | # Construct the pruner
24 | prune_type = cfg.PRUNER.TYPE
25 | pruner = _PRUNER_TYPES[prune_type](cfg)
26 |
27 | return pruner
28 |
29 |
30 | def log_pruned_model_info(model, verbose=False):
31 | """Logs pruned model info"""
32 | if verbose:
33 | logger.info(f"Classification Model:\n{model}")
34 | model_total_params = sum(p.numel() for p in model.parameters())
35 | model_grad_params = sum(int(p.mask.sum()) if hasattr(p, 'mask') else p.numel() for p in model.parameters() if p.requires_grad)
36 | logger.info("Total Parameters: {0}\t Gradient Parameters: {1}".format(
37 | model_total_params, model_grad_params))
38 | logger.info("tuned percent:%.3f"%(model_grad_params/model_total_params*100))
39 |
--------------------------------------------------------------------------------
/src/data/transforms.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | """Image transformations."""
4 | import torchvision as tv
5 |
6 |
7 | def get_transforms(split, size):
8 | normalize = tv.transforms.Normalize(
9 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
10 | )
11 | if size == 448:
12 | resize_dim = 512
13 | crop_dim = 448
14 | elif size == 384:
15 | resize_dim = 438
16 | crop_dim = 384
17 | elif size == 224:
18 | resize_dim = 256
19 | crop_dim = 224
20 | if split == "train":
21 | transform = tv.transforms.Compose(
22 | [
23 | tv.transforms.Resize(resize_dim),
24 | tv.transforms.RandomCrop(crop_dim),
25 | tv.transforms.RandomHorizontalFlip(0.5),
26 | # tv.transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
27 | # tv.transforms.RandomHorizontalFlip(),
28 | tv.transforms.ToTensor(),
29 | normalize,
30 | ]
31 | )
32 | else:
33 | transform = tv.transforms.Compose(
34 | [
35 | tv.transforms.Resize(resize_dim),
36 | tv.transforms.CenterCrop(crop_dim),
37 | tv.transforms.ToTensor(),
38 | normalize,
39 | ]
40 | )
41 | return transform
42 |
--------------------------------------------------------------------------------
/scripts/mosa/vit/VTAB/vtab_specialized.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_id=${1}
4 | data_path=/your/path/to/VTAB
5 | model_root=checkpoints
6 | dataset=("vtab-patch_camelyon" "vtab-eurosat" "vtab-resisc45" 'vtab-diabetic_retinopathy(config="btgraham-300")')
7 | number_classes=(2 10 45 5)
8 | output_dir=${2}
9 |
10 | for idx in {0..3}; do
11 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \
12 | --config-file configs/finetune/cub.yaml \
13 | --sparse-train \
14 | DATA.BATCH_SIZE "128" \
15 | DATA.FEATURE "sup_vitb16_imagenet21k" \
16 | DATA.NAME ${dataset[idx]} \
17 | DATA.NUMBER_CLASSES "${number_classes[idx]}" \
18 | MODEL.ADAPTER.BOTTLENECK_SIZE "16" \
19 | MODEL.ADAPTER.EXPERT_NUM "4" \
20 | MODEL.ADAPTER.MOE "True" \
21 | MODEL.ADAPTER.MERGE "add" \
22 | MODEL.ADAPTER.SHARE "down" \
23 | MODEL.ADAPTER.ADDITIONAL "True" \
24 | MODEL.ADAPTER.DEEPREG "False" \
25 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \
26 | MODEL.ADAPTER.REG_WEIGHT "1.0" \
27 | MODEL.TRANSFER_TYPE "mosa" \
28 | MODEL.TYPE "vit" \
29 | SEED "3407" \
30 | SOLVER.BASE_LR "0.005" \
31 | SOLVER.WEIGHT_DECAY "0.01" \
32 | SOLVER.WARMUP_EPOCH "10" \
33 | DATA.DATAPATH "${data_path}" \
34 | GPU_ID "${gpu_id}" \
35 | MODEL.MODEL_ROOT "${model_root}" \
36 | OUTPUT_DIR "${output_dir}"
37 | done
--------------------------------------------------------------------------------
/scripts/mosa/vit/VTAB/vtab_natural.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_id=${1}
4 | data_path=/your/path/to/VTAB
5 | model_root=checkpoints
6 | dataset=("vtab-cifar(num_classes=100)" "vtab-caltech101" "vtab-dtd" "vtab-oxford_flowers102" "vtab-oxford_iiit_pet" "vtab-svhn" "vtab-sun397")
7 | number_classes=(100 102 47 102 37 10 397)
8 | output_dir=${2}
9 |
10 | for idx in {0..6}; do
11 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \
12 | --config-file configs/finetune/cub.yaml \
13 | --sparse-train \
14 | DATA.BATCH_SIZE "128" \
15 | DATA.FEATURE "sup_vitb16_imagenet21k" \
16 | DATA.NAME ${dataset[idx]} \
17 | DATA.NUMBER_CLASSES "${number_classes[idx]}" \
18 | MODEL.ADAPTER.BOTTLENECK_SIZE "16" \
19 | MODEL.ADAPTER.EXPERT_NUM "4" \
20 | MODEL.ADAPTER.MOE "True" \
21 | MODEL.ADAPTER.MERGE "add" \
22 | MODEL.ADAPTER.SHARE "down" \
23 | MODEL.ADAPTER.ADDITIONAL "True" \
24 | MODEL.ADAPTER.DEEPREG "False" \
25 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \
26 | MODEL.ADAPTER.REG_WEIGHT "1.0" \
27 | MODEL.TRANSFER_TYPE "mosa" \
28 | MODEL.TYPE "vit" \
29 | SEED "3407" \
30 | SOLVER.BASE_LR "0.005" \
31 | SOLVER.WEIGHT_DECAY "0.01" \
32 | SOLVER.WARMUP_EPOCH "10" \
33 | DATA.DATAPATH "${data_path}" \
34 | GPU_ID "${gpu_id}" \
35 | MODEL.MODEL_ROOT "${model_root}" \
36 | OUTPUT_DIR "${output_dir}"
37 | done
--------------------------------------------------------------------------------
/scripts/mosa/vit/VTAB/vtab_structured.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_id=${1}
4 | data_path=/your/path/to/VTAB
5 | model_root=checkpoints
6 | dataset=('vtab-clevr(task="count_all")' 'vtab-clevr(task="closest_object_distance")' "vtab-dmlab" 'vtab-kitti(task="closest_vehicle_distance")' 'vtab-dsprites(predicted_attribute="label_x_position",num_classes=16)' 'vtab-dsprites(predicted_attribute="label_orientation",num_classes=16)' 'vtab-smallnorb(predicted_attribute="label_azimuth")' 'vtab-smallnorb(predicted_attribute="label_elevation")')
7 | number_classes=(8 6 6 4 16 16 18 9)
8 | output_dir=${2}
9 |
10 | for idx in {0..7}; do
11 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \
12 | --config-file configs/finetune/cub.yaml \
13 | --sparse-train \
14 | DATA.BATCH_SIZE "128" \
15 | DATA.FEATURE "sup_vitb16_imagenet21k" \
16 | DATA.NAME ${dataset[idx]} \
17 | DATA.NUMBER_CLASSES "${number_classes[idx]}" \
18 | MODEL.ADAPTER.BOTTLENECK_SIZE "16" \
19 | MODEL.ADAPTER.EXPERT_NUM "4" \
20 | MODEL.ADAPTER.MOE "True" \
21 | MODEL.ADAPTER.MERGE "add" \
22 | MODEL.ADAPTER.SHARE "down" \
23 | MODEL.ADAPTER.ADDITIONAL "True" \
24 | MODEL.ADAPTER.DEEPREG "False" \
25 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \
26 | MODEL.ADAPTER.REG_WEIGHT "1.0" \
27 | MODEL.TRANSFER_TYPE "mosa" \
28 | MODEL.TYPE "vit" \
29 | SEED "3407" \
30 | SOLVER.BASE_LR "0.005" \
31 | SOLVER.WEIGHT_DECAY "0.01" \
32 | SOLVER.WARMUP_EPOCH "10" \
33 | DATA.DATAPATH "${data_path}" \
34 | GPU_ID "${gpu_id}" \
35 | MODEL.MODEL_ROOT "${model_root}" \
36 | OUTPUT_DIR "${output_dir}"
37 | done
--------------------------------------------------------------------------------
/scripts/mosa/vit/VTAB/vtab.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_id=${1}
4 | data_path=/your/path/to/VTAB
5 | model_root=checkpoints
6 | dataset=("vtab-cifar(num_classes=100)" "vtab-caltech101" "vtab-dtd" "vtab-oxford_flowers102" "vtab-oxford_iiit_pet" "vtab-svhn" "vtab-sun397" "vtab-patch_camelyon" "vtab-eurosat" "vtab-resisc45" 'vtab-diabetic_retinopathy(config="btgraham-300")' 'vtab-clevr(task="count_all")' 'vtab-clevr(task="closest_object_distance")' "vtab-dmlab" 'vtab-kitti(task="closest_vehicle_distance")' 'vtab-dsprites(predicted_attribute="label_x_position",num_classes=16)' 'vtab-dsprites(predicted_attribute="label_orientation",num_classes=16)' 'vtab-smallnorb(predicted_attribute="label_azimuth")' 'vtab-smallnorb(predicted_attribute="label_elevation")')
7 | number_classes=(100 102 47 102 37 10 397 2 10 45 5 8 6 6 4 16 16 18 9)
8 | output_dir=${2}
9 |
10 | for idx in {0..18}; do
11 | CUDA_VISIBLE_DEVICES=${gpu_id} python train.py \
12 | --config-file configs/finetune/cub.yaml \
13 | --sparse-train \
14 | DATA.BATCH_SIZE "128" \
15 | DATA.FEATURE "sup_vitb16_imagenet21k" \
16 | DATA.NAME ${dataset[idx]} \
17 | DATA.NUMBER_CLASSES "${number_classes[idx]}" \
18 | MODEL.ADAPTER.BOTTLENECK_SIZE "16" \
19 | MODEL.ADAPTER.EXPERT_NUM "4" \
20 | MODEL.ADAPTER.MOE "True" \
21 | MODEL.ADAPTER.MERGE "add" \
22 | MODEL.ADAPTER.SHARE "down" \
23 | MODEL.ADAPTER.ADDITIONAL "True" \
24 | MODEL.ADAPTER.DEEPREG "False" \
25 | MODEL.ADAPTER.ADD_WEIGHT "0.0" \
26 | MODEL.ADAPTER.REG_WEIGHT "1.0" \
27 | MODEL.TRANSFER_TYPE "mosa" \
28 | MODEL.TYPE "vit" \
29 | SEED "3407" \
30 | SOLVER.BASE_LR "0.005" \
31 | SOLVER.WEIGHT_DECAY "0.01" \
32 | SOLVER.WARMUP_EPOCH "10" \
33 | DATA.DATAPATH "${data_path}" \
34 | GPU_ID "${gpu_id}" \
35 | MODEL.MODEL_ROOT "${model_root}" \
36 | OUTPUT_DIR "${output_dir}"
37 | done
--------------------------------------------------------------------------------
/src/utils/io_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | a bunch of helper functions for read and write data
4 | """
5 | import os
6 | import json
7 | import numpy as np
8 | import time
9 | import pandas as pd
10 |
11 | from typing import List, Union
12 | from PIL import Image, ImageFile
13 | Image.MAX_IMAGE_PIXELS = None
14 |
15 |
16 | def save_or_append_df(out_path, df):
17 | if os.path.exists(out_path):
18 | previous_df = pd.read_pickle(out_path)
19 | df = pd.concat([previous_df, df], ignore_index=True)
20 | df.to_pickle(out_path)
21 | print(f"Saved output at {out_path}")
22 |
23 |
24 | class JSONEncoder(json.JSONEncoder):
25 | def default(self, obj):
26 | if isinstance(obj, np.ndarray):
27 | return obj.tolist()
28 | elif isinstance(obj, bytes):
29 | return str(obj, encoding='utf-8')
30 | elif isinstance(obj, np.integer):
31 | return int(obj)
32 | elif isinstance(obj, np.floating):
33 | return float(obj)
34 | elif isinstance(obj, np.ndarray):
35 | return obj.tolist()
36 | else:
37 | # return super(MyEncoder, self).default(obj)
38 |
39 | raise TypeError(
40 | "Unserializable object {} of type {}".format(obj, type(obj))
41 | )
42 |
43 |
44 | def write_json(data: Union[list, dict], outfile: str) -> None:
45 | json_dir, _ = os.path.split(outfile)
46 | if json_dir and not os.path.exists(json_dir):
47 | os.makedirs(json_dir)
48 |
49 | with open(outfile, 'w') as f:
50 | json.dump(data, f, cls=JSONEncoder, ensure_ascii=False, indent=2)
51 |
52 |
53 | def read_json(filename: str) -> Union[list, dict]:
54 | """read json files"""
55 | with open(filename, "rb") as fin:
56 | data = json.load(fin)
57 | return data
58 |
59 |
60 | def pil_loader(path: str) -> Image.Image:
61 | """load an image from path, and suppress warning"""
62 | # to avoid crashing for truncated (corrupted images)
63 | ImageFile.LOAD_TRUNCATED_IMAGES = True
64 | # open path as file to avoid ResourceWarning
65 | # (https://github.com/python-pillow/Pillow/issues/835)
66 | with open(path, 'rb') as f:
67 | img = Image.open(f)
68 | return img.convert('RGB')
69 |
--------------------------------------------------------------------------------
/src/models/mlp.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Modified from: fbcode/multimo/models/encoders/mlp.py
4 | """
5 | import math
6 | import torch
7 |
8 | from torch import nn
9 | from typing import List, Type
10 |
11 | from ..utils import logging
12 | logger = logging.get_logger("MOSA")
13 |
14 |
15 | class MLP(nn.Module):
16 | def __init__(
17 | self,
18 | input_dim: int,
19 | mlp_dims: List[int],
20 | dropout: float = 0.1,
21 | nonlinearity: Type[nn.Module] = nn.ReLU,
22 | normalization: Type[nn.Module] = nn.BatchNorm1d, # nn.LayerNorm,
23 | special_bias: bool = False,
24 | add_bn_first: bool = False,
25 | ):
26 | super(MLP, self).__init__()
27 | projection_prev_dim = input_dim
28 | projection_modulelist = []
29 | last_dim = mlp_dims[-1]
30 | mlp_dims = mlp_dims[:-1]
31 |
32 | if add_bn_first:
33 | if normalization is not None:
34 | projection_modulelist.append(normalization(projection_prev_dim))
35 | if dropout != 0:
36 | projection_modulelist.append(nn.Dropout(dropout))
37 |
38 | for idx, mlp_dim in enumerate(mlp_dims):
39 | fc_layer = nn.Linear(projection_prev_dim, mlp_dim)
40 | nn.init.kaiming_normal_(fc_layer.weight, a=0, mode='fan_out')
41 | projection_modulelist.append(fc_layer)
42 | projection_modulelist.append(nonlinearity())
43 |
44 | if normalization is not None:
45 | projection_modulelist.append(normalization(mlp_dim))
46 |
47 | if dropout != 0:
48 | projection_modulelist.append(nn.Dropout(dropout))
49 | projection_prev_dim = mlp_dim
50 |
51 | self.projection = nn.Sequential(*projection_modulelist)
52 | self.last_layer = nn.Linear(projection_prev_dim, last_dim)
53 | nn.init.kaiming_normal_(self.last_layer.weight, a=0, mode='fan_out')
54 | if special_bias:
55 | prior_prob = 0.01
56 | bias_value = -math.log((1 - prior_prob) / prior_prob)
57 | torch.nn.init.constant_(self.last_layer.bias, bias_value)
58 |
59 | def forward(self, x: torch.Tensor) -> torch.Tensor:
60 | """
61 | input_arguments:
62 | @x: torch.FloatTensor
63 | """
64 | x = self.projection(x)
65 | x = self.last_layer(x)
66 | return x
67 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Mixture of Sparse Adapters
2 |
3 | This repository contains the official PyTorch implementation for MoSA.
4 |
5 | 
6 |
7 | ## Environment setup
8 |
9 | See `env_setup.sh`
10 |
11 | ## Datasets preperation
12 |
13 | - Fine-Grained Visual Classification (FGVC): The datasets can be downloaded following the official links. We split the training data if the public validation set is not available. The splitted dataset can be found here: [Dropbox](https://cornell.box.com/v/vptfgvcsplits), [Google Drive](https://drive.google.com/drive/folders/1mnvxTkYxmOr2W9QjcgS64UBpoJ4UmKaM?usp=sharing).
14 |
15 | - [CUB200 2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html)
16 |
17 | - [NABirds](http://info.allaboutbirds.org/nabirds/)
18 |
19 | - [Oxford Flowers](https://www.robots.ox.ac.uk/~vgg/data/flowers/)
20 |
21 | - [Stanford Dogs](http://vision.stanford.edu/aditya86/ImageNetDogs/main.html)
22 |
23 | - [Stanford Cars](https://ai.stanford.edu/~jkrause/cars/car_dataset.html)
24 |
25 | - Visual Task Adaptation Benchmark (VTAB): See [`VTAB_SETUP.md`](https://github.com/KMnP/vpt/blob/main/VTAB_SETUP.md) for detailed instructions and tips.
26 |
27 | - General Image Classification Datasets (GICD): The datasets will be automatically downloaded when you run an experiment using them via `MoSA`.
28 |
29 | ## Pre-trained weights preperation
30 |
31 | Download and place the pre-trained Transformer-based backbones to `MODEL.MODEL_ROOT`.
32 |
33 |
34 |
35 |
36 | | Pre-trained Backbone |
37 | Link |
38 | md5sum |
39 |
40 | | ViT-B/16 |
41 | link |
42 | d9715d |
43 |
44 | | ViT-L/16 |
45 | link |
46 | 8f39ce |
47 |
48 | | Swin-B |
49 | link |
50 | bf9cc1 |
51 |
52 |
53 |
54 | ## Training
55 |
56 | To fine-tune a pre-trained ViT model via MoSA on FGVC-cub, you can run:
57 |
58 | ```bash
59 | bash scripts/mosa/vit/FGVC/cub.sh
60 | ```
61 |
62 | ## License
63 |
64 | The majority of MoSA is licensed under the CC-BY-NC 4.0 license (see [LICENSE](https://github.com/KMnP/vpt/blob/main/LICENSE) for details).
65 |
--------------------------------------------------------------------------------
/src/models/build_model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Model construction functions.
4 | """
5 | from tabnanny import verbose
6 | import torch
7 |
8 | from .resnet import ResNet
9 | from .convnext import ConvNeXt
10 | from .vit_models import ViT, Swin, SSLViT
11 | from ..utils import logging
12 | logger = logging.get_logger("MOSA")
13 | # Supported model types
14 | _MODEL_TYPES = {
15 | "resnet": ResNet,
16 | "convnext": ConvNeXt,
17 | "vit": ViT,
18 | "swin": Swin,
19 | "ssl-vit": SSLViT,
20 | }
21 |
22 |
23 | def build_model(cfg):
24 | """
25 | build model here
26 | """
27 | assert (
28 | cfg.MODEL.TYPE in _MODEL_TYPES.keys()
29 | ), "Model type '{}' not supported".format(cfg.MODEL.TYPE)
30 | assert (
31 | cfg.NUM_GPUS <= torch.cuda.device_count()
32 | ), "Cannot use more GPU devices than available"
33 |
34 | # Construct the model
35 | train_type = cfg.MODEL.TYPE
36 | model = _MODEL_TYPES[train_type](cfg)
37 |
38 | log_model_info(model, verbose=cfg.DBG)
39 | model, device = load_model_to_device(model, cfg)
40 | logger.info(f"Device used for model: {device}")
41 |
42 | return model, device
43 |
44 |
45 | def log_model_info(model, verbose=False):
46 | """Logs model info"""
47 | if verbose:
48 | logger.info(f"Classification Model:\n{model}")
49 | model_total_params = sum(p.numel() for p in model.parameters())
50 | model_grad_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
51 | logger.info("Total Parameters: {0}\t Gradient Parameters: {1}".format(
52 | model_total_params, model_grad_params))
53 | logger.info("tuned percent:%.3f"%(model_grad_params/model_total_params*100))
54 |
55 |
56 | def get_current_device():
57 | if torch.cuda.is_available():
58 | # Determine the GPU used by the current process
59 | cur_device = torch.cuda.current_device()
60 | else:
61 | cur_device = torch.device('cpu')
62 | return cur_device
63 |
64 |
65 | def load_model_to_device(model, cfg):
66 | cur_device = get_current_device()
67 | if torch.cuda.is_available():
68 | # Transfer the model to the current GPU device
69 | model = model.cuda(device=cur_device)
70 | # Use multi-process data parallel model in the multi-gpu setting
71 | if cfg.NUM_GPUS > 1:
72 | # Make model replica operate on the current device
73 | model = torch.nn.parallel.DistributedDataParallel(
74 | module=model, device_ids=[cur_device], output_device=cur_device,
75 | find_unused_parameters=True,
76 | )
77 | else:
78 | model = model.to(cur_device)
79 | return model, cur_device
80 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | GIT_TOKENS.md
2 | checkpoints/
3 | hyperparameters/
4 | outputs/
5 | outputs_appendix/
6 | scripts/
7 | *.npy
8 | *.zip
9 |
10 | # General
11 | .DS_Store?
12 | .DS_Store
13 | .AppleDouble
14 | .LSOverride
15 |
16 | # Byte-compiled / optimized / DLL files
17 | __pycache__/
18 | *.py[cod]
19 | *$py.class
20 |
21 | # C extensions
22 | *.so
23 |
24 | # Distribution / packaging
25 | .Python
26 | build/
27 | develop-eggs/
28 | dist/
29 | downloads/
30 | eggs/
31 | .eggs/
32 | lib/
33 | lib64/
34 | parts/
35 | sdist/
36 | var/
37 | wheels/
38 | pip-wheel-metadata/
39 | share/python-wheels/
40 | *.egg-info/
41 | .installed.cfg
42 | *.egg
43 | MANIFEST
44 |
45 | # PyInstaller
46 | # Usually these files are written by a python script from a template
47 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
48 | *.manifest
49 | *.spec
50 |
51 | # Installer logs
52 | pip-log.txt
53 | pip-delete-this-directory.txt
54 |
55 | # Unit test / coverage reports
56 | htmlcov/
57 | .tox/
58 | .nox/
59 | .coverage
60 | .coverage.*
61 | .cache
62 | nosetests.xml
63 | coverage.xml
64 | *.cover
65 | *.py,cover
66 | .hypothesis/
67 | .pytest_cache/
68 |
69 | # Translations
70 | *.mo
71 | *.pot
72 |
73 | # Django stuff:
74 | *.log
75 | local_settings.py
76 | db.sqlite3
77 | db.sqlite3-journal
78 |
79 | # Flask stuff:
80 | instance/
81 | .webassets-cache
82 |
83 | # Scrapy stuff:
84 | .scrapy
85 |
86 | # Sphinx documentation
87 | docs/_build/
88 |
89 | # PyBuilder
90 | target/
91 |
92 | # Jupyter Notebook
93 | .ipynb_checkpoints
94 |
95 | # IPython
96 | profile_default/
97 | ipython_config.py
98 |
99 | # pyenv
100 | .python-version
101 |
102 | # pipenv
103 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
104 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
105 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
106 | # install all needed dependencies.
107 | #Pipfile.lock
108 |
109 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
110 | __pypackages__/
111 |
112 | # Celery stuff
113 | celerybeat-schedule
114 | celerybeat.pid
115 |
116 | # SageMath parsed files
117 | *.sage.py
118 |
119 | # Environments
120 | .env
121 | .venv
122 | env/
123 | venv/
124 | ENV/
125 | env.bak/
126 | venv.bak/
127 |
128 | # Spyder project settings
129 | .spyderproject
130 | .spyproject
131 |
132 | # Rope project settings
133 | .ropeproject
134 |
135 | # mkdocs documentation
136 | /site
137 |
138 | # mypy
139 | .mypy_cache/
140 | .dmypy.json
141 | dmypy.json
142 |
143 | # Pyre type checker
144 | .pyre/
145 |
--------------------------------------------------------------------------------
/src/data/vtab_datasets/patch_camelyon.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3 | # Copyright 2019 Google LLC.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Implements PatchCamelyon data class."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | import tensorflow_datasets as tfds
24 |
25 | from . import base as base
26 | from .registry import Registry
27 |
28 |
29 | @Registry.register("data.patch_camelyon", "class")
30 | class PatchCamelyonData(base.ImageTfdsData):
31 | """Provides PatchCamelyon data."""
32 |
33 | def __init__(self, data_dir=None):
34 |
35 | dataset_builder = tfds.builder("patch_camelyon:2.*.*", data_dir=data_dir)
36 | dataset_builder.download_and_prepare()
37 |
38 | # Defines dataset specific train/val/trainval/test splits.
39 | tfds_splits = {
40 | "test": "test",
41 | "train": "train",
42 | "val": "validation",
43 | "trainval": "train+validation",
44 | "train800": "train[:800]",
45 | "val200": "validation[:200]",
46 | "train800val200": "train[:800]+validation[:200]",
47 | }
48 | # Creates a dict with example counts.
49 | num_samples_splits = {
50 | "test": dataset_builder.info.splits["test"].num_examples,
51 | "train": dataset_builder.info.splits["train"].num_examples,
52 | "val": dataset_builder.info.splits["validation"].num_examples,
53 | "train800": 800,
54 | "val200": 200,
55 | "train800val200": 1000,
56 | }
57 | num_samples_splits["trainval"] = (
58 | num_samples_splits["train"] + num_samples_splits["val"])
59 | super(PatchCamelyonData, self).__init__(
60 | dataset_builder=dataset_builder,
61 | tfds_splits=tfds_splits,
62 | num_samples_splits=num_samples_splits,
63 | num_preprocessing_threads=400,
64 | shuffle_buffer_size=10000,
65 | # Note: Export only image and label tensors with their original types.
66 | base_preprocess_fn=base.make_get_tensors_fn(["image", "label"]),
67 | num_classes=dataset_builder.info.features["label"].num_classes)
68 |
--------------------------------------------------------------------------------
/src/models/vit_adapter/adapter_block.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | '''
3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
4 | '''
5 | import math
6 | import logging
7 | from functools import partial
8 | from collections import OrderedDict
9 | from copy import deepcopy
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 |
15 | from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
16 | from timm.models.vision_transformer import Attention
17 | from timm.models.vision_transformer import Block
18 |
19 | from ...utils import logging
20 | logger = logging.get_logger("MOSA")
21 |
22 |
23 | class Pfeiffer_Block(Block):
24 |
25 | def __init__(self, adapter_config, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
26 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
27 |
28 | super(Pfeiffer_Block, self).__init__(
29 | dim=dim,
30 | num_heads=num_heads,
31 | mlp_ratio=mlp_ratio,
32 | qkv_bias=qkv_bias,
33 | drop=drop,
34 | attn_drop=attn_drop,
35 | drop_path=drop_path,
36 | act_layer=act_layer,
37 | norm_layer=norm_layer)
38 |
39 | self.adapter_config = adapter_config
40 |
41 | if adapter_config.STYLE == "Pfeiffer":
42 | self.adapter_downsample = nn.Linear(
43 | dim,
44 | dim // adapter_config.REDUCATION_FACTOR
45 | )
46 | self.adapter_upsample = nn.Linear(
47 | dim // adapter_config.REDUCATION_FACTOR,
48 | dim
49 | )
50 | self.adapter_act_fn = act_layer()
51 |
52 | nn.init.zeros_(self.adapter_downsample.weight)
53 | nn.init.zeros_(self.adapter_downsample.bias)
54 |
55 | nn.init.zeros_(self.adapter_upsample.weight)
56 | nn.init.zeros_(self.adapter_upsample.bias)
57 | else:
58 | raise ValueError("Other adapter styles are not supported.")
59 |
60 | def forward(self, x):
61 |
62 | if self.adapter_config.STYLE == "Pfeiffer":
63 | # same as reguluar ViT block
64 | h = x
65 | x = self.norm1(x)
66 | x = self.attn(x)
67 | x = self.drop_path(x)
68 | x = x + h
69 |
70 | h = x
71 | x = self.norm2(x)
72 | x = self.mlp(x)
73 |
74 | # start to insert adapter layers...
75 | adpt = self.adapter_downsample(x)
76 | adpt = self.adapter_act_fn(adpt)
77 | adpt = self.adapter_upsample(adpt)
78 | x = adpt + x
79 | # ...end
80 |
81 | x = self.drop_path(x)
82 | x = x + h
83 |
84 | return x
85 |
--------------------------------------------------------------------------------
/src/data/vtab_datasets/sun397.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3 | # Copyright 2019 Google LLC.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Implements Sun397 data class."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | import tensorflow_datasets as tfds
24 |
25 | from . import base as base
26 | from .registry import Registry
27 | CUSTOM_TRAIN_SPLIT_PERCENT = 50
28 | CUSTOM_VALIDATION_SPLIT_PERCENT = 20
29 | CUSTOM_TEST_SPLIT_PERCENT = 30
30 |
31 |
32 | @Registry.register("data.sun397", "class")
33 | class Sun397Data(base.ImageTfdsData):
34 | """Provides Sun397Data data."""
35 |
36 | def __init__(self, config="tfds", data_dir=None):
37 |
38 | if config == "tfds":
39 | dataset_builder = tfds.builder("sun397/tfds:4.*.*", data_dir=data_dir)
40 | dataset_builder.download_and_prepare()
41 |
42 | tfds_splits = {
43 | "train": "train",
44 | "val": "validation",
45 | "test": "test",
46 | "trainval": "train+validation",
47 | "train800": "train[:800]",
48 | "val200": "validation[:200]",
49 | "train800val200": "train[:800]+validation[:200]",
50 | }
51 | # Creates a dict with example counts.
52 | num_samples_splits = {
53 | "test": dataset_builder.info.splits["test"].num_examples,
54 | "train": dataset_builder.info.splits["train"].num_examples,
55 | "val": dataset_builder.info.splits["validation"].num_examples,
56 | "train800": 800,
57 | "val200": 200,
58 | "train800val200": 1000,
59 | }
60 | num_samples_splits["trainval"] = (
61 | num_samples_splits["train"] + num_samples_splits["val"])
62 | else:
63 |
64 | raise ValueError("No supported config %r for Sun397Data." % config)
65 |
66 | super(Sun397Data, self).__init__(
67 | dataset_builder=dataset_builder,
68 | tfds_splits=tfds_splits,
69 | num_samples_splits=num_samples_splits,
70 | num_preprocessing_threads=400,
71 | shuffle_buffer_size=10000,
72 | # Note: Export only image and label tensors with their original types.
73 | base_preprocess_fn=base.make_get_tensors_fn(["image", "label"]),
74 | num_classes=dataset_builder.info.features["label"].num_classes)
75 |
--------------------------------------------------------------------------------
/src/models/vit_backbones/vit_mae.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3 | """
4 | borrowed from https://github.com/facebookresearch/mae/blob/main/models_vit.py
5 | """
6 | from functools import partial
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 | import timm.models.vision_transformer
12 |
13 |
14 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
15 | """ Vision Transformer with support for global average pooling
16 | """
17 | def __init__(self, global_pool=False, **kwargs):
18 | super(VisionTransformer, self).__init__(**kwargs)
19 |
20 | self.global_pool = global_pool
21 | if self.global_pool:
22 | norm_layer = kwargs['norm_layer']
23 | embed_dim = kwargs['embed_dim']
24 | self.fc_norm = norm_layer(embed_dim)
25 |
26 | del self.norm # remove the original norm
27 |
28 | def forward_features(self, x):
29 | B = x.shape[0]
30 | x = self.patch_embed(x)
31 |
32 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
33 | x = torch.cat((cls_tokens, x), dim=1)
34 | x = x + self.pos_embed
35 | x = self.pos_drop(x)
36 |
37 | for blk in self.blocks:
38 | x = blk(x)
39 |
40 | if self.global_pool:
41 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token
42 | outcome = self.fc_norm(x)
43 | else:
44 | x = self.norm(x)
45 | outcome = x[:, 0]
46 |
47 | return outcome
48 |
49 |
50 | def build_model(model_type):
51 | if "vitb" in model_type:
52 | return vit_base_patch16()
53 | elif "vitl" in model_type:
54 | return vit_large_patch16()
55 | elif "vith" in model_type:
56 | return vit_huge_patch14()
57 |
58 |
59 | def vit_base_patch16(**kwargs):
60 | model = VisionTransformer(
61 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune
62 | patch_size=16, embed_dim=768, depth=12, num_heads=12,
63 | mlp_ratio=4, qkv_bias=True,
64 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
65 | return model
66 |
67 |
68 | def vit_large_patch16(**kwargs):
69 | model = VisionTransformer(
70 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune
71 | patch_size=16, embed_dim=1024, depth=24, num_heads=16,
72 | mlp_ratio=4, qkv_bias=True,
73 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
74 | return model
75 |
76 |
77 | def vit_huge_patch14(**kwargs):
78 | model = VisionTransformer(
79 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune
80 | patch_size=14, embed_dim=1280, depth=32, num_heads=16,
81 | mlp_ratio=4, qkv_bias=True,
82 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
83 | return model
84 |
--------------------------------------------------------------------------------
/src/data/vtab_datasets/dtd.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3 | # Copyright 2019 Google LLC.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Implements the Describable Textures Dataset (DTD) data class."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | import tensorflow_datasets as tfds
24 |
25 | from . import base as base
26 | from .registry import Registry
27 |
28 |
29 | @Registry.register("data.dtd", "class")
30 | class DTDData(base.ImageTfdsData):
31 | """Provides Describable Textures Dataset (DTD) data.
32 |
33 | As of version 1.0.0, the train/val/test splits correspond to those of the
34 | 1st fold of the official cross-validation partition.
35 |
36 | For additional details and usage, see the base class.
37 | """
38 |
39 | def __init__(self, data_dir=None):
40 |
41 | dataset_builder = tfds.builder("dtd:3.*.*", data_dir=data_dir)
42 | dataset_builder.download_and_prepare()
43 |
44 | # Defines dataset specific train/val/trainval/test splits.
45 | tfds_splits = {
46 | "train": "train",
47 | "val": "validation",
48 | "trainval": "train+validation",
49 | "test": "test",
50 | "train800": "train[:800]",
51 | "val200": "validation[:200]",
52 | "train800val200": "train[:800]+validation[:200]",
53 | }
54 |
55 | # Creates a dict with example counts for each split.
56 | train_count = dataset_builder.info.splits["train"].num_examples
57 | val_count = dataset_builder.info.splits["validation"].num_examples
58 | test_count = dataset_builder.info.splits["test"].num_examples
59 | num_samples_splits = {
60 | "train": train_count,
61 | "val": val_count,
62 | "trainval": train_count + val_count,
63 | "test": test_count,
64 | "train800": 800,
65 | "val200": 200,
66 | "train800val200": 1000,
67 | }
68 |
69 | super(DTDData, self).__init__(
70 | dataset_builder=dataset_builder,
71 | tfds_splits=tfds_splits,
72 | num_samples_splits=num_samples_splits,
73 | num_preprocessing_threads=400,
74 | shuffle_buffer_size=10000,
75 | # Note: Export only image and label tensors with their original types.
76 | base_preprocess_fn=base.make_get_tensors_fn(["image", "label"]),
77 | num_classes=dataset_builder.info.features["label"].num_classes)
78 |
--------------------------------------------------------------------------------
/src/data/vtab_datasets/dmlab.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3 | # Copyright 2019 Google LLC.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Implements Dmlab data class."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | import tensorflow_datasets as tfds
24 |
25 | from . import base as base
26 | from .registry import Registry
27 |
28 |
29 | @Registry.register("data.dmlab", "class")
30 | class DmlabData(base.ImageTfdsData):
31 | """Dmlab dataset.
32 |
33 | The Dmlab dataset contains frames observed by the agent acting in the
34 | DMLab environment, which are annotated by the distance between
35 | the agent and various objects present in the environment. The goal is to
36 | is to evaluate the ability of a visual model to reason about distances
37 | from the visual input in 3D environments. The Dmlab dataset consists of
38 | 360x480 color images in 6 classes. The classes are
39 | {close, far, very far} x {positive reward, negative reward}
40 | respectively.
41 | """
42 |
43 | def __init__(self, data_dir=None):
44 |
45 | dataset_builder = tfds.builder("dmlab:2.0.1", data_dir=data_dir)
46 |
47 | tfds_splits = {
48 | "train": "train",
49 | "val": "validation",
50 | "trainval": "train+validation",
51 | "test": "test",
52 | "train800": "train[:800]",
53 | "val200": "validation[:200]",
54 | "train800val200": "train[:800]+validation[:200]",
55 | }
56 |
57 | # Example counts are retrieved from the tensorflow dataset info.
58 | train_count = dataset_builder.info.splits["train"].num_examples
59 | val_count = dataset_builder.info.splits["validation"].num_examples
60 | test_count = dataset_builder.info.splits["test"].num_examples
61 |
62 | # Creates a dict with example counts for each split.
63 | num_samples_splits = {
64 | "train": train_count,
65 | "val": val_count,
66 | "trainval": train_count + val_count,
67 | "test": test_count,
68 | "train800": 800,
69 | "val200": 200,
70 | "train800val200": 1000,
71 | }
72 |
73 | super(DmlabData, self).__init__(
74 | dataset_builder=dataset_builder,
75 | tfds_splits=tfds_splits,
76 | num_samples_splits=num_samples_splits,
77 | num_preprocessing_threads=400,
78 | shuffle_buffer_size=10000,
79 | base_preprocess_fn=base.make_get_and_cast_tensors_fn({
80 | "image": ("image", None),
81 | "label": ("label", None),
82 | }),
83 | num_classes=dataset_builder.info.features["label"].num_classes,
84 | image_key="image")
85 |
--------------------------------------------------------------------------------
/src/engine/eval/singlelabel.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | """Functions for computing metrics. all metrics has range of 0-1"""
4 |
5 | import numpy as np
6 | import torch
7 | from sklearn.metrics import (
8 | accuracy_score, average_precision_score, f1_score, roc_auc_score
9 | )
10 |
11 |
12 | def accuracy(y_probs, y_true):
13 | # y_prob: (num_images, num_classes)
14 | y_preds = np.argmax(y_probs, axis=1)
15 | accuracy = accuracy_score(y_true, y_preds)
16 | error = 1.0 - accuracy
17 | return accuracy, error
18 |
19 |
20 | def top_n_accuracy(y_probs, truths, n=1):
21 | # y_prob: (num_images, num_classes)
22 | # truth: (num_images, num_classes) multi/one-hot encoding
23 | best_n = np.argsort(y_probs, axis=1)[:, -n:]
24 | if isinstance(truths, np.ndarray) and truths.shape == y_probs.shape:
25 | ts = np.argmax(truths, axis=1)
26 | else:
27 | # a list of GT class idx
28 | ts = truths
29 |
30 | num_input = y_probs.shape[0]
31 | successes = 0
32 | for i in range(num_input):
33 | if ts[i] in best_n[i, :]:
34 | successes += 1
35 | return float(successes) / num_input
36 |
37 |
38 | def compute_acc_auc(y_probs, y_true_ids):
39 | onehot_tgts = np.zeros_like(y_probs)
40 | for idx, t in enumerate(y_true_ids):
41 | onehot_tgts[idx, t] = 1.
42 |
43 | num_classes = y_probs.shape[1]
44 | if num_classes == 2:
45 | top1, _ = accuracy(y_probs, y_true_ids)
46 | # so precision can set all to 2
47 | try:
48 | auc = roc_auc_score(onehot_tgts, y_probs, average='macro')
49 | except ValueError as e:
50 | print(f"value error encountered {e}, set auc sccore to -1.")
51 | auc = -1
52 | return {"top1": top1, "rocauc": auc}
53 |
54 | top1, _ = accuracy(y_probs, y_true_ids)
55 | k = min([5, num_classes]) # if number of labels < 5, use the total class
56 | top5 = top_n_accuracy(y_probs, y_true_ids, k)
57 | return {"top1": top1, f"top{k}": top5}
58 |
59 |
60 | def topks_correct(preds, labels, ks):
61 | """Computes the number of top-k correct predictions for each k."""
62 | assert preds.size(0) == labels.size(
63 | 0
64 | ), "Batch dim of predictions and labels must match"
65 | # Find the top max_k predictions for each sample
66 | _top_max_k_vals, top_max_k_inds = torch.topk(
67 | preds, max(ks), dim=1, largest=True, sorted=True
68 | )
69 | # (batch_size, max_k) -> (max_k, batch_size)
70 | top_max_k_inds = top_max_k_inds.t()
71 | # (batch_size, ) -> (max_k, batch_size)
72 | rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds)
73 | # (i, j) = 1 if top i-th prediction for the j-th sample is correct
74 | top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels)
75 | # Compute the number of topk correct predictions for each k
76 | topks_correct = [
77 | top_max_k_correct[:k, :].reshape(-1).float().sum() for k in ks
78 | ]
79 | return topks_correct
80 |
81 |
82 | def topk_errors(preds, labels, ks):
83 | """Computes the top-k error for each k."""
84 | if int(labels.min()) < 0: # has ignore
85 | keep_ids = np.where(labels.cpu() >= 0)[0]
86 | preds = preds[keep_ids, :]
87 | labels = labels[keep_ids]
88 |
89 | num_topks_correct = topks_correct(preds, labels, ks)
90 | return [(1.0 - x / preds.size(0)) for x in num_topks_correct]
91 |
92 |
93 | def topk_accuracies(preds, labels, ks):
94 | """Computes the top-k accuracy for each k."""
95 | num_topks_correct = topks_correct(preds, labels, ks)
96 | return [(x / preds.size(0)) for x in num_topks_correct]
97 |
98 |
--------------------------------------------------------------------------------
/src/data/vtab_datasets/caltech.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3 | # Copyright 2019 Google LLC.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Imports the Caltech images dataset."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 | from . import base as base
23 | from .registry import Registry
24 | import tensorflow_datasets as tfds
25 |
26 |
27 | # Percentage of the original training set retained for training, the rest is
28 | # used as a validation set.
29 | _TRAIN_SPLIT_PERCENT = 90
30 |
31 |
32 | @Registry.register("data.caltech101", "class")
33 | class Caltech101(base.ImageTfdsData):
34 | """Provides the Caltech101 dataset.
35 |
36 | See the base class for additional details on the class.
37 |
38 | See TFDS dataset for details on the dataset:
39 | third_party/py/tensorflow_datasets/image/caltech.py
40 |
41 | The original (TFDS) dataset contains only a train and test split. We randomly
42 | sample _TRAIN_SPLIT_PERCENT% of the train split for our "train" set. The
43 | remainder of the TFDS train split becomes our "val" set. The full TFDS train
44 | split is called "trainval". The TFDS test split is used as our test set.
45 |
46 | Note that, in the TFDS dataset, the training split is class-balanced, but not
47 | the test split. Therefore, a significant difference between performance on the
48 | "val" and "test" sets should be expected.
49 | """
50 |
51 | def __init__(self, data_dir=None):
52 | dataset_builder = tfds.builder("caltech101:3.*.*", data_dir=data_dir)
53 | dataset_builder.download_and_prepare()
54 |
55 | # Creates a dict with example counts for each split.
56 | trainval_count = dataset_builder.info.splits["train"].num_examples
57 | train_count = (_TRAIN_SPLIT_PERCENT * trainval_count) // 100
58 | test_count = dataset_builder.info.splits["test"].num_examples
59 | num_samples_splits = dict(
60 | train=train_count,
61 | val=trainval_count - train_count,
62 | trainval=trainval_count,
63 | test=test_count,
64 | train800=800,
65 | val200=200,
66 | train800val200=1000)
67 |
68 | # Defines dataset specific train/val/trainval/test splits.
69 | tfds_splits = {
70 | "train": "train[:{}]".format(train_count),
71 | "val": "train[{}:]".format(train_count),
72 | "trainval": "train",
73 | "test": "test",
74 | "train800": "train[:800]",
75 | "val200": "train[{}:{}]".format(train_count, train_count+200),
76 | "train800val200": (
77 | "train[:800]+train[{}:{}]".format(train_count, train_count+200)),
78 | }
79 |
80 | super(Caltech101, self).__init__(
81 | dataset_builder=dataset_builder,
82 | tfds_splits=tfds_splits,
83 | num_samples_splits=num_samples_splits,
84 | num_preprocessing_threads=400,
85 | shuffle_buffer_size=3000,
86 | base_preprocess_fn=base.make_get_tensors_fn(("image", "label")),
87 | num_classes=dataset_builder.info.features["label"].num_classes)
88 |
--------------------------------------------------------------------------------
/src/data/vtab_datasets/resisc45.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3 | # Copyright 2019 Google LLC.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Implements RESISC-45 data class."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | import tensorflow_datasets as tfds
24 |
25 | from . import base as base
26 | from .registry import Registry
27 | TRAIN_SPLIT_PERCENT = 60
28 | VALIDATION_SPLIT_PERCENT = 20
29 | TEST_SPLIT_PERCENT = 20
30 |
31 |
32 | @Registry.register("data.resisc45", "class")
33 | class Resisc45Data(base.ImageTfdsData):
34 | """Provides RESISC-45 dataset.
35 |
36 | RESISC45 dataset is a publicly available benchmark for Remote Sensing Image
37 | Scene Classification (RESISC), created by Northwestern Polytechnical
38 | University (NWPU). This dataset contains 31,500 images, covering 45 scene
39 | classes with 700 images in each class.
40 |
41 | URL: http://www.escience.cn/people/JunweiHan/NWPU-RESISC45.html
42 | """
43 |
44 | def __init__(self, data_dir=None):
45 | dataset_builder = tfds.builder("resisc45:3.*.*", data_dir=data_dir)
46 | dataset_builder.download_and_prepare()
47 |
48 | # Example counts are retrieved from the tensorflow dataset info.
49 | num_examples = dataset_builder.info.splits["train"].num_examples
50 | train_count = num_examples * TRAIN_SPLIT_PERCENT // 100
51 | val_count = num_examples * VALIDATION_SPLIT_PERCENT // 100
52 | test_count = num_examples * TEST_SPLIT_PERCENT // 100
53 |
54 | tfds_splits = {
55 | "train":
56 | "train[:{}]".format(train_count),
57 | "val":
58 | "train[{}:{}]".format(train_count, train_count + val_count),
59 | "trainval":
60 | "train[:{}]".format(train_count + val_count),
61 | "test":
62 | "train[{}:]".format(train_count + val_count),
63 | "train800":
64 | "train[:800]",
65 | "val200":
66 | "train[{}:{}]".format(train_count, train_count+200),
67 | "train800val200":
68 | "train[:800]+train[{}:{}]".format(train_count, train_count+200),
69 | }
70 |
71 | # Creates a dict with example counts for each split.
72 | num_samples_splits = {
73 | "train": train_count,
74 | "val": val_count,
75 | "trainval": train_count + val_count,
76 | "test": test_count,
77 | "train800": 800,
78 | "val200": 200,
79 | "train800val200": 1000,
80 | }
81 |
82 | super(Resisc45Data, self).__init__(
83 | dataset_builder=dataset_builder,
84 | tfds_splits=tfds_splits,
85 | num_samples_splits=num_samples_splits,
86 | num_preprocessing_threads=400,
87 | shuffle_buffer_size=10000,
88 | # Note: Rename tensors but keep their original types.
89 | base_preprocess_fn=base.make_get_and_cast_tensors_fn({
90 | "image": ("image", None),
91 | "label": ("label", None),
92 | }),
93 | num_classes=dataset_builder.info.features["label"].num_classes)
94 |
--------------------------------------------------------------------------------
/src/solver/losses.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from typing import Optional
6 |
7 | from ..utils import logging
8 | logger = logging.get_logger("MOSA")
9 |
10 |
11 | class SigmoidLoss(nn.Module):
12 | def __init__(self, cfg=None):
13 | super(SigmoidLoss, self).__init__()
14 |
15 | def is_single(self):
16 | return True
17 |
18 | def is_local(self):
19 | return False
20 |
21 | def multi_hot(self, labels: torch.Tensor, nb_classes: int) -> torch.Tensor:
22 | labels = labels.unsqueeze(1) # (batch_size, 1)
23 | target = torch.zeros(
24 | labels.size(0), nb_classes, device=labels.device
25 | ).scatter_(1, labels, 1.)
26 | # (batch_size, num_classes)
27 | return target
28 |
29 | def loss(
30 | self, logits, targets, per_cls_weights,
31 | multihot_targets: Optional[bool] = False
32 | ):
33 | # targets: 1d-tensor of integer
34 | # Only support single label at this moment
35 | # if len(targets.shape) != 2:
36 | num_classes = logits.shape[1]
37 | targets = self.multi_hot(targets, num_classes)
38 |
39 | loss = F.binary_cross_entropy_with_logits(
40 | logits, targets, reduction="none")
41 | # logger.info(f"loss shape: {loss.shape}")
42 | weight = torch.tensor(
43 | per_cls_weights, device=logits.device
44 | ).unsqueeze(0)
45 | # logger.info(f"weight shape: {weight.shape}")
46 | loss = torch.mul(loss.to(torch.float32), weight.to(torch.float32))
47 | return torch.sum(loss) / targets.shape[0]
48 |
49 | def forward(
50 | self, pred_logits, targets, per_cls_weights, multihot_targets=False
51 | ):
52 | loss = self.loss(
53 | pred_logits, targets, per_cls_weights, multihot_targets)
54 | return loss
55 |
56 |
57 | class SoftmaxLoss(SigmoidLoss):
58 | def __init__(self, cfg=None):
59 | super(SoftmaxLoss, self).__init__()
60 |
61 | def loss(self, logits, targets, per_cls_weights, kwargs):
62 | weight = torch.tensor(
63 | per_cls_weights, device=logits.device
64 | )
65 | loss = F.cross_entropy(logits, targets, weight, reduction="none")
66 |
67 | return torch.sum(loss) / targets.shape[0]
68 |
69 |
70 | def symmetric_KL_loss(input, target, reduction='batchmean'):
71 | """ symmetric KL-divergence 1/2*(KL(p||q)+KL(q||p)) """
72 |
73 | input = input.float()
74 | target = target.float()
75 | loss = F.kl_div(F.log_softmax(input, dim=-1, dtype=torch.float32),
76 | F.softmax(target.detach(), dim=-1, dtype=torch.float32), reduction=reduction) + \
77 | F.kl_div(F.log_softmax(target, dim=-1, dtype=torch.float32),
78 | F.softmax(input.detach(), dim=-1, dtype=torch.float32), reduction=reduction)
79 | return 0.5 * loss.sum()
80 |
81 |
82 | def deepreg_MSE_loss(input, target, reduction='mean'):
83 | """ deep regulerization MSE loss """
84 |
85 | loss = 0
86 | for i in range(6):
87 | loss += F.mse_loss(input[i], target[i], reduction=reduction)
88 | return loss.sum() / len(input) * 0.01
89 |
90 |
91 | LOSS = {
92 | "softmax": SoftmaxLoss,
93 | }
94 |
95 |
96 | def build_loss(cfg):
97 | loss_name = cfg.SOLVER.LOSS
98 | assert loss_name in LOSS, \
99 | f'loss name {loss_name} is not supported'
100 | loss_fn = LOSS[loss_name]
101 | if not loss_fn:
102 | return None
103 | else:
104 | return loss_fn(cfg)
105 |
--------------------------------------------------------------------------------
/src/configs/vit_configs.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Copyright (c) Meta Platforms, Inc. All Rights Reserved
4 | https://github.com/jeonsworld/ViT-pytorch/blob/main/models/configs.py
5 | """
6 | import ml_collections
7 |
8 |
9 | def get_testing():
10 | """Returns a minimal configuration for testing."""
11 | config = ml_collections.ConfigDict()
12 | config.patches = ml_collections.ConfigDict({'size': (16, 16)})
13 | config.hidden_size = 1
14 | config.transformer = ml_collections.ConfigDict()
15 | config.transformer.mlp_dim = 1
16 | config.transformer.num_heads = 1
17 | config.transformer.num_layers = 1
18 | config.transformer.attention_dropout_rate = 0.0
19 | config.transformer.dropout_rate = 0.1
20 | config.classifier = 'token'
21 | config.representation_size = None
22 | return config
23 |
24 |
25 | def get_b16_config():
26 | """Returns the ViT-B/16 configuration."""
27 | config = ml_collections.ConfigDict()
28 | config.patches = ml_collections.ConfigDict({'size': (16, 16)})
29 | config.hidden_size = 768
30 | config.transformer = ml_collections.ConfigDict()
31 | config.transformer.mlp_dim = 3072
32 | config.transformer.num_heads = 12
33 | config.transformer.num_layers = 12
34 | config.transformer.attention_dropout_rate = 0.0
35 | config.transformer.dropout_rate = 0.1
36 | config.classifier = 'token'
37 | config.representation_size = None
38 | return config
39 |
40 |
41 | def get_r50_b16_config():
42 | """Returns the Resnet50 + ViT-B/16 configuration."""
43 | config = get_b16_config()
44 | del config.patches.size
45 | config.patches.grid = (14, 14)
46 | config.resnet = ml_collections.ConfigDict()
47 | config.resnet.num_layers = (3, 4, 9)
48 | config.resnet.width_factor = 1
49 | return config
50 |
51 |
52 | def get_b32_config():
53 | """Returns the ViT-B/32 configuration."""
54 | config = get_b16_config()
55 | config.patches.size = (32, 32)
56 | return config
57 |
58 |
59 | def get_b8_config():
60 | """Returns the ViT-B/32 configuration."""
61 | config = get_b16_config()
62 | config.patches.size = (8, 8)
63 | return config
64 |
65 |
66 | def get_l16_config():
67 | """Returns the ViT-L/16 configuration."""
68 | config = ml_collections.ConfigDict()
69 | config.patches = ml_collections.ConfigDict({'size': (16, 16)})
70 | config.hidden_size = 1024
71 | config.transformer = ml_collections.ConfigDict()
72 | config.transformer.mlp_dim = 4096
73 | config.transformer.num_heads = 16
74 | config.transformer.num_layers = 24
75 | config.transformer.attention_dropout_rate = 0.0
76 | config.transformer.dropout_rate = 0.1
77 | config.classifier = 'token'
78 | config.representation_size = None
79 | return config
80 |
81 |
82 | def get_l32_config():
83 | """Returns the ViT-L/32 configuration."""
84 | config = get_l16_config()
85 | config.patches.size = (32, 32)
86 | return config
87 |
88 |
89 | def get_h14_config():
90 | """Returns the ViT-L/16 configuration."""
91 | config = ml_collections.ConfigDict()
92 | config.patches = ml_collections.ConfigDict({'size': (14, 14)})
93 | config.hidden_size = 1280
94 | config.transformer = ml_collections.ConfigDict()
95 | config.transformer.mlp_dim = 5120
96 | config.transformer.num_heads = 16
97 | config.transformer.num_layers = 32
98 | config.transformer.attention_dropout_rate = 0.0
99 | config.transformer.dropout_rate = 0.1
100 | config.classifier = 'token'
101 | config.representation_size = None
102 | return config
103 |
--------------------------------------------------------------------------------
/src/models/vit_adapter/vit_moco.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | borrow from https://github.com/facebookresearch/moco-v3/blob/main/vits.py
4 | """
5 | import math
6 | import torch
7 | import torch.nn as nn
8 | from functools import partial, reduce
9 | from operator import mul
10 |
11 | from timm.models.vision_transformer import VisionTransformer, _cfg
12 | # from timm.layers.helpers import to_2tuple
13 | from timm.models.layers.helpers import to_2tuple
14 | from timm.models.layers import PatchEmbed
15 |
16 | from .adapter_block import Pfeiffer_Block
17 | from ..vit_backbones.vit_moco import VisionTransformerMoCo
18 | from ...utils import logging
19 | logger = logging.get_logger("MOSA")
20 |
21 |
22 | class ADPT_VisionTransformerMoCo(VisionTransformerMoCo):
23 | def __init__(
24 | self,
25 | adapter_cfg,
26 | stop_grad_conv1=False,
27 | img_size=224,
28 | patch_size=16,
29 | in_chans=3,
30 | num_classes=1000,
31 | embed_dim=768,
32 | depth=12,
33 | num_heads=12,
34 | mlp_ratio=4.,
35 | qkv_bias=True,
36 | representation_size=None,
37 | distilled=False,
38 | drop_rate=0.,
39 | attn_drop_rate=0.,
40 | drop_path_rate=0.,
41 | embed_layer=PatchEmbed,
42 | norm_layer=None,
43 | act_layer=None,
44 | weight_init='',
45 | **kwargs):
46 |
47 | super(ADPT_VisionTransformerMoCo, self).__init__(
48 | stop_grad_conv1=stop_grad_conv1,
49 | img_size=img_size,
50 | patch_size=patch_size,
51 | in_chans=in_chans,
52 | num_classes=num_classes,
53 | embed_dim=embed_dim,
54 | depth=depth,
55 | num_heads=num_heads,
56 | mlp_ratio=mlp_ratio,
57 | qkv_bias=qkv_bias,
58 | representation_size=representation_size,
59 | distilled=distilled,
60 | drop_rate=drop_rate,
61 | attn_drop_rate=attn_drop_rate,
62 | drop_path_rate=drop_path_rate,
63 | embed_layer=embed_layer,
64 | norm_layer=norm_layer,
65 | act_layer=act_layer,
66 | weight_init=weight_init,
67 | **kwargs
68 | )
69 |
70 | self.adapter_cfg = adapter_cfg
71 |
72 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
73 | act_layer = act_layer or nn.GELU
74 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
75 |
76 | if adapter_cfg.STYLE == "Pfeiffer":
77 | self.blocks = nn.Sequential(*[
78 | Pfeiffer_Block(
79 | adapter_config=adapter_cfg,
80 | dim=embed_dim,
81 | num_heads=num_heads,
82 | mlp_ratio=mlp_ratio,
83 | qkv_bias=qkv_bias,
84 | drop=drop_rate,
85 | attn_drop=attn_drop_rate,
86 | drop_path=dpr[i],
87 | norm_layer=norm_layer,
88 | act_layer=act_layer) for i in range(depth)])
89 | else:
90 | raise ValueError("Other adapter styles are not supported.")
91 |
92 |
93 |
94 | def vit_base(adapter_cfg, **kwargs):
95 | model = ADPT_VisionTransformerMoCo(
96 | adapter_cfg,
97 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
98 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
99 | model.default_cfg = _cfg()
100 | return model
101 |
--------------------------------------------------------------------------------
/src/solver/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import math
3 |
4 | import torch.optim as optim
5 | from fvcore.common.config import CfgNode
6 | from torch.optim.lr_scheduler import LambdaLR
7 |
8 |
9 | def make_scheduler(
10 | optimizer: optim.Optimizer, train_params: CfgNode
11 | ) -> LambdaLR:
12 | warmup = train_params.WARMUP_EPOCH
13 | total_iters = train_params.TOTAL_EPOCH
14 |
15 | if train_params.SCHEDULER == "cosine":
16 | scheduler = WarmupCosineSchedule(
17 | optimizer,
18 | warmup_steps=warmup,
19 | t_total=total_iters
20 | )
21 | elif train_params.SCHEDULER == "cosine_hardrestart":
22 | scheduler = WarmupCosineWithHardRestartsSchedule(
23 | optimizer,
24 | warmup_steps=warmup,
25 | t_total=total_iters
26 | )
27 |
28 | elif train_params.SCHEDULER == "plateau":
29 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(
30 | optimizer,
31 | "max",
32 | patience=5,
33 | verbose=True,
34 | factor=train_params.LR_DECAY_FACTOR,
35 | )
36 | else:
37 | scheduler = None
38 | return scheduler
39 |
40 |
41 | class WarmupCosineSchedule(LambdaLR):
42 | """ Linear warmup and then cosine decay.
43 | Linearly increases learning rate from 0 to 1 over `warmup_steps`.
44 | Decreases learning rate from 1. to 0. over remaining
45 | `t_total - warmup_steps` steps following a cosine curve.
46 | If `cycles` (default=0.5) is different from default, learning rate
47 | follows cosine function after warmup.
48 | """
49 | def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1):
50 | self.warmup_steps = warmup_steps
51 | self.t_total = t_total
52 | self.cycles = cycles
53 | super(WarmupCosineSchedule, self).__init__(
54 | optimizer, self.lr_lambda, last_epoch=last_epoch)
55 |
56 | def lr_lambda(self, step):
57 | if step < self.warmup_steps:
58 | return float(step) / float(max(1.0, self.warmup_steps))
59 | # progress after warmup
60 | progress = float(step - self.warmup_steps) / float(max(
61 | 1, self.t_total - self.warmup_steps))
62 | return max(
63 | 0.0,
64 | 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress))
65 | )
66 |
67 |
68 | class WarmupCosineWithHardRestartsSchedule(LambdaLR):
69 | """ Linear warmup and then cosine cycles with hard restarts.
70 | Linearly increases learning rate from 0 to 1 over `warmup_steps`.
71 | If `cycles` (default=1.) is different from default, learning rate
72 | follows `cycles` times a cosine decaying learning rate
73 | (with hard restarts).
74 | """
75 | def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1):
76 | self.warmup_steps = warmup_steps
77 | self.t_total = t_total
78 | self.cycles = cycles
79 | super(WarmupCosineWithHardRestartsSchedule, self).__init__(
80 | optimizer, self.lr_lambda, last_epoch=last_epoch)
81 |
82 | def lr_lambda(self, step):
83 | if step < self.warmup_steps:
84 | return float(step) / float(max(1, self.warmup_steps))
85 | # progress after warmup
86 | progress = float(step - self.warmup_steps) / float(
87 | max(1, self.t_total - self.warmup_steps))
88 | if progress >= 1.0:
89 | return 0.0
90 | return max(
91 | 0.0,
92 | 0.5 * (1. + math.cos(
93 | math.pi * ((float(self.cycles) * progress) % 1.0)))
94 | )
95 |
--------------------------------------------------------------------------------
/src/data/vtab_datasets/oxford_iiit_pet.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3 | # Copyright 2019 Google LLC.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Implements OxfordIIITPet data class."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | import tensorflow_datasets as tfds
24 |
25 | from . import base as base
26 | from .registry import Registry
27 | # This constant specifies the percentage of data that is used to create custom
28 | # train/val splits. Specifically, TRAIN_SPLIT_PERCENT% of the official training
29 | # split is used as a new training split and the rest is used for validation.
30 | TRAIN_SPLIT_PERCENT = 80
31 |
32 |
33 | @Registry.register("data.oxford_iiit_pet", "class")
34 | class OxfordIIITPetData(base.ImageTfdsData):
35 | """Provides OxfordIIITPet data.
36 |
37 | The OxfordIIITPet dataset comes only with a training and test set.
38 | Therefore, the validation set is split out of the original training set, and
39 | the remaining examples are used as the "train" split. The "trainval" split
40 | corresponds to the original training set.
41 |
42 | For additional details and usage, see the base class.
43 | """
44 |
45 | def __init__(self, data_dir=None, train_split_percent=None):
46 |
47 | dataset_builder = tfds.builder("oxford_iiit_pet:3.*.*", data_dir=data_dir)
48 | dataset_builder.download_and_prepare()
49 | train_split_percent = train_split_percent or TRAIN_SPLIT_PERCENT
50 |
51 | # Creates a dict with example counts for each split.
52 | trainval_count = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples
53 | test_count = dataset_builder.info.splits[tfds.Split.TEST].num_examples
54 | num_samples_splits = {
55 | "train": (train_split_percent * trainval_count) // 100,
56 | "val": trainval_count - (train_split_percent * trainval_count) // 100,
57 | "trainval": trainval_count,
58 | "test": test_count,
59 | "train800": 800,
60 | "val200": 200,
61 | "train800val200": 1000,
62 | }
63 |
64 | # Defines dataset specific train/val/trainval/test splits.
65 | tfds_splits = {
66 | "train": "train[:{}]".format(num_samples_splits["train"]),
67 | "val": "train[{}:]".format(num_samples_splits["train"]),
68 | "trainval": tfds.Split.TRAIN,
69 | "test": tfds.Split.TEST,
70 | "train800": "train[:800]",
71 | "val200": "train[{}:{}]".format(
72 | num_samples_splits["train"], num_samples_splits["train"]+200),
73 | "train800val200": "train[:800]+train[{}:{}]".format(
74 | num_samples_splits["train"], num_samples_splits["train"]+200),
75 | }
76 |
77 | super(OxfordIIITPetData, self).__init__(
78 | dataset_builder=dataset_builder,
79 | tfds_splits=tfds_splits,
80 | num_samples_splits=num_samples_splits,
81 | num_preprocessing_threads=400,
82 | shuffle_buffer_size=10000,
83 | # Note: Export only image and label tensors with their original types.
84 | base_preprocess_fn=base.make_get_tensors_fn(["image", "label"]),
85 | num_classes=dataset_builder.info.features["label"].num_classes)
86 |
--------------------------------------------------------------------------------
/src/data/vtab_datasets/eurosat.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3 | # Copyright 2019 Google LLC.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Implements EurosatData class."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | import tensorflow_datasets as tfds
24 |
25 | from . import base as base
26 | from .registry import Registry
27 |
28 | TRAIN_SPLIT_PERCENT = 60
29 | VALIDATION_SPLIT_PERCENT = 20
30 | TEST_SPLIT_PERCENT = 20
31 |
32 |
33 | @Registry.register("data.eurosat", "class")
34 | class EurosatData(base.ImageTfdsData):
35 | """Provides EuroSat dataset.
36 |
37 | EuroSAT dataset is based on Sentinel-2 satellite images covering 13 spectral
38 | bands and consisting of 10 classes with 27000 labeled and
39 | geo-referenced samples.
40 |
41 | URL: https://github.com/phelber/eurosat
42 | """
43 |
44 | def __init__(self, subset="rgb", data_key="image", data_dir=None):
45 | dataset_name = "eurosat/{}:2.*.*".format(subset)
46 | dataset_builder = tfds.builder(dataset_name, data_dir=data_dir)
47 | dataset_builder.download_and_prepare()
48 |
49 | # Example counts are retrieved from the tensorflow dataset info.
50 | num_examples = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples
51 | train_count = num_examples * TRAIN_SPLIT_PERCENT // 100
52 | val_count = num_examples * VALIDATION_SPLIT_PERCENT // 100
53 | test_count = num_examples * TEST_SPLIT_PERCENT // 100
54 |
55 | tfds_splits = {
56 | "train":
57 | "train[:{}]".format(train_count),
58 | "val":
59 | "train[{}:{}]".format(train_count, train_count+val_count),
60 | "trainval":
61 | "train[:{}]".format(train_count+val_count),
62 | "test":
63 | "train[{}:]".format(train_count+val_count),
64 | "train800":
65 | "train[:800]",
66 | "val200":
67 | "train[{}:{}]".format(train_count, train_count+200),
68 | "train800val200":
69 | "train[:800]+train[{}:{}]".format(train_count, train_count+200),
70 | }
71 |
72 | # Creates a dict with example counts for each split.
73 | num_samples_splits = {
74 | "train": train_count,
75 | "val": val_count,
76 | "trainval": train_count + val_count,
77 | "test": test_count,
78 | "train800": 800,
79 | "val200": 200,
80 | "train800val200": 1000,
81 | }
82 |
83 | num_channels = 3
84 | if data_key == "sentinel2":
85 | num_channels = 13
86 |
87 | super(EurosatData, self).__init__(
88 | dataset_builder=dataset_builder,
89 | tfds_splits=tfds_splits,
90 | num_samples_splits=num_samples_splits,
91 | num_preprocessing_threads=100,
92 | shuffle_buffer_size=10000,
93 | base_preprocess_fn=base.make_get_and_cast_tensors_fn({
94 | data_key: ("image", None),
95 | "label": ("label", None),
96 | }),
97 | image_key=data_key,
98 | num_channels=num_channels,
99 | num_classes=dataset_builder.info.features["label"].num_classes)
100 |
--------------------------------------------------------------------------------
/src/data/vtab_datasets/cifar.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3 | # Copyright 2019 Google LLC.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Implements Cifar data class."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 | from . import base as base
23 | from .registry import Registry
24 | import tensorflow_datasets as tfds
25 |
26 | # This constant specifies the percentage of data that is used to create custom
27 | # train/val splits. Specifically, TRAIN_SPLIT_PERCENT% of the official training
28 | # split is used as a new training split and the rest is used for validation.
29 | TRAIN_SPLIT_PERCENT = 90
30 |
31 |
32 | @Registry.register("data.cifar", "class")
33 | class CifarData(base.ImageTfdsData):
34 | """Provides Cifar10 or Cifar100 data.
35 |
36 | Cifar comes only with a training and test set. Therefore, the validation set
37 | is split out of the original training set, and the remaining examples are used
38 | as the "train" split. The "trainval" split corresponds to the original
39 | training set.
40 |
41 | For additional details and usage, see the base class.
42 | """
43 |
44 | def __init__(self, num_classes=10, data_dir=None, train_split_percent=None):
45 |
46 | if num_classes == 10:
47 | dataset_builder = tfds.builder("cifar10:3.*.*", data_dir=data_dir)
48 | elif num_classes == 100:
49 | dataset_builder = tfds.builder("cifar100:3.*.*", data_dir=data_dir)
50 | else:
51 | raise ValueError(
52 | "Number of classes must be 10 or 100, got {}".format(num_classes))
53 |
54 | dataset_builder.download_and_prepare()
55 |
56 | train_split_percent = train_split_percent or TRAIN_SPLIT_PERCENT
57 |
58 | # Creates a dict with example counts for each split.
59 | trainval_count = dataset_builder.info.splits["train"].num_examples
60 | test_count = dataset_builder.info.splits["test"].num_examples
61 | num_samples_splits = {
62 | "train": (train_split_percent * trainval_count) // 100,
63 | "val": trainval_count - (train_split_percent * trainval_count) // 100,
64 | "trainval": trainval_count,
65 | "test": test_count,
66 | "train800": 800,
67 | "val200": 200,
68 | "train800val200": 1000,
69 | }
70 |
71 | # Defines dataset specific train/val/trainval/test splits.
72 | tfds_splits = {
73 | "train": "train[:{}]".format(num_samples_splits["train"]),
74 | "val": "train[{}:]".format(num_samples_splits["train"]),
75 | "trainval": "train",
76 | "test": "test",
77 | "train800": "train[:800]",
78 | "val200": "train[{}:{}]".format(
79 | num_samples_splits["train"], num_samples_splits["train"]+200),
80 | "train800val200": "train[:800]+train[{}:{}]".format(
81 | num_samples_splits["train"], num_samples_splits["train"]+200),
82 | }
83 |
84 | super(CifarData, self).__init__(
85 | dataset_builder=dataset_builder,
86 | tfds_splits=tfds_splits,
87 | num_samples_splits=num_samples_splits,
88 | num_preprocessing_threads=400,
89 | shuffle_buffer_size=10000,
90 | # Note: Export only image and label tensors with their original types.
91 | base_preprocess_fn=base.make_get_tensors_fn(["image", "label", "id"]),
92 | num_classes=dataset_builder.info.features["label"].num_classes)
93 |
--------------------------------------------------------------------------------
/src/data/vtab_datasets/smallnorb.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3 | # Copyright 2019 Google LLC.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Implements the SmallNORB data class."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | import tensorflow.compat.v1 as tf
24 | import tensorflow_datasets as tfds
25 |
26 | from . import base as base
27 | from .registry import Registry
28 | # This constant specifies the percentage of data that is used to create custom
29 | # val/test splits. Specifically, VAL_SPLIT_PERCENT% of the official testing
30 | # split is used as a new validation split and the rest is used for testing.
31 | VAL_SPLIT_PERCENT = 50
32 |
33 |
34 | @Registry.register("data.smallnorb", "class")
35 | class SmallNORBData(base.ImageTfdsData):
36 | """Provides the SmallNORB data set.
37 |
38 | SmallNORB comes only with a training and test set. Therefore, the validation
39 | set is split out of the original training set, and the remaining examples are
40 | used as the "train" split. The "trainval" split corresponds to the original
41 | training set.
42 |
43 | For additional details and usage, see the base class.
44 |
45 | The data set page is https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/.
46 | """
47 |
48 | def __init__(self, predicted_attribute, data_dir=None):
49 | dataset_builder = tfds.builder("smallnorb:2.*.*", data_dir=data_dir)
50 | dataset_builder.download_and_prepare()
51 |
52 | if predicted_attribute not in dataset_builder.info.features:
53 | raise ValueError(
54 | "{} is not a valid attribute to predict.".format(predicted_attribute))
55 |
56 | # Defines dataset specific train/val/trainval/test splits.
57 | tfds_splits = {
58 | "train": "train",
59 | "val": "test[:{}%]".format(VAL_SPLIT_PERCENT),
60 | "trainval": "train+test[:{}%]".format(VAL_SPLIT_PERCENT),
61 | "test": "test[{}%:]".format(VAL_SPLIT_PERCENT),
62 | "train800": "train[:800]",
63 | "val200": "test[:200]",
64 | "train800val200": "train[:800]+test[:200]",
65 | }
66 |
67 | # Creates a dict with example counts for each split.
68 | train_count = dataset_builder.info.splits["train"].num_examples
69 | test_count = dataset_builder.info.splits["test"].num_examples
70 | num_samples_validation = VAL_SPLIT_PERCENT * test_count // 100
71 | num_samples_splits = {
72 | "train": train_count,
73 | "val": num_samples_validation,
74 | "trainval": train_count + num_samples_validation,
75 | "test": test_count - num_samples_validation,
76 | "train800": 800,
77 | "val200": 200,
78 | "train800val200": 1000,
79 | }
80 |
81 | def preprocess_fn(tensors):
82 | # For consistency with other datasets, image needs to have three channels.
83 | image = tf.tile(tensors["image"], [1, 1, 3])
84 | return dict(image=image, label=tensors[predicted_attribute])
85 |
86 | info = dataset_builder.info
87 | super(SmallNORBData, self).__init__(
88 | dataset_builder=dataset_builder,
89 | tfds_splits=tfds_splits,
90 | num_samples_splits=num_samples_splits,
91 | num_preprocessing_threads=400,
92 | shuffle_buffer_size=10000,
93 | # We extract the attribute we want to predict in the preprocessing.
94 | base_preprocess_fn=preprocess_fn,
95 | num_classes=info.features[predicted_attribute].num_classes)
96 |
--------------------------------------------------------------------------------
/src/data/vtab_datasets/oxford_flowers102.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3 | # Copyright 2019 Google LLC.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Implements oxford flowers 102 data class."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | import tensorflow_datasets as tfds
24 |
25 | from . import base as base
26 | from .registry import Registry
27 |
28 |
29 | @Registry.register("data.oxford_flowers102", "class")
30 | class OxfordFlowers102Data(base.ImageTfdsData):
31 | """Provides Oxford 102 categories flowers dataset.
32 |
33 | See corresponding tfds dataset for details.
34 |
35 | URL: https://www.robots.ox.ac.uk/~vgg/data/flowers/102/
36 | """
37 |
38 | def __init__(self, data_dir=None, train_split_percent=None):
39 | dataset_builder = tfds.builder("oxford_flowers102:2.*.*", data_dir=data_dir)
40 | dataset_builder.download_and_prepare()
41 |
42 | # Example counts are retrieved from the tensorflow dataset info.
43 | train_count = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples
44 | val_count = dataset_builder.info.splits[tfds.Split.VALIDATION].num_examples
45 | test_count = dataset_builder.info.splits[tfds.Split.TEST].num_examples
46 |
47 | if train_split_percent:
48 | tfds_splits = {
49 | "train": "train[:{s}%]+validation[:{s}%]".format(
50 | s=train_split_percent),
51 | "val": "train[-{s}%:]+validation[-{s}%:]".format(
52 | s=train_split_percent),
53 | "trainval": "train+validation",
54 | "test": "test",
55 | "train800": "train[:800]",
56 | "val200": "validation[:200]",
57 | "train800val200": "train[:800]+validation[:200]",
58 | }
59 | num_samples_splits = {
60 | "train": (((train_count + val_count) // 100)
61 | * train_split_percent),
62 | "val": (((train_count + val_count) // 100) *
63 | (100 - train_split_percent)),
64 | "trainval": train_count + val_count,
65 | "test": test_count,
66 | "train800": 800,
67 | "val200": 200,
68 | "train800val200": 1000,
69 | }
70 | else:
71 | tfds_splits = {
72 | "train": "train",
73 | "val": "validation",
74 | "trainval": "train+validation",
75 | "test": "test",
76 | "train800": "train[:800]",
77 | "val200": "validation[:200]",
78 | "train800val200": "train[:800]+validation[:200]",
79 | }
80 | num_samples_splits = {
81 | "train": train_count,
82 | "val": val_count,
83 | "trainval": train_count + val_count,
84 | "test": test_count,
85 | "train800": 800,
86 | "val200": 200,
87 | "train800val200": 1000,
88 | }
89 |
90 | super(OxfordFlowers102Data, self).__init__(
91 | dataset_builder=dataset_builder,
92 | tfds_splits=tfds_splits,
93 | num_samples_splits=num_samples_splits,
94 | num_preprocessing_threads=400,
95 | shuffle_buffer_size=10000,
96 | # Note: Rename tensors but keep their original types.
97 | base_preprocess_fn=base.make_get_and_cast_tensors_fn({
98 | "image": ("image", None),
99 | "label": ("label", None),
100 | }),
101 | num_classes=dataset_builder.info.features["label"]
102 | .num_classes)
103 |
--------------------------------------------------------------------------------
/launch.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | launch helper functions
4 | """
5 | import argparse
6 | import os
7 | import sys
8 | import pprint
9 | import PIL
10 | from collections import defaultdict
11 | from tabulate import tabulate
12 | from typing import Tuple
13 |
14 | import torch
15 | from src.utils.file_io import PathManager
16 | from src.utils import logging
17 | from src.utils.distributed import get_rank, get_world_size
18 |
19 |
20 | def collect_torch_env() -> str:
21 | try:
22 | import torch.__config__
23 |
24 | return torch.__config__.show()
25 | except ImportError:
26 | # compatible with older versions of pytorch
27 | from torch.utils.collect_env import get_pretty_env_info
28 |
29 | return get_pretty_env_info()
30 |
31 |
32 | def get_env_module() -> Tuple[str]:
33 | var_name = "ENV_MODULE"
34 | return var_name, os.environ.get(var_name, "")
35 |
36 |
37 | def collect_env_info() -> str:
38 | data = []
39 | data.append(("Python", sys.version.replace("\n", "")))
40 | data.append(get_env_module())
41 | data.append(("PyTorch", torch.__version__))
42 | data.append(("PyTorch Debug Build", torch.version.debug))
43 |
44 | has_cuda = torch.cuda.is_available()
45 | data.append(("CUDA available", has_cuda))
46 | if has_cuda:
47 | data.append(("CUDA ID", os.environ["CUDA_VISIBLE_DEVICES"]))
48 | devices = defaultdict(list)
49 | for k in range(torch.cuda.device_count()):
50 | devices[torch.cuda.get_device_name(k)].append(str(k))
51 | for name, devids in devices.items():
52 | data.append(("GPU " + ",".join(devids), name))
53 | data.append(("Pillow", PIL.__version__))
54 |
55 | try:
56 | import cv2
57 |
58 | data.append(("cv2", cv2.__version__))
59 | except ImportError:
60 | pass
61 | env_str = tabulate(data) + "\n"
62 | env_str += collect_torch_env()
63 | return env_str
64 |
65 |
66 | def default_argument_parser():
67 | """
68 | create a simple parser to wrap around config file
69 | """
70 | parser = argparse.ArgumentParser(description="visual-prompt")
71 | parser.add_argument(
72 | "--config-file", default="", metavar="FILE", help="path to config file")
73 | parser.add_argument(
74 | "--train-type", default="", help="training types")
75 | parser.add_argument(
76 | "--sparse-train", default=False, action="store_true", help="sparse training")
77 | parser.add_argument(
78 | "--use-wandb", action="store_true", help="use wandb to log")
79 | parser.add_argument(
80 | "opts",
81 | help="Modify config options using the command-line",
82 | default=None,
83 | nargs=argparse.REMAINDER,
84 | )
85 |
86 | return parser
87 |
88 |
89 | def logging_train_setup(args, cfg) -> None:
90 | output_dir = cfg.OUTPUT_DIR
91 | if output_dir:
92 | PathManager.mkdirs(output_dir)
93 |
94 | logger = logging.setup_logging(
95 | cfg.NUM_GPUS, get_world_size(), output_dir, name="MOSA")
96 |
97 | # Log basic information about environment, cmdline arguments, and config
98 | rank = get_rank()
99 | logger.info(
100 | f"Rank of current process: {rank}. World size: {get_world_size()}")
101 | logger.info("Environment info:\n" + collect_env_info())
102 |
103 | logger.info("Command line arguments: " + str(args))
104 | if hasattr(args, "config_file") and args.config_file != "":
105 | logger.info(
106 | "Contents of args.config_file={}:\n{}".format(
107 | args.config_file,
108 | PathManager.open(args.config_file, "r").read()
109 | )
110 | )
111 | # Show the config
112 | logger.info("Training with config:")
113 | logger.info(pprint.pformat(cfg))
114 | # cudnn benchmark has large overhead.
115 | # It shouldn't be used considering the small size of typical val set.
116 | if not (hasattr(args, "eval_only") and args.eval_only):
117 | torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK
118 |
--------------------------------------------------------------------------------
/src/engine/evaluator.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import numpy as np
3 |
4 | from collections import defaultdict
5 | from typing import List, Union
6 |
7 | from .eval import multilabel
8 | from .eval import singlelabel
9 | from ..utils import logging
10 | logger = logging.get_logger("MOSA")
11 |
12 |
13 | class Evaluator():
14 | """
15 | An evaluator with below logics:
16 |
17 | 1. find which eval module to use.
18 | 2. store the eval results, pretty print it in log file as well.
19 | """
20 |
21 | def __init__(
22 | self,
23 | ) -> None:
24 | self.results = defaultdict(dict)
25 | self.iteration = -1
26 | self.threshold_end = 0.5
27 |
28 | def update_iteration(self, iteration: int) -> None:
29 | """update iteration info"""
30 | self.iteration = iteration
31 |
32 | def update_result(self, metric: str, value: Union[float, dict]) -> None:
33 | if self.iteration > -1:
34 | key_name = "epoch_" + str(self.iteration)
35 | else:
36 | key_name = "final"
37 | if isinstance(value, float):
38 | self.results[key_name].update({metric: value})
39 | else:
40 | if metric in self.results[key_name]:
41 | self.results[key_name][metric].update(value)
42 | else:
43 | self.results[key_name].update({metric: value})
44 |
45 | def classify(self, probs, targets, test_data, multilabel=False):
46 | """
47 | Evaluate classification result.
48 | Args:
49 | probs: np.ndarray for num_data x num_class, predicted probabilities
50 | targets: np.ndarray for multilabel, list of integers for single label
51 | test_labels: map test image ids to a list of class labels
52 | """
53 | if not targets:
54 | raise ValueError(
55 | "When evaluating classification, need at least give targets")
56 |
57 | if multilabel:
58 | self._eval_multilabel(probs, targets, test_data)
59 | else:
60 | self._eval_singlelabel(probs, targets, test_data)
61 |
62 | def _eval_singlelabel(
63 | self,
64 | scores: np.ndarray,
65 | targets: List[int],
66 | eval_type: str
67 | ) -> None:
68 | """
69 | if number of labels > 2:
70 | top1 and topk (5 by default) accuracy
71 | if number of labels == 2:
72 | top1 and rocauc
73 | """
74 | acc_dict = singlelabel.compute_acc_auc(scores, targets)
75 |
76 | log_results = {
77 | k: np.around(v * 100, decimals=2) for k, v in acc_dict.items()
78 | }
79 | save_results = acc_dict
80 |
81 | self.log_and_update(log_results, save_results, eval_type)
82 |
83 | def _eval_multilabel(
84 | self,
85 | scores: np.ndarray,
86 | targets: np.ndarray,
87 | eval_type: str
88 | ) -> None:
89 | num_labels = scores.shape[-1]
90 | targets = multilabel.multihot(targets, num_labels)
91 |
92 | log_results = {}
93 | ap, ar, mAP, mAR = multilabel.compute_map(scores, targets)
94 | f1_dict = multilabel.get_best_f1_scores(
95 | targets, scores, self.threshold_end)
96 |
97 | log_results["mAP"] = np.around(mAP * 100, decimals=2)
98 | log_results["mAR"] = np.around(mAR * 100, decimals=2)
99 | log_results.update({
100 | k: np.around(v * 100, decimals=2) for k, v in f1_dict.items()})
101 | save_results = {
102 | "ap": ap, "ar": ar, "mAP": mAP, "mAR": mAR, "f1": f1_dict
103 | }
104 | self.log_and_update(log_results, save_results, eval_type)
105 |
106 | def log_and_update(self, log_results, save_results, eval_type):
107 | log_str = ""
108 | for k, result in log_results.items():
109 | if not isinstance(result, np.ndarray):
110 | log_str += f"{k}: {result:.2f}\t"
111 | else:
112 | log_str += f"{k}: {list(result)}\t"
113 | logger.info(f"Classification results with {eval_type}: {log_str}")
114 | # save everything
115 | self.update_result("classification", {eval_type: save_results})
116 |
--------------------------------------------------------------------------------
/src/data/vtab_datasets/svhn.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3 | # Copyright 2019 Google LLC.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Implements Svhn data class."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | import tensorflow_datasets as tfds
24 |
25 | from . import base as base
26 | from .registry import Registry
27 | # This constant specifies the percentage of data that is used to create custom
28 | # train/val splits. Specifically, TRAIN_SPLIT_PERCENT% of the official training
29 | # split is used as a new training split and the rest is used for validation.
30 | TRAIN_SPLIT_PERCENT = 90
31 |
32 |
33 | @Registry.register("data.svhn", "class")
34 | class SvhnData(base.ImageTfdsData):
35 | """Provides SVHN data.
36 |
37 | The Street View House Numbers (SVHN) Dataset is an image digit recognition
38 | dataset of over 600,000 color digit images coming from real world data.
39 | Split size:
40 | - Training set: 73,257 images
41 | - Testing set: 26,032 images
42 | - Extra training set: 531,131 images
43 | Following the common setup on SVHN, we only use the official training and
44 | testing data. Images are cropped to 32x32.
45 |
46 | URL: http://ufldl.stanford.edu/housenumbers/
47 | """
48 |
49 | def __init__(self, data_dir=None):
50 | dataset_builder = tfds.builder("svhn_cropped:3.*.*", data_dir=data_dir)
51 | dataset_builder.download_and_prepare()
52 |
53 | # Example counts are retrieved from the tensorflow dataset info.
54 | trainval_count = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples
55 | test_count = dataset_builder.info.splits[tfds.Split.TEST].num_examples
56 |
57 | # Creates a dict with example counts for each split.
58 | num_samples_splits = {
59 | # Calculates the train/val split example count based on percent.
60 | "train": TRAIN_SPLIT_PERCENT * trainval_count // 100,
61 | "val": trainval_count - TRAIN_SPLIT_PERCENT * trainval_count // 100,
62 | "trainval": trainval_count,
63 | "test": test_count,
64 | "train800": 800,
65 | "val200": 200,
66 | "train800val200": 1000,
67 | }
68 |
69 | # Defines dataset specific train/val/trainval/test splits.
70 | # The validation set is split out of the original training set, and the
71 | # remaining examples are used as the "train" split. The "trainval" split
72 | # corresponds to the original training set.
73 | tfds_splits = {
74 | "train":
75 | "train[:{}]".format(num_samples_splits["train"]),
76 | "val":
77 | "train[{}:]".format(num_samples_splits["train"]),
78 | "trainval":
79 | "train",
80 | "test":
81 | "test",
82 | "train800":
83 | "train[:800]",
84 | "val200":
85 | "train[{}:{}]".format(num_samples_splits["train"],
86 | num_samples_splits["train"] + 200),
87 | "train800val200":
88 | "train[:800]+train[{}:{}]".format(
89 | num_samples_splits["train"], num_samples_splits["train"] + 200),
90 | }
91 |
92 | super(SvhnData, self).__init__(
93 | dataset_builder=dataset_builder,
94 | tfds_splits=tfds_splits,
95 | num_samples_splits=num_samples_splits,
96 | num_preprocessing_threads=400,
97 | shuffle_buffer_size=10000,
98 | # Note: Rename tensors but keep their original types.
99 | base_preprocess_fn=base.make_get_and_cast_tensors_fn({
100 | "image": ("image", None),
101 | "label": ("label", None),
102 | }),
103 | num_classes=dataset_builder.info.features["label"]
104 | .num_classes)
105 |
--------------------------------------------------------------------------------
/src/data/loader.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | """Data loader."""
4 | import torch
5 | from torch.utils.data.distributed import DistributedSampler
6 | from torch.utils.data.sampler import RandomSampler
7 |
8 | from ..utils import logging
9 | from .datasets.json_dataset import (
10 | CUB200Dataset, CarsDataset, DogsDataset, FlowersDataset, NabirdsDataset
11 | )
12 | from .datasets.pt_dataset import (
13 | CIFAR100Dataset, AircraftDataset, Food101Dataset
14 | )
15 |
16 | logger = logging.get_logger("MOSA")
17 | _TORCH_BASIC_DS = {
18 | "cifar100": CIFAR100Dataset,
19 | 'aircraft': AircraftDataset,
20 | "food101": Food101Dataset,
21 | }
22 | _DATASET_CATALOG = {
23 | "CUB": CUB200Dataset,
24 | "OxfordFlowers": FlowersDataset,
25 | "StanfordCars": CarsDataset,
26 | "StanfordDogs": DogsDataset,
27 | "nabirds": NabirdsDataset,
28 | }
29 |
30 |
31 | def _construct_loader(cfg, split, batch_size, shuffle, drop_last):
32 | """Constructs the data loader for the given dataset."""
33 | dataset_name = cfg.DATA.NAME
34 |
35 | # Construct the dataset
36 | if dataset_name.startswith("vtab-"):
37 | # import the tensorflow here only if needed
38 | from .datasets.tf_dataset import TFDataset
39 | dataset = TFDataset(cfg, split)
40 | # from .datasets.vtab_dataset import VTABDataset
41 | # dataset = VTABDataset(cfg, split)
42 |
43 | elif dataset_name == "multi":
44 | from .datasets.json_dataset import MultiDataset
45 | dataset = MultiDataset(cfg, split)
46 |
47 | else:
48 | if dataset_name in _TORCH_BASIC_DS.keys():
49 | dataset = _TORCH_BASIC_DS[dataset_name](cfg, split)
50 | elif dataset_name in _DATASET_CATALOG.keys():
51 | dataset = _DATASET_CATALOG[dataset_name](cfg, split)
52 | else:
53 | raise ValueError("Dataset '{}' not supported".format(dataset_name))
54 |
55 | # Create a sampler for multi-process training
56 | sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None
57 | # Create a loader
58 | loader = torch.utils.data.DataLoader(
59 | dataset,
60 | batch_size=batch_size,
61 | shuffle=(False if sampler else shuffle),
62 | sampler=sampler,
63 | num_workers=cfg.DATA.NUM_WORKERS,
64 | pin_memory=cfg.DATA.PIN_MEMORY,
65 | drop_last=drop_last,
66 | )
67 | return loader
68 |
69 |
70 | def construct_train_loader(cfg):
71 | """Train loader wrapper."""
72 | if cfg.NUM_GPUS > 1:
73 | drop_last = True
74 | else:
75 | drop_last = False
76 | return _construct_loader(
77 | cfg=cfg,
78 | split="train",
79 | batch_size=int(cfg.DATA.BATCH_SIZE / cfg.NUM_GPUS),
80 | shuffle=True,
81 | drop_last=drop_last,
82 | )
83 |
84 |
85 | def construct_trainval_loader(cfg):
86 | """Train loader wrapper."""
87 | if cfg.NUM_GPUS > 1:
88 | drop_last = True
89 | else:
90 | drop_last = False
91 | return _construct_loader(
92 | cfg=cfg,
93 | split="trainval",
94 | batch_size=int(cfg.DATA.BATCH_SIZE / cfg.NUM_GPUS),
95 | shuffle=True,
96 | drop_last=drop_last,
97 | )
98 |
99 |
100 | def construct_test_loader(cfg):
101 | """Test loader wrapper."""
102 | return _construct_loader(
103 | cfg=cfg,
104 | split="test",
105 | batch_size=int(cfg.DATA.BATCH_SIZE / cfg.NUM_GPUS),
106 | shuffle=False,
107 | drop_last=False,
108 | )
109 |
110 |
111 | def construct_val_loader(cfg, batch_size=None):
112 | if batch_size is None:
113 | bs = int(cfg.DATA.BATCH_SIZE / cfg.NUM_GPUS)
114 | else:
115 | bs = batch_size
116 | """Validation loader wrapper."""
117 | return _construct_loader(
118 | cfg=cfg,
119 | split="val",
120 | batch_size=bs,
121 | shuffle=False,
122 | drop_last=False,
123 | )
124 |
125 |
126 | def shuffle(loader, cur_epoch):
127 | """"Shuffles the data."""
128 | assert isinstance(
129 | loader.sampler, (RandomSampler, DistributedSampler)
130 | ), "Sampler type '{}' not supported".format(type(loader.sampler))
131 | # RandomSampler handles shuffling automatically
132 | if isinstance(loader.sampler, DistributedSampler):
133 | # DistributedSampler shuffles data based on epoch
134 | loader.sampler.set_epoch(cur_epoch)
135 |
--------------------------------------------------------------------------------
/src/models/vit_adapter/vit_mae.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | borrow from https://github.com/facebookresearch/mae/blob/main/models_vit.py
4 | """
5 | from functools import partial
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from .adapter_block import Pfeiffer_Block
11 | from ..vit_backbones.vit_mae import VisionTransformer
12 | from timm.models.layers import PatchEmbed
13 | from ...utils import logging
14 | logger = logging.get_logger("MOSA")
15 |
16 |
17 | class ADPT_VisionTransformer(VisionTransformer):
18 | """ Vision Transformer with support for global average pooling
19 | """
20 | def __init__(
21 | self,
22 | adapter_cfg,
23 | img_size=224,
24 | patch_size=16,
25 | in_chans=3,
26 | num_classes=1000,
27 | embed_dim=768,
28 | depth=12,
29 | num_heads=12,
30 | mlp_ratio=4.,
31 | qkv_bias=True,
32 | representation_size=None,
33 | distilled=False,
34 | drop_rate=0.,
35 | attn_drop_rate=0.,
36 | drop_path_rate=0.,
37 | embed_layer=PatchEmbed,
38 | norm_layer=None,
39 | act_layer=None,
40 | weight_init='',
41 | **kwargs):
42 |
43 | super(ADPT_VisionTransformer, self).__init__(
44 | img_size=img_size,
45 | patch_size=patch_size,
46 | in_chans=in_chans,
47 | num_classes=num_classes,
48 | embed_dim=embed_dim,
49 | depth=depth,
50 | num_heads=num_heads,
51 | mlp_ratio=mlp_ratio,
52 | qkv_bias=qkv_bias,
53 | representation_size=representation_size,
54 | distilled=distilled,
55 | drop_rate=drop_rate,
56 | attn_drop_rate=attn_drop_rate,
57 | drop_path_rate=drop_path_rate,
58 | embed_layer=embed_layer,
59 | norm_layer=norm_layer,
60 | act_layer=act_layer,
61 | weight_init=weight_init,
62 | **kwargs
63 | )
64 |
65 | self.adapter_cfg = adapter_cfg
66 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
67 | act_layer = act_layer or nn.GELU
68 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
69 |
70 | if adapter_cfg.STYLE == "Pfeiffer":
71 | self.blocks = nn.Sequential(*[
72 | Pfeiffer_Block(
73 | adapter_config=adapter_cfg,
74 | dim=embed_dim,
75 | num_heads=num_heads,
76 | mlp_ratio=mlp_ratio,
77 | qkv_bias=qkv_bias,
78 | drop=drop_rate,
79 | attn_drop=attn_drop_rate,
80 | drop_path=dpr[i],
81 | norm_layer=norm_layer,
82 | act_layer=act_layer) for i in range(depth)])
83 | else:
84 | raise ValueError("Other adapter styles are not supported.")
85 |
86 |
87 |
88 | def build_model(model_type, adapter_cfg):
89 | if "vitb" in model_type:
90 | return vit_base_patch16(adapter_cfg)
91 | elif "vitl" in model_type:
92 | return vit_large_patch16(adapter_cfg)
93 | elif "vith" in model_type:
94 | return vit_huge_patch14(adapter_cfg)
95 |
96 |
97 | def vit_base_patch16(adapter_cfg, **kwargs):
98 | model = ADPT_VisionTransformer(
99 | adapter_cfg,
100 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune
101 | patch_size=16, embed_dim=768, depth=12, num_heads=12,
102 | mlp_ratio=4, qkv_bias=True,
103 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
104 | return model
105 |
106 |
107 | def vit_large_patch16(adapter_cfg, **kwargs):
108 | model = ADPT_VisionTransformer(
109 | adapter_cfg,
110 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune
111 | patch_size=16, embed_dim=1024, depth=24, num_heads=16,
112 | mlp_ratio=4, qkv_bias=True,
113 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
114 | return model
115 |
116 |
117 | def vit_huge_patch14(adapter_cfg, **kwargs):
118 | model = ADPT_VisionTransformer(
119 | adapter_cfg,
120 | drop_path_rate=0.1, global_pool=True, # using default settings for mae-finetune
121 | patch_size=14, embed_dim=1280, depth=32, num_heads=16,
122 | mlp_ratio=4, qkv_bias=True,
123 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
124 | return model
125 |
--------------------------------------------------------------------------------
/src/data/vtab_datasets/clevr.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3 | # Copyright 2019 Google LLC.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Implements CLEVR data class."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | import numpy as np
24 | import tensorflow.compat.v1 as tf
25 | import tensorflow_datasets as tfds
26 |
27 | from . import base as base
28 | from .registry import Registry
29 |
30 | TRAIN_SPLIT_PERCENT = 90
31 |
32 |
33 | def _count_preprocess_fn(x):
34 | return {"image": x["image"],
35 | "label": tf.size(x["objects"]["size"]) - 3}
36 |
37 |
38 | def _count_cylinders_preprocess_fn(x):
39 | # Class distribution:
40 |
41 | num_cylinders = tf.reduce_sum(
42 | tf.cast(tf.equal(x["objects"]["shape"], 2), tf.int32))
43 | return {"image": x["image"], "label": num_cylinders}
44 |
45 |
46 | def _closest_object_preprocess_fn(x):
47 | dist = tf.reduce_min(x["objects"]["pixel_coords"][:, 2])
48 | # These thresholds are uniformly spaced and result in more or less balanced
49 | # distribution of classes, see the resulting histogram:
50 |
51 | thrs = np.array([0.0, 8.0, 8.5, 9.0, 9.5, 10.0, 100.0])
52 | label = tf.reduce_max(tf.where((thrs - dist) < 0))
53 | return {"image": x["image"],
54 | "label": label}
55 |
56 |
57 | _TASK_DICT = {
58 | "count_all": {
59 | "preprocess_fn": _count_preprocess_fn,
60 | "num_classes": 8
61 | },
62 | "count_cylinders": {
63 | "preprocess_fn": _count_cylinders_preprocess_fn,
64 | "num_classes": 11
65 | },
66 | "closest_object_distance": {
67 | "preprocess_fn": _closest_object_preprocess_fn,
68 | "num_classes": 6
69 | },
70 | }
71 |
72 |
73 | @Registry.register("data.clevr", "class")
74 | class CLEVRData(base.ImageTfdsData):
75 | """Provides CLEVR dataset.
76 |
77 | Currently, two tasks are supported:
78 | 1. Predict number of objects.
79 | 2. Predict distnace to the closest object.
80 | """
81 |
82 | def __init__(self, task, data_dir=None):
83 |
84 | if task not in _TASK_DICT:
85 | raise ValueError("Unknown task: %s" % task)
86 |
87 | dataset_builder = tfds.builder("clevr:3.*.*", data_dir=data_dir)
88 | dataset_builder.download_and_prepare()
89 |
90 | # Creates a dict with example counts for each split.
91 | trainval_count = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples
92 | test_count = dataset_builder.info.splits[tfds.Split.TEST].num_examples
93 | num_samples_splits = {
94 | "train": (TRAIN_SPLIT_PERCENT * trainval_count) // 100,
95 | "val": trainval_count - (TRAIN_SPLIT_PERCENT * trainval_count) // 100,
96 | "trainval": trainval_count,
97 | "test": test_count,
98 | "train800": 800,
99 | "val200": 200,
100 | "train800val200": 1000,
101 | }
102 |
103 | # Defines dataset specific train/val/trainval/test splits.
104 | tfds_splits = {
105 | "train": "train[:{}]".format(num_samples_splits["train"]),
106 | "val": "train[{}:]".format(num_samples_splits["train"]),
107 | "trainval": "train",
108 | "test": "validation",
109 | "train800": "train[:800]",
110 | "val200": "train[{}:{}]".format(
111 | num_samples_splits["train"], num_samples_splits["train"]+200),
112 | "train800val200": "train[:800]+train[{}:{}]".format(
113 | num_samples_splits["train"], num_samples_splits["train"]+200),
114 | }
115 |
116 | task = _TASK_DICT[task]
117 | base_preprocess_fn = task["preprocess_fn"]
118 |
119 | super(CLEVRData, self).__init__(
120 | dataset_builder=dataset_builder,
121 | tfds_splits=tfds_splits,
122 | num_samples_splits=num_samples_splits,
123 | num_preprocessing_threads=400,
124 | shuffle_buffer_size=10000,
125 | # Note: Export only image and label tensors with their original types.
126 | base_preprocess_fn=base_preprocess_fn,
127 | num_classes=task["num_classes"])
128 |
--------------------------------------------------------------------------------
/src/utils/pruner.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Pruner class
4 | """
5 | import torch
6 |
7 |
8 | class Pruner:
9 | def __init__(self, cfg):
10 | self.cfg = cfg
11 | if cfg.MODEL.TRANSFER_TYPE == "adapter" or cfg.MODEL.TRANSFER_TYPE == "mosa":
12 | self.num = cfg.MODEL.ADAPTER.EXPERT_NUM
13 | elif cfg.MODEL.TRANSFER_TYPE == "lora" or cfg.MODEL.TRANSFER_TYPE == "mosl":
14 | self.num = cfg.MODEL.LORA.EXPERT_NUM
15 | self.sparsity = 1 / self.num
16 |
17 | def score(self, param):
18 | raise NotImplementedError
19 |
20 | def prune(self, score):
21 | k = int((1.0 - self.sparsity) * score.numel())
22 | threshold, _ = torch.kthvalue(torch.flatten(score), k)
23 |
24 | zero = torch.LongTensor([0]).to(score.device)
25 | one = torch.LongTensor([1]).to(score.device)
26 | return torch.where(score <= threshold, zero, one)
27 |
28 | # def divide(self, score):
29 | # masks = []
30 | # zero = torch.LongTensor([0]).to(score.device)
31 | # one = torch.LongTensor([1]).to(score.device)
32 |
33 | # for i in range(self.num):
34 | # mask = torch.ones_like(score, dtype=torch.long)
35 |
36 | # lower_k = int((self.sparsity * i) * score.numel()) + 1
37 | # lower_threshold, _ = torch.kthvalue(torch.flatten(score), lower_k)
38 | # mask = torch.where(score < lower_threshold, zero, mask)
39 |
40 | # upper_k = int((self.sparsity * (i + 1)) * score.numel())
41 | # upper_threshold, _ = torch.kthvalue(torch.flatten(score), upper_k)
42 | # mask = torch.where(score > upper_threshold, zero, mask)
43 |
44 | # masks.append(mask)
45 |
46 | # return masks
47 |
48 | def divide(self, score, mode="all"):
49 | masks = []
50 | zero = torch.LongTensor([0]).to(score.device)
51 | one = torch.LongTensor([1]).to(score.device)
52 |
53 | for i in range(self.num):
54 | mask = torch.ones_like(score, dtype=torch.long)
55 | if self.cfg.MODEL.ADAPTER.BIAS and len(score.shape) == 1:
56 | masks.append(mask)
57 | continue
58 |
59 | if mode == "all":
60 | lower_k = int((self.sparsity * i) * score.numel()) + 1
61 | lower_threshold, _ = torch.kthvalue(torch.flatten(score), lower_k)
62 | mask = torch.where(score < lower_threshold, zero, mask)
63 |
64 | upper_k = int((self.sparsity * (i + 1)) * score.numel())
65 | upper_threshold, _ = torch.kthvalue(torch.flatten(score), upper_k)
66 | mask = torch.where(score > upper_threshold, zero, mask)
67 |
68 | elif mode == "row":
69 | r, c = score.shape
70 | flag = torch.zeros_like(score, dtype=torch.long)
71 | flag += torch.arange(r, device=score.device).view(r, 1)
72 |
73 | for j in range(r):
74 | lower_k = int((self.sparsity * i) * score[j].numel()) + 1
75 | lower_threshold, _ = torch.kthvalue(torch.flatten(score[j]), lower_k)
76 | mask = torch.where((score < lower_threshold) & (flag == j), zero, mask)
77 |
78 | upper_k = int((self.sparsity * (i + 1)) * score[j].numel())
79 | upper_threshold, _ = torch.kthvalue(torch.flatten(score[j]), upper_k)
80 | mask = torch.where((score > upper_threshold) & (flag == j), zero, mask)
81 |
82 | elif mode == "column":
83 | r, c = score.shape
84 | flag = torch.zeros_like(score, dtype=torch.long)
85 | flag += torch.arange(c, device=score.device)
86 |
87 | for j in range(c):
88 | lower_k = int((self.sparsity * i) * score[:, j].numel()) + 1
89 | lower_threshold, _ = torch.kthvalue(torch.flatten(score[:, j]), lower_k)
90 | mask = torch.where((score < lower_threshold) & (flag == j), zero, mask)
91 |
92 | upper_k = int((self.sparsity * (i + 1)) * score[:, j].numel())
93 | upper_threshold, _ = torch.kthvalue(torch.flatten(score[:, j]), upper_k)
94 | mask = torch.where((score > upper_threshold) & (flag == j), zero, mask)
95 |
96 | masks.append(mask)
97 |
98 | return masks
99 |
100 | class Rand(Pruner):
101 | def __init__(self, cfg):
102 | super(Rand, self).__init__(cfg)
103 |
104 | def score(self, param):
105 | return torch.rand_like(param)
106 |
107 |
--------------------------------------------------------------------------------
/src/configs/config.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | """Config system (based on Detectron's)."""
4 |
5 | from .config_node import CfgNode
6 |
7 |
8 | # Global config object
9 | _C = CfgNode()
10 | # Example usage:
11 | # from configs.config import cfg
12 |
13 | _C.DBG = False
14 | _C.GPU_ID = 0
15 | _C.OUTPUT_DIR = "./output"
16 | _C.RUN_N_TIMES = 5
17 | # Perform benchmarking to select the fastest CUDNN algorithms to use
18 | # Note that this may increase the memory usage and will likely not result
19 | # in overall speedups when variable size inputs are used (e.g. COCO training)
20 | _C.CUDNN_BENCHMARK = False
21 |
22 | # Number of GPUs to use (applies to both training and testing)
23 | _C.NUM_GPUS = 1
24 | _C.NUM_SHARDS = 1
25 |
26 | # Note that non-determinism may still be present due to non-deterministic
27 | # operator implementations in GPU operator libraries
28 | _C.SEED = None
29 |
30 | # ----------------------------------------------------------------------
31 | # Model options
32 | # ----------------------------------------------------------------------
33 | _C.MODEL = CfgNode()
34 | _C.MODEL.TRANSFER_TYPE = "linear" # one of linear, end2end, prompt, adapter, side, partial-1, tinytl-bias
35 | _C.MODEL.WEIGHT_PATH = "" # if resume from some checkpoint file
36 | _C.MODEL.SAVE_CKPT = False
37 |
38 | _C.MODEL.MODEL_ROOT = "" # root folder for pretrained model weights
39 |
40 | _C.MODEL.TYPE = "vit"
41 | _C.MODEL.MLP_NUM = 0
42 |
43 | _C.MODEL.LINEAR = CfgNode()
44 | _C.MODEL.LINEAR.MLP_SIZES = []
45 | _C.MODEL.LINEAR.DROPOUT = 0.1
46 |
47 | # ----------------------------------------------------------------------
48 | # Pruner options
49 | # ----------------------------------------------------------------------
50 | _C.PRUNER = CfgNode()
51 |
52 | _C.PRUNER.TYPE = "random"
53 | _C.PRUNER.NUM = 4
54 |
55 | # ----------------------------------------------------------------------
56 | # Adapter options
57 | # ----------------------------------------------------------------------
58 | _C.MODEL.ADAPTER = CfgNode()
59 | _C.MODEL.ADAPTER.BOTTLENECK_SIZE = 64
60 | _C.MODEL.ADAPTER.STYLE = "AdaptFormer"
61 | _C.MODEL.ADAPTER.SCALAR = 0.1
62 | _C.MODEL.ADAPTER.DROPOUT = 0.1
63 | _C.MODEL.ADAPTER.MOE = False
64 | _C.MODEL.ADAPTER.BIAS = True
65 | _C.MODEL.ADAPTER.EXPERT_NUM = 4
66 | _C.MODEL.ADAPTER.MERGE = "add"
67 | _C.MODEL.ADAPTER.SHARE = None
68 | _C.MODEL.ADAPTER.ADDITIONAL = False
69 | _C.MODEL.ADAPTER.DEEPREG = False
70 | _C.MODEL.ADAPTER.ADD_WEIGHT = 0.0
71 | _C.MODEL.ADAPTER.REG_WEIGHT = 1.0
72 |
73 | # ----------------------------------------------------------------------
74 | # LoRA options
75 | # ----------------------------------------------------------------------
76 | _C.MODEL.LORA = CfgNode()
77 | _C.MODEL.LORA.RANK = 16
78 | _C.MODEL.LORA.MODE = "qv"
79 | _C.MODEL.LORA.SCALAR = 1.0
80 | _C.MODEL.LORA.DROPOUT = 0.0
81 | _C.MODEL.LORA.MOE = False
82 | _C.MODEL.LORA.BIAS = False
83 | _C.MODEL.LORA.EXPERT_NUM = 4
84 | _C.MODEL.LORA.MERGE = "add"
85 | _C.MODEL.LORA.SHARE = None
86 | _C.MODEL.LORA.ADDITIONAL = False
87 | _C.MODEL.LORA.DEEPREG = False
88 | _C.MODEL.LORA.ADD_WEIGHT = 0.0
89 | _C.MODEL.LORA.REG_WEIGHT = 1.0
90 |
91 | # ----------------------------------------------------------------------
92 | # Solver options
93 | # ----------------------------------------------------------------------
94 | _C.SOLVER = CfgNode()
95 | _C.SOLVER.LOSS = "softmax"
96 | _C.SOLVER.LOSS_ALPHA = 0.01
97 |
98 | _C.SOLVER.OPTIMIZER = "sgd" # or "adamw"
99 | _C.SOLVER.MOMENTUM = 0.9
100 | _C.SOLVER.WEIGHT_DECAY = 0.0001
101 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0
102 |
103 | _C.SOLVER.PATIENCE = 300
104 |
105 |
106 | _C.SOLVER.SCHEDULER = "cosine"
107 |
108 | _C.SOLVER.BASE_LR = 0.01
109 | _C.SOLVER.BIAS_MULTIPLIER = 1. # for prompt + bias
110 |
111 | _C.SOLVER.WARMUP_EPOCH = 5
112 | _C.SOLVER.MERGE_EPOCH = 5
113 | _C.SOLVER.TOTAL_EPOCH = 30
114 | _C.SOLVER.LOG_EVERY_N = 1000
115 |
116 |
117 | _C.SOLVER.DBG_TRAINABLE = False # if True, will print the name of trainable params
118 |
119 | # ----------------------------------------------------------------------
120 | # Dataset options
121 | # ----------------------------------------------------------------------
122 | _C.DATA = CfgNode()
123 |
124 | _C.DATA.NAME = ""
125 | _C.DATA.DATAPATH = ""
126 | _C.DATA.FEATURE = "" # e.g. inat2021_supervised
127 |
128 | _C.DATA.PERCENTAGE = 1.0
129 | _C.DATA.NUMBER_CLASSES = -1
130 | _C.DATA.MULTILABEL = False
131 | _C.DATA.CLASS_WEIGHTS_TYPE = "none"
132 |
133 | _C.DATA.CROPSIZE = 224 # or 384
134 |
135 | _C.DATA.NO_TEST = False
136 | _C.DATA.BATCH_SIZE = 32
137 | # Number of data loader workers per training process
138 | _C.DATA.NUM_WORKERS = 4
139 | # Load data to pinned host memory
140 | _C.DATA.PIN_MEMORY = True
141 |
142 | _C.DIST_BACKEND = "nccl"
143 | _C.DIST_INIT_PATH = "env://"
144 | _C.DIST_INIT_FILE = ""
145 |
146 |
147 | def get_cfg():
148 | """
149 | Get a copy of the default config.
150 | """
151 | return _C.clone()
152 |
--------------------------------------------------------------------------------
/VTAB_SETUP.md:
--------------------------------------------------------------------------------
1 | # VTAB Preperation
2 |
3 | ## Download and prepare
4 |
5 | It is recommended to download the data before the experiments, to avoid duplicated effort if submitting experiments for multiple tuning protocols. Here are the collective command to set up the vtab data.
6 |
7 | ```python
8 | import tensorflow_datasets as tfds
9 | data_dir = "" # TODO: setup the data_dir to put the the data to, the DATA.DATAPATH value in config
10 |
11 | # caltech101
12 | dataset_builder = tfds.builder("caltech101:3.*.*", data_dir=data_dir)
13 | dataset_builder.download_and_prepare()
14 |
15 | # cifar100
16 | dataset_builder = tfds.builder("cifar100:3.*.*", data_dir=data_dir)
17 | dataset_builder.download_and_prepare()
18 |
19 | # clevr
20 | dataset_builder = tfds.builder("clevr:3.*.*", data_dir=data_dir)
21 | dataset_builder.download_and_prepare()
22 |
23 | # dmlab
24 | dataset_builder = tfds.builder("dmlab:2.0.1", data_dir=data_dir)
25 | dataset_builder.download_and_prepare()
26 |
27 | # dsprites
28 | dataset_builder = tfds.builder("dsprites:2.*.*", data_dir=data_dir)
29 | dataset_builder.download_and_prepare()
30 |
31 | # dtd
32 | dataset_builder = tfds.builder("dtd:3.*.*", data_dir=data_dir)
33 | dataset_builder.download_and_prepare()
34 |
35 | # eurosat
36 | subset="rgb"
37 | dataset_name = "eurosat/{}:2.*.*".format(subset)
38 | dataset_builder = tfds.builder(dataset_name, data_dir=data_dir)
39 | dataset_builder.download_and_prepare()
40 |
41 | # oxford_flowers102
42 | dataset_builder = tfds.builder("oxford_flowers102:2.*.*", data_dir=data_dir)
43 | dataset_builder.download_and_prepare()
44 |
45 | # oxford_iiit_pet
46 | dataset_builder = tfds.builder("oxford_iiit_pet:3.*.*", data_dir=data_dir)
47 | dataset_builder.download_and_prepare()
48 |
49 | # patch_camelyon
50 | dataset_builder = tfds.builder("patch_camelyon:2.*.*", data_dir=data_dir)
51 | dataset_builder.download_and_prepare()
52 |
53 | # smallnorb
54 | dataset_builder = tfds.builder("smallnorb:2.*.*", data_dir=data_dir)
55 | dataset_builder.download_and_prepare()
56 |
57 | # svhn
58 | dataset_builder = tfds.builder("svhn_cropped:3.*.*", data_dir=data_dir)
59 | dataset_builder.download_and_prepare()
60 | ```
61 |
62 | There are 4 datasets need special care:
63 |
64 | ```python
65 | # sun397 --> need cv2
66 | # cannot load one image, similar to issue here: https://github.com/tensorflow/datasets/issues/2889
67 | # "Image /t/track/outdoor/sun_aophkoiosslinihb.jpg could not be decoded by Tensorflow.""
68 | # sol: modify the file: "/fsx/menglin/conda/envs/prompt_tf/lib/python3.7/site-packages/tensorflow_datasets/image_classification/sun.py" to ignore those images
69 | dataset_builder = tfds.builder("sun397/tfds:4.*.*", data_dir=data_dir)
70 | dataset_builder.download_and_prepare()
71 |
72 | # kitti version is wrong from vtab repo, try 3.2.0 (https://github.com/google-research/task_adaptation/issues/18)
73 | dataset_builder = tfds.builder("kitti:3.2.0", data_dir=data_dir)
74 | dataset_builder.download_and_prepare()
75 |
76 |
77 | # diabetic_retinopathy
78 | """
79 | Download this dataset from Kaggle.
80 | https://www.kaggle.com/c/diabetic-retinopathy-detection/data
81 | After downloading,
82 | - unpack the test.zip file into /manual_dir/.
83 | - unpack the sample.zip to sample/.
84 | - unpack the sampleSubmissions.csv and trainLabels.csv.
85 |
86 | # ==== important! ====
87 | # 1. make sure to check that there are 5 train.zip files instead of 4 (somehow if you chose to download all from kaggle, the train.zip.005 file is missing)
88 | # 2. if unzip train.zip ran into issues, try to use jar xvf train.zip to handle huge zip file
89 | cat test.zip.* > test.zip
90 | cat train.zip.* > train.zip
91 | """
92 |
93 | config_and_version = "btgraham-300" + ":3.*.*"
94 | dataset_builder = tfds.builder("diabetic_retinopathy_detection/{}".format(config_and_version), data_dir=data_dir)
95 | dataset_builder.download_and_prepare()
96 |
97 |
98 | # resisc45
99 | """
100 | download/extract dataset artifacts manually:
101 | Dataset can be downloaded from OneDrive: https://1drv.ms/u/s!AmgKYzARBl5ca3HNaHIlzp_IXjs
102 | After downloading the rar file, please extract it to the manual_dir.
103 | """
104 |
105 | dataset_builder = tfds.builder("resisc45:3.*.*", data_dir=data_dir)
106 | dataset_builder.download_and_prepare()
107 | ```
108 |
109 | ## Notes
110 |
111 | ### TFDS version
112 |
113 | Note that the experimental results may be different with different API and/or dataset generation code versions. See more from [tfds documentation](https://www.tensorflow.org/datasets/datasets_versioning). Here are what we used for MoSA:
114 |
115 | ```bash
116 | tfds: 4.4.0+nightly
117 |
118 | # Natural:
119 | cifar100: 3.0.2
120 | caltech101: 3.0.1
121 | dtd: 3.0.1
122 | oxford_flowers102: 2.1.1
123 | oxford_iiit_pet: 3.2.0
124 | svhn_cropped: 3.0.0
125 | sun397: 4.0.0
126 |
127 | # Specialized:
128 | patch_camelyon: 2.0.0
129 | eurosat: 2.0.0
130 | resisc45: 3.0.0
131 | diabetic_retinopathy_detection: 3.0.0
132 |
133 |
134 | # Structured
135 | clevr: 3.1.0
136 | dmlab: 2.0.1
137 | kitti: 3.2.0
138 | dsprites: 2.0.0
139 | smallnorb: 2.0.0
140 | ```
141 |
--------------------------------------------------------------------------------
/src/engine/eval/multilabel.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | evaluate precision@1, @5 equal to Top1 and Top5 error rate
4 | """
5 | import numpy as np
6 | from typing import List, Tuple, Dict
7 | from sklearn.metrics import (
8 | precision_recall_curve,
9 | average_precision_score,
10 | f1_score
11 | )
12 |
13 |
14 | def get_continuous_ids(probe_labels: List[int]) -> Dict[int, int]:
15 | sorted(probe_labels)
16 | id2continuousid = {}
17 | for idx, p_id in enumerate(probe_labels):
18 | id2continuousid[p_id] = idx
19 | return id2continuousid
20 |
21 |
22 | def multihot(x: List[List[int]], nb_classes: int) -> np.ndarray:
23 | """transform to multihot encoding
24 |
25 | Arguments:
26 | x: list of multi-class integer labels, in the range
27 | [0, nb_classes-1]
28 | nb_classes: number of classes for the multi-hot vector
29 |
30 | Returns:
31 | multihot: multihot vector of type int, (num_samples, nb_classes)
32 | """
33 | num_samples = len(x)
34 |
35 | multihot = np.zeros((num_samples, nb_classes), dtype=np.int32)
36 | for idx, labs in enumerate(x):
37 | for lab in labs:
38 | multihot[idx, lab] = 1
39 |
40 | return multihot.astype(np.int)
41 |
42 |
43 | def compute_map(
44 | scores: np.ndarray, multihot_targets: np.ndarray
45 | ) -> Tuple[np.ndarray, np.ndarray, float, float]:
46 | """Compute the mean average precision across all class labels.
47 |
48 | Arguments:
49 | scores: matrix of per-class distances,
50 | of size num_samples x nb_classes
51 | multihot_targets: matrix of multi-hot target predictions,
52 | of size num_samples x nb_classes
53 |
54 | Returns:
55 | ap: list of average-precision scores, one for each of
56 | the nb_classes classes.
57 | ar: list of average-recall scores, one for each of
58 | the nb_classes classes.
59 | mAP: the mean average precision score over all average
60 | precisions for all nb_classes classes.
61 | mAR: the mean average recall score over all average
62 | precisions for all nb_classes classes.
63 | """
64 | nb_classes = scores.shape[1]
65 |
66 | ap = np.zeros((nb_classes,), dtype=np.float32)
67 | ar = np.zeros((nb_classes,), dtype=np.float32)
68 |
69 | for c in range(nb_classes):
70 | y_true = multihot_targets[:, c]
71 | y_scores = scores[:, c]
72 |
73 | # Use interpolated average precision (a la PASCAL
74 | try:
75 | ap[c] = average_precision_score(y_true, y_scores)
76 | except ValueError:
77 | ap[c] = -1
78 |
79 | # Also get the average of the recalls on the raw PR-curve
80 | try:
81 | _, rec, _ = precision_recall_curve(y_true, y_scores)
82 | ar[c] = rec.mean()
83 | except ValueError:
84 | ar[c] = -1
85 |
86 | mAP = ap.mean()
87 | mAR = ar.mean()
88 |
89 | return ap, ar, mAP, mAR
90 |
91 |
92 | def compute_f1(
93 | multihot_targets: np.ndarray, scores: np.ndarray, threshold: float = 0.5
94 | ) -> Tuple[float, float, float]:
95 | # change scores to predict_labels
96 | predict_labels = scores > threshold
97 | predict_labels = predict_labels.astype(np.int)
98 |
99 | # change targets to multihot
100 | f1 = {}
101 | f1["micro"] = f1_score(
102 | y_true=multihot_targets,
103 | y_pred=predict_labels,
104 | average="micro"
105 | )
106 | f1["samples"] = f1_score(
107 | y_true=multihot_targets,
108 | y_pred=predict_labels,
109 | average="samples"
110 | )
111 | f1["macro"] = f1_score(
112 | y_true=multihot_targets,
113 | y_pred=predict_labels,
114 | average="macro"
115 | )
116 | f1["none"] = f1_score(
117 | y_true=multihot_targets,
118 | y_pred=predict_labels,
119 | average=None
120 | )
121 | return f1["micro"], f1["samples"], f1["macro"], f1["none"]
122 |
123 |
124 | def get_best_f1_scores(
125 | multihot_targets: np.ndarray, scores: np.ndarray, threshold_end: int
126 | ) -> Dict[str, float]:
127 | end = 0.5
128 | end = 0.05
129 | end = threshold_end
130 | thrs = np.linspace(
131 | end, 0.95, int(np.round((0.95 - end) / 0.05)) + 1, endpoint=True
132 | )
133 | f1_micros = []
134 | f1_macros = []
135 | f1_samples = []
136 | f1_none = []
137 | for thr in thrs:
138 | _micros, _samples, _macros, _none = compute_f1(multihot_targets, scores, thr)
139 | f1_micros.append(_micros)
140 | f1_samples.append(_samples)
141 | f1_macros.append(_macros)
142 | f1_none.append(_none)
143 |
144 | f1_macros_m = max(f1_macros)
145 | b_thr = np.argmax(f1_macros)
146 |
147 | f1_micros_m = f1_micros[b_thr]
148 | f1_samples_m = f1_samples[b_thr]
149 | f1_none_m = f1_none[b_thr]
150 | f1 = {}
151 | f1["micro"] = f1_micros_m
152 | f1["macro"] = f1_macros_m
153 | f1["samples"] = f1_samples_m
154 | f1["threshold"] = thrs[b_thr]
155 | f1["none"] = f1_none_m
156 | return f1
157 |
--------------------------------------------------------------------------------
/src/data/vtab_datasets/dsprites.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3 | # Copyright 2019 Google LLC.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Implements the DSprites data class."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | import tensorflow.compat.v1 as tf
24 | import tensorflow_datasets as tfds
25 |
26 | from . import base as base
27 | from .registry import Registry
28 |
29 |
30 | # These constants specify the percentage of data that is used to create custom
31 | # train/val splits. Specifically, TRAIN_SPLIT_PERCENT% of the data set is used
32 | # as a new training split and VAL_SPLIT_PERCENT% is used for validation.
33 | # The rest is used for testing.
34 | TRAIN_SPLIT_PERCENT = 80
35 | VAL_SPLIT_PERCENT = 10
36 |
37 |
38 | @Registry.register("data.dsprites", "class")
39 | class DSpritesData(base.ImageTfdsData):
40 | """Provides the DSprites data set.
41 |
42 | DSprites only comes with a training set. Therefore, the training, validation,
43 | and test set are split out of the original training set.
44 |
45 | For additional details and usage, see the base class.
46 |
47 | The data set page is https://github.com/deepmind/dsprites-dataset/.
48 | """
49 |
50 | def __init__(self, predicted_attribute, num_classes=None, data_dir=None):
51 | dataset_builder = tfds.builder("dsprites:2.*.*", data_dir=data_dir)
52 | dataset_builder.download_and_prepare()
53 | info = dataset_builder.info
54 |
55 | if predicted_attribute not in dataset_builder.info.features:
56 | raise ValueError(
57 | "{} is not a valid attribute to predict.".format(predicted_attribute))
58 |
59 | # If num_classes is set, we group together nearby integer values to arrive
60 | # at the desired number of classes. This is useful for example for grouping
61 | # together different spatial positions.
62 | num_original_classes = info.features[predicted_attribute].num_classes
63 | if num_classes is None:
64 | num_classes = num_original_classes
65 | if not isinstance(num_classes, int) or num_classes <= 1 or (
66 | num_classes > num_original_classes):
67 | raise ValueError(
68 | "The number of classes should be None or in [2, ..., num_classes].")
69 | class_division_factor = float(num_original_classes) / num_classes
70 |
71 | # Creates a dict with example counts for each split.
72 | num_total = dataset_builder.info.splits["train"].num_examples
73 | num_samples_train = TRAIN_SPLIT_PERCENT * num_total // 100
74 | num_samples_val = VAL_SPLIT_PERCENT * num_total // 100
75 | num_samples_splits = {
76 | "train": num_samples_train,
77 | "val": num_samples_val,
78 | "trainval": num_samples_val + num_samples_train,
79 | "test": num_total - num_samples_val - num_samples_train,
80 | "train800": 800,
81 | "val200": 200,
82 | "train800val200": 1000,
83 | }
84 |
85 | # Defines dataset specific train/val/trainval/test splits.
86 | tfds_splits = {
87 | "train": "train[:{}]".format(num_samples_splits["train"]),
88 | "val": "train[{}:{}]".format(num_samples_splits["train"],
89 | num_samples_splits["trainval"]),
90 | "trainval": "train[:{}]".format(num_samples_splits["trainval"]),
91 | "test": "train[{}:]".format(num_samples_splits["trainval"]),
92 | "train800": "train[:800]",
93 | "val200": "train[{}:{}]".format(num_samples_splits["train"],
94 | num_samples_splits["train"]+200),
95 | "train800val200": "train[:800]+train[{}:{}]".format(
96 | num_samples_splits["train"], num_samples_splits["train"]+200),
97 | }
98 |
99 | def preprocess_fn(tensors):
100 | # For consistency with other datasets, image needs to have three channels
101 | # and be in [0, 255).
102 | images = tf.tile(tensors["image"], [1, 1, 3]) * 255
103 | label = tf.cast(
104 | tf.math.floordiv(
105 | tf.cast(tensors[predicted_attribute], tf.float32),
106 | class_division_factor), info.features[predicted_attribute].dtype)
107 | return dict(image=images, label=label)
108 |
109 | super(DSpritesData, self).__init__(
110 | dataset_builder=dataset_builder,
111 | tfds_splits=tfds_splits,
112 | num_samples_splits=num_samples_splits,
113 | num_preprocessing_threads=400,
114 | shuffle_buffer_size=10000,
115 | # We extract the attribute we want to predict in the preprocessing.
116 | base_preprocess_fn=preprocess_fn,
117 | num_classes=num_classes)
118 |
--------------------------------------------------------------------------------
/src/models/vit_adapter/swin_adapter.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | vit with adapter
4 | """
5 | import copy
6 | import math
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 |
11 | from ..vit_backbones.swin_transformer import Mlp, SwinTransformerBlock, BasicLayer, PatchMerging, SwinTransformer
12 | from ...utils import logging
13 | logger = logging.get_logger("MOSA")
14 |
15 |
16 | class AdaptedMlp(Mlp):
17 | def __init__(
18 | self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,
19 | adapter_config=None, adapter_scalar=1.0, dropout=0.0
20 | ):
21 | super(AdaptedMlp, self).__init__(in_features, hidden_features, out_features, act_layer, drop)
22 | self.adapter_config = adapter_config
23 |
24 | if adapter_scalar is None:
25 | self.adapter_scale = nn.Parameter(torch.ones(1))
26 | else:
27 | self.adapter_scale = adapter_scalar
28 | self.dropout = dropout
29 |
30 | out_features = out_features or in_features
31 | self.adapter_down = nn.Linear(
32 | in_features,
33 | adapter_config.BOTTLENECK_SIZE
34 | )
35 | self.adapter_up = nn.Linear(
36 | adapter_config.BOTTLENECK_SIZE,
37 | out_features
38 | )
39 | self.adapter_act_fn = nn.ReLU()
40 |
41 | nn.init.kaiming_uniform_(self.adapter_down.weight, a=math.sqrt(5))
42 | nn.init.zeros_(self.adapter_down.bias)
43 |
44 | nn.init.zeros_(self.adapter_up.weight)
45 | nn.init.zeros_(self.adapter_up.bias)
46 |
47 | def forward(self, x):
48 | # same as reguluar Mlp block
49 |
50 | h = x
51 | x = self.fc1(x)
52 | x = self.act(x)
53 | x = self.drop(x)
54 | x = self.fc2(x)
55 | x = self.drop(x)
56 |
57 | # start to insert adapter layers...
58 | adpt = self.adapter_down(h)
59 | adpt = self.adapter_act_fn(adpt)
60 | adpt = nn.functional.dropout(adpt, p=self.dropout, training=self.training)
61 | adpt = self.adapter_up(adpt)
62 |
63 | x = adpt * self.adapter_scale + x
64 | # ...end
65 |
66 | return x
67 |
68 |
69 | class AdaptedSwinTransformerBlock(SwinTransformerBlock):
70 | def __init__(
71 | self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
72 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
73 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, adapter_config=None
74 | ):
75 | super(AdaptedSwinTransformerBlock, self).__init__(
76 | dim, input_resolution, num_heads, window_size,
77 | shift_size, mlp_ratio, qkv_bias, qk_scale, drop,
78 | attn_drop, drop_path, act_layer, norm_layer
79 | )
80 | self.adapter_config = adapter_config
81 | mlp_hidden_dim = int(dim * mlp_ratio)
82 | self.mlp = AdaptedMlp(
83 | in_features=dim, hidden_features=mlp_hidden_dim,
84 | act_layer=act_layer, drop=drop, adapter_config=adapter_config
85 | )
86 |
87 |
88 | class AdaptedSwinTransformer(SwinTransformer):
89 | def __init__(
90 | self, adapter_config, img_size=224, patch_size=4, in_chans=3,
91 | num_classes=1000, embed_dim=96, depths=[2, 2, 6, 2],
92 | num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4., qkv_bias=True,
93 | qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
94 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
95 | use_checkpoint=False, **kwargs
96 | ):
97 | super(AdaptedSwinTransformer, self).__init__(
98 | img_size, patch_size, in_chans, num_classes, embed_dim, depths,
99 | num_heads, window_size, mlp_ratio, qkv_bias, qk_scale, drop_rate,
100 | attn_drop_rate, drop_path_rate, norm_layer, ape, patch_norm,
101 | use_checkpoint, **kwargs
102 | )
103 | self.adapter_config = adapter_config
104 |
105 | # stochastic depth
106 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
107 |
108 | # build layers
109 | self.layers = nn.ModuleList()
110 | for i_layer in range(self.num_layers):
111 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
112 | input_resolution=(
113 | self.patches_resolution[0] // (2 ** i_layer),
114 | self.patches_resolution[1] // (2 ** i_layer)),
115 | depth=depths[i_layer],
116 | num_heads=num_heads[i_layer],
117 | window_size=window_size,
118 | mlp_ratio=self.mlp_ratio,
119 | qkv_bias=qkv_bias, qk_scale=qk_scale,
120 | drop=drop_rate, attn_drop=attn_drop_rate,
121 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
122 | norm_layer=norm_layer,
123 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
124 | use_checkpoint=use_checkpoint,
125 | block_module=AdaptedSwinTransformerBlock,
126 | adapter_config=adapter_config
127 | )
128 | self.layers.append(layer)
129 |
--------------------------------------------------------------------------------
/src/data/vtab_datasets/registry.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3 | # Copyright 2019 Google LLC.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Global Registry for the task adaptation framework.
18 | """
19 |
20 | from __future__ import absolute_import
21 | from __future__ import division
22 | from __future__ import print_function
23 |
24 | import ast
25 | import functools
26 |
27 |
28 | def partialclass(cls, *base_args, **base_kwargs):
29 | """Builds a subclass with partial application of the given args and keywords.
30 |
31 | Equivalent to functools.partial performance, base_args are preprended to the
32 | positional arguments given during object initialization and base_kwargs are
33 | updated with the kwargs given later.
34 |
35 | Args:
36 | cls: The base class.
37 | *base_args: Positional arguments to be applied to the subclass.
38 | **base_kwargs: Keyword arguments to be applied to the subclass.
39 |
40 | Returns:
41 | A subclass of the input class.
42 | """
43 |
44 | class _NewClass(cls):
45 |
46 | def __init__(self, *args, **kwargs):
47 | bound_args = base_args + args
48 | bound_kwargs = base_kwargs.copy()
49 | bound_kwargs.update(kwargs)
50 | super(_NewClass, self).__init__(*bound_args, **bound_kwargs)
51 |
52 | return _NewClass
53 |
54 |
55 | def parse_name(string_to_parse):
56 | """Parses input to the registry's lookup function.
57 |
58 | Args:
59 | string_to_parse: can be either an arbitrary name or function call
60 | (optionally with positional and keyword arguments).
61 | e.g. "multiclass", "resnet50_v2(filters_factor=8)".
62 |
63 | Returns:
64 | A tuple of input name and a dctinary with arguments. Examples:
65 | "multiclass" -> ("multiclass", (), {})
66 | "resnet50_v2(9, filters_factor=4)" ->
67 | ("resnet50_v2", (9,), {"filters_factor": 4})
68 | """
69 | expr = ast.parse(string_to_parse, mode="eval").body # pytype: disable=attribute-error
70 | if not isinstance(expr, (ast.Attribute, ast.Call, ast.Name)):
71 | raise ValueError(
72 | "The given string should be a name or a call, but a {} was parsed from "
73 | "the string {!r}".format(type(expr), string_to_parse))
74 |
75 | # Notes:
76 | # name="some_name" -> type(expr) = ast.Name
77 | # name="module.some_name" -> type(expr) = ast.Attribute
78 | # name="some_name()" -> type(expr) = ast.Call
79 | # name="module.some_name()" -> type(expr) = ast.Call
80 |
81 | if isinstance(expr, ast.Name):
82 | return string_to_parse, {}
83 | elif isinstance(expr, ast.Attribute):
84 | return string_to_parse, {}
85 |
86 | def _get_func_name(expr):
87 | if isinstance(expr, ast.Attribute):
88 | return _get_func_name(expr.value) + "." + expr.attr
89 | elif isinstance(expr, ast.Name):
90 | return expr.id
91 | else:
92 | raise ValueError(
93 | "Type {!r} is not supported in a function name, the string to parse "
94 | "was {!r}".format(type(expr), string_to_parse))
95 |
96 | def _get_func_args_and_kwargs(call):
97 | args = tuple([ast.literal_eval(arg) for arg in call.args])
98 | kwargs = {
99 | kwarg.arg: ast.literal_eval(kwarg.value) for kwarg in call.keywords
100 | }
101 | return args, kwargs
102 |
103 | func_name = _get_func_name(expr.func)
104 | func_args, func_kwargs = _get_func_args_and_kwargs(expr)
105 | if func_args:
106 | raise ValueError("Positional arguments are not supported here, but these "
107 | "were found: {!r}".format(func_args))
108 |
109 | return func_name, func_kwargs
110 |
111 |
112 | class Registry(object):
113 | """Implements global Registry."""
114 |
115 | _GLOBAL_REGISTRY = {}
116 |
117 | @staticmethod
118 | def global_registry():
119 | return Registry._GLOBAL_REGISTRY
120 |
121 | @staticmethod
122 | def register(name, item_type):
123 | """Creates a function that registers its input."""
124 | if item_type not in ["function", "class"]:
125 | raise ValueError("Unknown item type: %s" % item_type)
126 |
127 | def _register(item):
128 | if name in Registry.global_registry():
129 | raise KeyError(
130 | "The name {!r} was already registered in with type {!r}".format(
131 | name, item_type))
132 |
133 | Registry.global_registry()[name] = (item, item_type)
134 | return item
135 |
136 | return _register
137 |
138 | @staticmethod
139 | def lookup(lookup_string, kwargs_extra=None):
140 | """Lookup a name in the registry."""
141 |
142 | name, kwargs = parse_name(lookup_string)
143 | if kwargs_extra:
144 | kwargs.update(kwargs_extra)
145 | item, item_type = Registry.global_registry()[name]
146 | if item_type == "function":
147 | return functools.partial(item, **kwargs)
148 | elif item_type == "class":
149 | return partialclass(item, **kwargs)
150 |
--------------------------------------------------------------------------------
/src/utils/distributed.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | """Distributed helpers."""
4 |
5 | import torch
6 | import torch.distributed as dist
7 | _LOCAL_PROCESS_GROUP = None
8 |
9 |
10 | def get_world_size() -> int:
11 | if not dist.is_available():
12 | return 1
13 | if not dist.is_initialized():
14 | return 1
15 | return dist.get_world_size()
16 |
17 |
18 | def get_rank() -> int:
19 | if not dist.is_available():
20 | return 0
21 | if not dist.is_initialized():
22 | return 0
23 | return dist.get_rank()
24 |
25 |
26 | def is_master_process(num_gpus=8):
27 | """
28 | Determines if the current process is the master process.
29 | """
30 | if torch.distributed.is_initialized():
31 | return dist.get_rank() % num_gpus == 0
32 | else:
33 | return True
34 |
35 |
36 | def run(
37 | local_rank,
38 | num_proc,
39 | func,
40 | init_method,
41 | shard_id,
42 | num_shards,
43 | backend,
44 | cfg,
45 | args,
46 | ):
47 | """
48 | Runs a function from a child process.
49 | Args:
50 | local_rank (int): rank of the current process on the current machine.
51 | num_proc (int): number of processes per machine.
52 | func (function): function to execute on each of the process.
53 | init_method (string): method to initialize the distributed training.
54 | TCP initialization: equiring a network address reachable from all
55 | processes followed by the port.
56 | Shared file-system initialization: makes use of a file system that
57 | is shared and visible from all machines. The URL should start with
58 | file:// and contain a path to a non-existent file on a shared file
59 | system.
60 | shard_id (int): the rank of the current machine.
61 | num_shards (int): number of overall machines for the distributed
62 | training job.
63 | backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are
64 | supports, each with different capabilities. Details can be found
65 | here:
66 | https://pytorch.org/docs/stable/distributed.html
67 | cfg (CfgNode): configs. Details can be found in
68 | loco/config/defaults.py
69 | """
70 | # Initialize the process group.
71 | # shard_id = get_rank()
72 | world_size = num_proc * num_shards
73 | rank = shard_id * num_proc + local_rank
74 |
75 | try:
76 | torch.distributed.init_process_group(
77 | backend=backend,
78 | init_method=init_method,
79 | world_size=world_size,
80 | rank=rank,
81 | )
82 | except Exception as e:
83 | raise e
84 |
85 | torch.cuda.set_device(local_rank)
86 | func(cfg, args)
87 |
88 |
89 | def destroy_process_group():
90 | """Destroys the default process group."""
91 | torch.distributed.destroy_process_group()
92 |
93 |
94 | def scaled_all_reduce(cfg, tensors):
95 | """Performs the scaled all_reduce operation on the provided tensors.
96 |
97 | The input tensors are modified in-place. Currently supports only the sum
98 | reduction operator. The reduced values are scaled by the inverse size of
99 | the process group (equivalent to cfg.NUM_GPUS).
100 | """
101 | # Queue the reductions
102 | reductions = []
103 | for tensor in tensors:
104 | reduction = torch.distributed.all_reduce(tensor, async_op=True)
105 | reductions.append(reduction)
106 | # Wait for reductions to finish
107 | for reduction in reductions:
108 | reduction.wait()
109 | # Scale the results
110 | for tensor in tensors:
111 | tensor.mul_(1.0 / cfg.NUM_GPUS / cfg.NUM_SHARDS)
112 | return tensors
113 |
114 |
115 | def cat_all_gather(tensors):
116 | """Performs the concatenated all_gather operation on the provided tensors.
117 | """
118 | tensors_gather = [
119 | torch.ones_like(tensors)
120 | for _ in range(torch.distributed.get_world_size())
121 | ]
122 | torch.distributed.all_gather(tensors_gather, tensors, async_op=False)
123 |
124 | output = torch.cat(tensors_gather, dim=0)
125 | return output
126 |
127 |
128 | def local_cat_all_gather(tensors):
129 | """Performs the concatenated all_gather operation on the provided tensors.
130 | """
131 | tensors_gather = [
132 | torch.ones_like(tensors)
133 | for _ in range(get_local_size())
134 | ]
135 | torch.distributed.all_gather(
136 | tensors_gather,
137 | tensors,
138 | async_op=False,
139 | group=_LOCAL_PROCESS_GROUP,
140 | )
141 | output = torch.cat(tensors_gather, dim=0)
142 | return output
143 |
144 |
145 | def get_local_size():
146 | """
147 | Returns:
148 | The size of the per-machine process group,
149 | i.e. the number of processes per machine.
150 | """
151 | if not dist.is_available():
152 | return 1
153 | if not dist.is_initialized():
154 | return 1
155 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
156 |
157 |
158 | def get_local_rank():
159 | """
160 | Returns:
161 | The rank of the current process within the local (per-machine) process group.
162 | """
163 | if not dist.is_available():
164 | return 0
165 | if not dist.is_initialized():
166 | return 0
167 | assert _LOCAL_PROCESS_GROUP is not None
168 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
169 |
--------------------------------------------------------------------------------
/src/models/vit_backbones/vit_moco.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3 | """
4 | borrowed from https://github.com/facebookresearch/moco-v3/blob/main/vits.py
5 | """
6 | import math
7 | import torch
8 | import torch.nn as nn
9 | from functools import partial, reduce
10 | from operator import mul
11 |
12 | from timm.models.vision_transformer import VisionTransformer, _cfg
13 | from timm.models.layers.helpers import to_2tuple
14 | from timm.models.layers import PatchEmbed
15 |
16 | __all__ = [
17 | 'vit_small',
18 | 'vit_base',
19 | 'vit_conv_small',
20 | 'vit_conv_base',
21 | ]
22 |
23 |
24 | class VisionTransformerMoCo(VisionTransformer):
25 | def __init__(self, stop_grad_conv1=False, **kwargs):
26 | super().__init__(**kwargs)
27 | # Use fixed 2D sin-cos position embedding
28 | self.build_2d_sincos_position_embedding()
29 |
30 | # weight initialization
31 | for name, m in self.named_modules():
32 | if isinstance(m, nn.Linear):
33 | if 'qkv' in name:
34 | # treat the weights of Q, K, V separately
35 | val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1]))
36 | nn.init.uniform_(m.weight, -val, val)
37 | else:
38 | nn.init.xavier_uniform_(m.weight)
39 | nn.init.zeros_(m.bias)
40 | nn.init.normal_(self.cls_token, std=1e-6)
41 |
42 | if isinstance(self.patch_embed, PatchEmbed):
43 | # xavier_uniform initialization
44 | val = math.sqrt(6. / float(3 * reduce(mul, self.patch_embed.patch_size, 1) + self.embed_dim))
45 | nn.init.uniform_(self.patch_embed.proj.weight, -val, val)
46 | nn.init.zeros_(self.patch_embed.proj.bias)
47 |
48 | if stop_grad_conv1:
49 | self.patch_embed.proj.weight.requires_grad = False
50 | self.patch_embed.proj.bias.requires_grad = False
51 |
52 | def build_2d_sincos_position_embedding(self, temperature=10000.):
53 | h, w = self.patch_embed.grid_size
54 | grid_w = torch.arange(w, dtype=torch.float32)
55 | grid_h = torch.arange(h, dtype=torch.float32)
56 | grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
57 | assert self.embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
58 | pos_dim = self.embed_dim // 4
59 | omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
60 | omega = 1. / (temperature**omega)
61 | out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
62 | out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
63 | pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
64 |
65 | assert self.num_tokens == 1, 'Assuming one and only one token, [cls]'
66 | pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32)
67 | self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
68 | self.pos_embed.requires_grad = False
69 |
70 |
71 | class ConvStem(nn.Module):
72 | """
73 | ConvStem, from Early Convolutions Help Transformers See Better, Tete et al. https://arxiv.org/abs/2106.14881
74 | """
75 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
76 | super().__init__()
77 |
78 | assert patch_size == 16, 'ConvStem only supports patch size of 16'
79 | assert embed_dim % 8 == 0, 'Embed dimension must be divisible by 8 for ConvStem'
80 |
81 | img_size = to_2tuple(img_size)
82 | patch_size = to_2tuple(patch_size)
83 | self.img_size = img_size
84 | self.patch_size = patch_size
85 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
86 | self.num_patches = self.grid_size[0] * self.grid_size[1]
87 | self.flatten = flatten
88 |
89 | # build stem, similar to the design in https://arxiv.org/abs/2106.14881
90 | stem = []
91 | input_dim, output_dim = 3, embed_dim // 8
92 | for l in range(4):
93 | stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False))
94 | stem.append(nn.BatchNorm2d(output_dim))
95 | stem.append(nn.ReLU(inplace=True))
96 | input_dim = output_dim
97 | output_dim *= 2
98 | stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1))
99 | self.proj = nn.Sequential(*stem)
100 |
101 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
102 |
103 | def forward(self, x):
104 | B, C, H, W = x.shape
105 | assert H == self.img_size[0] and W == self.img_size[1], \
106 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
107 | x = self.proj(x)
108 | if self.flatten:
109 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
110 | x = self.norm(x)
111 | return x
112 |
113 |
114 | def vit_small(**kwargs):
115 | model = VisionTransformerMoCo(
116 | patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
117 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
118 | model.default_cfg = _cfg()
119 | return model
120 |
121 | def vit_base(**kwargs):
122 | model = VisionTransformerMoCo(
123 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
124 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
125 | model.default_cfg = _cfg()
126 | return model
127 |
128 | def vit_conv_small(**kwargs):
129 | # minus one ViT block
130 | model = VisionTransformerMoCo(
131 | patch_size=16, embed_dim=384, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True,
132 | norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs)
133 | model.default_cfg = _cfg()
134 | return model
135 |
136 | def vit_conv_base(**kwargs):
137 | # minus one ViT block
138 | model = VisionTransformerMoCo(
139 | patch_size=16, embed_dim=768, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True,
140 | norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs)
141 | model.default_cfg = _cfg()
142 | return model
143 |
--------------------------------------------------------------------------------
/src/utils/logging.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | """Logging."""
4 |
5 | import builtins
6 | import decimal
7 | import functools
8 | import logging
9 | import simplejson
10 | import sys
11 | import os
12 | from termcolor import colored
13 |
14 | from .distributed import is_master_process
15 | from .file_io import PathManager
16 |
17 | # Show filename and line number in logs
18 | _FORMAT = "[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s"
19 |
20 |
21 | def _suppress_print():
22 | """Suppresses printing from the current process."""
23 |
24 | def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False):
25 | pass
26 |
27 | builtins.print = print_pass
28 |
29 |
30 | # cache the opened file object, so that different calls to `setup_logger`
31 | # with the same file name can safely write to the same file.
32 | @functools.lru_cache(maxsize=None)
33 | def _cached_log_stream(filename):
34 | return PathManager.open(filename, "a")
35 |
36 |
37 | @functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers # noqa
38 | def setup_logging(
39 | num_gpu, num_shards, output="", name="MOSA", color=True):
40 | """Sets up the logging."""
41 | # Enable logging only for the master process
42 | if is_master_process(num_gpu):
43 | # Clear the root logger to prevent any existing logging config
44 | # (e.g. set by another module) from messing with our setup
45 | logging.root.handlers = []
46 | # Configure logging
47 | logging.basicConfig(
48 | level=logging.INFO, format=_FORMAT, stream=sys.stdout
49 | )
50 | else:
51 | _suppress_print()
52 |
53 | if name is None:
54 | name = __name__
55 | logger = logging.getLogger(name)
56 | # remove any lingering handler
57 | logger.handlers.clear()
58 |
59 | logger.setLevel(logging.INFO)
60 | logger.propagate = False
61 |
62 | plain_formatter = logging.Formatter(
63 | "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s",
64 | datefmt="%m/%d %H:%M:%S",
65 | )
66 | if color:
67 | formatter = _ColorfulFormatter(
68 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
69 | datefmt="%m/%d %H:%M:%S",
70 | root_name=name,
71 | abbrev_name=str(name),
72 | )
73 | else:
74 | formatter = plain_formatter
75 |
76 | if is_master_process(num_gpu):
77 | ch = logging.StreamHandler(stream=sys.stdout)
78 | ch.setLevel(logging.DEBUG)
79 | ch.setFormatter(formatter)
80 | logger.addHandler(ch)
81 |
82 | if is_master_process(num_gpu * num_shards):
83 | if len(output) > 0:
84 | if output.endswith(".txt") or output.endswith(".log"):
85 | filename = output
86 | else:
87 | filename = os.path.join(output, "logs.txt")
88 |
89 | PathManager.mkdirs(os.path.dirname(filename))
90 |
91 | fh = logging.StreamHandler(_cached_log_stream(filename))
92 | fh.setLevel(logging.DEBUG)
93 | fh.setFormatter(plain_formatter)
94 | logger.addHandler(fh)
95 | return logger
96 |
97 |
98 | def setup_single_logging(name, output=""):
99 | """Sets up the logging."""
100 | # Enable logging only for the master process
101 | # Clear the root logger to prevent any existing logging config
102 | # (e.g. set by another module) from messing with our setup
103 | logging.root.handlers = []
104 | # Configure logging
105 | logging.basicConfig(
106 | level=logging.INFO, format=_FORMAT, stream=sys.stdout
107 | )
108 |
109 | if len(name) == 0:
110 | name = __name__
111 | logger = logging.getLogger(name)
112 | logger.setLevel(logging.INFO)
113 | logger.propagate = False
114 |
115 | plain_formatter = logging.Formatter(
116 | "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s",
117 | datefmt="%m/%d %H:%M:%S",
118 | )
119 | formatter = _ColorfulFormatter(
120 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
121 | datefmt="%m/%d %H:%M:%S",
122 | root_name=name,
123 | abbrev_name=str(name),
124 | )
125 |
126 | ch = logging.StreamHandler(stream=sys.stdout)
127 | ch.setLevel(logging.DEBUG)
128 | ch.setFormatter(formatter)
129 | logger.addHandler(ch)
130 |
131 | if len(output) > 0:
132 | if output.endswith(".txt") or output.endswith(".log"):
133 | filename = output
134 | else:
135 | filename = os.path.join(output, "logs.txt")
136 |
137 | PathManager.mkdirs(os.path.dirname(filename))
138 |
139 | fh = logging.StreamHandler(_cached_log_stream(filename))
140 | fh.setLevel(logging.DEBUG)
141 | fh.setFormatter(plain_formatter)
142 | logger.addHandler(fh)
143 |
144 | return logger
145 |
146 |
147 | def get_logger(name):
148 | """Retrieves the logger."""
149 | return logging.getLogger(name)
150 |
151 |
152 | def log_json_stats(stats, sort_keys=True):
153 | """Logs json stats."""
154 | # It seems that in Python >= 3.6 json.encoder.FLOAT_REPR has no effect
155 | # Use decimal+string as a workaround for having fixed length values in logs
156 | logger = get_logger(__name__)
157 | stats = {
158 | k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v
159 | for k, v in stats.items()
160 | }
161 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True)
162 | if stats["_type"] == "test_epoch" or stats["_type"] == "train_epoch":
163 | logger.info("json_stats: {:s}".format(json_stats))
164 | else:
165 | logger.info("{:s}".format(json_stats))
166 |
167 |
168 | class _ColorfulFormatter(logging.Formatter):
169 | # from detectron2
170 | def __init__(self, *args, **kwargs):
171 | self._root_name = kwargs.pop("root_name") + "."
172 | self._abbrev_name = kwargs.pop("abbrev_name", "")
173 | if len(self._abbrev_name):
174 | self._abbrev_name = self._abbrev_name + "."
175 | super(_ColorfulFormatter, self).__init__(*args, **kwargs)
176 |
177 | def formatMessage(self, record: logging.LogRecord) -> str:
178 | record.name = record.name.replace(self._root_name, self._abbrev_name)
179 | log = super(_ColorfulFormatter, self).formatMessage(record)
180 | if record.levelno == logging.WARNING:
181 | prefix = colored("WARNING", "red", attrs=["blink"])
182 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
183 | prefix = colored("ERROR", "red", attrs=["blink", "underline"])
184 | else:
185 | return log
186 | return prefix + " " + log
187 |
--------------------------------------------------------------------------------
/src/data/datasets/tf_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | """a dataset that handles output of tf.data: support datasets from VTAB"""
4 | import functools
5 | import tensorflow.compat.v1 as tf
6 | import torch
7 | import torch.utils.data
8 | import numpy as np
9 |
10 | from collections import Counter
11 | from torch import Tensor
12 |
13 | from ..vtab_datasets import base
14 | # pylint: disable=unused-import
15 | from ..vtab_datasets import caltech
16 | from ..vtab_datasets import cifar
17 | from ..vtab_datasets import clevr
18 | from ..vtab_datasets import diabetic_retinopathy
19 | from ..vtab_datasets import dmlab
20 | from ..vtab_datasets import dsprites
21 | from ..vtab_datasets import dtd
22 | from ..vtab_datasets import eurosat
23 | from ..vtab_datasets import kitti
24 | from ..vtab_datasets import oxford_flowers102
25 | from ..vtab_datasets import oxford_iiit_pet
26 | from ..vtab_datasets import patch_camelyon
27 | from ..vtab_datasets import resisc45
28 | from ..vtab_datasets import smallnorb
29 | from ..vtab_datasets import sun397
30 | from ..vtab_datasets import svhn
31 | from ..vtab_datasets.registry import Registry
32 |
33 | from ...utils import logging
34 | logger = logging.get_logger("MOSA")
35 | tf.config.experimental.set_visible_devices([], 'GPU') # set tensorflow to not use gpu # noqa
36 | DATASETS = [
37 | 'caltech101',
38 | 'cifar(num_classes=100)',
39 | 'dtd',
40 | 'oxford_flowers102',
41 | 'oxford_iiit_pet',
42 | 'patch_camelyon',
43 | 'sun397',
44 | 'svhn',
45 | 'resisc45',
46 | 'eurosat',
47 | 'dmlab',
48 | 'kitti(task="closest_vehicle_distance")',
49 | 'smallnorb(predicted_attribute="label_azimuth")',
50 | 'smallnorb(predicted_attribute="label_elevation")',
51 | 'dsprites(predicted_attribute="label_x_position",num_classes=16)',
52 | 'dsprites(predicted_attribute="label_orientation",num_classes=16)',
53 | 'clevr(task="closest_object_distance")',
54 | 'clevr(task="count_all")',
55 | 'diabetic_retinopathy(config="btgraham-300")'
56 | ]
57 |
58 |
59 | class TFDataset(torch.utils.data.Dataset):
60 | def __init__(self, cfg, split):
61 | assert split in {
62 | "train",
63 | "val",
64 | "test",
65 | "trainval"
66 | }, "Split '{}' not supported for {} dataset".format(
67 | split, cfg.DATA.NAME)
68 | logger.info("Constructing {} dataset {}...".format(
69 | cfg.DATA.NAME, split))
70 |
71 | self.cfg = cfg
72 | self._split = split
73 | self.name = cfg.DATA.NAME
74 |
75 | self.img_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
76 | self.img_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
77 |
78 | self.get_data(cfg, split)
79 |
80 | def get_data(self, cfg, split):
81 | tf_data = build_tf_dataset(cfg, split)
82 | data_list = list(tf_data) # a list of tuples
83 |
84 | self._image_tensor_list = [t[0].numpy().squeeze() for t in data_list]
85 | self._targets = [int(t[1].numpy()[0]) for t in data_list]
86 | self._class_ids = sorted(list(set(self._targets)))
87 |
88 | logger.info("Number of images: {}".format(len(self._image_tensor_list)))
89 | logger.info("Number of classes: {} / {}".format(
90 | len(self._class_ids), self.get_class_num()))
91 |
92 | del data_list
93 | del tf_data
94 |
95 | def get_info(self):
96 | num_imgs = len(self._image_tensor_list)
97 | return num_imgs, self.get_class_num()
98 |
99 | def get_class_num(self):
100 | return self.cfg.DATA.NUMBER_CLASSES
101 |
102 | def get_class_weights(self, weight_type):
103 | """get a list of class weight, return a list float"""
104 | if "train" not in self._split:
105 | raise ValueError(
106 | "only getting training class distribution, " + \
107 | "got split {} instead".format(self._split)
108 | )
109 |
110 | cls_num = self.get_class_num()
111 | if weight_type == "none":
112 | return [1.0] * cls_num
113 |
114 | id2counts = Counter(self._class_ids)
115 | assert len(id2counts) == cls_num
116 | num_per_cls = np.array([id2counts[i] for i in self._class_ids])
117 |
118 | if weight_type == 'inv':
119 | mu = -1.0
120 | elif weight_type == 'inv_sqrt':
121 | mu = -0.5
122 | weight_list = num_per_cls ** mu
123 | weight_list = np.divide(
124 | weight_list, np.linalg.norm(weight_list, 1)) * cls_num
125 | return weight_list.tolist()
126 |
127 | def __getitem__(self, index):
128 | # Load the image
129 | label = self._targets[index]
130 | im = to_torch_imgs(
131 | self._image_tensor_list[index], self.img_mean, self.img_std)
132 |
133 | if self._split == "train":
134 | index = index
135 | else:
136 | index = f"{self._split}{index}"
137 | sample = {
138 | "image": im,
139 | "label": label,
140 | # "id": index
141 | }
142 | return sample
143 |
144 | def __len__(self):
145 | return len(self._targets)
146 |
147 |
148 | def preprocess_fn(data, size=224, input_range=(0.0, 1.0)):
149 | image = data["image"]
150 | image = tf.image.resize(image, [size, size])
151 |
152 | image = tf.cast(image, tf.float32) / 255.0
153 | image = image * (input_range[1] - input_range[0]) + input_range[0]
154 |
155 | data["image"] = image
156 | return data
157 |
158 |
159 | def build_tf_dataset(cfg, mode):
160 | """
161 | Builds a tf data instance, then transform to a list of tensors and labels
162 | """
163 |
164 | if mode not in ["train", "val", "test", "trainval"]:
165 | raise ValueError("The input pipeline supports `train`, `val`, `test`."
166 | "Provided mode is {}".format(mode))
167 |
168 | vtab_dataname = cfg.DATA.NAME.split("vtab-")[-1]
169 | data_dir = cfg.DATA.DATAPATH
170 | if vtab_dataname in DATASETS:
171 | data_cls = Registry.lookup("data." + vtab_dataname)
172 | vtab_tf_dataloader = data_cls(data_dir=data_dir)
173 | else:
174 | raise ValueError("Unknown type for \"dataset\" field: {}".format(
175 | type(vtab_dataname)))
176 |
177 | split_name_dict = {
178 | "dataset_train_split_name": "train800",
179 | "dataset_val_split_name": "val200",
180 | "dataset_trainval_split_name": "train800val200",
181 | "dataset_test_split_name": "test",
182 | }
183 |
184 | def _dict_to_tuple(batch):
185 | return batch['image'], batch['label']
186 |
187 | return vtab_tf_dataloader.get_tf_data(
188 | batch_size=1, # data_params["batch_size"],
189 | drop_remainder=False,
190 | split_name=split_name_dict[f"dataset_{mode}_split_name"],
191 | preprocess_fn=functools.partial(
192 | preprocess_fn,
193 | input_range=(0.0, 1.0),
194 | size=cfg.DATA.CROPSIZE,
195 | ),
196 | for_eval=mode != "train", # handles shuffling
197 | shuffle_buffer_size=1000,
198 | prefetch=1,
199 | train_examples=None,
200 | epochs=1 # setting epochs to 1 make sure it returns one copy of the dataset
201 | ).map(_dict_to_tuple) # return a PrefetchDataset object. (which does not have much documentation to go on)
202 |
203 |
204 | def to_torch_imgs(img: np.ndarray, mean: Tensor, std: Tensor) -> Tensor:
205 | t_img: Tensor = torch.from_numpy(np.transpose(img, (2, 0, 1)))
206 | t_img -= mean
207 | t_img /= std
208 |
209 | return t_img
210 |
--------------------------------------------------------------------------------
/src/data/datasets/pt_dataset.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 | import numpy as np
3 | from torchvision.datasets import CIFAR100, FGVCAircraft, Food101
4 | import torchvision as tv
5 |
6 | from ..transforms import get_transforms
7 | from ...utils import logging
8 | logger = logging.get_logger("MOSA")
9 |
10 | import ssl
11 | ssl._create_default_https_context = ssl._create_unverified_context
12 |
13 |
14 | class CIFAR100Dataset(CIFAR100):
15 | def __init__(self, cfg, split):
16 | root = cfg.DATA.DATAPATH
17 | train = True if split == "train" else False
18 | transform = get_transforms(split, cfg.DATA.CROPSIZE)
19 | super(CIFAR100Dataset, self).__init__(root=root, train=train, transform=transform)
20 | self.raw_ds = CIFAR100(root=root, train=train)
21 | logger.info("Constructing {} dataset {}...".format(
22 | cfg.DATA.NAME, split))
23 |
24 | self.cfg = cfg
25 | self._split = split
26 | self.name = cfg.DATA.NAME
27 | self._construct_imdb()
28 |
29 | def _construct_imdb(self):
30 | logger.info("Number of images: {}".format(len(self.data)))
31 | logger.info("Number of classes: {}".format(len(self.class_to_idx)))
32 |
33 | def get_info(self):
34 | num_imgs = len(self.data)
35 | return num_imgs, self.get_class_num()
36 |
37 | def get_class_num(self):
38 | return self.cfg.DATA.NUMBER_CLASSES
39 | # return len(self._class_ids)
40 |
41 | def get_class_weights(self, weight_type):
42 | """get a list of class weight, return a list float"""
43 | if "train" not in self._split:
44 | raise ValueError(
45 | "only getting training class distribution, " + \
46 | "got split {} instead".format(self._split)
47 | )
48 |
49 | cls_num = self.get_class_num()
50 | if weight_type == "none":
51 | return [1.0] * cls_num
52 |
53 | id2counts = Counter(self.classes)
54 | assert len(id2counts) == cls_num
55 | num_per_cls = np.array([id2counts[i] for i in self.classes])
56 |
57 | if weight_type == 'inv':
58 | mu = -1.0
59 | elif weight_type == 'inv_sqrt':
60 | mu = -0.5
61 | weight_list = num_per_cls ** mu
62 | weight_list = np.divide(
63 | weight_list, np.linalg.norm(weight_list, 1)) * cls_num
64 | return weight_list.tolist()
65 |
66 | def __getitem__(self, index):
67 | img, target = super(CIFAR100Dataset, self).__getitem__(index)
68 | raw_transform = tv.transforms.Compose(
69 | [
70 | tv.transforms.Resize(256),
71 | tv.transforms.CenterCrop(224),
72 | tv.transforms.ToTensor(),
73 | ]
74 | )
75 | raw_img, test_target = self.raw_ds.__getitem__(index)
76 | assert(test_target == target)
77 |
78 | sample = {
79 | "image": img,
80 | "label": target,
81 | "raw": raw_transform(raw_img)
82 | }
83 | return sample
84 |
85 |
86 | class AircraftDataset(FGVCAircraft):
87 | def __init__(self, cfg, split):
88 | root = cfg.DATA.DATAPATH
89 | if split == "train":
90 | split = "trainval"
91 | if split == "val":
92 | split = "test"
93 | transform = get_transforms(split, cfg.DATA.CROPSIZE)
94 | super(AircraftDataset, self).__init__(root=root, split=split, transform=transform)
95 | logger.info("Constructing {} dataset {}...".format(
96 | cfg.DATA.NAME, split))
97 |
98 | self.cfg = cfg
99 | self._split = split
100 | self.name = cfg.DATA.NAME
101 | self._construct_imdb()
102 |
103 | def _construct_imdb(self):
104 | logger.info("Number of images: {}".format(len(self._image_files)))
105 | logger.info("Number of classes: {}".format(len(self.class_to_idx)))
106 |
107 | def get_info(self):
108 | num_imgs = len(self._image_files)
109 | return num_imgs, self.get_class_num()
110 |
111 | def get_class_num(self):
112 | return self.cfg.DATA.NUMBER_CLASSES
113 | # return len(self._class_ids)
114 |
115 | def get_class_weights(self, weight_type):
116 | """get a list of class weight, return a list float"""
117 | if "train" not in self._split:
118 | raise ValueError(
119 | "only getting training class distribution, " + \
120 | "got split {} instead".format(self._split)
121 | )
122 |
123 | cls_num = self.get_class_num()
124 | if weight_type == "none":
125 | return [1.0] * cls_num
126 |
127 | id2counts = Counter(self._labels)
128 | assert len(id2counts) == cls_num
129 | num_per_cls = np.array([id2counts[i] for i in self._labels])
130 |
131 | if weight_type == 'inv':
132 | mu = -1.0
133 | elif weight_type == 'inv_sqrt':
134 | mu = -0.5
135 | weight_list = num_per_cls ** mu
136 | weight_list = np.divide(
137 | weight_list, np.linalg.norm(weight_list, 1)) * cls_num
138 | return weight_list.tolist()
139 |
140 | def __getitem__(self, index):
141 | img, target = super(AircraftDataset, self).__getitem__(index)
142 | sample = {
143 | "image": img,
144 | "label": target,
145 | }
146 | return sample
147 |
148 |
149 | class Food101Dataset(Food101):
150 | def __init__(self, cfg, split):
151 | root = cfg.DATA.DATAPATH
152 | if split == "val":
153 | split = "test"
154 | transform = get_transforms(split, cfg.DATA.CROPSIZE)
155 | super(Food101Dataset, self).__init__(root=root, split=split, transform=transform)
156 | logger.info("Constructing {} dataset {}...".format(
157 | cfg.DATA.NAME, split))
158 |
159 | self.cfg = cfg
160 | self._split = split
161 | self.name = cfg.DATA.NAME
162 | self._construct_imdb()
163 |
164 | def _construct_imdb(self):
165 | logger.info("Number of images: {}".format(len(self._image_files)))
166 | logger.info("Number of classes: {}".format(len(self.class_to_idx)))
167 |
168 | def get_info(self):
169 | num_imgs = len(self._image_files)
170 | return num_imgs, self.get_class_num()
171 |
172 | def get_class_num(self):
173 | return self.cfg.DATA.NUMBER_CLASSES
174 | # return len(self._class_ids)
175 |
176 | def get_class_weights(self, weight_type):
177 | """get a list of class weight, return a list float"""
178 | if "train" not in self._split:
179 | raise ValueError(
180 | "only getting training class distribution, " + \
181 | "got split {} instead".format(self._split)
182 | )
183 |
184 | cls_num = self.get_class_num()
185 | if weight_type == "none":
186 | return [1.0] * cls_num
187 |
188 | id2counts = Counter(self._labels)
189 | assert len(id2counts) == cls_num
190 | num_per_cls = np.array([id2counts[i] for i in self._labels])
191 |
192 | if weight_type == 'inv':
193 | mu = -1.0
194 | elif weight_type == 'inv_sqrt':
195 | mu = -0.5
196 | weight_list = num_per_cls ** mu
197 | weight_list = np.divide(
198 | weight_list, np.linalg.norm(weight_list, 1)) * cls_num
199 | return weight_list.tolist()
200 |
201 | def __getitem__(self, index):
202 | img, target = super(Food101Dataset, self).__getitem__(index)
203 | sample = {
204 | "image": img,
205 | "label": target,
206 | }
207 | return sample
208 |
--------------------------------------------------------------------------------
/src/data/vtab_datasets/kitti.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3 | # Copyright 2019 Google LLC.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Implements Kitti data class."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | import numpy as np
24 |
25 | import tensorflow.compat.v1 as tf
26 | import tensorflow_datasets as tfds
27 |
28 | from . import base as base
29 | from .registry import Registry
30 |
31 |
32 | def _count_all_pp(x):
33 | """Count all objects."""
34 | # Count distribution (thresholded at 15):
35 |
36 | label = tf.math.minimum(tf.size(x["objects"]["type"]) - 1, 8)
37 | return {"image": x["image"], "label": label}
38 |
39 |
40 | def _count_vehicles_pp(x):
41 | """Counting vehicles."""
42 | # Label distribution:
43 |
44 | vehicles = tf.where(x["objects"]["type"] < 3) # Car, Van, Truck.
45 | # Cap at 3.
46 | label = tf.math.minimum(tf.size(vehicles), 3)
47 | return {"image": x["image"], "label": label}
48 |
49 |
50 | def _count_left_pp(x):
51 | """Count objects on the left hand side of the camera."""
52 | # Count distribution (thresholded at 15):
53 |
54 | # Location feature contains (x, y, z) in meters w.r.t. the camera.
55 | objects_on_left = tf.where(x["objects"]["location"][:, 0] < 0)
56 | label = tf.math.minimum(tf.size(objects_on_left), 8)
57 | return {"image": x["image"], "label": label}
58 |
59 |
60 | def _count_far_pp(x):
61 | """Counts objects far from the camera."""
62 | # Threshold removes ~half of the objects.
63 | # Count distribution (thresholded at 15):
64 |
65 | # Location feature contains (x, y, z) in meters w.r.t. the camera.
66 | distant_objects = tf.where(x["objects"]["location"][:, 2] >= 25)
67 | label = tf.math.minimum(tf.size(distant_objects), 8)
68 | return {"image": x["image"], "label": label}
69 |
70 |
71 | def _count_near_pp(x):
72 | """Counts objects close to the camera."""
73 | # Threshold removes ~half of the objects.
74 | # Count distribution:
75 |
76 | # Location feature contains (x, y, z) in meters w.r.t. the camera.
77 | close_objects = tf.where(x["objects"]["location"][:, 2] < 25)
78 | label = tf.math.minimum(tf.size(close_objects), 8)
79 | return {"image": x["image"], "label": label}
80 |
81 |
82 | def _closest_object_distance_pp(x):
83 | """Predict the distance to the closest object."""
84 | # Label distribution:
85 |
86 | # Location feature contains (x, y, z) in meters w.r.t. the camera.
87 | dist = tf.reduce_min(x["objects"]["location"][:, 2])
88 | thrs = np.array([-100, 5.6, 8.4, 13.4, 23.4])
89 | label = tf.reduce_max(tf.where((thrs - dist) < 0))
90 | return {"image": x["image"], "label": label}
91 |
92 |
93 | def _closest_vehicle_distance_pp(x):
94 | """Predict the distance to the closest vehicle."""
95 | # Label distribution:
96 |
97 | # Location feature contains (x, y, z) in meters w.r.t. the camera.
98 | vehicles = tf.where(x["objects"]["type"] < 3) # Car, Van, Truck.
99 | vehicle_z = tf.gather(params=x["objects"]["location"][:, 2], indices=vehicles)
100 | vehicle_z = tf.concat([vehicle_z, tf.constant([[1000.0]])], axis=0)
101 | dist = tf.reduce_min(vehicle_z)
102 | # Results in a uniform distribution over three distances, plus one class for
103 | # "no vehicle".
104 | thrs = np.array([-100.0, 8.0, 20.0, 999.0])
105 | label = tf.reduce_max(tf.where((thrs - dist) < 0))
106 | return {"image": x["image"], "label": label}
107 |
108 |
109 | def _closest_object_x_location_pp(x):
110 | """Predict the absolute x position of the closest object."""
111 | # Label distribution:
112 |
113 | # Location feature contains (x, y, z) in meters w.r.t. the camera.
114 | idx = tf.math.argmin(x["objects"]["location"][:, 2])
115 | xloc = x["objects"]["location"][idx, 0]
116 | thrs = np.array([-100, -6.4, -3.5, 0.0, 3.3, 23.9])
117 | label = tf.reduce_max(tf.where((thrs - xloc) < 0))
118 | return {"image": x["image"], "label": label}
119 |
120 |
121 | _TASK_DICT = {
122 | "count_all": {
123 | "preprocess_fn": _count_all_pp,
124 | "num_classes": 16,
125 | },
126 | "count_left": {
127 | "preprocess_fn": _count_left_pp,
128 | "num_classes": 16,
129 | },
130 | "count_far": {
131 | "preprocess_fn": _count_far_pp,
132 | "num_classes": 16,
133 | },
134 | "count_near": {
135 | "preprocess_fn": _count_near_pp,
136 | "num_classes": 16,
137 | },
138 | "closest_object_distance": {
139 | "preprocess_fn": _closest_object_distance_pp,
140 | "num_classes": 5,
141 | },
142 | "closest_object_x_location": {
143 | "preprocess_fn": _closest_object_x_location_pp,
144 | "num_classes": 5,
145 | },
146 | "count_vehicles": {
147 | "preprocess_fn": _count_vehicles_pp,
148 | "num_classes": 4,
149 | },
150 | "closest_vehicle_distance": {
151 | "preprocess_fn": _closest_vehicle_distance_pp,
152 | "num_classes": 4,
153 | },
154 | }
155 |
156 |
157 | @Registry.register("data.kitti", "class")
158 | class KittiData(base.ImageTfdsData):
159 | """Provides Kitti dataset.
160 |
161 | Six tasks are supported:
162 | 1. Count the number of objects.
163 | 2. Count the number of objects on the left hand side of the camera.
164 | 3. Count the number of objects in the foreground.
165 | 4. Count the number of objects in the background.
166 | 5. Predict the distance of the closest object.
167 | 6. Predict the x-location (w.r.t. the camera) of the closest object.
168 | """
169 |
170 | def __init__(self, task, data_dir=None):
171 |
172 | if task not in _TASK_DICT:
173 | raise ValueError("Unknown task: %s" % task)
174 |
175 | dataset_builder = tfds.builder("kitti:3.2.0", data_dir=data_dir)
176 | dataset_builder.download_and_prepare()
177 |
178 | tfds_splits = {
179 | "train": "train",
180 | "val": "validation",
181 | "trainval": "train+validation",
182 | "test": "test",
183 | "train800": "train[:800]",
184 | "val200": "validation[:200]",
185 | "train800val200": "train[:800]+validation[:200]",
186 | }
187 |
188 | # Example counts are retrieved from the tensorflow dataset info.
189 | train_count = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples
190 | val_count = dataset_builder.info.splits[tfds.Split.VALIDATION].num_examples
191 | test_count = dataset_builder.info.splits[tfds.Split.TEST].num_examples
192 | # Creates a dict with example counts for each split.
193 | num_samples_splits = {
194 | "train": train_count,
195 | "val": val_count,
196 | "trainval": train_count + val_count,
197 | "test": test_count,
198 | "train800": 800,
199 | "val200": 200,
200 | "train800val200": 1000,
201 | }
202 |
203 | task = _TASK_DICT[task]
204 | base_preprocess_fn = task["preprocess_fn"]
205 | super(KittiData, self).__init__(
206 | dataset_builder=dataset_builder,
207 | tfds_splits=tfds_splits,
208 | num_samples_splits=num_samples_splits,
209 | num_preprocessing_threads=400,
210 | shuffle_buffer_size=10000,
211 | base_preprocess_fn=base_preprocess_fn,
212 | num_classes=task["num_classes"])
213 |
--------------------------------------------------------------------------------
/src/data/datasets/json_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | """JSON dataset: support CUB, NABrids, Flower, Dogs and Cars"""
4 |
5 | import os
6 | import torch
7 | import torch.utils.data
8 | import torchvision as tv
9 | import numpy as np
10 | from collections import Counter
11 |
12 | from ..transforms import get_transforms
13 | from ...utils import logging
14 | from ...utils.io_utils import read_json
15 | logger = logging.get_logger("MOSA")
16 |
17 | add_dict = {
18 | "FGVC-cars": "",
19 | "FGVC-cub": 'images',
20 | "FGVC-dogs" : "Images",
21 | "FGVC-flowers": "",
22 | "FGVC-nabirds": "images"
23 | }
24 |
25 | class JSONDataset(torch.utils.data.Dataset):
26 | def __init__(self, cfg, split):
27 | assert split in {
28 | "train",
29 | "val",
30 | "test",
31 | }, "Split '{}' not supported for {} dataset".format(
32 | split, cfg.DATA.NAME)
33 | logger.info("Constructing {} dataset {}...".format(
34 | cfg.DATA.NAME, split))
35 |
36 | self.cfg = cfg
37 | self._split = split
38 | self.name = cfg.DATA.NAME
39 | self.data_dir = cfg.DATA.DATAPATH
40 | self.data_percentage = cfg.DATA.PERCENTAGE
41 | self._construct_imdb(cfg)
42 | self.transform = get_transforms(split, cfg.DATA.CROPSIZE)
43 |
44 | def get_anno(self):
45 | anno_path = os.path.join(self.data_dir, "{}.json".format(self._split))
46 | if "train" in self._split:
47 | if self.data_percentage < 1.0:
48 | anno_path = os.path.join(
49 | self.data_dir,
50 | "{}_{}.json".format(self._split, self.data_percentage)
51 | )
52 | assert os.path.exists(anno_path), "{} dir not found".format(anno_path)
53 |
54 | return read_json(anno_path)
55 |
56 | def get_imagedir(self):
57 | raise NotImplementedError()
58 |
59 | def _construct_imdb(self, cfg):
60 | """Constructs the imdb."""
61 | if cfg.DATA.NAME == "multi":
62 | class_nbr = {"FGVC-cars": 196, "FGVC-cub": 200, "FGVC-dogs": 120, "FGVC-flowers": 102, "FGVC-nabirds": 555}
63 | label_counter = 0
64 | self._imdb = []
65 | self._class_ids = []
66 | for name in class_nbr.keys():
67 | class_dir = self.get_imagedir() + name
68 | if add_dict[name] != "":
69 | img_dir = os.path.join(class_dir, add_dict[name])
70 | else:
71 | img_dir = os.path.join(class_dir)
72 | anno_path = os.path.join(class_dir, "{}.json".format(self._split))
73 | anno = read_json(anno_path)
74 | class_ids_list = sorted(list(set(anno.values())))
75 | class_id_cont_id_dict = {v: i for i, v in enumerate(class_ids_list)}
76 | # label_count = len(list(set(anno.values())))
77 | label_count = class_nbr[name]
78 | for img_name, cls_id in anno.items():
79 | cont_id = class_id_cont_id_dict[cls_id]
80 | im_path = os.path.join(img_dir, img_name)
81 | self._imdb.append({"im_path": im_path, "class": cont_id+label_counter})
82 | label_counter += label_count
83 | self._class_ids += class_ids_list
84 | #print(name, label_count)
85 | #print(cfg.DATA.NUMBER_CLASSES, label_counter)
86 | assert(cfg.DATA.NUMBER_CLASSES == label_counter)
87 |
88 | else:
89 | img_dir = self.get_imagedir()
90 | assert os.path.exists(img_dir), "{} dir not found".format(img_dir)
91 |
92 | anno = self.get_anno()
93 | # Map class ids to contiguous ids
94 | self._class_ids = sorted(list(set(anno.values())))
95 | self._class_id_cont_id = {v: i for i, v in enumerate(self._class_ids)}
96 |
97 | # Construct the image db
98 | self._imdb = []
99 | for img_name, cls_id in anno.items():
100 | cont_id = self._class_id_cont_id[cls_id]
101 | im_path = os.path.join(img_dir, img_name)
102 | self._imdb.append({"im_path": im_path, "class": cont_id})
103 |
104 | logger.info("Number of images: {}".format(len(self._imdb)))
105 | logger.info("Number of classes: {} / {}".format(
106 | len(self._class_ids), self.get_class_num()))
107 |
108 | def get_info(self):
109 | num_imgs = len(self._imdb)
110 | return num_imgs, self.get_class_num()
111 |
112 | def get_class_num(self):
113 | return self.cfg.DATA.NUMBER_CLASSES
114 |
115 | def get_class_weights(self, weight_type):
116 | """get a list of class weight, return a list float"""
117 | if "train" not in self._split:
118 | raise ValueError(
119 | "only getting training class distribution, " + \
120 | "got split {} instead".format(self._split)
121 | )
122 |
123 | cls_num = self.get_class_num()
124 | if weight_type == "none":
125 | return [1.0] * cls_num
126 |
127 | id2counts = Counter(self._class_ids)
128 | assert len(id2counts) == cls_num
129 | num_per_cls = np.array([id2counts[i] for i in self._class_ids])
130 |
131 | if weight_type == 'inv':
132 | mu = -1.0
133 | elif weight_type == 'inv_sqrt':
134 | mu = -0.5
135 | weight_list = num_per_cls ** mu
136 | weight_list = np.divide(
137 | weight_list, np.linalg.norm(weight_list, 1)) * cls_num
138 | return weight_list.tolist()
139 |
140 | def __getitem__(self, index):
141 | # Load the image
142 | im = tv.datasets.folder.default_loader(self._imdb[index]["im_path"])
143 | label = self._imdb[index]["class"]
144 | im = self.transform(im)
145 | if self._split == "train":
146 | index = index
147 | else:
148 | index = f"{self._split}{index}"
149 |
150 | sample = {
151 | "image": im,
152 | "label": label,
153 | # "id": index
154 | }
155 | return sample
156 |
157 | def __len__(self):
158 | return len(self._imdb)
159 |
160 |
161 | class CUB200Dataset(JSONDataset):
162 | """CUB_200 dataset."""
163 |
164 | def __init__(self, cfg, split):
165 | super(CUB200Dataset, self).__init__(cfg, split)
166 |
167 | def get_imagedir(self):
168 | return os.path.join(self.data_dir, "images")
169 |
170 |
171 | class CarsDataset(JSONDataset):
172 | """stanford-cars dataset."""
173 |
174 | def __init__(self, cfg, split):
175 | super(CarsDataset, self).__init__(cfg, split)
176 |
177 | def get_imagedir(self):
178 | return self.data_dir
179 |
180 |
181 | class DogsDataset(JSONDataset):
182 | """stanford-dogs dataset."""
183 |
184 | def __init__(self, cfg, split):
185 | super(DogsDataset, self).__init__(cfg, split)
186 |
187 | def get_imagedir(self):
188 | return os.path.join(self.data_dir, "Images")
189 |
190 |
191 | class FlowersDataset(JSONDataset):
192 | """flowers dataset."""
193 |
194 | def __init__(self, cfg, split):
195 | super(FlowersDataset, self).__init__(cfg, split)
196 |
197 | def get_imagedir(self):
198 | return self.data_dir
199 |
200 |
201 | class NabirdsDataset(JSONDataset):
202 | """Nabirds dataset."""
203 |
204 | def __init__(self, cfg, split):
205 | super(NabirdsDataset, self).__init__(cfg, split)
206 |
207 | def get_imagedir(self):
208 | return os.path.join(self.data_dir, "images")
209 |
210 |
211 | class MultiDataset(JSONDataset):
212 | """Multitask dataset."""
213 | def __init__(self, cfg, split):
214 | super(MultiDataset, self).__init__(cfg, split)
215 |
216 | def get_imagedir(self):
217 | return self.data_dir
--------------------------------------------------------------------------------