├── clip ├── prompt.txt ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── __pycache__ │ ├── clip.cpython-37.pyc │ ├── model.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ └── simple_tokenizer.cpython-37.pyc ├── simple_tokenizer.py ├── clip.py └── model.py ├── data ├── datasets ├── __init__.py ├── __pycache__ │ ├── dtd.cpython-37.pyc │ ├── sun397.cpython-37.pyc │ ├── ucf101.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── eurosat.cpython-37.pyc │ ├── food101.cpython-37.pyc │ ├── imagenet.cpython-37.pyc │ ├── caltech101.cpython-37.pyc │ ├── datasetbase.cpython-37.pyc │ ├── imagenet_a.cpython-37.pyc │ ├── imagenet_r.cpython-37.pyc │ ├── imagenetv2.cpython-37.pyc │ ├── oxford_pets.cpython-37.pyc │ ├── data_manager.cpython-37.pyc │ ├── fgvc_aircraft.cpython-37.pyc │ ├── oxford_flowers.cpython-37.pyc │ ├── stanford_cars.cpython-37.pyc │ └── imagenet_sketch.cpython-37.pyc ├── food101.py ├── eurosat.py ├── stanford_cars.py ├── sun397.py ├── oxford_flowers.py ├── dtd.py ├── ucf101.py ├── fgvc_aircraft.py ├── caltech101.py ├── oxford_pets.py ├── data_manager.py ├── imagenet.py └── datasetbase.py ├── trainers ├── __init__.py ├── __pycache__ │ ├── coop.cpython-37.pyc │ ├── utils.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── hhzsclip.cpython-37.pyc │ ├── hhsstrainer.cpython-37.pyc │ └── upltrainer.cpython-37.pyc ├── zsclip.py ├── hhzsclip.py ├── coop.py └── utils.py ├── requirements.txt ├── configs ├── datasets │ ├── sseurosat.yaml │ ├── ssfood101.yaml │ ├── ssimagenet.yaml │ ├── sssun397.yaml │ ├── ssucf101.yaml │ ├── sscaltech101.yaml │ ├── ssoxford_pets.yaml │ ├── ssdtd.yaml │ ├── ssoxford_flowers.yaml │ ├── ssfgvc_aircraft.yaml │ └── ssstanford_cars.yaml ├── upl_default_config │ ├── __pycache__ │ │ └── upl_default.cpython-37.pyc │ └── upl_default.py └── trainers │ └── UPLTrainer │ ├── anay_rn101.yaml │ ├── anay_rn50.yaml │ ├── anay_rn50x4.yaml │ ├── anay_ViT_B_16.yaml │ ├── anay_ViT_B_32.yaml │ ├── anay_ViT_L_14.yaml │ ├── anay_rn50x16.yaml │ ├── anay_rn50x64.yaml │ └── rn50_ep50.yaml ├── .gitignore ├── temp_analyze_results_miltiple ├── figures └── overview.png ├── evaluation ├── __pycache__ │ └── evaluator.cpython-37.pyc └── evaluator.py ├── scripts ├── get_all_info.sh ├── get_info.sh ├── upl_test_existing_logits.sh └── upl_train.sh ├── readme.md ├── get_info.py ├── upl_train.py └── upl_test.py /clip/prompt.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data: -------------------------------------------------------------------------------- 1 | ../CoOp/data -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy 2 | regex 3 | tqdm 4 | matplotlib -------------------------------------------------------------------------------- /configs/datasets/sseurosat.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "SSEuroSAT" -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | analysis_results_test/* 2 | analyze_results/* 3 | output/* -------------------------------------------------------------------------------- /configs/datasets/ssfood101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "SSFood101" 3 | -------------------------------------------------------------------------------- /configs/datasets/ssimagenet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "SSImageNet" -------------------------------------------------------------------------------- /configs/datasets/sssun397.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "SSSUN397" 3 | -------------------------------------------------------------------------------- /configs/datasets/ssucf101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "SSUCF101" 3 | -------------------------------------------------------------------------------- /temp_analyze_results_miltiple: -------------------------------------------------------------------------------- 1 | ../CoOp/temp_analyze_results_miltiple/ -------------------------------------------------------------------------------- /configs/datasets/sscaltech101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "SSCaltech101" -------------------------------------------------------------------------------- /configs/datasets/ssoxford_pets.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "SSOxfordPets" -------------------------------------------------------------------------------- /configs/datasets/ssdtd.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "SSDescribableTextures" 3 | -------------------------------------------------------------------------------- /configs/datasets/ssoxford_flowers.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "SSOxfordFlowers" -------------------------------------------------------------------------------- /configs/datasets/ssfgvc_aircraft.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "SSFGVCAircraft" 3 | -------------------------------------------------------------------------------- /configs/datasets/ssstanford_cars.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "SSStanfordCars" 3 | -------------------------------------------------------------------------------- /figures/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/figures/overview.png -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/__pycache__/clip.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/clip/__pycache__/clip.cpython-37.pyc -------------------------------------------------------------------------------- /clip/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/clip/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /clip/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/clip/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dtd.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/dtd.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/coop.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/trainers/__pycache__/coop.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sun397.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/sun397.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/ucf101.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/ucf101.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/trainers/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/eurosat.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/eurosat.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/food101.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/food101.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/imagenet.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/trainers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/hhzsclip.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/trainers/__pycache__/hhzsclip.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/caltech101.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/caltech101.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/datasetbase.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/datasetbase.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet_a.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/imagenet_a.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet_r.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/imagenet_r.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenetv2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/imagenetv2.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/oxford_pets.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/oxford_pets.cpython-37.pyc -------------------------------------------------------------------------------- /evaluation/__pycache__/evaluator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/evaluation/__pycache__/evaluator.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/hhsstrainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/trainers/__pycache__/hhsstrainer.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/upltrainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/trainers/__pycache__/upltrainer.cpython-37.pyc -------------------------------------------------------------------------------- /clip/__pycache__/simple_tokenizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/clip/__pycache__/simple_tokenizer.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/data_manager.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/data_manager.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/fgvc_aircraft.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/fgvc_aircraft.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/oxford_flowers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/oxford_flowers.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/stanford_cars.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/stanford_cars.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet_sketch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/imagenet_sketch.cpython-37.pyc -------------------------------------------------------------------------------- /configs/upl_default_config/__pycache__/upl_default.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/configs/upl_default_config/__pycache__/upl_default.cpython-37.pyc -------------------------------------------------------------------------------- /scripts/get_all_info.sh: -------------------------------------------------------------------------------- 1 | 2 | DATASET=$1 3 | 4 | 5 | bash get_info.sh ${DATASET} anay_rn50 end 16 -1 False 6 | bash get_info.sh ${DATASET} anay_rn50x4 end 16 -1 False 7 | bash get_info.sh ${DATASET} anay_rn50x16 end 16 -1 False 8 | bash get_info.sh ${DATASET} anay_rn50x64 end 16 -1 False 9 | bash get_info.sh ${DATASET} anay_rn101 end 16 -1 False 10 | bash get_info.sh ${DATASET} anay_ViT_B_16 end 16 -1 False 11 | bash get_info.sh ${DATASET} anay_ViT_B_32 end 16 -1 False 12 | bash get_info.sh ${DATASET} anay_ViT_L_14 end 16 -1 False -------------------------------------------------------------------------------- /configs/trainers/UPLTrainer/anay_rn101.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | # TRANSFORMS: ["center_crop", "normalize"]"cutout", 15 | 16 | OPTIM: 17 | NAME: "sgd" 18 | LR: 0.002 19 | MAX_EPOCH: 50 20 | LR_SCHEDULER: "cosine" 21 | WARMUP_EPOCH: 1 22 | WARMUP_TYPE: "constant" 23 | WARMUP_CONS_LR: 1e-5 24 | 25 | TRAIN: 26 | PRINT_FREQ: 5 27 | 28 | TEST: 29 | Analyze_Result_Path: './analysis_results_test/' 30 | FINAL_MODEL: "last_val" 31 | PER_CLASS_RESULT: True 32 | 33 | MODEL: 34 | BACKBONE: 35 | NAME: "RN101" 36 | 37 | TRAINER: 38 | UPLTrainer: 39 | CTX_INIT: "a photo of a" -------------------------------------------------------------------------------- /configs/trainers/UPLTrainer/anay_rn50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | # TRANSFORMS: ["center_crop", "normalize"]"cutout", 15 | 16 | OPTIM: 17 | NAME: "sgd" 18 | LR: 0.002 19 | MAX_EPOCH: 50 20 | LR_SCHEDULER: "cosine" 21 | WARMUP_EPOCH: 1 22 | WARMUP_TYPE: "constant" 23 | WARMUP_CONS_LR: 1e-5 24 | 25 | TRAIN: 26 | PRINT_FREQ: 5 27 | 28 | TEST: 29 | Analyze_Result_Path: './analysis_results_test/' 30 | FINAL_MODEL: "last_val" 31 | PER_CLASS_RESULT: True 32 | 33 | MODEL: 34 | BACKBONE: 35 | NAME: "RN50" 36 | 37 | TRAINER: 38 | UPLTrainer: 39 | CTX_INIT: "a photo of a" -------------------------------------------------------------------------------- /configs/trainers/UPLTrainer/anay_rn50x4.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (288, 288) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | # TRANSFORMS: ["center_crop", "normalize"]"cutout", 15 | 16 | OPTIM: 17 | NAME: "sgd" 18 | LR: 0.002 19 | MAX_EPOCH: 50 20 | LR_SCHEDULER: "cosine" 21 | WARMUP_EPOCH: 1 22 | WARMUP_TYPE: "constant" 23 | WARMUP_CONS_LR: 1e-5 24 | 25 | TRAIN: 26 | PRINT_FREQ: 5 27 | 28 | TEST: 29 | Analyze_Result_Path: './analysis_results_test/' 30 | FINAL_MODEL: "last_val" 31 | PER_CLASS_RESULT: True 32 | 33 | MODEL: 34 | BACKBONE: 35 | NAME: "RN50x4" 36 | 37 | TRAINER: 38 | UPLTrainer: 39 | CTX_INIT: "a photo of a" -------------------------------------------------------------------------------- /configs/trainers/UPLTrainer/anay_ViT_B_16.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | # TRANSFORMS: ["center_crop", "normalize"]"cutout", 15 | 16 | OPTIM: 17 | NAME: "sgd" 18 | LR: 0.002 19 | MAX_EPOCH: 50 20 | LR_SCHEDULER: "cosine" 21 | WARMUP_EPOCH: 1 22 | WARMUP_TYPE: "constant" 23 | WARMUP_CONS_LR: 1e-5 24 | 25 | TRAIN: 26 | PRINT_FREQ: 5 27 | 28 | TEST: 29 | Analyze_Result_Path: './analysis_results_test/' 30 | FINAL_MODEL: "last_val" 31 | PER_CLASS_RESULT: True 32 | 33 | MODEL: 34 | BACKBONE: 35 | NAME: "ViT-B/16" 36 | 37 | TRAINER: 38 | UPLTrainer: 39 | CTX_INIT: "a photo of a" -------------------------------------------------------------------------------- /configs/trainers/UPLTrainer/anay_ViT_B_32.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | # TRANSFORMS: ["center_crop", "normalize"]"cutout", 15 | 16 | OPTIM: 17 | NAME: "sgd" 18 | LR: 0.002 19 | MAX_EPOCH: 50 20 | LR_SCHEDULER: "cosine" 21 | WARMUP_EPOCH: 1 22 | WARMUP_TYPE: "constant" 23 | WARMUP_CONS_LR: 1e-5 24 | 25 | TRAIN: 26 | PRINT_FREQ: 5 27 | 28 | TEST: 29 | Analyze_Result_Path: './analysis_results_test/' 30 | FINAL_MODEL: "last_val" 31 | PER_CLASS_RESULT: True 32 | 33 | MODEL: 34 | BACKBONE: 35 | NAME: "ViT-B/32" 36 | 37 | TRAINER: 38 | UPLTrainer: 39 | CTX_INIT: "a photo of a" -------------------------------------------------------------------------------- /configs/trainers/UPLTrainer/anay_ViT_L_14.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | # TRANSFORMS: ["center_crop", "normalize"]"cutout", 15 | 16 | OPTIM: 17 | NAME: "sgd" 18 | LR: 0.002 19 | MAX_EPOCH: 50 20 | LR_SCHEDULER: "cosine" 21 | WARMUP_EPOCH: 1 22 | WARMUP_TYPE: "constant" 23 | WARMUP_CONS_LR: 1e-5 24 | 25 | TRAIN: 26 | PRINT_FREQ: 5 27 | 28 | TEST: 29 | Analyze_Result_Path: './analysis_results_test/' 30 | FINAL_MODEL: "last_val" 31 | PER_CLASS_RESULT: True 32 | 33 | MODEL: 34 | BACKBONE: 35 | NAME: "ViT-L/14" 36 | 37 | TRAINER: 38 | UPLTrainer: 39 | CTX_INIT: "a photo of a" -------------------------------------------------------------------------------- /configs/trainers/UPLTrainer/anay_rn50x16.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (384, 384) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | # TRANSFORMS: ["center_crop", "normalize"]"cutout", 15 | 16 | OPTIM: 17 | NAME: "sgd" 18 | LR: 0.002 19 | MAX_EPOCH: 50 20 | LR_SCHEDULER: "cosine" 21 | WARMUP_EPOCH: 1 22 | WARMUP_TYPE: "constant" 23 | WARMUP_CONS_LR: 1e-5 24 | 25 | TRAIN: 26 | PRINT_FREQ: 5 27 | 28 | TEST: 29 | Analyze_Result_Path: './analysis_results_test/' 30 | FINAL_MODEL: "last_val" 31 | PER_CLASS_RESULT: True 32 | 33 | MODEL: 34 | BACKBONE: 35 | NAME: "RN50x16" 36 | 37 | TRAINER: 38 | UPLTrainer: 39 | CTX_INIT: "a photo of a" -------------------------------------------------------------------------------- /configs/trainers/UPLTrainer/anay_rn50x64.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (448, 448) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | # TRANSFORMS: ["center_crop", "normalize"]"cutout", 15 | 16 | OPTIM: 17 | NAME: "sgd" 18 | LR: 0.002 19 | MAX_EPOCH: 50 20 | LR_SCHEDULER: "cosine" 21 | WARMUP_EPOCH: 1 22 | WARMUP_TYPE: "constant" 23 | WARMUP_CONS_LR: 1e-5 24 | 25 | TRAIN: 26 | PRINT_FREQ: 5 27 | 28 | TEST: 29 | Analyze_Result_Path: './analysis_results_test/' 30 | FINAL_MODEL: "last_val" 31 | PER_CLASS_RESULT: True 32 | 33 | MODEL: 34 | BACKBONE: 35 | NAME: "RN50x64" 36 | 37 | TRAINER: 38 | UPLTrainer: 39 | CTX_INIT: "a photo of a" -------------------------------------------------------------------------------- /configs/trainers/UPLTrainer/rn50_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | # TRANSFORMS: ["center_crop", "normalize"]"cutout", 15 | 16 | OPTIM: 17 | NAME: "sgd" 18 | LR: 0.002 19 | MAX_EPOCH: 50 20 | LR_SCHEDULER: "cosine" 21 | WARMUP_EPOCH: 1 22 | WARMUP_TYPE: "constant" 23 | WARMUP_CONS_LR: 1e-5 24 | 25 | TRAIN: 26 | PRINT_FREQ: 5 27 | 28 | TEST: 29 | Analyze_Result_Path: './analysis_results_test/' 30 | FINAL_MODEL: "last_val" 31 | PER_CLASS_RESULT: True 32 | 33 | 34 | MODEL: 35 | BACKBONE: 36 | NAME: "RN50" 37 | # PSEUDO_LABEL_MODELS: ['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B-16', 'ViT-B-32', 'ViT-L-14'] 38 | PSEUDO_LABEL_MODELS: ['RN50'] 39 | -------------------------------------------------------------------------------- /scripts/get_info.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd .. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=UPLTrainer 8 | 9 | DATASET=$1 10 | CFG=$2 # config file 11 | CTP=$3 # class token position (end or middle) 12 | NCTX=$4 # number of context tokens 13 | SHOTS=$5 # number of shots (1, 2, 4, 8, 16) 14 | CSC=$6 # class-specific context (False or True) 15 | 16 | 17 | for SEED in 1 18 | do 19 | DIR=./output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots_random_init/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED} 20 | 21 | echo "Run this job and save the output to ${DIR}" 22 | python get_info.py \ 23 | --root ${DATA} \ 24 | --seed ${SEED} \ 25 | --trainer ${TRAINER} \ 26 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 28 | --output-dir ${DIR} \ 29 | TRAINER.UPLTrainer.N_CTX ${NCTX} \ 30 | TRAINER.UPLTrainer.CSC ${CSC} \ 31 | TRAINER.UPLTrainer.CLASS_TOKEN_POSITION ${CTP} \ 32 | DATASET.NUM_SHOTS ${SHOTS} 33 | done 34 | -------------------------------------------------------------------------------- /scripts/upl_test_existing_logits.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd .. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=UPLTrainer 8 | 9 | DATASET=$1 10 | CFG=$2 # config file 11 | CTP=$3 # class token position (end or middle) 12 | NCTX=$4 # number of context tokens 13 | SHOTS=$5 # number of shots (1, 2, 4, 8, 16) 14 | CSC=$6 # class-specific context (False or True) 15 | CLASS_EQULE=$7 # CLASS_EQULE True of False 16 | 17 | 18 | DIR=./output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots_EQULE_${CLASS_EQULE}_CONF_THRESHOLD_${CONF_THRESHOLD}_RN50_temp/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED} 19 | 20 | python upl_test.py \ 21 | --root ${DATA} \ 22 | --seed 1 \ 23 | --trainer ${TRAINER} \ 24 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 25 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 26 | --output-dir ${DIR} \ 27 | TRAINER.UPLTrainer.N_CTX ${NCTX} \ 28 | TRAINER.UPLTrainer.CSC ${CSC} \ 29 | TRAINER.UPLTrainer.CLASS_TOKEN_POSITION ${CTP} \ 30 | DATASET.NUM_SHOTS ${SHOTS} \ 31 | DATASET.CLASS_EQULE ${CLASS_EQULE} 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /scripts/upl_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd .. 4 | 5 | # custom config 6 | DATA=./data 7 | TRAINER=UPLTrainer 8 | 9 | DATASET=$1 10 | CFG=$2 # config file 11 | CTP=$3 # class token position (end or middle) 12 | NCTX=$4 # number of context tokens 13 | SHOTS=$5 # number of shots (1, 2, 4, 8, 16) 14 | CSC=$6 # class-specific context (False or True) 15 | CLASS_EQULE=$7 # CLASS_EQULE True of False 16 | TAG=$8 # log tag (multiple_models_random_init or rn50_random_init) 17 | 18 | 19 | for SEED in {1..16} 20 | do 21 | DIR=./output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots_EQULE_${CLASS_EQULE}_${CONF_THRESHOLD}_${TAG}/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED} 22 | if [ -d "$DIR" ]; then 23 | echo "Results are available in ${DIR}. Skip this job" 24 | else 25 | echo "Run this job and save the output to ${DIR}" 26 | python upl_train.py \ 27 | --root ${DATA} \ 28 | --seed ${SEED} \ 29 | --trainer ${TRAINER} \ 30 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 31 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 32 | --output-dir ${DIR} \ 33 | TRAINER.UPLTrainer.N_CTX ${NCTX} \ 34 | TRAINER.UPLTrainer.CSC ${CSC} \ 35 | TRAINER.UPLTrainer.CLASS_TOKEN_POSITION ${CTP} \ 36 | DATASET.NUM_SHOTS ${SHOTS} \ 37 | DATASET.CLASS_EQULE ${CLASS_EQULE} 38 | fi 39 | done 40 | 41 | 42 | # #!/bin/bash 43 | 44 | # cd .. 45 | 46 | # # custom config 47 | # DATA=./data 48 | # TRAINER=ZeroshotCLIP 49 | # DATASET=$1 50 | # CFG=$2 # rn50, rn101, vit_b32 or vit_b16 51 | 52 | # python sstrain.py \ 53 | # --root ${DATA} \ 54 | # --trainer ${TRAINER} \ 55 | # --dataset-config-file configs/datasets/${DATASET}.yaml \ 56 | # --config-file configs/trainers/HHTrainer/${CFG}.yaml \ 57 | # --output-dir output/${TRAINER}/${CFG}/zeroshot/${DATASET} \ 58 | # --eval-only -------------------------------------------------------------------------------- /datasets/food101.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | 5 | from .oxford_pets import OxfordPets 6 | from .dtd import DescribableTextures as DTD 7 | from .datasetbase import UPLDatasetBase 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class Food101(DatasetBase): 12 | 13 | dataset_dir = "food-101" 14 | 15 | def __init__(self, cfg): 16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 17 | self.dataset_dir = os.path.join(root, self.dataset_dir) 18 | self.image_dir = os.path.join(self.dataset_dir, "images") 19 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_Food101.json") 20 | 21 | if os.path.exists(self.split_path): 22 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 23 | else: 24 | train, val, test = DTD.read_and_split_data(self.image_dir) 25 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 26 | 27 | num_shots = cfg.DATASET.NUM_SHOTS 28 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 29 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 30 | 31 | super().__init__(train_x=train, val=val, test=test) 32 | 33 | 34 | @DATASET_REGISTRY.register() 35 | class SSFood101(UPLDatasetBase): 36 | 37 | dataset_dir = "food-101" 38 | 39 | def __init__(self, cfg): 40 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 41 | self.dataset_dir = os.path.join(root, self.dataset_dir) 42 | self.image_dir = os.path.join(self.dataset_dir, "images") 43 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_Food101.json") 44 | 45 | if os.path.exists(self.split_path): 46 | train, val, test = self.read_split(self.split_path, self.image_dir) 47 | else: 48 | train, val, test = DTD.read_and_split_data(self.image_dir) 49 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 50 | 51 | sstrain = self.read_sstrain_data(self.split_path, self.image_dir) 52 | num_shots = cfg.DATASET.NUM_SHOTS 53 | train = self.generate_fewshot_dataset(train, num_shots=-1) 54 | val = self.generate_fewshot_dataset(val, num_shots=-1) 55 | super().__init__(train_x=train, val = val, test=test, sstrain=sstrain) -------------------------------------------------------------------------------- /datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | 5 | from .oxford_pets import OxfordPets 6 | from .dtd import DescribableTextures as DTD 7 | from .datasetbase import UPLDatasetBase 8 | 9 | NEW_CNAMES = { 10 | "AnnualCrop": "Annual Crop Land", 11 | "Forest": "Forest", 12 | "HerbaceousVegetation": "Herbaceous Vegetation Land", 13 | "Highway": "Highway or Road", 14 | "Industrial": "Industrial Buildings", 15 | "Pasture": "Pasture Land", 16 | "PermanentCrop": "Permanent Crop Land", 17 | "Residential": "Residential Buildings", 18 | "River": "River", 19 | "SeaLake": "Sea or Lake", 20 | } 21 | 22 | 23 | @DATASET_REGISTRY.register() 24 | class EuroSAT(DatasetBase): 25 | 26 | dataset_dir = "eurosat" 27 | 28 | def __init__(self, cfg): 29 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 30 | self.dataset_dir = os.path.join(root, self.dataset_dir) 31 | self.image_dir = os.path.join(self.dataset_dir, "2750") 32 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_EuroSAT.json") 33 | 34 | if os.path.exists(self.split_path): 35 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 36 | else: 37 | train, val, test = DTD.read_and_split_data(self.image_dir, new_cnames=NEW_CNAMES) 38 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 39 | 40 | num_shots = cfg.DATASET.NUM_SHOTS 41 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 42 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 43 | 44 | super().__init__(train_x=train, val=val, test=test) 45 | 46 | def update_classname(self, dataset_old): 47 | dataset_new = [] 48 | for item_old in dataset_old: 49 | cname_old = item_old.classname 50 | cname_new = NEW_CLASSNAMES[cname_old] 51 | item_new = Datum(impath=item_old.impath, label=item_old.label, classname=cname_new) 52 | dataset_new.append(item_new) 53 | return dataset_new 54 | 55 | @DATASET_REGISTRY.register() 56 | class SSEuroSAT(UPLDatasetBase): 57 | dataset_dir = 'eurosat' 58 | def __init__(self, cfg): 59 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 60 | self.dataset_dir = os.path.join(root, self.dataset_dir) 61 | self.image_dir = os.path.join(self.dataset_dir, "2750") 62 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_EuroSAT.json") 63 | 64 | 65 | if os.path.exists(self.split_path): 66 | train, val, test = self.read_split(self.split_path, self.image_dir) 67 | else: 68 | train, val, test = DTD.read_and_split_data( 69 | self.image_dir, 70 | ignored=IGNORED, 71 | new_cnames=NEW_CNAMES 72 | ) 73 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 74 | sstrain = self.read_sstrain_data(self.split_path, self.image_dir) 75 | num_shots = cfg.DATASET.NUM_SHOTS 76 | train = self.generate_fewshot_dataset(train, num_shots=-1) 77 | 78 | val = self.generate_fewshot_dataset(val, num_shots=-1) 79 | super().__init__(train_x=train, val = val, test=test, sstrain=sstrain) 80 | 81 | -------------------------------------------------------------------------------- /trainers/zsclip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from dassl.engine import TRAINER_REGISTRY, TrainerX 5 | from dassl.optim import build_optimizer, build_lr_scheduler 6 | 7 | from clip import clip 8 | from clip.model import convert_weights 9 | 10 | from .coop import load_clip_to_cpu 11 | 12 | CUSTOM_TEMPLATES = { 13 | "OxfordPets": "a photo of a {}, a type of pet.", 14 | "OxfordFlowers": "a photo of a {}, a type of flower.", 15 | "FGVCAircraft": "a photo of a {}, a type of aircraft.", 16 | "DescribableTextures": "{} texture.", 17 | "EuroSAT": "a centered satellite photo of {}.", 18 | "StanfordCars": "a photo of a {}.", 19 | "Food101": "a photo of {}, a type of food.", 20 | "SUN397": "a photo of a {}.", 21 | "Caltech101": "a photo of a {}.", 22 | "UCF101": "a photo of a person doing {}.", 23 | "ImageNet": "a photo of a {}.", 24 | "ImageNetSketch": "a photo of a {}.", 25 | "ImageNetV2": "a photo of a {}.", 26 | "ImageNetA": "a photo of a {}.", 27 | "ImageNetR": "a photo of a {}.", 28 | } 29 | 30 | 31 | @TRAINER_REGISTRY.register() 32 | class ZeroshotCLIP(TrainerX): 33 | def build_model(self): 34 | cfg = self.cfg 35 | classnames = self.dm.dataset.classnames 36 | 37 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 38 | clip_model = load_clip_to_cpu(cfg) 39 | clip_model.to(self.device) 40 | 41 | temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME] 42 | prompts = [temp.format(c.replace("_", " ")) for c in classnames] 43 | print(f"Prompts: {prompts}") 44 | prompts = torch.cat([clip.tokenize(p) for p in prompts]) 45 | prompts = prompts.to(self.device) 46 | 47 | with torch.no_grad(): 48 | text_features = clip_model.encode_text(prompts) 49 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 50 | 51 | self.text_features = text_features 52 | self.clip_model = clip_model 53 | 54 | def model_inference(self, image): 55 | image_features = self.clip_model.encode_image(image) 56 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 57 | logit_scale = self.clip_model.logit_scale.exp() 58 | logits = logit_scale * image_features @ self.text_features.t() 59 | return logits 60 | 61 | 62 | @TRAINER_REGISTRY.register() 63 | class ZeroshotCLIP2(ZeroshotCLIP): 64 | """Prompt ensembling.""" 65 | 66 | # templates = IMAGENET_TEMPLATES 67 | templates = IMAGENET_TEMPLATES_SELECT 68 | 69 | def build_model(self): 70 | cfg = self.cfg 71 | classnames = self.dm.dataset.classnames 72 | 73 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 74 | clip_model = load_clip_to_cpu(cfg) 75 | clip_model.to(self.device) 76 | 77 | for params in clip_model.parameters(): 78 | params.requires_grad_(False) 79 | 80 | # add custom-made prompt 81 | if cfg.DATASET.NAME != "ImageNet": 82 | self.templates += [CUSTOM_TEMPLATES[cfg.DATASET.NAME]] 83 | 84 | num_temp = len(self.templates) 85 | print(f"Prompt ensembling (n={num_temp})") 86 | 87 | mean_text_features = 0 88 | for i, temp in enumerate(self.templates): 89 | prompts = [temp.format(c.replace("_", " ")) for c in classnames] 90 | prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(self.device) 91 | text_features = clip_model.encode_text(prompts) 92 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 93 | mean_text_features = mean_text_features + text_features 94 | mean_text_features = mean_text_features / num_temp 95 | mean_text_features = mean_text_features / mean_text_features.norm(dim=-1, keepdim=True) 96 | 97 | self.text_features = mean_text_features 98 | self.clip_model = clip_model 99 | -------------------------------------------------------------------------------- /datasets/stanford_cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | from scipy.io import loadmat 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | 6 | from .oxford_pets import OxfordPets 7 | from .datasetbase import UPLDatasetBase 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class StanfordCars(DatasetBase): 12 | 13 | dataset_dir = "stanford_cars" 14 | 15 | def __init__(self, cfg): 16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 17 | self.dataset_dir = os.path.join(root, self.dataset_dir) 18 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_StanfordCars.json") 19 | 20 | if os.path.exists(self.split_path): 21 | train, val, test = OxfordPets.read_split(self.split_path, self.dataset_dir) 22 | else: 23 | trainval_file = os.path.join(self.dataset_dir, "devkit", "cars_train_annos.mat") 24 | test_file = os.path.join(self.dataset_dir, "cars_test_annos_withlabels.mat") 25 | meta_file = os.path.join(self.dataset_dir, "devkit", "cars_meta.mat") 26 | trainval = self.read_data("cars_train", trainval_file, meta_file) 27 | test = self.read_data("cars_test", test_file, meta_file) 28 | train, val = OxfordPets.split_trainval(trainval) 29 | OxfordPets.save_split(train, val, test, self.split_path, self.dataset_dir) 30 | 31 | num_shots = cfg.DATASET.NUM_SHOTS 32 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 33 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 34 | 35 | super().__init__(train_x=train, val=val, test=test) 36 | 37 | def read_data(self, image_dir, anno_file, meta_file): 38 | anno_file = loadmat(anno_file)["annotations"][0] 39 | meta_file = loadmat(meta_file)["class_names"][0] 40 | items = [] 41 | 42 | for i in range(len(anno_file)): 43 | imname = anno_file[i]["fname"][0] 44 | impath = os.path.join(self.dataset_dir, image_dir, imname) 45 | label = anno_file[i]["class"][0, 0] 46 | label = int(label) - 1 # convert to 0-based index 47 | classname = meta_file[label][0] 48 | names = classname.split(" ") 49 | year = names.pop(-1) 50 | names.insert(0, year) 51 | classname = " ".join(names) 52 | item = Datum(impath=impath, label=label, classname=classname) 53 | items.append(item) 54 | 55 | return items 56 | 57 | @DATASET_REGISTRY.register() 58 | class SSStanfordCars(UPLDatasetBase): 59 | 60 | dataset_dir = 'stanford_cars' 61 | 62 | def __init__(self, cfg): 63 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 64 | self.dataset_dir = os.path.join(root, self.dataset_dir) 65 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_StanfordCars.json") 66 | self.image_dir = self.dataset_dir 67 | 68 | if os.path.exists(self.split_path): 69 | train, val, test = self.read_split(self.split_path, self.dataset_dir) 70 | else: 71 | trainval_file = os.path.join(self.dataset_dir, "devkit", "cars_train_annos.mat") 72 | test_file = os.path.join(self.dataset_dir, "cars_test_annos_withlabels.mat") 73 | meta_file = os.path.join(self.dataset_dir, "devkit", "cars_meta.mat") 74 | trainval = self.read_data("cars_train", trainval_file, meta_file) 75 | test = self.read_data("cars_test", test_file, meta_file) 76 | train, val = OxfordPets.split_trainval(trainval) 77 | OxfordPets.save_split(train, val, test, self.split_path, self.dataset_dir) 78 | 79 | sstrain = self.read_sstrain_data(self.split_path, self.dataset_dir) 80 | num_shots = cfg.DATASET.NUM_SHOTS 81 | train = self.generate_fewshot_dataset(train, num_shots=-1) 82 | val = self.generate_fewshot_dataset(val, num_shots=-1) 83 | super().__init__(train_x=train, val = val, test=test, sstrain=sstrain) -------------------------------------------------------------------------------- /evaluation/evaluator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os.path as osp 3 | from collections import OrderedDict, defaultdict 4 | import torch 5 | from sklearn.metrics import f1_score, confusion_matrix 6 | 7 | from dassl.evaluation.evaluator import Classification 8 | from dassl.evaluation.build import EVALUATOR_REGISTRY 9 | 10 | @EVALUATOR_REGISTRY.register() 11 | class UPLClassification(Classification): 12 | 13 | def process(self, mo, gt, per_image_txt_writer, per_class_txt_writer): 14 | # mo (torch.Tensor): model output [batch, num_classes] 15 | # gt (torch.LongTensor): ground truth [batch] 16 | self.per_image_txt_writer = per_image_txt_writer 17 | self.per_class_txt_writer = per_class_txt_writer 18 | pred = mo.max(1)[1] 19 | matches = pred.eq(gt).float() 20 | self._correct += int(matches.sum().item()) 21 | self._total += gt.shape[0] 22 | 23 | self._y_true.extend(gt.data.cpu().numpy().tolist()) 24 | self._y_pred.extend(pred.data.cpu().numpy().tolist()) 25 | 26 | if self._per_class_res is not None: 27 | for i, (label, pred) in enumerate(zip(gt, mo)): 28 | label = label.item() 29 | label_name = self._lab2cname[label] 30 | matches_i = int(matches[i].item()) 31 | self._per_class_res[label].append(matches_i) 32 | # print(i, label_name, label, matches_i, pred.data.cpu().numpy().tolist()) 33 | write_line = str(i) + "," + str(label_name) + "," + str(label) + "," + str(matches_i) + "," + str(pred.data.cpu().numpy().tolist()) 34 | self.per_image_txt_writer.write(write_line+'\n') 35 | 36 | def evaluate(self): 37 | results = OrderedDict() 38 | acc = 100.0 * self._correct / self._total 39 | err = 100.0 - acc 40 | macro_f1 = 100.0 * f1_score( 41 | self._y_true, 42 | self._y_pred, 43 | average="macro", 44 | labels=np.unique(self._y_true) 45 | ) 46 | 47 | # The first value will be returned by trainer.test() 48 | results["accuracy"] = acc 49 | results["error_rate"] = err 50 | results["macro_f1"] = macro_f1 51 | 52 | print( 53 | "=> result\n" 54 | f"* total: {self._total:,}\n" 55 | f"* correct: {self._correct:,}\n" 56 | f"* accuracy: {acc:.2f}%\n" 57 | f"* error: {err:.2f}%\n" 58 | f"* macro_f1: {macro_f1:.2f}%" 59 | ) 60 | 61 | if self._per_class_res is not None: 62 | labels = list(self._per_class_res.keys()) 63 | labels.sort() 64 | 65 | print("=> per-class result") 66 | accs = [] 67 | 68 | for label in self._lab2cname: 69 | classname = self._lab2cname[label] 70 | res = self._per_class_res[label] 71 | correct = sum(res) 72 | total = len(res) 73 | acc = 100.0 * correct / total 74 | accs.append(acc) 75 | print( 76 | "* class: {} ({})\t" 77 | "total: {:,}\t" 78 | "correct: {:,}\t" 79 | "acc: {:.2f}%".format( 80 | label, classname, total, correct, acc 81 | ) 82 | ) 83 | write_line = "* class: {} ({}), total: {:,}, correct: {:,}, acc: {:.2f}% \n".format(label, classname, total, correct, acc) 84 | self.per_class_txt_writer.write(write_line) 85 | mean_acc = np.mean(accs) 86 | print("* average: {:.2f}%".format(mean_acc)) 87 | 88 | results["perclass_accuracy"] = mean_acc 89 | 90 | if self.cfg.TEST.COMPUTE_CMAT: 91 | cmat = confusion_matrix( 92 | self._y_true, self._y_pred, normalize="true" 93 | ) 94 | save_path = osp.join(self.cfg.OUTPUT_DIR, "cmat.pt") 95 | torch.save(cmat, save_path) 96 | print('Confusion matrix is saved to "{}"'.format(save_path)) 97 | 98 | return results 99 | -------------------------------------------------------------------------------- /datasets/sun397.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | 5 | from .oxford_pets import OxfordPets 6 | from .datasetbase import UPLDatasetBase 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class SUN397(DatasetBase): 11 | 12 | dataset_dir = "sun397" 13 | 14 | def __init__(self, cfg): 15 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 16 | self.dataset_dir = os.path.join(root, self.dataset_dir) 17 | self.image_dir = os.path.join(self.dataset_dir, "SUN397") 18 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_SUN397.json") 19 | 20 | if os.path.exists(self.split_path): 21 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 22 | else: 23 | classnames = [] 24 | with open(os.path.join(self.dataset_dir, "ClassName.txt"), "r") as f: 25 | lines = f.readlines() 26 | for line in lines: 27 | line = line.strip()[1:] # remove / 28 | classnames.append(line) 29 | cname2lab = {c: i for i, c in enumerate(classnames)} 30 | trainval = self.read_data(cname2lab, "Training_01.txt") 31 | test = self.read_data(cname2lab, "Testing_01.txt") 32 | train, val = OxfordPets.split_trainval(trainval) 33 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 34 | 35 | num_shots = cfg.DATASET.NUM_SHOTS 36 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 37 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 38 | 39 | super().__init__(train_x=train, val=val, test=test) 40 | 41 | def read_data(self, cname2lab, text_file): 42 | text_file = os.path.join(self.dataset_dir, text_file) 43 | items = [] 44 | 45 | with open(text_file, "r") as f: 46 | lines = f.readlines() 47 | for line in lines: 48 | imname = line.strip()[1:] # remove / 49 | classname = os.path.dirname(imname) 50 | label = cname2lab[classname] 51 | impath = os.path.join(self.image_dir, imname) 52 | 53 | names = classname.split("/")[1:] # remove 1st letter 54 | names = names[::-1] # put words like indoor/outdoor at first 55 | classname = " ".join(names) 56 | 57 | item = Datum(impath=impath, label=label, classname=classname) 58 | items.append(item) 59 | 60 | return items 61 | 62 | 63 | 64 | @DATASET_REGISTRY.register() 65 | class SSSUN397(UPLDatasetBase): 66 | 67 | dataset_dir = "sun397" 68 | 69 | def __init__(self, cfg): 70 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 71 | self.dataset_dir = os.path.join(root, self.dataset_dir) 72 | self.image_dir = os.path.join(self.dataset_dir, "SUN397") 73 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_SUN397.json") 74 | 75 | if os.path.exists(self.split_path): 76 | train, val, test = self.read_split(self.split_path, self.image_dir) 77 | else: 78 | classnames = [] 79 | with open(os.path.join(self.dataset_dir, "ClassName.txt"), "r") as f: 80 | lines = f.readlines() 81 | for line in lines: 82 | line = line.strip()[1:] # remove / 83 | classnames.append(line) 84 | cname2lab = {c: i for i, c in enumerate(classnames)} 85 | trainval = self.read_data(cname2lab, "Training_01.txt") 86 | test = self.read_data(cname2lab, "Testing_01.txt") 87 | train, val = OxfordPets.split_trainval(trainval) 88 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 89 | 90 | sstrain = self.read_sstrain_data(self.split_path, self.image_dir) 91 | num_shots = cfg.DATASET.NUM_SHOTS 92 | train = self.generate_fewshot_dataset(train, num_shots=-1) 93 | val = self.generate_fewshot_dataset(val, num_shots=-1) 94 | super().__init__(train_x=train, val = val, test=test, sstrain=sstrain) -------------------------------------------------------------------------------- /datasets/oxford_flowers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from scipy.io import loadmat 4 | from collections import defaultdict 5 | 6 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 7 | from dassl.utils import read_json 8 | 9 | from .oxford_pets import OxfordPets 10 | from .datasetbase import UPLDatasetBase 11 | 12 | 13 | @DATASET_REGISTRY.register() 14 | class OxfordFlowers(DatasetBase): 15 | 16 | dataset_dir = "oxford_flowers" 17 | 18 | def __init__(self, cfg): 19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 20 | self.dataset_dir = os.path.join(root, self.dataset_dir) 21 | self.image_dir = os.path.join(self.dataset_dir, "jpg") 22 | self.label_file = os.path.join(self.dataset_dir, "imagelabels.mat") 23 | self.lab2cname_file = os.path.join(self.dataset_dir, "cat_to_name.json") 24 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordFlowers.json") 25 | 26 | if os.path.exists(self.split_path): 27 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 28 | else: 29 | train, val, test = self.read_data() 30 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 31 | 32 | num_shots = cfg.DATASET.NUM_SHOTS 33 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 34 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 35 | 36 | super().__init__(train_x=train, val=val, test=test) 37 | 38 | def read_data(self): 39 | tracker = defaultdict(list) 40 | label_file = loadmat(self.label_file)["labels"][0] 41 | for i, label in enumerate(label_file): 42 | imname = f"image_{str(i + 1).zfill(5)}.jpg" 43 | impath = os.path.join(self.image_dir, imname) 44 | label = int(label) 45 | tracker[label].append(impath) 46 | 47 | print("Splitting data into 50% train, 20% val, and 30% test") 48 | 49 | def _collate(ims, y, c): 50 | items = [] 51 | for im in ims: 52 | item = Datum(impath=im, label=y - 1, classname=c) # convert to 0-based label 53 | items.append(item) 54 | return items 55 | 56 | lab2cname = read_json(self.lab2cname_file) 57 | train, val, test = [], [], [] 58 | for label, impaths in tracker.items(): 59 | random.shuffle(impaths) 60 | n_total = len(impaths) 61 | n_train = round(n_total * 0.5) 62 | n_val = round(n_total * 0.2) 63 | n_test = n_total - n_train - n_val 64 | assert n_train > 0 and n_val > 0 and n_test > 0 65 | cname = lab2cname[str(label)] 66 | train.extend(_collate(impaths[:n_train], label, cname)) 67 | val.extend(_collate(impaths[n_train : n_train + n_val], label, cname)) 68 | test.extend(_collate(impaths[n_train + n_val :], label, cname)) 69 | 70 | return train, val, test 71 | 72 | 73 | @DATASET_REGISTRY.register() 74 | class SSOxfordFlowers(UPLDatasetBase): 75 | dataset_dir = "oxford_flowers" 76 | def __init__(self, cfg): 77 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 78 | self.dataset_dir = os.path.join(root, self.dataset_dir) 79 | self.image_dir = os.path.join(self.dataset_dir, "jpg") 80 | self.label_file = os.path.join(self.dataset_dir, "imagelabels.mat") 81 | self.lab2cname_file = os.path.join(self.dataset_dir, "cat_to_name.json") 82 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordFlowers.json") 83 | 84 | print(self.split_path, 23333333) 85 | 86 | if os.path.exists(self.split_path): 87 | train, val, test = self.read_split(self.split_path, self.image_dir) 88 | else: 89 | train, val, test = self.read_data() 90 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 91 | 92 | sstrain = self.read_sstrain_data(self.split_path, self.image_dir) 93 | num_shots = cfg.DATASET.NUM_SHOTS 94 | train = self.generate_fewshot_dataset(train, num_shots=-1) 95 | val = self.generate_fewshot_dataset(val, num_shots=-1) 96 | super().__init__(train_x=train, val = val, test=test, sstrain=sstrain) -------------------------------------------------------------------------------- /datasets/dtd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import listdir_nohidden 6 | 7 | from .oxford_pets import OxfordPets 8 | from .datasetbase import UPLDatasetBase 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class DescribableTextures(DatasetBase): 13 | 14 | dataset_dir = "dtd" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "images") 20 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_DescribableTextures.json") 21 | 22 | if os.path.exists(self.split_path): 23 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 24 | else: 25 | train, val, test = self.read_and_split_data(self.image_dir) 26 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 27 | 28 | num_shots = cfg.DATASET.NUM_SHOTS 29 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 30 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 31 | 32 | super().__init__(train_x=train, val=val, test=test) 33 | 34 | @staticmethod 35 | def read_and_split_data(image_dir, p_trn=0.5, p_val=0.2, ignored=[], new_cnames=None): 36 | # The data are supposed to be organized into the following structure 37 | # ============= 38 | # images/ 39 | # dog/ 40 | # cat/ 41 | # horse/ 42 | # ============= 43 | categories = listdir_nohidden(image_dir) 44 | categories = [c for c in categories if c not in ignored] 45 | categories.sort() 46 | 47 | p_tst = 1 - p_trn - p_val 48 | print(f"Splitting into {p_trn:.0%} train, {p_val:.0%} val, and {p_tst:.0%} test") 49 | 50 | def _collate(ims, y, c): 51 | items = [] 52 | for im in ims: 53 | item = Datum(impath=im, label=y, classname=c) # is already 0-based 54 | items.append(item) 55 | return items 56 | 57 | train, val, test = [], [], [] 58 | for label, category in enumerate(categories): 59 | category_dir = os.path.join(image_dir, category) 60 | images = listdir_nohidden(category_dir) 61 | images = [os.path.join(category_dir, im) for im in images] 62 | random.shuffle(images) 63 | n_total = len(images) 64 | n_train = round(n_total * p_trn) 65 | n_val = round(n_total * p_val) 66 | n_test = n_total - n_train - n_val 67 | assert n_train > 0 and n_val > 0 and n_test > 0 68 | 69 | if new_cnames is not None and category in new_cnames: 70 | category = new_cnames[category] 71 | 72 | train.extend(_collate(images[:n_train], label, category)) 73 | val.extend(_collate(images[n_train : n_train + n_val], label, category)) 74 | test.extend(_collate(images[n_train + n_val :], label, category)) 75 | 76 | return train, val, test 77 | 78 | @DATASET_REGISTRY.register() 79 | class SSDescribableTextures(UPLDatasetBase): 80 | 81 | dataset_dir = "dtd" 82 | 83 | def __init__(self, cfg): 84 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 85 | self.dataset_dir = os.path.join(root, self.dataset_dir) 86 | self.image_dir = os.path.join(self.dataset_dir, "images") 87 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_DescribableTextures.json") 88 | 89 | if os.path.exists(self.split_path): 90 | train, val, test = self.read_split(self.split_path, self.image_dir) 91 | else: 92 | train, val, test = self.read_and_split_data(self.image_dir) 93 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 94 | sstrain = self.read_sstrain_data(self.split_path, self.image_dir) 95 | num_shots = cfg.DATASET.NUM_SHOTS 96 | train = self.generate_fewshot_dataset(train, num_shots=-1) 97 | val = self.generate_fewshot_dataset(val, num_shots=-1) 98 | super().__init__(train_x=train, val = val, test=test, sstrain=sstrain) -------------------------------------------------------------------------------- /datasets/ucf101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | 6 | from .oxford_pets import OxfordPets 7 | from .datasetbase import UPLDatasetBase 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class UCF101(DatasetBase): 12 | 13 | dataset_dir = "ucf101" 14 | 15 | def __init__(self, cfg): 16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 17 | self.dataset_dir = os.path.join(root, self.dataset_dir) 18 | self.image_dir = os.path.join(self.dataset_dir, "UCF-101-midframes") 19 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_UCF101.json") 20 | 21 | if os.path.exists(self.split_path): 22 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 23 | else: 24 | cname2lab = {} 25 | filepath = os.path.join(self.dataset_dir, "ucfTrainTestlist/classInd.txt") 26 | with open(filepath, "r") as f: 27 | lines = f.readlines() 28 | for line in lines: 29 | label, classname = line.strip().split(" ") 30 | label = int(label) - 1 # conver to 0-based index 31 | cname2lab[classname] = label 32 | 33 | trainval = self.read_data(cname2lab, "ucfTrainTestlist/trainlist01.txt") 34 | test = self.read_data(cname2lab, "ucfTrainTestlist/testlist01.txt") 35 | train, val = OxfordPets.split_trainval(trainval) 36 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 37 | 38 | num_shots = cfg.DATASET.NUM_SHOTS 39 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 40 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 41 | 42 | super().__init__(train_x=train, val=val, test=test) 43 | 44 | def read_data(self, cname2lab, text_file): 45 | text_file = os.path.join(self.dataset_dir, text_file) 46 | items = [] 47 | 48 | with open(text_file, "r") as f: 49 | lines = f.readlines() 50 | for line in lines: 51 | line = line.strip().split(" ")[0] # trainlist: filename, label 52 | action, filename = line.split("/") 53 | label = cname2lab[action] 54 | 55 | elements = re.findall("[A-Z][^A-Z]*", action) 56 | renamed_action = "_".join(elements) 57 | 58 | filename = filename.replace(".avi", ".jpg") 59 | impath = os.path.join(self.image_dir, renamed_action, filename) 60 | 61 | item = Datum(impath=impath, label=label, classname=renamed_action) 62 | items.append(item) 63 | 64 | return items 65 | 66 | @DATASET_REGISTRY.register() 67 | class SSUCF101(UPLDatasetBase): 68 | 69 | dataset_dir = "ucf101" 70 | 71 | def __init__(self, cfg): 72 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 73 | self.dataset_dir = os.path.join(root, self.dataset_dir) 74 | self.image_dir = os.path.join(self.dataset_dir, "UCF-101-midframes") 75 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_UCF101.json") 76 | 77 | if os.path.exists(self.split_path): 78 | train, val, test = self.read_split(self.split_path, self.image_dir) 79 | else: 80 | cname2lab = {} 81 | filepath = os.path.join(self.dataset_dir, "ucfTrainTestlist/classInd.txt") 82 | with open(filepath, "r") as f: 83 | lines = f.readlines() 84 | for line in lines: 85 | label, classname = line.strip().split(" ") 86 | label = int(label) - 1 # conver to 0-based index 87 | cname2lab[classname] = label 88 | 89 | trainval = self.read_data(cname2lab, "ucfTrainTestlist/trainlist01.txt") 90 | test = self.read_data(cname2lab, "ucfTrainTestlist/testlist01.txt") 91 | train, val = OxfordPets.split_trainval(trainval) 92 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 93 | 94 | sstrain = self.read_sstrain_data(self.split_path, self.image_dir) 95 | num_shots = cfg.DATASET.NUM_SHOTS 96 | train = self.generate_fewshot_dataset(train, num_shots=-1) 97 | val = self.generate_fewshot_dataset(val, num_shots=-1) 98 | super().__init__(train_x=train, val = val, test=test, sstrain=sstrain) -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | 2 | # Unsupervised Prompt Learning for Vision-Language Models 3 | 4 | 5 | ## Introduction 6 | 7 | This repo is the official implementation of **Unsupervised Prompt Learning for Vision-Language Models**. UPL is the first framework to introduce unsupervised learning into prompt learning of vision-language models to avoid time-consuming prompt engineering and better adapt vision-language models for the downstream image recognition task. 8 | 9 | Contact us with tonyhuang_pku@outlook.com or fawe@microsoft.com. 10 | 11 | 12 |
13 | 14 | 15 | Fig.1 Overview of Unsupervised Prompt Learning (UPL) Framework. 16 |
17 | 18 | ## Install 19 | 20 | The code is built on the [CoOp](https://github.com/KaiyangZhou/CoOp) and [Dassl](https://github.com/KaiyangZhou/Dassl.pytorch) with commit `ac6e44194b2f90e325f477aadd6d9bc3a92ce255`, so you need to install the dassl environment first. You can follow the [instructions](https://github.com/KaiyangZhou/Dassl.pytorch#installation) to install *dassl* as well as *PyTorch*. After that, run `pip install -r requirements.txt` under `UPL/` to install a few more packages required by [CLIP](https://github.com/openai/CLIP). 21 | 22 | **We also prepare all installation commands for you**: 23 | 24 | ```bash 25 | 26 | ############ install Dassl ############ 27 | 28 | # Clone this repo 29 | git clone https://github.com/KaiyangZhou/Dassl.pytorch.git 30 | cd Dassl.pytorch/ 31 | git reset --hard ac6e44194b2f90e325f477aadd6d9bc3a92ce255 32 | 33 | # Create a conda environment 34 | conda create -n dassl python=3.7 35 | 36 | # Activate the environment 37 | conda activate dassl 38 | 39 | # Install dependencies 40 | pip install -r requirements.txt 41 | 42 | # Install torch (version >= 1.7.1) and torchvision 43 | conda install pytorch==1.8.1 torchvision==0.9.1 cudatoolkit=10.1 -c pytorch 44 | 45 | # Install this library (no need to re-build if the source code is modified) 46 | python setup.py develop 47 | 48 | ############ install UPL ############ 49 | 50 | # Enter the directory at the same level as Dassl 51 | cd .. 52 | 53 | # Clone this repo 54 | git clone https://github.com/tonyhuang2022/UPL.git 55 | cd UPL/ 56 | 57 | # Install CLIP dependencies 58 | pip install -r requirements.txt 59 | 60 | ######## attention ######## 61 | # We have two soft links, and you can redirect them! 62 | # The `data` is linked to the datasets, and the `temp_analyze_results_miltiple` is linked to the `info`. 63 | # We strongly recommend that you create these two paths on the disk which has enough space, and then use 64 | 65 | rm data temp_analyze_results_miltiple # remove the existing file 66 | ln -s ${your_data_path} ./data 67 | ln -s ${your_temp_analyze_results_miltiple_path} ./temp_analyze_results_miltiple 68 | 69 | # Finished 70 | ``` 71 | 72 | 73 | ## Datasets 74 | 75 | After that, you can follow the [CoOp Datasets Instructions](https://github.com/KaiyangZhou/CoOp/blob/main/DATASETS.md) to prepare the datasets. 76 | 77 | Then, you can run the code ~ 78 | 79 | ## Training 80 | 81 | Note that we need to obtain the pseudo labels of the train sets, so you need to run the `get info` step under `UPL/scripts` before training to obtain the needed files (e.g., logits). 82 | 83 | ### get info of RN50 84 | 85 | 86 | ```python 87 | CUDA_VISIBLE_DEVICES=0 bash get_info.sh sscaltech101 anay_rn50 end 16 -1 False 88 | ``` 89 | 90 | You can change the config file to get the info from other models (e.g., anay_rn101, anay_rn50x4). We also prepare a script for getting info from all CLIP released models. 91 | 92 | ```python 93 | CUDA_VISIBLE_DEVICES=0 bash get_all_info.sh sscaltech101 94 | ``` 95 | 96 | 97 | 98 | ### UPL train 99 | 100 | After `get info` step, we can train the prompt (default run 16 seeds, you can change it in `UPL/configs/UPLTrainer/rn50_ep50.yaml`): 101 | 102 | ```python 103 | CUDA_VISIBLE_DEVICES=0 bash upl_train.sh sscaltech101 rn50_ep50 end 16 16 False True rn50_random_init 104 | ``` 105 | 106 | If you want to use *UPL**, please change the `PSEUDO_LABEL_MODELS` in `UPL/configs/UPLTrainer/rn50_ep50.yaml`. Please ensure that you have obatined info from all released models. Then, you can run 107 | 108 | ```python 109 | CUDA_VISIBLE_DEVICES=0 bash upl_train.sh sscaltech101 rn50_ep50 end 16 16 False True multiple_models_random_init 110 | ``` 111 | 112 | 113 | ## Testing 114 | 115 | ### Test with existing files after UPL training 116 | 117 | ```python 118 | bash upl_test_existing_logits.sh sscaltech101 rn50_ep50 end 16 16 False True 119 | ``` 120 | 121 | ## Thanks 122 | 123 | We use code from [CoOp](https://github.com/KaiyangZhou/CoOp) and [Dassl](https://github.com/KaiyangZhou/Dassl.pytorch), which are great repositories and we encourage you to check them out and cite them in your work. -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /datasets/fgvc_aircraft.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | 5 | from .oxford_pets import OxfordPets 6 | from .datasetbase import UPLDatasetBase 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class FGVCAircraft(DatasetBase): 11 | 12 | dataset_dir = "fgvc_aircraft" 13 | 14 | def __init__(self, cfg): 15 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 16 | self.dataset_dir = os.path.join(root, self.dataset_dir) 17 | self.image_dir = os.path.join(self.dataset_dir, "images") 18 | 19 | classnames = [] 20 | with open(os.path.join(self.dataset_dir, "variants.txt"), "r") as f: 21 | lines = f.readlines() 22 | for line in lines: 23 | classnames.append(line.strip()) 24 | cname2lab = {c: i for i, c in enumerate(classnames)} 25 | 26 | train = self.read_data(cname2lab, "images_variant_train.txt") 27 | val = self.read_data(cname2lab, "images_variant_val.txt") 28 | test = self.read_data(cname2lab, "images_variant_test.txt") 29 | 30 | num_shots = cfg.DATASET.NUM_SHOTS 31 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 32 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 33 | 34 | super().__init__(train_x=train, val=val, test=test) 35 | 36 | def read_data(self, cname2lab, split_file): 37 | filepath = os.path.join(self.dataset_dir, split_file) 38 | items = [] 39 | 40 | with open(filepath, "r") as f: 41 | lines = f.readlines() 42 | for line in lines: 43 | line = line.strip().split(" ") 44 | imname = line[0] + ".jpg" 45 | classname = " ".join(line[1:]) 46 | impath = os.path.join(self.image_dir, imname) 47 | label = cname2lab[classname] 48 | item = Datum(impath=impath, label=label, classname=classname) 49 | items.append(item) 50 | 51 | return items 52 | 53 | 54 | @DATASET_REGISTRY.register() 55 | class SSFGVCAircraft(UPLDatasetBase): 56 | dataset_dir = 'fgvc_aircraft' 57 | def __init__(self, cfg): 58 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 59 | self.dataset_dir = os.path.join(root, self.dataset_dir) 60 | self.image_dir = os.path.join(self.dataset_dir, "images") 61 | 62 | classnames = [] 63 | with open(os.path.join(self.dataset_dir, "variants.txt"), "r") as f: 64 | lines = f.readlines() 65 | for line in lines: 66 | classnames.append(line.strip()) 67 | cname2lab = {c: i for i, c in enumerate(classnames)} 68 | self.cname2lab = cname2lab 69 | 70 | train = self.read_data(cname2lab, "images_variant_train.txt") 71 | val = self.read_data(cname2lab, "images_variant_val.txt") 72 | test = self.read_data(cname2lab, "images_variant_test.txt") 73 | 74 | sstrain = self.read_data_without_label(cname2lab, "images_variant_train.txt") 75 | num_shots = cfg.DATASET.NUM_SHOTS 76 | train = self.generate_fewshot_dataset(train, num_shots=-1) 77 | val = self.generate_fewshot_dataset(val, num_shots=-1) 78 | super().__init__(train_x=train, val=val, test=test, sstrain=sstrain) 79 | 80 | def read_data(self, cname2lab, split_file): 81 | filepath = os.path.join(self.dataset_dir, split_file) 82 | items = [] 83 | 84 | with open(filepath, "r") as f: 85 | lines = f.readlines() 86 | for line in lines: 87 | line = line.strip().split(" ") 88 | imname = line[0] + ".jpg" 89 | classname = " ".join(line[1:]) 90 | impath = os.path.join(self.image_dir, imname) 91 | label = cname2lab[classname] 92 | item = Datum(impath=impath, label=label, classname=classname) 93 | items.append(item) 94 | 95 | return items 96 | 97 | def read_data_without_label(self, cname2lab, split_file, predict_label_dict=None): 98 | filepath = os.path.join(self.dataset_dir, split_file) 99 | items = [] 100 | if predict_label_dict is None: 101 | with open(filepath, "r") as f: 102 | lines = f.readlines() 103 | for line in lines: 104 | line = line.strip().split(" ") 105 | imname = line[0] + ".jpg" 106 | classname = " ".join(line[1:]) 107 | impath = os.path.join(self.image_dir, imname) 108 | label = cname2lab[classname] 109 | item = Datum(impath=impath, label=-1, classname=None) 110 | items.append(item) 111 | else: 112 | with open(filepath, "r") as f: 113 | lines = f.readlines() 114 | for line in lines: 115 | line = line.strip().split(" ") 116 | imname = line[0] + ".jpg" 117 | classname = " ".join(line[1:]) 118 | impath = os.path.join(self.image_dir, imname) 119 | sub_impath = './data/' + impath.split('/data/')[1] 120 | if sub_impath in predict_label_dict: 121 | item = Datum(impath=impath, label=predict_label_dict[sub_impath], classname=self._lab2cname[predict_label_dict[sub_impath]]) 122 | items.append(item) 123 | return items 124 | -------------------------------------------------------------------------------- /datasets/caltech101.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from .datasetbase import UPLDatasetBase 5 | from dassl.utils import read_json 6 | 7 | from .oxford_pets import OxfordPets 8 | from .dtd import DescribableTextures as DTD 9 | import random 10 | 11 | IGNORED = ["BACKGROUND_Google", "Faces_easy"] 12 | NEW_CNAMES = { 13 | "airplanes": "airplane", 14 | "Faces": "face", 15 | "Leopards": "leopard", 16 | "Motorbikes": "motorbike", 17 | } 18 | 19 | 20 | @DATASET_REGISTRY.register() 21 | class Caltech101(DatasetBase): 22 | 23 | dataset_dir = "caltech-101" 24 | 25 | def __init__(self, cfg): 26 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 27 | self.dataset_dir = os.path.join(root, self.dataset_dir) 28 | self.image_dir = os.path.join(self.dataset_dir, "101_ObjectCategories") 29 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_Caltech101.json") 30 | 31 | if os.path.exists(self.split_path): 32 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 33 | else: 34 | train, val, test = DTD.read_and_split_data(self.image_dir, ignored=IGNORED, new_cnames=NEW_CNAMES) 35 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 36 | 37 | num_shots = cfg.DATASET.NUM_SHOTS 38 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 39 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 40 | 41 | super().__init__(train_x=train, val=val, test=test) 42 | 43 | @DATASET_REGISTRY.register() 44 | class OpensetCaltech101(UPLDatasetBase): 45 | 46 | dataset_dir = 'caltech-101' 47 | 48 | def __init__(self, cfg): 49 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 50 | self.dataset_dir = os.path.join(root, self.dataset_dir) 51 | self.image_dir = os.path.join(self.dataset_dir, '101_ObjectCategories') 52 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_Caltech101.json') 53 | 54 | if os.path.exists(self.split_path): 55 | train, val, test = self.read_split(self.split_path, self.image_dir) 56 | else: 57 | train, val, test = DTD.read_and_split_data( 58 | self.image_dir, 59 | ignored=IGNORED, 60 | new_cnames=NEW_CNAMES 61 | ) 62 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 63 | 64 | 65 | # if IGNORE_NUM <=0 use IGNORE_FILE 66 | ignore_label_num = int(cfg.DATASET.IGNORE_NUM) 67 | if ignore_label_num > 0: 68 | ignore_labels = [i for i in range(100)] 69 | ignore_labels = random.sample(ignore_labels, ignore_label_num) 70 | else: 71 | ignore_labels = [] 72 | caltech101_ignore_file = open(cfg.DATASET.IGNORE_FILE, 'r') 73 | lines = caltech101_ignore_file.readlines() 74 | for line_id in range(0, len(lines)): 75 | class_name, is_ignore = lines[line_id].split(',') 76 | if int(is_ignore) == 1: 77 | ignore_labels.append(class_name) 78 | 79 | 80 | num_shots = cfg.DATASET.NUM_SHOTS 81 | ignore_label_num = int(cfg.DATASET.IGNORE_NUM) 82 | train = self.generate_fewshot_dataset(train, num_shots=num_shots, mode='train') 83 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4), mode='val') 84 | novel, base = self.split_base_and_novel(test, ignore_labels) 85 | self.novel = novel 86 | self.base = base 87 | 88 | super().__init__(train_x=train, val=val, test=test, novel=novel, base=base) 89 | 90 | def read_split(self, filepath, path_prefix): 91 | def _convert(items): 92 | out = [] 93 | for impath, label, classname in items: 94 | impath = os.path.join(path_prefix, impath) 95 | item = Datum( 96 | impath=impath, 97 | label=int(label), 98 | classname=classname 99 | ) 100 | out.append(item) 101 | return out 102 | 103 | print(f'Reading split from {filepath}') 104 | split = read_json(filepath) 105 | train = _convert(split['train']) 106 | val = _convert(split['val']) 107 | test = _convert(split['test']) 108 | 109 | return train, val, test 110 | 111 | 112 | 113 | @DATASET_REGISTRY.register() 114 | class SSCaltech101(UPLDatasetBase): 115 | dataset_dir = 'caltech-101' 116 | def __init__(self, cfg): 117 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 118 | self.dataset_dir = os.path.join(root, self.dataset_dir) 119 | self.image_dir = os.path.join(self.dataset_dir, '101_ObjectCategories') 120 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_Caltech101.json') 121 | 122 | if os.path.exists(self.split_path): 123 | train, val, test = self.read_split(self.split_path, self.image_dir) 124 | else: 125 | train, val, test = DTD.read_and_split_data( 126 | self.image_dir, 127 | ignored=IGNORED, 128 | new_cnames=NEW_CNAMES 129 | ) 130 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 131 | sstrain = self.read_sstrain_data(self.split_path, self.image_dir) 132 | num_shots = cfg.DATASET.NUM_SHOTS 133 | train = self.generate_fewshot_dataset(train, num_shots=-1) 134 | val = self.generate_fewshot_dataset(val, num_shots=-1) 135 | super().__init__(train_x=train, val = val, test=test, sstrain=sstrain) 136 | 137 | -------------------------------------------------------------------------------- /datasets/oxford_pets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import random 4 | from collections import defaultdict 5 | 6 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 7 | from dassl.utils import read_json, write_json 8 | 9 | from .datasetbase import UPLDatasetBase 10 | 11 | 12 | @DATASET_REGISTRY.register() 13 | class OxfordPets(DatasetBase): 14 | 15 | dataset_dir = "oxford_pets" 16 | 17 | def __init__(self, cfg): 18 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 19 | self.dataset_dir = os.path.join(root, self.dataset_dir) 20 | self.image_dir = os.path.join(self.dataset_dir, "images") 21 | self.anno_dir = os.path.join(self.dataset_dir, "annotations") 22 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordPets.json") 23 | 24 | if os.path.exists(self.split_path): 25 | train, val, test = self.read_split(self.split_path, self.image_dir) 26 | else: 27 | trainval = self.read_data(split_file="trainval.txt") 28 | test = self.read_data(split_file="test.txt") 29 | train, val = self.split_trainval(trainval) 30 | self.save_split(train, val, test, self.split_path, self.image_dir) 31 | 32 | num_shots = cfg.DATASET.NUM_SHOTS 33 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 34 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 35 | 36 | super().__init__(train_x=train, val=val, test=test) 37 | 38 | def read_data(self, split_file): 39 | filepath = os.path.join(self.anno_dir, split_file) 40 | items = [] 41 | 42 | with open(filepath, "r") as f: 43 | lines = f.readlines() 44 | for line in lines: 45 | line = line.strip() 46 | imname, label, species, _ = line.split(" ") 47 | breed = imname.split("_")[:-1] 48 | breed = "_".join(breed) 49 | breed = breed.lower() 50 | imname += ".jpg" 51 | impath = os.path.join(self.image_dir, imname) 52 | label = int(label) - 1 # convert to 0-based index 53 | item = Datum(impath=impath, label=label, classname=breed) 54 | items.append(item) 55 | 56 | return items 57 | 58 | @staticmethod 59 | def split_trainval(trainval, p_val=0.2): 60 | p_trn = 1 - p_val 61 | print(f"Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val") 62 | tracker = defaultdict(list) 63 | for idx, item in enumerate(trainval): 64 | label = item.label 65 | tracker[label].append(idx) 66 | 67 | train, val = [], [] 68 | for label, idxs in tracker.items(): 69 | n_val = round(len(idxs) * p_val) 70 | assert n_val > 0 71 | random.shuffle(idxs) 72 | for n, idx in enumerate(idxs): 73 | item = trainval[idx] 74 | if n < n_val: 75 | val.append(item) 76 | else: 77 | train.append(item) 78 | 79 | return train, val 80 | 81 | @staticmethod 82 | def save_split(train, val, test, filepath, path_prefix): 83 | def _extract(items): 84 | out = [] 85 | for item in items: 86 | impath = item.impath 87 | label = item.label 88 | classname = item.classname 89 | impath = impath.replace(path_prefix, "") 90 | if impath.startswith("/"): 91 | impath = impath[1:] 92 | out.append((impath, label, classname)) 93 | return out 94 | 95 | train = _extract(train) 96 | val = _extract(val) 97 | test = _extract(test) 98 | 99 | split = {"train": train, "val": val, "test": test} 100 | 101 | write_json(split, filepath) 102 | print(f"Saved split to {filepath}") 103 | 104 | @staticmethod 105 | def read_split(filepath, path_prefix): 106 | def _convert(items): 107 | out = [] 108 | for impath, label, classname in items: 109 | impath = os.path.join(path_prefix, impath) 110 | item = Datum(impath=impath, label=int(label), classname=classname) 111 | out.append(item) 112 | return out 113 | 114 | print(f"Reading split from {filepath}") 115 | split = read_json(filepath) 116 | train = _convert(split["train"]) 117 | val = _convert(split["val"]) 118 | test = _convert(split["test"]) 119 | 120 | return train, val, test 121 | 122 | 123 | @DATASET_REGISTRY.register() 124 | class SSOxfordPets(UPLDatasetBase): 125 | dataset_dir = "oxford_pets" 126 | def __init__(self, cfg): 127 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 128 | self.dataset_dir = os.path.join(root, self.dataset_dir) 129 | self.image_dir = os.path.join(self.dataset_dir, "images") 130 | self.anno_dir = os.path.join(self.dataset_dir, "annotations") 131 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordPets.json") 132 | 133 | if os.path.exists(self.split_path): 134 | train, val, test = self.read_split(self.split_path, self.image_dir) 135 | else: 136 | trainval = self.read_data(split_file="trainval.txt") 137 | test = self.read_data(split_file="test.txt") 138 | train, val = self.split_trainval(trainval) 139 | self.save_split(train, val, test, self.split_path, self.image_dir) 140 | 141 | sstrain = self.read_sstrain_data(self.split_path, self.image_dir) 142 | num_shots = cfg.DATASET.NUM_SHOTS 143 | train = self.generate_fewshot_dataset(train, num_shots=-1) 144 | val = self.generate_fewshot_dataset(val, num_shots=-1) 145 | super().__init__(train_x=train, val = val, test=test, sstrain=sstrain) 146 | -------------------------------------------------------------------------------- /datasets/data_manager.py: -------------------------------------------------------------------------------- 1 | from dassl.data import DataManager 2 | from dassl.data.data_manager import DatasetWrapper 3 | from dassl.data.transforms import build_transform 4 | from dassl.data.samplers import build_sampler 5 | 6 | from torch.utils.data import DataLoader, WeightedRandomSampler 7 | import torch 8 | 9 | def build_data_loader( 10 | cfg, 11 | sampler_type="RandomSampler", 12 | sampler=None, 13 | data_source=None, 14 | batch_size=64, 15 | n_domain=0, 16 | n_ins=2, 17 | tfm=None, 18 | is_train=True, 19 | dataset_wrapper=None, 20 | tag=None 21 | ): 22 | # Build sampler 23 | if sampler_type is not None: 24 | sampler = build_sampler( 25 | sampler_type, 26 | cfg=cfg, 27 | data_source=data_source, 28 | batch_size=batch_size, 29 | n_domain=n_domain, 30 | n_ins=n_ins, 31 | ) 32 | else: 33 | sampler = sampler 34 | 35 | if dataset_wrapper is None: 36 | dataset_wrapper = DatasetWrapper 37 | 38 | # Build data loader 39 | if tag is None: 40 | data_loader = torch.utils.data.DataLoader( 41 | dataset_wrapper(cfg, data_source, transform=tfm, is_train=is_train), 42 | batch_size=batch_size, 43 | sampler=sampler, 44 | num_workers=cfg.DATALOADER.NUM_WORKERS, 45 | drop_last=is_train and len(data_source) >= batch_size, 46 | pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA), 47 | ) 48 | else: 49 | data_loader = torch.utils.data.DataLoader( 50 | dataset_wrapper(cfg, data_source, transform=tfm, is_train=is_train), 51 | batch_size=batch_size, 52 | sampler=sampler, 53 | num_workers=cfg.DATALOADER.NUM_WORKERS, 54 | drop_last=False, 55 | pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA), 56 | ) 57 | 58 | 59 | 60 | 61 | return data_loader 62 | 63 | class UPLDataManager(DataManager): 64 | def __init__(self, 65 | cfg, 66 | custom_tfm_train=None, 67 | custom_tfm_test=None, 68 | dataset_wrapper=None): 69 | super().__init__(cfg, custom_tfm_train, custom_tfm_test, dataset_wrapper) 70 | 71 | 72 | if custom_tfm_test is None: 73 | tfm_test = build_transform(cfg, is_train=False) 74 | else: 75 | print("* Using custom transform for testing") 76 | tfm_test = custom_tfm_test 77 | 78 | # Build transform 79 | if custom_tfm_train is None: 80 | tfm_train = build_transform(cfg, is_train=True) 81 | else: 82 | print("* Using custom transform for training") 83 | tfm_train = custom_tfm_train 84 | 85 | # save cfg 86 | self.cfg = cfg 87 | self.tfm_train = tfm_train 88 | self.dataset_wrapper = dataset_wrapper 89 | 90 | if cfg.DATALOADER.OPEN_SETTING: 91 | test_novel_loader = build_data_loader( 92 | cfg, 93 | sampler_type=cfg.DATALOADER.TEST.SAMPLER, 94 | data_source=self.dataset.novel, 95 | batch_size=cfg.DATALOADER.TEST.BATCH_SIZE, 96 | tfm=tfm_test, 97 | is_train=False, 98 | dataset_wrapper=dataset_wrapper, 99 | ) 100 | 101 | test_base_loader = build_data_loader( 102 | cfg, 103 | sampler_type=cfg.DATALOADER.TEST.SAMPLER, 104 | data_source=self.dataset.base, 105 | batch_size=cfg.DATALOADER.TEST.BATCH_SIZE, 106 | tfm=tfm_test, 107 | is_train=False, 108 | dataset_wrapper=dataset_wrapper, 109 | ) 110 | 111 | self.test_novel_loader = test_novel_loader 112 | self.test_base_loader = test_base_loader 113 | 114 | try: 115 | if self.dataset.sstrain: 116 | # 除了dataset的source是不一样的,其他跟trian都是一样的 117 | train_loader_sstrain = build_data_loader( 118 | cfg, 119 | sampler_type="SequentialSampler", 120 | data_source=self.dataset.sstrain, 121 | batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE, 122 | tfm=tfm_test, 123 | is_train=False, 124 | dataset_wrapper=dataset_wrapper, 125 | tag='sstrain' # 初始化的时候需要设置这个来保证所有样本的载入 126 | ) 127 | self.train_loader_sstrain = train_loader_sstrain 128 | 129 | # Build train_loader_x 130 | train_loader_x = build_data_loader( 131 | cfg, 132 | sampler_type="SequentialSampler", 133 | data_source=self.dataset.train_x, 134 | batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE, 135 | tfm=tfm_test, # 这个是不训练的,所以要用测试的配置, 136 | is_train=False, 137 | dataset_wrapper=dataset_wrapper, 138 | tag='sstrain' 139 | ) 140 | self.train_loader_x = train_loader_x 141 | except: 142 | pass 143 | 144 | def update_ssdateloader(self, predict_label_dict, predict_conf_dict): 145 | """update the train_loader_sstrain to add labels 146 | 147 | Args: 148 | predict_label_dict ([dict]): [a dict {'imagepath': 'label'}] 149 | """ 150 | 151 | 152 | sstrain = self.dataset.add_label(predict_label_dict, self.cfg.DATASET.NAME) 153 | print('sstrain', len(sstrain)) 154 | 155 | 156 | # train_sampler = WeightedRandomSampler(weights, len(sstrain)) 157 | train_loader_sstrain = build_data_loader( 158 | self.cfg, 159 | sampler_type="RandomSampler", 160 | sampler=None, 161 | data_source=sstrain, 162 | batch_size=self.cfg.DATALOADER.TRAIN_X.BATCH_SIZE, 163 | n_domain=self.cfg.DATALOADER.TRAIN_X.N_DOMAIN, 164 | n_ins=1, # 每个类别n_ins个instance 165 | tfm=self.tfm_train, 166 | is_train=True, 167 | dataset_wrapper=self.dataset_wrapper, 168 | ) 169 | self.train_loader_sstrain = train_loader_sstrain 170 | -------------------------------------------------------------------------------- /get_info.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | # torch.set_printoptions(profile="full") 4 | import os 5 | 6 | from dassl.utils import setup_logger, set_random_seed, collect_env_info 7 | from configs.upl_default_config.upl_default import get_cfg_default 8 | from dassl.engine import build_trainer 9 | 10 | # custom 11 | import datasets.oxford_pets 12 | import datasets.oxford_flowers 13 | import datasets.fgvc_aircraft 14 | import datasets.dtd 15 | import datasets.eurosat 16 | import datasets.stanford_cars 17 | import datasets.food101 18 | import datasets.sun397 19 | import datasets.caltech101 20 | import datasets.ucf101 21 | import datasets.imagenet 22 | 23 | import trainers.upltrainer 24 | import trainers.hhzsclip 25 | 26 | 27 | def print_args(args, cfg): 28 | print('***************') 29 | print('** Arguments **') 30 | print('***************') 31 | optkeys = list(args.__dict__.keys()) 32 | optkeys.sort() 33 | for key in optkeys: 34 | print('{}: {}'.format(key, args.__dict__[key])) 35 | print('************') 36 | print('** Config **') 37 | print('************') 38 | print(cfg) 39 | 40 | 41 | def reset_cfg(cfg, args): 42 | if args.root: 43 | cfg.DATASET.ROOT = args.root 44 | 45 | if args.output_dir: 46 | cfg.OUTPUT_DIR = args.output_dir 47 | 48 | if args.resume: 49 | cfg.RESUME = args.resume 50 | 51 | if args.seed: 52 | cfg.SEED = args.seed 53 | 54 | if args.source_domains: 55 | cfg.DATASET.SOURCE_DOMAINS = args.source_domains 56 | 57 | if args.target_domains: 58 | cfg.DATASET.TARGET_DOMAINS = args.target_domains 59 | 60 | if args.transforms: 61 | cfg.INPUT.TRANSFORMS = args.transforms 62 | 63 | if args.trainer: 64 | cfg.TRAINER.NAME = args.trainer 65 | 66 | if args.backbone: 67 | cfg.MODEL.BACKBONE.NAME = args.backbone 68 | 69 | if args.head: 70 | cfg.MODEL.HEAD.NAME = args.head 71 | 72 | 73 | def extend_cfg(cfg): 74 | """ 75 | Add new config variables. 76 | 77 | E.g. 78 | from yacs.config import CfgNode as CN 79 | cfg.TRAINER.MY_MODEL = CN() 80 | cfg.TRAINER.MY_MODEL.PARAM_A = 1. 81 | cfg.TRAINER.MY_MODEL.PARAM_B = 0.5 82 | cfg.TRAINER.MY_MODEL.PARAM_C = False 83 | """ 84 | from yacs.config import CfgNode as CN 85 | 86 | cfg.TRAINER.UPLTrainer = CN() 87 | cfg.TRAINER.UPLTrainer.N_CTX = 16 # number of context vectors 88 | cfg.TRAINER.UPLTrainer.CSC = False # class-specific context 89 | cfg.TRAINER.UPLTrainer.CTX_INIT = "" # initialization words 90 | cfg.TRAINER.UPLTrainer.PREC = "fp16" # fp16, fp32, amp 91 | cfg.TRAINER.UPLTrainer.CLASS_TOKEN_POSITION = "end" # 'middle' or 'end' or 'front' 92 | 93 | 94 | def setup_cfg(args): 95 | cfg = get_cfg_default() 96 | extend_cfg(cfg) 97 | 98 | # 1. From the dataset config file 99 | if args.dataset_config_file: 100 | cfg.merge_from_file(args.dataset_config_file) 101 | 102 | # 2. From the method config file 103 | if args.config_file: 104 | cfg.merge_from_file(args.config_file) 105 | 106 | # 3. hh config 107 | if args.hh_config_file: 108 | cfg.merge_from_file(args.hh_config_file) 109 | 110 | # 3. From input arguments 111 | reset_cfg(cfg, args) 112 | 113 | # 4. From optional input arguments 114 | cfg.merge_from_list(args.opts) 115 | 116 | cfg.freeze() 117 | 118 | return cfg 119 | 120 | 121 | def main(args): 122 | cfg = setup_cfg(args) 123 | if cfg.SEED >= 0: 124 | print('Setting fixed seed: {}'.format(cfg.SEED)) 125 | set_random_seed(cfg.SEED) 126 | setup_logger(cfg.OUTPUT_DIR) 127 | 128 | if torch.cuda.is_available() and cfg.USE_CUDA: 129 | torch.backends.cudnn.benchmark = True 130 | 131 | print_args(args, cfg) 132 | print('Collecting env info ...') 133 | print('** System info **\n{}\n'.format(collect_env_info())) 134 | 135 | trainer = build_trainer(cfg) 136 | if args.model_dir: 137 | trainer.load_model_by_id(args.model_dir, epoch=args.load_epoch, model_id=0) 138 | trainer.zero_shot_analyze() 139 | 140 | 141 | if __name__ == '__main__': 142 | parser = argparse.ArgumentParser() 143 | parser.add_argument('--root', type=str, default='', help='path to dataset') 144 | parser.add_argument( 145 | '--output-dir', type=str, default='', help='output directory' 146 | ) 147 | parser.add_argument( 148 | '--resume', 149 | type=str, 150 | default='', 151 | help='checkpoint directory (from which the training resumes)' 152 | ) 153 | parser.add_argument( 154 | '--seed', 155 | type=int, 156 | default=-1, 157 | help='only positive value enables a fixed seed' 158 | ) 159 | parser.add_argument( 160 | '--source-domains', 161 | type=str, 162 | nargs='+', 163 | help='source domains for DA/DG' 164 | ) 165 | parser.add_argument( 166 | '--target-domains', 167 | type=str, 168 | nargs='+', 169 | help='target domains for DA/DG' 170 | ) 171 | parser.add_argument( 172 | '--transforms', type=str, nargs='+', help='data augmentation methods' 173 | ) 174 | parser.add_argument( 175 | '--config-file', type=str, default='', help='path to config file' 176 | ) 177 | parser.add_argument( 178 | '--dataset-config-file', 179 | type=str, 180 | default='', 181 | help='path to config file for dataset setup' 182 | ) 183 | parser.add_argument( 184 | '--hh-config-file', type=str, default='', help='path to config file' 185 | ) 186 | parser.add_argument( 187 | '--trainer', type=str, default='', help='name of trainer' 188 | ) 189 | parser.add_argument( 190 | '--backbone', type=str, default='', help='name of CNN backbone' 191 | ) 192 | parser.add_argument('--head', type=str, default='', help='name of head') 193 | parser.add_argument( 194 | '--eval-only', action='store_true', help='evaluation only' 195 | ) 196 | parser.add_argument( 197 | '--model-dir', 198 | type=str, 199 | default='', 200 | help='load model from this directory for eval-only mode' 201 | ) 202 | parser.add_argument( 203 | '--load-epoch', 204 | type=int, 205 | help='load model weights at this epoch for evaluation' 206 | ) 207 | parser.add_argument( 208 | '--no-train', action='store_true', help='do not call trainer.train()' 209 | ) 210 | parser.add_argument( 211 | 'opts', 212 | default=None, 213 | nargs=argparse.REMAINDER, 214 | help='modify config options using the command-line' 215 | ) 216 | args = parser.parse_args() 217 | main(args) 218 | -------------------------------------------------------------------------------- /datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from collections import OrderedDict 4 | from tqdm import tqdm 5 | 6 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 7 | from dassl.utils import listdir_nohidden 8 | from .datasetbase import UPLDatasetBase 9 | 10 | TO_BE_IGNORED = ['new_val_prep.sh', 'train1.txt', 'train.txt'] 11 | 12 | @DATASET_REGISTRY.register() 13 | class SSImageNet(UPLDatasetBase): 14 | 15 | dataset_dir = "imagenet" 16 | 17 | def __init__(self, cfg): 18 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 19 | self.dataset_dir = os.path.join(root, self.dataset_dir) 20 | self.image_dir = os.path.join(self.dataset_dir, "images") 21 | self.preprocessed = os.path.join(self.dataset_dir, "preprocessed.pkl") 22 | self.img_dir = [] 23 | 24 | if os.path.exists(self.preprocessed): 25 | with open(self.preprocessed, "rb") as f: 26 | preprocessed = pickle.load(f) 27 | train = preprocessed["train"] 28 | test = preprocessed["test"] 29 | else: 30 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 31 | classnames = self.read_classnames(text_file) 32 | train = self.read_data(classnames, "train") 33 | # Follow standard practice to perform evaluation on the val set 34 | # Also used as the val set (so evaluate the last-step model) 35 | test = self.read_data(classnames, "val") 36 | 37 | preprocessed = {"train": train, "test": test} 38 | with open(self.preprocessed, "wb") as f: 39 | pickle.dump(preprocessed, f, protocol=pickle.HIGHEST_PROTOCOL) 40 | 41 | num_shots = cfg.DATASET.NUM_SHOTS 42 | train = self.generate_fewshot_dataset(train, num_shots=-1) 43 | 44 | sstrain = self.read_sstrain_data(train) 45 | super().__init__(train_x=train, test=test, sstrain=sstrain) 46 | 47 | @staticmethod 48 | def read_classnames(text_file): 49 | """Return a dictionary containing 50 | key-value pairs of : . 51 | """ 52 | classnames = OrderedDict() 53 | with open(text_file, "r") as f: 54 | lines = f.readlines() 55 | for line in lines: 56 | line = line.strip().split(" ") 57 | folder = line[0] 58 | classname = " ".join(line[1:]) 59 | classnames[folder] = classname 60 | return classnames 61 | 62 | def read_data(self, classnames, split_dir): 63 | split_dir = os.path.join(self.image_dir, split_dir) 64 | folders = sorted(f.name for f in os.scandir(split_dir) if f.is_dir()) 65 | items = [] 66 | 67 | for label, folder in enumerate(folders): 68 | imnames = listdir_nohidden(os.path.join(split_dir, folder)) 69 | classname = classnames[folder] 70 | for imname in imnames: 71 | impath = os.path.join(split_dir, folder, imname) 72 | item = Datum(impath=impath, label=label, classname=classname) 73 | items.append(item) 74 | return items 75 | 76 | def read_sstrain_data(self, train, predict_label_dict=None): 77 | items = [] 78 | if predict_label_dict is None: 79 | for item in tqdm(train): 80 | new_item = Datum(impath=item.impath, label=-1, classname=None) 81 | items.append(new_item) 82 | self.img_dir.append(item.impath) 83 | else: 84 | for item in tqdm(train): 85 | sub_impath = './data/' + item.impath.split('/data/')[1] 86 | if sub_impath in predict_label_dict: 87 | new_item = Datum(impath=item.impath, label=predict_label_dict[sub_impath], classname=self._lab2cname[predict_label_dict[sub_impath]]) 88 | items.append(new_item) 89 | self.img_dir.append(item.impath) 90 | return items 91 | 92 | 93 | 94 | 95 | @DATASET_REGISTRY.register() 96 | class ImageNet(DatasetBase): 97 | 98 | dataset_dir = "imagenet" 99 | 100 | def __init__(self, cfg): 101 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 102 | self.dataset_dir = os.path.join(root, self.dataset_dir) 103 | self.image_dir = os.path.join(self.dataset_dir, "images") 104 | self.preprocessed = os.path.join(self.dataset_dir, "preprocessed.pkl") 105 | 106 | if os.path.exists(self.preprocessed): 107 | with open(self.preprocessed, "rb") as f: 108 | preprocessed = pickle.load(f) 109 | train = preprocessed["train"] 110 | test = preprocessed["test"] 111 | else: 112 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 113 | classnames = self.read_classnames(text_file) 114 | train = self.read_data(classnames, "train") 115 | # Follow standard practice to perform evaluation on the val set 116 | # Also used as the val set (so evaluate the last-step model) 117 | test = self.read_data(classnames, "val") 118 | 119 | preprocessed = {"train": train, "test": test} 120 | with open(self.preprocessed, "wb") as f: 121 | pickle.dump(preprocessed, f, protocol=pickle.HIGHEST_PROTOCOL) 122 | 123 | num_shots = cfg.DATASET.NUM_SHOTS 124 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 125 | 126 | super().__init__(train_x=train, val=test, test=test) 127 | 128 | @staticmethod 129 | def read_classnames(text_file): 130 | """Return a dictionary containing 131 | key-value pairs of : . 132 | """ 133 | classnames = OrderedDict() 134 | with open(text_file, "r") as f: 135 | lines = f.readlines() 136 | for line in lines: 137 | line = line.strip().split(" ") 138 | folder = line[0] 139 | classname = " ".join(line[1:]) 140 | classnames[folder] = classname 141 | return classnames 142 | 143 | def read_data(self, classnames, split_dir): 144 | split_dir = os.path.join(self.image_dir, split_dir) 145 | folders = sorted(f.name for f in os.scandir(split_dir) if f.is_dir()) 146 | items = [] 147 | 148 | for label, folder in enumerate(folders): 149 | imnames = listdir_nohidden(os.path.join(split_dir, folder)) 150 | classname = classnames[folder] 151 | for imname in imnames: 152 | impath = os.path.join(split_dir, folder, imname) 153 | item = Datum(impath=impath, label=label, classname=classname) 154 | items.append(item) 155 | 156 | return items -------------------------------------------------------------------------------- /upl_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | # torch.set_printoptions(profile="full") 4 | import os 5 | 6 | from dassl.utils import setup_logger, set_random_seed, collect_env_info 7 | from configs.upl_default_config.upl_default import get_cfg_default 8 | from dassl.engine import build_trainer 9 | 10 | # custom 11 | import datasets.oxford_pets 12 | import datasets.oxford_flowers 13 | import datasets.fgvc_aircraft 14 | import datasets.dtd 15 | import datasets.eurosat 16 | import datasets.stanford_cars 17 | import datasets.food101 18 | import datasets.sun397 19 | import datasets.caltech101 20 | import datasets.ucf101 21 | import datasets.imagenet 22 | 23 | import trainers.upltrainer 24 | import trainers.hhzsclip 25 | 26 | 27 | def print_args(args, cfg): 28 | print('***************') 29 | print('** Arguments **') 30 | print('***************') 31 | optkeys = list(args.__dict__.keys()) 32 | optkeys.sort() 33 | for key in optkeys: 34 | print('{}: {}'.format(key, args.__dict__[key])) 35 | print('************') 36 | print('** Config **') 37 | print('************') 38 | print(cfg) 39 | 40 | 41 | def reset_cfg(cfg, args): 42 | if args.root: 43 | cfg.DATASET.ROOT = args.root 44 | 45 | if args.output_dir: 46 | cfg.OUTPUT_DIR = args.output_dir 47 | 48 | if args.resume: 49 | cfg.RESUME = args.resume 50 | 51 | if args.seed: 52 | cfg.SEED = args.seed 53 | 54 | if args.source_domains: 55 | cfg.DATASET.SOURCE_DOMAINS = args.source_domains 56 | 57 | if args.target_domains: 58 | cfg.DATASET.TARGET_DOMAINS = args.target_domains 59 | 60 | if args.transforms: 61 | cfg.INPUT.TRANSFORMS = args.transforms 62 | 63 | if args.trainer: 64 | cfg.TRAINER.NAME = args.trainer 65 | 66 | if args.backbone: 67 | cfg.MODEL.BACKBONE.NAME = args.backbone 68 | 69 | if args.head: 70 | cfg.MODEL.HEAD.NAME = args.head 71 | 72 | 73 | def extend_cfg(cfg): 74 | """ 75 | Add new config variables. 76 | 77 | E.g. 78 | from yacs.config import CfgNode as CN 79 | cfg.TRAINER.MY_MODEL = CN() 80 | cfg.TRAINER.MY_MODEL.PARAM_A = 1. 81 | cfg.TRAINER.MY_MODEL.PARAM_B = 0.5 82 | cfg.TRAINER.MY_MODEL.PARAM_C = False 83 | """ 84 | from yacs.config import CfgNode as CN 85 | 86 | cfg.TRAINER.UPLTrainer = CN() 87 | cfg.TRAINER.UPLTrainer.N_CTX = 16 # number of context vectors 88 | cfg.TRAINER.UPLTrainer.CSC = False # class-specific context 89 | cfg.TRAINER.UPLTrainer.CTX_INIT = "" # initialization words 90 | cfg.TRAINER.UPLTrainer.PREC = "fp16" # fp16, fp32, amp 91 | cfg.TRAINER.UPLTrainer.CLASS_TOKEN_POSITION = "end" # 'middle' or 'end' or 'front' 92 | 93 | 94 | def setup_cfg(args): 95 | cfg = get_cfg_default() 96 | extend_cfg(cfg) 97 | 98 | # 1. From the dataset config file 99 | if args.dataset_config_file: 100 | cfg.merge_from_file(args.dataset_config_file) 101 | 102 | # 2. From the method config file 103 | if args.config_file: 104 | cfg.merge_from_file(args.config_file) 105 | 106 | # 3. hh config 107 | if args.hh_config_file: 108 | cfg.merge_from_file(args.hh_config_file) 109 | 110 | # 3. From input arguments 111 | reset_cfg(cfg, args) 112 | 113 | # 4. From optional input arguments 114 | cfg.merge_from_list(args.opts) 115 | 116 | cfg.freeze() 117 | 118 | return cfg 119 | 120 | 121 | def main(args): 122 | cfg = setup_cfg(args) 123 | if cfg.SEED >= 0: 124 | print('Setting fixed seed: {}'.format(cfg.SEED)) 125 | set_random_seed(cfg.SEED) 126 | setup_logger(cfg.OUTPUT_DIR) 127 | 128 | if torch.cuda.is_available() and cfg.USE_CUDA: 129 | torch.backends.cudnn.benchmark = True 130 | 131 | print_args(args, cfg) 132 | print('Collecting env info ...') 133 | print('** System info **\n{}\n'.format(collect_env_info())) 134 | 135 | trainer_list = [] 136 | for i in range(int(cfg.TRAINER.ENSEMBLE_NUM)): 137 | trainer = build_trainer(cfg) 138 | if args.model_dir: 139 | trainer.load_model_by_id(args.model_dir, epoch=args.load_epoch, model_id=i) 140 | trainer_list.append(trainer) 141 | 142 | predict_label_dict, predict_conf_dict = trainer.load_from_exist_file(file_path='./analyze_results', 143 | model_names=cfg.MODEL.PSEUDO_LABEL_MODELS) 144 | 145 | trainer.dm.update_ssdateloader(predict_label_dict, predict_conf_dict) 146 | trainer.train_loader_sstrain = trainer.dm.train_loader_sstrain 147 | trainer.sstrain_with_id(model_id=i) 148 | 149 | if __name__ == '__main__': 150 | parser = argparse.ArgumentParser() 151 | parser.add_argument('--root', type=str, default='', help='path to dataset') 152 | parser.add_argument( 153 | '--output-dir', type=str, default='', help='output directory' 154 | ) 155 | parser.add_argument( 156 | '--resume', 157 | type=str, 158 | default='', 159 | help='checkpoint directory (from which the training resumes)' 160 | ) 161 | parser.add_argument( 162 | '--seed', 163 | type=int, 164 | default=-1, 165 | help='only positive value enables a fixed seed' 166 | ) 167 | parser.add_argument( 168 | '--source-domains', 169 | type=str, 170 | nargs='+', 171 | help='source domains for DA/DG' 172 | ) 173 | parser.add_argument( 174 | '--target-domains', 175 | type=str, 176 | nargs='+', 177 | help='target domains for DA/DG' 178 | ) 179 | parser.add_argument( 180 | '--transforms', type=str, nargs='+', help='data augmentation methods' 181 | ) 182 | parser.add_argument( 183 | '--config-file', type=str, default='', help='path to config file' 184 | ) 185 | parser.add_argument( 186 | '--dataset-config-file', 187 | type=str, 188 | default='', 189 | help='path to config file for dataset setup' 190 | ) 191 | parser.add_argument( 192 | '--hh-config-file', type=str, default='', help='path to config file' 193 | ) 194 | parser.add_argument( 195 | '--trainer', type=str, default='', help='name of trainer' 196 | ) 197 | parser.add_argument( 198 | '--backbone', type=str, default='', help='name of CNN backbone' 199 | ) 200 | parser.add_argument('--head', type=str, default='', help='name of head') 201 | parser.add_argument( 202 | '--eval-only', action='store_true', help='evaluation only' 203 | ) 204 | parser.add_argument( 205 | '--model-dir', 206 | type=str, 207 | default='', 208 | help='load model from this directory for eval-only mode' 209 | ) 210 | parser.add_argument( 211 | '--load-epoch', 212 | type=int, 213 | help='load model weights at this epoch for evaluation' 214 | ) 215 | parser.add_argument( 216 | '--no-train', action='store_true', help='do not call trainer.train()' 217 | ) 218 | parser.add_argument( 219 | 'opts', 220 | default=None, 221 | nargs=argparse.REMAINDER, 222 | help='modify config options using the command-line' 223 | ) 224 | args = parser.parse_args() 225 | main(args) 226 | -------------------------------------------------------------------------------- /datasets/datasetbase.py: -------------------------------------------------------------------------------- 1 | from dassl.data.datasets import DatasetBase 2 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 3 | from dassl.utils import read_json, write_json 4 | from dassl.utils import listdir_nohidden 5 | 6 | import random 7 | import os 8 | 9 | 10 | class UPLDatasetBase(DatasetBase): 11 | def __init__(self, train_x=None, train_u=None, val=None, test=None, novel=None, base=None, sstrain=None): 12 | super().__init__(train_x=train_x, val=val, test=test) 13 | self._novel = novel 14 | self._base = base 15 | self.sstrain = sstrain 16 | self._lab2cname_novel, _ = self.get_lab2cname(novel) 17 | self._lab2cname_base, _ = self.get_lab2cname(base) 18 | 19 | def generate_fewshot_dataset(self, *data_sources, num_shots=-1, mode=None, ignore_labels=None, repeat=False): 20 | """Generate a few-shot dataset (typically for the training set). 21 | 22 | This function is useful when one wants to evaluate a model 23 | in a few-shot learning setting where each class only contains 24 | a few number of images. 25 | 26 | Args: 27 | data_sources: each individual is a list containing Datum objects. 28 | num_shots (int): number of instances per class to sample. 29 | repeat (bool): repeat images if needed (default: False). 30 | """ 31 | if num_shots < 1: 32 | if len(data_sources) == 1: 33 | return data_sources[0] 34 | return data_sources 35 | 36 | print(f"Creating a {num_shots}-shot dataset") 37 | 38 | output = [] 39 | 40 | for data_source in data_sources: 41 | tracker = self.split_dataset_by_label(data_source) 42 | dataset = [] 43 | 44 | for label, items in tracker.items(): 45 | # print(label, items[0].classname) 46 | if mode == 'train': 47 | if ignore_labels: 48 | if items[0].classname not in ignore_labels: 49 | if len(items) >= num_shots: 50 | sampled_items = random.sample(items, num_shots) 51 | else: 52 | if repeat: 53 | sampled_items = random.choices(items, k=num_shots) 54 | else: 55 | sampled_items = items 56 | dataset.extend(sampled_items) 57 | else: 58 | if len(items) >= num_shots: 59 | sampled_items = random.sample(items, num_shots) 60 | else: 61 | if repeat: 62 | sampled_items = random.choices(items, k=num_shots) 63 | else: 64 | sampled_items = items 65 | dataset.extend(sampled_items) 66 | else: 67 | if len(items) >= num_shots: 68 | sampled_items = random.sample(items, num_shots) 69 | else: 70 | if repeat: 71 | sampled_items = random.choices(items, k=num_shots) 72 | else: 73 | sampled_items = items 74 | dataset.extend(sampled_items) 75 | output.append(dataset) 76 | if len(output) == 1: 77 | return output[0] 78 | 79 | return output 80 | 81 | def split_base_and_novel(self, test, ignore_labels): 82 | tracker = self.split_dataset_by_label(test) 83 | dataset_base = [] 84 | dataset_novel = [] 85 | for label, items in tracker.items(): 86 | if items[0].classname not in ignore_labels: 87 | dataset_novel.extend(items) 88 | else: 89 | dataset_base.extend(items) 90 | return dataset_novel, dataset_base 91 | 92 | def get_lab2cname(self, data_source): 93 | """Get a label-to-classname mapping (dict). 94 | 95 | Args: 96 | data_source (list): a list of Datum objects. 97 | """ 98 | container = set() 99 | if data_source is not None: 100 | for item in data_source: 101 | container.add((item.label, item.classname)) 102 | mapping = {label: classname for label, classname in container} 103 | labels = list(mapping.keys()) 104 | labels.sort() 105 | classnames = [mapping[label] for label in labels] 106 | return mapping, classnames 107 | else: 108 | return None, None 109 | 110 | def read_split(self, filepath, path_prefix): 111 | def _convert(items): 112 | out = [] 113 | for impath, label, classname in items: 114 | impath = os.path.join(path_prefix, impath) 115 | item = Datum(impath=impath, label=label, classname=classname) 116 | out.append(item) 117 | return out 118 | 119 | print(f"Reading split from {filepath}") 120 | split = read_json(filepath) 121 | train = _convert(split["train"]) 122 | val = _convert(split["val"]) 123 | test = _convert(split["test"]) 124 | return train, val, test 125 | 126 | def read_sstrain_data(self, filepath, path_prefix, predict_label_dict=None): 127 | 128 | def _convert(items): 129 | out = [] 130 | for impath, _, _ in items: 131 | impath = os.path.join(path_prefix, impath) 132 | sub_impath = './data/' + impath.split('/data/')[1] 133 | if sub_impath in predict_label_dict: 134 | item = Datum(impath=impath, label=predict_label_dict[sub_impath], classname=self._lab2cname[predict_label_dict[sub_impath]]) 135 | out.append(item) 136 | return out 137 | 138 | def _convert_no_label(items): 139 | out = [] 140 | for impath, _, _ in items: 141 | impath = os.path.join(path_prefix, impath) 142 | item = Datum(impath=impath, label=-1, classname=None) 143 | out.append(item) 144 | return out 145 | 146 | print(f"Reading split from {filepath}") 147 | split = read_json(filepath) 148 | if predict_label_dict is not None: 149 | train = _convert(split["train"]) 150 | else: 151 | train = _convert_no_label(split["train"]) 152 | return train 153 | 154 | # dtd会报错 需要单独处理,它里面有处理的函数,需要copy过来,详见原版的database 155 | 156 | def add_label(self, predict_label_dict, dataset_name): 157 | """add label when training for self-supervised learning 158 | 159 | Args: 160 | predict_label_dict ([dict]): [a dict {'imagepath': 'label'}] 161 | """ 162 | # print(predict_label_dict, 'predict_label_dict') 163 | print(dataset_name) 164 | if dataset_name == 'SSFGVCAircraft': 165 | sstrain = self.read_data_without_label(self.cname2lab, "images_variant_train.txt", predict_label_dict) 166 | elif dataset_name == 'SSImageNet': 167 | sstrain = self.read_sstrain_data(self.train_x, predict_label_dict) 168 | else: 169 | sstrain = self.read_sstrain_data(self.split_path, self.image_dir, predict_label_dict) 170 | 171 | self.sstrain = sstrain 172 | return sstrain -------------------------------------------------------------------------------- /upl_test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | # torch.set_printoptions(profile="full") 4 | import os 5 | import random 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | 9 | 10 | from dassl.utils import setup_logger, set_random_seed, collect_env_info 11 | from configs.upl_default_config.upl_default import get_cfg_default 12 | from dassl.engine import build_trainer 13 | 14 | # custom 15 | import datasets.oxford_pets 16 | import datasets.oxford_flowers 17 | import datasets.fgvc_aircraft 18 | import datasets.dtd 19 | import datasets.eurosat 20 | import datasets.stanford_cars 21 | import datasets.food101 22 | import datasets.sun397 23 | import datasets.caltech101 24 | import datasets.ucf101 25 | import datasets.imagenet 26 | 27 | import trainers.upltrainer 28 | import trainers.hhzsclip 29 | 30 | 31 | def print_args(args, cfg): 32 | print('***************') 33 | print('** Arguments **') 34 | print('***************') 35 | optkeys = list(args.__dict__.keys()) 36 | optkeys.sort() 37 | for key in optkeys: 38 | print('{}: {}'.format(key, args.__dict__[key])) 39 | print('************') 40 | print('** Config **') 41 | print('************') 42 | print(cfg) 43 | 44 | def reset_cfg(cfg, args): 45 | if args.root: 46 | cfg.DATASET.ROOT = args.root 47 | 48 | if args.output_dir: 49 | cfg.OUTPUT_DIR = args.output_dir 50 | 51 | if args.resume: 52 | cfg.RESUME = args.resume 53 | 54 | if args.seed: 55 | cfg.SEED = args.seed 56 | 57 | if args.source_domains: 58 | cfg.DATASET.SOURCE_DOMAINS = args.source_domains 59 | 60 | if args.target_domains: 61 | cfg.DATASET.TARGET_DOMAINS = args.target_domains 62 | 63 | if args.transforms: 64 | cfg.INPUT.TRANSFORMS = args.transforms 65 | 66 | if args.trainer: 67 | cfg.TRAINER.NAME = args.trainer 68 | 69 | if args.backbone: 70 | cfg.MODEL.BACKBONE.NAME = args.backbone 71 | 72 | if args.head: 73 | cfg.MODEL.HEAD.NAME = args.head 74 | 75 | 76 | def extend_cfg(cfg): 77 | """ 78 | Add new config variables. 79 | 80 | E.g. 81 | from yacs.config import CfgNode as CN 82 | cfg.TRAINER.MY_MODEL = CN() 83 | cfg.TRAINER.MY_MODEL.PARAM_A = 1. 84 | cfg.TRAINER.MY_MODEL.PARAM_B = 0.5 85 | cfg.TRAINER.MY_MODEL.PARAM_C = False 86 | """ 87 | from yacs.config import CfgNode as CN 88 | 89 | cfg.TRAINER.UPLTrainer = CN() 90 | cfg.TRAINER.UPLTrainer.N_CTX = 16 # number of context vectors 91 | cfg.TRAINER.UPLTrainer.CSC = False # class-specific context 92 | cfg.TRAINER.UPLTrainer.CTX_INIT = "" # initialization words 93 | cfg.TRAINER.UPLTrainer.PREC = "fp16" # fp16, fp32, amp 94 | cfg.TRAINER.UPLTrainer.CLASS_TOKEN_POSITION = "end" # 'middle' or 'end' or 'front' 95 | 96 | 97 | def setup_cfg(args): 98 | cfg = get_cfg_default() 99 | extend_cfg(cfg) 100 | 101 | # 1. From the dataset config file 102 | if args.dataset_config_file: 103 | cfg.merge_from_file(args.dataset_config_file) 104 | 105 | # 2. From the method config file 106 | if args.config_file: 107 | cfg.merge_from_file(args.config_file) 108 | 109 | # 3. hh config 110 | if args.hh_config_file: 111 | cfg.merge_from_file(args.hh_config_file) 112 | 113 | # 3. From input arguments 114 | reset_cfg(cfg, args) 115 | 116 | # 4. From optional input arguments 117 | cfg.merge_from_list(args.opts) 118 | 119 | cfg.freeze() 120 | 121 | return cfg 122 | 123 | 124 | def main(args): 125 | cfg = setup_cfg(args) 126 | if cfg.SEED >= 0: 127 | print('Setting fixed seed: {}'.format(cfg.SEED)) 128 | set_random_seed(cfg.SEED) 129 | setup_logger(cfg.OUTPUT_DIR) 130 | 131 | if torch.cuda.is_available() and cfg.USE_CUDA: 132 | torch.backends.cudnn.benchmark = True 133 | 134 | print_args(args, cfg) 135 | print('Collecting env info ...') 136 | print('** System info **\n{}\n'.format(collect_env_info())) 137 | 138 | trainer_list = [] 139 | for i in range(int(cfg.TRAINER.ENSEMBLE_NUM)): 140 | trainer = build_trainer(cfg) 141 | if args.model_dir: 142 | trainer.load_model_by_id(args.model_dir, epoch=args.load_epoch, model_id=i) 143 | trainer_list.append(trainer) 144 | 145 | 146 | prob_end = [] 147 | results_dict = {} 148 | for SHOTS in [16]: 149 | print('SHOTS:', SHOTS, '\n\n\n\n\n') 150 | for seed in range(1, 17): 151 | try: 152 | prob = torch.load('./analysis_results_test/{}/50_{}_{}_random_initend/test_logits.pt'.format(cfg.DATASET.NAME, seed, SHOTS)) 153 | prob_end.append(prob) 154 | except: 155 | print('loss') 156 | 157 | print(len(prob_end), ' shots ensemble') 158 | prob_test = sum(prob_end) / len(prob_end) 159 | results_end = trainer_list[0].test_with_existing_logits(prob_test) 160 | 161 | print('ensemble: {:.2f}'.format((results_end['accuracy']))) 162 | results_dict[len(prob_end)] = results_end['accuracy'] 163 | 164 | 165 | 166 | if __name__ == '__main__': 167 | parser = argparse.ArgumentParser() 168 | parser.add_argument('--root', type=str, default='', help='path to dataset') 169 | parser.add_argument( 170 | '--output-dir', type=str, default='', help='output directory' 171 | ) 172 | parser.add_argument( 173 | '--resume', 174 | type=str, 175 | default='', 176 | help='checkpoint directory (from which the training resumes)' 177 | ) 178 | parser.add_argument( 179 | '--seed', 180 | type=int, 181 | default=-1, 182 | help='only positive value enables a fixed seed' 183 | ) 184 | parser.add_argument( 185 | '--source-domains', 186 | type=str, 187 | nargs='+', 188 | help='source domains for DA/DG' 189 | ) 190 | parser.add_argument( 191 | '--target-domains', 192 | type=str, 193 | nargs='+', 194 | help='target domains for DA/DG' 195 | ) 196 | parser.add_argument( 197 | '--transforms', type=str, nargs='+', help='data augmentation methods' 198 | ) 199 | parser.add_argument( 200 | '--config-file', type=str, default='', help='path to config file' 201 | ) 202 | parser.add_argument( 203 | '--dataset-config-file', 204 | type=str, 205 | default='', 206 | help='path to config file for dataset setup' 207 | ) 208 | parser.add_argument( 209 | '--hh-config-file', type=str, default='', help='path to config file' 210 | ) 211 | parser.add_argument( 212 | '--trainer', type=str, default='', help='name of trainer' 213 | ) 214 | parser.add_argument( 215 | '--backbone', type=str, default='', help='name of CNN backbone' 216 | ) 217 | parser.add_argument('--head', type=str, default='', help='name of head') 218 | parser.add_argument( 219 | '--eval-only', action='store_true', help='evaluation only' 220 | ) 221 | parser.add_argument( 222 | '--model-dir', 223 | type=str, 224 | default='', 225 | help='load model from this directory for eval-only mode' 226 | ) 227 | parser.add_argument( 228 | '--load-epoch', 229 | type=int, 230 | help='load model weights at this epoch for evaluation' 231 | ) 232 | parser.add_argument( 233 | '--no-train', action='store_true', help='do not call trainer.train()' 234 | ) 235 | parser.add_argument( 236 | 'opts', 237 | default=None, 238 | nargs=argparse.REMAINDER, 239 | help='modify config options using the command-line' 240 | ) 241 | args = parser.parse_args() 242 | main(args) 243 | -------------------------------------------------------------------------------- /trainers/hhzsclip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | 5 | from dassl.engine import TRAINER_REGISTRY, TrainerX 6 | from dassl.optim import build_optimizer, build_lr_scheduler 7 | 8 | from clip import clip 9 | from clip.model import convert_weights 10 | 11 | from .coop import load_clip_to_cpu 12 | 13 | from datasets.data_manager import UPLDataManager 14 | 15 | from .utils import plotLogitsMap, plotPRMap 16 | 17 | CUSTOM_TEMPLATES = { 18 | "OxfordPets": "a photo of a {}, a type of pet.", 19 | "OxfordFlowers": "a photo of a {}, a type of flower.", 20 | "FGVCAircraft": "a photo of a {}, a type of aircraft.", 21 | "DescribableTextures": "{} texture.", 22 | "EuroSAT": "a centered satellite photo of {}.", 23 | "StanfordCars": "a photo of a {}.", 24 | "Food101": "a photo of {}, a type of food.", 25 | "SUN397": "a photo of a {}.", 26 | "Caltech101": "a photo of a {}.", 27 | "UCF101": "a photo of a person doing {}.", 28 | "ImageNet": "a photo of a {}.", 29 | "ImageNetSketch": "a photo of a {}.", 30 | "ImageNetV2": "a photo of a {}.", 31 | "ImageNetA": "a photo of a {}.", 32 | "ImageNetR": "a photo of a {}.", 33 | # semi-supervised templates 34 | "SSOxfordPets": "a photo of a {}, a type of pet.", 35 | "SSOxfordFlowers": "a photo of a {}, a type of flower.", 36 | "SSFGVCAircraft": "a photo of a {}, a type of aircraft.", 37 | "SSDescribableTextures": "{} texture.", 38 | "SSEuroSAT": "a centered satellite photo of {}.", 39 | "SSStanfordCars": "a photo of a {}.", 40 | "SSFood101": "a photo of {}, a type of food.", 41 | "SSSUN397": "a photo of a {}.", 42 | "SSCaltech101": "a photo of a {}.", 43 | "SSUCF101": "a photo of a person doing {}.", 44 | "SSImageNet": "a photo of a {}.", 45 | } 46 | 47 | 48 | 49 | 50 | @TRAINER_REGISTRY.register() 51 | class ZeroshotCLIP(TrainerX): 52 | 53 | def build_data_loader(self): 54 | """Create essential data-related attributes. 55 | 56 | A re-implementation of this method must create the 57 | same attributes (except self.dm). 58 | """ 59 | dm = UPLDataManager(self.cfg) 60 | 61 | self.train_loader_x = dm.train_loader_x 62 | self.train_loader_u = dm.train_loader_u # optional, can be None 63 | self.val_loader = dm.val_loader # optional, can be None 64 | self.test_loader = dm.test_loader 65 | self.num_classes = dm.num_classes 66 | self.num_source_domains = dm.num_source_domains 67 | self.lab2cname = dm.lab2cname # dict {label: classname} 68 | 69 | if self.cfg.DATALOADER.OPEN_SETTING: 70 | self.test_novel_loader = dm.test_novel_loader 71 | self.test_base_loader = dm.test_base_loader 72 | 73 | self.dm = dm 74 | 75 | def build_model(self): 76 | cfg = self.cfg 77 | classnames = self.dm.dataset.classnames 78 | 79 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 80 | clip_model = load_clip_to_cpu(cfg) 81 | clip_model.to(self.device) 82 | 83 | temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME] 84 | prompts = [temp.format(c.replace("_", " ")) for c in classnames] 85 | print(f"Prompts: {prompts}") 86 | prompts = torch.cat([clip.tokenize(p) for p in prompts]) 87 | prompts = prompts.to(self.device) 88 | 89 | with torch.no_grad(): 90 | text_features = clip_model.encode_text(prompts) 91 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 92 | 93 | self.text_features = text_features 94 | self.clip_model = clip_model 95 | 96 | def model_inference(self, image): 97 | image_features = self.clip_model.encode_image(image) 98 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 99 | logit_scale = self.clip_model.logit_scale.exp() 100 | logits = logit_scale * image_features @ self.text_features.t() 101 | return logits, image_features, self.text_features 102 | 103 | @torch.no_grad() 104 | def test(self, split=None, trainer_list=None): 105 | """A generic testing pipeline.""" 106 | # 如果是ensemble的需要添加tag 在实验记录中方便区分 107 | # if trainer_list is not None and "ENSEMBLE:{}".format(self.cfg.TRAINER.ENSEMBLE_NUM) not in self.online_logger.tags: 108 | # self.online_logger.tags += ("ENSEMBLE:{}".format(self.cfg.TRAINER.ENSEMBLE_NUM),) 109 | 110 | # if self.cfg.DATASET.NAME not in self.online_logger.tags: 111 | # self.online_logger.tags += ("DATASET INFERENCE:{}".format(self.cfg.DATASET.NAME),) 112 | 113 | self.set_model_mode("eval") 114 | self.evaluator.reset() 115 | 116 | save_path = os.path.join(self.cfg.TEST.Analyze_Result_Path, self.cfg.DATASET.NAME, 117 | str(self.cfg.OPTIM.MAX_EPOCH)+'_'+str(self.cfg.SEED)+'_'+str(self.cfg.DATASET.NUM_SHOTS)) 118 | if not os.path.exists(save_path): 119 | os.makedirs(save_path) 120 | 121 | results_id = 0 122 | while os.path.exists(os.path.join(save_path, 'per_image_results_{}_{}.txt'.format(split, results_id))): 123 | results_id += 1 124 | self.per_image_txt_writer = open(os.path.join(save_path, 'per_image_results_{}_{}.txt'.format(split, results_id)), 'w') 125 | self.per_class_txt_writer = open(os.path.join(save_path, 'per_class_results_{}_{}.txt'.format(split, results_id)), 'w') 126 | 127 | if split is None: 128 | split = self.cfg.TEST.SPLIT 129 | 130 | if split == "val" and self.val_loader is not None: 131 | data_loader = self.val_loader 132 | print("Do evaluation on {} set".format(split)) 133 | elif split=="novel": 134 | data_loader = self.test_novel_loader 135 | print("Do evaluation on test novel set") 136 | elif split=="base": 137 | data_loader = self.test_base_loader 138 | print("Do evaluation on test base set") 139 | elif split=="train": 140 | data_loader = self.train_loader_x 141 | print("Do evaluation on train set") 142 | elif split=="all": 143 | data_loader = self.test_loader 144 | print("Do evaluation on test set") 145 | else: 146 | data_loader = self.test_loader 147 | print("Do evaluation on test set") 148 | 149 | outputs_for_tsne = [] 150 | label_for_tsne = [] 151 | image_features_for_tsne = [] 152 | text_features_for_tsne = [] 153 | for batch_idx, batch in enumerate(data_loader): 154 | input, label = self.parse_batch_test(batch) 155 | if trainer_list is None or len(trainer_list)==1: 156 | # 如果不是ensemble的测试 157 | output, image_features, text_features = self.model_inference(input) 158 | image_features_for_tsne.append(image_features) 159 | text_features_for_tsne.append(text_features) 160 | else: 161 | # ensemble的测试 162 | outputs = [t.model_inference(input)[0] for t in trainer_list] 163 | output = sum(outputs) / len(outputs) 164 | self.evaluator.process(output, label, self.per_image_txt_writer, self.per_class_txt_writer) 165 | outputs_for_tsne.append(output) 166 | label_for_tsne.append(label) 167 | results = self.evaluator.evaluate() 168 | if split in ['all', 'train', 'test', 'novel', 'base']: 169 | if len(outputs_for_tsne) != 0: 170 | outputs_for_tsne = torch.cat(outputs_for_tsne, dim=0) 171 | print('outputs_for_tsne', outputs_for_tsne.shape) 172 | label_for_tsne = torch.cat(label_for_tsne, dim=0) 173 | print('label_for_tsne', label_for_tsne.shape) 174 | image_features_for_tsne = torch.cat(image_features_for_tsne, dim=0) 175 | text_features_for_tsne = text_features_for_tsne[0] 176 | torch.save(image_features_for_tsne, os.path.join(save_path, '{}_v_features.pt'.format(split))) 177 | torch.save(image_features_for_tsne, os.path.join(save_path, '{}_targets.pt'.format(split))) 178 | torch.save(outputs_for_tsne, os.path.join(save_path, '{}_logits.pt'.format(split))) 179 | torch.save(text_features_for_tsne, os.path.join(save_path, '{}_l_logits.pt'.format(split))) 180 | # plotLogitsMap(outputs_for_tsne, label_for_tsne, os.path.join(save_path, '{}_map.png'.format(split)), self.cfg.DATASET.NAME+'_'+split) 181 | plotPRMap(outputs_for_tsne, label_for_tsne, os.path.join(save_path, '{}_PR.png'.format(split)), self.cfg.DATASET.NAME+'_'+split) 182 | 183 | 184 | # self.online_logger.log({"inputs": wandb.Image('./{}.png'.format(split)), "captions": wandb.Html(split)}) 185 | # wandb.run.summary["{}_accuracy".format(split)] = results['accuracy'] 186 | self.per_image_txt_writer.close() 187 | self.per_class_txt_writer.close() 188 | 189 | 190 | for k, v in results.items(): 191 | tag = "{}/{}".format(split, k) 192 | self.write_scalar(tag, v, self.epoch) 193 | 194 | return list(results.values())[0] 195 | 196 | -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | try: 16 | from torchvision.transforms import InterpolationMode 17 | BICUBIC = InterpolationMode.BICUBIC 18 | except ImportError: 19 | BICUBIC = Image.BICUBIC 20 | 21 | 22 | if torch.__version__.split(".") < ["1", "7", "1"]: 23 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 24 | 25 | 26 | __all__ = ["available_models", "load", "tokenize"] 27 | _tokenizer = _Tokenizer() 28 | 29 | _MODELS = { 30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 34 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 35 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 36 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 37 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 38 | } 39 | 40 | 41 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 42 | os.makedirs(root, exist_ok=True) 43 | filename = os.path.basename(url) 44 | 45 | expected_sha256 = url.split("/")[-2] 46 | download_target = os.path.join(root, filename) 47 | 48 | if os.path.exists(download_target) and not os.path.isfile(download_target): 49 | raise RuntimeError(f"{download_target} exists and is not a regular file") 50 | 51 | if os.path.isfile(download_target): 52 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 53 | return download_target 54 | else: 55 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 56 | 57 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 58 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 59 | while True: 60 | buffer = source.read(8192) 61 | if not buffer: 62 | break 63 | 64 | output.write(buffer) 65 | loop.update(len(buffer)) 66 | 67 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 68 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 69 | 70 | return download_target 71 | 72 | 73 | def _transform(n_px): 74 | return Compose([ 75 | Resize(n_px, interpolation=BICUBIC), 76 | CenterCrop(n_px), 77 | lambda image: image.convert("RGB"), 78 | ToTensor(), 79 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 80 | ]) 81 | 82 | 83 | def available_models() -> List[str]: 84 | """Returns the names of available CLIP models""" 85 | return list(_MODELS.keys()) 86 | 87 | 88 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False): 89 | """Load a CLIP model 90 | 91 | Parameters 92 | ---------- 93 | name : str 94 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 95 | 96 | device : Union[str, torch.device] 97 | The device to put the loaded model 98 | 99 | jit : bool 100 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 101 | 102 | Returns 103 | ------- 104 | model : torch.nn.Module 105 | The CLIP model 106 | 107 | preprocess : Callable[[PIL.Image], torch.Tensor] 108 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 109 | """ 110 | if name in _MODELS: 111 | model_path = _download(_MODELS[name]) 112 | elif os.path.isfile(name): 113 | model_path = name 114 | else: 115 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 116 | 117 | try: 118 | # loading JIT archive 119 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 120 | state_dict = None 121 | except RuntimeError: 122 | # loading saved state dict 123 | if jit: 124 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 125 | jit = False 126 | state_dict = torch.load(model_path, map_location="cpu") 127 | 128 | if not jit: 129 | model = build_model(state_dict or model.state_dict()).to(device) 130 | if str(device) == "cpu": 131 | model.float() 132 | return model, _transform(model.visual.input_resolution) 133 | 134 | # patch the device names 135 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 136 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 137 | 138 | def patch_device(module): 139 | try: 140 | graphs = [module.graph] if hasattr(module, "graph") else [] 141 | except RuntimeError: 142 | graphs = [] 143 | 144 | if hasattr(module, "forward1"): 145 | graphs.append(module.forward1.graph) 146 | 147 | for graph in graphs: 148 | for node in graph.findAllNodes("prim::Constant"): 149 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 150 | node.copyAttributes(device_node) 151 | 152 | model.apply(patch_device) 153 | patch_device(model.encode_image) 154 | patch_device(model.encode_text) 155 | 156 | # patch dtype to float32 on CPU 157 | if str(device) == "cpu": 158 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 159 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 160 | float_node = float_input.node() 161 | 162 | def patch_float(module): 163 | try: 164 | graphs = [module.graph] if hasattr(module, "graph") else [] 165 | except RuntimeError: 166 | graphs = [] 167 | 168 | if hasattr(module, "forward1"): 169 | graphs.append(module.forward1.graph) 170 | 171 | for graph in graphs: 172 | for node in graph.findAllNodes("aten::to"): 173 | inputs = list(node.inputs()) 174 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 175 | if inputs[i].node()["value"] == 5: 176 | inputs[i].node().copyAttributes(float_node) 177 | 178 | model.apply(patch_float) 179 | patch_float(model.encode_image) 180 | patch_float(model.encode_text) 181 | 182 | model.float() 183 | 184 | return model, _transform(model.input_resolution.item()) 185 | 186 | 187 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 188 | """ 189 | Returns the tokenized representation of given input string(s) 190 | 191 | Parameters 192 | ---------- 193 | texts : Union[str, List[str]] 194 | An input string or a list of input strings to tokenize 195 | 196 | context_length : int 197 | The context length to use; all CLIP models use 77 as the context length 198 | 199 | truncate: bool 200 | Whether to truncate the text in case its encoding is longer than the context length 201 | 202 | Returns 203 | ------- 204 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 205 | """ 206 | if isinstance(texts, str): 207 | texts = [texts] 208 | 209 | sot_token = _tokenizer.encoder["<|startoftext|>"] 210 | eot_token = _tokenizer.encoder["<|endoftext|>"] 211 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 212 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 213 | 214 | for i, tokens in enumerate(all_tokens): 215 | if len(tokens) > context_length: 216 | if truncate: 217 | tokens = tokens[:context_length] 218 | tokens[-1] = eot_token 219 | else: 220 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 221 | result[i, :len(tokens)] = torch.tensor(tokens) 222 | 223 | return result 224 | -------------------------------------------------------------------------------- /configs/upl_default_config/upl_default.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | ########################### 4 | # Config definition 5 | ########################### 6 | 7 | _C = CN() 8 | 9 | _C.VERSION = 1 10 | 11 | # Directory to save the output files (like log.txt and model weights) 12 | _C.OUTPUT_DIR = "./output" 13 | # Path to a directory where the files were saved previously 14 | _C.RESUME = "" 15 | # Set seed to negative value to randomize everything 16 | # Set seed to positive value to use a fixed seed 17 | _C.SEED = -1 18 | _C.USE_CUDA = True 19 | # Print detailed information 20 | # E.g. trainer, dataset, and backbone 21 | _C.VERBOSE = True 22 | 23 | ########################### 24 | # Input 25 | ########################### 26 | _C.INPUT = CN() 27 | _C.INPUT.SIZE = (224, 224) 28 | # Mode of interpolation in resize functions 29 | _C.INPUT.INTERPOLATION = "bilinear" 30 | # For available choices please refer to transforms.py 31 | _C.INPUT.TRANSFORMS = () 32 | # If True, tfm_train and tfm_test will be None 33 | _C.INPUT.NO_TRANSFORM = False 34 | # Default mean and std come from ImageNet 35 | _C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406] 36 | _C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225] 37 | # Padding for random crop 38 | _C.INPUT.CROP_PADDING = 4 39 | # Cutout 40 | _C.INPUT.CUTOUT_N = 1 41 | _C.INPUT.CUTOUT_LEN = 16 42 | # Gaussian noise 43 | _C.INPUT.GN_MEAN = 0.0 44 | _C.INPUT.GN_STD = 0.15 45 | # RandomAugment 46 | _C.INPUT.RANDAUGMENT_N = 2 47 | _C.INPUT.RANDAUGMENT_M = 10 48 | # ColorJitter (brightness, contrast, saturation, hue) 49 | _C.INPUT.COLORJITTER_B = 0.4 50 | _C.INPUT.COLORJITTER_C = 0.4 51 | _C.INPUT.COLORJITTER_S = 0.4 52 | _C.INPUT.COLORJITTER_H = 0.1 53 | # Random gray scale's probability 54 | _C.INPUT.RGS_P = 0.2 55 | # Gaussian blur 56 | _C.INPUT.GB_P = 0.5 # propability of applying this operation 57 | _C.INPUT.GB_K = 21 # kernel size (should be an odd number) 58 | 59 | ########################### 60 | # Dataset 61 | ########################### 62 | _C.DATASET = CN() 63 | # Directory where datasets are stored 64 | _C.DATASET.ROOT = "" 65 | _C.DATASET.NAME = "" 66 | # List of names of source domains 67 | _C.DATASET.SOURCE_DOMAINS = () 68 | # List of names of target domains 69 | _C.DATASET.TARGET_DOMAINS = () 70 | # Number of labeled instances in total 71 | # Useful for the semi-supervised learning 72 | _C.DATASET.NUM_LABELED = -1 73 | # Number of images per class 74 | _C.DATASET.NUM_SHOTS = -1 75 | # Percentage of validation data (only used for SSL datasets) 76 | # Set to 0 if do not want to use val data 77 | # Using val data for hyperparameter tuning was done in Oliver et al. 2018 78 | _C.DATASET.VAL_PERCENT = 0.1 79 | # Fold index for STL-10 dataset (normal range is 0 - 9) 80 | # Negative number means None 81 | _C.DATASET.STL10_FOLD = -1 82 | # CIFAR-10/100-C's corruption type and intensity level 83 | _C.DATASET.CIFAR_C_TYPE = "" 84 | _C.DATASET.CIFAR_C_LEVEL = 1 85 | # Use all data in the unlabeled data set (e.g. FixMatch) 86 | _C.DATASET.ALL_AS_UNLABELED = False 87 | _C.DATASET.IGNORE_NUM = 0 88 | _C.DATASET.IGNORE_FILE = "" 89 | _C.DATASET.CLASS_EQULE = True 90 | _C.DATASET.CONF_THRESHOLD = 0.9 91 | 92 | ########################### 93 | # Dataloader 94 | ########################### 95 | _C.DATALOADER = CN() 96 | _C.DATALOADER.NUM_WORKERS = 4 97 | # Apply transformations to an image K times (during training) 98 | _C.DATALOADER.K_TRANSFORMS = 1 99 | # img0 denotes image tensor without augmentation 100 | # Useful for consistency learning 101 | _C.DATALOADER.RETURN_IMG0 = False 102 | # Setting for the train_x data-loader 103 | _C.DATALOADER.TRAIN_X = CN() 104 | _C.DATALOADER.TRAIN_X.SAMPLER = "RandomSampler" 105 | _C.DATALOADER.TRAIN_X.BATCH_SIZE = 32 106 | # Parameter for RandomDomainSampler 107 | # 0 or -1 means sampling from all domains 108 | _C.DATALOADER.TRAIN_X.N_DOMAIN = 0 109 | # Parameter of RandomClassSampler 110 | # Number of instances per class 111 | _C.DATALOADER.TRAIN_X.N_INS = 16 112 | 113 | # Setting for the train_u data-loader 114 | _C.DATALOADER.TRAIN_U = CN() 115 | # Set to false if you want to have unique 116 | # data loader params for train_u 117 | _C.DATALOADER.TRAIN_U.SAME_AS_X = True 118 | _C.DATALOADER.TRAIN_U.SAMPLER = "RandomSampler" 119 | _C.DATALOADER.TRAIN_U.BATCH_SIZE = 32 120 | _C.DATALOADER.TRAIN_U.N_DOMAIN = 0 121 | _C.DATALOADER.TRAIN_U.N_INS = 16 122 | 123 | # Setting for the test data-loader 124 | _C.DATALOADER.TEST = CN() 125 | _C.DATALOADER.TEST.SAMPLER = "SequentialSampler" 126 | _C.DATALOADER.TEST.BATCH_SIZE = 32 127 | 128 | _C.DATALOADER.OPEN_SETTING = False 129 | 130 | ########################### 131 | # Model 132 | ########################### 133 | _C.MODEL = CN() 134 | # Path to model weights (for initialization) 135 | _C.MODEL.INIT_WEIGHTS = "" 136 | _C.MODEL.BACKBONE = CN() 137 | _C.MODEL.BACKBONE.NAME = "" 138 | _C.MODEL.BACKBONE.PRETRAINED = True 139 | # Definition of embedding layers 140 | _C.MODEL.HEAD = CN() 141 | # If none, do not construct embedding layers, the 142 | # backbone's output will be passed to the classifier 143 | _C.MODEL.HEAD.NAME = "" 144 | # Structure of hidden layers (a list), e.g. [512, 512] 145 | # If undefined, no embedding layer will be constructed 146 | _C.MODEL.HEAD.HIDDEN_LAYERS = () 147 | _C.MODEL.HEAD.ACTIVATION = "relu" 148 | _C.MODEL.HEAD.BN = True 149 | _C.MODEL.HEAD.DROPOUT = 0.0 150 | 151 | _C.MODEL.PSEUDO_LABEL_MODELS =[] 152 | 153 | ########################### 154 | # Optimization 155 | ########################### 156 | _C.OPTIM = CN() 157 | _C.OPTIM.NAME = "adam" 158 | _C.OPTIM.LR = 0.0003 159 | _C.OPTIM.WEIGHT_DECAY = 5e-4 160 | _C.OPTIM.MOMENTUM = 0.9 161 | _C.OPTIM.SGD_DAMPNING = 0 162 | _C.OPTIM.SGD_NESTEROV = False 163 | _C.OPTIM.RMSPROP_ALPHA = 0.99 164 | _C.OPTIM.ADAM_BETA1 = 0.9 165 | _C.OPTIM.ADAM_BETA2 = 0.999 166 | # STAGED_LR allows different layers to have 167 | # different lr, e.g. pre-trained base layers 168 | # can be assigned a smaller lr than the new 169 | # classification layer 170 | _C.OPTIM.STAGED_LR = False 171 | _C.OPTIM.NEW_LAYERS = () 172 | _C.OPTIM.BASE_LR_MULT = 0.1 173 | # Learning rate scheduler 174 | _C.OPTIM.LR_SCHEDULER = "single_step" 175 | # -1 or 0 means the stepsize is equal to max_epoch 176 | _C.OPTIM.STEPSIZE = (-1, ) 177 | _C.OPTIM.GAMMA = 0.1 178 | _C.OPTIM.MAX_EPOCH = 10 179 | # Set WARMUP_EPOCH larger than 0 to activate warmup training 180 | _C.OPTIM.WARMUP_EPOCH = -1 181 | # Either linear or constant 182 | _C.OPTIM.WARMUP_TYPE = "linear" 183 | # Constant learning rate when type=constant 184 | _C.OPTIM.WARMUP_CONS_LR = 1e-5 185 | # Minimum learning rate when type=linear 186 | _C.OPTIM.WARMUP_MIN_LR = 1e-5 187 | # Recount epoch for the next scheduler (last_epoch=-1) 188 | # Otherwise last_epoch=warmup_epoch 189 | _C.OPTIM.WARMUP_RECOUNT = True 190 | 191 | ########################### 192 | # Train 193 | ########################### 194 | _C.TRAIN = CN() 195 | # How often (epoch) to save model during training 196 | # Set to 0 or negative value to only save the last one 197 | _C.TRAIN.CHECKPOINT_FREQ = 0 198 | # How often (batch) to print training information 199 | _C.TRAIN.PRINT_FREQ = 10 200 | # Use 'train_x', 'train_u' or 'smaller_one' to count 201 | # the number of iterations in an epoch (for DA and SSL) 202 | _C.TRAIN.COUNT_ITER = "train_x" 203 | 204 | ########################### 205 | # Test 206 | ########################### 207 | _C.TEST = CN() 208 | _C.TEST.EVALUATOR = "UPLClassification" 209 | _C.TEST.PER_CLASS_RESULT = False 210 | # Compute confusion matrix, which will be saved 211 | # to $OUTPUT_DIR/cmat.pt 212 | _C.TEST.COMPUTE_CMAT = False 213 | # If NO_TEST=True, no testing will be conducted 214 | _C.TEST.NO_TEST = False 215 | # Use test or val set for FINAL evaluation 216 | _C.TEST.SPLIT = "test" 217 | # Which model to test after training 218 | # Either last_step or best_val 219 | _C.TEST.FINAL_MODEL = "last_step" 220 | _C.TEST.Analyze_Result_Path = None 221 | 222 | ########################### 223 | # Trainer specifics 224 | ########################### 225 | _C.TRAINER = CN() 226 | _C.TRAINER.NAME = "" 227 | 228 | # MCD 229 | _C.TRAINER.MCD = CN() 230 | _C.TRAINER.MCD.N_STEP_F = 4 # number of steps to train F 231 | # MME 232 | _C.TRAINER.MME = CN() 233 | _C.TRAINER.MME.LMDA = 0.1 # weight for the entropy loss 234 | # SelfEnsembling 235 | _C.TRAINER.SE = CN() 236 | _C.TRAINER.SE.EMA_ALPHA = 0.999 237 | _C.TRAINER.SE.CONF_THRE = 0.95 238 | _C.TRAINER.SE.RAMPUP = 300 239 | 240 | # M3SDA 241 | _C.TRAINER.M3SDA = CN() 242 | _C.TRAINER.M3SDA.LMDA = 0.5 # weight for the moment distance loss 243 | _C.TRAINER.M3SDA.N_STEP_F = 4 # follow MCD 244 | # DAEL 245 | _C.TRAINER.DAEL = CN() 246 | _C.TRAINER.DAEL.WEIGHT_U = 0.5 # weight on the unlabeled loss 247 | _C.TRAINER.DAEL.CONF_THRE = 0.95 # confidence threshold 248 | _C.TRAINER.DAEL.STRONG_TRANSFORMS = () 249 | 250 | # CrossGrad 251 | _C.TRAINER.CG = CN() 252 | _C.TRAINER.CG.EPS_F = 1.0 # scaling parameter for D's gradients 253 | _C.TRAINER.CG.EPS_D = 1.0 # scaling parameter for F's gradients 254 | _C.TRAINER.CG.ALPHA_F = 0.5 # balancing weight for the label net's loss 255 | _C.TRAINER.CG.ALPHA_D = 0.5 # balancing weight for the domain net's loss 256 | # DDAIG 257 | _C.TRAINER.DDAIG = CN() 258 | _C.TRAINER.DDAIG.G_ARCH = "" # generator's architecture 259 | _C.TRAINER.DDAIG.LMDA = 0.3 # perturbation weight 260 | _C.TRAINER.DDAIG.CLAMP = False # clamp perturbation values 261 | _C.TRAINER.DDAIG.CLAMP_MIN = -1.0 262 | _C.TRAINER.DDAIG.CLAMP_MAX = 1.0 263 | _C.TRAINER.DDAIG.WARMUP = 0 264 | _C.TRAINER.DDAIG.ALPHA = 0.5 # balancing weight for the losses 265 | 266 | # EntMin 267 | _C.TRAINER.ENTMIN = CN() 268 | _C.TRAINER.ENTMIN.LMDA = 1e-3 # weight on the entropy loss 269 | # Mean Teacher 270 | _C.TRAINER.MEANTEA = CN() 271 | _C.TRAINER.MEANTEA.WEIGHT_U = 1.0 # weight on the unlabeled loss 272 | _C.TRAINER.MEANTEA.EMA_ALPHA = 0.999 273 | _C.TRAINER.MEANTEA.RAMPUP = 5 # epochs used to ramp up the loss_u weight 274 | # MixMatch 275 | _C.TRAINER.MIXMATCH = CN() 276 | _C.TRAINER.MIXMATCH.WEIGHT_U = 100.0 # weight on the unlabeled loss 277 | _C.TRAINER.MIXMATCH.TEMP = 2.0 # temperature for sharpening the probability 278 | _C.TRAINER.MIXMATCH.MIXUP_BETA = 0.75 279 | _C.TRAINER.MIXMATCH.RAMPUP = 20000 # steps used to ramp up the loss_u weight 280 | # FixMatch 281 | _C.TRAINER.FIXMATCH = CN() 282 | _C.TRAINER.FIXMATCH.WEIGHT_U = 1.0 # weight on the unlabeled loss 283 | _C.TRAINER.FIXMATCH.CONF_THRE = 0.95 # confidence threshold 284 | _C.TRAINER.FIXMATCH.STRONG_TRANSFORMS = () 285 | 286 | _C.TRAINER.ENSEMBLE_NUM = 1 287 | 288 | 289 | def get_cfg_default(): 290 | return _C.clone() -------------------------------------------------------------------------------- /trainers/coop.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from torch.cuda.amp import GradScaler, autocast 7 | 8 | from dassl.engine import TRAINER_REGISTRY, TrainerX 9 | from dassl.metrics import compute_accuracy 10 | from dassl.utils import load_pretrained_weights, load_checkpoint 11 | from dassl.optim import build_optimizer, build_lr_scheduler 12 | 13 | from clip import clip 14 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | _tokenizer = _Tokenizer() 17 | 18 | 19 | def load_clip_to_cpu(cfg): 20 | backbone_name = cfg.MODEL.BACKBONE.NAME 21 | url = clip._MODELS[backbone_name] 22 | model_path = clip._download(url) 23 | 24 | try: 25 | # loading JIT archive 26 | model = torch.jit.load(model_path, map_location="cpu").eval() 27 | state_dict = None 28 | 29 | except RuntimeError: 30 | state_dict = torch.load(model_path, map_location="cpu") 31 | 32 | model = clip.build_model(state_dict or model.state_dict()) 33 | 34 | return model 35 | 36 | 37 | class TextEncoder(nn.Module): 38 | def __init__(self, clip_model): 39 | super().__init__() 40 | self.transformer = clip_model.transformer 41 | self.positional_embedding = clip_model.positional_embedding 42 | self.ln_final = clip_model.ln_final 43 | self.text_projection = clip_model.text_projection 44 | self.dtype = clip_model.dtype 45 | 46 | def forward(self, prompts, tokenized_prompts): 47 | x = prompts + self.positional_embedding.type(self.dtype) 48 | x = x.permute(1, 0, 2) # NLD -> LND 49 | x = self.transformer(x) 50 | x = x.permute(1, 0, 2) # LND -> NLD 51 | x = self.ln_final(x).type(self.dtype) 52 | 53 | # x.shape = [batch_size, n_ctx, transformer.width] 54 | # take features from the eot embedding (eot_token is the highest number in each sequence) 55 | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection 56 | 57 | return x 58 | 59 | 60 | class PromptLearner(nn.Module): 61 | def __init__(self, cfg, classnames, clip_model): 62 | super().__init__() 63 | n_cls = len(classnames) 64 | n_ctx = cfg.TRAINER.COOP.N_CTX 65 | ctx_init = cfg.TRAINER.COOP.CTX_INIT 66 | dtype = clip_model.dtype 67 | ctx_dim = clip_model.ln_final.weight.shape[0] 68 | clip_imsize = clip_model.visual.input_resolution 69 | cfg_imsize = cfg.INPUT.SIZE[0] 70 | assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})" 71 | 72 | if ctx_init: 73 | # use given words to initialize context vectors 74 | ctx_init = ctx_init.replace("_", " ") 75 | n_ctx = len(ctx_init.split(" ")) 76 | prompt = clip.tokenize(ctx_init) 77 | with torch.no_grad(): 78 | embedding = clip_model.token_embedding(prompt).type(dtype) 79 | ctx_vectors = embedding[0, 1 : 1 + n_ctx, :] 80 | prompt_prefix = ctx_init 81 | 82 | else: 83 | # random initialization 84 | if cfg.TRAINER.COOP.CSC: 85 | print("Initializing class-specific contexts") 86 | ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype) 87 | else: 88 | print("Initializing a generic context") 89 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) 90 | nn.init.normal_(ctx_vectors, std=0.02) 91 | prompt_prefix = " ".join(["X"] * n_ctx) 92 | 93 | print(f'Initial context: "{prompt_prefix}"') 94 | print(f"Number of context words (tokens): {n_ctx}") 95 | 96 | self.ctx = nn.Parameter(ctx_vectors) # to be optimized 97 | 98 | classnames = [name.replace("_", " ") for name in classnames] 99 | name_lens = [len(_tokenizer.encode(name)) for name in classnames] 100 | prompts = [prompt_prefix + " " + name + "." for name in classnames] 101 | 102 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) 103 | with torch.no_grad(): 104 | embedding = clip_model.token_embedding(tokenized_prompts).type(dtype) 105 | 106 | # These token vectors will be saved when in save_model(), 107 | # but they should be ignored in load_model() as we want to use 108 | # those computed using the current class names 109 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS 110 | self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS 111 | 112 | self.n_cls = n_cls 113 | self.n_ctx = n_ctx 114 | self.tokenized_prompts = tokenized_prompts # torch.Tensor 115 | self.name_lens = name_lens 116 | self.class_token_position = cfg.TRAINER.COOP.CLASS_TOKEN_POSITION 117 | 118 | def forward(self): 119 | ctx = self.ctx 120 | if ctx.dim() == 2: 121 | ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) 122 | 123 | prefix = self.token_prefix 124 | suffix = self.token_suffix 125 | 126 | if self.class_token_position == "end": 127 | prompts = torch.cat( 128 | [ 129 | prefix, # (n_cls, 1, dim) 130 | ctx, # (n_cls, n_ctx, dim) 131 | suffix, # (n_cls, *, dim) 132 | ], 133 | dim=1, 134 | ) 135 | 136 | elif self.class_token_position == "middle": 137 | half_n_ctx = self.n_ctx // 2 138 | prompts = [] 139 | for i in range(self.n_cls): 140 | name_len = self.name_lens[i] 141 | prefix_i = prefix[i : i + 1, :, :] 142 | class_i = suffix[i : i + 1, :name_len, :] 143 | suffix_i = suffix[i : i + 1, name_len:, :] 144 | ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :] 145 | ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :] 146 | prompt = torch.cat( 147 | [ 148 | prefix_i, # (1, 1, dim) 149 | ctx_i_half1, # (1, n_ctx//2, dim) 150 | class_i, # (1, name_len, dim) 151 | ctx_i_half2, # (1, n_ctx//2, dim) 152 | suffix_i, # (1, *, dim) 153 | ], 154 | dim=1, 155 | ) 156 | prompts.append(prompt) 157 | prompts = torch.cat(prompts, dim=0) 158 | 159 | elif self.class_token_position == "front": 160 | prompts = [] 161 | for i in range(self.n_cls): 162 | name_len = self.name_lens[i] 163 | prefix_i = prefix[i : i + 1, :, :] 164 | class_i = suffix[i : i + 1, :name_len, :] 165 | suffix_i = suffix[i : i + 1, name_len:, :] 166 | ctx_i = ctx[i : i + 1, :, :] 167 | prompt = torch.cat( 168 | [ 169 | prefix_i, # (1, 1, dim) 170 | class_i, # (1, name_len, dim) 171 | ctx_i, # (1, n_ctx, dim) 172 | suffix_i, # (1, *, dim) 173 | ], 174 | dim=1, 175 | ) 176 | prompts.append(prompt) 177 | prompts = torch.cat(prompts, dim=0) 178 | 179 | else: 180 | raise ValueError 181 | 182 | return prompts 183 | 184 | 185 | class CustomCLIP(nn.Module): 186 | def __init__(self, cfg, classnames, clip_model): 187 | super().__init__() 188 | self.prompt_learner = PromptLearner(cfg, classnames, clip_model) 189 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts 190 | self.image_encoder = clip_model.visual 191 | self.text_encoder = TextEncoder(clip_model) 192 | self.logit_scale = clip_model.logit_scale 193 | self.dtype = clip_model.dtype 194 | 195 | def forward(self, image): 196 | image_features = self.image_encoder(image.type(self.dtype)) 197 | 198 | prompts = self.prompt_learner() 199 | tokenized_prompts = self.tokenized_prompts 200 | text_features = self.text_encoder(prompts, tokenized_prompts) 201 | 202 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 203 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 204 | 205 | logit_scale = self.logit_scale.exp() 206 | logits = logit_scale * image_features @ text_features.t() 207 | 208 | return logits 209 | 210 | 211 | @TRAINER_REGISTRY.register() 212 | class CoOp(TrainerX): 213 | """Context Optimization (CoOp). 214 | 215 | Learning to Prompt for Vision-Language Models 216 | https://arxiv.org/abs/2109.01134 217 | """ 218 | 219 | def check_cfg(self, cfg): 220 | assert cfg.TRAINER.COOP.PREC in ["fp16", "fp32", "amp"] 221 | 222 | def build_model(self): 223 | cfg = self.cfg 224 | classnames = self.dm.dataset.classnames 225 | 226 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 227 | clip_model = load_clip_to_cpu(cfg) 228 | 229 | if cfg.TRAINER.COOP.PREC == "fp32" or cfg.TRAINER.COOP.PREC == "amp": 230 | # CLIP's default precision is fp16 231 | clip_model.float() 232 | 233 | print("Building custom CLIP") 234 | self.model = CustomCLIP(cfg, classnames, clip_model) 235 | 236 | print("Turning off gradients in both the image and the text encoder") 237 | for name, param in self.model.named_parameters(): 238 | if "prompt_learner" not in name: 239 | param.requires_grad_(False) 240 | 241 | if cfg.MODEL.INIT_WEIGHTS: 242 | load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS) 243 | 244 | self.model.to(self.device) 245 | # NOTE: only give prompt_learner to the optimizer 246 | self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM) 247 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) 248 | self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched) 249 | 250 | self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == "amp" else None 251 | 252 | # Note that multi-gpu training could be slow because CLIP's size is 253 | # big, which slows down the copy operation in DataParallel 254 | device_count = torch.cuda.device_count() 255 | if device_count > 1: 256 | print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!") 257 | self.model = nn.DataParallel(self.model) 258 | 259 | def forward_backward(self, batch): 260 | image, label = self.parse_batch_train(batch) 261 | 262 | prec = self.cfg.TRAINER.COOP.PREC 263 | if prec == "amp": 264 | with autocast(): 265 | output, image_features, text_features = self.model(image) 266 | loss = F.cross_entropy(output, label) 267 | self.optim.zero_grad() 268 | self.scaler.scale(loss).backward() 269 | self.scaler.step(self.optim) 270 | self.scaler.update() 271 | else: 272 | output = self.model(image) 273 | loss = F.cross_entropy(output, label) 274 | self.model_backward_and_update(loss) 275 | 276 | loss_summary = { 277 | "loss": loss.item(), 278 | "acc": compute_accuracy(output, label)[0].item(), 279 | } 280 | 281 | if (self.batch_idx + 1) == self.num_batches: 282 | self.update_lr() 283 | 284 | return loss_summary 285 | 286 | def parse_batch_train(self, batch): 287 | input = batch["img"] 288 | label = batch["label"] 289 | input = input.to(self.device) 290 | label = label.to(self.device) 291 | return input, label 292 | 293 | def load_model(self, directory, epoch=None): 294 | if not directory: 295 | print("Note that load_model() is skipped as no pretrained model is given") 296 | return 297 | 298 | names = self.get_model_names() 299 | 300 | # By default, the best model is loaded 301 | model_file = "model-best.pth.tar" 302 | 303 | if epoch is not None: 304 | model_file = "model.pth.tar-" + str(epoch) 305 | 306 | for name in names: 307 | model_path = osp.join(directory, name, model_file) 308 | 309 | if not osp.exists(model_path): 310 | raise FileNotFoundError('Model not found at "{}"'.format(model_path)) 311 | 312 | checkpoint = load_checkpoint(model_path) 313 | state_dict = checkpoint["state_dict"] 314 | epoch = checkpoint["epoch"] 315 | 316 | # Ignore fixed token vectors 317 | if "token_prefix" in state_dict: 318 | del state_dict["token_prefix"] 319 | 320 | if "token_suffix" in state_dict: 321 | del state_dict["token_suffix"] 322 | 323 | print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch)) 324 | # set strict=False 325 | self._models[name].load_state_dict(state_dict, strict=False) 326 | 327 | def load_model_by_id(self, directory, model_id, epoch=None): 328 | if not directory: 329 | print( 330 | 'Note that load_model() is skipped as no pretrained model is given' 331 | ) 332 | return 333 | 334 | names = self.get_model_names() 335 | 336 | # By default, the best model is loaded 337 | model_file = 'model-best-{}.pth.tar'.format(model_id) 338 | 339 | if epoch is not None: 340 | model_file = 'model.pth.tar-' + str(epoch) 341 | 342 | for name in names: 343 | model_path = osp.join(directory, name, model_file) 344 | 345 | if not osp.exists(model_path): 346 | raise FileNotFoundError( 347 | 'Model not found at "{}"'.format(model_path) 348 | ) 349 | 350 | checkpoint = load_checkpoint(model_path) 351 | state_dict = checkpoint['state_dict'] 352 | epoch = checkpoint['epoch'] 353 | 354 | # Ignore fixed token vectors 355 | if 'token_prefix' in state_dict: 356 | del state_dict['token_prefix'] 357 | 358 | if 'token_suffix' in state_dict: 359 | del state_dict['token_suffix'] 360 | 361 | print( 362 | 'Loading weights to {} ' 363 | 'from "{}" (epoch = {})'.format(name, model_path, epoch) 364 | ) 365 | # set strict=False 366 | self._models[name].load_state_dict(state_dict, strict=False) 367 | 368 | -------------------------------------------------------------------------------- /clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | 20 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 24 | 25 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 26 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 27 | 28 | self.relu = nn.ReLU(inplace=True) 29 | self.downsample = None 30 | self.stride = stride 31 | 32 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 33 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 34 | self.downsample = nn.Sequential(OrderedDict([ 35 | ("-1", nn.AvgPool2d(stride)), 36 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 37 | ("1", nn.BatchNorm2d(planes * self.expansion)) 38 | ])) 39 | 40 | def forward(self, x: torch.Tensor): 41 | identity = x 42 | 43 | out = self.relu(self.bn1(self.conv1(x))) 44 | out = self.relu(self.bn2(self.conv2(out))) 45 | out = self.avgpool(out) 46 | out = self.bn3(self.conv3(out)) 47 | 48 | if self.downsample is not None: 49 | identity = self.downsample(x) 50 | 51 | out += identity 52 | out = self.relu(out) 53 | return out 54 | 55 | 56 | class AttentionPool2d(nn.Module): 57 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 58 | super().__init__() 59 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 60 | self.k_proj = nn.Linear(embed_dim, embed_dim) 61 | self.q_proj = nn.Linear(embed_dim, embed_dim) 62 | self.v_proj = nn.Linear(embed_dim, embed_dim) 63 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 64 | self.num_heads = num_heads 65 | 66 | def forward(self, x): 67 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 68 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 69 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 70 | x, _ = F.multi_head_attention_forward( 71 | query=x, key=x, value=x, 72 | embed_dim_to_check=x.shape[-1], 73 | num_heads=self.num_heads, 74 | q_proj_weight=self.q_proj.weight, 75 | k_proj_weight=self.k_proj.weight, 76 | v_proj_weight=self.v_proj.weight, 77 | in_proj_weight=None, 78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 79 | bias_k=None, 80 | bias_v=None, 81 | add_zero_attn=False, 82 | dropout_p=0, 83 | out_proj_weight=self.c_proj.weight, 84 | out_proj_bias=self.c_proj.bias, 85 | use_separate_proj_weight=True, 86 | training=self.training, 87 | need_weights=False 88 | ) 89 | 90 | return x[0] 91 | 92 | 93 | class ModifiedResNet(nn.Module): 94 | """ 95 | A ResNet class that is similar to torchvision's but contains the following changes: 96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 98 | - The final pooling layer is a QKV attention instead of an average pool 99 | """ 100 | 101 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 102 | super().__init__() 103 | self.output_dim = output_dim 104 | self.input_resolution = input_resolution 105 | 106 | # the 3-layer stem 107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 108 | self.bn1 = nn.BatchNorm2d(width // 2) 109 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 110 | self.bn2 = nn.BatchNorm2d(width // 2) 111 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 112 | self.bn3 = nn.BatchNorm2d(width) 113 | self.avgpool = nn.AvgPool2d(2) 114 | self.relu = nn.ReLU(inplace=True) 115 | 116 | # residual layers 117 | self._inplanes = width # this is a *mutable* variable used during construction 118 | self.layer1 = self._make_layer(width, layers[0]) 119 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 120 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 121 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 122 | 123 | embed_dim = width * 32 # the ResNet feature dimension 124 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 125 | 126 | def _make_layer(self, planes, blocks, stride=1): 127 | layers = [Bottleneck(self._inplanes, planes, stride)] 128 | 129 | self._inplanes = planes * Bottleneck.expansion 130 | for _ in range(1, blocks): 131 | layers.append(Bottleneck(self._inplanes, planes)) 132 | 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | def stem(x): 137 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 138 | x = self.relu(bn(conv(x))) 139 | x = self.avgpool(x) 140 | return x 141 | 142 | x = x.type(self.conv1.weight.dtype) 143 | x = stem(x) 144 | x = self.layer1(x) 145 | x = self.layer2(x) 146 | x = self.layer3(x) 147 | x = self.layer4(x) 148 | x = self.attnpool(x) 149 | 150 | return x 151 | 152 | 153 | class LayerNorm(nn.LayerNorm): 154 | """Subclass torch's LayerNorm to handle fp16.""" 155 | 156 | def forward(self, x: torch.Tensor): 157 | orig_type = x.dtype 158 | ret = super().forward(x.type(torch.float32)) 159 | return ret.type(orig_type) 160 | 161 | 162 | class QuickGELU(nn.Module): 163 | def forward(self, x: torch.Tensor): 164 | return x * torch.sigmoid(1.702 * x) 165 | 166 | 167 | class ResidualAttentionBlock(nn.Module): 168 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 169 | super().__init__() 170 | 171 | self.attn = nn.MultiheadAttention(d_model, n_head) 172 | self.ln_1 = LayerNorm(d_model) 173 | self.mlp = nn.Sequential(OrderedDict([ 174 | ("c_fc", nn.Linear(d_model, d_model * 4)), 175 | ("gelu", QuickGELU()), 176 | ("c_proj", nn.Linear(d_model * 4, d_model)) 177 | ])) 178 | self.ln_2 = LayerNorm(d_model) 179 | self.attn_mask = attn_mask 180 | 181 | def attention(self, x: torch.Tensor): 182 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 183 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 184 | 185 | def forward(self, x: torch.Tensor): 186 | x = x + self.attention(self.ln_1(x)) 187 | x = x + self.mlp(self.ln_2(x)) 188 | return x 189 | 190 | 191 | class Transformer(nn.Module): 192 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 193 | super().__init__() 194 | self.width = width 195 | self.layers = layers 196 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 197 | 198 | def forward(self, x: torch.Tensor): 199 | return self.resblocks(x) 200 | 201 | 202 | class VisionTransformer(nn.Module): 203 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 204 | super().__init__() 205 | self.input_resolution = input_resolution 206 | self.output_dim = output_dim 207 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 208 | 209 | scale = width ** -0.5 210 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 211 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 212 | self.ln_pre = LayerNorm(width) 213 | 214 | self.transformer = Transformer(width, layers, heads) 215 | 216 | self.ln_post = LayerNorm(width) 217 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 218 | 219 | def forward(self, x: torch.Tensor): 220 | x = self.conv1(x) # shape = [*, width, grid, grid] 221 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 222 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 223 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 224 | x = x + self.positional_embedding.to(x.dtype) 225 | x = self.ln_pre(x) 226 | 227 | x = x.permute(1, 0, 2) # NLD -> LND 228 | x = self.transformer(x) 229 | x = x.permute(1, 0, 2) # LND -> NLD 230 | 231 | x = self.ln_post(x[:, 0, :]) 232 | 233 | if self.proj is not None: 234 | x = x @ self.proj 235 | 236 | return x 237 | 238 | 239 | class CLIP(nn.Module): 240 | def __init__(self, 241 | embed_dim: int, 242 | # vision 243 | image_resolution: int, 244 | vision_layers: Union[Tuple[int, int, int, int], int], 245 | vision_width: int, 246 | vision_patch_size: int, 247 | # text 248 | context_length: int, 249 | vocab_size: int, 250 | transformer_width: int, 251 | transformer_heads: int, 252 | transformer_layers: int 253 | ): 254 | super().__init__() 255 | 256 | self.context_length = context_length 257 | 258 | if isinstance(vision_layers, (tuple, list)): 259 | vision_heads = vision_width * 32 // 64 260 | self.visual = ModifiedResNet( 261 | layers=vision_layers, 262 | output_dim=embed_dim, 263 | heads=vision_heads, 264 | input_resolution=image_resolution, 265 | width=vision_width 266 | ) 267 | else: 268 | vision_heads = vision_width // 64 269 | self.visual = VisionTransformer( 270 | input_resolution=image_resolution, 271 | patch_size=vision_patch_size, 272 | width=vision_width, 273 | layers=vision_layers, 274 | heads=vision_heads, 275 | output_dim=embed_dim 276 | ) 277 | 278 | self.transformer = Transformer( 279 | width=transformer_width, 280 | layers=transformer_layers, 281 | heads=transformer_heads, 282 | attn_mask=self.build_attention_mask() 283 | ) 284 | 285 | self.vocab_size = vocab_size 286 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 287 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 288 | self.ln_final = LayerNorm(transformer_width) 289 | 290 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 291 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 292 | 293 | self.initialize_parameters() 294 | 295 | def initialize_parameters(self): 296 | nn.init.normal_(self.token_embedding.weight, std=0.02) 297 | nn.init.normal_(self.positional_embedding, std=0.01) 298 | 299 | if isinstance(self.visual, ModifiedResNet): 300 | if self.visual.attnpool is not None: 301 | std = self.visual.attnpool.c_proj.in_features ** -0.5 302 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 303 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 304 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 305 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 306 | 307 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 308 | for name, param in resnet_block.named_parameters(): 309 | if name.endswith("bn3.weight"): 310 | nn.init.zeros_(param) 311 | 312 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 313 | attn_std = self.transformer.width ** -0.5 314 | fc_std = (2 * self.transformer.width) ** -0.5 315 | for block in self.transformer.resblocks: 316 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 317 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 318 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 319 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 320 | 321 | if self.text_projection is not None: 322 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 323 | 324 | def build_attention_mask(self): 325 | # lazily create causal attention mask, with full attention between the vision tokens 326 | # pytorch uses additive attention mask; fill with -inf 327 | mask = torch.empty(self.context_length, self.context_length) 328 | mask.fill_(float("-inf")) 329 | mask.triu_(1) # zero out the lower diagonal 330 | return mask 331 | 332 | @property 333 | def dtype(self): 334 | return self.visual.conv1.weight.dtype 335 | 336 | def encode_image(self, image): 337 | return self.visual(image.type(self.dtype)) 338 | 339 | def encode_text(self, text): 340 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 341 | 342 | x = x + self.positional_embedding.type(self.dtype) 343 | x = x.permute(1, 0, 2) # NLD -> LND 344 | x = self.transformer(x) 345 | x = x.permute(1, 0, 2) # LND -> NLD 346 | x = self.ln_final(x).type(self.dtype) 347 | 348 | # x.shape = [batch_size, n_ctx, transformer.width] 349 | # take features from the eot embedding (eot_token is the highest number in each sequence) 350 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 351 | 352 | return x 353 | 354 | def forward(self, image, text): 355 | image_features = self.encode_image(image) 356 | text_features = self.encode_text(text) 357 | 358 | # normalized features 359 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 360 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 361 | 362 | # cosine similarity as logits 363 | logit_scale = self.logit_scale.exp() 364 | logits_per_image = logit_scale * image_features @ text_features.t() 365 | logits_per_text = logit_scale * text_features @ image_features.t() 366 | 367 | # shape = [global_batch_size, global_batch_size] 368 | return logits_per_image, logits_per_text 369 | 370 | 371 | def convert_weights(model: nn.Module): 372 | """Convert applicable model parameters to fp16""" 373 | 374 | def _convert_weights_to_fp16(l): 375 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 376 | l.weight.data = l.weight.data.half() 377 | if l.bias is not None: 378 | l.bias.data = l.bias.data.half() 379 | 380 | if isinstance(l, nn.MultiheadAttention): 381 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 382 | tensor = getattr(l, attr) 383 | if tensor is not None: 384 | tensor.data = tensor.data.half() 385 | 386 | for name in ["text_projection", "proj"]: 387 | if hasattr(l, name): 388 | attr = getattr(l, name) 389 | if attr is not None: 390 | attr.data = attr.data.half() 391 | 392 | model.apply(_convert_weights_to_fp16) 393 | 394 | 395 | def build_model(state_dict: dict): 396 | vit = "visual.proj" in state_dict 397 | 398 | if vit: 399 | vision_width = state_dict["visual.conv1.weight"].shape[0] 400 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 401 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 402 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 403 | image_resolution = vision_patch_size * grid_size 404 | else: 405 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 406 | vision_layers = tuple(counts) 407 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 408 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 409 | vision_patch_size = None 410 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 411 | image_resolution = output_width * 32 412 | 413 | embed_dim = state_dict["text_projection"].shape[1] 414 | context_length = state_dict["positional_embedding"].shape[0] 415 | vocab_size = state_dict["token_embedding.weight"].shape[0] 416 | transformer_width = state_dict["ln_final.weight"].shape[0] 417 | transformer_heads = transformer_width // 64 418 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 419 | 420 | model = CLIP( 421 | embed_dim, 422 | image_resolution, vision_layers, vision_width, vision_patch_size, 423 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 424 | ) 425 | 426 | for key in ["input_resolution", "context_length", "vocab_size"]: 427 | if key in state_dict: 428 | del state_dict[key] 429 | 430 | convert_weights(model) 431 | model.load_state_dict(state_dict) 432 | return model.eval() 433 | -------------------------------------------------------------------------------- /trainers/utils.py: -------------------------------------------------------------------------------- 1 | from tkinter.tix import Tree 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import os 5 | import json 6 | import torch 7 | from sklearn.metrics import precision_recall_curve 8 | 9 | 10 | def plotLogitsMap(outputs, label, save_path, fig_title, max_lines=1000): 11 | fig, ax = plt.subplots(figsize=(5, 200)) 12 | Softmax = torch.nn.Softmax(dim=1) 13 | output_m = Softmax(outputs) 14 | output_m = outputs.cpu().detach().numpy() 15 | 16 | pred = outputs.max(1)[1] 17 | matches = pred.eq(label).float() 18 | output_m = np.sort(output_m) 19 | output_m = output_m[:,::-1] # 从大到小排序 20 | output_m = output_m[:,:5] # 取前五个 21 | output_m_index = output_m[:,0].argsort() 22 | output_m = output_m[output_m_index] 23 | output_m = output_m[::-1,:] # 按第一列从大到小排序 24 | matches = matches[output_m_index] 25 | matches = torch.flip(matches, dims=[0]) 26 | matches = matches.cpu().detach().numpy() 27 | 28 | 29 | if len(matches) > max_lines: 30 | gap = int(len(matches) / 1000) 31 | index = np.arange(0, gap*1000, gap, int) 32 | output_m = output_m[index] 33 | matches = matches[index] 34 | print(save_path) 35 | matches = matches.tolist() 36 | 37 | 38 | im = ax.imshow(output_m, aspect='auto') 39 | ax.set_yticks(np.arange(output_m.shape[0]), labels=matches) 40 | for i, label in enumerate(ax.get_yticklabels()): 41 | if (int(matches[i])==0): 42 | label.set_color('red') 43 | elif (int(matches[i])==1): 44 | label.set_color('green') 45 | 46 | for i in range(output_m.shape[0]): 47 | for j in range(output_m.shape[1]): 48 | text = ax.text(j, i, str(round(output_m[i, j],2)), 49 | ha="center", va="center", color="w") 50 | plt.title(fig_title) 51 | plt.savefig(save_path) 52 | plt.close() 53 | 54 | def plotPRMap(outputs, label, save_path, fig_title): 55 | plt.figure(figsize=(15,15)) 56 | plt.title('{} Precision/Recall Curve'.format(fig_title)) 57 | plt.xlabel('Recall') 58 | plt.ylabel('Precision') 59 | output_m = outputs.cpu().detach().numpy() 60 | pred = outputs.max(1)[1] 61 | matches = pred.eq(label).float() 62 | output_m = np.sort(output_m) 63 | output_m = output_m[:,::-1] # 从大到小排序 64 | output_m = output_m[:,:5] # 取前五个 65 | output_m_index = output_m[:,0].argsort() 66 | output_m = output_m[output_m_index] 67 | output_m = output_m[::-1,:] # 按第一列从大到小排序 68 | matches = matches[output_m_index] 69 | matches = torch.flip(matches, dims=[0]) 70 | matches = matches.cpu().detach().numpy() 71 | # print(output_m[:,0].shape, matches.shape) 72 | precision, recall, thresholds = precision_recall_curve(matches, output_m[:,0]) 73 | # print(precision) 74 | # print(recall) 75 | # print(thresholds, len(thresholds)) 76 | plt.plot(recall, precision) 77 | 78 | step = 0 79 | for a, b, text in zip(recall, precision, thresholds): 80 | # if float(text) % 0.05 == 0: 81 | if step % 40 == 0: 82 | plt.text(a, b, text, ha='center', va='bottom', fontsize=10, color='blue') 83 | plt.plot(a, b, marker='o', color='red') 84 | step += 1 85 | plt.grid(ls='--') 86 | plt.savefig(save_path) 87 | plt.close() 88 | 89 | def select_top_k_similarity_per_class(outputs, img_paths, K=1, image_features=None, is_softmax=True): 90 | # print(outputs.shape) 91 | if is_softmax: 92 | outputs = torch.nn.Softmax(dim=1)(outputs) 93 | output_m = outputs.cpu().detach().numpy() 94 | output_ori = outputs.cpu().detach() 95 | output_m_max = output_m.max(axis=1) 96 | output_m_max_id = np.argsort(-output_m_max) 97 | output_m = output_m[output_m_max_id] 98 | img_paths = img_paths[output_m_max_id] 99 | output_m_max = output_m_max[output_m_max_id] 100 | output_ori = output_ori[output_m_max_id] 101 | ids = (-output_m).argsort()[:, 0] # 获得每行的类别标签 102 | 103 | if image_features is not None: 104 | image_features = image_features.cpu().detach() 105 | image_features = image_features[output_m_max_id] 106 | 107 | predict_label_dict = {} 108 | predict_conf_dict = {} 109 | from tqdm import tqdm 110 | for id in tqdm(list(set(ids.tolist()))): # 标签去重 111 | index = np.where(ids==id) 112 | conf_class = output_m_max[index] # 置信度 113 | output_class = output_ori[index] 114 | img_paths_class = img_paths[index] # 每个类别的路径 115 | 116 | if image_features is not None: 117 | img_features = image_features[index] 118 | if K >= 0: 119 | for img_path, img_feature, conf, logit in zip(img_paths_class[:K], img_features[:K], conf_class[:K], output_class): 120 | if '/data/' in img_path: 121 | img_path = './data/' + img_path.split('/data/')[1] 122 | predict_label_dict[img_path] = [id, img_feature, conf, logit] 123 | else: 124 | for img_path, img_feature, conf, logit in zip(img_paths_class, img_features, conf_class, output_class): 125 | if '/data/' in img_path: 126 | img_path = './data/' + img_path.split('/data/')[1] 127 | predict_label_dict[img_path] = [id, img_feature, conf, logit] 128 | else: 129 | if K >= 0: 130 | for img_path, conf in zip(img_paths_class[:K], conf_class): 131 | if '/data/' in img_path: 132 | img_path = './data/' + img_path.split('/data/')[1] 133 | predict_label_dict[img_path] = id 134 | predict_conf_dict[img_path] = conf 135 | else: 136 | for img_path, conf in zip(img_paths_class, conf_class): 137 | if '/data/' in img_path: 138 | img_path = './data/' + img_path.split('/data/')[1] 139 | predict_label_dict[img_path] = id 140 | predict_conf_dict[img_path] = conf 141 | return predict_label_dict, predict_conf_dict 142 | 143 | def select_by_conf(outputs, img_paths, K=1, conf_threshold=None, is_softmax=True): 144 | # print(outputs.shape) 145 | if is_softmax: 146 | outputs = torch.nn.Softmax(dim=1)(outputs) 147 | output_m = outputs.cpu().detach().numpy() 148 | output_ori = outputs.cpu().detach() 149 | output_m_max = output_m.max(axis=1) 150 | output_m_max_id = np.argsort(-output_m_max) 151 | output_m = output_m[output_m_max_id] 152 | img_paths = img_paths[output_m_max_id] 153 | output_m_max = output_m_max[output_m_max_id] 154 | output_ori = output_ori[output_m_max_id] 155 | ids = (-output_m).argsort()[:, 0] # 获得每行的类别标签 156 | 157 | 158 | predict_label_dict = {} 159 | predict_conf_dict = {} 160 | from tqdm import tqdm 161 | for id in tqdm(list(set(ids.tolist()))): # 标签去重 162 | index = np.where(ids==id) 163 | conf_class = output_m_max[index] # 置信度 164 | output_class = output_ori[index] 165 | img_paths_class = img_paths[index] # 每个类别的路径 166 | 167 | for img_path, conf in zip(img_paths_class, conf_class): 168 | if conf > conf_threshold: 169 | predict_label_dict[img_path] = id 170 | predict_conf_dict[img_path] = conf 171 | return predict_label_dict, predict_conf_dict 172 | 173 | def select_top_k_similarity(outputs, img_paths, K=1, image_features=None, repeat=False): 174 | # print(outputs.shape) 175 | outputs = torch.nn.Softmax(dim=1)(outputs) 176 | output_m = outputs.cpu().detach().numpy() 177 | output_ori = outputs.cpu().detach() 178 | output_m_max = output_m.max(axis=1) 179 | output_m_max_id = np.argsort(-output_m_max) 180 | output_m = output_m[output_m_max_id] 181 | img_paths = img_paths[output_m_max_id] 182 | # output_m_max = output_m_max[output_m_max_id] 183 | output_ori = output_ori[output_m_max_id] 184 | conf_class = output_m_max[output_m_max_id] # 置信度 185 | ids = (-output_m).argsort()[:, 0] # 获得每行的类别标签 186 | 187 | if image_features is not None: 188 | image_features = image_features.cpu().detach() 189 | image_features = image_features[output_m_max_id] 190 | 191 | predict_label_dict = {} 192 | predict_conf_dict = {} 193 | if image_features is not None: 194 | img_features = image_features 195 | if K >= 0: 196 | for img_path, img_feature, conf, logit in zip(img_paths_class[:K], img_features[:K], conf_class[:K], output_class): 197 | predict_label_dict[img_path] = [id, img_feature, conf, logit] 198 | else: 199 | for img_path, img_feature, conf, logit in zip(img_paths_class, img_features, conf_class, output_class): 200 | predict_label_dict[img_path] = [id, img_feature, conf, logit] 201 | else: 202 | if K >= 0: 203 | for img_path, conf, id in zip(img_paths[:K], conf_class[:K], ids[:K]): 204 | predict_label_dict[img_path] = id 205 | predict_conf_dict[img_path] = conf 206 | else: 207 | for img_path, conf, id in zip(img_paths, conf_class, ids): 208 | predict_label_dict[img_path] = id 209 | predict_conf_dict[img_path] = conf 210 | return predict_label_dict, predict_conf_dict 211 | 212 | 213 | def select_top_by_value(outputs, img_paths, conf_threshold=0.95, image_features=None, repeat=False): 214 | outputs = torch.nn.Softmax(dim=1)(outputs) 215 | output_m = outputs.cpu().detach().numpy() 216 | output_ori = outputs.cpu().detach() 217 | output_m_max = output_m.max(axis=1) 218 | output_m_max_id = np.argsort(-output_m_max) 219 | output_m = output_m[output_m_max_id] 220 | img_paths = img_paths[output_m_max_id] 221 | # output_m_max = output_m_max[output_m_max_id] 222 | output_ori = output_ori[output_m_max_id] 223 | conf_class = output_m_max[output_m_max_id] # 置信度 224 | ids = (-output_m).argsort()[:, 0] # 获得每行的类别标签 225 | 226 | if image_features is not None: 227 | image_features = image_features.cpu().detach() 228 | image_features = image_features[output_m_max_id] 229 | 230 | predict_label_dict = {} 231 | predict_conf_dict = {} 232 | if image_features is not None: 233 | img_features = image_features 234 | for img_path, img_feature, conf, logit in zip(img_paths_class, img_features, conf_class, output_class): 235 | if conf > conf_threshold: 236 | predict_label_dict[img_path] = [id, img_feature, conf, logit] 237 | else: 238 | for img_path, id, conf in zip(img_paths, ids, conf_class): 239 | if conf > conf_threshold: 240 | predict_label_dict[img_path] = id 241 | predict_conf_dict[img_path] = conf 242 | return predict_label_dict, predict_conf_dict 243 | 244 | 245 | def caculate_noise_rate(predict_label_dict, train_loader, trainer, sample_level=False): 246 | gt_label_dict = {} 247 | for batch_idx, batch in enumerate(train_loader): 248 | input, label, impath = trainer.parse_batch_test_with_impath(batch) 249 | for l, ip in zip(label, impath): 250 | if '/data/' in ip: 251 | ip = './data/' + ip.split('/data/')[1] 252 | gt_label_dict[ip] = l 253 | 254 | # print('gt_label_dict', len(gt_label_dict)) 255 | # print('gt_label_dict', gt_label_dict) 256 | total = 0 257 | correct = 0 258 | for item in predict_label_dict: 259 | if '/data/' in item: 260 | item = './data/' + item.split('/data/')[1] 261 | if gt_label_dict[item] == predict_label_dict[item]: 262 | correct += 1 263 | total += 1 264 | print('Acc Rate {:.4f}'.format(correct/total)) 265 | 266 | 267 | def caculate_noise_rate_analyze(predict_label_dict, train_loader, trainer, sample_level=False): 268 | gt_label_dict = {} 269 | for batch_idx, batch in enumerate(train_loader): 270 | input, label, impath = trainer.parse_batch_test_with_impath(batch) 271 | for l, ip in zip(label, impath): 272 | ip = './data/' + ip.split('/data/')[1] 273 | gt_label_dict[ip] = l 274 | total = 0 275 | correct = 0 276 | for item in predict_label_dict: 277 | if gt_label_dict[item] == predict_label_dict[item][0]: 278 | correct += 1 279 | if sample_level is True: 280 | print(gt_label_dict[item], 1) 281 | total += 1 282 | print('Acc Rate {:.4f}'.format(correct/total)) 283 | 284 | 285 | def save_outputs(train_loader, trainer, predict_label_dict, dataset_name, text_features, backbone_name=None): 286 | backbone_name = backbone_name.replace('/', '-') 287 | gt_pred_label_dict = {} 288 | for batch_idx, batch in enumerate(train_loader): 289 | input, label, impath = trainer.parse_batch_test_with_impath(batch) 290 | for l, ip in zip(label, impath): 291 | l = l.item() 292 | ip = './data/' + ip.split('/data/')[1] 293 | if l not in gt_pred_label_dict: 294 | gt_pred_label_dict[l] = [] 295 | pred_label = predict_label_dict[ip][0] 296 | pred_v_feature = predict_label_dict[ip][1] 297 | 298 | conf = predict_label_dict[ip][2] 299 | logits = predict_label_dict[ip][3] 300 | gt_pred_label_dict[l].append([ip, pred_label, pred_v_feature, conf, logits]) 301 | else: 302 | pred_label = predict_label_dict[ip][0] 303 | pred_v_feature = predict_label_dict[ip][1] 304 | conf = predict_label_dict[ip][2] 305 | logits = predict_label_dict[ip][3] 306 | gt_pred_label_dict[l].append([ip, pred_label, pred_v_feature, conf, logits]) 307 | 308 | idx = 0 309 | v_distance_dict = {} 310 | v_features = [] 311 | logits_list = [] 312 | for label in gt_pred_label_dict: 313 | avg_feature = None 314 | for item in gt_pred_label_dict[label]: 315 | impath, pred_label, pred_v_feature = item[0], item[1], item[2], 316 | if avg_feature is None: 317 | avg_feature = pred_v_feature.clone() 318 | else: 319 | avg_feature += pred_v_feature.clone() 320 | avg_feature /= len(gt_pred_label_dict[label]) # class center 321 | v_distance_dict_per_class = {} 322 | for item in gt_pred_label_dict[label]: 323 | impath, pred_label, pred_v_feature, conf, logits = item[0], item[1], item[2], item[3], item[4] 324 | v_features.append(pred_v_feature) 325 | logits_list.append(logits) 326 | v_dis = torch.dist(avg_feature, pred_v_feature, p=2) 327 | v_distance_dict_per_class[impath] = [idx, v_dis.item(), conf.item(), pred_label] # id, visual distance, confidence, predicted label 328 | idx += 1 329 | v_distance_dict[label] = v_distance_dict_per_class 330 | 331 | v_features = torch.vstack(v_features) 332 | logits_tensor = torch.vstack(logits_list) 333 | 334 | if not os.path.exists('./analyze_results/{}/'.format(backbone_name)): 335 | os.makedirs('./analyze_results/{}/'.format(backbone_name)) 336 | 337 | torch.save(v_features, './analyze_results/{}/{}_v_feature.pt'.format(backbone_name, dataset_name)) 338 | torch.save(text_features, './analyze_results/{}/{}_l_feature.pt'.format(backbone_name, dataset_name)) 339 | torch.save(logits_tensor, './analyze_results/{}/{}_logits.pt'.format(backbone_name, dataset_name)) 340 | 341 | 342 | with open("./analyze_results/{}/{}.json".format(backbone_name, dataset_name), "w") as outfile: 343 | json.dump(v_distance_dict, outfile) 344 | 345 | 346 | def select_top_k_similarity_per_class_with_high_conf(outputs, img_paths, K=1, image_features=None, repeat=False): 347 | # print(outputs.shape) 348 | outputs = torch.nn.Softmax(dim=1)(outputs) 349 | output_m = outputs.cpu().detach().numpy() 350 | output_ori = outputs.cpu().detach() 351 | output_m_max = output_m.max(axis=1) 352 | output_m_max_id = np.argsort(-output_m_max) 353 | output_m = output_m[output_m_max_id] 354 | img_paths = img_paths[output_m_max_id] 355 | output_m_max = output_m_max[output_m_max_id] 356 | output_ori = output_ori[output_m_max_id] 357 | ids = (-output_m).argsort()[:, 0] # 获得每行的类别标签 358 | 359 | 360 | # 计算类别平均置信度 361 | class_avg_conf = {} 362 | for id in list(set(ids.tolist())): # 标签去重 363 | index = np.where(ids==id) 364 | conf_class = output_m_max[index] # 置信度 365 | class_avg_conf[id] = conf_class.sum() / conf_class.size 366 | # print(class_avg_conf[id]) 367 | 368 | selected_ids = sorted(class_avg_conf.items(), key = lambda kv:(kv[1], kv[0]), reverse=True)[:int(0.8*len(class_avg_conf))] 369 | remain_ids = sorted(class_avg_conf.items(), key = lambda kv:(kv[1], kv[0]), reverse=True)[int(0.8*len(class_avg_conf)):] 370 | 371 | selected_ids = [id[0] for id in selected_ids] 372 | remain_ids = [id[0] for id in remain_ids] 373 | 374 | if image_features is not None: 375 | image_features = image_features.cpu().detach() 376 | image_features = image_features[output_m_max_id] 377 | 378 | predict_label_dict = {} 379 | predict_conf_dict = {} 380 | 381 | # selected_ids.append(0) 382 | 383 | for id in selected_ids: # 标签去重 384 | index = np.where(ids==id) 385 | conf_class = output_m_max[index] # 置信度 386 | output_class = output_ori[index] 387 | img_paths_class = img_paths[index] # 每个类别的路径 388 | 389 | if image_features is not None: 390 | img_features = image_features[index] 391 | if K >= 0: 392 | for img_path, img_feature, conf, logit in zip(img_paths_class[:K], img_features[:K], conf_class[:K], output_class[:K]): 393 | img_path = './data/' + img_path.split('/data/')[1] 394 | predict_label_dict[img_path] = [id, img_feature, conf, logit] 395 | else: 396 | for img_path, img_feature, conf, logit in zip(img_paths_class, img_features, conf_class, output_class): 397 | img_path = './data/' + img_path.split('/data/')[1] 398 | predict_label_dict[img_path] = [id, img_feature, conf, logit] 399 | else: 400 | if K >= 0: 401 | for img_path, conf in zip(img_paths_class[:K], conf_class): 402 | img_path = './data/' + img_path.split('/data/')[1] 403 | predict_label_dict[img_path] = id 404 | predict_conf_dict[img_path] = conf 405 | else: 406 | for img_path, conf in zip(img_paths_class, conf_class): 407 | img_path = './data/' + img_path.split('/data/')[1] 408 | predict_label_dict[img_path] = id 409 | predict_conf_dict[img_path] = conf 410 | return predict_label_dict, predict_conf_dict, remain_ids, selected_ids 411 | 412 | 413 | def select_top_k_similarity_per_class_with_low_conf(outputs, img_paths, conf_threshold, remain_ids, selected_ids, K=2): 414 | # print(outputs.shape) 415 | outputs = torch.nn.Softmax(dim=1)(outputs) 416 | remain_ids_list = remain_ids 417 | remain_ids = np.sort(np.array(remain_ids).astype(np.int)) 418 | remain_logits = -100*torch.ones(outputs.shape).half().cuda() 419 | remain_logits[:, remain_ids] = outputs[:, remain_ids] * 5 420 | remain_logits = torch.nn.Softmax(dim=1)(remain_logits.float()) 421 | outputs = remain_logits 422 | 423 | 424 | output_m = outputs.cpu().detach().numpy() 425 | output_ori = outputs.cpu().detach() 426 | output_m_max = output_m.max(axis=1) 427 | output_m_max_id = np.argsort(-output_m_max) 428 | output_m = output_m[output_m_max_id] 429 | img_paths = img_paths[output_m_max_id] 430 | output_m_max = output_m_max[output_m_max_id] 431 | output_ori = output_ori[output_m_max_id] 432 | ids = (-output_m).argsort()[:, 0] # 获得每行的类别标签 433 | 434 | 435 | predict_label_dict = {} 436 | predict_conf_dict = {} 437 | no_sample_ids = [] 438 | 439 | for id in remain_ids_list: # 标签去重 440 | # print(id) 441 | is_id_have_sample = False 442 | index = np.where(ids==id) 443 | conf_class = output_m_max[index] # 置信度 444 | output_class = output_ori[index] 445 | img_paths_class = img_paths[index] # 每个类别的路径 446 | 447 | if K >= 0: 448 | for img_path, conf in zip(img_paths_class[:K], conf_class[:K]): 449 | print(conf) 450 | if conf > 0.4: 451 | predict_label_dict[img_path] = id 452 | predict_conf_dict[img_path] = conf 453 | is_id_have_sample = True 454 | else: 455 | for img_path, conf in zip(img_paths_class, conf_class): 456 | predict_label_dict[img_path] = id 457 | predict_conf_dict[img_path] = conf 458 | # print(is_id_have_sample) 459 | if is_id_have_sample is False: 460 | no_sample_ids.append(id) 461 | 462 | print(no_sample_ids) 463 | return predict_label_dict, predict_conf_dict, no_sample_ids 464 | 465 | def select_top_k_similarity_per_class_no_smaple(outputs, img_paths, no_sample_ids, K=16): 466 | # print(outputs.shape) 467 | outputs = torch.nn.Softmax(dim=1)(outputs) 468 | output_m = outputs.cpu().detach().numpy() 469 | output_ori = outputs.cpu().detach() 470 | output_m_max = output_m.max(axis=1) 471 | output_m_max_id = np.argsort(-output_m_max) 472 | output_m = output_m[output_m_max_id] 473 | img_paths = img_paths[output_m_max_id] 474 | output_m_max = output_m_max[output_m_max_id] 475 | output_ori = output_ori[output_m_max_id] 476 | ids = (-output_m).argsort()[:, 0] # 获得每行的类别标签 477 | 478 | 479 | predict_label_dict = {} 480 | predict_conf_dict = {} 481 | 482 | for id in no_sample_ids: # 标签去重 483 | print(id) 484 | index = np.where(ids==id) 485 | conf_class = output_m_max[index] # 置信度 486 | output_class = output_ori[index] 487 | img_paths_class = img_paths[index] # 每个类别的路径 488 | 489 | if K >= 0: 490 | for img_path, conf in zip(img_paths_class[:K], conf_class[:K]): 491 | predict_label_dict[img_path] = id 492 | predict_conf_dict[img_path] = conf 493 | else: 494 | for img_path, conf in zip(img_paths_class, conf_class): 495 | predict_label_dict[img_path] = id 496 | predict_conf_dict[img_path] = conf 497 | return predict_label_dict, predict_conf_dict 498 | 499 | 500 | --------------------------------------------------------------------------------