├── clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── simple_tokenizer.py └── clip.py ├── clip_maple ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── simple_tokenizer.py └── clip.py ├── requirements.txt ├── configs ├── datasets │ ├── sun397.yaml │ ├── ucf101.yaml │ ├── eurosat.yaml │ ├── food101.yaml │ ├── imagenet.yaml │ ├── oxford_pets.yaml │ ├── caltech101.yaml │ ├── dtd.yaml │ ├── imagenet_a.yaml │ ├── imagenet_r.yaml │ ├── imagenetv2.yaml │ ├── oxford_flowers.yaml │ ├── fgvc_aircraft.yaml │ ├── stanford_cars.yaml │ └── imagenet_sketch.yaml └── trainers │ ├── CoOp │ └── vit_b16_ep10_bs4_lr35.yaml │ ├── ExtrasLinearProbeCoOp │ └── vit_b16_ep10_bs4_lr35.yaml │ ├── CoCoOp │ └── vit_b16_c4_ep10_batch1_ctxv1.yaml │ ├── ExtrasLinearProbeCoCoOp │ └── vit_b16_c4_ep10_batch1_ctxv1.yaml │ ├── MaPLe │ └── vit_b16_c2_ep10_batch4_2ctx.yaml │ ├── ExtrasLinearProbeMaPLe │ └── vit_b16_c2_ep10_batch4_2ctx.yaml │ ├── KgCoOp │ └── vit_b16_ep10_ctxv1_bs4_lr35.yaml │ └── ExtrasLinearProbeKgCoOp │ └── vit_b16_ep10_ctxv1_bs4_lr35.yaml ├── tests ├── shots.jpg ├── lambda_and_epoch.jpg ├── results.csv ├── save_stats.py └── channel_importance.py ├── examples ├── framework.png ├── performance.png ├── few_shot_performance.png ├── base_to_new_performance.png ├── cross_dataset_performance.png └── domain_generalization_performance.png ├── additional_results └── domain_generalization.md ├── utils ├── mail.py ├── clear_logs.py ├── logger.py ├── acrhive.py ├── gpu_allocater.py └── result_parser.py ├── trainers ├── __init__.py ├── coop_stats.py ├── oracle_stats.py ├── elp_coop_stats.py ├── base.py ├── optim.py ├── elp_maple.py ├── elp_cocoop.py ├── elp_kgcoop.py └── elp_coop.py ├── .gitignore ├── datasets ├── __init__.py ├── imagenet_sketch.py ├── imagenetv2.py ├── imagenet_a.py ├── imagenet_r.py ├── food101.py ├── caltech101.py ├── fgvc_aircraft.py ├── eurosat.py ├── stanford_cars.py ├── sun397.py ├── ucf101.py ├── imagenet.py ├── oxford_flowers.py ├── dtd.py ├── oxford_pets.py └── DATASETS.md ├── templates.py ├── configs.py ├── README.md ├── train.py └── parallel_runner.py /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /clip_maple/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy 2 | regex 3 | tqdm 4 | -------------------------------------------------------------------------------- /configs/datasets/sun397.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "SUN397" 3 | -------------------------------------------------------------------------------- /configs/datasets/ucf101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "UCF101" 3 | -------------------------------------------------------------------------------- /configs/datasets/eurosat.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "EuroSAT" 3 | -------------------------------------------------------------------------------- /configs/datasets/food101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "Food101" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNet" 3 | -------------------------------------------------------------------------------- /configs/datasets/oxford_pets.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "OxfordPets" -------------------------------------------------------------------------------- /configs/datasets/caltech101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "Caltech101" 3 | -------------------------------------------------------------------------------- /configs/datasets/dtd.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "DescribableTextures" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenet_a.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetA" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenet_r.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetR" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenetv2.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetV2" 3 | -------------------------------------------------------------------------------- /configs/datasets/oxford_flowers.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "OxfordFlowers" -------------------------------------------------------------------------------- /configs/datasets/fgvc_aircraft.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "FGVCAircraft" 3 | -------------------------------------------------------------------------------- /configs/datasets/stanford_cars.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "StanfordCars" 3 | -------------------------------------------------------------------------------- /tests/shots.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Koorye/DePT/HEAD/tests/shots.jpg -------------------------------------------------------------------------------- /configs/datasets/imagenet_sketch.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetSketch" 3 | -------------------------------------------------------------------------------- /examples/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Koorye/DePT/HEAD/examples/framework.png -------------------------------------------------------------------------------- /examples/performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Koorye/DePT/HEAD/examples/performance.png -------------------------------------------------------------------------------- /tests/lambda_and_epoch.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Koorye/DePT/HEAD/tests/lambda_and_epoch.jpg -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Koorye/DePT/HEAD/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /examples/few_shot_performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Koorye/DePT/HEAD/examples/few_shot_performance.png -------------------------------------------------------------------------------- /examples/base_to_new_performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Koorye/DePT/HEAD/examples/base_to_new_performance.png -------------------------------------------------------------------------------- /examples/cross_dataset_performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Koorye/DePT/HEAD/examples/cross_dataset_performance.png -------------------------------------------------------------------------------- /clip_maple/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Koorye/DePT/HEAD/clip_maple/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /examples/domain_generalization_performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Koorye/DePT/HEAD/examples/domain_generalization_performance.png -------------------------------------------------------------------------------- /additional_results/domain_generalization.md: -------------------------------------------------------------------------------- 1 | **Domain Generalization Performance** 2 | 3 | ![Domain Generalization](../examples/domain_generalization_performance.png) -------------------------------------------------------------------------------- /utils/mail.py: -------------------------------------------------------------------------------- 1 | # a simple mail sending tool 2 | 3 | import yagmail 4 | 5 | 6 | class MailClient(object): 7 | def __init__(self, cfg): 8 | self.client = yagmail.SMTP(cfg['username'], cfg['password'], cfg['host']) 9 | self.to = cfg['to'] 10 | 11 | def send(self, subject, contents): 12 | self.client.send(self.to, subject, contents) 13 | -------------------------------------------------------------------------------- /tests/results.csv: -------------------------------------------------------------------------------- 1 | ,lambda,group,acc,shape 2 | 0,0.0,base,81.68,. 3 | 1,0.1,base,81.97,. 4 | 2,0.3,base,82.52,. 5 | 3,0.5,base,83.33,. 6 | 4,0.7,base,83.66,. 7 | 5,0.9,base,83.01,. 8 | 6,1.0,base,82.05,. 9 | 7,0.0,new,71.48,. 10 | 8,0.1,new,70.94,. 11 | 9,0.3,new,71.89,. 12 | 10,0.5,new,70.58,. 13 | 11,0.7,new,71.82,. 14 | 12,0.9,new,70.84,. 15 | 13,1.0,new,67.78,. 16 | 14,0.0,H,75.83,* 17 | 15,0.1,H,75.55,* 18 | 16,0.3,H,76.45,* 19 | 17,0.5,H,75.78,* 20 | 18,0.7,H,77.29,* 21 | 19,0.9,H,75.85,* 22 | 20,1.0,H,73.64,* 23 | -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .coop import CoOp 2 | from .coop import CoOp 3 | from .cocoop import CoCoOp 4 | from .kgcoop import KgCoOp 5 | from .maple import MaPLe 6 | from .elp_coop import ExtrasLinearProbeCoOp 7 | from .elp_cocoop import ExtrasLinearProbeCoCoOp 8 | from .elp_kgcoop import ExtrasLinearProbeKgCoOp 9 | from .elp_maple import ExtrasLinearProbeMaPLe 10 | 11 | __all__ = ['CoOp', 'CoCoOp', 'KgCoOp', 'MaPLe', 12 | 'ExtrasLinearProbeCoOp', 'ExtrasLinearProbeCoCoOp', 'ExtrasLinearProbeKgCoOp', 'ExtrasLinearProbeMaPLe'] 13 | -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b16_ep10_bs4_lr35.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 4 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.0035 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /utils/clear_logs.py: -------------------------------------------------------------------------------- 1 | # a tool for clearing logs, you can ignore it 2 | import argparse 3 | import os 4 | import os.path as osp 5 | 6 | 7 | def clear_logs(root): 8 | for root, dirs, filenames in os.walk(root): 9 | for filename in filenames: 10 | if 'log.txt' in filename: 11 | path = osp.join(root, filename) 12 | print(f'Deleting file {path}') 13 | os.remove(path) 14 | 15 | def main(args): 16 | root = args.root 17 | clear_logs(root) 18 | 19 | 20 | if __name__ == '__main__': 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--root') 23 | args = parser.parse_args() 24 | main(args) 25 | -------------------------------------------------------------------------------- /configs/trainers/ExtrasLinearProbeCoOp/vit_b16_ep10_bs4_lr35.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 4 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.0035 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | outputs/ 2 | results/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | bin/ 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # Installer logs 27 | pip-log.txt 28 | pip-delete-this-directory.txt 29 | 30 | # Unit test / coverage reports 31 | .tox/ 32 | .coverage 33 | .cache 34 | nosetests.xml 35 | coverage.xml 36 | 37 | # Translations 38 | *.mo 39 | 40 | # Mr Developer 41 | .mr.developer.cfg 42 | .project 43 | .pydevproject 44 | 45 | # Rope 46 | .ropeproject 47 | 48 | # Django stuff: 49 | *.log 50 | *.pot 51 | 52 | # Sphinx documentation 53 | docs/_build/ -------------------------------------------------------------------------------- /trainers/coop_stats.py: -------------------------------------------------------------------------------- 1 | # save coop inference features, for channel importance statistics 2 | from dassl.engine import TRAINER_REGISTRY 3 | 4 | from .coop import CoOp 5 | 6 | 7 | @TRAINER_REGISTRY.register() 8 | class CoOpStats(CoOp): 9 | def model_inference(self, image): 10 | image_features = self.model.image_encoder(image.type(self.model.dtype)) 11 | 12 | prompts = self.model.prompt_learner() 13 | tokenized_prompts = self.model.tokenized_prompts 14 | text_features = self.model.text_encoder(prompts, tokenized_prompts) 15 | 16 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 17 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 18 | 19 | return text_features, image_features 20 | -------------------------------------------------------------------------------- /configs/trainers/CoCoOp/vit_b16_c4_ep10_batch1_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 4 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /trainers/oracle_stats.py: -------------------------------------------------------------------------------- 1 | # save oracle inference features, for channel importance statistics 2 | from dassl.engine import TRAINER_REGISTRY 3 | 4 | from .coop import CoOp 5 | 6 | 7 | @TRAINER_REGISTRY.register() 8 | class OracleStats(CoOp): 9 | def model_inference(self, image): 10 | image_features = self.model.image_encoder(image.type(self.model.dtype)) 11 | 12 | prompts = self.model.prompt_learner() 13 | tokenized_prompts = self.model.tokenized_prompts 14 | text_features = self.model.text_encoder(prompts, tokenized_prompts) 15 | 16 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 17 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 18 | 19 | return text_features, image_features 20 | -------------------------------------------------------------------------------- /configs/trainers/ExtrasLinearProbeCoCoOp/vit_b16_c4_ep10_batch1_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 4 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/MaPLe/vit_b16_c2_ep10_batch4_2ctx.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 4 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.0035 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | MAPLE: 33 | N_CTX: 2 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" 36 | PROMPT_DEPTH: 9 -------------------------------------------------------------------------------- /configs/trainers/ExtrasLinearProbeMaPLe/vit_b16_c2_ep10_batch4_2ctx.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 4 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.0035 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | MAPLE: 33 | N_CTX: 2 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" 36 | PROMPT_DEPTH: 9 -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .caltech101 import Caltech101 2 | from .dtd import DescribableTextures 3 | from .eurosat import EuroSAT 4 | from .fgvc_aircraft import FGVCAircraft 5 | from .food101 import Food101 6 | from .imagenet_a import ImageNetA 7 | from .imagenet_r import ImageNetR 8 | from .imagenet_sketch import ImageNetSketch 9 | from .imagenet import ImageNet 10 | from .imagenetv2 import ImageNetV2 11 | from .oxford_flowers import OxfordFlowers 12 | from .oxford_pets import OxfordPets 13 | from .stanford_cars import StanfordCars 14 | from .sun397 import SUN397 15 | from .ucf101 import UCF101 16 | 17 | 18 | __all__ = ['Caltech101', 'DescribableTextures', 'EuroSAT', 'FGVCAircraft', 'Food101', 19 | 'ImageNetA', 'ImageNetR', 'ImageNetSketch', 'ImageNet', 'ImageNetV2', 20 | 'OxfordFlowers', 'OxfordPets', 'StanfordCars', 'SUN397', 'UCF101'] 21 | -------------------------------------------------------------------------------- /configs/trainers/KgCoOp/vit_b16_ep10_ctxv1_bs4_lr35.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 4 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.0035 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" 34 | N_CTX: 4 35 | CSC: False 36 | CLASS_TOKEN_POSITION: "end" 37 | ALPHA: 1.0 38 | W: 2.0 39 | -------------------------------------------------------------------------------- /configs/trainers/ExtrasLinearProbeKgCoOp/vit_b16_ep10_ctxv1_bs4_lr35.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 4 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.0035 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" 34 | N_CTX: 4 35 | CSC: False 36 | CLASS_TOKEN_POSITION: "end" 37 | ALPHA: 1.0 38 | W: 2.0 39 | -------------------------------------------------------------------------------- /datasets/imagenet_sketch.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import listdir_nohidden 6 | 7 | from .imagenet import ImageNet 8 | from .oxford_pets import OxfordPets 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class ImageNetSketch(DatasetBase): 13 | """ImageNet-Sketch. 14 | 15 | This dataset is used for testing only. 16 | """ 17 | 18 | dataset_dir = "imagenet-sketch" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.image_dir = os.path.join(self.dataset_dir, "images") 24 | 25 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 26 | classnames = ImageNet.read_classnames(text_file) 27 | 28 | data = self.read_data(classnames) 29 | 30 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 31 | train, test = copy.deepcopy(data), copy.deepcopy(data) 32 | train, test = OxfordPets.subsample_classes(train, test, subsample=subsample) 33 | 34 | super().__init__(train_x=train, test=test) 35 | 36 | def read_data(self, classnames): 37 | image_dir = self.image_dir 38 | folders = listdir_nohidden(image_dir, sort=True) 39 | items = [] 40 | 41 | for label, folder in enumerate(folders): 42 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 43 | classname = classnames[folder] 44 | for imname in imnames: 45 | impath = os.path.join(image_dir, folder, imname) 46 | item = Datum(impath=impath, label=label, classname=classname) 47 | items.append(item) 48 | 49 | return items 50 | -------------------------------------------------------------------------------- /trainers/elp_coop_stats.py: -------------------------------------------------------------------------------- 1 | # save coop w/ DePT inference features, for channel importance statistics 2 | from dassl.engine import TRAINER_REGISTRY 3 | 4 | from .elp_coop import ExtrasLinearProbeCoOp 5 | 6 | 7 | @TRAINER_REGISTRY.register() 8 | class ExtrasLinearProbeCoOpStats(ExtrasLinearProbeCoOp): 9 | def model_inference(self, input_): 10 | if self.is_base: 11 | return self._forward_base(input_) 12 | else: 13 | return self._forward_new(input_) 14 | 15 | def _forward_base(self, img, labels=None): 16 | assert not self.model.prompt_learner.training 17 | 18 | text_feats, img_feats = self.model._forward_feats(img) 19 | 20 | text_feats_norm = text_feats / text_feats.norm(dim=-1, keepdim=True) 21 | img_feats_norm = img_feats / img_feats.norm(dim=-1, keepdim=True) 22 | 23 | if self.model.film_cfg.LINEAR_PROBE: 24 | text_feats_lp = self.model.film_lp_text(text_feats) 25 | img_feats_lp = self.model.film_lp_img(img_feats) 26 | text_feats_lp_norm = text_feats_lp / text_feats_lp.norm(dim=-1, keepdim=True) 27 | img_feats_lp_norm = img_feats_lp / img_feats_lp.norm(dim=-1, keepdim=True) 28 | 29 | lambda_ = 0 30 | text_feats_norm = text_feats_norm * (1 - lambda_) + text_feats_lp_norm * lambda_ 31 | img_feats_norm = img_feats_norm * (1 - lambda_) + img_feats_lp_norm * lambda_ 32 | 33 | return text_feats_norm, img_feats_norm 34 | 35 | def _forward_new(self, img): 36 | assert not self.model.prompt_learner.training 37 | 38 | text_feats, img_feats = self.model._forward_feats(img) 39 | text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True) 40 | img_feats = img_feats / img_feats.norm(dim=-1, keepdim=True) 41 | return text_feats, img_feats 42 | -------------------------------------------------------------------------------- /datasets/imagenetv2.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import listdir_nohidden 6 | 7 | from .imagenet import ImageNet 8 | from .oxford_pets import OxfordPets 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class ImageNetV2(DatasetBase): 13 | """ImageNetV2. 14 | 15 | This dataset is used for testing only. 16 | """ 17 | 18 | dataset_dir = "imagenetv2" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | image_dir = "imagenetv2-matched-frequency-format-val" 24 | self.image_dir = os.path.join(self.dataset_dir, image_dir) 25 | 26 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 27 | classnames = ImageNet.read_classnames(text_file) 28 | 29 | data = self.read_data(classnames) 30 | 31 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 32 | train, test = copy.deepcopy(data), copy.deepcopy(data) 33 | train, test = OxfordPets.subsample_classes(train, test, subsample=subsample) 34 | 35 | super().__init__(train_x=train, test=test) 36 | 37 | def read_data(self, classnames): 38 | image_dir = self.image_dir 39 | folders = list(classnames.keys()) 40 | items = [] 41 | 42 | for label in range(1000): 43 | class_dir = os.path.join(image_dir, str(label)) 44 | imnames = listdir_nohidden(class_dir) 45 | folder = folders[label] 46 | classname = classnames[folder] 47 | for imname in imnames: 48 | impath = os.path.join(class_dir, imname) 49 | item = Datum(impath=impath, label=label, classname=classname) 50 | items.append(item) 51 | 52 | return items 53 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | import sys 4 | import time 5 | import os.path as osp 6 | 7 | from dassl.utils.tools import mkdir_if_missing 8 | 9 | # modify print function which will flush instantly 10 | print = functools.partial(print, flush=True) 11 | 12 | 13 | class Logger: 14 | def __init__(self, fpath=None, write_to_console=False): 15 | self.console = sys.stdout if write_to_console else None 16 | self.file = None 17 | if fpath is not None: 18 | mkdir_if_missing(osp.dirname(fpath)) 19 | self.file = open(fpath, "w") 20 | 21 | def __del__(self): 22 | self.close() 23 | 24 | def __enter__(self): 25 | pass 26 | 27 | def __exit__(self, *args): 28 | self.close() 29 | 30 | def write(self, msg): 31 | if self.console is not None: 32 | self.console.write(msg) 33 | 34 | if self.file is not None: 35 | self.file.write(msg) 36 | 37 | def flush(self): 38 | if self.console is not None: 39 | self.console.flush() 40 | 41 | if self.file is not None: 42 | self.file.flush() 43 | os.fsync(self.file.fileno()) 44 | 45 | def close(self): 46 | if self.console is not None: 47 | self.console.close() 48 | 49 | if self.file is not None: 50 | self.file.close() 51 | 52 | 53 | def setup_logger(output=None, write_to_console=False): 54 | if output is None: 55 | return 56 | 57 | if output.endswith(".txt") or output.endswith(".log"): 58 | fpath = output 59 | else: 60 | fpath = osp.join(output, "log.txt") 61 | 62 | if osp.exists(fpath): 63 | # make sure the existing log file is not over-written 64 | fpath += time.strftime("-%Y-%m-%d-%H-%M-%S") 65 | 66 | sys.stdout = Logger(fpath, write_to_console) 67 | return fpath 68 | -------------------------------------------------------------------------------- /datasets/imagenet_a.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import listdir_nohidden 6 | 7 | from .imagenet import ImageNet 8 | from .oxford_pets import OxfordPets 9 | 10 | TO_BE_IGNORED = ["README.txt"] 11 | 12 | 13 | @DATASET_REGISTRY.register() 14 | class ImageNetA(DatasetBase): 15 | """ImageNet-A(dversarial). 16 | 17 | This dataset is used for testing only. 18 | """ 19 | 20 | dataset_dir = "imagenet-adversarial" 21 | 22 | def __init__(self, cfg): 23 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 24 | self.dataset_dir = os.path.join(root, self.dataset_dir) 25 | self.image_dir = os.path.join(self.dataset_dir, "imagenet-a") 26 | 27 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 28 | classnames = ImageNet.read_classnames(text_file) 29 | 30 | data = self.read_data(classnames) 31 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 32 | 33 | train, test = copy.deepcopy(data), copy.deepcopy(data) 34 | train, test = OxfordPets.subsample_classes(train, test, subsample=subsample) 35 | 36 | super().__init__(train_x=train, test=test) 37 | 38 | def read_data(self, classnames): 39 | image_dir = self.image_dir 40 | folders = listdir_nohidden(image_dir, sort=True) 41 | folders = [f for f in folders if f not in TO_BE_IGNORED] 42 | items = [] 43 | 44 | for label, folder in enumerate(folders): 45 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 46 | classname = classnames[folder] 47 | for imname in imnames: 48 | impath = os.path.join(image_dir, folder, imname) 49 | item = Datum(impath=impath, label=label, classname=classname) 50 | items.append(item) 51 | 52 | return items 53 | -------------------------------------------------------------------------------- /datasets/imagenet_r.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import listdir_nohidden 6 | 7 | from .imagenet import ImageNet 8 | from .oxford_pets import OxfordPets 9 | 10 | TO_BE_IGNORED = ["README.txt"] 11 | 12 | 13 | @DATASET_REGISTRY.register() 14 | class ImageNetR(DatasetBase): 15 | """ImageNet-R(endition). 16 | 17 | This dataset is used for testing only. 18 | """ 19 | 20 | dataset_dir = "imagenet-rendition" 21 | 22 | def __init__(self, cfg): 23 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 24 | self.dataset_dir = os.path.join(root, self.dataset_dir) 25 | self.image_dir = os.path.join(self.dataset_dir, "imagenet-r") 26 | 27 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 28 | classnames = ImageNet.read_classnames(text_file) 29 | 30 | data = self.read_data(classnames) 31 | 32 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 33 | train, test = copy.deepcopy(data), copy.deepcopy(data) 34 | train, test = OxfordPets.subsample_classes(train, test, subsample=subsample) 35 | 36 | super().__init__(train_x=train, test=test) 37 | 38 | def read_data(self, classnames): 39 | image_dir = self.image_dir 40 | folders = listdir_nohidden(image_dir, sort=True) 41 | folders = [f for f in folders if f not in TO_BE_IGNORED] 42 | items = [] 43 | 44 | for label, folder in enumerate(folders): 45 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 46 | classname = classnames[folder] 47 | for imname in imnames: 48 | impath = os.path.join(image_dir, folder, imname) 49 | item = Datum(impath=impath, label=label, classname=classname) 50 | items.append(item) 51 | 52 | return items 53 | -------------------------------------------------------------------------------- /utils/acrhive.py: -------------------------------------------------------------------------------- 1 | # a tool for sorting out the results of experiments, you can ignore it 2 | import os 3 | import os.path as osp 4 | import pandas as pd 5 | 6 | 7 | base_dir = 'results/shots/' 8 | dfs = [] 9 | 10 | for dir_ in os.listdir(base_dir): 11 | if not osp.isdir(osp.join(base_dir, dir_)): 12 | continue 13 | 14 | filename = os.listdir(osp.join(base_dir, dir_))[0] 15 | df = pd.read_csv(osp.join(base_dir, dir_, filename)) 16 | df['method'], df['shot'] = dir_.split('_') 17 | df['shot'] = df['shot'].apply(lambda x: int(x[0])) 18 | dfs.append(df) 19 | 20 | df = pd.concat(dfs).reset_index(drop=True) 21 | df = df[df['dataset'] == 'average'].drop(columns='dataset').reset_index(drop=True) 22 | df = df[['shot', 'method', 'base_acc', 'new_acc', 'H']] 23 | 24 | df['method'] = df['method'].replace({ 25 | 'coop': 'CoOp', 26 | 'cocoop': 'CoCoOp', 27 | 'kgcoop': 'KgCoOp', 28 | 'maple': 'MaPLe', 29 | 'elpcoop': 'CoOp w/ DePT', 30 | 'elpcocoop': 'CoCoOp w/ DePT', 31 | 'elpkgcoop': 'KgCoOp w/ DePT', 32 | 'elpmaple': 'MaPLe w/ DePT', 33 | }) 34 | 35 | df = df.sort_values(['method', 'shot']).reset_index(drop=True) 36 | df.to_csv(osp.join(base_dir, 'shots.csv'), index=None) 37 | 38 | 39 | base_dir = 'results/epochs/' 40 | dfs = [] 41 | 42 | for dir_ in os.listdir(base_dir): 43 | if not osp.isdir(osp.join(base_dir, dir_)): 44 | continue 45 | 46 | filenames = os.listdir(osp.join(base_dir, dir_)) 47 | for filename in filenames: 48 | df = pd.read_csv(osp.join(base_dir, dir_, filename)) 49 | df['method'], df['epoch'] = filename.split('.')[0].split('-') 50 | df['epoch'] = df['epoch'].apply(lambda x: int(x[2:])) 51 | dfs.append(df) 52 | 53 | df = pd.concat(dfs).reset_index(drop=True) 54 | df = df[df['dataset'] == 'average'].drop(columns='dataset').reset_index(drop=True) 55 | df = df[['epoch', 'method', 'base_acc', 'new_acc', 'H']] 56 | 57 | df['method'] = df['method'].replace({ 58 | 'coop': 'CoOp', 59 | 'elpcoop': 'CoOp w/ DePT', 60 | }) 61 | 62 | df = df.sort_values('epoch').reset_index(drop=True) 63 | df.to_csv(osp.join(base_dir, 'epochs.csv'), index=None) 64 | -------------------------------------------------------------------------------- /trainers/base.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import time 3 | from dassl.engine.trainer import MetricMeter, AverageMeter, SummaryWriter 4 | from dassl.engine.trainer import TrainerX as TrainerX_ 5 | 6 | from utils.logger import print # modify print function 7 | 8 | 9 | class TrainerX(TrainerX_): 10 | def run_epoch(self): 11 | self.set_model_mode("train") 12 | losses = MetricMeter() 13 | batch_time = AverageMeter() 14 | data_time = AverageMeter() 15 | self.num_batches = len(self.train_loader_x) 16 | 17 | end = time.time() 18 | for self.batch_idx, batch in enumerate(self.train_loader_x): 19 | data_time.update(time.time() - end) 20 | loss_summary = self.forward_backward(batch) 21 | batch_time.update(time.time() - end) 22 | losses.update(loss_summary) 23 | 24 | meet_freq = (self.batch_idx + 1) % self.cfg.TRAIN.PRINT_FREQ == 0 25 | only_few_batches = self.num_batches < self.cfg.TRAIN.PRINT_FREQ 26 | if meet_freq or only_few_batches: 27 | nb_remain = 0 28 | nb_remain += self.num_batches - self.batch_idx - 1 29 | nb_remain += ( 30 | self.max_epoch - self.epoch - 1 31 | ) * self.num_batches 32 | eta_seconds = batch_time.avg * nb_remain 33 | eta = str(datetime.timedelta(seconds=int(eta_seconds))) 34 | 35 | info = [] 36 | info += [f"epoch [{self.epoch + 1}/{self.max_epoch}]"] 37 | info += [f"batch [{self.batch_idx + 1}/{self.num_batches}]"] 38 | info += [f"time {batch_time.val:.3f} ({batch_time.avg:.3f})"] 39 | info += [f"data {data_time.val:.3f} ({data_time.avg:.3f})"] 40 | info += [f"{losses}"] 41 | info += [f"lr {self.get_current_lr():.4e}"] 42 | info += [f"eta {eta}"] 43 | print(" ".join(info)) 44 | 45 | n_iter = self.epoch * self.num_batches + self.batch_idx 46 | for name, meter in losses.meters.items(): 47 | self.write_scalar("train/" + name, meter.avg, n_iter) 48 | self.write_scalar("train/lr", self.get_current_lr(), n_iter) 49 | 50 | end = time.time() 51 | -------------------------------------------------------------------------------- /datasets/food101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | from .dtd import DescribableTextures as DTD 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class Food101(DatasetBase): 13 | 14 | dataset_dir = "food-101" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "images") 20 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_Food101.json") 21 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 22 | mkdir_if_missing(self.split_fewshot_dir) 23 | 24 | if os.path.exists(self.split_path): 25 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 26 | else: 27 | train, val, test = DTD.read_and_split_data(self.image_dir) 28 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 29 | 30 | num_shots = cfg.DATASET.NUM_SHOTS 31 | if num_shots >= 1: 32 | seed = cfg.SEED 33 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 34 | 35 | if os.path.exists(preprocessed): 36 | print(f"Loading preprocessed few-shot data from {preprocessed}") 37 | with open(preprocessed, "rb") as file: 38 | data = pickle.load(file) 39 | train, val = data["train"], data["val"] 40 | else: 41 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 42 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 43 | data = {"train": train, "val": val} 44 | print(f"Saving preprocessed few-shot data to {preprocessed}") 45 | with open(preprocessed, "wb") as file: 46 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 47 | 48 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 49 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 50 | 51 | super().__init__(train_x=train, val=val, test=test) 52 | -------------------------------------------------------------------------------- /datasets/caltech101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | from .dtd import DescribableTextures as DTD 9 | 10 | IGNORED = ["BACKGROUND_Google", "Faces_easy"] 11 | NEW_CNAMES = { 12 | "airplanes": "airplane", 13 | "Faces": "face", 14 | "Leopards": "leopard", 15 | "Motorbikes": "motorbike", 16 | } 17 | 18 | 19 | @DATASET_REGISTRY.register() 20 | class Caltech101(DatasetBase): 21 | 22 | dataset_dir = "caltech-101" 23 | 24 | def __init__(self, cfg): 25 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 26 | self.dataset_dir = os.path.join(root, self.dataset_dir) 27 | self.image_dir = os.path.join(self.dataset_dir, "101_ObjectCategories") 28 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_Caltech101.json") 29 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 30 | mkdir_if_missing(self.split_fewshot_dir) 31 | 32 | if os.path.exists(self.split_path): 33 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 34 | else: 35 | train, val, test = DTD.read_and_split_data(self.image_dir, ignored=IGNORED, new_cnames=NEW_CNAMES) 36 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 37 | 38 | num_shots = cfg.DATASET.NUM_SHOTS 39 | if num_shots >= 1: 40 | seed = cfg.SEED 41 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 42 | 43 | if os.path.exists(preprocessed): 44 | print(f"Loading preprocessed few-shot data from {preprocessed}") 45 | with open(preprocessed, "rb") as file: 46 | data = pickle.load(file) 47 | train, val = data["train"], data["val"] 48 | else: 49 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 50 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 51 | data = {"train": train, "val": val} 52 | print(f"Saving preprocessed few-shot data to {preprocessed}") 53 | with open(preprocessed, "wb") as file: 54 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 55 | 56 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 57 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 58 | 59 | super().__init__(train_x=train, val=val, test=test) 60 | -------------------------------------------------------------------------------- /templates.py: -------------------------------------------------------------------------------- 1 | # 0: data root 2 | # 1: seed 3 | # 2: trainer 4 | # 3: dataset 5 | # 4: cfg 6 | # 5: root 7 | # 6: shots 8 | # 7: load epoch 9 | TRAIN_CMD_TEMPLATE_BASE_TO_NEW = r'''python train.py \ 10 | --root {0} \ 11 | --seed {1} \ 12 | --trainer {2} \ 13 | --dataset-config-file configs/datasets/{3}.yaml \ 14 | --config-file configs/trainers/{2}/{4}.yaml \ 15 | --output-dir {5}/train_base/{2}/{3}/shots{6}/{4}/seed{1} \ 16 | DATASET.NUM_SHOTS {6} DATASET.SUBSAMPLE_CLASSES base ''' 17 | 18 | TEST_CMD_TEMPLATE_BASE_TO_NEW = r'''python train.py \ 19 | --root {0} \ 20 | --seed {1} \ 21 | --trainer {2} \ 22 | --dataset-config-file configs/datasets/{3}.yaml \ 23 | --config-file configs/trainers/{2}/{4}.yaml \ 24 | --output-dir {5}/test_new/{2}/{3}/shots{6}/{4}/seed{1} \ 25 | --model-dir {5}/train_base/{2}/{3}/shots{6}/{4}/seed{1} \ 26 | --load-epoch {7} \ 27 | --eval-only \ 28 | DATASET.NUM_SHOTS {6} DATASET.SUBSAMPLE_CLASSES new ''' 29 | 30 | # 0: data root 31 | # 1: seed 32 | # 2: trainer 33 | # 3: dataset 34 | # 4: cfg 35 | # 5: root 36 | # 6: shots 37 | # 7: load dataset 38 | # 8: load epoch 39 | TRAIN_CMD_TEMPLATE_CROSS_DATASET = r'''python train.py \ 40 | --root {0} \ 41 | --seed {1} \ 42 | --trainer {2} \ 43 | --dataset-config-file configs/datasets/{3}.yaml \ 44 | --config-file configs/trainers/{2}/{4}.yaml \ 45 | --output-dir {5}/{2}/{3}/shots{6}/{4}/seed{1} \ 46 | DATASET.NUM_SHOTS {6} DATASET.SUBSAMPLE_CLASSES all ''' 47 | 48 | TEST_CMD_TEMPLATE_CROSS_DATASET = r'''python train.py \ 49 | --root {0} \ 50 | --seed {1} \ 51 | --trainer {2} \ 52 | --dataset-config-file configs/datasets/{3}.yaml \ 53 | --config-file configs/trainers/{2}/{4}.yaml \ 54 | --output-dir {5}/{2}/{3}/shots{6}/{4}/seed{1} \ 55 | --model-dir {5}/{2}/{7}/shots{6}/{4}/seed{1} \ 56 | --load-epoch {8} \ 57 | --eval-only \ 58 | DATASET.NUM_SHOTS {6} DATASET.SUBSAMPLE_CLASSES all ''' 59 | 60 | 61 | def get_command(data_root, seed, trainer, dataset, cfg, root, shots, load_dataset, load_epoch, opts=[], mode='b2n', train=True): 62 | if mode == 'b2n': 63 | if train: 64 | cmd = TRAIN_CMD_TEMPLATE_BASE_TO_NEW.format(data_root, seed, trainer, dataset, cfg, root, shots) 65 | else: 66 | cmd = TEST_CMD_TEMPLATE_BASE_TO_NEW.format(data_root, seed, trainer, dataset, cfg, root, shots, load_epoch) 67 | else: 68 | if train: 69 | cmd = TRAIN_CMD_TEMPLATE_CROSS_DATASET.format(data_root, seed, trainer, dataset, cfg, root, shots) 70 | else: 71 | cmd = TEST_CMD_TEMPLATE_CROSS_DATASET.format(data_root, seed, trainer, dataset, cfg, root, shots, load_dataset, load_epoch) 72 | 73 | for opt in opts: 74 | cmd += f'{opt} ' 75 | 76 | return cmd 77 | -------------------------------------------------------------------------------- /utils/gpu_allocater.py: -------------------------------------------------------------------------------- 1 | # a tool for parallelizing running commands, where multiple commands are assigned to multiple graphics cards 2 | import datetime 3 | import os 4 | from threading import Thread 5 | 6 | from utils.logger import print 7 | 8 | 9 | class RunCommandThread(Thread): 10 | def __init__(self, command): 11 | Thread.__init__(self) 12 | self.command = command 13 | 14 | def run(self): 15 | self.result = os.system(self.command) 16 | 17 | def get_result(self): 18 | return self.result 19 | 20 | 21 | class GPUAllocater(object): 22 | def __init__(self, gpu_ids): 23 | self.gpu_ids = gpu_ids 24 | 25 | self.num_gpus = len(gpu_ids) 26 | self.commands = [] 27 | 28 | def add_command(self, command): 29 | self.commands.append(command) 30 | 31 | def run(self): 32 | print('Summary of all commands:') 33 | for command in self.commands: 34 | command_ = command.replace('\\', '').replace('\n', ' ') 35 | print(command_[:75] + '...' + command_[-75:]) 36 | print('=' * 40) 37 | 38 | current_command_idx, num_commands = 0, len(self.commands) 39 | print(f'Number of commands: {num_commands}\n') 40 | 41 | while len(self.commands) > 0: 42 | commands_once, self.commands = self.commands[:self.num_gpus], self.commands[self.num_gpus:] 43 | current_command_idx += len(commands_once) 44 | print(f'[{current_command_idx} / {num_commands}] Running commands:') 45 | self.run_once(commands_once) 46 | 47 | def run_once(self, commands): 48 | tasks = [] 49 | 50 | print('=' * 40) 51 | for idx, command in enumerate(commands): 52 | gpu_id = self.gpu_ids[idx] 53 | command = f'CUDA_VISIBLE_DEVICES={gpu_id} ' + command 54 | 55 | print(command) 56 | if idx != len(commands) - 1: 57 | print('\n') 58 | 59 | t = RunCommandThread(command) 60 | tasks.append(t) 61 | 62 | print('=' * 40) 63 | print('Starting commands...') 64 | 65 | start_time = datetime.datetime.now() 66 | for t in tasks: 67 | t.start() 68 | 69 | for t in tasks: 70 | t.join() 71 | 72 | # raise exception when one of tasks does not run successfully 73 | results = [t.get_result() for t in tasks] 74 | for res in results: 75 | if res != 0: 76 | raise Exception('Commands cannot run properly!') 77 | 78 | end_time = datetime.datetime.now() 79 | print(f'Multi tasks FINISHED! Time cost: {end_time - start_time}\n') 80 | -------------------------------------------------------------------------------- /datasets/fgvc_aircraft.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class FGVCAircraft(DatasetBase): 12 | 13 | dataset_dir = "fgvc_aircraft" 14 | 15 | def __init__(self, cfg): 16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 17 | self.dataset_dir = os.path.join(root, self.dataset_dir) 18 | self.image_dir = os.path.join(self.dataset_dir, "images") 19 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 20 | mkdir_if_missing(self.split_fewshot_dir) 21 | 22 | classnames = [] 23 | with open(os.path.join(self.dataset_dir, "variants.txt"), "r") as f: 24 | lines = f.readlines() 25 | for line in lines: 26 | classnames.append(line.strip()) 27 | cname2lab = {c: i for i, c in enumerate(classnames)} 28 | 29 | train = self.read_data(cname2lab, "images_variant_train.txt") 30 | val = self.read_data(cname2lab, "images_variant_val.txt") 31 | test = self.read_data(cname2lab, "images_variant_test.txt") 32 | 33 | num_shots = cfg.DATASET.NUM_SHOTS 34 | if num_shots >= 1: 35 | seed = cfg.SEED 36 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 37 | 38 | if os.path.exists(preprocessed): 39 | print(f"Loading preprocessed few-shot data from {preprocessed}") 40 | with open(preprocessed, "rb") as file: 41 | data = pickle.load(file) 42 | train, val = data["train"], data["val"] 43 | else: 44 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 45 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 46 | data = {"train": train, "val": val} 47 | print(f"Saving preprocessed few-shot data to {preprocessed}") 48 | with open(preprocessed, "wb") as file: 49 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 50 | 51 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 52 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 53 | 54 | super().__init__(train_x=train, val=val, test=test) 55 | 56 | def read_data(self, cname2lab, split_file): 57 | filepath = os.path.join(self.dataset_dir, split_file) 58 | items = [] 59 | 60 | with open(filepath, "r") as f: 61 | lines = f.readlines() 62 | for line in lines: 63 | line = line.strip().split(" ") 64 | imname = line[0] + ".jpg" 65 | classname = " ".join(line[1:]) 66 | impath = os.path.join(self.image_dir, imname) 67 | label = cname2lab[classname] 68 | item = Datum(impath=impath, label=label, classname=classname) 69 | items.append(item) 70 | 71 | return items 72 | -------------------------------------------------------------------------------- /datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | from .dtd import DescribableTextures as DTD 9 | 10 | NEW_CNAMES = { 11 | "AnnualCrop": "Annual Crop Land", 12 | "Forest": "Forest", 13 | "HerbaceousVegetation": "Herbaceous Vegetation Land", 14 | "Highway": "Highway or Road", 15 | "Industrial": "Industrial Buildings", 16 | "Pasture": "Pasture Land", 17 | "PermanentCrop": "Permanent Crop Land", 18 | "Residential": "Residential Buildings", 19 | "River": "River", 20 | "SeaLake": "Sea or Lake", 21 | } 22 | 23 | 24 | @DATASET_REGISTRY.register() 25 | class EuroSAT(DatasetBase): 26 | 27 | dataset_dir = "eurosat" 28 | 29 | def __init__(self, cfg): 30 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 31 | self.dataset_dir = os.path.join(root, self.dataset_dir) 32 | self.image_dir = os.path.join(self.dataset_dir, "2750") 33 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_EuroSAT.json") 34 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 35 | mkdir_if_missing(self.split_fewshot_dir) 36 | 37 | if os.path.exists(self.split_path): 38 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 39 | else: 40 | train, val, test = DTD.read_and_split_data(self.image_dir, new_cnames=NEW_CNAMES) 41 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 42 | 43 | num_shots = cfg.DATASET.NUM_SHOTS 44 | if num_shots >= 1: 45 | seed = cfg.SEED 46 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 47 | 48 | if os.path.exists(preprocessed): 49 | print(f"Loading preprocessed few-shot data from {preprocessed}") 50 | with open(preprocessed, "rb") as file: 51 | data = pickle.load(file) 52 | train, val = data["train"], data["val"] 53 | else: 54 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 55 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 56 | data = {"train": train, "val": val} 57 | print(f"Saving preprocessed few-shot data to {preprocessed}") 58 | with open(preprocessed, "wb") as file: 59 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 60 | 61 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 62 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 63 | 64 | super().__init__(train_x=train, val=val, test=test) 65 | 66 | def update_classname(self, dataset_old): 67 | dataset_new = [] 68 | for item_old in dataset_old: 69 | cname_old = item_old.classname 70 | cname_new = NEW_CLASSNAMES[cname_old] 71 | item_new = Datum(impath=item_old.impath, label=item_old.label, classname=cname_new) 72 | dataset_new.append(item_new) 73 | return dataset_new 74 | -------------------------------------------------------------------------------- /datasets/stanford_cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from scipy.io import loadmat 4 | 5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 6 | from dassl.utils import mkdir_if_missing 7 | 8 | from .oxford_pets import OxfordPets 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class StanfordCars(DatasetBase): 13 | 14 | dataset_dir = "stanford_cars" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_StanfordCars.json") 20 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 21 | mkdir_if_missing(self.split_fewshot_dir) 22 | 23 | if os.path.exists(self.split_path): 24 | train, val, test = OxfordPets.read_split(self.split_path, self.dataset_dir) 25 | else: 26 | trainval_file = os.path.join(self.dataset_dir, "devkit", "cars_train_annos.mat") 27 | test_file = os.path.join(self.dataset_dir, "cars_test_annos_withlabels.mat") 28 | meta_file = os.path.join(self.dataset_dir, "devkit", "cars_meta.mat") 29 | trainval = self.read_data("cars_train", trainval_file, meta_file) 30 | test = self.read_data("cars_test", test_file, meta_file) 31 | train, val = OxfordPets.split_trainval(trainval) 32 | OxfordPets.save_split(train, val, test, self.split_path, self.dataset_dir) 33 | 34 | num_shots = cfg.DATASET.NUM_SHOTS 35 | if num_shots >= 1: 36 | seed = cfg.SEED 37 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 38 | 39 | if os.path.exists(preprocessed): 40 | print(f"Loading preprocessed few-shot data from {preprocessed}") 41 | with open(preprocessed, "rb") as file: 42 | data = pickle.load(file) 43 | train, val = data["train"], data["val"] 44 | else: 45 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 46 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 47 | data = {"train": train, "val": val} 48 | print(f"Saving preprocessed few-shot data to {preprocessed}") 49 | with open(preprocessed, "wb") as file: 50 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 51 | 52 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 53 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 54 | 55 | super().__init__(train_x=train, val=val, test=test) 56 | 57 | def read_data(self, image_dir, anno_file, meta_file): 58 | anno_file = loadmat(anno_file)["annotations"][0] 59 | meta_file = loadmat(meta_file)["class_names"][0] 60 | items = [] 61 | 62 | for i in range(len(anno_file)): 63 | imname = anno_file[i]["fname"][0] 64 | impath = os.path.join(self.dataset_dir, image_dir, imname) 65 | label = anno_file[i]["class"][0, 0] 66 | label = int(label) - 1 # convert to 0-based index 67 | classname = meta_file[label][0] 68 | names = classname.split(" ") 69 | year = names.pop(-1) 70 | names.insert(0, year) 71 | classname = " ".join(names) 72 | item = Datum(impath=impath, label=label, classname=classname) 73 | items.append(item) 74 | 75 | return items 76 | -------------------------------------------------------------------------------- /datasets/sun397.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class SUN397(DatasetBase): 12 | 13 | dataset_dir = "sun397" 14 | 15 | def __init__(self, cfg): 16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 17 | self.dataset_dir = os.path.join(root, self.dataset_dir) 18 | self.image_dir = os.path.join(self.dataset_dir, "SUN397") 19 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_SUN397.json") 20 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 21 | mkdir_if_missing(self.split_fewshot_dir) 22 | 23 | if os.path.exists(self.split_path): 24 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 25 | else: 26 | classnames = [] 27 | with open(os.path.join(self.dataset_dir, "ClassName.txt"), "r") as f: 28 | lines = f.readlines() 29 | for line in lines: 30 | line = line.strip()[1:] # remove / 31 | classnames.append(line) 32 | cname2lab = {c: i for i, c in enumerate(classnames)} 33 | trainval = self.read_data(cname2lab, "Training_01.txt") 34 | test = self.read_data(cname2lab, "Testing_01.txt") 35 | train, val = OxfordPets.split_trainval(trainval) 36 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 37 | 38 | num_shots = cfg.DATASET.NUM_SHOTS 39 | if num_shots >= 1: 40 | seed = cfg.SEED 41 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 42 | 43 | if os.path.exists(preprocessed): 44 | print(f"Loading preprocessed few-shot data from {preprocessed}") 45 | with open(preprocessed, "rb") as file: 46 | data = pickle.load(file) 47 | train, val = data["train"], data["val"] 48 | else: 49 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 50 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 51 | data = {"train": train, "val": val} 52 | print(f"Saving preprocessed few-shot data to {preprocessed}") 53 | with open(preprocessed, "wb") as file: 54 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 55 | 56 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 57 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 58 | 59 | super().__init__(train_x=train, val=val, test=test) 60 | 61 | def read_data(self, cname2lab, text_file): 62 | text_file = os.path.join(self.dataset_dir, text_file) 63 | items = [] 64 | 65 | with open(text_file, "r") as f: 66 | lines = f.readlines() 67 | for line in lines: 68 | imname = line.strip()[1:] # remove / 69 | classname = os.path.dirname(imname) 70 | label = cname2lab[classname] 71 | impath = os.path.join(self.image_dir, imname) 72 | 73 | names = classname.split("/")[1:] # remove 1st letter 74 | names = names[::-1] # put words like indoor/outdoor at first 75 | classname = " ".join(names) 76 | 77 | item = Datum(impath=impath, label=label, classname=classname) 78 | items.append(item) 79 | 80 | return items 81 | -------------------------------------------------------------------------------- /datasets/ucf101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import re 4 | 5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 6 | from dassl.utils import mkdir_if_missing 7 | 8 | from .oxford_pets import OxfordPets 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class UCF101(DatasetBase): 13 | 14 | dataset_dir = "ucf101" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "UCF-101-midframes") 20 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_UCF101.json") 21 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 22 | mkdir_if_missing(self.split_fewshot_dir) 23 | 24 | if os.path.exists(self.split_path): 25 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 26 | else: 27 | cname2lab = {} 28 | filepath = os.path.join(self.dataset_dir, "ucfTrainTestlist/classInd.txt") 29 | with open(filepath, "r") as f: 30 | lines = f.readlines() 31 | for line in lines: 32 | label, classname = line.strip().split(" ") 33 | label = int(label) - 1 # conver to 0-based index 34 | cname2lab[classname] = label 35 | 36 | trainval = self.read_data(cname2lab, "ucfTrainTestlist/trainlist01.txt") 37 | test = self.read_data(cname2lab, "ucfTrainTestlist/testlist01.txt") 38 | train, val = OxfordPets.split_trainval(trainval) 39 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 40 | 41 | num_shots = cfg.DATASET.NUM_SHOTS 42 | if num_shots >= 1: 43 | seed = cfg.SEED 44 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 45 | 46 | if os.path.exists(preprocessed): 47 | print(f"Loading preprocessed few-shot data from {preprocessed}") 48 | with open(preprocessed, "rb") as file: 49 | data = pickle.load(file) 50 | train, val = data["train"], data["val"] 51 | else: 52 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 53 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 54 | data = {"train": train, "val": val} 55 | print(f"Saving preprocessed few-shot data to {preprocessed}") 56 | with open(preprocessed, "wb") as file: 57 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 58 | 59 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 60 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 61 | 62 | super().__init__(train_x=train, val=val, test=test) 63 | 64 | def read_data(self, cname2lab, text_file): 65 | text_file = os.path.join(self.dataset_dir, text_file) 66 | items = [] 67 | 68 | with open(text_file, "r") as f: 69 | lines = f.readlines() 70 | for line in lines: 71 | line = line.strip().split(" ")[0] # trainlist: filename, label 72 | action, filename = line.split("/") 73 | label = cname2lab[action] 74 | 75 | elements = re.findall("[A-Z][^A-Z]*", action) 76 | renamed_action = "_".join(elements) 77 | 78 | filename = filename.replace(".avi", ".jpg") 79 | impath = os.path.join(self.image_dir, renamed_action, filename) 80 | 81 | item = Datum(impath=impath, label=label, classname=renamed_action) 82 | items.append(item) 83 | 84 | return items 85 | -------------------------------------------------------------------------------- /datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from collections import OrderedDict 4 | 5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 6 | from dassl.utils import listdir_nohidden, mkdir_if_missing 7 | 8 | from .oxford_pets import OxfordPets 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class ImageNet(DatasetBase): 13 | 14 | dataset_dir = "imagenet" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "images") 20 | self.preprocessed = os.path.join(self.dataset_dir, "preprocessed.pkl") 21 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 22 | mkdir_if_missing(self.split_fewshot_dir) 23 | 24 | if os.path.exists(self.preprocessed): 25 | with open(self.preprocessed, "rb") as f: 26 | preprocessed = pickle.load(f) 27 | train = preprocessed["train"] 28 | test = preprocessed["test"] 29 | else: 30 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 31 | classnames = self.read_classnames(text_file) 32 | train = self.read_data(classnames, "train") 33 | # Follow standard practice to perform evaluation on the val set 34 | # Also used as the val set (so evaluate the last-step model) 35 | test = self.read_data(classnames, "val") 36 | 37 | preprocessed = {"train": train, "test": test} 38 | with open(self.preprocessed, "wb") as f: 39 | pickle.dump(preprocessed, f, protocol=pickle.HIGHEST_PROTOCOL) 40 | 41 | num_shots = cfg.DATASET.NUM_SHOTS 42 | if num_shots >= 1: 43 | seed = cfg.SEED 44 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 45 | 46 | if os.path.exists(preprocessed): 47 | print(f"Loading preprocessed few-shot data from {preprocessed}") 48 | with open(preprocessed, "rb") as file: 49 | data = pickle.load(file) 50 | train = data["train"] 51 | else: 52 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 53 | data = {"train": train} 54 | print(f"Saving preprocessed few-shot data to {preprocessed}") 55 | with open(preprocessed, "wb") as file: 56 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 57 | 58 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 59 | train, test = OxfordPets.subsample_classes(train, test, subsample=subsample) 60 | 61 | super().__init__(train_x=train, val=test, test=test) 62 | 63 | @staticmethod 64 | def read_classnames(text_file): 65 | """Return a dictionary containing 66 | key-value pairs of : . 67 | """ 68 | classnames = OrderedDict() 69 | with open(text_file, "r") as f: 70 | lines = f.readlines() 71 | for line in lines: 72 | line = line.strip().split(" ") 73 | folder = line[0] 74 | classname = " ".join(line[1:]) 75 | classnames[folder] = classname 76 | return classnames 77 | 78 | def read_data(self, classnames, split_dir): 79 | split_dir = os.path.join(self.image_dir, split_dir) 80 | folders = sorted(f.name for f in os.scandir(split_dir) if f.is_dir()) 81 | items = [] 82 | 83 | for label, folder in enumerate(folders): 84 | imnames = listdir_nohidden(os.path.join(split_dir, folder)) 85 | classname = classnames[folder] 86 | for imname in imnames: 87 | impath = os.path.join(split_dir, folder, imname) 88 | item = Datum(impath=impath, label=label, classname=classname) 89 | items.append(item) 90 | 91 | return items 92 | -------------------------------------------------------------------------------- /datasets/oxford_flowers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | from scipy.io import loadmat 5 | from collections import defaultdict 6 | 7 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 8 | from dassl.utils import read_json, mkdir_if_missing 9 | 10 | from .oxford_pets import OxfordPets 11 | 12 | 13 | @DATASET_REGISTRY.register() 14 | class OxfordFlowers(DatasetBase): 15 | 16 | dataset_dir = "oxford_flowers" 17 | 18 | def __init__(self, cfg): 19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 20 | self.dataset_dir = os.path.join(root, self.dataset_dir) 21 | self.image_dir = os.path.join(self.dataset_dir, "jpg") 22 | self.label_file = os.path.join(self.dataset_dir, "imagelabels.mat") 23 | self.lab2cname_file = os.path.join(self.dataset_dir, "cat_to_name.json") 24 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordFlowers.json") 25 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 26 | mkdir_if_missing(self.split_fewshot_dir) 27 | 28 | if os.path.exists(self.split_path): 29 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 30 | else: 31 | train, val, test = self.read_data() 32 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 33 | 34 | num_shots = cfg.DATASET.NUM_SHOTS 35 | if num_shots >= 1: 36 | seed = cfg.SEED 37 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 38 | 39 | if os.path.exists(preprocessed): 40 | print(f"Loading preprocessed few-shot data from {preprocessed}") 41 | with open(preprocessed, "rb") as file: 42 | data = pickle.load(file) 43 | train, val = data["train"], data["val"] 44 | else: 45 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 46 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 47 | data = {"train": train, "val": val} 48 | print(f"Saving preprocessed few-shot data to {preprocessed}") 49 | with open(preprocessed, "wb") as file: 50 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 51 | 52 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 53 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 54 | 55 | super().__init__(train_x=train, val=val, test=test) 56 | 57 | def read_data(self): 58 | tracker = defaultdict(list) 59 | label_file = loadmat(self.label_file)["labels"][0] 60 | for i, label in enumerate(label_file): 61 | imname = f"image_{str(i + 1).zfill(5)}.jpg" 62 | impath = os.path.join(self.image_dir, imname) 63 | label = int(label) 64 | tracker[label].append(impath) 65 | 66 | print("Splitting data into 50% train, 20% val, and 30% test") 67 | 68 | def _collate(ims, y, c): 69 | items = [] 70 | for im in ims: 71 | item = Datum(impath=im, label=y - 1, classname=c) # convert to 0-based label 72 | items.append(item) 73 | return items 74 | 75 | lab2cname = read_json(self.lab2cname_file) 76 | train, val, test = [], [], [] 77 | for label, impaths in tracker.items(): 78 | random.shuffle(impaths) 79 | n_total = len(impaths) 80 | n_train = round(n_total * 0.5) 81 | n_val = round(n_total * 0.2) 82 | n_test = n_total - n_train - n_val 83 | assert n_train > 0 and n_val > 0 and n_test > 0 84 | cname = lab2cname[str(label)] 85 | train.extend(_collate(impaths[:n_train], label, cname)) 86 | val.extend(_collate(impaths[n_train : n_train + n_val], label, cname)) 87 | test.extend(_collate(impaths[n_train + n_val :], label, cname)) 88 | 89 | return train, val, test 90 | -------------------------------------------------------------------------------- /datasets/dtd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | 5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 6 | from dassl.utils import listdir_nohidden, mkdir_if_missing 7 | 8 | from .oxford_pets import OxfordPets 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class DescribableTextures(DatasetBase): 13 | 14 | dataset_dir = "dtd" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "images") 20 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_DescribableTextures.json") 21 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 22 | mkdir_if_missing(self.split_fewshot_dir) 23 | 24 | if os.path.exists(self.split_path): 25 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 26 | else: 27 | train, val, test = self.read_and_split_data(self.image_dir) 28 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 29 | 30 | num_shots = cfg.DATASET.NUM_SHOTS 31 | if num_shots >= 1: 32 | seed = cfg.SEED 33 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 34 | 35 | if os.path.exists(preprocessed): 36 | print(f"Loading preprocessed few-shot data from {preprocessed}") 37 | with open(preprocessed, "rb") as file: 38 | data = pickle.load(file) 39 | train, val = data["train"], data["val"] 40 | else: 41 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 42 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 43 | data = {"train": train, "val": val} 44 | print(f"Saving preprocessed few-shot data to {preprocessed}") 45 | with open(preprocessed, "wb") as file: 46 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 47 | 48 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 49 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 50 | 51 | super().__init__(train_x=train, val=val, test=test) 52 | 53 | @staticmethod 54 | def read_and_split_data(image_dir, p_trn=0.5, p_val=0.2, ignored=[], new_cnames=None): 55 | # The data are supposed to be organized into the following structure 56 | # ============= 57 | # images/ 58 | # dog/ 59 | # cat/ 60 | # horse/ 61 | # ============= 62 | categories = listdir_nohidden(image_dir) 63 | categories = [c for c in categories if c not in ignored] 64 | categories.sort() 65 | 66 | p_tst = 1 - p_trn - p_val 67 | print(f"Splitting into {p_trn:.0%} train, {p_val:.0%} val, and {p_tst:.0%} test") 68 | 69 | def _collate(ims, y, c): 70 | items = [] 71 | for im in ims: 72 | item = Datum(impath=im, label=y, classname=c) # is already 0-based 73 | items.append(item) 74 | return items 75 | 76 | train, val, test = [], [], [] 77 | for label, category in enumerate(categories): 78 | category_dir = os.path.join(image_dir, category) 79 | images = listdir_nohidden(category_dir) 80 | images = [os.path.join(category_dir, im) for im in images] 81 | random.shuffle(images) 82 | n_total = len(images) 83 | n_train = round(n_total * p_trn) 84 | n_val = round(n_total * p_val) 85 | n_test = n_total - n_train - n_val 86 | assert n_train > 0 and n_val > 0 and n_test > 0 87 | 88 | if new_cnames is not None and category in new_cnames: 89 | category = new_cnames[category] 90 | 91 | train.extend(_collate(images[:n_train], label, category)) 92 | val.extend(_collate(images[n_train : n_train + n_val], label, category)) 93 | test.extend(_collate(images[n_train + n_val :], label, category)) 94 | 95 | return train, val, test 96 | -------------------------------------------------------------------------------- /trainers/optim.py: -------------------------------------------------------------------------------- 1 | """ modified from dassl.optim """ 2 | import warnings 3 | import torch 4 | import torch.nn as nn 5 | 6 | from dassl.optim.radam import RAdam 7 | 8 | AVAI_OPTIMS = ['adam', 'amsgrad', 'sgd', 'rmsprop', 'radam', 'adamw'] 9 | 10 | 11 | def build_optimizer(model, optim_cfg, param_groups=None): 12 | optim = optim_cfg.NAME 13 | lr = optim_cfg.LR 14 | weight_decay = optim_cfg.WEIGHT_DECAY 15 | momentum = optim_cfg.MOMENTUM 16 | sgd_dampening = optim_cfg.SGD_DAMPNING 17 | sgd_nesterov = optim_cfg.SGD_NESTEROV 18 | rmsprop_alpha = optim_cfg.RMSPROP_ALPHA 19 | adam_beta1 = optim_cfg.ADAM_BETA1 20 | adam_beta2 = optim_cfg.ADAM_BETA2 21 | staged_lr = optim_cfg.STAGED_LR 22 | new_layers = optim_cfg.NEW_LAYERS 23 | base_lr_mult = optim_cfg.BASE_LR_MULT 24 | 25 | if optim not in AVAI_OPTIMS: 26 | raise ValueError( 27 | f'optim must be one of {AVAI_OPTIMS}, but got {optim}' 28 | ) 29 | 30 | if param_groups is not None and staged_lr: 31 | warnings.warn( 32 | 'staged_lr will be ignored, if you need to use staged_lr, ' 33 | 'please bind it with param_groups yourself.' 34 | ) 35 | 36 | if param_groups is None: 37 | if staged_lr: 38 | # modify the function of lr_mult 39 | exp = optim_cfg.LR_EXP 40 | lr *= exp 41 | base_lr_mult /= exp 42 | 43 | if not isinstance(model, nn.Module): 44 | raise TypeError( 45 | 'When staged_lr is True, model given to ' 46 | 'build_optimizer() must be an instance of nn.Module' 47 | ) 48 | 49 | if isinstance(model, nn.DataParallel): 50 | model = model.module 51 | 52 | if isinstance(new_layers, str): 53 | if new_layers is None: 54 | warnings.warn('new_layers is empty (staged_lr is useless)') 55 | new_layers = [new_layers] 56 | 57 | base_params, new_params = [], [] 58 | base_layers, new_layers_ = [], [] 59 | 60 | for name, module in model.named_children(): 61 | is_new = False 62 | 63 | for layer in new_layers: 64 | if layer in name: 65 | is_new = True 66 | break 67 | 68 | if is_new: 69 | new_params += [p for p in module.parameters()] 70 | new_layers_.append(name) 71 | else: 72 | base_params += [p for p in module.parameters()] 73 | base_layers.append(name) 74 | 75 | param_groups = [{'params': base_params, 76 | 'lr': lr * base_lr_mult}, 77 | {'params': new_params}] 78 | 79 | # return lr of each layer 80 | infos = [{'layers': base_layers, 81 | 'lr': lr * base_lr_mult}, 82 | {'layers': new_layers_, 83 | 'lr': lr}] 84 | else: 85 | if isinstance(model, nn.Module): 86 | param_groups = model.parameters() 87 | else: 88 | param_groups = model 89 | 90 | infos = None 91 | 92 | if optim == 'adam': 93 | optimizer = torch.optim.Adam( 94 | param_groups, 95 | lr=lr, 96 | weight_decay=weight_decay, 97 | betas=(adam_beta1, adam_beta2), 98 | ) 99 | 100 | elif optim == 'amsgrad': 101 | optimizer = torch.optim.Adam( 102 | param_groups, 103 | lr=lr, 104 | weight_decay=weight_decay, 105 | betas=(adam_beta1, adam_beta2), 106 | amsgrad=True, 107 | ) 108 | 109 | elif optim == 'sgd': 110 | optimizer = torch.optim.SGD( 111 | param_groups, 112 | lr=lr, 113 | momentum=momentum, 114 | weight_decay=weight_decay, 115 | dampening=sgd_dampening, 116 | nesterov=sgd_nesterov, 117 | ) 118 | 119 | elif optim == 'rmsprop': 120 | optimizer = torch.optim.RMSprop( 121 | param_groups, 122 | lr=lr, 123 | momentum=momentum, 124 | weight_decay=weight_decay, 125 | alpha=rmsprop_alpha, 126 | ) 127 | 128 | elif optim == 'radam': 129 | optimizer = RAdam( 130 | param_groups, 131 | lr=lr, 132 | weight_decay=weight_decay, 133 | betas=(adam_beta1, adam_beta2), 134 | ) 135 | 136 | elif optim == 'adamw': 137 | optimizer = torch.optim.AdamW( 138 | param_groups, 139 | lr=lr, 140 | weight_decay=weight_decay, 141 | betas=(adam_beta1, adam_beta2), 142 | ) 143 | else: 144 | raise NotImplementedError(f'Optimizer {optim} not implemented yet!') 145 | 146 | return optimizer, infos -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /clip_maple/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /tests/save_stats.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import math 4 | import numpy as np 5 | import torch 6 | from collections import defaultdict 7 | from tqdm import tqdm 8 | 9 | from dassl.utils import setup_logger, set_random_seed, collect_env_info 10 | from dassl.engine import build_trainer 11 | 12 | import sys 13 | sys.path.append('.') 14 | 15 | # custom 16 | import datasets.oxford_pets 17 | import datasets.oxford_flowers 18 | import datasets.fgvc_aircraft 19 | import datasets.dtd 20 | import datasets.eurosat 21 | import datasets.stanford_cars 22 | import datasets.food101 23 | import datasets.sun397 24 | import datasets.caltech101 25 | import datasets.ucf101 26 | import datasets.imagenet 27 | 28 | import datasets.imagenet_sketch 29 | import datasets.imagenetv2 30 | import datasets.imagenet_a 31 | import datasets.imagenet_r 32 | 33 | import trainers.coop_stats 34 | import trainers.elp_coop_stats 35 | import trainers.oracle_stats 36 | import trainers.zsclip_stats 37 | 38 | from train import print_args, setup_cfg 39 | 40 | 41 | @torch.no_grad() 42 | def main(args): 43 | cfg = setup_cfg(args) 44 | print(cfg) 45 | 46 | if cfg.SEED >= 0: 47 | print("Setting fixed seed: {}".format(cfg.SEED)) 48 | set_random_seed(cfg.SEED) 49 | setup_logger(cfg.OUTPUT_DIR) 50 | 51 | if torch.cuda.is_available() and cfg.USE_CUDA: 52 | torch.backends.cudnn.benchmark = True 53 | 54 | print_args(args, cfg) 55 | print("Collecting env info ...") 56 | print("** System info **\n{}\n".format(collect_env_info())) 57 | 58 | trainer = build_trainer(cfg) 59 | 60 | trainer.load_model(args.model_dir, epoch=args.load_epoch) 61 | trainer.set_model_mode("eval") 62 | 63 | split = cfg.TEST.SPLIT 64 | if split == "val" and trainer.val_loader is not None: 65 | data_loader = trainer.val_loader 66 | else: 67 | split = "test" # in case val_loader is None 68 | data_loader = trainer.test_loader 69 | print(f"Evaluate on the *{split}* set") 70 | 71 | feats = defaultdict(list) 72 | n = trainer.dm.num_classes 73 | print('num classes:', n) 74 | labels = range(n) 75 | m = math.ceil(n / 2) 76 | base_labels = labels[:m] 77 | 78 | for batch_idx, batch in enumerate(tqdm(data_loader)): 79 | input_, labels = trainer.parse_batch_test(batch) 80 | 81 | trainer.is_base = labels[0] in base_labels 82 | text_feats, img_feats = trainer.model_inference(input_) 83 | 84 | labels = labels.detach().cpu().numpy() 85 | text_feats = text_feats.detach().cpu().numpy() 86 | img_feats = img_feats.detach().cpu().numpy() 87 | 88 | feats['label'].append(labels) 89 | feats['text'].append(text_feats) 90 | feats['img'].append(img_feats) 91 | 92 | feats['label'] = np.concatenate(feats['label']) 93 | feats['text'] = feats['text'][0] 94 | feats['img'] = np.concatenate(feats['img']) 95 | 96 | filename = f'feats_{args.trainer}_{cfg.DATASET.NAME}_seed{cfg.SEED}.pkl' 97 | with open(f'stats/{filename}', 'wb') as f: 98 | pickle.dump(feats, f) 99 | 100 | 101 | if __name__ == "__main__": 102 | parser = argparse.ArgumentParser() 103 | parser.add_argument("--root", type=str, default="", help="path to dataset") 104 | parser.add_argument("--output-dir", type=str, default="", help="output directory") 105 | parser.add_argument( 106 | "--resume", 107 | type=str, 108 | default="", 109 | help="checkpoint directory (from which the training resumes)", 110 | ) 111 | parser.add_argument( 112 | "--seed", type=int, default=-1, help="only positive value enables a fixed seed" 113 | ) 114 | parser.add_argument( 115 | "--source-domains", type=str, nargs="+", help="source domains for DA/DG" 116 | ) 117 | parser.add_argument( 118 | "--target-domains", type=str, nargs="+", help="target domains for DA/DG" 119 | ) 120 | parser.add_argument( 121 | "--transforms", type=str, nargs="+", help="data augmentation methods" 122 | ) 123 | parser.add_argument( 124 | "--config-file", type=str, default="", help="path to config file" 125 | ) 126 | parser.add_argument( 127 | "--dataset-config-file", 128 | type=str, 129 | default="", 130 | help="path to config file for dataset setup", 131 | ) 132 | parser.add_argument("--trainer", type=str, default="", help="name of trainer") 133 | parser.add_argument("--backbone", type=str, default="", help="name of CNN backbone") 134 | parser.add_argument("--head", type=str, default="", help="name of head") 135 | parser.add_argument("--eval-only", action="store_true", help="evaluation only") 136 | parser.add_argument( 137 | "--model-dir", 138 | type=str, 139 | default="", 140 | help="load model from this directory for eval-only mode", 141 | ) 142 | parser.add_argument( 143 | "--load-epoch", type=int, help="load model weights at this epoch for evaluation" 144 | ) 145 | parser.add_argument( 146 | "--no-train", action="store_true", help="do not call trainer.train()" 147 | ) 148 | parser.add_argument( 149 | "opts", 150 | default=None, 151 | nargs=argparse.REMAINDER, 152 | help="modify config options using the command-line", 153 | ) 154 | args = parser.parse_args() 155 | main(args) 156 | 157 | -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | 4 | base = dict( 5 | # dataset configs 6 | data = dict( 7 | root='root of datasets here', 8 | datasets_base_to_new=['dtd', 'caltech101', 'eurosat', 'ucf101', 'oxford_flowers', 9 | 'oxford_pets', 'stanford_cars', 'fgvc_aircraft', 'food101', 'sun397', 'imagenet'], 10 | datasets_cross_dataset=['caltech101', 'oxford_pets', 'stanford_cars', 'oxford_flowers', 'food101', 11 | 'fgvc_aircraft', 'sun397', 'dtd', 'eurosat', 'ucf101', 12 | 'imagenetv2', 'imagenet_sketch', 'imagenet_a', 'imagenet_r'], 13 | ), 14 | 15 | # mail configs 16 | mail = dict( 17 | username='somebody@example.com', 18 | password='password here', 19 | host='host here', 20 | to='somebody@example.com', 21 | ), 22 | ) 23 | 24 | ########################################################## 25 | 26 | coop = dict( 27 | # GPU ids, if you have multiple GPUs, it can be setted to [0, 1, 2, ...] 28 | # number of GPU ids is recommanded to be a multiple of 3 29 | # because seeds are 1, 2, 3 30 | gpu_ids = [0], 31 | # gpu_ids = [0, 1, 2], 32 | # training and eval mode 33 | # 'b2n' means base to new, or 'xd' means cross dataset and domain generalization 34 | mode='b2n', 35 | 36 | # training configs 37 | train = dict( 38 | trainer='CoOp', # trainer, please see trainers 39 | cfg='vit_b16_ep10_bs4_lr35', # config, please see configs/ 40 | seeds=[1, 2, 3], # seeds 41 | loadep=-1, # load epoch, -1 to load the last epoch 42 | shots=16, # num of shots 43 | opts=[], # extra opts, if you have, please add, such as [OPTIM.MAX_EPOCH, 10] 44 | ), 45 | 46 | # grid search configs, if enable=False, grid search will not be used 47 | grid_search = dict(enable=False), 48 | 49 | # output configs 50 | output = dict( 51 | root='outputs/coop', # output root 52 | result='results/coop', # result root 53 | remove_dirs=['root'], # which directorys will be removed before training task starts 54 | ), 55 | ) 56 | 57 | cocoop = dict( 58 | gpu_ids = [0], 59 | mode='b2n', 60 | 61 | train = dict( 62 | trainer='CoCoOp', 63 | cfg='vit_b16_c4_ep10_batch1_ctxv1', 64 | seeds=[1, 2, 3], 65 | loadep=-1, 66 | shots=16, 67 | opts=[], 68 | ), 69 | 70 | grid_search = dict(enable=False), 71 | 72 | output = dict( 73 | root='outputs/cocoop', 74 | result='results/cocoop', 75 | remove_dirs=['root'], 76 | ), 77 | ) 78 | 79 | kgcoop = dict( 80 | gpu_ids = [0], 81 | mode='b2n', 82 | 83 | train = dict( 84 | trainer='KgCoOp', 85 | cfg='vit_b16_ep10_ctxv1_bs4_lr35', 86 | seeds=[1, 2, 3], 87 | loadep=-1, 88 | shots=16, 89 | opts=[], 90 | ), 91 | 92 | grid_search = dict(enable=False), 93 | 94 | output = dict( 95 | root='outputs/kgcoop', 96 | result='results/kgcoop', 97 | remove_dirs=['root'], 98 | ), 99 | ) 100 | 101 | maple = dict( 102 | gpu_ids = [0], 103 | mode='b2n', 104 | 105 | train = dict( 106 | trainer='MaPLe', 107 | cfg='vit_b16_c2_ep10_batch4_2ctx', 108 | seeds=[1, 2, 3], 109 | loadep=-1, 110 | shots=16, 111 | opts=[], 112 | ), 113 | 114 | grid_search = dict(enable=False), 115 | 116 | output = dict( 117 | root='outputs/maple', 118 | result='results/maple', 119 | remove_dirs=['root'], 120 | ), 121 | ) 122 | 123 | coop_dept = dict( 124 | gpu_ids = [0], 125 | mode='b2n', 126 | 127 | train = dict( 128 | trainer='ExtrasLinearProbeCoOp', 129 | cfg='vit_b16_ep10_bs4_lr35', 130 | seeds=[1, 2, 3], 131 | loadep=-1, 132 | shots=16, 133 | opts=[], 134 | ), 135 | 136 | grid_search = dict(enable=False), 137 | 138 | output = dict( 139 | root='outputs/coop_dept', 140 | result='results/coop_dept', 141 | remove_dirs=['root'], 142 | ), 143 | ) 144 | 145 | cocoop_dept = dict( 146 | gpu_ids = [0], 147 | mode='b2n', 148 | 149 | train = dict( 150 | trainer='ExtrasLinearProbeCoCoOp', 151 | cfg='vit_b16_c4_ep10_batch1_ctxv1', 152 | seeds=[1, 2, 3], 153 | loadep=-1, 154 | shots=16, 155 | opts=[], 156 | ), 157 | 158 | grid_search = dict(enable=False), 159 | 160 | output = dict( 161 | root='outputs/cocoop_dept', 162 | result='results/cocoop_dept', 163 | remove_dirs=['root'], 164 | ), 165 | ) 166 | 167 | kgcoop_dept = dict( 168 | gpu_ids = [0], 169 | mode='b2n', 170 | 171 | train = dict( 172 | trainer='ExtrasLinearProbeKgCoOp', 173 | cfg='vit_b16_ep10_ctxv1_bs4_lr35', 174 | seeds=[1, 2, 3], 175 | loadep=-1, 176 | shots=16, 177 | opts=[], 178 | ), 179 | 180 | grid_search = dict(enable=False), 181 | 182 | output = dict( 183 | root='outputs/kgcoop_dept', 184 | result='results/kgcoop_dept', 185 | remove_dirs=['root'], 186 | ), 187 | ) 188 | 189 | maple_dept = dict( 190 | gpu_ids = [0], 191 | mode='b2n', 192 | 193 | train = dict( 194 | trainer='ExtrasLinearProbeMaPLe', 195 | cfg='vit_b16_c2_ep10_batch4_2ctx', 196 | seeds=[1, 2, 3], 197 | loadep=-1, 198 | shots=16, 199 | opts=[], 200 | ), 201 | 202 | grid_search = dict(enable=False), 203 | 204 | output = dict( 205 | root='outputs/maple_dept', 206 | result='results/maple_dept', 207 | remove_dirs=['root'], 208 | ), 209 | ) 210 | 211 | def get_config(name): 212 | cfg = copy.deepcopy(base) 213 | extend_cfg = copy.deepcopy(globals()[name]) 214 | cfg.update(extend_cfg) 215 | return cfg 216 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DePT: Decoupled Prompt Tuning 2 | 3 | Our DePT established the following remarkable results without borrowing Extra Knowledge from stronger models (e.g., via KD) or employing extra Data Augmentation strategies. 4 | 5 | Offical implementation of the paper [DePT: Decoupled Prompt Tuning](https://arxiv.org/abs/2309.07439). 6 | 7 | **Note:** We are doing our best to improve this work. If you have any questions or suggestions, please feel free to create an issue in this repo or contact us at jizhang.jim@gmail.com. 8 | 9 | ---- 10 | 11 | # News 12 | 13 | - (Feb. 27, 2024) 14 | 15 | - Our paper is accepted at CVPR 2024! 16 | 17 | - (Nov. 05, 2023) 18 | 19 | - Training and evaluation codes for DePT are released. 20 | 21 | - (Sep. 14, 2023) 22 | 23 | - Our paper is published on arXiv. 24 | 25 | ---- 26 | 27 | # Highlights 28 | 29 | > **Abstract** Prompt tuning has shown great success in adapting large vision-language pre-trained models to downstream tasks. A plethora of methods have been proposed to tackle the base- new tradeoff (BNT) dilemma, i.e., the better the adapted model generalizes to the base (a.k.a. target) task, the worse it generalizes to new tasks, and vice versa. Despite this, the BNT problem is still far from being resolved and its underlying mechanisms are poorly understood. In this work, we bridge this gap by proposing Decoupled Prompt Tuning (DePT), a first framework tackling the BNT problem from a feature decoupling perspective. Specifically, through an in-depth analysis on the learned features of the base and new tasks, we observe that the BNT stems from a channel bias issue, i.e., the vast majority of feature channels are occupied by base-specific knowledge, resulting in the collapse oftask-shared knowledge important to new tasks. To address this, DePT decouples base-specific knowledge from feature channels into an isolated feature space during prompt tuning, so as to maximally preserve task-shared knowledge in the original feature space for achieving better zero-shot generalization on new tasks. DePT is orthogonal to existing prompt tuning methods, hence it can tackle the BNT problem for all of them. Extensive experiments on 11 datasets show the strong flexibility and effectiveness of DePT. 30 | 31 | ![Framework](examples/framework.png) 32 | 33 | ---- 34 | 35 | # Main Contributions 36 | 37 | > 1. We provide an insightful view to analyze the BNT problem in prompt tuning, and for the first time reveal that the BNT stems from the channel bias issue. 38 | > 2. We propose the DePT framework to tackle the BNT problem from a feature decoupling perspective, and DePT is orthogonal to existing prompt tuning methods. 39 | > 3. We perform experiments on 11 diverse datasets and show that DePT consistently enhances the performance of a broad spectrum of baseline methods. 40 | 41 | ---- 42 | 43 | # Flexibility and Effectiveness 44 | 45 | Our DePT is orthogonal to both prompt tuning and adapter tuning approaches, therefore can be used as a plugin to improve all of them. 46 | 47 |
48 | 49 |
50 | 51 | **Base-to-New Generalization Performance** 52 | 53 | ![Base-to-New Generalization](examples/base_to_new_performance.png) 54 | 55 | **Cross-Dataset Generalization Performance** 56 | 57 | ![Cross-Dataset Generalization](examples/cross_dataset_performance.png) 58 | 59 | ---- 60 | 61 | # Installation 62 | 63 | This codebase is tested on Ubuntu 20.04.2 LTS with python 3.8. Follow the below steps to create environment and install dependencies. 64 | 65 | Setup conda environment (recommended). 66 | 67 | **Create a conda environment** 68 | 69 | ``` 70 | conda create -y -n dept python=3.8 71 | conda activate dept 72 | ``` 73 | 74 | **Install torch (requires version >= 1.8.1) and torchvision** 75 | 76 | ``` 77 | pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html 78 | ``` 79 | 80 | **Install dassl** 81 | 82 | ``` 83 | git clone https://github.com/KaiyangZhou/Dassl.pytorch.git 84 | cd Dassl.pytorch/ 85 | pip install -r requirements.txt 86 | python setup.py develop 87 | ``` 88 | 89 | **Install DePT** 90 | 91 | ``` 92 | cd .. 93 | 94 | git clone https://github.com/koorye/DePT.git 95 | cd DePT/ 96 | 97 | pip install -r requirements.txt 98 | pip install setuptools==59.5.0 99 | ``` 100 | 101 | ---- 102 | 103 | # Data preparation 104 | 105 | Please follow the instructions at [DATASETS.md](datasets/DATASETS.md) to prepare all datasets. 106 | 107 | ---- 108 | 109 | # Training and Evaluation 110 | 111 | We provide parallel running script `parallel_runner.py` for each prompting variant including CoOp (w/ DePT), CoCoOp (w/ DePT), KgCoOp (w/ DePT), MaPLe (w/ DePT). Make sure to configure the dataset paths in environment variable DATA and run the commands from the main directory. 112 | 113 | **Base to New Generalization** 114 | 115 | ``` 116 | # Running CoOp (w/ DePT) 117 | python parallel_runner.py --cfg coop 118 | python parallel_runner.py --cfg coop_dept 119 | 120 | # Running CoCoOp (w/ DePT) 121 | python parallel_runner.py --cfg cocoop 122 | python parallel_runner.py --cfg cocoop_dept 123 | 124 | # Running KgCoOp (w/ DePT) 125 | python parallel_runner.py --cfg kgcoop 126 | python parallel_runner.py --cfg kgcoop_dept 127 | 128 | # Running MaPLe (w/ DePT) 129 | python parallel_runner.py --cfg maple 130 | python parallel_runner.py --cfg maple_dept 131 | ``` 132 | 133 | After running, the output will be in the `outputs/` directory, the results will be tallied in the `results/` directory as csv, and a mail will be sent to email address. 134 | 135 | If you want to add your own models, you'll need to write your models in the `trainers/` directory and register them in dassl, then configure the settings in the `configs/` directory and `train.py` file, and add your new tasks to the `configs.py` file. Then you can run `python parallel_runner.py --cfg your_model` to run our own model. 136 | 137 | ---- 138 | 139 | # Citation 140 | 141 | If you use our work, please consider citing 142 | 143 | ``` 144 | @inproceedings{zhang2024dept, 145 | title={Dept: Decoupled prompt tuning}, 146 | author={Zhang, Ji and Wu, Shihan and Gao, Lianli and Shen, Heng Tao and Song, Jingkuan}, 147 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 148 | pages={12924--12933}, 149 | year={2024} 150 | }v 151 | ``` 152 | 153 | ---- 154 | 155 | # Acknowledgements 156 | 157 | Our code is based on [CoOp, CoCoOp](https://github.com/KaiyangZhou/CoOp), [KgCoOp](https://github.com/htyao89/KgCoOp) and [MaPLe](https://github.com/muzairkhattak/multimodal-prompt-learning) repositories. We thank the authors for releasing their code. If you use our model and code, please consider citing these works as well. 158 | -------------------------------------------------------------------------------- /tests/channel_importance.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import os 4 | import os.path as osp 5 | import pandas as pd 6 | import patchworklib as pw 7 | import pickle 8 | import shutil 9 | import torch 10 | from plotnine import * 11 | from tqdm import tqdm 12 | 13 | 14 | def read_feats(root, trainer, dataset, seed): 15 | filename = f'feats_{trainer}_{dataset}_seed{seed}.pkl' 16 | path = osp.join(root, filename) 17 | with open(path, 'rb') as f: 18 | feats = pickle.load(f) 19 | return feats 20 | 21 | def filter_labels(feats, labels_reserved): 22 | labels = feats['label'].copy() 23 | text_feats = feats['text'].copy() 24 | img_feats = feats['img'].copy() 25 | 26 | flag = [label in labels_reserved for label in labels] 27 | labels = labels[flag] 28 | img_feats = img_feats[flag] 29 | text_feats = text_feats[labels_reserved] 30 | 31 | label_to_new_label = {label: idx for idx, label in enumerate(labels_reserved)} 32 | labels = np.array([label_to_new_label[label] for label in labels]) 33 | 34 | feats = dict( 35 | label=labels, 36 | text=text_feats, 37 | img=img_feats) 38 | 39 | return feats 40 | 41 | def cal_importance_per_channel(feats, channel_inds): 42 | labels = torch.from_numpy(feats['label']).long().cuda() 43 | img_feats = torch.from_numpy(feats['img']).float().cuda() 44 | text_feats = torch.from_numpy(feats['text']).float().cuda() 45 | 46 | img_feats = img_feats[:, channel_inds] 47 | text_feats = text_feats[:, channel_inds] 48 | # (N, D), (C, D) -> (N, D), (D, C) -> (N, C) 49 | similarities = img_feats @ text_feats.t() 50 | similarities = similarities.clamp_min_(0.) 51 | # (N, C) -> (N,) 52 | similarities_gt = torch.gather(similarities, 1, labels.unsqueeze(-1)).squeeze(-1) 53 | # (N, C) -> (N,) 54 | # (N,), (N,) -> (N,) -> scaler 55 | importance = similarities_gt / (similarities.mean(dim=1) + 1e-12) 56 | return importance.mean().item() 57 | 58 | def cal_importance(feats): 59 | importances = [] 60 | for channel_ind_start in tqdm(range(0, 512)): 61 | channel_ind_end = channel_ind_start + 64 62 | channel_inds = list(range(channel_ind_start, channel_ind_end)) 63 | channel_inds = [ind - 512 if ind >= 512 else ind for ind in channel_inds] 64 | importance = cal_importance_per_channel(feats, channel_inds) 65 | importances.append(importance) 66 | 67 | return np.array(importances) 68 | 69 | 70 | ROOT = 'stats/' 71 | TRAINERS = ['OracleStats', 'CoOpStats'] 72 | DATASETS = ['EuroSAT'] 73 | SEEDS = [1, 2, 3] 74 | SAVE_ROOT = 'viz/' 75 | 76 | 77 | def main(): 78 | if osp.exists(SAVE_ROOT): 79 | shutil.rmtree(SAVE_ROOT) 80 | os.makedirs(SAVE_ROOT) 81 | 82 | all_plots_density, all_plots_point = [], [] 83 | 84 | for dataset in DATASETS: 85 | plots_density, plots_point = [], [] 86 | 87 | for seed in SEEDS: 88 | print(f'Stating on Dataset {dataset}, seed {seed}...') 89 | dfs = [] 90 | 91 | for trainer in TRAINERS: 92 | feats = read_feats(ROOT, trainer, dataset, seed) 93 | 94 | n = max(feats['label']) + 1 95 | m = math.ceil(n / 2) 96 | base_labels = list(range(0, m)) 97 | novel_labels = list(range(m, n)) 98 | 99 | base_feats = filter_labels(feats, base_labels) 100 | novel_feats = filter_labels(feats, novel_labels) 101 | 102 | base_importances = cal_importance(base_feats) 103 | novel_importances = cal_importance(novel_feats) 104 | 105 | df = pd.DataFrame({ 106 | 'channel_idx': range(len(base_importances)), 107 | 'base': base_importances, 108 | 'novel': novel_importances, 109 | 'trainer': trainer}) 110 | dfs.append(df.copy()) 111 | 112 | df = pd.concat(dfs) 113 | 114 | df['ratio'] = df['base'] / df['novel'] 115 | 116 | p_density = (ggplot(df, aes('ratio', fill='trainer', color='trainer')) 117 | + geom_density(alpha=0.5) 118 | + ggtitle(f'{dataset} seed{seed}') 119 | + theme_seaborn() 120 | + theme(axis_text_x=element_text(angle=0), 121 | axis_text_y=element_text(angle=90), 122 | plot_title=element_text(hjust=0.5), 123 | panel_spacing_x=0.02, 124 | legend_position='bottom', 125 | axis_title_x=element_text(face='bold'), 126 | axis_title_y=element_text(face='bold'), 127 | strip_background=element_blank(), 128 | strip_text_x=element_blank()) 129 | + labs(x='Ratio', y='Density', color='')) 130 | 131 | plots_density.append(pw.load_ggplot(p_density)) 132 | 133 | p_point = [] 134 | for idx, df in enumerate(dfs): 135 | trainer = df['trainer'].tolist()[0] 136 | if idx == 1: 137 | title = f'{dataset} seed{seed}\n{trainer}' 138 | else: 139 | title = trainer 140 | 141 | df = df.sort_values('base') 142 | df['order'] = range(len(df)) 143 | df = df.sort_values('channel_idx') 144 | 145 | df = pd.melt(df, id_vars=['channel_idx', 'trainer', 'order'], value_vars=['base', 'novel']) 146 | 147 | p = (ggplot(df, aes('reorder(channel_idx, order)', 'value', color='variable')) 148 | + geom_point() 149 | + ggtitle(title) 150 | + theme_seaborn() 151 | + theme(axis_text_x=element_blank(), 152 | plot_title=element_text(hjust=0.5), 153 | legend_position='bottom', 154 | axis_title_x=element_text(face='bold'), 155 | axis_title_y=element_text(face='bold')) 156 | + labs(x='Channel', y='Importance', color='')) 157 | 158 | p_point.append(pw.load_ggplot(p, figsize=(4, 6))) 159 | 160 | p_point = pw.stack(p_point, operator='|') 161 | plots_point.append(p_point) 162 | 163 | all_plots_density.append(pw.stack(plots_density, operator='|')) 164 | all_plots_point.append(pw.stack(plots_point, operator='|')) 165 | 166 | p_density = pw.stack(all_plots_density, operator='/') 167 | p_point = pw.stack(all_plots_point, operator='/') 168 | p_density.savefig(f'{SAVE_ROOT}/all_density_channel.jpg') 169 | p_point.savefig(f'{SAVE_ROOT}/all_point_channel.jpg') 170 | 171 | 172 | if __name__ == '__main__': 173 | main() 174 | -------------------------------------------------------------------------------- /utils/result_parser.py: -------------------------------------------------------------------------------- 1 | # result parser for base to new and cross dataset tasks, result will be saved as csv 2 | import argparse 3 | import os 4 | import os.path as osp 5 | import pandas as pd 6 | import re 7 | 8 | 9 | ORDERS_BASE_TO_NEW = ['imagenet', 'caltech101', 'oxford_pets', 'stanford_cars', 'oxford_flowers', 10 | 'food101', 'fgvc_aircraft', 'sun397', 'dtd', 'eurosat', 'ucf101'] 11 | 12 | ORDERS_CROSS_DATASET = ['imagenet', 'caltech101', 'oxford_pets', 'stanford_cars', 'oxford_flowers', 13 | 'food101', 'fgvc_aircraft', 'sun397', 'dtd', 'eurosat', 'ucf101', 14 | 'imagenetv2', 'imagenet_sketch', 'imagenet_a', 'imagenet_r',] 15 | 16 | 17 | class ResultParser(object): 18 | def __init__(self, mode, dir_, save_path): 19 | self.mode = mode 20 | self.dir_ = dir_ 21 | self.save_path = save_path 22 | 23 | def parse_and_save(self): 24 | if self.mode == 'b2n': 25 | self.read_accs_base_to_new() 26 | elif self.mode == 'xd': 27 | self.read_accs_cross_dataset() 28 | 29 | self.save() 30 | 31 | def load_property(self, dir_): 32 | """ get property (trainer, datasets, num_shots, cfg, seeds) from directory """ 33 | trainer = [subdir for subdir in os.listdir(dir_) if osp.isdir(osp.join(dir_, subdir))][0] 34 | 35 | dir_ = osp.join(dir_, trainer) 36 | datasets = os.listdir(dir_) 37 | 38 | if self.mode == 'b2n': 39 | datasets = [dataset for dataset in ORDERS_BASE_TO_NEW if dataset in datasets] 40 | elif self.mode == 'xd': 41 | datasets = [dataset for dataset in ORDERS_CROSS_DATASET if dataset in datasets] 42 | else: 43 | raise NotImplementedError 44 | 45 | dir_ = osp.join(dir_, datasets[0]) 46 | num_shots = int(os.listdir(dir_)[0][5:]) 47 | 48 | dir_ = osp.join(dir_, f'shots{num_shots}') 49 | cfg = os.listdir(dir_)[0] 50 | 51 | dir_ = osp.join(dir_, cfg) 52 | seeds = list(sorted([int(name[4:]) for name in os.listdir(dir_)])) 53 | 54 | self.prop = dict( 55 | trainer=trainer, 56 | datasets=datasets, 57 | num_shots=num_shots, 58 | cfg=cfg, 59 | seeds=seeds) 60 | 61 | def read_accs_base_to_new(self): 62 | dir_ = self.dir_ 63 | 64 | base_dir = osp.join(dir_, 'train_base') 65 | new_dir = osp.join(dir_, 'test_new') 66 | 67 | self.load_property(base_dir) 68 | prop = self.prop 69 | 70 | trainer = prop['trainer'] 71 | datasets = prop['datasets'] 72 | num_shots = prop['num_shots'] 73 | cfg = prop['cfg'] 74 | seeds = prop['seeds'] 75 | 76 | headers = ['dataset', 77 | 'base_acc_seed1', 'new_acc_seed1', 'H_seed1', 78 | 'base_acc_seed2', 'new_acc_seed2', 'H_seed2', 79 | 'base_acc_seed3', 'new_acc_seed3', 'H_seed3'] 80 | rows = [] 81 | 82 | for dataset in datasets: 83 | row = [dataset] 84 | 85 | for seed in seeds: 86 | base_path = osp.join(base_dir, trainer, dataset, f'shots{num_shots}', cfg, f'seed{seed}', 'log.txt') 87 | new_path = osp.join(new_dir, trainer, dataset, f'shots{num_shots}', cfg, f'seed{seed}', 'log.txt') 88 | 89 | base_acc = self._read_acc(base_path) 90 | new_acc = self._read_acc(new_path) 91 | H = 2 / (1 / base_acc + 1 / new_acc) 92 | 93 | row += [base_acc, new_acc, H] 94 | 95 | rows.append(row) 96 | 97 | df = pd.DataFrame(rows, columns=headers) 98 | df['base_acc'] = (df['base_acc_seed1'] + df['base_acc_seed2'] + df['base_acc_seed3']) / 3 99 | df['new_acc'] = (df['new_acc_seed1'] + df['new_acc_seed2'] + df['new_acc_seed3']) / 3 100 | 101 | df.loc[len(df.index)] = ['average'] + df.drop(columns=['dataset']).mean().tolist() 102 | df['H'] = 2 / (1 / df['base_acc'] + 1 / df['new_acc']) 103 | 104 | self.df = df 105 | 106 | 107 | def read_accs_cross_dataset(self): 108 | dir_ = self.dir_ 109 | 110 | self.load_property(dir_) 111 | prop = self.prop 112 | 113 | trainer = prop['trainer'] 114 | datasets = prop['datasets'] 115 | num_shots = prop['num_shots'] 116 | cfg = prop['cfg'] 117 | seeds = prop['seeds'] 118 | 119 | headers = ['dataset', 'acc_seed1', 'acc_seed2', 'acc_seed3'] 120 | rows = [] 121 | 122 | datasets = [dataset for dataset in ORDERS_CROSS_DATASET if dataset in datasets] 123 | for dataset in datasets: 124 | row = [dataset] 125 | 126 | for seed in seeds: 127 | path = osp.join(dir_, trainer, dataset, f'shots{num_shots}', cfg, f'seed{seed}', 'log.txt') 128 | acc = self._read_acc(path) 129 | row.append(acc) 130 | 131 | rows.append(row) 132 | 133 | df = pd.DataFrame(rows, columns=headers) 134 | df['acc'] = (df['acc_seed1'] + df['acc_seed2'] + df['acc_seed3']) / 3 135 | 136 | dg_datasets = [dataset for dataset in datasets 137 | if 'imagenet' in dataset and dataset != 'imagenet'] 138 | xd_datasets = [dataset for dataset in datasets 139 | if dataset not in dg_datasets and dataset != 'imagenet'] 140 | 141 | dg_df = df.loc[df['dataset'].isin(dg_datasets)].copy().reset_index(drop=True) 142 | xd_df = df.loc[df['dataset'].isin(xd_datasets)].copy().reset_index(drop=True) 143 | img_net_df = df.loc[df['dataset'] == 'imagenet'].copy().reset_index(drop=True) 144 | 145 | dg_df.loc[len(dg_df.index)] = ['average_dg'] + dg_df.drop(columns=['dataset']).mean().tolist() 146 | xd_df.loc[len(xd_df.index)] = ['average_xd'] + xd_df.drop(columns=['dataset']).mean().tolist() 147 | 148 | df = pd.concat([img_net_df, xd_df, dg_df]).reset_index(drop=True) 149 | 150 | self.df = df 151 | 152 | def save(self): 153 | save_path = self.save_path 154 | save_dir = osp.join(*save_path.replace('\\', '/').split('/')[:-1]) 155 | 156 | os.makedirs(save_dir, exist_ok=True) 157 | self.df.round(2).to_csv(save_path, index=None) 158 | 159 | def _read_acc(self, path): 160 | with open(path, encoding='utf-8') as f: 161 | content = ''.join(f.readlines()) 162 | try: 163 | acc = float(re.findall(r'accuracy\: (\d+\.\d*)\%', content)[-1]) 164 | return acc 165 | except BaseException as e: 166 | print(f'Key word "accuracy" not found in file {path}!') 167 | raise e 168 | 169 | 170 | def main(args): 171 | parser = ResultParser(args.mode, args.dir, args.save_path) 172 | parser.parse_and_save() 173 | 174 | 175 | if __name__ == '__main__': 176 | parser = argparse.ArgumentParser() 177 | parser.add_argument('--mode', type=str, help='mode, b2n or xd') 178 | parser.add_argument('--dir', type=str, help='directory which need to stats') 179 | parser.add_argument('--save-path', type=str, help='directory to save statistics') 180 | args = parser.parse_args() 181 | main(args) 182 | -------------------------------------------------------------------------------- /datasets/oxford_pets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import math 4 | import random 5 | from collections import defaultdict 6 | 7 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 8 | from dassl.utils import read_json, write_json, mkdir_if_missing 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class OxfordPets(DatasetBase): 13 | 14 | dataset_dir = "oxford_pets" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "images") 20 | self.anno_dir = os.path.join(self.dataset_dir, "annotations") 21 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordPets.json") 22 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 23 | mkdir_if_missing(self.split_fewshot_dir) 24 | 25 | if os.path.exists(self.split_path): 26 | train, val, test = self.read_split(self.split_path, self.image_dir) 27 | else: 28 | trainval = self.read_data(split_file="trainval.txt") 29 | test = self.read_data(split_file="test.txt") 30 | train, val = self.split_trainval(trainval) 31 | self.save_split(train, val, test, self.split_path, self.image_dir) 32 | 33 | num_shots = cfg.DATASET.NUM_SHOTS 34 | if num_shots >= 1: 35 | seed = cfg.SEED 36 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 37 | 38 | if os.path.exists(preprocessed): 39 | print(f"Loading preprocessed few-shot data from {preprocessed}") 40 | with open(preprocessed, "rb") as file: 41 | data = pickle.load(file) 42 | train, val = data["train"], data["val"] 43 | else: 44 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 45 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 46 | data = {"train": train, "val": val} 47 | print(f"Saving preprocessed few-shot data to {preprocessed}") 48 | with open(preprocessed, "wb") as file: 49 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 50 | 51 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 52 | train, val, test = self.subsample_classes(train, val, test, subsample=subsample) 53 | 54 | super().__init__(train_x=train, val=val, test=test) 55 | 56 | def read_data(self, split_file): 57 | filepath = os.path.join(self.anno_dir, split_file) 58 | items = [] 59 | 60 | with open(filepath, "r") as f: 61 | lines = f.readlines() 62 | for line in lines: 63 | line = line.strip() 64 | imname, label, species, _ = line.split(" ") 65 | breed = imname.split("_")[:-1] 66 | breed = "_".join(breed) 67 | breed = breed.lower() 68 | imname += ".jpg" 69 | impath = os.path.join(self.image_dir, imname) 70 | label = int(label) - 1 # convert to 0-based index 71 | item = Datum(impath=impath, label=label, classname=breed) 72 | items.append(item) 73 | 74 | return items 75 | 76 | @staticmethod 77 | def split_trainval(trainval, p_val=0.2): 78 | p_trn = 1 - p_val 79 | print(f"Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val") 80 | tracker = defaultdict(list) 81 | for idx, item in enumerate(trainval): 82 | label = item.label 83 | tracker[label].append(idx) 84 | 85 | train, val = [], [] 86 | for label, idxs in tracker.items(): 87 | n_val = round(len(idxs) * p_val) 88 | assert n_val > 0 89 | random.shuffle(idxs) 90 | for n, idx in enumerate(idxs): 91 | item = trainval[idx] 92 | if n < n_val: 93 | val.append(item) 94 | else: 95 | train.append(item) 96 | 97 | return train, val 98 | 99 | @staticmethod 100 | def save_split(train, val, test, filepath, path_prefix): 101 | def _extract(items): 102 | out = [] 103 | for item in items: 104 | impath = item.impath 105 | label = item.label 106 | classname = item.classname 107 | impath = impath.replace(path_prefix, "") 108 | if impath.startswith("/"): 109 | impath = impath[1:] 110 | out.append((impath, label, classname)) 111 | return out 112 | 113 | train = _extract(train) 114 | val = _extract(val) 115 | test = _extract(test) 116 | 117 | split = {"train": train, "val": val, "test": test} 118 | 119 | write_json(split, filepath) 120 | print(f"Saved split to {filepath}") 121 | 122 | @staticmethod 123 | def read_split(filepath, path_prefix): 124 | def _convert(items): 125 | out = [] 126 | for impath, label, classname in items: 127 | impath = os.path.join(path_prefix, impath) 128 | item = Datum(impath=impath, label=int(label), classname=classname) 129 | out.append(item) 130 | return out 131 | 132 | print(f"Reading split from {filepath}") 133 | split = read_json(filepath) 134 | train = _convert(split["train"]) 135 | val = _convert(split["val"]) 136 | test = _convert(split["test"]) 137 | 138 | return train, val, test 139 | 140 | @staticmethod 141 | def subsample_classes(*args, subsample="all"): 142 | """Divide classes into two groups. The first group 143 | represents base classes while the second group represents 144 | new classes. 145 | 146 | Args: 147 | args: a list of datasets, e.g. train, val and test. 148 | subsample (str): what classes to subsample. 149 | """ 150 | assert subsample in ["all", "base", "new"] 151 | 152 | if subsample == "all": 153 | return args 154 | 155 | dataset = args[0] 156 | labels = set() 157 | for item in dataset: 158 | labels.add(item.label) 159 | labels = list(labels) 160 | labels.sort() 161 | n = len(labels) 162 | # Divide classes into two halves 163 | m = math.ceil(n / 2) 164 | 165 | print(f"SUBSAMPLE {subsample.upper()} CLASSES!") 166 | if subsample == "base": 167 | selected = labels[:m] # take the first half 168 | else: 169 | selected = labels[m:] # take the second half 170 | relabeler = {y: y_new for y_new, y in enumerate(selected)} 171 | 172 | output = [] 173 | for dataset in args: 174 | dataset_new = [] 175 | for item in dataset: 176 | if item.label not in selected: 177 | continue 178 | item_new = Datum( 179 | impath=item.impath, 180 | label=relabeler[item.label], 181 | classname=item.classname 182 | ) 183 | dataset_new.append(item_new) 184 | output.append(dataset_new) 185 | 186 | return output 187 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') 3 | 4 | import argparse 5 | import os 6 | import os.path as osp 7 | import traceback 8 | import torch 9 | 10 | from dassl.utils import set_random_seed, collect_env_info 11 | from dassl.config import get_cfg_default, clean_cfg 12 | from dassl.engine import build_trainer 13 | 14 | from utils.logger import setup_logger, print 15 | 16 | # register datasets and trainers 17 | import datasets 18 | import trainers 19 | 20 | 21 | def print_args(args, cfg): 22 | print("***************") 23 | print("** Arguments **") 24 | print("***************") 25 | optkeys = list(args.__dict__.keys()) 26 | optkeys.sort() 27 | for key in optkeys: 28 | print("{}: {}".format(key, args.__dict__[key])) 29 | print("************") 30 | print("** Config **") 31 | print("************") 32 | print(cfg) 33 | 34 | 35 | def reset_cfg(cfg, args): 36 | if args.root: 37 | cfg.DATASET.ROOT = args.root 38 | 39 | if args.output_dir: 40 | cfg.OUTPUT_DIR = args.output_dir 41 | 42 | if args.resume: 43 | cfg.RESUME = args.resume 44 | 45 | if args.seed: 46 | cfg.SEED = args.seed 47 | 48 | if args.source_domains: 49 | cfg.DATASET.SOURCE_DOMAINS = args.source_domains 50 | 51 | if args.target_domains: 52 | cfg.DATASET.TARGET_DOMAINS = args.target_domains 53 | 54 | if args.transforms: 55 | cfg.INPUT.TRANSFORMS = args.transforms 56 | 57 | if args.trainer: 58 | cfg.TRAINER.NAME = args.trainer 59 | 60 | if args.backbone: 61 | cfg.MODEL.BACKBONE.NAME = args.backbone 62 | 63 | if args.head: 64 | cfg.MODEL.HEAD.NAME = args.head 65 | 66 | 67 | def extend_cfg(cfg): 68 | from yacs.config import CfgNode as CN 69 | 70 | # optim settings, new layers' lr will be setted to 0.0035 * 6.5 71 | cfg.OPTIM.LR_EXP = 6.5 72 | cfg.OPTIM.STAGED_LR = True 73 | cfg.OPTIM.NEW_LAYERS = ['linear_probe', 'film'] 74 | cfg.OPTIM.LR = 0.0035 75 | cfg.OPTIM.BASE_LR_MULT = 1.0 76 | cfg.OPTIM.MAX_EPOCH = 10 77 | 78 | cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new 79 | cfg.TRAIN.CHECKPOINT_FREQ = -1 80 | 81 | # modules which need to update 82 | cfg.TRAINER.NAMES_TO_UPDATE = ['prompt_learner', 'linear_probe', 'film'] 83 | 84 | # linear classifier settings 85 | cfg.TRAINER.LINEAR_PROBE = CN() 86 | cfg.TRAINER.LINEAR_PROBE.TYPE = 'linear' 87 | cfg.TRAINER.LINEAR_PROBE.WEIGHT = 0.7 88 | cfg.TRAINER.LINEAR_PROBE.TEST_TIME_FUSION = True 89 | 90 | # cwT module settings 91 | cfg.TRAINER.FILM = CN() 92 | cfg.TRAINER.FILM.LINEAR_PROBE = True 93 | 94 | # CoOp settings 95 | cfg.TRAINER.COOP = CN() 96 | cfg.TRAINER.COOP.N_CTX = 16 # number of context vectors 97 | cfg.TRAINER.COOP.CSC = False # class-specific context 98 | cfg.TRAINER.COOP.CTX_INIT = "" # initialization words 99 | cfg.TRAINER.COOP.PREC = "fp16" # fp16, fp32, amp 100 | cfg.TRAINER.COOP.CLASS_TOKEN_POSITION = "end" 101 | cfg.TRAINER.COOP.ALPHA = 1.0 # for KgCoOp but NOT USE 102 | cfg.TRAINER.COOP.W = 2.0 # for KgCoOp 103 | 104 | # CoCoOp settings 105 | cfg.TRAINER.COCOOP = CN() 106 | cfg.TRAINER.COCOOP.N_CTX = 4 # number of context vectors 107 | cfg.TRAINER.COCOOP.CTX_INIT = "a photo of a" # initialization words 108 | cfg.TRAINER.COCOOP.PREC = "fp16" # fp16, fp32, amp 109 | 110 | # MaPLe settings 111 | cfg.TRAINER.MAPLE = CN() 112 | cfg.TRAINER.MAPLE.N_CTX = 2 # number of context vectors 113 | cfg.TRAINER.MAPLE.CTX_INIT = "a photo of a" # initialization words 114 | cfg.TRAINER.MAPLE.PREC = "fp16" # fp16, fp32, amp 115 | cfg.TRAINER.MAPLE.PROMPT_DEPTH = 9 # Max 12, minimum 0, for 1 it will act as shallow MaPLe (J=1) 116 | 117 | 118 | def setup_cfg(args): 119 | cfg = get_cfg_default() 120 | 121 | clean_cfg(cfg, 'COOP') 122 | 123 | extend_cfg(cfg) 124 | 125 | # 1. From the dataset config file 126 | if args.dataset_config_file: 127 | cfg.merge_from_file(args.dataset_config_file) 128 | 129 | # 2. From the method config file 130 | if args.config_file: 131 | cfg.merge_from_file(args.config_file) 132 | 133 | # 3. From input arguments 134 | reset_cfg(cfg, args) 135 | 136 | # 4. From optional input arguments 137 | cfg.merge_from_list(args.opts) 138 | 139 | cfg.freeze() 140 | 141 | return cfg 142 | 143 | 144 | def main(args): 145 | exception_path = osp.join(args.output_dir, 'exceptions.txt') 146 | if osp.exists(exception_path): 147 | os.remove(exception_path) 148 | 149 | try: 150 | cfg = setup_cfg(args) 151 | setup_logger(cfg.OUTPUT_DIR) 152 | 153 | if cfg.SEED >= 0: 154 | print("Setting fixed seed: {}".format(cfg.SEED)) 155 | set_random_seed(cfg.SEED) 156 | 157 | if torch.cuda.is_available() and cfg.USE_CUDA: 158 | torch.backends.cudnn.benchmark = True 159 | 160 | print_args(args, cfg) 161 | print("Collecting env info ...") 162 | print("** System info **\n{}\n".format(collect_env_info())) 163 | 164 | trainer = build_trainer(cfg) 165 | 166 | if args.eval_only: 167 | trainer.load_model(args.model_dir, epoch=args.load_epoch) 168 | trainer.test() 169 | return 170 | 171 | if not args.no_train: 172 | trainer.train() 173 | except: 174 | # handle exception, contents of exception will be saved to exception_path 175 | e = traceback.format_exc() 176 | with open(exception_path, 'w') as f: 177 | f.write(e) 178 | raise Exception('Training task does not run successfully!') 179 | 180 | 181 | if __name__ == "__main__": 182 | parser = argparse.ArgumentParser() 183 | parser.add_argument("--root", type=str, default="", help="path to dataset") 184 | parser.add_argument("--output-dir", type=str, default="", help="output directory") 185 | parser.add_argument( 186 | "--resume", 187 | type=str, 188 | default="", 189 | help="checkpoint directory (from which the training resumes)", 190 | ) 191 | parser.add_argument( 192 | "--seed", type=int, default=-1, help="only positive value enables a fixed seed" 193 | ) 194 | parser.add_argument( 195 | "--source-domains", type=str, nargs="+", help="source domains for DA/DG" 196 | ) 197 | parser.add_argument( 198 | "--target-domains", type=str, nargs="+", help="target domains for DA/DG" 199 | ) 200 | parser.add_argument( 201 | "--transforms", type=str, nargs="+", help="data augmentation methods" 202 | ) 203 | parser.add_argument( 204 | "--config-file", type=str, default="", help="path to config file" 205 | ) 206 | parser.add_argument( 207 | "--dataset-config-file", 208 | type=str, 209 | default="", 210 | help="path to config file for dataset setup", 211 | ) 212 | parser.add_argument("--trainer", type=str, default="", help="name of trainer") 213 | parser.add_argument("--backbone", type=str, default="", help="name of CNN backbone") 214 | parser.add_argument("--head", type=str, default="", help="name of head") 215 | parser.add_argument("--eval-only", action="store_true", help="evaluation only") 216 | parser.add_argument( 217 | "--model-dir", 218 | type=str, 219 | default="", 220 | help="load model from this directory for eval-only mode", 221 | ) 222 | parser.add_argument( 223 | "--load-epoch", type=int, help="load model weights at this epoch for evaluation" 224 | ) 225 | parser.add_argument( 226 | "--no-train", action="store_true", help="do not call trainer.train()" 227 | ) 228 | parser.add_argument( 229 | "opts", 230 | default=None, 231 | nargs=argparse.REMAINDER, 232 | help="modify config options using the command-line", 233 | ) 234 | args = parser.parse_args() 235 | main(args) 236 | -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | try: 16 | from torchvision.transforms import InterpolationMode 17 | BICUBIC = InterpolationMode.BICUBIC 18 | except ImportError: 19 | BICUBIC = Image.BICUBIC 20 | 21 | 22 | if torch.__version__.split(".") < ["1", "7", "1"]: 23 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 24 | 25 | 26 | __all__ = ["available_models", "load", "tokenize"] 27 | _tokenizer = _Tokenizer() 28 | 29 | _MODELS = { 30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 34 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 35 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 36 | } 37 | 38 | 39 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 40 | os.makedirs(root, exist_ok=True) 41 | filename = os.path.basename(url) 42 | 43 | expected_sha256 = url.split("/")[-2] 44 | download_target = os.path.join(root, filename) 45 | 46 | if os.path.exists(download_target) and not os.path.isfile(download_target): 47 | raise RuntimeError(f"{download_target} exists and is not a regular file") 48 | 49 | if os.path.isfile(download_target): 50 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 51 | return download_target 52 | else: 53 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 54 | 55 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 56 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 57 | while True: 58 | buffer = source.read(8192) 59 | if not buffer: 60 | break 61 | 62 | output.write(buffer) 63 | loop.update(len(buffer)) 64 | 65 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 66 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 67 | 68 | return download_target 69 | 70 | 71 | def _transform(n_px): 72 | return Compose([ 73 | Resize(n_px, interpolation=BICUBIC), 74 | CenterCrop(n_px), 75 | lambda image: image.convert("RGB"), 76 | ToTensor(), 77 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 78 | ]) 79 | 80 | 81 | def available_models() -> List[str]: 82 | """Returns the names of available CLIP models""" 83 | return list(_MODELS.keys()) 84 | 85 | 86 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False): 87 | """Load a CLIP model 88 | 89 | Parameters 90 | ---------- 91 | name : str 92 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 93 | 94 | device : Union[str, torch.device] 95 | The device to put the loaded model 96 | 97 | jit : bool 98 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 99 | 100 | Returns 101 | ------- 102 | model : torch.nn.Module 103 | The CLIP model 104 | 105 | preprocess : Callable[[PIL.Image], torch.Tensor] 106 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 107 | """ 108 | if name in _MODELS: 109 | model_path = _download(_MODELS[name]) 110 | elif os.path.isfile(name): 111 | model_path = name 112 | else: 113 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 114 | 115 | try: 116 | # loading JIT archive 117 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 118 | state_dict = None 119 | except RuntimeError: 120 | # loading saved state dict 121 | if jit: 122 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 123 | jit = False 124 | state_dict = torch.load(model_path, map_location="cpu") 125 | 126 | if not jit: 127 | model = build_model(state_dict or model.state_dict()).to(device) 128 | if str(device) == "cpu": 129 | model.float() 130 | return model, _transform(model.visual.input_resolution) 131 | 132 | # patch the device names 133 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 134 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 135 | 136 | def patch_device(module): 137 | try: 138 | graphs = [module.graph] if hasattr(module, "graph") else [] 139 | except RuntimeError: 140 | graphs = [] 141 | 142 | if hasattr(module, "forward1"): 143 | graphs.append(module.forward1.graph) 144 | 145 | for graph in graphs: 146 | for node in graph.findAllNodes("prim::Constant"): 147 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 148 | node.copyAttributes(device_node) 149 | 150 | model.apply(patch_device) 151 | patch_device(model.encode_image) 152 | patch_device(model.encode_text) 153 | 154 | # patch dtype to float32 on CPU 155 | if str(device) == "cpu": 156 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 157 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 158 | float_node = float_input.node() 159 | 160 | def patch_float(module): 161 | try: 162 | graphs = [module.graph] if hasattr(module, "graph") else [] 163 | except RuntimeError: 164 | graphs = [] 165 | 166 | if hasattr(module, "forward1"): 167 | graphs.append(module.forward1.graph) 168 | 169 | for graph in graphs: 170 | for node in graph.findAllNodes("aten::to"): 171 | inputs = list(node.inputs()) 172 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 173 | if inputs[i].node()["value"] == 5: 174 | inputs[i].node().copyAttributes(float_node) 175 | 176 | model.apply(patch_float) 177 | patch_float(model.encode_image) 178 | patch_float(model.encode_text) 179 | 180 | model.float() 181 | 182 | return model, _transform(model.input_resolution.item()) 183 | 184 | 185 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 186 | """ 187 | Returns the tokenized representation of given input string(s) 188 | 189 | Parameters 190 | ---------- 191 | texts : Union[str, List[str]] 192 | An input string or a list of input strings to tokenize 193 | 194 | context_length : int 195 | The context length to use; all CLIP models use 77 as the context length 196 | 197 | truncate: bool 198 | Whether to truncate the text in case its encoding is longer than the context length 199 | 200 | Returns 201 | ------- 202 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 203 | """ 204 | if isinstance(texts, str): 205 | texts = [texts] 206 | 207 | sot_token = _tokenizer.encoder["<|startoftext|>"] 208 | eot_token = _tokenizer.encoder["<|endoftext|>"] 209 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 210 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 211 | 212 | for i, tokens in enumerate(all_tokens): 213 | if len(tokens) > context_length: 214 | if truncate: 215 | tokens = tokens[:context_length] 216 | tokens[-1] = eot_token 217 | else: 218 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 219 | result[i, :len(tokens)] = torch.tensor(tokens) 220 | 221 | return result 222 | -------------------------------------------------------------------------------- /clip_maple/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | try: 16 | from torchvision.transforms import InterpolationMode 17 | BICUBIC = InterpolationMode.BICUBIC 18 | except ImportError: 19 | BICUBIC = Image.BICUBIC 20 | 21 | 22 | if torch.__version__.split(".") < ["1", "7", "1"]: 23 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 24 | 25 | 26 | __all__ = ["available_models", "load", "tokenize"] 27 | _tokenizer = _Tokenizer() 28 | 29 | _MODELS = { 30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 34 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 35 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 36 | } 37 | 38 | 39 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 40 | os.makedirs(root, exist_ok=True) 41 | filename = os.path.basename(url) 42 | 43 | expected_sha256 = url.split("/")[-2] 44 | download_target = os.path.join(root, filename) 45 | 46 | if os.path.exists(download_target) and not os.path.isfile(download_target): 47 | raise RuntimeError(f"{download_target} exists and is not a regular file") 48 | 49 | if os.path.isfile(download_target): 50 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 51 | return download_target 52 | else: 53 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 54 | 55 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 56 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 57 | while True: 58 | buffer = source.read(8192) 59 | if not buffer: 60 | break 61 | 62 | output.write(buffer) 63 | loop.update(len(buffer)) 64 | 65 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 66 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 67 | 68 | return download_target 69 | 70 | 71 | def _transform(n_px): 72 | return Compose([ 73 | Resize(n_px, interpolation=BICUBIC), 74 | CenterCrop(n_px), 75 | lambda image: image.convert("RGB"), 76 | ToTensor(), 77 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 78 | ]) 79 | 80 | 81 | def available_models() -> List[str]: 82 | """Returns the names of available CLIP models""" 83 | return list(_MODELS.keys()) 84 | 85 | 86 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False): 87 | """Load a CLIP model 88 | 89 | Parameters 90 | ---------- 91 | name : str 92 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 93 | 94 | device : Union[str, torch.device] 95 | The device to put the loaded model 96 | 97 | jit : bool 98 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 99 | 100 | Returns 101 | ------- 102 | model : torch.nn.Module 103 | The CLIP model 104 | 105 | preprocess : Callable[[PIL.Image], torch.Tensor] 106 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 107 | """ 108 | if name in _MODELS: 109 | model_path = _download(_MODELS[name]) 110 | elif os.path.isfile(name): 111 | model_path = name 112 | else: 113 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 114 | 115 | try: 116 | # loading JIT archive 117 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 118 | state_dict = None 119 | except RuntimeError: 120 | # loading saved state dict 121 | if jit: 122 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 123 | jit = False 124 | state_dict = torch.load(model_path, map_location="cpu") 125 | 126 | if not jit: 127 | model = build_model(state_dict or model.state_dict()).to(device) 128 | if str(device) == "cpu": 129 | model.float() 130 | return model, _transform(model.visual.input_resolution) 131 | 132 | # patch the device names 133 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 134 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 135 | 136 | def patch_device(module): 137 | try: 138 | graphs = [module.graph] if hasattr(module, "graph") else [] 139 | except RuntimeError: 140 | graphs = [] 141 | 142 | if hasattr(module, "forward1"): 143 | graphs.append(module.forward1.graph) 144 | 145 | for graph in graphs: 146 | for node in graph.findAllNodes("prim::Constant"): 147 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 148 | node.copyAttributes(device_node) 149 | 150 | model.apply(patch_device) 151 | patch_device(model.encode_image) 152 | patch_device(model.encode_text) 153 | 154 | # patch dtype to float32 on CPU 155 | if str(device) == "cpu": 156 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 157 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 158 | float_node = float_input.node() 159 | 160 | def patch_float(module): 161 | try: 162 | graphs = [module.graph] if hasattr(module, "graph") else [] 163 | except RuntimeError: 164 | graphs = [] 165 | 166 | if hasattr(module, "forward1"): 167 | graphs.append(module.forward1.graph) 168 | 169 | for graph in graphs: 170 | for node in graph.findAllNodes("aten::to"): 171 | inputs = list(node.inputs()) 172 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 173 | if inputs[i].node()["value"] == 5: 174 | inputs[i].node().copyAttributes(float_node) 175 | 176 | model.apply(patch_float) 177 | patch_float(model.encode_image) 178 | patch_float(model.encode_text) 179 | 180 | model.float() 181 | 182 | return model, _transform(model.input_resolution.item()) 183 | 184 | 185 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 186 | """ 187 | Returns the tokenized representation of given input string(s) 188 | 189 | Parameters 190 | ---------- 191 | texts : Union[str, List[str]] 192 | An input string or a list of input strings to tokenize 193 | 194 | context_length : int 195 | The context length to use; all CLIP models use 77 as the context length 196 | 197 | truncate: bool 198 | Whether to truncate the text in case its encoding is longer than the context length 199 | 200 | Returns 201 | ------- 202 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 203 | """ 204 | if isinstance(texts, str): 205 | texts = [texts] 206 | 207 | sot_token = _tokenizer.encoder["<|startoftext|>"] 208 | eot_token = _tokenizer.encoder["<|endoftext|>"] 209 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 210 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 211 | 212 | for i, tokens in enumerate(all_tokens): 213 | if len(tokens) > context_length: 214 | if truncate: 215 | tokens = tokens[:context_length] 216 | tokens[-1] = eot_token 217 | else: 218 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 219 | result[i, :len(tokens)] = torch.tensor(tokens) 220 | 221 | return result 222 | -------------------------------------------------------------------------------- /parallel_runner.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import numpy as np 4 | import os 5 | import os.path as osp 6 | import shutil 7 | 8 | from utils.gpu_allocater import GPUAllocater 9 | from utils.logger import setup_logger, print 10 | from utils.mail import MailClient 11 | from utils.result_parser import ResultParser 12 | 13 | from configs import get_config 14 | from templates import get_command 15 | 16 | 17 | class ParallelRunner(object): 18 | def __init__(self, cfg): 19 | self.cfg = cfg 20 | 21 | self.data_cfg = cfg['data'] 22 | self.train_cfg = cfg['train'] 23 | self.grid_search_cfg = cfg['grid_search'] 24 | self.output_cfg = cfg['output'] 25 | self.mail_cfg = cfg['mail'] 26 | 27 | self.allocater = GPUAllocater(cfg['gpu_ids']) 28 | self.mail = MailClient(self.mail_cfg) 29 | 30 | def run(self): 31 | """ main method """ 32 | grid_search_cfg = self.grid_search_cfg 33 | output_cfg = self.output_cfg 34 | 35 | # remove useless directories 36 | remove_dirs = [output_cfg[name] for name in output_cfg['remove_dirs']] 37 | 38 | for dir_ in remove_dirs: 39 | if osp.exists(dir_): 40 | print(f'Remove directory >>> {dir_}') 41 | shutil.rmtree(dir_) 42 | 43 | os.makedirs(dir_) 44 | 45 | setup_logger(osp.join(output_cfg['root'], 'log.txt'), write_to_console=True) 46 | 47 | start_time = datetime.datetime.now() 48 | 49 | try: 50 | # main 51 | if grid_search_cfg['enable']: 52 | result_paths = self.run_grid_serach() 53 | else: 54 | result_paths = [self.run_single()] 55 | except: 56 | # handle exception, contents of exception will be sent to your email 57 | end_time = datetime.datetime.now() 58 | contents = [f'Training tasks FAILED! Time cost: {end_time - start_time}\n\n', 59 | 'Exception is following above:\n'] 60 | 61 | exception_path = osp.join(output_cfg['root'], 'exceptions.txt') 62 | with open(exception_path) as f: 63 | contents += f.readlines() 64 | 65 | print('Training tasks FAILED! Mail will be sent >>> {}'.format(self.mail_cfg['to'])) 66 | self.mail.send('Training Tasks FAILED!', contents) 67 | return 68 | 69 | # after finished, results will be sent to your email 70 | end_time = datetime.datetime.now() 71 | contents = [f'Training tasks FINISHED! Time cost: {end_time - start_time}\n\n', 72 | 'Results are following above:\n'] 73 | 74 | for result_path in result_paths: 75 | contents += [f'\n{result_path}\n'] 76 | with open(result_path) as f: 77 | contents += f.readlines() 78 | 79 | print('Training tasks FINISHED! Mail will be sent >>> {}'.format(self.mail_cfg['to'])) 80 | self.mail.send('Training Tasks FINISHED!', contents) 81 | 82 | def run_grid_serach(self): 83 | """ run if grid search is enabled """ 84 | output_cfg = self.output_cfg 85 | root = output_cfg['root'] 86 | 87 | # parse gird search params 88 | dirnames, opts_list = self.get_grid_search_opts() 89 | 90 | print('Grid search opts:') 91 | for opts in opts_list: 92 | print(opts) 93 | print() 94 | 95 | result_paths = [] 96 | 97 | for idx, (dirname, opts) in enumerate(zip(dirnames, opts_list)): 98 | # run single task for each grid search param group 99 | print(f'[{idx + 1} / {len(dirnames)}] Running task {opts}\n') 100 | output_cfg['root'] = osp.join(root, dirname) 101 | 102 | self.dirname = dirname 103 | result_paths.append(self.run_single(opts)) 104 | 105 | return result_paths 106 | 107 | def run_single(self, opts=[]): 108 | cfg = self.cfg 109 | train_cfg = self.train_cfg 110 | grid_search_cfg = self.grid_search_cfg 111 | output_cfg = self.output_cfg 112 | 113 | # get command 114 | if cfg['mode'] == 'b2n': 115 | commands = self.get_base_to_new_commands(opts) 116 | else: 117 | commands = self.get_cross_dataset_commands(opts) 118 | 119 | # add command 120 | for command in commands: 121 | self.allocater.add_command(command) 122 | 123 | # run command 124 | self.allocater.run() 125 | 126 | # save result 127 | if not grid_search_cfg['enable']: 128 | filename = '{}-{}-{}.csv'.format(cfg['mode'], train_cfg['trainer'], train_cfg['cfg']) 129 | else: 130 | filename = '{}-{}-{}-{}.csv'.format(cfg['mode'], train_cfg['trainer'], train_cfg['cfg'], self.dirname) 131 | 132 | os.makedirs(output_cfg['result'], exist_ok=True) 133 | result_path = osp.join(output_cfg['result'], filename) 134 | 135 | print(f'Results will be save >>> {result_path}') 136 | parser = ResultParser(cfg['mode'], output_cfg['root'], result_path) 137 | parser.parse_and_save() 138 | 139 | return result_path 140 | 141 | def get_base_to_new_commands(self, opts=[]): 142 | data_cfg = self.data_cfg 143 | train_cfg = self.train_cfg 144 | output_cfg = self.output_cfg 145 | 146 | data_root = data_cfg['root'] 147 | datasets = data_cfg['datasets_base_to_new'] 148 | 149 | trainer = train_cfg['trainer'] 150 | cfg = train_cfg['cfg'] 151 | seeds = train_cfg['seeds'] 152 | loadep = train_cfg['loadep'] 153 | shots = train_cfg['shots'] 154 | opts += train_cfg['opts'] 155 | 156 | root = output_cfg['root'] 157 | 158 | commands = [] 159 | 160 | # training on all datasets 161 | for dataset in datasets: 162 | for seed in seeds: 163 | cmd = get_command(data_root, seed, trainer, dataset, cfg, root, 164 | shots, dataset, loadep, opts, mode='b2n', train=True) 165 | commands.append(cmd) 166 | 167 | # testing on all datasets 168 | for dataset in datasets: 169 | for seed in seeds: 170 | cmd = get_command(data_root, seed, trainer, dataset, cfg, root, 171 | shots, dataset, loadep, opts, mode='b2n', train=False) 172 | commands.append(cmd) 173 | 174 | return commands 175 | 176 | def get_cross_dataset_commands(self, opts): 177 | data_cfg = self.data_cfg 178 | train_cfg = self.train_cfg 179 | output_cfg = self.output_cfg 180 | 181 | data_root = data_cfg['root'] 182 | datasets = data_cfg['datasets_cross_dataset'] 183 | 184 | trainer = train_cfg['trainer'] 185 | cfg = train_cfg['cfg'] 186 | seeds = train_cfg['seeds'] 187 | loadep = train_cfg['loadep'] 188 | shots = train_cfg['shots'] 189 | opts += train_cfg['opts'] 190 | 191 | root = output_cfg['root'] 192 | 193 | commands = [] 194 | 195 | # training on image 196 | load_dataset = 'imagenet' 197 | for seed in seeds: 198 | cmd = get_command(data_root, seed, trainer, load_dataset, cfg, root, 199 | shots, load_dataset, loadep, opts, mode='xd', train=True) 200 | commands.append(cmd) 201 | 202 | # testing on other datasets 203 | for dataset in datasets: 204 | for seed in seeds: 205 | cmd = get_command(data_root, seed, trainer, dataset, cfg, root, 206 | shots, load_dataset, loadep, opts, mode='xd', train=False) 207 | commands.append(cmd) 208 | 209 | return commands 210 | 211 | def get_grid_search_opts(self): 212 | grid_search_cfg = self.grid_search_cfg 213 | mode = grid_search_cfg['mode'] 214 | params = grid_search_cfg['params'] 215 | 216 | names = [param['name'] for param in params] 217 | aliases = [param['alias'] for param in params] 218 | values_list = [param['values'] for param in params] 219 | 220 | # grid to sequential 221 | if mode == 'grid' and len(names) > 1: 222 | values_list = [list(arr.flatten()) for arr in np.meshgrid(*values_list)] 223 | 224 | # build opts 225 | dirnames, grid_search_opts_list = [], [] 226 | for i in range(len(values_list[0])): 227 | values = [values[i] for values in values_list] 228 | 229 | dirname, opts = [], [] 230 | for name, alias, value in zip(names, aliases, values): 231 | dirname.append(f'{alias}{value}') 232 | opts += [name, value] 233 | 234 | dirname = '_'.join(dirname) 235 | dirnames.append(dirname) 236 | grid_search_opts_list.append(opts) 237 | 238 | return dirnames, grid_search_opts_list 239 | 240 | 241 | def main(args): 242 | cfg = get_config(args.cfg) 243 | runner = ParallelRunner(cfg) 244 | runner.run() 245 | 246 | 247 | if __name__ == '__main__': 248 | parser = argparse.ArgumentParser() 249 | parser.add_argument('--cfg', type=str) 250 | args = parser.parse_args() 251 | main(args) 252 | -------------------------------------------------------------------------------- /datasets/DATASETS.md: -------------------------------------------------------------------------------- 1 | # How to install datasets 2 | 3 | ### Acknowledgement: This readme file for installing datasets has been borrowed directly from [CoOp's](https://github.com/KaiyangZhou/CoOp/blob/main/DATASETS.md) official repository. 4 | 5 | We suggest putting all datasets under the same folder (say `$DATA`) to ease management and following the instructions below to organize datasets to avoid modifying the source code. The file structure looks like 6 | 7 | ``` 8 | $DATA/ 9 | |–– imagenet/ 10 | |–– caltech-101/ 11 | |–– oxford_pets/ 12 | |–– stanford_cars/ 13 | ``` 14 | 15 | If you have some datasets already installed somewhere else, you can create symbolic links in `$DATA/dataset_name` that point to the original data to avoid duplicate download. 16 | 17 | Datasets list: 18 | - [ImageNet](#imagenet) 19 | - [Caltech101](#caltech101) 20 | - [OxfordPets](#oxfordpets) 21 | - [StanfordCars](#stanfordcars) 22 | - [Flowers102](#flowers102) 23 | - [Food101](#food101) 24 | - [FGVCAircraft](#fgvcaircraft) 25 | - [SUN397](#sun397) 26 | - [DTD](#dtd) 27 | - [EuroSAT](#eurosat) 28 | - [UCF101](#ucf101) 29 | - [ImageNetV2](#imagenetv2) 30 | - [ImageNet-Sketch](#imagenet-sketch) 31 | - [ImageNet-A](#imagenet-a) 32 | - [ImageNet-R](#imagenet-r) 33 | 34 | The instructions to prepare each dataset are detailed below. To ensure reproducibility and fair comparison for future work, we provide fixed train/val/test splits for all datasets except ImageNet where the validation set is used as test set. The fixed splits are either from the original datasets (if available) or created by us. 35 | 36 | ### ImageNet 37 | - Create a folder named `imagenet/` under `$DATA`. 38 | - Create `images/` under `imagenet/`. 39 | - Download the dataset from the [official website](https://image-net.org/index.php) and extract the training and validation sets to `$DATA/imagenet/images`. The directory structure should look like 40 | ``` 41 | imagenet/ 42 | |–– images/ 43 | | |–– train/ # contains 1,000 folders like n01440764, n01443537, etc. 44 | | |–– val/ 45 | ``` 46 | - If you had downloaded the ImageNet dataset before, you can create symbolic links to map the training and validation sets to `$DATA/imagenet/images`. 47 | - Download the `classnames.txt` to `$DATA/imagenet/` from this [link](https://drive.google.com/file/d/1-61f_ol79pViBFDG_IDlUQSwoLcn2XXF/view?usp=sharing). The class names are copied from [CLIP](https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb). 48 | 49 | ### Caltech101 50 | - Create a folder named `caltech-101/` under `$DATA`. 51 | - Download `101_ObjectCategories.tar.gz` from http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz and extract the file under `$DATA/caltech-101`. 52 | - Download `split_zhou_Caltech101.json` from this [link](https://drive.google.com/file/d/1hyarUivQE36mY6jSomru6Fjd-JzwcCzN/view?usp=sharing) and put it under `$DATA/caltech-101`. 53 | 54 | The directory structure should look like 55 | ``` 56 | caltech-101/ 57 | |–– 101_ObjectCategories/ 58 | |–– split_zhou_Caltech101.json 59 | ``` 60 | 61 | ### OxfordPets 62 | - Create a folder named `oxford_pets/` under `$DATA`. 63 | - Download the images from https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz. 64 | - Download the annotations from https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz. 65 | - Download `split_zhou_OxfordPets.json` from this [link](https://drive.google.com/file/d/1501r8Ber4nNKvmlFVQZ8SeUHTcdTTEqs/view?usp=sharing). 66 | 67 | The directory structure should look like 68 | ``` 69 | oxford_pets/ 70 | |–– images/ 71 | |–– annotations/ 72 | |–– split_zhou_OxfordPets.json 73 | ``` 74 | 75 | ### StanfordCars 76 | - Create a folder named `stanford_cars/` under `$DATA`. 77 | - Download the train images http://ai.stanford.edu/~jkrause/car196/cars_train.tgz. 78 | - Download the test images http://ai.stanford.edu/~jkrause/car196/cars_test.tgz. 79 | - Download the train labels https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz. 80 | - Download the test labels http://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat. 81 | - Download `split_zhou_StanfordCars.json` from this [link](https://drive.google.com/file/d/1ObCFbaAgVu0I-k_Au-gIUcefirdAuizT/view?usp=sharing). 82 | 83 | The directory structure should look like 84 | ``` 85 | stanford_cars/ 86 | |–– cars_test\ 87 | |–– cars_test_annos_withlabels.mat 88 | |–– cars_train\ 89 | |–– devkit\ 90 | |–– split_zhou_StanfordCars.json 91 | ``` 92 | 93 | ### Flowers102 94 | - Create a folder named `oxford_flowers/` under `$DATA`. 95 | - Download the images and labels from https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz and https://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat respectively. 96 | - Download `cat_to_name.json` from [here](https://drive.google.com/file/d/1AkcxCXeK_RCGCEC_GvmWxjcjaNhu-at0/view?usp=sharing). 97 | - Download `split_zhou_OxfordFlowers.json` from [here](https://drive.google.com/file/d/1Pp0sRXzZFZq15zVOzKjKBu4A9i01nozT/view?usp=sharing). 98 | 99 | The directory structure should look like 100 | ``` 101 | oxford_flowers/ 102 | |–– cat_to_name.json 103 | |–– imagelabels.mat 104 | |–– jpg/ 105 | |–– split_zhou_OxfordFlowers.json 106 | ``` 107 | 108 | ### Food101 109 | - Download the dataset from https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/ and extract the file `food-101.tar.gz` under `$DATA`, resulting in a folder named `$DATA/food-101/`. 110 | - Download `split_zhou_Food101.json` from [here](https://drive.google.com/file/d/1QK0tGi096I0Ba6kggatX1ee6dJFIcEJl/view?usp=sharing). 111 | 112 | The directory structure should look like 113 | ``` 114 | food-101/ 115 | |–– images/ 116 | |–– license_agreement.txt 117 | |–– meta/ 118 | |–– README.txt 119 | |–– split_zhou_Food101.json 120 | ``` 121 | 122 | ### FGVCAircraft 123 | - Download the data from https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz. 124 | - Extract `fgvc-aircraft-2013b.tar.gz` and keep only `data/`. 125 | - Move `data/` to `$DATA` and rename the folder to `fgvc_aircraft/`. 126 | 127 | The directory structure should look like 128 | ``` 129 | fgvc_aircraft/ 130 | |–– images/ 131 | |–– ... # a bunch of .txt files 132 | ``` 133 | 134 | ### SUN397 135 | - Create a folder named `sun397/` under `$DATA`. 136 | - Download the images http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz. 137 | - Download the partitions https://vision.princeton.edu/projects/2010/SUN/download/Partitions.zip. 138 | - Extract these files under `$DATA/sun397/`. 139 | - Download `split_zhou_SUN397.json` from this [link](https://drive.google.com/file/d/1y2RD81BYuiyvebdN-JymPfyWYcd8_MUq/view?usp=sharing). 140 | 141 | The directory structure should look like 142 | ``` 143 | sun397/ 144 | |–– SUN397/ 145 | |–– split_zhou_SUN397.json 146 | |–– ... # a bunch of .txt files 147 | ``` 148 | 149 | ### DTD 150 | - Download the dataset from https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz and extract it to `$DATA`. This should lead to `$DATA/dtd/`. 151 | - Download `split_zhou_DescribableTextures.json` from this [link](https://drive.google.com/file/d/1u3_QfB467jqHgNXC00UIzbLZRQCg2S7x/view?usp=sharing). 152 | 153 | The directory structure should look like 154 | ``` 155 | dtd/ 156 | |–– images/ 157 | |–– imdb/ 158 | |–– labels/ 159 | |–– split_zhou_DescribableTextures.json 160 | ``` 161 | 162 | ### EuroSAT 163 | - Create a folder named `eurosat/` under `$DATA`. 164 | - Download the dataset from http://madm.dfki.de/files/sentinel/EuroSAT.zip and extract it to `$DATA/eurosat/`. 165 | - Download `split_zhou_EuroSAT.json` from [here](https://drive.google.com/file/d/1Ip7yaCWFi0eaOFUGga0lUdVi_DDQth1o/view?usp=sharing). 166 | 167 | The directory structure should look like 168 | ``` 169 | eurosat/ 170 | |–– 2750/ 171 | |–– split_zhou_EuroSAT.json 172 | ``` 173 | 174 | ### UCF101 175 | - Create a folder named `ucf101/` under `$DATA`. 176 | - Download the zip file `UCF-101-midframes.zip` from [here](https://drive.google.com/file/d/10Jqome3vtUA2keJkNanAiFpgbyC9Hc2O/view?usp=sharing) and extract it to `$DATA/ucf101/`. This zip file contains the extracted middle video frames. 177 | - Download `split_zhou_UCF101.json` from this [link](https://drive.google.com/file/d/1I0S0q91hJfsV9Gf4xDIjgDq4AqBNJb1y/view?usp=sharing). 178 | 179 | The directory structure should look like 180 | ``` 181 | ucf101/ 182 | |–– UCF-101-midframes/ 183 | |–– split_zhou_UCF101.json 184 | ``` 185 | 186 | ### ImageNetV2 187 | - Create a folder named `imagenetv2/` under `$DATA`. 188 | - Go to this github repo https://github.com/modestyachts/ImageNetV2. 189 | - Download the matched-frequency dataset from https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-matched-frequency.tar.gz and extract it to `$DATA/imagenetv2/`. 190 | - Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenetv2/`. 191 | 192 | The directory structure should look like 193 | ``` 194 | imagenetv2/ 195 | |–– imagenetv2-matched-frequency-format-val/ 196 | |–– classnames.txt 197 | ``` 198 | 199 | ### ImageNet-Sketch 200 | - Download the dataset from https://github.com/HaohanWang/ImageNet-Sketch. 201 | - Extract the dataset to `$DATA/imagenet-sketch`. 202 | - Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenet-sketch/`. 203 | 204 | The directory structure should look like 205 | ``` 206 | imagenet-sketch/ 207 | |–– images/ # contains 1,000 folders whose names have the format of n* 208 | |–– classnames.txt 209 | ``` 210 | 211 | ### ImageNet-A 212 | - Create a folder named `imagenet-adversarial/` under `$DATA`. 213 | - Download the dataset from https://github.com/hendrycks/natural-adv-examples and extract it to `$DATA/imagenet-adversarial/`. 214 | - Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenet-adversarial/`. 215 | 216 | The directory structure should look like 217 | ``` 218 | imagenet-adversarial/ 219 | |–– imagenet-a/ # contains 200 folders whose names have the format of n* 220 | |–– classnames.txt 221 | ``` 222 | 223 | ### ImageNet-R 224 | - Create a folder named `imagenet-rendition/` under `$DATA`. 225 | - Download the dataset from https://github.com/hendrycks/imagenet-r and extract it to `$DATA/imagenet-rendition/`. 226 | - Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenet-rendition/`. 227 | 228 | The directory structure should look like 229 | ``` 230 | imagenet-rendition/ 231 | |–– imagenet-r/ # contains 200 folders whose names have the format of n* 232 | |–– classnames.txt 233 | ``` -------------------------------------------------------------------------------- /trainers/elp_maple.py: -------------------------------------------------------------------------------- 1 | # see elp_coopy.py to view comments 2 | import os 3 | import os.path as osp 4 | import torch 5 | import torch.nn.functional as F 6 | from dassl.engine import TRAINER_REGISTRY 7 | from dassl.utils import load_pretrained_weights, load_checkpoint 8 | from dassl.optim import build_lr_scheduler 9 | from plotnine import * 10 | from torch.cuda.amp import GradScaler 11 | from torch import nn 12 | 13 | from .maple import MaPLe, load_clip_to_cpu 14 | from .maple import CustomCLIP as CustomCLIP_ 15 | from .optim import build_optimizer 16 | 17 | 18 | class FiLM(nn.Module): 19 | def __init__(self, 20 | dim, 21 | bias=True, 22 | use_sigmoid=False): 23 | super().__init__() 24 | self.scale = nn.Parameter(torch.ones(dim)) 25 | self.bias = nn.Parameter(torch.zeros(dim)) if bias else None 26 | self.has_bias = bias 27 | self.use_sigmoid = use_sigmoid 28 | 29 | def forward(self, x): 30 | scale = self.scale.unsqueeze(0).type(x.dtype) 31 | bias = self.bias.unsqueeze(0).type(x.dtype) if self.has_bias else None 32 | 33 | x = scale * x 34 | if bias is not None: 35 | x = x + bias 36 | 37 | if self.use_sigmoid: 38 | return x.sigmoid() 39 | 40 | return x 41 | 42 | 43 | class CustomCLIP(CustomCLIP_): 44 | def __init__(self, cfg, classnames, clip_model): 45 | super().__init__(cfg, classnames, clip_model) 46 | self.subsample_classes = cfg.DATASET.SUBSAMPLE_CLASSES 47 | self.dataset = cfg.DATASET.NAME 48 | self.lp_cfg = cfg.TRAINER.LINEAR_PROBE 49 | self.film_cfg = cfg.TRAINER.FILM 50 | 51 | clip_dim = clip_model.text_projection.size(1) 52 | 53 | film_cfg = self.film_cfg 54 | 55 | if film_cfg.LINEAR_PROBE: 56 | self.film_lp_img = FiLM(clip_dim) 57 | self.film_lp_text = FiLM(clip_dim) 58 | 59 | if (self.subsample_classes == 'base') \ 60 | or (self.subsample_classes == 'all' and 'ImageNet' in self.dataset): 61 | assert self.lp_cfg.TYPE in ['similarity', 'linear'] 62 | 63 | if self.lp_cfg.TYPE == 'similarity': 64 | self.linear_probe_proj = nn.Identity() 65 | elif self.lp_cfg.TYPE == 'linear': 66 | self.linear_probe_proj = nn.Linear(clip_dim, len(classnames)).type(self.dtype) 67 | else: 68 | self.linear_probe_proj = nn.Identity() 69 | 70 | def forward(self, img, labels=None): 71 | if (self.subsample_classes == 'base') \ 72 | or (self.subsample_classes == 'all' and 'ImageNet' in self.dataset): 73 | return self._forward_base(img, labels) 74 | else: 75 | return self._forward_new(img) 76 | 77 | def _forward_base(self, img, labels=None): 78 | text_feats, img_feats = self._forward_feats(img) 79 | 80 | logits = self._forward_logits_similarity(text_feats, img_feats) 81 | logits_lp, labels_lp = self._forward_logits_linear_probe(text_feats, img_feats, labels) 82 | 83 | if self.prompt_learner.training: 84 | return self._loss(logits, labels, logits_lp, labels_lp) 85 | 86 | if not self.lp_cfg.TEST_TIME_FUSION: 87 | return logits_lp 88 | 89 | lp_weight = self.lp_cfg.WEIGHT 90 | logits = (1 - lp_weight) * logits + lp_weight * logits_lp 91 | return logits 92 | 93 | def _forward_new(self, img): 94 | assert not self.prompt_learner.training 95 | 96 | text_feats, img_feats = self._forward_feats(img) 97 | logits = self._forward_logits_similarity(text_feats, img_feats) 98 | return logits 99 | 100 | def _forward_feats(self, img): 101 | text_prompts, vision_prompts, deep_text_prompts_list, deep_vision_prompts_list = self.prompt_learner() 102 | 103 | tokenized_prompts = self.tokenized_prompts 104 | text_feats = self.text_encoder(text_prompts, tokenized_prompts, deep_text_prompts_list) 105 | img_feats = self.image_encoder(img.type(self.dtype), vision_prompts, deep_vision_prompts_list) 106 | 107 | return text_feats, img_feats 108 | 109 | def _forward_logits_similarity(self, text_feats, img_feats): 110 | text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True) 111 | img_feats = img_feats / img_feats.norm(dim=-1, keepdim=True) 112 | logit_scale = self.logit_scale.exp() 113 | logits = logit_scale * img_feats @ text_feats.t() 114 | return logits 115 | 116 | def _forward_logits_linear_probe(self, text_feats, img_feats, labels): 117 | if self.film_cfg.LINEAR_PROBE: 118 | text_feats = self.film_lp_text(text_feats) 119 | img_feats = self.film_lp_img(img_feats) 120 | 121 | if self.lp_cfg.TYPE == 'similarity': 122 | text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True) 123 | img_feats = img_feats / img_feats.norm(dim=-1, keepdim=True) 124 | logit_scale = self.logit_scale.exp() 125 | logits = logit_scale * img_feats @ text_feats.t() 126 | return logits, labels 127 | 128 | if labels is None: 129 | all_feats = img_feats 130 | all_labels = labels 131 | else: 132 | text_feats = text_feats[labels] 133 | all_feats = torch.cat([text_feats, img_feats]) 134 | all_labels = torch.cat([labels, labels]) 135 | 136 | all_logits = self.linear_probe_proj(all_feats) 137 | return all_logits, all_labels 138 | 139 | def _loss(self, logits, labels, logits_lp, labels_lp): 140 | loss_cls = F.cross_entropy(logits, labels) 141 | loss_cls_lp = F.cross_entropy(logits_lp, labels_lp) 142 | 143 | lp_weight = self.lp_cfg.WEIGHT 144 | loss = (1 - lp_weight) * loss_cls + lp_weight * loss_cls_lp 145 | return loss 146 | 147 | 148 | @TRAINER_REGISTRY.register() 149 | class ExtrasLinearProbeMaPLe(MaPLe): 150 | def build_model(self): 151 | cfg = self.cfg 152 | classnames = self.dm.dataset.classnames 153 | 154 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 155 | clip_model = load_clip_to_cpu(cfg) 156 | 157 | if cfg.TRAINER.MAPLE.PREC == "fp32" or cfg.TRAINER.MAPLE.PREC == "amp": 158 | clip_model.float() 159 | 160 | print("Building custom CLIP") 161 | self.model = CustomCLIP(cfg, classnames, clip_model) 162 | 163 | print("Turning off gradients in both the image and the text encoder") 164 | names_to_update = cfg.TRAINER.NAMES_TO_UPDATE 165 | 166 | for name, param in self.model.named_parameters(): 167 | update = False 168 | 169 | for name_to_update in names_to_update: 170 | if name_to_update in name: 171 | update = True 172 | break 173 | 174 | param.requires_grad_(update) 175 | 176 | enabled = [] 177 | for name, param in self.model.named_parameters(): 178 | if param.requires_grad: 179 | enabled.append(name) 180 | print(f"Parameters to be updated: {list(sorted(enabled))}") 181 | 182 | if cfg.MODEL.INIT_WEIGHTS: 183 | load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS) 184 | 185 | self.model.to(self.device) 186 | self.optim, infos = build_optimizer(self.model, cfg.OPTIM) 187 | 188 | if infos is not None: 189 | print('Learning rate of parameters:') 190 | for info in infos: 191 | print('lr: {}, layers: {}'.format(info['lr'], info['layers'])) 192 | 193 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) 194 | self.register_model("MultiModalPromptLearner", self.model, self.optim, self.sched) 195 | 196 | self.scaler = GradScaler() if cfg.TRAINER.MAPLE.PREC == "amp" else None 197 | 198 | device_count = torch.cuda.device_count() 199 | if device_count > 1: 200 | print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!") 201 | self.model = nn.DataParallel(self.model) 202 | 203 | def load_model(self, directory, epoch=None): 204 | if not directory: 205 | print("Note that load_model() is skipped as no pretrained model is given") 206 | return 207 | 208 | names = self.get_model_names() 209 | 210 | # By default, the best model is loaded 211 | model_file = "model-best.pth.tar" 212 | 213 | if epoch is not None: 214 | model_file = "model.pth.tar-" + str(epoch) 215 | 216 | for name in names: 217 | if epoch < 0: 218 | all_model_files = os.listdir(osp.join(directory, name)) 219 | all_model_files = [file_ for file_ in all_model_files if file_ != 'checkpoint'] 220 | model_epochs = [int(file_.split('-')[-1]) for file_ in all_model_files] 221 | last_epoch = max(model_epochs) 222 | model_file = 'model.pth.tar-' + str(last_epoch) 223 | 224 | model_path = osp.join(directory, name, model_file) 225 | 226 | if not osp.exists(model_path): 227 | raise FileNotFoundError('Model not found at "{}"'.format(model_path)) 228 | 229 | checkpoint = load_checkpoint(model_path) 230 | state_dict = checkpoint["state_dict"] 231 | epoch = checkpoint["epoch"] 232 | 233 | # Ignore fixed token vectors 234 | if "prompt_learner.token_prefix" in state_dict: 235 | del state_dict["prompt_learner.token_prefix"] 236 | 237 | if "prompt_learner.token_suffix" in state_dict: 238 | del state_dict["prompt_learner.token_suffix"] 239 | 240 | print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch)) 241 | 242 | if self.cfg.DATASET.NAME in ['ImageNetA', 'ImageNetR']: 243 | from datasets.imagenet import ImageNet 244 | from dassl.utils import listdir_nohidden 245 | 246 | dataset = self.dm.dataset 247 | text_file = osp.join(dataset.dataset_dir, "classnames.txt") 248 | all_folders = ImageNet.read_classnames(text_file).keys() 249 | 250 | TO_BE_IGNORED = ["README.txt"] 251 | folders = listdir_nohidden(dataset.image_dir, sort=True) 252 | folders = [f for f in folders if f not in TO_BE_IGNORED] 253 | is_reserves = [f in folders for f in all_folders] 254 | 255 | print(f'State dict is CLIPPED to match the shape of target dataset {self.cfg.DATASET.NAME}!') 256 | state_dict['linear_probe_proj.weight'] = state_dict['linear_probe_proj.weight'][is_reserves] 257 | state_dict['linear_probe_proj.bias'] = state_dict['linear_probe_proj.bias'][is_reserves] 258 | 259 | # set strict=False 260 | self._models[name].load_state_dict(state_dict, strict=False) 261 | -------------------------------------------------------------------------------- /trainers/elp_cocoop.py: -------------------------------------------------------------------------------- 1 | # see elp_coopy.py to view comments 2 | import os 3 | import os.path as osp 4 | import torch 5 | import torch.nn.functional as F 6 | from dassl.engine import TRAINER_REGISTRY 7 | from dassl.utils import load_pretrained_weights, load_checkpoint 8 | from dassl.optim import build_lr_scheduler 9 | from plotnine import * 10 | from torch.cuda.amp import GradScaler, autocast 11 | from torch import nn 12 | 13 | from .cocoop import CoCoOp, load_clip_to_cpu 14 | from .cocoop import CustomCLIP as CustomCLIP_ 15 | from .elp_maple import FiLM 16 | from .optim import build_optimizer 17 | 18 | 19 | class CustomCLIP(CustomCLIP_): 20 | def __init__(self, cfg, classnames, clip_model): 21 | super().__init__(cfg, classnames, clip_model) 22 | self.subsample_classes = cfg.DATASET.SUBSAMPLE_CLASSES 23 | self.dataset = cfg.DATASET.NAME 24 | self.lp_cfg = cfg.TRAINER.LINEAR_PROBE 25 | self.film_cfg = cfg.TRAINER.FILM 26 | 27 | clip_dim = clip_model.text_projection.size(1) 28 | 29 | film_cfg = self.film_cfg 30 | 31 | if film_cfg.LINEAR_PROBE: 32 | self.film_lp_img = FiLM(clip_dim) 33 | self.film_lp_text = FiLM(clip_dim) 34 | 35 | if (self.subsample_classes == 'base') \ 36 | or (self.subsample_classes == 'all' and 'ImageNet' in self.dataset): 37 | assert self.lp_cfg.TYPE in ['similarity', 'linear'] 38 | 39 | if self.lp_cfg.TYPE == 'similarity': 40 | self.linear_probe_proj = nn.Identity() 41 | elif self.lp_cfg.TYPE == 'linear': 42 | self.linear_probe_proj = nn.Linear(clip_dim, len(classnames)).type(self.dtype) 43 | else: 44 | self.linear_probe_proj = nn.Identity() 45 | 46 | def forward(self, img, labels=None): 47 | if (self.subsample_classes == 'base') \ 48 | or (self.subsample_classes == 'all' and 'ImageNet' in self.dataset): 49 | return self._forward_base(img, labels) 50 | else: 51 | return self._forward_new(img) 52 | 53 | def _forward_base(self, img, labels=None): 54 | text_feats_list, img_feats = self._forward_feats(img) 55 | 56 | logits = self._forward_logits_similarity(text_feats_list, img_feats) 57 | logits_lp, labels_lp = self._forward_logits_linear_probe(text_feats_list, img_feats, labels) 58 | 59 | if self.prompt_learner.training: 60 | return self._loss(logits, labels, logits_lp, labels_lp) 61 | 62 | if not self.lp_cfg.TEST_TIME_FUSION: 63 | return logits_lp 64 | 65 | lp_weight = self.lp_cfg.WEIGHT 66 | logits = (1 - lp_weight) * logits + lp_weight * logits_lp 67 | return logits 68 | 69 | def _forward_new(self, img): 70 | assert not self.prompt_learner.training 71 | 72 | text_feats, img_feats = self._forward_feats(img) 73 | logits = self._forward_logits_similarity(text_feats, img_feats) 74 | return logits 75 | 76 | def _forward_feats(self, img): 77 | img_feats = self.image_encoder(img.type(self.dtype)) 78 | img_feats_norm = img_feats / img_feats.norm(dim=-1, keepdim=True) 79 | 80 | prompts = self.prompt_learner(img_feats_norm) 81 | 82 | tokenized_prompts = self.tokenized_prompts 83 | text_feats_list = [self.text_encoder(prompt, tokenized_prompts) for prompt in prompts] 84 | 85 | return text_feats_list, img_feats 86 | 87 | def _forward_logits_similarity(self, text_feats_list, img_feats): 88 | img_feats = img_feats / img_feats.norm(dim=-1, keepdim=True) 89 | 90 | logits = [] 91 | logit_scale = self.logit_scale.exp() 92 | 93 | for text_feats, img_feats_per_batch in zip(text_feats_list, img_feats): 94 | text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True) 95 | logits_per_batch = logit_scale * img_feats_per_batch @ text_feats.t() 96 | logits.append(logits_per_batch) 97 | 98 | logits = torch.stack(logits) 99 | return logits 100 | 101 | def _forward_logits_linear_probe(self, text_feats_list, img_feats, labels): 102 | if self.film_cfg.LINEAR_PROBE: 103 | text_feats_list = [self.film_lp_text(text_feats) for text_feats in text_feats_list] 104 | img_feats = self.film_lp_img(img_feats) 105 | 106 | if self.lp_cfg.TYPE == 'similarity': 107 | return self._forward_logits_similarity(text_feats_list, img_feats) 108 | 109 | if labels is None: 110 | all_feats = img_feats 111 | all_labels = labels 112 | else: 113 | text_feats = torch.stack([text_feats[label] for text_feats, label 114 | in zip(text_feats_list, labels)]) 115 | all_feats = torch.cat([text_feats, img_feats]) 116 | all_labels = torch.cat([labels, labels]) 117 | 118 | all_logits = self.linear_probe_proj(all_feats) 119 | return all_logits, all_labels 120 | 121 | def _loss(self, logits, labels, logits_lp, labels_lp): 122 | loss_cls = F.cross_entropy(logits, labels) 123 | loss_cls_lp = F.cross_entropy(logits_lp, labels_lp) 124 | 125 | lp_weight = self.lp_cfg.WEIGHT 126 | loss = (1 - lp_weight) * loss_cls + lp_weight * loss_cls_lp 127 | return loss 128 | 129 | 130 | @TRAINER_REGISTRY.register() 131 | class ExtrasLinearProbeCoCoOp(CoCoOp): 132 | def forward_backward(self, batch): 133 | image, label = self.parse_batch_train(batch) 134 | 135 | model = self.model 136 | optim = self.optim 137 | scaler = self.scaler 138 | 139 | prec = self.cfg.TRAINER.COCOOP.PREC 140 | if prec == "amp": 141 | with autocast(): 142 | loss = model(image, label) 143 | optim.zero_grad() 144 | scaler.scale(loss).backward() 145 | scaler.step(optim) 146 | scaler.update() 147 | else: 148 | loss = model(image, label) 149 | optim.zero_grad() 150 | loss.backward() 151 | optim.step() 152 | 153 | loss_summary = {"loss": loss.item()} 154 | 155 | if (self.batch_idx + 1) == self.num_batches: 156 | self.update_lr() 157 | 158 | return loss_summary 159 | 160 | def parse_batch_train(self, batch): 161 | input = batch["img"] 162 | label = batch["label"] 163 | input = input.to(self.device) 164 | label = label.to(self.device) 165 | return input, label 166 | 167 | def build_model(self): 168 | cfg = self.cfg 169 | classnames = self.dm.dataset.classnames 170 | 171 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 172 | clip_model = load_clip_to_cpu(cfg) 173 | 174 | if cfg.TRAINER.COCOOP.PREC == "fp32" or cfg.TRAINER.COCOOP.PREC == "amp": 175 | clip_model.float() 176 | 177 | print("Building custom CLIP") 178 | self.model = CustomCLIP(cfg, classnames, clip_model) 179 | 180 | print("Turning off gradients in both the image and the text encoder") 181 | names_to_update = cfg.TRAINER.NAMES_TO_UPDATE 182 | 183 | for name, param in self.model.named_parameters(): 184 | update = False 185 | 186 | for name_to_update in names_to_update: 187 | if name_to_update in name: 188 | update = True 189 | break 190 | 191 | param.requires_grad_(update) 192 | 193 | enabled = [] 194 | for name, param in self.model.named_parameters(): 195 | if param.requires_grad: 196 | enabled.append(name) 197 | print(f"Parameters to be updated: {list(sorted(enabled))}") 198 | 199 | if cfg.MODEL.INIT_WEIGHTS: 200 | load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS) 201 | 202 | self.model.to(self.device) 203 | self.optim, infos = build_optimizer(self.model, cfg.OPTIM) 204 | 205 | if infos is not None: 206 | print('Learning rate of parameters:') 207 | for info in infos: 208 | print('lr: {}, layers: {}'.format(info['lr'], info['layers'])) 209 | 210 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) 211 | self.register_model("PromptLearner", self.model, self.optim, self.sched) 212 | 213 | self.scaler = GradScaler() if cfg.TRAINER.COCOOP.PREC == "amp" else None 214 | 215 | device_count = torch.cuda.device_count() 216 | if device_count > 1: 217 | print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!") 218 | self.model = nn.DataParallel(self.model) 219 | 220 | def load_model(self, directory, epoch=None): 221 | if not directory: 222 | print("Note that load_model() is skipped as no pretrained model is given") 223 | return 224 | 225 | names = self.get_model_names() 226 | 227 | # By default, the best model is loaded 228 | model_file = "model-best.pth.tar" 229 | 230 | if epoch is not None: 231 | model_file = "model.pth.tar-" + str(epoch) 232 | 233 | for name in names: 234 | if epoch < 0: 235 | all_model_files = os.listdir(osp.join(directory, name)) 236 | all_model_files = [file_ for file_ in all_model_files if file_ != 'checkpoint'] 237 | model_epochs = [int(file_.split('-')[-1]) for file_ in all_model_files] 238 | last_epoch = max(model_epochs) 239 | model_file = 'model.pth.tar-' + str(last_epoch) 240 | 241 | model_path = osp.join(directory, name, model_file) 242 | 243 | if not osp.exists(model_path): 244 | raise FileNotFoundError('Model not found at "{}"'.format(model_path)) 245 | 246 | checkpoint = load_checkpoint(model_path) 247 | state_dict = checkpoint["state_dict"] 248 | epoch = checkpoint["epoch"] 249 | 250 | # Ignore fixed token vectors 251 | if "prompt_learner.token_prefix" in state_dict: 252 | del state_dict["prompt_learner.token_prefix"] 253 | 254 | if "prompt_learner.token_suffix" in state_dict: 255 | del state_dict["prompt_learner.token_suffix"] 256 | 257 | print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch)) 258 | 259 | if self.cfg.DATASET.NAME in ['ImageNetA', 'ImageNetR']: 260 | from datasets.imagenet import ImageNet 261 | from dassl.utils import listdir_nohidden 262 | dataset = self.dm.dataset 263 | text_file = osp.join(dataset.dataset_dir, "classnames.txt") 264 | all_folders = ImageNet.read_classnames(text_file).keys() 265 | 266 | TO_BE_IGNORED = ["README.txt"] 267 | folders = listdir_nohidden(dataset.image_dir, sort=True) 268 | folders = [f for f in folders if f not in TO_BE_IGNORED] 269 | 270 | is_reserves = [f in folders for f in all_folders] 271 | 272 | print(f'State dict is CLIPPED to match the shape of target dataset {self.cfg.DATASET.NAME}!') 273 | state_dict['linear_probe_proj.weight'] = state_dict['linear_probe_proj.weight'][is_reserves] 274 | state_dict['linear_probe_proj.bias'] = state_dict['linear_probe_proj.bias'][is_reserves] 275 | 276 | # set strict=False 277 | self._models[name].load_state_dict(state_dict, strict=False) 278 | -------------------------------------------------------------------------------- /trainers/elp_kgcoop.py: -------------------------------------------------------------------------------- 1 | # see elp_coopy.py to view comments 2 | import os 3 | import os.path as osp 4 | import torch 5 | import torch.nn.functional as F 6 | from dassl.engine import TRAINER_REGISTRY 7 | from dassl.utils import load_pretrained_weights, load_checkpoint 8 | from dassl.optim import build_lr_scheduler 9 | from plotnine import * 10 | from torch.cuda.amp import GradScaler, autocast 11 | from torch import nn 12 | 13 | from .kgcoop import KgCoOp, load_clip_to_cpu 14 | from .kgcoop import CustomCLIP as CustomCLIP_ 15 | from .elp_maple import FiLM 16 | from .optim import build_optimizer 17 | 18 | 19 | class CustomCLIP(CustomCLIP_): 20 | def __init__(self, cfg, classnames, clip_model): 21 | super().__init__(cfg, classnames, clip_model) 22 | self.subsample_classes = cfg.DATASET.SUBSAMPLE_CLASSES 23 | self.dataset = cfg.DATASET.NAME 24 | self.lp_cfg = cfg.TRAINER.LINEAR_PROBE 25 | self.film_cfg = cfg.TRAINER.FILM 26 | self.kg_weight = cfg.TRAINER.COOP.W 27 | 28 | clip_dim = clip_model.text_projection.size(1) 29 | 30 | film_cfg = self.film_cfg 31 | 32 | if film_cfg.LINEAR_PROBE: 33 | self.film_lp_img = FiLM(clip_dim) 34 | self.film_lp_text = FiLM(clip_dim) 35 | 36 | if (self.subsample_classes == 'base') \ 37 | or (self.subsample_classes == 'all' and 'ImageNet' in self.dataset): 38 | assert self.lp_cfg.TYPE in ['similarity', 'linear'] 39 | 40 | if self.lp_cfg.TYPE == 'similarity': 41 | self.linear_probe_proj = nn.Identity() 42 | elif self.lp_cfg.TYPE == 'linear': 43 | self.linear_probe_proj = nn.Linear(clip_dim, len(classnames)).type(self.dtype) 44 | else: 45 | self.linear_probe_proj = nn.Identity() 46 | 47 | def forward(self, img, labels=None): 48 | if (self.subsample_classes == 'base') \ 49 | or (self.subsample_classes == 'all' and 'ImageNet' in self.dataset): 50 | return self._forward_base(img, labels) 51 | else: 52 | return self._forward_new(img) 53 | 54 | def _forward_base(self, img, labels=None): 55 | text_feats, img_feats = self._forward_feats(img) 56 | 57 | logits = self._forward_logits_similarity(text_feats, img_feats) 58 | logits_lp, labels_lp = self._forward_logits_linear_probe(text_feats, img_feats, labels) 59 | 60 | if self.prompt_learner.training: 61 | return self._loss(logits, labels, logits_lp, labels_lp, text_feats) 62 | 63 | if not self.lp_cfg.TEST_TIME_FUSION: 64 | return logits_lp 65 | 66 | lp_weight = self.lp_cfg.WEIGHT 67 | logits = (1 - lp_weight) * logits + lp_weight * logits_lp 68 | return logits 69 | 70 | def _forward_new(self, img): 71 | assert not self.prompt_learner.training 72 | 73 | text_feats, img_feats = self._forward_feats(img) 74 | logits = self._forward_logits_similarity(text_feats, img_feats) 75 | return logits 76 | 77 | def _forward_feats(self, img): 78 | prompts = self.prompt_learner() 79 | 80 | tokenized_prompts = self.tokenized_prompts 81 | text_feats = self.text_encoder(prompts, tokenized_prompts) 82 | img_feats = self.image_encoder(img.type(self.dtype)) 83 | 84 | return text_feats, img_feats 85 | 86 | def _forward_logits_similarity(self, text_feats, img_feats): 87 | text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True) 88 | img_feats = img_feats / img_feats.norm(dim=-1, keepdim=True) 89 | logit_scale = self.logit_scale.exp() 90 | logits = logit_scale * img_feats @ text_feats.t() 91 | return logits 92 | 93 | def _forward_logits_linear_probe(self, text_feats, img_feats, labels): 94 | if self.film_cfg.LINEAR_PROBE: 95 | text_feats = self.film_lp_text(text_feats) 96 | img_feats = self.film_lp_img(img_feats) 97 | 98 | if self.lp_cfg.TYPE == 'similarity': 99 | text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True) 100 | img_feats = img_feats / img_feats.norm(dim=-1, keepdim=True) 101 | logit_scale = self.logit_scale.exp() 102 | logits = logit_scale * img_feats @ text_feats.t() 103 | return logits, labels 104 | 105 | if labels is None: 106 | all_feats = img_feats 107 | all_labels = labels 108 | else: 109 | text_feats = text_feats[labels] 110 | all_feats = torch.cat([text_feats, img_feats]) 111 | all_labels = torch.cat([labels, labels]) 112 | 113 | all_logits = self.linear_probe_proj(all_feats) 114 | return all_logits, all_labels 115 | 116 | def _loss(self, logits, labels, logits_lp, labels_lp, text_feats): 117 | loss_cls = F.cross_entropy(logits, labels) 118 | loss_cls_lp = F.cross_entropy(logits_lp, labels_lp) 119 | 120 | text_feats_old = self.ori_embedding 121 | text_feats_old = text_feats_old / text_feats_old.norm(dim=-1, keepdim=True) 122 | 123 | cos = torch.nn.CosineSimilarity(dim=1,eps=1e-07) 124 | score = cos(text_feats, text_feats_old) 125 | loss_kg = 1.0 - torch.mean(score) 126 | 127 | lp_weight = self.lp_cfg.WEIGHT 128 | loss = (1 - lp_weight) * loss_cls + lp_weight * loss_cls_lp + self.kg_weight * loss_kg 129 | return loss 130 | 131 | 132 | @TRAINER_REGISTRY.register() 133 | class ExtrasLinearProbeKgCoOp(KgCoOp): 134 | def forward_backward(self, batch): 135 | image, label = self.parse_batch_train(batch) 136 | 137 | model = self.model 138 | optim = self.optim 139 | scaler = self.scaler 140 | 141 | prec = self.cfg.TRAINER.COOP.PREC 142 | if prec == "amp": 143 | with autocast(): 144 | loss = model(image, label) 145 | optim.zero_grad() 146 | scaler.scale(loss).backward() 147 | scaler.step(optim) 148 | scaler.update() 149 | else: 150 | loss = model(image, label) 151 | optim.zero_grad() 152 | loss.backward() 153 | optim.step() 154 | 155 | loss_summary = {"loss": loss.item()} 156 | 157 | if (self.batch_idx + 1) == self.num_batches: 158 | self.update_lr() 159 | 160 | return loss_summary 161 | 162 | def model_inference(self, input): 163 | return self.model(input) 164 | 165 | def parse_batch_train(self, batch): 166 | input = batch["img"] 167 | label = batch["label"] 168 | input = input.to(self.device) 169 | label = label.to(self.device) 170 | return input, label 171 | 172 | def build_model(self): 173 | cfg = self.cfg 174 | classnames = self.dm.dataset.classnames 175 | 176 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 177 | clip_model = load_clip_to_cpu(cfg) 178 | 179 | if cfg.TRAINER.COOP.PREC == "fp32" or cfg.TRAINER.COOP.PREC == "amp": 180 | clip_model.float() 181 | 182 | print("Building custom CLIP") 183 | self.model = CustomCLIP(cfg, classnames, clip_model) 184 | 185 | print("Turning off gradients in both the image and the text encoder") 186 | names_to_update = cfg.TRAINER.NAMES_TO_UPDATE 187 | 188 | for name, param in self.model.named_parameters(): 189 | update = False 190 | 191 | for name_to_update in names_to_update: 192 | if name_to_update in name: 193 | update = True 194 | break 195 | 196 | param.requires_grad_(update) 197 | 198 | enabled = [] 199 | for name, param in self.model.named_parameters(): 200 | if param.requires_grad: 201 | enabled.append(name) 202 | print(f"Parameters to be updated: {list(sorted(enabled))}") 203 | 204 | if cfg.MODEL.INIT_WEIGHTS: 205 | load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS) 206 | 207 | self.model.to(self.device) 208 | self.optim, infos = build_optimizer(self.model, cfg.OPTIM) 209 | 210 | if infos is not None: 211 | print('Learning rate of parameters:') 212 | for info in infos: 213 | print('lr: {}, layers: {}'.format(info['lr'], info['layers'])) 214 | 215 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) 216 | self.register_model("PromptLearner", self.model, self.optim, self.sched) 217 | 218 | self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == "amp" else None 219 | 220 | device_count = torch.cuda.device_count() 221 | if device_count > 1: 222 | print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!") 223 | self.model = nn.DataParallel(self.model) 224 | 225 | def load_model(self, directory, epoch=None): 226 | if not directory: 227 | print("Note that load_model() is skipped as no pretrained model is given") 228 | return 229 | 230 | names = self.get_model_names() 231 | 232 | # By default, the best model is loaded 233 | model_file = "model-best.pth.tar" 234 | 235 | if epoch is not None: 236 | model_file = "model.pth.tar-" + str(epoch) 237 | 238 | for name in names: 239 | if epoch < 0: 240 | all_model_files = os.listdir(osp.join(directory, name)) 241 | all_model_files = [file_ for file_ in all_model_files if file_ != 'checkpoint'] 242 | model_epochs = [int(file_.split('-')[-1]) for file_ in all_model_files] 243 | last_epoch = max(model_epochs) 244 | model_file = 'model.pth.tar-' + str(last_epoch) 245 | 246 | model_path = osp.join(directory, name, model_file) 247 | 248 | if not osp.exists(model_path): 249 | raise FileNotFoundError('Model not found at "{}"'.format(model_path)) 250 | 251 | checkpoint = load_checkpoint(model_path) 252 | state_dict = checkpoint["state_dict"] 253 | epoch = checkpoint["epoch"] 254 | 255 | # Ignore fixed token vectors 256 | if "prompt_learner.token_prefix" in state_dict: 257 | del state_dict["prompt_learner.token_prefix"] 258 | 259 | if "prompt_learner.token_suffix" in state_dict: 260 | del state_dict["prompt_learner.token_suffix"] 261 | 262 | print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch)) 263 | 264 | if self.cfg.DATASET.NAME in ['ImageNetA', 'ImageNetR']: 265 | from datasets.imagenet import ImageNet 266 | from dassl.utils import listdir_nohidden 267 | 268 | dataset = self.dm.dataset 269 | text_file = osp.join(dataset.dataset_dir, "classnames.txt") 270 | all_folders = ImageNet.read_classnames(text_file).keys() 271 | 272 | TO_BE_IGNORED = ["README.txt"] 273 | folders = listdir_nohidden(dataset.image_dir, sort=True) 274 | folders = [f for f in folders if f not in TO_BE_IGNORED] 275 | is_reserves = [f in folders for f in all_folders] 276 | 277 | print(f'State dict is CLIPPED to match the shape of target dataset {self.cfg.DATASET.NAME}!') 278 | state_dict['linear_probe_proj.weight'] = state_dict['linear_probe_proj.weight'][is_reserves] 279 | state_dict['linear_probe_proj.bias'] = state_dict['linear_probe_proj.bias'][is_reserves] 280 | 281 | # set strict=False 282 | self._models[name].load_state_dict(state_dict, strict=False) 283 | -------------------------------------------------------------------------------- /trainers/elp_coop.py: -------------------------------------------------------------------------------- 1 | # CoOp w/ DePT 2 | import os 3 | import os.path as osp 4 | import torch 5 | import torch.nn.functional as F 6 | from dassl.engine import TRAINER_REGISTRY 7 | from dassl.utils import load_pretrained_weights, load_checkpoint 8 | from dassl.optim import build_lr_scheduler 9 | from plotnine import * 10 | from torch.cuda.amp import GradScaler, autocast 11 | from torch import nn 12 | 13 | from .coop import CoOp, load_clip_to_cpu 14 | from .coop import CustomCLIP as CustomCLIP_ 15 | from .elp_maple import FiLM 16 | from .optim import build_optimizer 17 | 18 | 19 | class CustomCLIP(CustomCLIP_): 20 | def __init__(self, cfg, classnames, clip_model): 21 | super().__init__(cfg, classnames, clip_model) 22 | self.subsample_classes = cfg.DATASET.SUBSAMPLE_CLASSES 23 | self.dataset = cfg.DATASET.NAME 24 | self.lp_cfg = cfg.TRAINER.LINEAR_PROBE 25 | self.film_cfg = cfg.TRAINER.FILM 26 | 27 | clip_dim = clip_model.text_projection.size(1) 28 | 29 | film_cfg = self.film_cfg 30 | 31 | if film_cfg.LINEAR_PROBE: 32 | # cwT module 33 | self.film_lp_img = FiLM(clip_dim) 34 | self.film_lp_text = FiLM(clip_dim) 35 | 36 | # for base to new, base classes will be 'base' 37 | # for cross dataset, classes from ImageNet will be 'base' 38 | if (self.subsample_classes == 'base') \ 39 | or (self.subsample_classes == 'all' and 'ImageNet' in self.dataset): 40 | assert self.lp_cfg.TYPE in ['similarity', 'linear'] 41 | 42 | # linear classifier 43 | if self.lp_cfg.TYPE == 'similarity': 44 | self.linear_probe_proj = nn.Identity() 45 | elif self.lp_cfg.TYPE == 'linear': 46 | self.linear_probe_proj = nn.Linear(clip_dim, len(classnames)).type(self.dtype) 47 | else: 48 | self.linear_probe_proj = nn.Identity() 49 | 50 | def forward(self, img, labels=None): 51 | if (self.subsample_classes == 'base') \ 52 | or (self.subsample_classes == 'all' and 'ImageNet' in self.dataset): 53 | return self._forward_base(img, labels) 54 | else: 55 | return self._forward_new(img) 56 | 57 | def _forward_base(self, img, labels=None): 58 | """ forward function for base classes """ 59 | text_feats, img_feats = self._forward_feats(img) 60 | 61 | # forward similartiy and linear logits 62 | logits = self._forward_logits_similarity(text_feats, img_feats) 63 | logits_lp, labels_lp = self._forward_logits_linear_probe(text_feats, img_feats, labels) 64 | 65 | if self.prompt_learner.training: 66 | # while training, return loss of both logits 67 | return self._loss(logits, labels, logits_lp, labels_lp) 68 | 69 | if not self.lp_cfg.TEST_TIME_FUSION: 70 | return logits_lp 71 | 72 | # while inference, fusion both logits and return 73 | lp_weight = self.lp_cfg.WEIGHT 74 | logits = (1 - lp_weight) * logits + lp_weight * logits_lp 75 | return logits 76 | 77 | def _forward_new(self, img): 78 | """ forward function for new classes """ 79 | assert not self.prompt_learner.training 80 | 81 | # for new classes, only forward similarity logits 82 | text_feats, img_feats = self._forward_feats(img) 83 | logits = self._forward_logits_similarity(text_feats, img_feats) 84 | return logits 85 | 86 | def _forward_feats(self, img): 87 | prompts = self.prompt_learner() 88 | 89 | tokenized_prompts = self.tokenized_prompts 90 | text_feats = self.text_encoder(prompts, tokenized_prompts) 91 | img_feats = self.image_encoder(img.type(self.dtype)) 92 | 93 | return text_feats, img_feats 94 | 95 | def _forward_logits_similarity(self, text_feats, img_feats): 96 | # normalize and calcute cosine similarity 97 | text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True) 98 | img_feats = img_feats / img_feats.norm(dim=-1, keepdim=True) 99 | logit_scale = self.logit_scale.exp() 100 | logits = logit_scale * img_feats @ text_feats.t() 101 | return logits 102 | 103 | def _forward_logits_linear_probe(self, text_feats, img_feats, labels): 104 | # cwT module 105 | if self.film_cfg.LINEAR_PROBE: 106 | text_feats = self.film_lp_text(text_feats) 107 | img_feats = self.film_lp_img(img_feats) 108 | 109 | # while new head is similarity head, use similarity forward function 110 | if self.lp_cfg.TYPE == 'similarity': 111 | return self._forward_logits_similarity(text_feats, img_feats), labels 112 | 113 | if labels is None: 114 | # while inference, forward image features only 115 | all_feats = img_feats 116 | all_labels = labels 117 | else: 118 | # while training, image features and text features will be concated to train classifier 119 | text_feats = text_feats[labels] 120 | all_feats = torch.cat([text_feats, img_feats]) 121 | all_labels = torch.cat([labels, labels]) 122 | 123 | all_logits = self.linear_probe_proj(all_feats) 124 | return all_logits, all_labels 125 | 126 | def _loss(self, logits, labels, logits_lp, labels_lp): 127 | # calculate similarity loss and linear loss 128 | loss_cls = F.cross_entropy(logits, labels) 129 | loss_cls_lp = F.cross_entropy(logits_lp, labels_lp) 130 | 131 | lp_weight = self.lp_cfg.WEIGHT 132 | loss = (1 - lp_weight) * loss_cls + lp_weight * loss_cls_lp 133 | return loss 134 | 135 | 136 | @TRAINER_REGISTRY.register() 137 | class ExtrasLinearProbeCoOp(CoOp): 138 | def forward_backward(self, batch): 139 | image, label = self.parse_batch_train(batch) 140 | 141 | model = self.model 142 | optim = self.optim 143 | scaler = self.scaler 144 | 145 | prec = self.cfg.TRAINER.COOP.PREC 146 | if prec == "amp": 147 | with autocast(): 148 | loss = model(image, label) 149 | optim.zero_grad() 150 | scaler.scale(loss).backward() 151 | scaler.step(optim) 152 | scaler.update() 153 | else: 154 | loss = model(image, label) 155 | optim.zero_grad() 156 | loss.backward() 157 | optim.step() 158 | 159 | loss_summary = {"loss": loss.item()} 160 | 161 | if (self.batch_idx + 1) == self.num_batches: 162 | self.update_lr() 163 | 164 | return loss_summary 165 | 166 | def parse_batch_train(self, batch): 167 | input = batch["img"] 168 | label = batch["label"] 169 | input = input.to(self.device) 170 | label = label.to(self.device) 171 | return input, label 172 | 173 | def build_model(self): 174 | cfg = self.cfg 175 | classnames = self.dm.dataset.classnames 176 | 177 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 178 | clip_model = load_clip_to_cpu(cfg) 179 | 180 | if cfg.TRAINER.COOP.PREC == "fp32" or cfg.TRAINER.COOP.PREC == "amp": 181 | clip_model.float() 182 | 183 | print("Building custom CLIP") 184 | self.model = CustomCLIP(cfg, classnames, clip_model) 185 | 186 | print("Turning off gradients in both the image and the text encoder") 187 | names_to_update = cfg.TRAINER.NAMES_TO_UPDATE 188 | 189 | for name, param in self.model.named_parameters(): 190 | update = False 191 | 192 | for name_to_update in names_to_update: 193 | if name_to_update in name: 194 | update = True 195 | break 196 | 197 | param.requires_grad_(update) 198 | 199 | enabled = [] 200 | for name, param in self.model.named_parameters(): 201 | if param.requires_grad: 202 | enabled.append(name) 203 | print(f"Parameters to be updated: {list(sorted(enabled))}") 204 | 205 | if cfg.MODEL.INIT_WEIGHTS: 206 | load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS) 207 | 208 | self.model.to(self.device) 209 | self.optim, infos = build_optimizer(self.model, cfg.OPTIM) 210 | 211 | if infos is not None: 212 | print('Learning rate of parameters:') 213 | for info in infos: 214 | print('lr: {}, layers: {}'.format(info['lr'], info['layers'])) 215 | 216 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) 217 | self.register_model("PromptLearner", self.model, self.optim, self.sched) 218 | 219 | self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == "amp" else None 220 | 221 | device_count = torch.cuda.device_count() 222 | if device_count > 1: 223 | print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!") 224 | self.model = nn.DataParallel(self.model) 225 | 226 | def load_model(self, directory, epoch=None): 227 | if not directory: 228 | print("Note that load_model() is skipped as no pretrained model is given") 229 | return 230 | 231 | names = self.get_model_names() 232 | 233 | # By default, the best model is loaded 234 | model_file = "model-best.pth.tar" 235 | 236 | if epoch is not None: 237 | model_file = "model.pth.tar-" + str(epoch) 238 | 239 | for name in names: 240 | if epoch < 0: 241 | all_model_files = os.listdir(osp.join(directory, name)) 242 | all_model_files = [file_ for file_ in all_model_files if file_ != 'checkpoint'] 243 | model_epochs = [int(file_.split('-')[-1]) for file_ in all_model_files] 244 | last_epoch = max(model_epochs) 245 | model_file = 'model.pth.tar-' + str(last_epoch) 246 | 247 | model_path = osp.join(directory, name, model_file) 248 | 249 | if not osp.exists(model_path): 250 | raise FileNotFoundError('Model not found at "{}"'.format(model_path)) 251 | 252 | checkpoint = load_checkpoint(model_path) 253 | state_dict = checkpoint["state_dict"] 254 | epoch = checkpoint["epoch"] 255 | 256 | # Ignore fixed token vectors 257 | if "prompt_learner.token_prefix" in state_dict: 258 | del state_dict["prompt_learner.token_prefix"] 259 | 260 | if "prompt_learner.token_suffix" in state_dict: 261 | del state_dict["prompt_learner.token_suffix"] 262 | 263 | print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch)) 264 | 265 | # for some dataset in domain generalization, number of target classes is different from number of source classes 266 | # thus a mapping must be created to preserve the required class weights 267 | if self.cfg.DATASET.NAME in ['ImageNetA', 'ImageNetR']: 268 | from datasets.imagenet import ImageNet 269 | from dassl.utils import listdir_nohidden 270 | 271 | # read classes from source dataset 272 | dataset = self.dm.dataset 273 | text_file = osp.join(dataset.dataset_dir, "classnames.txt") 274 | all_folders = ImageNet.read_classnames(text_file).keys() 275 | 276 | # read classes from target dataset 277 | TO_BE_IGNORED = ["README.txt"] 278 | folders = listdir_nohidden(dataset.image_dir, sort=True) 279 | folders = [f for f in folders if f not in TO_BE_IGNORED] 280 | 281 | # find that which class from target dataset is in source dataset 282 | is_reserves = [f in folders for f in all_folders] 283 | 284 | # only reserve required class weights 285 | print(f'State dict is CLIPPED to match the shape of target dataset {self.cfg.DATASET.NAME}!') 286 | state_dict['linear_probe_proj.weight'] = state_dict['linear_probe_proj.weight'][is_reserves] 287 | state_dict['linear_probe_proj.bias'] = state_dict['linear_probe_proj.bias'][is_reserves] 288 | 289 | # set strict=False 290 | self._models[name].load_state_dict(state_dict, strict=False) 291 | --------------------------------------------------------------------------------