├── clip
├── prompt.txt
├── __init__.py
├── bpe_simple_vocab_16e6.txt.gz
├── __pycache__
│ ├── clip.cpython-37.pyc
│ ├── model.cpython-37.pyc
│ ├── __init__.cpython-37.pyc
│ └── simple_tokenizer.cpython-37.pyc
├── simple_tokenizer.py
├── clip.py
└── model.py
├── data
├── datasets
├── __init__.py
├── __pycache__
│ ├── dtd.cpython-37.pyc
│ ├── sun397.cpython-37.pyc
│ ├── ucf101.cpython-37.pyc
│ ├── __init__.cpython-37.pyc
│ ├── eurosat.cpython-37.pyc
│ ├── food101.cpython-37.pyc
│ ├── imagenet.cpython-37.pyc
│ ├── caltech101.cpython-37.pyc
│ ├── datasetbase.cpython-37.pyc
│ ├── imagenet_a.cpython-37.pyc
│ ├── imagenet_r.cpython-37.pyc
│ ├── imagenetv2.cpython-37.pyc
│ ├── oxford_pets.cpython-37.pyc
│ ├── data_manager.cpython-37.pyc
│ ├── fgvc_aircraft.cpython-37.pyc
│ ├── oxford_flowers.cpython-37.pyc
│ ├── stanford_cars.cpython-37.pyc
│ └── imagenet_sketch.cpython-37.pyc
├── food101.py
├── eurosat.py
├── stanford_cars.py
├── sun397.py
├── oxford_flowers.py
├── dtd.py
├── ucf101.py
├── fgvc_aircraft.py
├── caltech101.py
├── oxford_pets.py
├── data_manager.py
├── imagenet.py
└── datasetbase.py
├── trainers
├── __init__.py
├── __pycache__
│ ├── coop.cpython-37.pyc
│ ├── utils.cpython-37.pyc
│ ├── __init__.cpython-37.pyc
│ ├── hhzsclip.cpython-37.pyc
│ ├── hhsstrainer.cpython-37.pyc
│ └── upltrainer.cpython-37.pyc
├── zsclip.py
├── hhzsclip.py
├── coop.py
└── utils.py
├── requirements.txt
├── configs
├── datasets
│ ├── sseurosat.yaml
│ ├── ssfood101.yaml
│ ├── ssimagenet.yaml
│ ├── sssun397.yaml
│ ├── ssucf101.yaml
│ ├── sscaltech101.yaml
│ ├── ssoxford_pets.yaml
│ ├── ssdtd.yaml
│ ├── ssoxford_flowers.yaml
│ ├── ssfgvc_aircraft.yaml
│ └── ssstanford_cars.yaml
├── upl_default_config
│ ├── __pycache__
│ │ └── upl_default.cpython-37.pyc
│ └── upl_default.py
└── trainers
│ └── UPLTrainer
│ ├── anay_rn101.yaml
│ ├── anay_rn50.yaml
│ ├── anay_rn50x4.yaml
│ ├── anay_ViT_B_16.yaml
│ ├── anay_ViT_B_32.yaml
│ ├── anay_ViT_L_14.yaml
│ ├── anay_rn50x16.yaml
│ ├── anay_rn50x64.yaml
│ └── rn50_ep50.yaml
├── .gitignore
├── temp_analyze_results_miltiple
├── figures
└── overview.png
├── evaluation
├── __pycache__
│ └── evaluator.cpython-37.pyc
└── evaluator.py
├── scripts
├── get_all_info.sh
├── get_info.sh
├── upl_test_existing_logits.sh
└── upl_train.sh
├── readme.md
├── get_info.py
├── upl_train.py
└── upl_test.py
/clip/prompt.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data:
--------------------------------------------------------------------------------
1 | ../CoOp/data
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/trainers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import *
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ftfy
2 | regex
3 | tqdm
4 | matplotlib
--------------------------------------------------------------------------------
/configs/datasets/sseurosat.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "SSEuroSAT"
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | analysis_results_test/*
2 | analyze_results/*
3 | output/*
--------------------------------------------------------------------------------
/configs/datasets/ssfood101.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "SSFood101"
3 |
--------------------------------------------------------------------------------
/configs/datasets/ssimagenet.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "SSImageNet"
--------------------------------------------------------------------------------
/configs/datasets/sssun397.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "SSSUN397"
3 |
--------------------------------------------------------------------------------
/configs/datasets/ssucf101.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "SSUCF101"
3 |
--------------------------------------------------------------------------------
/temp_analyze_results_miltiple:
--------------------------------------------------------------------------------
1 | ../CoOp/temp_analyze_results_miltiple/
--------------------------------------------------------------------------------
/configs/datasets/sscaltech101.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "SSCaltech101"
--------------------------------------------------------------------------------
/configs/datasets/ssoxford_pets.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "SSOxfordPets"
--------------------------------------------------------------------------------
/configs/datasets/ssdtd.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "SSDescribableTextures"
3 |
--------------------------------------------------------------------------------
/configs/datasets/ssoxford_flowers.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "SSOxfordFlowers"
--------------------------------------------------------------------------------
/configs/datasets/ssfgvc_aircraft.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "SSFGVCAircraft"
3 |
--------------------------------------------------------------------------------
/configs/datasets/ssstanford_cars.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "SSStanfordCars"
3 |
--------------------------------------------------------------------------------
/figures/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/figures/overview.png
--------------------------------------------------------------------------------
/clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/clip/__pycache__/clip.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/clip/__pycache__/clip.cpython-37.pyc
--------------------------------------------------------------------------------
/clip/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/clip/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/clip/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/clip/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/dtd.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/dtd.cpython-37.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/coop.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/trainers/__pycache__/coop.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/sun397.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/sun397.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/ucf101.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/ucf101.cpython-37.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/trainers/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/eurosat.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/eurosat.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/food101.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/food101.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/imagenet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/imagenet.cpython-37.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/trainers/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/hhzsclip.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/trainers/__pycache__/hhzsclip.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/caltech101.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/caltech101.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/datasetbase.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/datasetbase.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/imagenet_a.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/imagenet_a.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/imagenet_r.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/imagenet_r.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/imagenetv2.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/imagenetv2.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/oxford_pets.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/oxford_pets.cpython-37.pyc
--------------------------------------------------------------------------------
/evaluation/__pycache__/evaluator.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/evaluation/__pycache__/evaluator.cpython-37.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/hhsstrainer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/trainers/__pycache__/hhsstrainer.cpython-37.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/upltrainer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/trainers/__pycache__/upltrainer.cpython-37.pyc
--------------------------------------------------------------------------------
/clip/__pycache__/simple_tokenizer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/clip/__pycache__/simple_tokenizer.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/data_manager.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/data_manager.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/fgvc_aircraft.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/fgvc_aircraft.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/oxford_flowers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/oxford_flowers.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/stanford_cars.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/stanford_cars.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/imagenet_sketch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/datasets/__pycache__/imagenet_sketch.cpython-37.pyc
--------------------------------------------------------------------------------
/configs/upl_default_config/__pycache__/upl_default.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyhuang2022/UPL/HEAD/configs/upl_default_config/__pycache__/upl_default.cpython-37.pyc
--------------------------------------------------------------------------------
/scripts/get_all_info.sh:
--------------------------------------------------------------------------------
1 |
2 | DATASET=$1
3 |
4 |
5 | bash get_info.sh ${DATASET} anay_rn50 end 16 -1 False
6 | bash get_info.sh ${DATASET} anay_rn50x4 end 16 -1 False
7 | bash get_info.sh ${DATASET} anay_rn50x16 end 16 -1 False
8 | bash get_info.sh ${DATASET} anay_rn50x64 end 16 -1 False
9 | bash get_info.sh ${DATASET} anay_rn101 end 16 -1 False
10 | bash get_info.sh ${DATASET} anay_ViT_B_16 end 16 -1 False
11 | bash get_info.sh ${DATASET} anay_ViT_B_32 end 16 -1 False
12 | bash get_info.sh ${DATASET} anay_ViT_L_14 end 16 -1 False
--------------------------------------------------------------------------------
/configs/trainers/UPLTrainer/anay_rn101.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 | # TRANSFORMS: ["center_crop", "normalize"]"cutout",
15 |
16 | OPTIM:
17 | NAME: "sgd"
18 | LR: 0.002
19 | MAX_EPOCH: 50
20 | LR_SCHEDULER: "cosine"
21 | WARMUP_EPOCH: 1
22 | WARMUP_TYPE: "constant"
23 | WARMUP_CONS_LR: 1e-5
24 |
25 | TRAIN:
26 | PRINT_FREQ: 5
27 |
28 | TEST:
29 | Analyze_Result_Path: './analysis_results_test/'
30 | FINAL_MODEL: "last_val"
31 | PER_CLASS_RESULT: True
32 |
33 | MODEL:
34 | BACKBONE:
35 | NAME: "RN101"
36 |
37 | TRAINER:
38 | UPLTrainer:
39 | CTX_INIT: "a photo of a"
--------------------------------------------------------------------------------
/configs/trainers/UPLTrainer/anay_rn50.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 | # TRANSFORMS: ["center_crop", "normalize"]"cutout",
15 |
16 | OPTIM:
17 | NAME: "sgd"
18 | LR: 0.002
19 | MAX_EPOCH: 50
20 | LR_SCHEDULER: "cosine"
21 | WARMUP_EPOCH: 1
22 | WARMUP_TYPE: "constant"
23 | WARMUP_CONS_LR: 1e-5
24 |
25 | TRAIN:
26 | PRINT_FREQ: 5
27 |
28 | TEST:
29 | Analyze_Result_Path: './analysis_results_test/'
30 | FINAL_MODEL: "last_val"
31 | PER_CLASS_RESULT: True
32 |
33 | MODEL:
34 | BACKBONE:
35 | NAME: "RN50"
36 |
37 | TRAINER:
38 | UPLTrainer:
39 | CTX_INIT: "a photo of a"
--------------------------------------------------------------------------------
/configs/trainers/UPLTrainer/anay_rn50x4.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (288, 288)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 | # TRANSFORMS: ["center_crop", "normalize"]"cutout",
15 |
16 | OPTIM:
17 | NAME: "sgd"
18 | LR: 0.002
19 | MAX_EPOCH: 50
20 | LR_SCHEDULER: "cosine"
21 | WARMUP_EPOCH: 1
22 | WARMUP_TYPE: "constant"
23 | WARMUP_CONS_LR: 1e-5
24 |
25 | TRAIN:
26 | PRINT_FREQ: 5
27 |
28 | TEST:
29 | Analyze_Result_Path: './analysis_results_test/'
30 | FINAL_MODEL: "last_val"
31 | PER_CLASS_RESULT: True
32 |
33 | MODEL:
34 | BACKBONE:
35 | NAME: "RN50x4"
36 |
37 | TRAINER:
38 | UPLTrainer:
39 | CTX_INIT: "a photo of a"
--------------------------------------------------------------------------------
/configs/trainers/UPLTrainer/anay_ViT_B_16.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 | # TRANSFORMS: ["center_crop", "normalize"]"cutout",
15 |
16 | OPTIM:
17 | NAME: "sgd"
18 | LR: 0.002
19 | MAX_EPOCH: 50
20 | LR_SCHEDULER: "cosine"
21 | WARMUP_EPOCH: 1
22 | WARMUP_TYPE: "constant"
23 | WARMUP_CONS_LR: 1e-5
24 |
25 | TRAIN:
26 | PRINT_FREQ: 5
27 |
28 | TEST:
29 | Analyze_Result_Path: './analysis_results_test/'
30 | FINAL_MODEL: "last_val"
31 | PER_CLASS_RESULT: True
32 |
33 | MODEL:
34 | BACKBONE:
35 | NAME: "ViT-B/16"
36 |
37 | TRAINER:
38 | UPLTrainer:
39 | CTX_INIT: "a photo of a"
--------------------------------------------------------------------------------
/configs/trainers/UPLTrainer/anay_ViT_B_32.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 | # TRANSFORMS: ["center_crop", "normalize"]"cutout",
15 |
16 | OPTIM:
17 | NAME: "sgd"
18 | LR: 0.002
19 | MAX_EPOCH: 50
20 | LR_SCHEDULER: "cosine"
21 | WARMUP_EPOCH: 1
22 | WARMUP_TYPE: "constant"
23 | WARMUP_CONS_LR: 1e-5
24 |
25 | TRAIN:
26 | PRINT_FREQ: 5
27 |
28 | TEST:
29 | Analyze_Result_Path: './analysis_results_test/'
30 | FINAL_MODEL: "last_val"
31 | PER_CLASS_RESULT: True
32 |
33 | MODEL:
34 | BACKBONE:
35 | NAME: "ViT-B/32"
36 |
37 | TRAINER:
38 | UPLTrainer:
39 | CTX_INIT: "a photo of a"
--------------------------------------------------------------------------------
/configs/trainers/UPLTrainer/anay_ViT_L_14.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 | # TRANSFORMS: ["center_crop", "normalize"]"cutout",
15 |
16 | OPTIM:
17 | NAME: "sgd"
18 | LR: 0.002
19 | MAX_EPOCH: 50
20 | LR_SCHEDULER: "cosine"
21 | WARMUP_EPOCH: 1
22 | WARMUP_TYPE: "constant"
23 | WARMUP_CONS_LR: 1e-5
24 |
25 | TRAIN:
26 | PRINT_FREQ: 5
27 |
28 | TEST:
29 | Analyze_Result_Path: './analysis_results_test/'
30 | FINAL_MODEL: "last_val"
31 | PER_CLASS_RESULT: True
32 |
33 | MODEL:
34 | BACKBONE:
35 | NAME: "ViT-L/14"
36 |
37 | TRAINER:
38 | UPLTrainer:
39 | CTX_INIT: "a photo of a"
--------------------------------------------------------------------------------
/configs/trainers/UPLTrainer/anay_rn50x16.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (384, 384)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 | # TRANSFORMS: ["center_crop", "normalize"]"cutout",
15 |
16 | OPTIM:
17 | NAME: "sgd"
18 | LR: 0.002
19 | MAX_EPOCH: 50
20 | LR_SCHEDULER: "cosine"
21 | WARMUP_EPOCH: 1
22 | WARMUP_TYPE: "constant"
23 | WARMUP_CONS_LR: 1e-5
24 |
25 | TRAIN:
26 | PRINT_FREQ: 5
27 |
28 | TEST:
29 | Analyze_Result_Path: './analysis_results_test/'
30 | FINAL_MODEL: "last_val"
31 | PER_CLASS_RESULT: True
32 |
33 | MODEL:
34 | BACKBONE:
35 | NAME: "RN50x16"
36 |
37 | TRAINER:
38 | UPLTrainer:
39 | CTX_INIT: "a photo of a"
--------------------------------------------------------------------------------
/configs/trainers/UPLTrainer/anay_rn50x64.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (448, 448)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 | # TRANSFORMS: ["center_crop", "normalize"]"cutout",
15 |
16 | OPTIM:
17 | NAME: "sgd"
18 | LR: 0.002
19 | MAX_EPOCH: 50
20 | LR_SCHEDULER: "cosine"
21 | WARMUP_EPOCH: 1
22 | WARMUP_TYPE: "constant"
23 | WARMUP_CONS_LR: 1e-5
24 |
25 | TRAIN:
26 | PRINT_FREQ: 5
27 |
28 | TEST:
29 | Analyze_Result_Path: './analysis_results_test/'
30 | FINAL_MODEL: "last_val"
31 | PER_CLASS_RESULT: True
32 |
33 | MODEL:
34 | BACKBONE:
35 | NAME: "RN50x64"
36 |
37 | TRAINER:
38 | UPLTrainer:
39 | CTX_INIT: "a photo of a"
--------------------------------------------------------------------------------
/configs/trainers/UPLTrainer/rn50_ep50.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 | # TRANSFORMS: ["center_crop", "normalize"]"cutout",
15 |
16 | OPTIM:
17 | NAME: "sgd"
18 | LR: 0.002
19 | MAX_EPOCH: 50
20 | LR_SCHEDULER: "cosine"
21 | WARMUP_EPOCH: 1
22 | WARMUP_TYPE: "constant"
23 | WARMUP_CONS_LR: 1e-5
24 |
25 | TRAIN:
26 | PRINT_FREQ: 5
27 |
28 | TEST:
29 | Analyze_Result_Path: './analysis_results_test/'
30 | FINAL_MODEL: "last_val"
31 | PER_CLASS_RESULT: True
32 |
33 |
34 | MODEL:
35 | BACKBONE:
36 | NAME: "RN50"
37 | # PSEUDO_LABEL_MODELS: ['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B-16', 'ViT-B-32', 'ViT-L-14']
38 | PSEUDO_LABEL_MODELS: ['RN50']
39 |
--------------------------------------------------------------------------------
/scripts/get_info.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ..
4 |
5 | # custom config
6 | DATA=./data
7 | TRAINER=UPLTrainer
8 |
9 | DATASET=$1
10 | CFG=$2 # config file
11 | CTP=$3 # class token position (end or middle)
12 | NCTX=$4 # number of context tokens
13 | SHOTS=$5 # number of shots (1, 2, 4, 8, 16)
14 | CSC=$6 # class-specific context (False or True)
15 |
16 |
17 | for SEED in 1
18 | do
19 | DIR=./output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots_random_init/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED}
20 |
21 | echo "Run this job and save the output to ${DIR}"
22 | python get_info.py \
23 | --root ${DATA} \
24 | --seed ${SEED} \
25 | --trainer ${TRAINER} \
26 | --dataset-config-file configs/datasets/${DATASET}.yaml \
27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
28 | --output-dir ${DIR} \
29 | TRAINER.UPLTrainer.N_CTX ${NCTX} \
30 | TRAINER.UPLTrainer.CSC ${CSC} \
31 | TRAINER.UPLTrainer.CLASS_TOKEN_POSITION ${CTP} \
32 | DATASET.NUM_SHOTS ${SHOTS}
33 | done
34 |
--------------------------------------------------------------------------------
/scripts/upl_test_existing_logits.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ..
4 |
5 | # custom config
6 | DATA=./data
7 | TRAINER=UPLTrainer
8 |
9 | DATASET=$1
10 | CFG=$2 # config file
11 | CTP=$3 # class token position (end or middle)
12 | NCTX=$4 # number of context tokens
13 | SHOTS=$5 # number of shots (1, 2, 4, 8, 16)
14 | CSC=$6 # class-specific context (False or True)
15 | CLASS_EQULE=$7 # CLASS_EQULE True of False
16 |
17 |
18 | DIR=./output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots_EQULE_${CLASS_EQULE}_CONF_THRESHOLD_${CONF_THRESHOLD}_RN50_temp/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED}
19 |
20 | python upl_test.py \
21 | --root ${DATA} \
22 | --seed 1 \
23 | --trainer ${TRAINER} \
24 | --dataset-config-file configs/datasets/${DATASET}.yaml \
25 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
26 | --output-dir ${DIR} \
27 | TRAINER.UPLTrainer.N_CTX ${NCTX} \
28 | TRAINER.UPLTrainer.CSC ${CSC} \
29 | TRAINER.UPLTrainer.CLASS_TOKEN_POSITION ${CTP} \
30 | DATASET.NUM_SHOTS ${SHOTS} \
31 | DATASET.CLASS_EQULE ${CLASS_EQULE}
32 |
33 |
34 |
35 |
--------------------------------------------------------------------------------
/scripts/upl_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ..
4 |
5 | # custom config
6 | DATA=./data
7 | TRAINER=UPLTrainer
8 |
9 | DATASET=$1
10 | CFG=$2 # config file
11 | CTP=$3 # class token position (end or middle)
12 | NCTX=$4 # number of context tokens
13 | SHOTS=$5 # number of shots (1, 2, 4, 8, 16)
14 | CSC=$6 # class-specific context (False or True)
15 | CLASS_EQULE=$7 # CLASS_EQULE True of False
16 | TAG=$8 # log tag (multiple_models_random_init or rn50_random_init)
17 |
18 |
19 | for SEED in {1..16}
20 | do
21 | DIR=./output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots_EQULE_${CLASS_EQULE}_${CONF_THRESHOLD}_${TAG}/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED}
22 | if [ -d "$DIR" ]; then
23 | echo "Results are available in ${DIR}. Skip this job"
24 | else
25 | echo "Run this job and save the output to ${DIR}"
26 | python upl_train.py \
27 | --root ${DATA} \
28 | --seed ${SEED} \
29 | --trainer ${TRAINER} \
30 | --dataset-config-file configs/datasets/${DATASET}.yaml \
31 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
32 | --output-dir ${DIR} \
33 | TRAINER.UPLTrainer.N_CTX ${NCTX} \
34 | TRAINER.UPLTrainer.CSC ${CSC} \
35 | TRAINER.UPLTrainer.CLASS_TOKEN_POSITION ${CTP} \
36 | DATASET.NUM_SHOTS ${SHOTS} \
37 | DATASET.CLASS_EQULE ${CLASS_EQULE}
38 | fi
39 | done
40 |
41 |
42 | # #!/bin/bash
43 |
44 | # cd ..
45 |
46 | # # custom config
47 | # DATA=./data
48 | # TRAINER=ZeroshotCLIP
49 | # DATASET=$1
50 | # CFG=$2 # rn50, rn101, vit_b32 or vit_b16
51 |
52 | # python sstrain.py \
53 | # --root ${DATA} \
54 | # --trainer ${TRAINER} \
55 | # --dataset-config-file configs/datasets/${DATASET}.yaml \
56 | # --config-file configs/trainers/HHTrainer/${CFG}.yaml \
57 | # --output-dir output/${TRAINER}/${CFG}/zeroshot/${DATASET} \
58 | # --eval-only
--------------------------------------------------------------------------------
/datasets/food101.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
4 |
5 | from .oxford_pets import OxfordPets
6 | from .dtd import DescribableTextures as DTD
7 | from .datasetbase import UPLDatasetBase
8 |
9 |
10 | @DATASET_REGISTRY.register()
11 | class Food101(DatasetBase):
12 |
13 | dataset_dir = "food-101"
14 |
15 | def __init__(self, cfg):
16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
17 | self.dataset_dir = os.path.join(root, self.dataset_dir)
18 | self.image_dir = os.path.join(self.dataset_dir, "images")
19 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_Food101.json")
20 |
21 | if os.path.exists(self.split_path):
22 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
23 | else:
24 | train, val, test = DTD.read_and_split_data(self.image_dir)
25 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
26 |
27 | num_shots = cfg.DATASET.NUM_SHOTS
28 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
29 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
30 |
31 | super().__init__(train_x=train, val=val, test=test)
32 |
33 |
34 | @DATASET_REGISTRY.register()
35 | class SSFood101(UPLDatasetBase):
36 |
37 | dataset_dir = "food-101"
38 |
39 | def __init__(self, cfg):
40 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
41 | self.dataset_dir = os.path.join(root, self.dataset_dir)
42 | self.image_dir = os.path.join(self.dataset_dir, "images")
43 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_Food101.json")
44 |
45 | if os.path.exists(self.split_path):
46 | train, val, test = self.read_split(self.split_path, self.image_dir)
47 | else:
48 | train, val, test = DTD.read_and_split_data(self.image_dir)
49 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
50 |
51 | sstrain = self.read_sstrain_data(self.split_path, self.image_dir)
52 | num_shots = cfg.DATASET.NUM_SHOTS
53 | train = self.generate_fewshot_dataset(train, num_shots=-1)
54 | val = self.generate_fewshot_dataset(val, num_shots=-1)
55 | super().__init__(train_x=train, val = val, test=test, sstrain=sstrain)
--------------------------------------------------------------------------------
/datasets/eurosat.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
4 |
5 | from .oxford_pets import OxfordPets
6 | from .dtd import DescribableTextures as DTD
7 | from .datasetbase import UPLDatasetBase
8 |
9 | NEW_CNAMES = {
10 | "AnnualCrop": "Annual Crop Land",
11 | "Forest": "Forest",
12 | "HerbaceousVegetation": "Herbaceous Vegetation Land",
13 | "Highway": "Highway or Road",
14 | "Industrial": "Industrial Buildings",
15 | "Pasture": "Pasture Land",
16 | "PermanentCrop": "Permanent Crop Land",
17 | "Residential": "Residential Buildings",
18 | "River": "River",
19 | "SeaLake": "Sea or Lake",
20 | }
21 |
22 |
23 | @DATASET_REGISTRY.register()
24 | class EuroSAT(DatasetBase):
25 |
26 | dataset_dir = "eurosat"
27 |
28 | def __init__(self, cfg):
29 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
30 | self.dataset_dir = os.path.join(root, self.dataset_dir)
31 | self.image_dir = os.path.join(self.dataset_dir, "2750")
32 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_EuroSAT.json")
33 |
34 | if os.path.exists(self.split_path):
35 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
36 | else:
37 | train, val, test = DTD.read_and_split_data(self.image_dir, new_cnames=NEW_CNAMES)
38 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
39 |
40 | num_shots = cfg.DATASET.NUM_SHOTS
41 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
42 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
43 |
44 | super().__init__(train_x=train, val=val, test=test)
45 |
46 | def update_classname(self, dataset_old):
47 | dataset_new = []
48 | for item_old in dataset_old:
49 | cname_old = item_old.classname
50 | cname_new = NEW_CLASSNAMES[cname_old]
51 | item_new = Datum(impath=item_old.impath, label=item_old.label, classname=cname_new)
52 | dataset_new.append(item_new)
53 | return dataset_new
54 |
55 | @DATASET_REGISTRY.register()
56 | class SSEuroSAT(UPLDatasetBase):
57 | dataset_dir = 'eurosat'
58 | def __init__(self, cfg):
59 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
60 | self.dataset_dir = os.path.join(root, self.dataset_dir)
61 | self.image_dir = os.path.join(self.dataset_dir, "2750")
62 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_EuroSAT.json")
63 |
64 |
65 | if os.path.exists(self.split_path):
66 | train, val, test = self.read_split(self.split_path, self.image_dir)
67 | else:
68 | train, val, test = DTD.read_and_split_data(
69 | self.image_dir,
70 | ignored=IGNORED,
71 | new_cnames=NEW_CNAMES
72 | )
73 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
74 | sstrain = self.read_sstrain_data(self.split_path, self.image_dir)
75 | num_shots = cfg.DATASET.NUM_SHOTS
76 | train = self.generate_fewshot_dataset(train, num_shots=-1)
77 |
78 | val = self.generate_fewshot_dataset(val, num_shots=-1)
79 | super().__init__(train_x=train, val = val, test=test, sstrain=sstrain)
80 |
81 |
--------------------------------------------------------------------------------
/trainers/zsclip.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from dassl.engine import TRAINER_REGISTRY, TrainerX
5 | from dassl.optim import build_optimizer, build_lr_scheduler
6 |
7 | from clip import clip
8 | from clip.model import convert_weights
9 |
10 | from .coop import load_clip_to_cpu
11 |
12 | CUSTOM_TEMPLATES = {
13 | "OxfordPets": "a photo of a {}, a type of pet.",
14 | "OxfordFlowers": "a photo of a {}, a type of flower.",
15 | "FGVCAircraft": "a photo of a {}, a type of aircraft.",
16 | "DescribableTextures": "{} texture.",
17 | "EuroSAT": "a centered satellite photo of {}.",
18 | "StanfordCars": "a photo of a {}.",
19 | "Food101": "a photo of {}, a type of food.",
20 | "SUN397": "a photo of a {}.",
21 | "Caltech101": "a photo of a {}.",
22 | "UCF101": "a photo of a person doing {}.",
23 | "ImageNet": "a photo of a {}.",
24 | "ImageNetSketch": "a photo of a {}.",
25 | "ImageNetV2": "a photo of a {}.",
26 | "ImageNetA": "a photo of a {}.",
27 | "ImageNetR": "a photo of a {}.",
28 | }
29 |
30 |
31 | @TRAINER_REGISTRY.register()
32 | class ZeroshotCLIP(TrainerX):
33 | def build_model(self):
34 | cfg = self.cfg
35 | classnames = self.dm.dataset.classnames
36 |
37 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
38 | clip_model = load_clip_to_cpu(cfg)
39 | clip_model.to(self.device)
40 |
41 | temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
42 | prompts = [temp.format(c.replace("_", " ")) for c in classnames]
43 | print(f"Prompts: {prompts}")
44 | prompts = torch.cat([clip.tokenize(p) for p in prompts])
45 | prompts = prompts.to(self.device)
46 |
47 | with torch.no_grad():
48 | text_features = clip_model.encode_text(prompts)
49 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
50 |
51 | self.text_features = text_features
52 | self.clip_model = clip_model
53 |
54 | def model_inference(self, image):
55 | image_features = self.clip_model.encode_image(image)
56 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
57 | logit_scale = self.clip_model.logit_scale.exp()
58 | logits = logit_scale * image_features @ self.text_features.t()
59 | return logits
60 |
61 |
62 | @TRAINER_REGISTRY.register()
63 | class ZeroshotCLIP2(ZeroshotCLIP):
64 | """Prompt ensembling."""
65 |
66 | # templates = IMAGENET_TEMPLATES
67 | templates = IMAGENET_TEMPLATES_SELECT
68 |
69 | def build_model(self):
70 | cfg = self.cfg
71 | classnames = self.dm.dataset.classnames
72 |
73 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
74 | clip_model = load_clip_to_cpu(cfg)
75 | clip_model.to(self.device)
76 |
77 | for params in clip_model.parameters():
78 | params.requires_grad_(False)
79 |
80 | # add custom-made prompt
81 | if cfg.DATASET.NAME != "ImageNet":
82 | self.templates += [CUSTOM_TEMPLATES[cfg.DATASET.NAME]]
83 |
84 | num_temp = len(self.templates)
85 | print(f"Prompt ensembling (n={num_temp})")
86 |
87 | mean_text_features = 0
88 | for i, temp in enumerate(self.templates):
89 | prompts = [temp.format(c.replace("_", " ")) for c in classnames]
90 | prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(self.device)
91 | text_features = clip_model.encode_text(prompts)
92 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
93 | mean_text_features = mean_text_features + text_features
94 | mean_text_features = mean_text_features / num_temp
95 | mean_text_features = mean_text_features / mean_text_features.norm(dim=-1, keepdim=True)
96 |
97 | self.text_features = mean_text_features
98 | self.clip_model = clip_model
99 |
--------------------------------------------------------------------------------
/datasets/stanford_cars.py:
--------------------------------------------------------------------------------
1 | import os
2 | from scipy.io import loadmat
3 |
4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
5 |
6 | from .oxford_pets import OxfordPets
7 | from .datasetbase import UPLDatasetBase
8 |
9 |
10 | @DATASET_REGISTRY.register()
11 | class StanfordCars(DatasetBase):
12 |
13 | dataset_dir = "stanford_cars"
14 |
15 | def __init__(self, cfg):
16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
17 | self.dataset_dir = os.path.join(root, self.dataset_dir)
18 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_StanfordCars.json")
19 |
20 | if os.path.exists(self.split_path):
21 | train, val, test = OxfordPets.read_split(self.split_path, self.dataset_dir)
22 | else:
23 | trainval_file = os.path.join(self.dataset_dir, "devkit", "cars_train_annos.mat")
24 | test_file = os.path.join(self.dataset_dir, "cars_test_annos_withlabels.mat")
25 | meta_file = os.path.join(self.dataset_dir, "devkit", "cars_meta.mat")
26 | trainval = self.read_data("cars_train", trainval_file, meta_file)
27 | test = self.read_data("cars_test", test_file, meta_file)
28 | train, val = OxfordPets.split_trainval(trainval)
29 | OxfordPets.save_split(train, val, test, self.split_path, self.dataset_dir)
30 |
31 | num_shots = cfg.DATASET.NUM_SHOTS
32 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
33 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
34 |
35 | super().__init__(train_x=train, val=val, test=test)
36 |
37 | def read_data(self, image_dir, anno_file, meta_file):
38 | anno_file = loadmat(anno_file)["annotations"][0]
39 | meta_file = loadmat(meta_file)["class_names"][0]
40 | items = []
41 |
42 | for i in range(len(anno_file)):
43 | imname = anno_file[i]["fname"][0]
44 | impath = os.path.join(self.dataset_dir, image_dir, imname)
45 | label = anno_file[i]["class"][0, 0]
46 | label = int(label) - 1 # convert to 0-based index
47 | classname = meta_file[label][0]
48 | names = classname.split(" ")
49 | year = names.pop(-1)
50 | names.insert(0, year)
51 | classname = " ".join(names)
52 | item = Datum(impath=impath, label=label, classname=classname)
53 | items.append(item)
54 |
55 | return items
56 |
57 | @DATASET_REGISTRY.register()
58 | class SSStanfordCars(UPLDatasetBase):
59 |
60 | dataset_dir = 'stanford_cars'
61 |
62 | def __init__(self, cfg):
63 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
64 | self.dataset_dir = os.path.join(root, self.dataset_dir)
65 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_StanfordCars.json")
66 | self.image_dir = self.dataset_dir
67 |
68 | if os.path.exists(self.split_path):
69 | train, val, test = self.read_split(self.split_path, self.dataset_dir)
70 | else:
71 | trainval_file = os.path.join(self.dataset_dir, "devkit", "cars_train_annos.mat")
72 | test_file = os.path.join(self.dataset_dir, "cars_test_annos_withlabels.mat")
73 | meta_file = os.path.join(self.dataset_dir, "devkit", "cars_meta.mat")
74 | trainval = self.read_data("cars_train", trainval_file, meta_file)
75 | test = self.read_data("cars_test", test_file, meta_file)
76 | train, val = OxfordPets.split_trainval(trainval)
77 | OxfordPets.save_split(train, val, test, self.split_path, self.dataset_dir)
78 |
79 | sstrain = self.read_sstrain_data(self.split_path, self.dataset_dir)
80 | num_shots = cfg.DATASET.NUM_SHOTS
81 | train = self.generate_fewshot_dataset(train, num_shots=-1)
82 | val = self.generate_fewshot_dataset(val, num_shots=-1)
83 | super().__init__(train_x=train, val = val, test=test, sstrain=sstrain)
--------------------------------------------------------------------------------
/evaluation/evaluator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os.path as osp
3 | from collections import OrderedDict, defaultdict
4 | import torch
5 | from sklearn.metrics import f1_score, confusion_matrix
6 |
7 | from dassl.evaluation.evaluator import Classification
8 | from dassl.evaluation.build import EVALUATOR_REGISTRY
9 |
10 | @EVALUATOR_REGISTRY.register()
11 | class UPLClassification(Classification):
12 |
13 | def process(self, mo, gt, per_image_txt_writer, per_class_txt_writer):
14 | # mo (torch.Tensor): model output [batch, num_classes]
15 | # gt (torch.LongTensor): ground truth [batch]
16 | self.per_image_txt_writer = per_image_txt_writer
17 | self.per_class_txt_writer = per_class_txt_writer
18 | pred = mo.max(1)[1]
19 | matches = pred.eq(gt).float()
20 | self._correct += int(matches.sum().item())
21 | self._total += gt.shape[0]
22 |
23 | self._y_true.extend(gt.data.cpu().numpy().tolist())
24 | self._y_pred.extend(pred.data.cpu().numpy().tolist())
25 |
26 | if self._per_class_res is not None:
27 | for i, (label, pred) in enumerate(zip(gt, mo)):
28 | label = label.item()
29 | label_name = self._lab2cname[label]
30 | matches_i = int(matches[i].item())
31 | self._per_class_res[label].append(matches_i)
32 | # print(i, label_name, label, matches_i, pred.data.cpu().numpy().tolist())
33 | write_line = str(i) + "," + str(label_name) + "," + str(label) + "," + str(matches_i) + "," + str(pred.data.cpu().numpy().tolist())
34 | self.per_image_txt_writer.write(write_line+'\n')
35 |
36 | def evaluate(self):
37 | results = OrderedDict()
38 | acc = 100.0 * self._correct / self._total
39 | err = 100.0 - acc
40 | macro_f1 = 100.0 * f1_score(
41 | self._y_true,
42 | self._y_pred,
43 | average="macro",
44 | labels=np.unique(self._y_true)
45 | )
46 |
47 | # The first value will be returned by trainer.test()
48 | results["accuracy"] = acc
49 | results["error_rate"] = err
50 | results["macro_f1"] = macro_f1
51 |
52 | print(
53 | "=> result\n"
54 | f"* total: {self._total:,}\n"
55 | f"* correct: {self._correct:,}\n"
56 | f"* accuracy: {acc:.2f}%\n"
57 | f"* error: {err:.2f}%\n"
58 | f"* macro_f1: {macro_f1:.2f}%"
59 | )
60 |
61 | if self._per_class_res is not None:
62 | labels = list(self._per_class_res.keys())
63 | labels.sort()
64 |
65 | print("=> per-class result")
66 | accs = []
67 |
68 | for label in self._lab2cname:
69 | classname = self._lab2cname[label]
70 | res = self._per_class_res[label]
71 | correct = sum(res)
72 | total = len(res)
73 | acc = 100.0 * correct / total
74 | accs.append(acc)
75 | print(
76 | "* class: {} ({})\t"
77 | "total: {:,}\t"
78 | "correct: {:,}\t"
79 | "acc: {:.2f}%".format(
80 | label, classname, total, correct, acc
81 | )
82 | )
83 | write_line = "* class: {} ({}), total: {:,}, correct: {:,}, acc: {:.2f}% \n".format(label, classname, total, correct, acc)
84 | self.per_class_txt_writer.write(write_line)
85 | mean_acc = np.mean(accs)
86 | print("* average: {:.2f}%".format(mean_acc))
87 |
88 | results["perclass_accuracy"] = mean_acc
89 |
90 | if self.cfg.TEST.COMPUTE_CMAT:
91 | cmat = confusion_matrix(
92 | self._y_true, self._y_pred, normalize="true"
93 | )
94 | save_path = osp.join(self.cfg.OUTPUT_DIR, "cmat.pt")
95 | torch.save(cmat, save_path)
96 | print('Confusion matrix is saved to "{}"'.format(save_path))
97 |
98 | return results
99 |
--------------------------------------------------------------------------------
/datasets/sun397.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
4 |
5 | from .oxford_pets import OxfordPets
6 | from .datasetbase import UPLDatasetBase
7 |
8 |
9 | @DATASET_REGISTRY.register()
10 | class SUN397(DatasetBase):
11 |
12 | dataset_dir = "sun397"
13 |
14 | def __init__(self, cfg):
15 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
16 | self.dataset_dir = os.path.join(root, self.dataset_dir)
17 | self.image_dir = os.path.join(self.dataset_dir, "SUN397")
18 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_SUN397.json")
19 |
20 | if os.path.exists(self.split_path):
21 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
22 | else:
23 | classnames = []
24 | with open(os.path.join(self.dataset_dir, "ClassName.txt"), "r") as f:
25 | lines = f.readlines()
26 | for line in lines:
27 | line = line.strip()[1:] # remove /
28 | classnames.append(line)
29 | cname2lab = {c: i for i, c in enumerate(classnames)}
30 | trainval = self.read_data(cname2lab, "Training_01.txt")
31 | test = self.read_data(cname2lab, "Testing_01.txt")
32 | train, val = OxfordPets.split_trainval(trainval)
33 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
34 |
35 | num_shots = cfg.DATASET.NUM_SHOTS
36 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
37 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
38 |
39 | super().__init__(train_x=train, val=val, test=test)
40 |
41 | def read_data(self, cname2lab, text_file):
42 | text_file = os.path.join(self.dataset_dir, text_file)
43 | items = []
44 |
45 | with open(text_file, "r") as f:
46 | lines = f.readlines()
47 | for line in lines:
48 | imname = line.strip()[1:] # remove /
49 | classname = os.path.dirname(imname)
50 | label = cname2lab[classname]
51 | impath = os.path.join(self.image_dir, imname)
52 |
53 | names = classname.split("/")[1:] # remove 1st letter
54 | names = names[::-1] # put words like indoor/outdoor at first
55 | classname = " ".join(names)
56 |
57 | item = Datum(impath=impath, label=label, classname=classname)
58 | items.append(item)
59 |
60 | return items
61 |
62 |
63 |
64 | @DATASET_REGISTRY.register()
65 | class SSSUN397(UPLDatasetBase):
66 |
67 | dataset_dir = "sun397"
68 |
69 | def __init__(self, cfg):
70 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
71 | self.dataset_dir = os.path.join(root, self.dataset_dir)
72 | self.image_dir = os.path.join(self.dataset_dir, "SUN397")
73 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_SUN397.json")
74 |
75 | if os.path.exists(self.split_path):
76 | train, val, test = self.read_split(self.split_path, self.image_dir)
77 | else:
78 | classnames = []
79 | with open(os.path.join(self.dataset_dir, "ClassName.txt"), "r") as f:
80 | lines = f.readlines()
81 | for line in lines:
82 | line = line.strip()[1:] # remove /
83 | classnames.append(line)
84 | cname2lab = {c: i for i, c in enumerate(classnames)}
85 | trainval = self.read_data(cname2lab, "Training_01.txt")
86 | test = self.read_data(cname2lab, "Testing_01.txt")
87 | train, val = OxfordPets.split_trainval(trainval)
88 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
89 |
90 | sstrain = self.read_sstrain_data(self.split_path, self.image_dir)
91 | num_shots = cfg.DATASET.NUM_SHOTS
92 | train = self.generate_fewshot_dataset(train, num_shots=-1)
93 | val = self.generate_fewshot_dataset(val, num_shots=-1)
94 | super().__init__(train_x=train, val = val, test=test, sstrain=sstrain)
--------------------------------------------------------------------------------
/datasets/oxford_flowers.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | from scipy.io import loadmat
4 | from collections import defaultdict
5 |
6 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
7 | from dassl.utils import read_json
8 |
9 | from .oxford_pets import OxfordPets
10 | from .datasetbase import UPLDatasetBase
11 |
12 |
13 | @DATASET_REGISTRY.register()
14 | class OxfordFlowers(DatasetBase):
15 |
16 | dataset_dir = "oxford_flowers"
17 |
18 | def __init__(self, cfg):
19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
20 | self.dataset_dir = os.path.join(root, self.dataset_dir)
21 | self.image_dir = os.path.join(self.dataset_dir, "jpg")
22 | self.label_file = os.path.join(self.dataset_dir, "imagelabels.mat")
23 | self.lab2cname_file = os.path.join(self.dataset_dir, "cat_to_name.json")
24 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordFlowers.json")
25 |
26 | if os.path.exists(self.split_path):
27 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
28 | else:
29 | train, val, test = self.read_data()
30 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
31 |
32 | num_shots = cfg.DATASET.NUM_SHOTS
33 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
34 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
35 |
36 | super().__init__(train_x=train, val=val, test=test)
37 |
38 | def read_data(self):
39 | tracker = defaultdict(list)
40 | label_file = loadmat(self.label_file)["labels"][0]
41 | for i, label in enumerate(label_file):
42 | imname = f"image_{str(i + 1).zfill(5)}.jpg"
43 | impath = os.path.join(self.image_dir, imname)
44 | label = int(label)
45 | tracker[label].append(impath)
46 |
47 | print("Splitting data into 50% train, 20% val, and 30% test")
48 |
49 | def _collate(ims, y, c):
50 | items = []
51 | for im in ims:
52 | item = Datum(impath=im, label=y - 1, classname=c) # convert to 0-based label
53 | items.append(item)
54 | return items
55 |
56 | lab2cname = read_json(self.lab2cname_file)
57 | train, val, test = [], [], []
58 | for label, impaths in tracker.items():
59 | random.shuffle(impaths)
60 | n_total = len(impaths)
61 | n_train = round(n_total * 0.5)
62 | n_val = round(n_total * 0.2)
63 | n_test = n_total - n_train - n_val
64 | assert n_train > 0 and n_val > 0 and n_test > 0
65 | cname = lab2cname[str(label)]
66 | train.extend(_collate(impaths[:n_train], label, cname))
67 | val.extend(_collate(impaths[n_train : n_train + n_val], label, cname))
68 | test.extend(_collate(impaths[n_train + n_val :], label, cname))
69 |
70 | return train, val, test
71 |
72 |
73 | @DATASET_REGISTRY.register()
74 | class SSOxfordFlowers(UPLDatasetBase):
75 | dataset_dir = "oxford_flowers"
76 | def __init__(self, cfg):
77 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
78 | self.dataset_dir = os.path.join(root, self.dataset_dir)
79 | self.image_dir = os.path.join(self.dataset_dir, "jpg")
80 | self.label_file = os.path.join(self.dataset_dir, "imagelabels.mat")
81 | self.lab2cname_file = os.path.join(self.dataset_dir, "cat_to_name.json")
82 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordFlowers.json")
83 |
84 | print(self.split_path, 23333333)
85 |
86 | if os.path.exists(self.split_path):
87 | train, val, test = self.read_split(self.split_path, self.image_dir)
88 | else:
89 | train, val, test = self.read_data()
90 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
91 |
92 | sstrain = self.read_sstrain_data(self.split_path, self.image_dir)
93 | num_shots = cfg.DATASET.NUM_SHOTS
94 | train = self.generate_fewshot_dataset(train, num_shots=-1)
95 | val = self.generate_fewshot_dataset(val, num_shots=-1)
96 | super().__init__(train_x=train, val = val, test=test, sstrain=sstrain)
--------------------------------------------------------------------------------
/datasets/dtd.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
5 | from dassl.utils import listdir_nohidden
6 |
7 | from .oxford_pets import OxfordPets
8 | from .datasetbase import UPLDatasetBase
9 |
10 |
11 | @DATASET_REGISTRY.register()
12 | class DescribableTextures(DatasetBase):
13 |
14 | dataset_dir = "dtd"
15 |
16 | def __init__(self, cfg):
17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
18 | self.dataset_dir = os.path.join(root, self.dataset_dir)
19 | self.image_dir = os.path.join(self.dataset_dir, "images")
20 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_DescribableTextures.json")
21 |
22 | if os.path.exists(self.split_path):
23 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
24 | else:
25 | train, val, test = self.read_and_split_data(self.image_dir)
26 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
27 |
28 | num_shots = cfg.DATASET.NUM_SHOTS
29 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
30 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
31 |
32 | super().__init__(train_x=train, val=val, test=test)
33 |
34 | @staticmethod
35 | def read_and_split_data(image_dir, p_trn=0.5, p_val=0.2, ignored=[], new_cnames=None):
36 | # The data are supposed to be organized into the following structure
37 | # =============
38 | # images/
39 | # dog/
40 | # cat/
41 | # horse/
42 | # =============
43 | categories = listdir_nohidden(image_dir)
44 | categories = [c for c in categories if c not in ignored]
45 | categories.sort()
46 |
47 | p_tst = 1 - p_trn - p_val
48 | print(f"Splitting into {p_trn:.0%} train, {p_val:.0%} val, and {p_tst:.0%} test")
49 |
50 | def _collate(ims, y, c):
51 | items = []
52 | for im in ims:
53 | item = Datum(impath=im, label=y, classname=c) # is already 0-based
54 | items.append(item)
55 | return items
56 |
57 | train, val, test = [], [], []
58 | for label, category in enumerate(categories):
59 | category_dir = os.path.join(image_dir, category)
60 | images = listdir_nohidden(category_dir)
61 | images = [os.path.join(category_dir, im) for im in images]
62 | random.shuffle(images)
63 | n_total = len(images)
64 | n_train = round(n_total * p_trn)
65 | n_val = round(n_total * p_val)
66 | n_test = n_total - n_train - n_val
67 | assert n_train > 0 and n_val > 0 and n_test > 0
68 |
69 | if new_cnames is not None and category in new_cnames:
70 | category = new_cnames[category]
71 |
72 | train.extend(_collate(images[:n_train], label, category))
73 | val.extend(_collate(images[n_train : n_train + n_val], label, category))
74 | test.extend(_collate(images[n_train + n_val :], label, category))
75 |
76 | return train, val, test
77 |
78 | @DATASET_REGISTRY.register()
79 | class SSDescribableTextures(UPLDatasetBase):
80 |
81 | dataset_dir = "dtd"
82 |
83 | def __init__(self, cfg):
84 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
85 | self.dataset_dir = os.path.join(root, self.dataset_dir)
86 | self.image_dir = os.path.join(self.dataset_dir, "images")
87 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_DescribableTextures.json")
88 |
89 | if os.path.exists(self.split_path):
90 | train, val, test = self.read_split(self.split_path, self.image_dir)
91 | else:
92 | train, val, test = self.read_and_split_data(self.image_dir)
93 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
94 | sstrain = self.read_sstrain_data(self.split_path, self.image_dir)
95 | num_shots = cfg.DATASET.NUM_SHOTS
96 | train = self.generate_fewshot_dataset(train, num_shots=-1)
97 | val = self.generate_fewshot_dataset(val, num_shots=-1)
98 | super().__init__(train_x=train, val = val, test=test, sstrain=sstrain)
--------------------------------------------------------------------------------
/datasets/ucf101.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 |
4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
5 |
6 | from .oxford_pets import OxfordPets
7 | from .datasetbase import UPLDatasetBase
8 |
9 |
10 | @DATASET_REGISTRY.register()
11 | class UCF101(DatasetBase):
12 |
13 | dataset_dir = "ucf101"
14 |
15 | def __init__(self, cfg):
16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
17 | self.dataset_dir = os.path.join(root, self.dataset_dir)
18 | self.image_dir = os.path.join(self.dataset_dir, "UCF-101-midframes")
19 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_UCF101.json")
20 |
21 | if os.path.exists(self.split_path):
22 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
23 | else:
24 | cname2lab = {}
25 | filepath = os.path.join(self.dataset_dir, "ucfTrainTestlist/classInd.txt")
26 | with open(filepath, "r") as f:
27 | lines = f.readlines()
28 | for line in lines:
29 | label, classname = line.strip().split(" ")
30 | label = int(label) - 1 # conver to 0-based index
31 | cname2lab[classname] = label
32 |
33 | trainval = self.read_data(cname2lab, "ucfTrainTestlist/trainlist01.txt")
34 | test = self.read_data(cname2lab, "ucfTrainTestlist/testlist01.txt")
35 | train, val = OxfordPets.split_trainval(trainval)
36 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
37 |
38 | num_shots = cfg.DATASET.NUM_SHOTS
39 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
40 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
41 |
42 | super().__init__(train_x=train, val=val, test=test)
43 |
44 | def read_data(self, cname2lab, text_file):
45 | text_file = os.path.join(self.dataset_dir, text_file)
46 | items = []
47 |
48 | with open(text_file, "r") as f:
49 | lines = f.readlines()
50 | for line in lines:
51 | line = line.strip().split(" ")[0] # trainlist: filename, label
52 | action, filename = line.split("/")
53 | label = cname2lab[action]
54 |
55 | elements = re.findall("[A-Z][^A-Z]*", action)
56 | renamed_action = "_".join(elements)
57 |
58 | filename = filename.replace(".avi", ".jpg")
59 | impath = os.path.join(self.image_dir, renamed_action, filename)
60 |
61 | item = Datum(impath=impath, label=label, classname=renamed_action)
62 | items.append(item)
63 |
64 | return items
65 |
66 | @DATASET_REGISTRY.register()
67 | class SSUCF101(UPLDatasetBase):
68 |
69 | dataset_dir = "ucf101"
70 |
71 | def __init__(self, cfg):
72 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
73 | self.dataset_dir = os.path.join(root, self.dataset_dir)
74 | self.image_dir = os.path.join(self.dataset_dir, "UCF-101-midframes")
75 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_UCF101.json")
76 |
77 | if os.path.exists(self.split_path):
78 | train, val, test = self.read_split(self.split_path, self.image_dir)
79 | else:
80 | cname2lab = {}
81 | filepath = os.path.join(self.dataset_dir, "ucfTrainTestlist/classInd.txt")
82 | with open(filepath, "r") as f:
83 | lines = f.readlines()
84 | for line in lines:
85 | label, classname = line.strip().split(" ")
86 | label = int(label) - 1 # conver to 0-based index
87 | cname2lab[classname] = label
88 |
89 | trainval = self.read_data(cname2lab, "ucfTrainTestlist/trainlist01.txt")
90 | test = self.read_data(cname2lab, "ucfTrainTestlist/testlist01.txt")
91 | train, val = OxfordPets.split_trainval(trainval)
92 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
93 |
94 | sstrain = self.read_sstrain_data(self.split_path, self.image_dir)
95 | num_shots = cfg.DATASET.NUM_SHOTS
96 | train = self.generate_fewshot_dataset(train, num_shots=-1)
97 | val = self.generate_fewshot_dataset(val, num_shots=-1)
98 | super().__init__(train_x=train, val = val, test=test, sstrain=sstrain)
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 |
2 | # Unsupervised Prompt Learning for Vision-Language Models
3 |
4 |
5 | ## Introduction
6 |
7 | This repo is the official implementation of **Unsupervised Prompt Learning for Vision-Language Models**. UPL is the first framework to introduce unsupervised learning into prompt learning of vision-language models to avoid time-consuming prompt engineering and better adapt vision-language models for the downstream image recognition task.
8 |
9 | Contact us with tonyhuang_pku@outlook.com or fawe@microsoft.com.
10 |
11 |
12 |
13 |
14 |
15 | Fig.1 Overview of Unsupervised Prompt Learning (UPL) Framework.
16 |
17 |
18 | ## Install
19 |
20 | The code is built on the [CoOp](https://github.com/KaiyangZhou/CoOp) and [Dassl](https://github.com/KaiyangZhou/Dassl.pytorch) with commit `ac6e44194b2f90e325f477aadd6d9bc3a92ce255`, so you need to install the dassl environment first. You can follow the [instructions](https://github.com/KaiyangZhou/Dassl.pytorch#installation) to install *dassl* as well as *PyTorch*. After that, run `pip install -r requirements.txt` under `UPL/` to install a few more packages required by [CLIP](https://github.com/openai/CLIP).
21 |
22 | **We also prepare all installation commands for you**:
23 |
24 | ```bash
25 |
26 | ############ install Dassl ############
27 |
28 | # Clone this repo
29 | git clone https://github.com/KaiyangZhou/Dassl.pytorch.git
30 | cd Dassl.pytorch/
31 | git reset --hard ac6e44194b2f90e325f477aadd6d9bc3a92ce255
32 |
33 | # Create a conda environment
34 | conda create -n dassl python=3.7
35 |
36 | # Activate the environment
37 | conda activate dassl
38 |
39 | # Install dependencies
40 | pip install -r requirements.txt
41 |
42 | # Install torch (version >= 1.7.1) and torchvision
43 | conda install pytorch==1.8.1 torchvision==0.9.1 cudatoolkit=10.1 -c pytorch
44 |
45 | # Install this library (no need to re-build if the source code is modified)
46 | python setup.py develop
47 |
48 | ############ install UPL ############
49 |
50 | # Enter the directory at the same level as Dassl
51 | cd ..
52 |
53 | # Clone this repo
54 | git clone https://github.com/tonyhuang2022/UPL.git
55 | cd UPL/
56 |
57 | # Install CLIP dependencies
58 | pip install -r requirements.txt
59 |
60 | ######## attention ########
61 | # We have two soft links, and you can redirect them!
62 | # The `data` is linked to the datasets, and the `temp_analyze_results_miltiple` is linked to the `info`.
63 | # We strongly recommend that you create these two paths on the disk which has enough space, and then use
64 |
65 | rm data temp_analyze_results_miltiple # remove the existing file
66 | ln -s ${your_data_path} ./data
67 | ln -s ${your_temp_analyze_results_miltiple_path} ./temp_analyze_results_miltiple
68 |
69 | # Finished
70 | ```
71 |
72 |
73 | ## Datasets
74 |
75 | After that, you can follow the [CoOp Datasets Instructions](https://github.com/KaiyangZhou/CoOp/blob/main/DATASETS.md) to prepare the datasets.
76 |
77 | Then, you can run the code ~
78 |
79 | ## Training
80 |
81 | Note that we need to obtain the pseudo labels of the train sets, so you need to run the `get info` step under `UPL/scripts` before training to obtain the needed files (e.g., logits).
82 |
83 | ### get info of RN50
84 |
85 |
86 | ```python
87 | CUDA_VISIBLE_DEVICES=0 bash get_info.sh sscaltech101 anay_rn50 end 16 -1 False
88 | ```
89 |
90 | You can change the config file to get the info from other models (e.g., anay_rn101, anay_rn50x4). We also prepare a script for getting info from all CLIP released models.
91 |
92 | ```python
93 | CUDA_VISIBLE_DEVICES=0 bash get_all_info.sh sscaltech101
94 | ```
95 |
96 |
97 |
98 | ### UPL train
99 |
100 | After `get info` step, we can train the prompt (default run 16 seeds, you can change it in `UPL/configs/UPLTrainer/rn50_ep50.yaml`):
101 |
102 | ```python
103 | CUDA_VISIBLE_DEVICES=0 bash upl_train.sh sscaltech101 rn50_ep50 end 16 16 False True rn50_random_init
104 | ```
105 |
106 | If you want to use *UPL**, please change the `PSEUDO_LABEL_MODELS` in `UPL/configs/UPLTrainer/rn50_ep50.yaml`. Please ensure that you have obatined info from all released models. Then, you can run
107 |
108 | ```python
109 | CUDA_VISIBLE_DEVICES=0 bash upl_train.sh sscaltech101 rn50_ep50 end 16 16 False True multiple_models_random_init
110 | ```
111 |
112 |
113 | ## Testing
114 |
115 | ### Test with existing files after UPL training
116 |
117 | ```python
118 | bash upl_test_existing_logits.sh sscaltech101 rn50_ep50 end 16 16 False True
119 | ```
120 |
121 | ## Thanks
122 |
123 | We use code from [CoOp](https://github.com/KaiyangZhou/CoOp) and [Dassl](https://github.com/KaiyangZhou/Dassl.pytorch), which are great repositories and we encourage you to check them out and cite them in your work.
--------------------------------------------------------------------------------
/clip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
--------------------------------------------------------------------------------
/datasets/fgvc_aircraft.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
4 |
5 | from .oxford_pets import OxfordPets
6 | from .datasetbase import UPLDatasetBase
7 |
8 |
9 | @DATASET_REGISTRY.register()
10 | class FGVCAircraft(DatasetBase):
11 |
12 | dataset_dir = "fgvc_aircraft"
13 |
14 | def __init__(self, cfg):
15 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
16 | self.dataset_dir = os.path.join(root, self.dataset_dir)
17 | self.image_dir = os.path.join(self.dataset_dir, "images")
18 |
19 | classnames = []
20 | with open(os.path.join(self.dataset_dir, "variants.txt"), "r") as f:
21 | lines = f.readlines()
22 | for line in lines:
23 | classnames.append(line.strip())
24 | cname2lab = {c: i for i, c in enumerate(classnames)}
25 |
26 | train = self.read_data(cname2lab, "images_variant_train.txt")
27 | val = self.read_data(cname2lab, "images_variant_val.txt")
28 | test = self.read_data(cname2lab, "images_variant_test.txt")
29 |
30 | num_shots = cfg.DATASET.NUM_SHOTS
31 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
32 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
33 |
34 | super().__init__(train_x=train, val=val, test=test)
35 |
36 | def read_data(self, cname2lab, split_file):
37 | filepath = os.path.join(self.dataset_dir, split_file)
38 | items = []
39 |
40 | with open(filepath, "r") as f:
41 | lines = f.readlines()
42 | for line in lines:
43 | line = line.strip().split(" ")
44 | imname = line[0] + ".jpg"
45 | classname = " ".join(line[1:])
46 | impath = os.path.join(self.image_dir, imname)
47 | label = cname2lab[classname]
48 | item = Datum(impath=impath, label=label, classname=classname)
49 | items.append(item)
50 |
51 | return items
52 |
53 |
54 | @DATASET_REGISTRY.register()
55 | class SSFGVCAircraft(UPLDatasetBase):
56 | dataset_dir = 'fgvc_aircraft'
57 | def __init__(self, cfg):
58 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
59 | self.dataset_dir = os.path.join(root, self.dataset_dir)
60 | self.image_dir = os.path.join(self.dataset_dir, "images")
61 |
62 | classnames = []
63 | with open(os.path.join(self.dataset_dir, "variants.txt"), "r") as f:
64 | lines = f.readlines()
65 | for line in lines:
66 | classnames.append(line.strip())
67 | cname2lab = {c: i for i, c in enumerate(classnames)}
68 | self.cname2lab = cname2lab
69 |
70 | train = self.read_data(cname2lab, "images_variant_train.txt")
71 | val = self.read_data(cname2lab, "images_variant_val.txt")
72 | test = self.read_data(cname2lab, "images_variant_test.txt")
73 |
74 | sstrain = self.read_data_without_label(cname2lab, "images_variant_train.txt")
75 | num_shots = cfg.DATASET.NUM_SHOTS
76 | train = self.generate_fewshot_dataset(train, num_shots=-1)
77 | val = self.generate_fewshot_dataset(val, num_shots=-1)
78 | super().__init__(train_x=train, val=val, test=test, sstrain=sstrain)
79 |
80 | def read_data(self, cname2lab, split_file):
81 | filepath = os.path.join(self.dataset_dir, split_file)
82 | items = []
83 |
84 | with open(filepath, "r") as f:
85 | lines = f.readlines()
86 | for line in lines:
87 | line = line.strip().split(" ")
88 | imname = line[0] + ".jpg"
89 | classname = " ".join(line[1:])
90 | impath = os.path.join(self.image_dir, imname)
91 | label = cname2lab[classname]
92 | item = Datum(impath=impath, label=label, classname=classname)
93 | items.append(item)
94 |
95 | return items
96 |
97 | def read_data_without_label(self, cname2lab, split_file, predict_label_dict=None):
98 | filepath = os.path.join(self.dataset_dir, split_file)
99 | items = []
100 | if predict_label_dict is None:
101 | with open(filepath, "r") as f:
102 | lines = f.readlines()
103 | for line in lines:
104 | line = line.strip().split(" ")
105 | imname = line[0] + ".jpg"
106 | classname = " ".join(line[1:])
107 | impath = os.path.join(self.image_dir, imname)
108 | label = cname2lab[classname]
109 | item = Datum(impath=impath, label=-1, classname=None)
110 | items.append(item)
111 | else:
112 | with open(filepath, "r") as f:
113 | lines = f.readlines()
114 | for line in lines:
115 | line = line.strip().split(" ")
116 | imname = line[0] + ".jpg"
117 | classname = " ".join(line[1:])
118 | impath = os.path.join(self.image_dir, imname)
119 | sub_impath = './data/' + impath.split('/data/')[1]
120 | if sub_impath in predict_label_dict:
121 | item = Datum(impath=impath, label=predict_label_dict[sub_impath], classname=self._lab2cname[predict_label_dict[sub_impath]])
122 | items.append(item)
123 | return items
124 |
--------------------------------------------------------------------------------
/datasets/caltech101.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
4 | from .datasetbase import UPLDatasetBase
5 | from dassl.utils import read_json
6 |
7 | from .oxford_pets import OxfordPets
8 | from .dtd import DescribableTextures as DTD
9 | import random
10 |
11 | IGNORED = ["BACKGROUND_Google", "Faces_easy"]
12 | NEW_CNAMES = {
13 | "airplanes": "airplane",
14 | "Faces": "face",
15 | "Leopards": "leopard",
16 | "Motorbikes": "motorbike",
17 | }
18 |
19 |
20 | @DATASET_REGISTRY.register()
21 | class Caltech101(DatasetBase):
22 |
23 | dataset_dir = "caltech-101"
24 |
25 | def __init__(self, cfg):
26 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
27 | self.dataset_dir = os.path.join(root, self.dataset_dir)
28 | self.image_dir = os.path.join(self.dataset_dir, "101_ObjectCategories")
29 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_Caltech101.json")
30 |
31 | if os.path.exists(self.split_path):
32 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
33 | else:
34 | train, val, test = DTD.read_and_split_data(self.image_dir, ignored=IGNORED, new_cnames=NEW_CNAMES)
35 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
36 |
37 | num_shots = cfg.DATASET.NUM_SHOTS
38 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
39 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
40 |
41 | super().__init__(train_x=train, val=val, test=test)
42 |
43 | @DATASET_REGISTRY.register()
44 | class OpensetCaltech101(UPLDatasetBase):
45 |
46 | dataset_dir = 'caltech-101'
47 |
48 | def __init__(self, cfg):
49 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
50 | self.dataset_dir = os.path.join(root, self.dataset_dir)
51 | self.image_dir = os.path.join(self.dataset_dir, '101_ObjectCategories')
52 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_Caltech101.json')
53 |
54 | if os.path.exists(self.split_path):
55 | train, val, test = self.read_split(self.split_path, self.image_dir)
56 | else:
57 | train, val, test = DTD.read_and_split_data(
58 | self.image_dir,
59 | ignored=IGNORED,
60 | new_cnames=NEW_CNAMES
61 | )
62 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
63 |
64 |
65 | # if IGNORE_NUM <=0 use IGNORE_FILE
66 | ignore_label_num = int(cfg.DATASET.IGNORE_NUM)
67 | if ignore_label_num > 0:
68 | ignore_labels = [i for i in range(100)]
69 | ignore_labels = random.sample(ignore_labels, ignore_label_num)
70 | else:
71 | ignore_labels = []
72 | caltech101_ignore_file = open(cfg.DATASET.IGNORE_FILE, 'r')
73 | lines = caltech101_ignore_file.readlines()
74 | for line_id in range(0, len(lines)):
75 | class_name, is_ignore = lines[line_id].split(',')
76 | if int(is_ignore) == 1:
77 | ignore_labels.append(class_name)
78 |
79 |
80 | num_shots = cfg.DATASET.NUM_SHOTS
81 | ignore_label_num = int(cfg.DATASET.IGNORE_NUM)
82 | train = self.generate_fewshot_dataset(train, num_shots=num_shots, mode='train')
83 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4), mode='val')
84 | novel, base = self.split_base_and_novel(test, ignore_labels)
85 | self.novel = novel
86 | self.base = base
87 |
88 | super().__init__(train_x=train, val=val, test=test, novel=novel, base=base)
89 |
90 | def read_split(self, filepath, path_prefix):
91 | def _convert(items):
92 | out = []
93 | for impath, label, classname in items:
94 | impath = os.path.join(path_prefix, impath)
95 | item = Datum(
96 | impath=impath,
97 | label=int(label),
98 | classname=classname
99 | )
100 | out.append(item)
101 | return out
102 |
103 | print(f'Reading split from {filepath}')
104 | split = read_json(filepath)
105 | train = _convert(split['train'])
106 | val = _convert(split['val'])
107 | test = _convert(split['test'])
108 |
109 | return train, val, test
110 |
111 |
112 |
113 | @DATASET_REGISTRY.register()
114 | class SSCaltech101(UPLDatasetBase):
115 | dataset_dir = 'caltech-101'
116 | def __init__(self, cfg):
117 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
118 | self.dataset_dir = os.path.join(root, self.dataset_dir)
119 | self.image_dir = os.path.join(self.dataset_dir, '101_ObjectCategories')
120 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_Caltech101.json')
121 |
122 | if os.path.exists(self.split_path):
123 | train, val, test = self.read_split(self.split_path, self.image_dir)
124 | else:
125 | train, val, test = DTD.read_and_split_data(
126 | self.image_dir,
127 | ignored=IGNORED,
128 | new_cnames=NEW_CNAMES
129 | )
130 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
131 | sstrain = self.read_sstrain_data(self.split_path, self.image_dir)
132 | num_shots = cfg.DATASET.NUM_SHOTS
133 | train = self.generate_fewshot_dataset(train, num_shots=-1)
134 | val = self.generate_fewshot_dataset(val, num_shots=-1)
135 | super().__init__(train_x=train, val = val, test=test, sstrain=sstrain)
136 |
137 |
--------------------------------------------------------------------------------
/datasets/oxford_pets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import random
4 | from collections import defaultdict
5 |
6 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
7 | from dassl.utils import read_json, write_json
8 |
9 | from .datasetbase import UPLDatasetBase
10 |
11 |
12 | @DATASET_REGISTRY.register()
13 | class OxfordPets(DatasetBase):
14 |
15 | dataset_dir = "oxford_pets"
16 |
17 | def __init__(self, cfg):
18 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
19 | self.dataset_dir = os.path.join(root, self.dataset_dir)
20 | self.image_dir = os.path.join(self.dataset_dir, "images")
21 | self.anno_dir = os.path.join(self.dataset_dir, "annotations")
22 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordPets.json")
23 |
24 | if os.path.exists(self.split_path):
25 | train, val, test = self.read_split(self.split_path, self.image_dir)
26 | else:
27 | trainval = self.read_data(split_file="trainval.txt")
28 | test = self.read_data(split_file="test.txt")
29 | train, val = self.split_trainval(trainval)
30 | self.save_split(train, val, test, self.split_path, self.image_dir)
31 |
32 | num_shots = cfg.DATASET.NUM_SHOTS
33 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
34 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
35 |
36 | super().__init__(train_x=train, val=val, test=test)
37 |
38 | def read_data(self, split_file):
39 | filepath = os.path.join(self.anno_dir, split_file)
40 | items = []
41 |
42 | with open(filepath, "r") as f:
43 | lines = f.readlines()
44 | for line in lines:
45 | line = line.strip()
46 | imname, label, species, _ = line.split(" ")
47 | breed = imname.split("_")[:-1]
48 | breed = "_".join(breed)
49 | breed = breed.lower()
50 | imname += ".jpg"
51 | impath = os.path.join(self.image_dir, imname)
52 | label = int(label) - 1 # convert to 0-based index
53 | item = Datum(impath=impath, label=label, classname=breed)
54 | items.append(item)
55 |
56 | return items
57 |
58 | @staticmethod
59 | def split_trainval(trainval, p_val=0.2):
60 | p_trn = 1 - p_val
61 | print(f"Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val")
62 | tracker = defaultdict(list)
63 | for idx, item in enumerate(trainval):
64 | label = item.label
65 | tracker[label].append(idx)
66 |
67 | train, val = [], []
68 | for label, idxs in tracker.items():
69 | n_val = round(len(idxs) * p_val)
70 | assert n_val > 0
71 | random.shuffle(idxs)
72 | for n, idx in enumerate(idxs):
73 | item = trainval[idx]
74 | if n < n_val:
75 | val.append(item)
76 | else:
77 | train.append(item)
78 |
79 | return train, val
80 |
81 | @staticmethod
82 | def save_split(train, val, test, filepath, path_prefix):
83 | def _extract(items):
84 | out = []
85 | for item in items:
86 | impath = item.impath
87 | label = item.label
88 | classname = item.classname
89 | impath = impath.replace(path_prefix, "")
90 | if impath.startswith("/"):
91 | impath = impath[1:]
92 | out.append((impath, label, classname))
93 | return out
94 |
95 | train = _extract(train)
96 | val = _extract(val)
97 | test = _extract(test)
98 |
99 | split = {"train": train, "val": val, "test": test}
100 |
101 | write_json(split, filepath)
102 | print(f"Saved split to {filepath}")
103 |
104 | @staticmethod
105 | def read_split(filepath, path_prefix):
106 | def _convert(items):
107 | out = []
108 | for impath, label, classname in items:
109 | impath = os.path.join(path_prefix, impath)
110 | item = Datum(impath=impath, label=int(label), classname=classname)
111 | out.append(item)
112 | return out
113 |
114 | print(f"Reading split from {filepath}")
115 | split = read_json(filepath)
116 | train = _convert(split["train"])
117 | val = _convert(split["val"])
118 | test = _convert(split["test"])
119 |
120 | return train, val, test
121 |
122 |
123 | @DATASET_REGISTRY.register()
124 | class SSOxfordPets(UPLDatasetBase):
125 | dataset_dir = "oxford_pets"
126 | def __init__(self, cfg):
127 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
128 | self.dataset_dir = os.path.join(root, self.dataset_dir)
129 | self.image_dir = os.path.join(self.dataset_dir, "images")
130 | self.anno_dir = os.path.join(self.dataset_dir, "annotations")
131 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordPets.json")
132 |
133 | if os.path.exists(self.split_path):
134 | train, val, test = self.read_split(self.split_path, self.image_dir)
135 | else:
136 | trainval = self.read_data(split_file="trainval.txt")
137 | test = self.read_data(split_file="test.txt")
138 | train, val = self.split_trainval(trainval)
139 | self.save_split(train, val, test, self.split_path, self.image_dir)
140 |
141 | sstrain = self.read_sstrain_data(self.split_path, self.image_dir)
142 | num_shots = cfg.DATASET.NUM_SHOTS
143 | train = self.generate_fewshot_dataset(train, num_shots=-1)
144 | val = self.generate_fewshot_dataset(val, num_shots=-1)
145 | super().__init__(train_x=train, val = val, test=test, sstrain=sstrain)
146 |
--------------------------------------------------------------------------------
/datasets/data_manager.py:
--------------------------------------------------------------------------------
1 | from dassl.data import DataManager
2 | from dassl.data.data_manager import DatasetWrapper
3 | from dassl.data.transforms import build_transform
4 | from dassl.data.samplers import build_sampler
5 |
6 | from torch.utils.data import DataLoader, WeightedRandomSampler
7 | import torch
8 |
9 | def build_data_loader(
10 | cfg,
11 | sampler_type="RandomSampler",
12 | sampler=None,
13 | data_source=None,
14 | batch_size=64,
15 | n_domain=0,
16 | n_ins=2,
17 | tfm=None,
18 | is_train=True,
19 | dataset_wrapper=None,
20 | tag=None
21 | ):
22 | # Build sampler
23 | if sampler_type is not None:
24 | sampler = build_sampler(
25 | sampler_type,
26 | cfg=cfg,
27 | data_source=data_source,
28 | batch_size=batch_size,
29 | n_domain=n_domain,
30 | n_ins=n_ins,
31 | )
32 | else:
33 | sampler = sampler
34 |
35 | if dataset_wrapper is None:
36 | dataset_wrapper = DatasetWrapper
37 |
38 | # Build data loader
39 | if tag is None:
40 | data_loader = torch.utils.data.DataLoader(
41 | dataset_wrapper(cfg, data_source, transform=tfm, is_train=is_train),
42 | batch_size=batch_size,
43 | sampler=sampler,
44 | num_workers=cfg.DATALOADER.NUM_WORKERS,
45 | drop_last=is_train and len(data_source) >= batch_size,
46 | pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA),
47 | )
48 | else:
49 | data_loader = torch.utils.data.DataLoader(
50 | dataset_wrapper(cfg, data_source, transform=tfm, is_train=is_train),
51 | batch_size=batch_size,
52 | sampler=sampler,
53 | num_workers=cfg.DATALOADER.NUM_WORKERS,
54 | drop_last=False,
55 | pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA),
56 | )
57 |
58 |
59 |
60 |
61 | return data_loader
62 |
63 | class UPLDataManager(DataManager):
64 | def __init__(self,
65 | cfg,
66 | custom_tfm_train=None,
67 | custom_tfm_test=None,
68 | dataset_wrapper=None):
69 | super().__init__(cfg, custom_tfm_train, custom_tfm_test, dataset_wrapper)
70 |
71 |
72 | if custom_tfm_test is None:
73 | tfm_test = build_transform(cfg, is_train=False)
74 | else:
75 | print("* Using custom transform for testing")
76 | tfm_test = custom_tfm_test
77 |
78 | # Build transform
79 | if custom_tfm_train is None:
80 | tfm_train = build_transform(cfg, is_train=True)
81 | else:
82 | print("* Using custom transform for training")
83 | tfm_train = custom_tfm_train
84 |
85 | # save cfg
86 | self.cfg = cfg
87 | self.tfm_train = tfm_train
88 | self.dataset_wrapper = dataset_wrapper
89 |
90 | if cfg.DATALOADER.OPEN_SETTING:
91 | test_novel_loader = build_data_loader(
92 | cfg,
93 | sampler_type=cfg.DATALOADER.TEST.SAMPLER,
94 | data_source=self.dataset.novel,
95 | batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
96 | tfm=tfm_test,
97 | is_train=False,
98 | dataset_wrapper=dataset_wrapper,
99 | )
100 |
101 | test_base_loader = build_data_loader(
102 | cfg,
103 | sampler_type=cfg.DATALOADER.TEST.SAMPLER,
104 | data_source=self.dataset.base,
105 | batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
106 | tfm=tfm_test,
107 | is_train=False,
108 | dataset_wrapper=dataset_wrapper,
109 | )
110 |
111 | self.test_novel_loader = test_novel_loader
112 | self.test_base_loader = test_base_loader
113 |
114 | try:
115 | if self.dataset.sstrain:
116 | # 除了dataset的source是不一样的,其他跟trian都是一样的
117 | train_loader_sstrain = build_data_loader(
118 | cfg,
119 | sampler_type="SequentialSampler",
120 | data_source=self.dataset.sstrain,
121 | batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE,
122 | tfm=tfm_test,
123 | is_train=False,
124 | dataset_wrapper=dataset_wrapper,
125 | tag='sstrain' # 初始化的时候需要设置这个来保证所有样本的载入
126 | )
127 | self.train_loader_sstrain = train_loader_sstrain
128 |
129 | # Build train_loader_x
130 | train_loader_x = build_data_loader(
131 | cfg,
132 | sampler_type="SequentialSampler",
133 | data_source=self.dataset.train_x,
134 | batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE,
135 | tfm=tfm_test, # 这个是不训练的,所以要用测试的配置,
136 | is_train=False,
137 | dataset_wrapper=dataset_wrapper,
138 | tag='sstrain'
139 | )
140 | self.train_loader_x = train_loader_x
141 | except:
142 | pass
143 |
144 | def update_ssdateloader(self, predict_label_dict, predict_conf_dict):
145 | """update the train_loader_sstrain to add labels
146 |
147 | Args:
148 | predict_label_dict ([dict]): [a dict {'imagepath': 'label'}]
149 | """
150 |
151 |
152 | sstrain = self.dataset.add_label(predict_label_dict, self.cfg.DATASET.NAME)
153 | print('sstrain', len(sstrain))
154 |
155 |
156 | # train_sampler = WeightedRandomSampler(weights, len(sstrain))
157 | train_loader_sstrain = build_data_loader(
158 | self.cfg,
159 | sampler_type="RandomSampler",
160 | sampler=None,
161 | data_source=sstrain,
162 | batch_size=self.cfg.DATALOADER.TRAIN_X.BATCH_SIZE,
163 | n_domain=self.cfg.DATALOADER.TRAIN_X.N_DOMAIN,
164 | n_ins=1, # 每个类别n_ins个instance
165 | tfm=self.tfm_train,
166 | is_train=True,
167 | dataset_wrapper=self.dataset_wrapper,
168 | )
169 | self.train_loader_sstrain = train_loader_sstrain
170 |
--------------------------------------------------------------------------------
/get_info.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | # torch.set_printoptions(profile="full")
4 | import os
5 |
6 | from dassl.utils import setup_logger, set_random_seed, collect_env_info
7 | from configs.upl_default_config.upl_default import get_cfg_default
8 | from dassl.engine import build_trainer
9 |
10 | # custom
11 | import datasets.oxford_pets
12 | import datasets.oxford_flowers
13 | import datasets.fgvc_aircraft
14 | import datasets.dtd
15 | import datasets.eurosat
16 | import datasets.stanford_cars
17 | import datasets.food101
18 | import datasets.sun397
19 | import datasets.caltech101
20 | import datasets.ucf101
21 | import datasets.imagenet
22 |
23 | import trainers.upltrainer
24 | import trainers.hhzsclip
25 |
26 |
27 | def print_args(args, cfg):
28 | print('***************')
29 | print('** Arguments **')
30 | print('***************')
31 | optkeys = list(args.__dict__.keys())
32 | optkeys.sort()
33 | for key in optkeys:
34 | print('{}: {}'.format(key, args.__dict__[key]))
35 | print('************')
36 | print('** Config **')
37 | print('************')
38 | print(cfg)
39 |
40 |
41 | def reset_cfg(cfg, args):
42 | if args.root:
43 | cfg.DATASET.ROOT = args.root
44 |
45 | if args.output_dir:
46 | cfg.OUTPUT_DIR = args.output_dir
47 |
48 | if args.resume:
49 | cfg.RESUME = args.resume
50 |
51 | if args.seed:
52 | cfg.SEED = args.seed
53 |
54 | if args.source_domains:
55 | cfg.DATASET.SOURCE_DOMAINS = args.source_domains
56 |
57 | if args.target_domains:
58 | cfg.DATASET.TARGET_DOMAINS = args.target_domains
59 |
60 | if args.transforms:
61 | cfg.INPUT.TRANSFORMS = args.transforms
62 |
63 | if args.trainer:
64 | cfg.TRAINER.NAME = args.trainer
65 |
66 | if args.backbone:
67 | cfg.MODEL.BACKBONE.NAME = args.backbone
68 |
69 | if args.head:
70 | cfg.MODEL.HEAD.NAME = args.head
71 |
72 |
73 | def extend_cfg(cfg):
74 | """
75 | Add new config variables.
76 |
77 | E.g.
78 | from yacs.config import CfgNode as CN
79 | cfg.TRAINER.MY_MODEL = CN()
80 | cfg.TRAINER.MY_MODEL.PARAM_A = 1.
81 | cfg.TRAINER.MY_MODEL.PARAM_B = 0.5
82 | cfg.TRAINER.MY_MODEL.PARAM_C = False
83 | """
84 | from yacs.config import CfgNode as CN
85 |
86 | cfg.TRAINER.UPLTrainer = CN()
87 | cfg.TRAINER.UPLTrainer.N_CTX = 16 # number of context vectors
88 | cfg.TRAINER.UPLTrainer.CSC = False # class-specific context
89 | cfg.TRAINER.UPLTrainer.CTX_INIT = "" # initialization words
90 | cfg.TRAINER.UPLTrainer.PREC = "fp16" # fp16, fp32, amp
91 | cfg.TRAINER.UPLTrainer.CLASS_TOKEN_POSITION = "end" # 'middle' or 'end' or 'front'
92 |
93 |
94 | def setup_cfg(args):
95 | cfg = get_cfg_default()
96 | extend_cfg(cfg)
97 |
98 | # 1. From the dataset config file
99 | if args.dataset_config_file:
100 | cfg.merge_from_file(args.dataset_config_file)
101 |
102 | # 2. From the method config file
103 | if args.config_file:
104 | cfg.merge_from_file(args.config_file)
105 |
106 | # 3. hh config
107 | if args.hh_config_file:
108 | cfg.merge_from_file(args.hh_config_file)
109 |
110 | # 3. From input arguments
111 | reset_cfg(cfg, args)
112 |
113 | # 4. From optional input arguments
114 | cfg.merge_from_list(args.opts)
115 |
116 | cfg.freeze()
117 |
118 | return cfg
119 |
120 |
121 | def main(args):
122 | cfg = setup_cfg(args)
123 | if cfg.SEED >= 0:
124 | print('Setting fixed seed: {}'.format(cfg.SEED))
125 | set_random_seed(cfg.SEED)
126 | setup_logger(cfg.OUTPUT_DIR)
127 |
128 | if torch.cuda.is_available() and cfg.USE_CUDA:
129 | torch.backends.cudnn.benchmark = True
130 |
131 | print_args(args, cfg)
132 | print('Collecting env info ...')
133 | print('** System info **\n{}\n'.format(collect_env_info()))
134 |
135 | trainer = build_trainer(cfg)
136 | if args.model_dir:
137 | trainer.load_model_by_id(args.model_dir, epoch=args.load_epoch, model_id=0)
138 | trainer.zero_shot_analyze()
139 |
140 |
141 | if __name__ == '__main__':
142 | parser = argparse.ArgumentParser()
143 | parser.add_argument('--root', type=str, default='', help='path to dataset')
144 | parser.add_argument(
145 | '--output-dir', type=str, default='', help='output directory'
146 | )
147 | parser.add_argument(
148 | '--resume',
149 | type=str,
150 | default='',
151 | help='checkpoint directory (from which the training resumes)'
152 | )
153 | parser.add_argument(
154 | '--seed',
155 | type=int,
156 | default=-1,
157 | help='only positive value enables a fixed seed'
158 | )
159 | parser.add_argument(
160 | '--source-domains',
161 | type=str,
162 | nargs='+',
163 | help='source domains for DA/DG'
164 | )
165 | parser.add_argument(
166 | '--target-domains',
167 | type=str,
168 | nargs='+',
169 | help='target domains for DA/DG'
170 | )
171 | parser.add_argument(
172 | '--transforms', type=str, nargs='+', help='data augmentation methods'
173 | )
174 | parser.add_argument(
175 | '--config-file', type=str, default='', help='path to config file'
176 | )
177 | parser.add_argument(
178 | '--dataset-config-file',
179 | type=str,
180 | default='',
181 | help='path to config file for dataset setup'
182 | )
183 | parser.add_argument(
184 | '--hh-config-file', type=str, default='', help='path to config file'
185 | )
186 | parser.add_argument(
187 | '--trainer', type=str, default='', help='name of trainer'
188 | )
189 | parser.add_argument(
190 | '--backbone', type=str, default='', help='name of CNN backbone'
191 | )
192 | parser.add_argument('--head', type=str, default='', help='name of head')
193 | parser.add_argument(
194 | '--eval-only', action='store_true', help='evaluation only'
195 | )
196 | parser.add_argument(
197 | '--model-dir',
198 | type=str,
199 | default='',
200 | help='load model from this directory for eval-only mode'
201 | )
202 | parser.add_argument(
203 | '--load-epoch',
204 | type=int,
205 | help='load model weights at this epoch for evaluation'
206 | )
207 | parser.add_argument(
208 | '--no-train', action='store_true', help='do not call trainer.train()'
209 | )
210 | parser.add_argument(
211 | 'opts',
212 | default=None,
213 | nargs=argparse.REMAINDER,
214 | help='modify config options using the command-line'
215 | )
216 | args = parser.parse_args()
217 | main(args)
218 |
--------------------------------------------------------------------------------
/datasets/imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from collections import OrderedDict
4 | from tqdm import tqdm
5 |
6 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
7 | from dassl.utils import listdir_nohidden
8 | from .datasetbase import UPLDatasetBase
9 |
10 | TO_BE_IGNORED = ['new_val_prep.sh', 'train1.txt', 'train.txt']
11 |
12 | @DATASET_REGISTRY.register()
13 | class SSImageNet(UPLDatasetBase):
14 |
15 | dataset_dir = "imagenet"
16 |
17 | def __init__(self, cfg):
18 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
19 | self.dataset_dir = os.path.join(root, self.dataset_dir)
20 | self.image_dir = os.path.join(self.dataset_dir, "images")
21 | self.preprocessed = os.path.join(self.dataset_dir, "preprocessed.pkl")
22 | self.img_dir = []
23 |
24 | if os.path.exists(self.preprocessed):
25 | with open(self.preprocessed, "rb") as f:
26 | preprocessed = pickle.load(f)
27 | train = preprocessed["train"]
28 | test = preprocessed["test"]
29 | else:
30 | text_file = os.path.join(self.dataset_dir, "classnames.txt")
31 | classnames = self.read_classnames(text_file)
32 | train = self.read_data(classnames, "train")
33 | # Follow standard practice to perform evaluation on the val set
34 | # Also used as the val set (so evaluate the last-step model)
35 | test = self.read_data(classnames, "val")
36 |
37 | preprocessed = {"train": train, "test": test}
38 | with open(self.preprocessed, "wb") as f:
39 | pickle.dump(preprocessed, f, protocol=pickle.HIGHEST_PROTOCOL)
40 |
41 | num_shots = cfg.DATASET.NUM_SHOTS
42 | train = self.generate_fewshot_dataset(train, num_shots=-1)
43 |
44 | sstrain = self.read_sstrain_data(train)
45 | super().__init__(train_x=train, test=test, sstrain=sstrain)
46 |
47 | @staticmethod
48 | def read_classnames(text_file):
49 | """Return a dictionary containing
50 | key-value pairs of : .
51 | """
52 | classnames = OrderedDict()
53 | with open(text_file, "r") as f:
54 | lines = f.readlines()
55 | for line in lines:
56 | line = line.strip().split(" ")
57 | folder = line[0]
58 | classname = " ".join(line[1:])
59 | classnames[folder] = classname
60 | return classnames
61 |
62 | def read_data(self, classnames, split_dir):
63 | split_dir = os.path.join(self.image_dir, split_dir)
64 | folders = sorted(f.name for f in os.scandir(split_dir) if f.is_dir())
65 | items = []
66 |
67 | for label, folder in enumerate(folders):
68 | imnames = listdir_nohidden(os.path.join(split_dir, folder))
69 | classname = classnames[folder]
70 | for imname in imnames:
71 | impath = os.path.join(split_dir, folder, imname)
72 | item = Datum(impath=impath, label=label, classname=classname)
73 | items.append(item)
74 | return items
75 |
76 | def read_sstrain_data(self, train, predict_label_dict=None):
77 | items = []
78 | if predict_label_dict is None:
79 | for item in tqdm(train):
80 | new_item = Datum(impath=item.impath, label=-1, classname=None)
81 | items.append(new_item)
82 | self.img_dir.append(item.impath)
83 | else:
84 | for item in tqdm(train):
85 | sub_impath = './data/' + item.impath.split('/data/')[1]
86 | if sub_impath in predict_label_dict:
87 | new_item = Datum(impath=item.impath, label=predict_label_dict[sub_impath], classname=self._lab2cname[predict_label_dict[sub_impath]])
88 | items.append(new_item)
89 | self.img_dir.append(item.impath)
90 | return items
91 |
92 |
93 |
94 |
95 | @DATASET_REGISTRY.register()
96 | class ImageNet(DatasetBase):
97 |
98 | dataset_dir = "imagenet"
99 |
100 | def __init__(self, cfg):
101 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
102 | self.dataset_dir = os.path.join(root, self.dataset_dir)
103 | self.image_dir = os.path.join(self.dataset_dir, "images")
104 | self.preprocessed = os.path.join(self.dataset_dir, "preprocessed.pkl")
105 |
106 | if os.path.exists(self.preprocessed):
107 | with open(self.preprocessed, "rb") as f:
108 | preprocessed = pickle.load(f)
109 | train = preprocessed["train"]
110 | test = preprocessed["test"]
111 | else:
112 | text_file = os.path.join(self.dataset_dir, "classnames.txt")
113 | classnames = self.read_classnames(text_file)
114 | train = self.read_data(classnames, "train")
115 | # Follow standard practice to perform evaluation on the val set
116 | # Also used as the val set (so evaluate the last-step model)
117 | test = self.read_data(classnames, "val")
118 |
119 | preprocessed = {"train": train, "test": test}
120 | with open(self.preprocessed, "wb") as f:
121 | pickle.dump(preprocessed, f, protocol=pickle.HIGHEST_PROTOCOL)
122 |
123 | num_shots = cfg.DATASET.NUM_SHOTS
124 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
125 |
126 | super().__init__(train_x=train, val=test, test=test)
127 |
128 | @staticmethod
129 | def read_classnames(text_file):
130 | """Return a dictionary containing
131 | key-value pairs of : .
132 | """
133 | classnames = OrderedDict()
134 | with open(text_file, "r") as f:
135 | lines = f.readlines()
136 | for line in lines:
137 | line = line.strip().split(" ")
138 | folder = line[0]
139 | classname = " ".join(line[1:])
140 | classnames[folder] = classname
141 | return classnames
142 |
143 | def read_data(self, classnames, split_dir):
144 | split_dir = os.path.join(self.image_dir, split_dir)
145 | folders = sorted(f.name for f in os.scandir(split_dir) if f.is_dir())
146 | items = []
147 |
148 | for label, folder in enumerate(folders):
149 | imnames = listdir_nohidden(os.path.join(split_dir, folder))
150 | classname = classnames[folder]
151 | for imname in imnames:
152 | impath = os.path.join(split_dir, folder, imname)
153 | item = Datum(impath=impath, label=label, classname=classname)
154 | items.append(item)
155 |
156 | return items
--------------------------------------------------------------------------------
/upl_train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | # torch.set_printoptions(profile="full")
4 | import os
5 |
6 | from dassl.utils import setup_logger, set_random_seed, collect_env_info
7 | from configs.upl_default_config.upl_default import get_cfg_default
8 | from dassl.engine import build_trainer
9 |
10 | # custom
11 | import datasets.oxford_pets
12 | import datasets.oxford_flowers
13 | import datasets.fgvc_aircraft
14 | import datasets.dtd
15 | import datasets.eurosat
16 | import datasets.stanford_cars
17 | import datasets.food101
18 | import datasets.sun397
19 | import datasets.caltech101
20 | import datasets.ucf101
21 | import datasets.imagenet
22 |
23 | import trainers.upltrainer
24 | import trainers.hhzsclip
25 |
26 |
27 | def print_args(args, cfg):
28 | print('***************')
29 | print('** Arguments **')
30 | print('***************')
31 | optkeys = list(args.__dict__.keys())
32 | optkeys.sort()
33 | for key in optkeys:
34 | print('{}: {}'.format(key, args.__dict__[key]))
35 | print('************')
36 | print('** Config **')
37 | print('************')
38 | print(cfg)
39 |
40 |
41 | def reset_cfg(cfg, args):
42 | if args.root:
43 | cfg.DATASET.ROOT = args.root
44 |
45 | if args.output_dir:
46 | cfg.OUTPUT_DIR = args.output_dir
47 |
48 | if args.resume:
49 | cfg.RESUME = args.resume
50 |
51 | if args.seed:
52 | cfg.SEED = args.seed
53 |
54 | if args.source_domains:
55 | cfg.DATASET.SOURCE_DOMAINS = args.source_domains
56 |
57 | if args.target_domains:
58 | cfg.DATASET.TARGET_DOMAINS = args.target_domains
59 |
60 | if args.transforms:
61 | cfg.INPUT.TRANSFORMS = args.transforms
62 |
63 | if args.trainer:
64 | cfg.TRAINER.NAME = args.trainer
65 |
66 | if args.backbone:
67 | cfg.MODEL.BACKBONE.NAME = args.backbone
68 |
69 | if args.head:
70 | cfg.MODEL.HEAD.NAME = args.head
71 |
72 |
73 | def extend_cfg(cfg):
74 | """
75 | Add new config variables.
76 |
77 | E.g.
78 | from yacs.config import CfgNode as CN
79 | cfg.TRAINER.MY_MODEL = CN()
80 | cfg.TRAINER.MY_MODEL.PARAM_A = 1.
81 | cfg.TRAINER.MY_MODEL.PARAM_B = 0.5
82 | cfg.TRAINER.MY_MODEL.PARAM_C = False
83 | """
84 | from yacs.config import CfgNode as CN
85 |
86 | cfg.TRAINER.UPLTrainer = CN()
87 | cfg.TRAINER.UPLTrainer.N_CTX = 16 # number of context vectors
88 | cfg.TRAINER.UPLTrainer.CSC = False # class-specific context
89 | cfg.TRAINER.UPLTrainer.CTX_INIT = "" # initialization words
90 | cfg.TRAINER.UPLTrainer.PREC = "fp16" # fp16, fp32, amp
91 | cfg.TRAINER.UPLTrainer.CLASS_TOKEN_POSITION = "end" # 'middle' or 'end' or 'front'
92 |
93 |
94 | def setup_cfg(args):
95 | cfg = get_cfg_default()
96 | extend_cfg(cfg)
97 |
98 | # 1. From the dataset config file
99 | if args.dataset_config_file:
100 | cfg.merge_from_file(args.dataset_config_file)
101 |
102 | # 2. From the method config file
103 | if args.config_file:
104 | cfg.merge_from_file(args.config_file)
105 |
106 | # 3. hh config
107 | if args.hh_config_file:
108 | cfg.merge_from_file(args.hh_config_file)
109 |
110 | # 3. From input arguments
111 | reset_cfg(cfg, args)
112 |
113 | # 4. From optional input arguments
114 | cfg.merge_from_list(args.opts)
115 |
116 | cfg.freeze()
117 |
118 | return cfg
119 |
120 |
121 | def main(args):
122 | cfg = setup_cfg(args)
123 | if cfg.SEED >= 0:
124 | print('Setting fixed seed: {}'.format(cfg.SEED))
125 | set_random_seed(cfg.SEED)
126 | setup_logger(cfg.OUTPUT_DIR)
127 |
128 | if torch.cuda.is_available() and cfg.USE_CUDA:
129 | torch.backends.cudnn.benchmark = True
130 |
131 | print_args(args, cfg)
132 | print('Collecting env info ...')
133 | print('** System info **\n{}\n'.format(collect_env_info()))
134 |
135 | trainer_list = []
136 | for i in range(int(cfg.TRAINER.ENSEMBLE_NUM)):
137 | trainer = build_trainer(cfg)
138 | if args.model_dir:
139 | trainer.load_model_by_id(args.model_dir, epoch=args.load_epoch, model_id=i)
140 | trainer_list.append(trainer)
141 |
142 | predict_label_dict, predict_conf_dict = trainer.load_from_exist_file(file_path='./analyze_results',
143 | model_names=cfg.MODEL.PSEUDO_LABEL_MODELS)
144 |
145 | trainer.dm.update_ssdateloader(predict_label_dict, predict_conf_dict)
146 | trainer.train_loader_sstrain = trainer.dm.train_loader_sstrain
147 | trainer.sstrain_with_id(model_id=i)
148 |
149 | if __name__ == '__main__':
150 | parser = argparse.ArgumentParser()
151 | parser.add_argument('--root', type=str, default='', help='path to dataset')
152 | parser.add_argument(
153 | '--output-dir', type=str, default='', help='output directory'
154 | )
155 | parser.add_argument(
156 | '--resume',
157 | type=str,
158 | default='',
159 | help='checkpoint directory (from which the training resumes)'
160 | )
161 | parser.add_argument(
162 | '--seed',
163 | type=int,
164 | default=-1,
165 | help='only positive value enables a fixed seed'
166 | )
167 | parser.add_argument(
168 | '--source-domains',
169 | type=str,
170 | nargs='+',
171 | help='source domains for DA/DG'
172 | )
173 | parser.add_argument(
174 | '--target-domains',
175 | type=str,
176 | nargs='+',
177 | help='target domains for DA/DG'
178 | )
179 | parser.add_argument(
180 | '--transforms', type=str, nargs='+', help='data augmentation methods'
181 | )
182 | parser.add_argument(
183 | '--config-file', type=str, default='', help='path to config file'
184 | )
185 | parser.add_argument(
186 | '--dataset-config-file',
187 | type=str,
188 | default='',
189 | help='path to config file for dataset setup'
190 | )
191 | parser.add_argument(
192 | '--hh-config-file', type=str, default='', help='path to config file'
193 | )
194 | parser.add_argument(
195 | '--trainer', type=str, default='', help='name of trainer'
196 | )
197 | parser.add_argument(
198 | '--backbone', type=str, default='', help='name of CNN backbone'
199 | )
200 | parser.add_argument('--head', type=str, default='', help='name of head')
201 | parser.add_argument(
202 | '--eval-only', action='store_true', help='evaluation only'
203 | )
204 | parser.add_argument(
205 | '--model-dir',
206 | type=str,
207 | default='',
208 | help='load model from this directory for eval-only mode'
209 | )
210 | parser.add_argument(
211 | '--load-epoch',
212 | type=int,
213 | help='load model weights at this epoch for evaluation'
214 | )
215 | parser.add_argument(
216 | '--no-train', action='store_true', help='do not call trainer.train()'
217 | )
218 | parser.add_argument(
219 | 'opts',
220 | default=None,
221 | nargs=argparse.REMAINDER,
222 | help='modify config options using the command-line'
223 | )
224 | args = parser.parse_args()
225 | main(args)
226 |
--------------------------------------------------------------------------------
/datasets/datasetbase.py:
--------------------------------------------------------------------------------
1 | from dassl.data.datasets import DatasetBase
2 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
3 | from dassl.utils import read_json, write_json
4 | from dassl.utils import listdir_nohidden
5 |
6 | import random
7 | import os
8 |
9 |
10 | class UPLDatasetBase(DatasetBase):
11 | def __init__(self, train_x=None, train_u=None, val=None, test=None, novel=None, base=None, sstrain=None):
12 | super().__init__(train_x=train_x, val=val, test=test)
13 | self._novel = novel
14 | self._base = base
15 | self.sstrain = sstrain
16 | self._lab2cname_novel, _ = self.get_lab2cname(novel)
17 | self._lab2cname_base, _ = self.get_lab2cname(base)
18 |
19 | def generate_fewshot_dataset(self, *data_sources, num_shots=-1, mode=None, ignore_labels=None, repeat=False):
20 | """Generate a few-shot dataset (typically for the training set).
21 |
22 | This function is useful when one wants to evaluate a model
23 | in a few-shot learning setting where each class only contains
24 | a few number of images.
25 |
26 | Args:
27 | data_sources: each individual is a list containing Datum objects.
28 | num_shots (int): number of instances per class to sample.
29 | repeat (bool): repeat images if needed (default: False).
30 | """
31 | if num_shots < 1:
32 | if len(data_sources) == 1:
33 | return data_sources[0]
34 | return data_sources
35 |
36 | print(f"Creating a {num_shots}-shot dataset")
37 |
38 | output = []
39 |
40 | for data_source in data_sources:
41 | tracker = self.split_dataset_by_label(data_source)
42 | dataset = []
43 |
44 | for label, items in tracker.items():
45 | # print(label, items[0].classname)
46 | if mode == 'train':
47 | if ignore_labels:
48 | if items[0].classname not in ignore_labels:
49 | if len(items) >= num_shots:
50 | sampled_items = random.sample(items, num_shots)
51 | else:
52 | if repeat:
53 | sampled_items = random.choices(items, k=num_shots)
54 | else:
55 | sampled_items = items
56 | dataset.extend(sampled_items)
57 | else:
58 | if len(items) >= num_shots:
59 | sampled_items = random.sample(items, num_shots)
60 | else:
61 | if repeat:
62 | sampled_items = random.choices(items, k=num_shots)
63 | else:
64 | sampled_items = items
65 | dataset.extend(sampled_items)
66 | else:
67 | if len(items) >= num_shots:
68 | sampled_items = random.sample(items, num_shots)
69 | else:
70 | if repeat:
71 | sampled_items = random.choices(items, k=num_shots)
72 | else:
73 | sampled_items = items
74 | dataset.extend(sampled_items)
75 | output.append(dataset)
76 | if len(output) == 1:
77 | return output[0]
78 |
79 | return output
80 |
81 | def split_base_and_novel(self, test, ignore_labels):
82 | tracker = self.split_dataset_by_label(test)
83 | dataset_base = []
84 | dataset_novel = []
85 | for label, items in tracker.items():
86 | if items[0].classname not in ignore_labels:
87 | dataset_novel.extend(items)
88 | else:
89 | dataset_base.extend(items)
90 | return dataset_novel, dataset_base
91 |
92 | def get_lab2cname(self, data_source):
93 | """Get a label-to-classname mapping (dict).
94 |
95 | Args:
96 | data_source (list): a list of Datum objects.
97 | """
98 | container = set()
99 | if data_source is not None:
100 | for item in data_source:
101 | container.add((item.label, item.classname))
102 | mapping = {label: classname for label, classname in container}
103 | labels = list(mapping.keys())
104 | labels.sort()
105 | classnames = [mapping[label] for label in labels]
106 | return mapping, classnames
107 | else:
108 | return None, None
109 |
110 | def read_split(self, filepath, path_prefix):
111 | def _convert(items):
112 | out = []
113 | for impath, label, classname in items:
114 | impath = os.path.join(path_prefix, impath)
115 | item = Datum(impath=impath, label=label, classname=classname)
116 | out.append(item)
117 | return out
118 |
119 | print(f"Reading split from {filepath}")
120 | split = read_json(filepath)
121 | train = _convert(split["train"])
122 | val = _convert(split["val"])
123 | test = _convert(split["test"])
124 | return train, val, test
125 |
126 | def read_sstrain_data(self, filepath, path_prefix, predict_label_dict=None):
127 |
128 | def _convert(items):
129 | out = []
130 | for impath, _, _ in items:
131 | impath = os.path.join(path_prefix, impath)
132 | sub_impath = './data/' + impath.split('/data/')[1]
133 | if sub_impath in predict_label_dict:
134 | item = Datum(impath=impath, label=predict_label_dict[sub_impath], classname=self._lab2cname[predict_label_dict[sub_impath]])
135 | out.append(item)
136 | return out
137 |
138 | def _convert_no_label(items):
139 | out = []
140 | for impath, _, _ in items:
141 | impath = os.path.join(path_prefix, impath)
142 | item = Datum(impath=impath, label=-1, classname=None)
143 | out.append(item)
144 | return out
145 |
146 | print(f"Reading split from {filepath}")
147 | split = read_json(filepath)
148 | if predict_label_dict is not None:
149 | train = _convert(split["train"])
150 | else:
151 | train = _convert_no_label(split["train"])
152 | return train
153 |
154 | # dtd会报错 需要单独处理,它里面有处理的函数,需要copy过来,详见原版的database
155 |
156 | def add_label(self, predict_label_dict, dataset_name):
157 | """add label when training for self-supervised learning
158 |
159 | Args:
160 | predict_label_dict ([dict]): [a dict {'imagepath': 'label'}]
161 | """
162 | # print(predict_label_dict, 'predict_label_dict')
163 | print(dataset_name)
164 | if dataset_name == 'SSFGVCAircraft':
165 | sstrain = self.read_data_without_label(self.cname2lab, "images_variant_train.txt", predict_label_dict)
166 | elif dataset_name == 'SSImageNet':
167 | sstrain = self.read_sstrain_data(self.train_x, predict_label_dict)
168 | else:
169 | sstrain = self.read_sstrain_data(self.split_path, self.image_dir, predict_label_dict)
170 |
171 | self.sstrain = sstrain
172 | return sstrain
--------------------------------------------------------------------------------
/upl_test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | # torch.set_printoptions(profile="full")
4 | import os
5 | import random
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 |
9 |
10 | from dassl.utils import setup_logger, set_random_seed, collect_env_info
11 | from configs.upl_default_config.upl_default import get_cfg_default
12 | from dassl.engine import build_trainer
13 |
14 | # custom
15 | import datasets.oxford_pets
16 | import datasets.oxford_flowers
17 | import datasets.fgvc_aircraft
18 | import datasets.dtd
19 | import datasets.eurosat
20 | import datasets.stanford_cars
21 | import datasets.food101
22 | import datasets.sun397
23 | import datasets.caltech101
24 | import datasets.ucf101
25 | import datasets.imagenet
26 |
27 | import trainers.upltrainer
28 | import trainers.hhzsclip
29 |
30 |
31 | def print_args(args, cfg):
32 | print('***************')
33 | print('** Arguments **')
34 | print('***************')
35 | optkeys = list(args.__dict__.keys())
36 | optkeys.sort()
37 | for key in optkeys:
38 | print('{}: {}'.format(key, args.__dict__[key]))
39 | print('************')
40 | print('** Config **')
41 | print('************')
42 | print(cfg)
43 |
44 | def reset_cfg(cfg, args):
45 | if args.root:
46 | cfg.DATASET.ROOT = args.root
47 |
48 | if args.output_dir:
49 | cfg.OUTPUT_DIR = args.output_dir
50 |
51 | if args.resume:
52 | cfg.RESUME = args.resume
53 |
54 | if args.seed:
55 | cfg.SEED = args.seed
56 |
57 | if args.source_domains:
58 | cfg.DATASET.SOURCE_DOMAINS = args.source_domains
59 |
60 | if args.target_domains:
61 | cfg.DATASET.TARGET_DOMAINS = args.target_domains
62 |
63 | if args.transforms:
64 | cfg.INPUT.TRANSFORMS = args.transforms
65 |
66 | if args.trainer:
67 | cfg.TRAINER.NAME = args.trainer
68 |
69 | if args.backbone:
70 | cfg.MODEL.BACKBONE.NAME = args.backbone
71 |
72 | if args.head:
73 | cfg.MODEL.HEAD.NAME = args.head
74 |
75 |
76 | def extend_cfg(cfg):
77 | """
78 | Add new config variables.
79 |
80 | E.g.
81 | from yacs.config import CfgNode as CN
82 | cfg.TRAINER.MY_MODEL = CN()
83 | cfg.TRAINER.MY_MODEL.PARAM_A = 1.
84 | cfg.TRAINER.MY_MODEL.PARAM_B = 0.5
85 | cfg.TRAINER.MY_MODEL.PARAM_C = False
86 | """
87 | from yacs.config import CfgNode as CN
88 |
89 | cfg.TRAINER.UPLTrainer = CN()
90 | cfg.TRAINER.UPLTrainer.N_CTX = 16 # number of context vectors
91 | cfg.TRAINER.UPLTrainer.CSC = False # class-specific context
92 | cfg.TRAINER.UPLTrainer.CTX_INIT = "" # initialization words
93 | cfg.TRAINER.UPLTrainer.PREC = "fp16" # fp16, fp32, amp
94 | cfg.TRAINER.UPLTrainer.CLASS_TOKEN_POSITION = "end" # 'middle' or 'end' or 'front'
95 |
96 |
97 | def setup_cfg(args):
98 | cfg = get_cfg_default()
99 | extend_cfg(cfg)
100 |
101 | # 1. From the dataset config file
102 | if args.dataset_config_file:
103 | cfg.merge_from_file(args.dataset_config_file)
104 |
105 | # 2. From the method config file
106 | if args.config_file:
107 | cfg.merge_from_file(args.config_file)
108 |
109 | # 3. hh config
110 | if args.hh_config_file:
111 | cfg.merge_from_file(args.hh_config_file)
112 |
113 | # 3. From input arguments
114 | reset_cfg(cfg, args)
115 |
116 | # 4. From optional input arguments
117 | cfg.merge_from_list(args.opts)
118 |
119 | cfg.freeze()
120 |
121 | return cfg
122 |
123 |
124 | def main(args):
125 | cfg = setup_cfg(args)
126 | if cfg.SEED >= 0:
127 | print('Setting fixed seed: {}'.format(cfg.SEED))
128 | set_random_seed(cfg.SEED)
129 | setup_logger(cfg.OUTPUT_DIR)
130 |
131 | if torch.cuda.is_available() and cfg.USE_CUDA:
132 | torch.backends.cudnn.benchmark = True
133 |
134 | print_args(args, cfg)
135 | print('Collecting env info ...')
136 | print('** System info **\n{}\n'.format(collect_env_info()))
137 |
138 | trainer_list = []
139 | for i in range(int(cfg.TRAINER.ENSEMBLE_NUM)):
140 | trainer = build_trainer(cfg)
141 | if args.model_dir:
142 | trainer.load_model_by_id(args.model_dir, epoch=args.load_epoch, model_id=i)
143 | trainer_list.append(trainer)
144 |
145 |
146 | prob_end = []
147 | results_dict = {}
148 | for SHOTS in [16]:
149 | print('SHOTS:', SHOTS, '\n\n\n\n\n')
150 | for seed in range(1, 17):
151 | try:
152 | prob = torch.load('./analysis_results_test/{}/50_{}_{}_random_initend/test_logits.pt'.format(cfg.DATASET.NAME, seed, SHOTS))
153 | prob_end.append(prob)
154 | except:
155 | print('loss')
156 |
157 | print(len(prob_end), ' shots ensemble')
158 | prob_test = sum(prob_end) / len(prob_end)
159 | results_end = trainer_list[0].test_with_existing_logits(prob_test)
160 |
161 | print('ensemble: {:.2f}'.format((results_end['accuracy'])))
162 | results_dict[len(prob_end)] = results_end['accuracy']
163 |
164 |
165 |
166 | if __name__ == '__main__':
167 | parser = argparse.ArgumentParser()
168 | parser.add_argument('--root', type=str, default='', help='path to dataset')
169 | parser.add_argument(
170 | '--output-dir', type=str, default='', help='output directory'
171 | )
172 | parser.add_argument(
173 | '--resume',
174 | type=str,
175 | default='',
176 | help='checkpoint directory (from which the training resumes)'
177 | )
178 | parser.add_argument(
179 | '--seed',
180 | type=int,
181 | default=-1,
182 | help='only positive value enables a fixed seed'
183 | )
184 | parser.add_argument(
185 | '--source-domains',
186 | type=str,
187 | nargs='+',
188 | help='source domains for DA/DG'
189 | )
190 | parser.add_argument(
191 | '--target-domains',
192 | type=str,
193 | nargs='+',
194 | help='target domains for DA/DG'
195 | )
196 | parser.add_argument(
197 | '--transforms', type=str, nargs='+', help='data augmentation methods'
198 | )
199 | parser.add_argument(
200 | '--config-file', type=str, default='', help='path to config file'
201 | )
202 | parser.add_argument(
203 | '--dataset-config-file',
204 | type=str,
205 | default='',
206 | help='path to config file for dataset setup'
207 | )
208 | parser.add_argument(
209 | '--hh-config-file', type=str, default='', help='path to config file'
210 | )
211 | parser.add_argument(
212 | '--trainer', type=str, default='', help='name of trainer'
213 | )
214 | parser.add_argument(
215 | '--backbone', type=str, default='', help='name of CNN backbone'
216 | )
217 | parser.add_argument('--head', type=str, default='', help='name of head')
218 | parser.add_argument(
219 | '--eval-only', action='store_true', help='evaluation only'
220 | )
221 | parser.add_argument(
222 | '--model-dir',
223 | type=str,
224 | default='',
225 | help='load model from this directory for eval-only mode'
226 | )
227 | parser.add_argument(
228 | '--load-epoch',
229 | type=int,
230 | help='load model weights at this epoch for evaluation'
231 | )
232 | parser.add_argument(
233 | '--no-train', action='store_true', help='do not call trainer.train()'
234 | )
235 | parser.add_argument(
236 | 'opts',
237 | default=None,
238 | nargs=argparse.REMAINDER,
239 | help='modify config options using the command-line'
240 | )
241 | args = parser.parse_args()
242 | main(args)
243 |
--------------------------------------------------------------------------------
/trainers/hhzsclip.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 |
5 | from dassl.engine import TRAINER_REGISTRY, TrainerX
6 | from dassl.optim import build_optimizer, build_lr_scheduler
7 |
8 | from clip import clip
9 | from clip.model import convert_weights
10 |
11 | from .coop import load_clip_to_cpu
12 |
13 | from datasets.data_manager import UPLDataManager
14 |
15 | from .utils import plotLogitsMap, plotPRMap
16 |
17 | CUSTOM_TEMPLATES = {
18 | "OxfordPets": "a photo of a {}, a type of pet.",
19 | "OxfordFlowers": "a photo of a {}, a type of flower.",
20 | "FGVCAircraft": "a photo of a {}, a type of aircraft.",
21 | "DescribableTextures": "{} texture.",
22 | "EuroSAT": "a centered satellite photo of {}.",
23 | "StanfordCars": "a photo of a {}.",
24 | "Food101": "a photo of {}, a type of food.",
25 | "SUN397": "a photo of a {}.",
26 | "Caltech101": "a photo of a {}.",
27 | "UCF101": "a photo of a person doing {}.",
28 | "ImageNet": "a photo of a {}.",
29 | "ImageNetSketch": "a photo of a {}.",
30 | "ImageNetV2": "a photo of a {}.",
31 | "ImageNetA": "a photo of a {}.",
32 | "ImageNetR": "a photo of a {}.",
33 | # semi-supervised templates
34 | "SSOxfordPets": "a photo of a {}, a type of pet.",
35 | "SSOxfordFlowers": "a photo of a {}, a type of flower.",
36 | "SSFGVCAircraft": "a photo of a {}, a type of aircraft.",
37 | "SSDescribableTextures": "{} texture.",
38 | "SSEuroSAT": "a centered satellite photo of {}.",
39 | "SSStanfordCars": "a photo of a {}.",
40 | "SSFood101": "a photo of {}, a type of food.",
41 | "SSSUN397": "a photo of a {}.",
42 | "SSCaltech101": "a photo of a {}.",
43 | "SSUCF101": "a photo of a person doing {}.",
44 | "SSImageNet": "a photo of a {}.",
45 | }
46 |
47 |
48 |
49 |
50 | @TRAINER_REGISTRY.register()
51 | class ZeroshotCLIP(TrainerX):
52 |
53 | def build_data_loader(self):
54 | """Create essential data-related attributes.
55 |
56 | A re-implementation of this method must create the
57 | same attributes (except self.dm).
58 | """
59 | dm = UPLDataManager(self.cfg)
60 |
61 | self.train_loader_x = dm.train_loader_x
62 | self.train_loader_u = dm.train_loader_u # optional, can be None
63 | self.val_loader = dm.val_loader # optional, can be None
64 | self.test_loader = dm.test_loader
65 | self.num_classes = dm.num_classes
66 | self.num_source_domains = dm.num_source_domains
67 | self.lab2cname = dm.lab2cname # dict {label: classname}
68 |
69 | if self.cfg.DATALOADER.OPEN_SETTING:
70 | self.test_novel_loader = dm.test_novel_loader
71 | self.test_base_loader = dm.test_base_loader
72 |
73 | self.dm = dm
74 |
75 | def build_model(self):
76 | cfg = self.cfg
77 | classnames = self.dm.dataset.classnames
78 |
79 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
80 | clip_model = load_clip_to_cpu(cfg)
81 | clip_model.to(self.device)
82 |
83 | temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
84 | prompts = [temp.format(c.replace("_", " ")) for c in classnames]
85 | print(f"Prompts: {prompts}")
86 | prompts = torch.cat([clip.tokenize(p) for p in prompts])
87 | prompts = prompts.to(self.device)
88 |
89 | with torch.no_grad():
90 | text_features = clip_model.encode_text(prompts)
91 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
92 |
93 | self.text_features = text_features
94 | self.clip_model = clip_model
95 |
96 | def model_inference(self, image):
97 | image_features = self.clip_model.encode_image(image)
98 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
99 | logit_scale = self.clip_model.logit_scale.exp()
100 | logits = logit_scale * image_features @ self.text_features.t()
101 | return logits, image_features, self.text_features
102 |
103 | @torch.no_grad()
104 | def test(self, split=None, trainer_list=None):
105 | """A generic testing pipeline."""
106 | # 如果是ensemble的需要添加tag 在实验记录中方便区分
107 | # if trainer_list is not None and "ENSEMBLE:{}".format(self.cfg.TRAINER.ENSEMBLE_NUM) not in self.online_logger.tags:
108 | # self.online_logger.tags += ("ENSEMBLE:{}".format(self.cfg.TRAINER.ENSEMBLE_NUM),)
109 |
110 | # if self.cfg.DATASET.NAME not in self.online_logger.tags:
111 | # self.online_logger.tags += ("DATASET INFERENCE:{}".format(self.cfg.DATASET.NAME),)
112 |
113 | self.set_model_mode("eval")
114 | self.evaluator.reset()
115 |
116 | save_path = os.path.join(self.cfg.TEST.Analyze_Result_Path, self.cfg.DATASET.NAME,
117 | str(self.cfg.OPTIM.MAX_EPOCH)+'_'+str(self.cfg.SEED)+'_'+str(self.cfg.DATASET.NUM_SHOTS))
118 | if not os.path.exists(save_path):
119 | os.makedirs(save_path)
120 |
121 | results_id = 0
122 | while os.path.exists(os.path.join(save_path, 'per_image_results_{}_{}.txt'.format(split, results_id))):
123 | results_id += 1
124 | self.per_image_txt_writer = open(os.path.join(save_path, 'per_image_results_{}_{}.txt'.format(split, results_id)), 'w')
125 | self.per_class_txt_writer = open(os.path.join(save_path, 'per_class_results_{}_{}.txt'.format(split, results_id)), 'w')
126 |
127 | if split is None:
128 | split = self.cfg.TEST.SPLIT
129 |
130 | if split == "val" and self.val_loader is not None:
131 | data_loader = self.val_loader
132 | print("Do evaluation on {} set".format(split))
133 | elif split=="novel":
134 | data_loader = self.test_novel_loader
135 | print("Do evaluation on test novel set")
136 | elif split=="base":
137 | data_loader = self.test_base_loader
138 | print("Do evaluation on test base set")
139 | elif split=="train":
140 | data_loader = self.train_loader_x
141 | print("Do evaluation on train set")
142 | elif split=="all":
143 | data_loader = self.test_loader
144 | print("Do evaluation on test set")
145 | else:
146 | data_loader = self.test_loader
147 | print("Do evaluation on test set")
148 |
149 | outputs_for_tsne = []
150 | label_for_tsne = []
151 | image_features_for_tsne = []
152 | text_features_for_tsne = []
153 | for batch_idx, batch in enumerate(data_loader):
154 | input, label = self.parse_batch_test(batch)
155 | if trainer_list is None or len(trainer_list)==1:
156 | # 如果不是ensemble的测试
157 | output, image_features, text_features = self.model_inference(input)
158 | image_features_for_tsne.append(image_features)
159 | text_features_for_tsne.append(text_features)
160 | else:
161 | # ensemble的测试
162 | outputs = [t.model_inference(input)[0] for t in trainer_list]
163 | output = sum(outputs) / len(outputs)
164 | self.evaluator.process(output, label, self.per_image_txt_writer, self.per_class_txt_writer)
165 | outputs_for_tsne.append(output)
166 | label_for_tsne.append(label)
167 | results = self.evaluator.evaluate()
168 | if split in ['all', 'train', 'test', 'novel', 'base']:
169 | if len(outputs_for_tsne) != 0:
170 | outputs_for_tsne = torch.cat(outputs_for_tsne, dim=0)
171 | print('outputs_for_tsne', outputs_for_tsne.shape)
172 | label_for_tsne = torch.cat(label_for_tsne, dim=0)
173 | print('label_for_tsne', label_for_tsne.shape)
174 | image_features_for_tsne = torch.cat(image_features_for_tsne, dim=0)
175 | text_features_for_tsne = text_features_for_tsne[0]
176 | torch.save(image_features_for_tsne, os.path.join(save_path, '{}_v_features.pt'.format(split)))
177 | torch.save(image_features_for_tsne, os.path.join(save_path, '{}_targets.pt'.format(split)))
178 | torch.save(outputs_for_tsne, os.path.join(save_path, '{}_logits.pt'.format(split)))
179 | torch.save(text_features_for_tsne, os.path.join(save_path, '{}_l_logits.pt'.format(split)))
180 | # plotLogitsMap(outputs_for_tsne, label_for_tsne, os.path.join(save_path, '{}_map.png'.format(split)), self.cfg.DATASET.NAME+'_'+split)
181 | plotPRMap(outputs_for_tsne, label_for_tsne, os.path.join(save_path, '{}_PR.png'.format(split)), self.cfg.DATASET.NAME+'_'+split)
182 |
183 |
184 | # self.online_logger.log({"inputs": wandb.Image('./{}.png'.format(split)), "captions": wandb.Html(split)})
185 | # wandb.run.summary["{}_accuracy".format(split)] = results['accuracy']
186 | self.per_image_txt_writer.close()
187 | self.per_class_txt_writer.close()
188 |
189 |
190 | for k, v in results.items():
191 | tag = "{}/{}".format(split, k)
192 | self.write_scalar(tag, v, self.epoch)
193 |
194 | return list(results.values())[0]
195 |
196 |
--------------------------------------------------------------------------------
/clip/clip.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 | from typing import Union, List
6 |
7 | import torch
8 | from PIL import Image
9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
10 | from tqdm import tqdm
11 |
12 | from .model import build_model
13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
14 |
15 | try:
16 | from torchvision.transforms import InterpolationMode
17 | BICUBIC = InterpolationMode.BICUBIC
18 | except ImportError:
19 | BICUBIC = Image.BICUBIC
20 |
21 |
22 | if torch.__version__.split(".") < ["1", "7", "1"]:
23 | warnings.warn("PyTorch version 1.7.1 or higher is recommended")
24 |
25 |
26 | __all__ = ["available_models", "load", "tokenize"]
27 | _tokenizer = _Tokenizer()
28 |
29 | _MODELS = {
30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
34 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
35 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
36 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
37 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
38 | }
39 |
40 |
41 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
42 | os.makedirs(root, exist_ok=True)
43 | filename = os.path.basename(url)
44 |
45 | expected_sha256 = url.split("/")[-2]
46 | download_target = os.path.join(root, filename)
47 |
48 | if os.path.exists(download_target) and not os.path.isfile(download_target):
49 | raise RuntimeError(f"{download_target} exists and is not a regular file")
50 |
51 | if os.path.isfile(download_target):
52 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
53 | return download_target
54 | else:
55 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
56 |
57 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
58 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
59 | while True:
60 | buffer = source.read(8192)
61 | if not buffer:
62 | break
63 |
64 | output.write(buffer)
65 | loop.update(len(buffer))
66 |
67 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
68 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
69 |
70 | return download_target
71 |
72 |
73 | def _transform(n_px):
74 | return Compose([
75 | Resize(n_px, interpolation=BICUBIC),
76 | CenterCrop(n_px),
77 | lambda image: image.convert("RGB"),
78 | ToTensor(),
79 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
80 | ])
81 |
82 |
83 | def available_models() -> List[str]:
84 | """Returns the names of available CLIP models"""
85 | return list(_MODELS.keys())
86 |
87 |
88 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False):
89 | """Load a CLIP model
90 |
91 | Parameters
92 | ----------
93 | name : str
94 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
95 |
96 | device : Union[str, torch.device]
97 | The device to put the loaded model
98 |
99 | jit : bool
100 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
101 |
102 | Returns
103 | -------
104 | model : torch.nn.Module
105 | The CLIP model
106 |
107 | preprocess : Callable[[PIL.Image], torch.Tensor]
108 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
109 | """
110 | if name in _MODELS:
111 | model_path = _download(_MODELS[name])
112 | elif os.path.isfile(name):
113 | model_path = name
114 | else:
115 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
116 |
117 | try:
118 | # loading JIT archive
119 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
120 | state_dict = None
121 | except RuntimeError:
122 | # loading saved state dict
123 | if jit:
124 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
125 | jit = False
126 | state_dict = torch.load(model_path, map_location="cpu")
127 |
128 | if not jit:
129 | model = build_model(state_dict or model.state_dict()).to(device)
130 | if str(device) == "cpu":
131 | model.float()
132 | return model, _transform(model.visual.input_resolution)
133 |
134 | # patch the device names
135 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
136 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
137 |
138 | def patch_device(module):
139 | try:
140 | graphs = [module.graph] if hasattr(module, "graph") else []
141 | except RuntimeError:
142 | graphs = []
143 |
144 | if hasattr(module, "forward1"):
145 | graphs.append(module.forward1.graph)
146 |
147 | for graph in graphs:
148 | for node in graph.findAllNodes("prim::Constant"):
149 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
150 | node.copyAttributes(device_node)
151 |
152 | model.apply(patch_device)
153 | patch_device(model.encode_image)
154 | patch_device(model.encode_text)
155 |
156 | # patch dtype to float32 on CPU
157 | if str(device) == "cpu":
158 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
159 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
160 | float_node = float_input.node()
161 |
162 | def patch_float(module):
163 | try:
164 | graphs = [module.graph] if hasattr(module, "graph") else []
165 | except RuntimeError:
166 | graphs = []
167 |
168 | if hasattr(module, "forward1"):
169 | graphs.append(module.forward1.graph)
170 |
171 | for graph in graphs:
172 | for node in graph.findAllNodes("aten::to"):
173 | inputs = list(node.inputs())
174 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
175 | if inputs[i].node()["value"] == 5:
176 | inputs[i].node().copyAttributes(float_node)
177 |
178 | model.apply(patch_float)
179 | patch_float(model.encode_image)
180 | patch_float(model.encode_text)
181 |
182 | model.float()
183 |
184 | return model, _transform(model.input_resolution.item())
185 |
186 |
187 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
188 | """
189 | Returns the tokenized representation of given input string(s)
190 |
191 | Parameters
192 | ----------
193 | texts : Union[str, List[str]]
194 | An input string or a list of input strings to tokenize
195 |
196 | context_length : int
197 | The context length to use; all CLIP models use 77 as the context length
198 |
199 | truncate: bool
200 | Whether to truncate the text in case its encoding is longer than the context length
201 |
202 | Returns
203 | -------
204 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
205 | """
206 | if isinstance(texts, str):
207 | texts = [texts]
208 |
209 | sot_token = _tokenizer.encoder["<|startoftext|>"]
210 | eot_token = _tokenizer.encoder["<|endoftext|>"]
211 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
212 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
213 |
214 | for i, tokens in enumerate(all_tokens):
215 | if len(tokens) > context_length:
216 | if truncate:
217 | tokens = tokens[:context_length]
218 | tokens[-1] = eot_token
219 | else:
220 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
221 | result[i, :len(tokens)] = torch.tensor(tokens)
222 |
223 | return result
224 |
--------------------------------------------------------------------------------
/configs/upl_default_config/upl_default.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | ###########################
4 | # Config definition
5 | ###########################
6 |
7 | _C = CN()
8 |
9 | _C.VERSION = 1
10 |
11 | # Directory to save the output files (like log.txt and model weights)
12 | _C.OUTPUT_DIR = "./output"
13 | # Path to a directory where the files were saved previously
14 | _C.RESUME = ""
15 | # Set seed to negative value to randomize everything
16 | # Set seed to positive value to use a fixed seed
17 | _C.SEED = -1
18 | _C.USE_CUDA = True
19 | # Print detailed information
20 | # E.g. trainer, dataset, and backbone
21 | _C.VERBOSE = True
22 |
23 | ###########################
24 | # Input
25 | ###########################
26 | _C.INPUT = CN()
27 | _C.INPUT.SIZE = (224, 224)
28 | # Mode of interpolation in resize functions
29 | _C.INPUT.INTERPOLATION = "bilinear"
30 | # For available choices please refer to transforms.py
31 | _C.INPUT.TRANSFORMS = ()
32 | # If True, tfm_train and tfm_test will be None
33 | _C.INPUT.NO_TRANSFORM = False
34 | # Default mean and std come from ImageNet
35 | _C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406]
36 | _C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225]
37 | # Padding for random crop
38 | _C.INPUT.CROP_PADDING = 4
39 | # Cutout
40 | _C.INPUT.CUTOUT_N = 1
41 | _C.INPUT.CUTOUT_LEN = 16
42 | # Gaussian noise
43 | _C.INPUT.GN_MEAN = 0.0
44 | _C.INPUT.GN_STD = 0.15
45 | # RandomAugment
46 | _C.INPUT.RANDAUGMENT_N = 2
47 | _C.INPUT.RANDAUGMENT_M = 10
48 | # ColorJitter (brightness, contrast, saturation, hue)
49 | _C.INPUT.COLORJITTER_B = 0.4
50 | _C.INPUT.COLORJITTER_C = 0.4
51 | _C.INPUT.COLORJITTER_S = 0.4
52 | _C.INPUT.COLORJITTER_H = 0.1
53 | # Random gray scale's probability
54 | _C.INPUT.RGS_P = 0.2
55 | # Gaussian blur
56 | _C.INPUT.GB_P = 0.5 # propability of applying this operation
57 | _C.INPUT.GB_K = 21 # kernel size (should be an odd number)
58 |
59 | ###########################
60 | # Dataset
61 | ###########################
62 | _C.DATASET = CN()
63 | # Directory where datasets are stored
64 | _C.DATASET.ROOT = ""
65 | _C.DATASET.NAME = ""
66 | # List of names of source domains
67 | _C.DATASET.SOURCE_DOMAINS = ()
68 | # List of names of target domains
69 | _C.DATASET.TARGET_DOMAINS = ()
70 | # Number of labeled instances in total
71 | # Useful for the semi-supervised learning
72 | _C.DATASET.NUM_LABELED = -1
73 | # Number of images per class
74 | _C.DATASET.NUM_SHOTS = -1
75 | # Percentage of validation data (only used for SSL datasets)
76 | # Set to 0 if do not want to use val data
77 | # Using val data for hyperparameter tuning was done in Oliver et al. 2018
78 | _C.DATASET.VAL_PERCENT = 0.1
79 | # Fold index for STL-10 dataset (normal range is 0 - 9)
80 | # Negative number means None
81 | _C.DATASET.STL10_FOLD = -1
82 | # CIFAR-10/100-C's corruption type and intensity level
83 | _C.DATASET.CIFAR_C_TYPE = ""
84 | _C.DATASET.CIFAR_C_LEVEL = 1
85 | # Use all data in the unlabeled data set (e.g. FixMatch)
86 | _C.DATASET.ALL_AS_UNLABELED = False
87 | _C.DATASET.IGNORE_NUM = 0
88 | _C.DATASET.IGNORE_FILE = ""
89 | _C.DATASET.CLASS_EQULE = True
90 | _C.DATASET.CONF_THRESHOLD = 0.9
91 |
92 | ###########################
93 | # Dataloader
94 | ###########################
95 | _C.DATALOADER = CN()
96 | _C.DATALOADER.NUM_WORKERS = 4
97 | # Apply transformations to an image K times (during training)
98 | _C.DATALOADER.K_TRANSFORMS = 1
99 | # img0 denotes image tensor without augmentation
100 | # Useful for consistency learning
101 | _C.DATALOADER.RETURN_IMG0 = False
102 | # Setting for the train_x data-loader
103 | _C.DATALOADER.TRAIN_X = CN()
104 | _C.DATALOADER.TRAIN_X.SAMPLER = "RandomSampler"
105 | _C.DATALOADER.TRAIN_X.BATCH_SIZE = 32
106 | # Parameter for RandomDomainSampler
107 | # 0 or -1 means sampling from all domains
108 | _C.DATALOADER.TRAIN_X.N_DOMAIN = 0
109 | # Parameter of RandomClassSampler
110 | # Number of instances per class
111 | _C.DATALOADER.TRAIN_X.N_INS = 16
112 |
113 | # Setting for the train_u data-loader
114 | _C.DATALOADER.TRAIN_U = CN()
115 | # Set to false if you want to have unique
116 | # data loader params for train_u
117 | _C.DATALOADER.TRAIN_U.SAME_AS_X = True
118 | _C.DATALOADER.TRAIN_U.SAMPLER = "RandomSampler"
119 | _C.DATALOADER.TRAIN_U.BATCH_SIZE = 32
120 | _C.DATALOADER.TRAIN_U.N_DOMAIN = 0
121 | _C.DATALOADER.TRAIN_U.N_INS = 16
122 |
123 | # Setting for the test data-loader
124 | _C.DATALOADER.TEST = CN()
125 | _C.DATALOADER.TEST.SAMPLER = "SequentialSampler"
126 | _C.DATALOADER.TEST.BATCH_SIZE = 32
127 |
128 | _C.DATALOADER.OPEN_SETTING = False
129 |
130 | ###########################
131 | # Model
132 | ###########################
133 | _C.MODEL = CN()
134 | # Path to model weights (for initialization)
135 | _C.MODEL.INIT_WEIGHTS = ""
136 | _C.MODEL.BACKBONE = CN()
137 | _C.MODEL.BACKBONE.NAME = ""
138 | _C.MODEL.BACKBONE.PRETRAINED = True
139 | # Definition of embedding layers
140 | _C.MODEL.HEAD = CN()
141 | # If none, do not construct embedding layers, the
142 | # backbone's output will be passed to the classifier
143 | _C.MODEL.HEAD.NAME = ""
144 | # Structure of hidden layers (a list), e.g. [512, 512]
145 | # If undefined, no embedding layer will be constructed
146 | _C.MODEL.HEAD.HIDDEN_LAYERS = ()
147 | _C.MODEL.HEAD.ACTIVATION = "relu"
148 | _C.MODEL.HEAD.BN = True
149 | _C.MODEL.HEAD.DROPOUT = 0.0
150 |
151 | _C.MODEL.PSEUDO_LABEL_MODELS =[]
152 |
153 | ###########################
154 | # Optimization
155 | ###########################
156 | _C.OPTIM = CN()
157 | _C.OPTIM.NAME = "adam"
158 | _C.OPTIM.LR = 0.0003
159 | _C.OPTIM.WEIGHT_DECAY = 5e-4
160 | _C.OPTIM.MOMENTUM = 0.9
161 | _C.OPTIM.SGD_DAMPNING = 0
162 | _C.OPTIM.SGD_NESTEROV = False
163 | _C.OPTIM.RMSPROP_ALPHA = 0.99
164 | _C.OPTIM.ADAM_BETA1 = 0.9
165 | _C.OPTIM.ADAM_BETA2 = 0.999
166 | # STAGED_LR allows different layers to have
167 | # different lr, e.g. pre-trained base layers
168 | # can be assigned a smaller lr than the new
169 | # classification layer
170 | _C.OPTIM.STAGED_LR = False
171 | _C.OPTIM.NEW_LAYERS = ()
172 | _C.OPTIM.BASE_LR_MULT = 0.1
173 | # Learning rate scheduler
174 | _C.OPTIM.LR_SCHEDULER = "single_step"
175 | # -1 or 0 means the stepsize is equal to max_epoch
176 | _C.OPTIM.STEPSIZE = (-1, )
177 | _C.OPTIM.GAMMA = 0.1
178 | _C.OPTIM.MAX_EPOCH = 10
179 | # Set WARMUP_EPOCH larger than 0 to activate warmup training
180 | _C.OPTIM.WARMUP_EPOCH = -1
181 | # Either linear or constant
182 | _C.OPTIM.WARMUP_TYPE = "linear"
183 | # Constant learning rate when type=constant
184 | _C.OPTIM.WARMUP_CONS_LR = 1e-5
185 | # Minimum learning rate when type=linear
186 | _C.OPTIM.WARMUP_MIN_LR = 1e-5
187 | # Recount epoch for the next scheduler (last_epoch=-1)
188 | # Otherwise last_epoch=warmup_epoch
189 | _C.OPTIM.WARMUP_RECOUNT = True
190 |
191 | ###########################
192 | # Train
193 | ###########################
194 | _C.TRAIN = CN()
195 | # How often (epoch) to save model during training
196 | # Set to 0 or negative value to only save the last one
197 | _C.TRAIN.CHECKPOINT_FREQ = 0
198 | # How often (batch) to print training information
199 | _C.TRAIN.PRINT_FREQ = 10
200 | # Use 'train_x', 'train_u' or 'smaller_one' to count
201 | # the number of iterations in an epoch (for DA and SSL)
202 | _C.TRAIN.COUNT_ITER = "train_x"
203 |
204 | ###########################
205 | # Test
206 | ###########################
207 | _C.TEST = CN()
208 | _C.TEST.EVALUATOR = "UPLClassification"
209 | _C.TEST.PER_CLASS_RESULT = False
210 | # Compute confusion matrix, which will be saved
211 | # to $OUTPUT_DIR/cmat.pt
212 | _C.TEST.COMPUTE_CMAT = False
213 | # If NO_TEST=True, no testing will be conducted
214 | _C.TEST.NO_TEST = False
215 | # Use test or val set for FINAL evaluation
216 | _C.TEST.SPLIT = "test"
217 | # Which model to test after training
218 | # Either last_step or best_val
219 | _C.TEST.FINAL_MODEL = "last_step"
220 | _C.TEST.Analyze_Result_Path = None
221 |
222 | ###########################
223 | # Trainer specifics
224 | ###########################
225 | _C.TRAINER = CN()
226 | _C.TRAINER.NAME = ""
227 |
228 | # MCD
229 | _C.TRAINER.MCD = CN()
230 | _C.TRAINER.MCD.N_STEP_F = 4 # number of steps to train F
231 | # MME
232 | _C.TRAINER.MME = CN()
233 | _C.TRAINER.MME.LMDA = 0.1 # weight for the entropy loss
234 | # SelfEnsembling
235 | _C.TRAINER.SE = CN()
236 | _C.TRAINER.SE.EMA_ALPHA = 0.999
237 | _C.TRAINER.SE.CONF_THRE = 0.95
238 | _C.TRAINER.SE.RAMPUP = 300
239 |
240 | # M3SDA
241 | _C.TRAINER.M3SDA = CN()
242 | _C.TRAINER.M3SDA.LMDA = 0.5 # weight for the moment distance loss
243 | _C.TRAINER.M3SDA.N_STEP_F = 4 # follow MCD
244 | # DAEL
245 | _C.TRAINER.DAEL = CN()
246 | _C.TRAINER.DAEL.WEIGHT_U = 0.5 # weight on the unlabeled loss
247 | _C.TRAINER.DAEL.CONF_THRE = 0.95 # confidence threshold
248 | _C.TRAINER.DAEL.STRONG_TRANSFORMS = ()
249 |
250 | # CrossGrad
251 | _C.TRAINER.CG = CN()
252 | _C.TRAINER.CG.EPS_F = 1.0 # scaling parameter for D's gradients
253 | _C.TRAINER.CG.EPS_D = 1.0 # scaling parameter for F's gradients
254 | _C.TRAINER.CG.ALPHA_F = 0.5 # balancing weight for the label net's loss
255 | _C.TRAINER.CG.ALPHA_D = 0.5 # balancing weight for the domain net's loss
256 | # DDAIG
257 | _C.TRAINER.DDAIG = CN()
258 | _C.TRAINER.DDAIG.G_ARCH = "" # generator's architecture
259 | _C.TRAINER.DDAIG.LMDA = 0.3 # perturbation weight
260 | _C.TRAINER.DDAIG.CLAMP = False # clamp perturbation values
261 | _C.TRAINER.DDAIG.CLAMP_MIN = -1.0
262 | _C.TRAINER.DDAIG.CLAMP_MAX = 1.0
263 | _C.TRAINER.DDAIG.WARMUP = 0
264 | _C.TRAINER.DDAIG.ALPHA = 0.5 # balancing weight for the losses
265 |
266 | # EntMin
267 | _C.TRAINER.ENTMIN = CN()
268 | _C.TRAINER.ENTMIN.LMDA = 1e-3 # weight on the entropy loss
269 | # Mean Teacher
270 | _C.TRAINER.MEANTEA = CN()
271 | _C.TRAINER.MEANTEA.WEIGHT_U = 1.0 # weight on the unlabeled loss
272 | _C.TRAINER.MEANTEA.EMA_ALPHA = 0.999
273 | _C.TRAINER.MEANTEA.RAMPUP = 5 # epochs used to ramp up the loss_u weight
274 | # MixMatch
275 | _C.TRAINER.MIXMATCH = CN()
276 | _C.TRAINER.MIXMATCH.WEIGHT_U = 100.0 # weight on the unlabeled loss
277 | _C.TRAINER.MIXMATCH.TEMP = 2.0 # temperature for sharpening the probability
278 | _C.TRAINER.MIXMATCH.MIXUP_BETA = 0.75
279 | _C.TRAINER.MIXMATCH.RAMPUP = 20000 # steps used to ramp up the loss_u weight
280 | # FixMatch
281 | _C.TRAINER.FIXMATCH = CN()
282 | _C.TRAINER.FIXMATCH.WEIGHT_U = 1.0 # weight on the unlabeled loss
283 | _C.TRAINER.FIXMATCH.CONF_THRE = 0.95 # confidence threshold
284 | _C.TRAINER.FIXMATCH.STRONG_TRANSFORMS = ()
285 |
286 | _C.TRAINER.ENSEMBLE_NUM = 1
287 |
288 |
289 | def get_cfg_default():
290 | return _C.clone()
--------------------------------------------------------------------------------
/trainers/coop.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import functional as F
6 | from torch.cuda.amp import GradScaler, autocast
7 |
8 | from dassl.engine import TRAINER_REGISTRY, TrainerX
9 | from dassl.metrics import compute_accuracy
10 | from dassl.utils import load_pretrained_weights, load_checkpoint
11 | from dassl.optim import build_optimizer, build_lr_scheduler
12 |
13 | from clip import clip
14 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
15 |
16 | _tokenizer = _Tokenizer()
17 |
18 |
19 | def load_clip_to_cpu(cfg):
20 | backbone_name = cfg.MODEL.BACKBONE.NAME
21 | url = clip._MODELS[backbone_name]
22 | model_path = clip._download(url)
23 |
24 | try:
25 | # loading JIT archive
26 | model = torch.jit.load(model_path, map_location="cpu").eval()
27 | state_dict = None
28 |
29 | except RuntimeError:
30 | state_dict = torch.load(model_path, map_location="cpu")
31 |
32 | model = clip.build_model(state_dict or model.state_dict())
33 |
34 | return model
35 |
36 |
37 | class TextEncoder(nn.Module):
38 | def __init__(self, clip_model):
39 | super().__init__()
40 | self.transformer = clip_model.transformer
41 | self.positional_embedding = clip_model.positional_embedding
42 | self.ln_final = clip_model.ln_final
43 | self.text_projection = clip_model.text_projection
44 | self.dtype = clip_model.dtype
45 |
46 | def forward(self, prompts, tokenized_prompts):
47 | x = prompts + self.positional_embedding.type(self.dtype)
48 | x = x.permute(1, 0, 2) # NLD -> LND
49 | x = self.transformer(x)
50 | x = x.permute(1, 0, 2) # LND -> NLD
51 | x = self.ln_final(x).type(self.dtype)
52 |
53 | # x.shape = [batch_size, n_ctx, transformer.width]
54 | # take features from the eot embedding (eot_token is the highest number in each sequence)
55 | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
56 |
57 | return x
58 |
59 |
60 | class PromptLearner(nn.Module):
61 | def __init__(self, cfg, classnames, clip_model):
62 | super().__init__()
63 | n_cls = len(classnames)
64 | n_ctx = cfg.TRAINER.COOP.N_CTX
65 | ctx_init = cfg.TRAINER.COOP.CTX_INIT
66 | dtype = clip_model.dtype
67 | ctx_dim = clip_model.ln_final.weight.shape[0]
68 | clip_imsize = clip_model.visual.input_resolution
69 | cfg_imsize = cfg.INPUT.SIZE[0]
70 | assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
71 |
72 | if ctx_init:
73 | # use given words to initialize context vectors
74 | ctx_init = ctx_init.replace("_", " ")
75 | n_ctx = len(ctx_init.split(" "))
76 | prompt = clip.tokenize(ctx_init)
77 | with torch.no_grad():
78 | embedding = clip_model.token_embedding(prompt).type(dtype)
79 | ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
80 | prompt_prefix = ctx_init
81 |
82 | else:
83 | # random initialization
84 | if cfg.TRAINER.COOP.CSC:
85 | print("Initializing class-specific contexts")
86 | ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype)
87 | else:
88 | print("Initializing a generic context")
89 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
90 | nn.init.normal_(ctx_vectors, std=0.02)
91 | prompt_prefix = " ".join(["X"] * n_ctx)
92 |
93 | print(f'Initial context: "{prompt_prefix}"')
94 | print(f"Number of context words (tokens): {n_ctx}")
95 |
96 | self.ctx = nn.Parameter(ctx_vectors) # to be optimized
97 |
98 | classnames = [name.replace("_", " ") for name in classnames]
99 | name_lens = [len(_tokenizer.encode(name)) for name in classnames]
100 | prompts = [prompt_prefix + " " + name + "." for name in classnames]
101 |
102 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
103 | with torch.no_grad():
104 | embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
105 |
106 | # These token vectors will be saved when in save_model(),
107 | # but they should be ignored in load_model() as we want to use
108 | # those computed using the current class names
109 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
110 | self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS
111 |
112 | self.n_cls = n_cls
113 | self.n_ctx = n_ctx
114 | self.tokenized_prompts = tokenized_prompts # torch.Tensor
115 | self.name_lens = name_lens
116 | self.class_token_position = cfg.TRAINER.COOP.CLASS_TOKEN_POSITION
117 |
118 | def forward(self):
119 | ctx = self.ctx
120 | if ctx.dim() == 2:
121 | ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
122 |
123 | prefix = self.token_prefix
124 | suffix = self.token_suffix
125 |
126 | if self.class_token_position == "end":
127 | prompts = torch.cat(
128 | [
129 | prefix, # (n_cls, 1, dim)
130 | ctx, # (n_cls, n_ctx, dim)
131 | suffix, # (n_cls, *, dim)
132 | ],
133 | dim=1,
134 | )
135 |
136 | elif self.class_token_position == "middle":
137 | half_n_ctx = self.n_ctx // 2
138 | prompts = []
139 | for i in range(self.n_cls):
140 | name_len = self.name_lens[i]
141 | prefix_i = prefix[i : i + 1, :, :]
142 | class_i = suffix[i : i + 1, :name_len, :]
143 | suffix_i = suffix[i : i + 1, name_len:, :]
144 | ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
145 | ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
146 | prompt = torch.cat(
147 | [
148 | prefix_i, # (1, 1, dim)
149 | ctx_i_half1, # (1, n_ctx//2, dim)
150 | class_i, # (1, name_len, dim)
151 | ctx_i_half2, # (1, n_ctx//2, dim)
152 | suffix_i, # (1, *, dim)
153 | ],
154 | dim=1,
155 | )
156 | prompts.append(prompt)
157 | prompts = torch.cat(prompts, dim=0)
158 |
159 | elif self.class_token_position == "front":
160 | prompts = []
161 | for i in range(self.n_cls):
162 | name_len = self.name_lens[i]
163 | prefix_i = prefix[i : i + 1, :, :]
164 | class_i = suffix[i : i + 1, :name_len, :]
165 | suffix_i = suffix[i : i + 1, name_len:, :]
166 | ctx_i = ctx[i : i + 1, :, :]
167 | prompt = torch.cat(
168 | [
169 | prefix_i, # (1, 1, dim)
170 | class_i, # (1, name_len, dim)
171 | ctx_i, # (1, n_ctx, dim)
172 | suffix_i, # (1, *, dim)
173 | ],
174 | dim=1,
175 | )
176 | prompts.append(prompt)
177 | prompts = torch.cat(prompts, dim=0)
178 |
179 | else:
180 | raise ValueError
181 |
182 | return prompts
183 |
184 |
185 | class CustomCLIP(nn.Module):
186 | def __init__(self, cfg, classnames, clip_model):
187 | super().__init__()
188 | self.prompt_learner = PromptLearner(cfg, classnames, clip_model)
189 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts
190 | self.image_encoder = clip_model.visual
191 | self.text_encoder = TextEncoder(clip_model)
192 | self.logit_scale = clip_model.logit_scale
193 | self.dtype = clip_model.dtype
194 |
195 | def forward(self, image):
196 | image_features = self.image_encoder(image.type(self.dtype))
197 |
198 | prompts = self.prompt_learner()
199 | tokenized_prompts = self.tokenized_prompts
200 | text_features = self.text_encoder(prompts, tokenized_prompts)
201 |
202 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
203 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
204 |
205 | logit_scale = self.logit_scale.exp()
206 | logits = logit_scale * image_features @ text_features.t()
207 |
208 | return logits
209 |
210 |
211 | @TRAINER_REGISTRY.register()
212 | class CoOp(TrainerX):
213 | """Context Optimization (CoOp).
214 |
215 | Learning to Prompt for Vision-Language Models
216 | https://arxiv.org/abs/2109.01134
217 | """
218 |
219 | def check_cfg(self, cfg):
220 | assert cfg.TRAINER.COOP.PREC in ["fp16", "fp32", "amp"]
221 |
222 | def build_model(self):
223 | cfg = self.cfg
224 | classnames = self.dm.dataset.classnames
225 |
226 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
227 | clip_model = load_clip_to_cpu(cfg)
228 |
229 | if cfg.TRAINER.COOP.PREC == "fp32" or cfg.TRAINER.COOP.PREC == "amp":
230 | # CLIP's default precision is fp16
231 | clip_model.float()
232 |
233 | print("Building custom CLIP")
234 | self.model = CustomCLIP(cfg, classnames, clip_model)
235 |
236 | print("Turning off gradients in both the image and the text encoder")
237 | for name, param in self.model.named_parameters():
238 | if "prompt_learner" not in name:
239 | param.requires_grad_(False)
240 |
241 | if cfg.MODEL.INIT_WEIGHTS:
242 | load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS)
243 |
244 | self.model.to(self.device)
245 | # NOTE: only give prompt_learner to the optimizer
246 | self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)
247 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
248 | self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched)
249 |
250 | self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == "amp" else None
251 |
252 | # Note that multi-gpu training could be slow because CLIP's size is
253 | # big, which slows down the copy operation in DataParallel
254 | device_count = torch.cuda.device_count()
255 | if device_count > 1:
256 | print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
257 | self.model = nn.DataParallel(self.model)
258 |
259 | def forward_backward(self, batch):
260 | image, label = self.parse_batch_train(batch)
261 |
262 | prec = self.cfg.TRAINER.COOP.PREC
263 | if prec == "amp":
264 | with autocast():
265 | output, image_features, text_features = self.model(image)
266 | loss = F.cross_entropy(output, label)
267 | self.optim.zero_grad()
268 | self.scaler.scale(loss).backward()
269 | self.scaler.step(self.optim)
270 | self.scaler.update()
271 | else:
272 | output = self.model(image)
273 | loss = F.cross_entropy(output, label)
274 | self.model_backward_and_update(loss)
275 |
276 | loss_summary = {
277 | "loss": loss.item(),
278 | "acc": compute_accuracy(output, label)[0].item(),
279 | }
280 |
281 | if (self.batch_idx + 1) == self.num_batches:
282 | self.update_lr()
283 |
284 | return loss_summary
285 |
286 | def parse_batch_train(self, batch):
287 | input = batch["img"]
288 | label = batch["label"]
289 | input = input.to(self.device)
290 | label = label.to(self.device)
291 | return input, label
292 |
293 | def load_model(self, directory, epoch=None):
294 | if not directory:
295 | print("Note that load_model() is skipped as no pretrained model is given")
296 | return
297 |
298 | names = self.get_model_names()
299 |
300 | # By default, the best model is loaded
301 | model_file = "model-best.pth.tar"
302 |
303 | if epoch is not None:
304 | model_file = "model.pth.tar-" + str(epoch)
305 |
306 | for name in names:
307 | model_path = osp.join(directory, name, model_file)
308 |
309 | if not osp.exists(model_path):
310 | raise FileNotFoundError('Model not found at "{}"'.format(model_path))
311 |
312 | checkpoint = load_checkpoint(model_path)
313 | state_dict = checkpoint["state_dict"]
314 | epoch = checkpoint["epoch"]
315 |
316 | # Ignore fixed token vectors
317 | if "token_prefix" in state_dict:
318 | del state_dict["token_prefix"]
319 |
320 | if "token_suffix" in state_dict:
321 | del state_dict["token_suffix"]
322 |
323 | print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
324 | # set strict=False
325 | self._models[name].load_state_dict(state_dict, strict=False)
326 |
327 | def load_model_by_id(self, directory, model_id, epoch=None):
328 | if not directory:
329 | print(
330 | 'Note that load_model() is skipped as no pretrained model is given'
331 | )
332 | return
333 |
334 | names = self.get_model_names()
335 |
336 | # By default, the best model is loaded
337 | model_file = 'model-best-{}.pth.tar'.format(model_id)
338 |
339 | if epoch is not None:
340 | model_file = 'model.pth.tar-' + str(epoch)
341 |
342 | for name in names:
343 | model_path = osp.join(directory, name, model_file)
344 |
345 | if not osp.exists(model_path):
346 | raise FileNotFoundError(
347 | 'Model not found at "{}"'.format(model_path)
348 | )
349 |
350 | checkpoint = load_checkpoint(model_path)
351 | state_dict = checkpoint['state_dict']
352 | epoch = checkpoint['epoch']
353 |
354 | # Ignore fixed token vectors
355 | if 'token_prefix' in state_dict:
356 | del state_dict['token_prefix']
357 |
358 | if 'token_suffix' in state_dict:
359 | del state_dict['token_suffix']
360 |
361 | print(
362 | 'Loading weights to {} '
363 | 'from "{}" (epoch = {})'.format(name, model_path, epoch)
364 | )
365 | # set strict=False
366 | self._models[name].load_state_dict(state_dict, strict=False)
367 |
368 |
--------------------------------------------------------------------------------
/clip/model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 |
9 |
10 | class Bottleneck(nn.Module):
11 | expansion = 4
12 |
13 | def __init__(self, inplanes, planes, stride=1):
14 | super().__init__()
15 |
16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 |
20 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
21 | self.bn2 = nn.BatchNorm2d(planes)
22 |
23 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
24 |
25 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
26 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
27 |
28 | self.relu = nn.ReLU(inplace=True)
29 | self.downsample = None
30 | self.stride = stride
31 |
32 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
33 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
34 | self.downsample = nn.Sequential(OrderedDict([
35 | ("-1", nn.AvgPool2d(stride)),
36 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
37 | ("1", nn.BatchNorm2d(planes * self.expansion))
38 | ]))
39 |
40 | def forward(self, x: torch.Tensor):
41 | identity = x
42 |
43 | out = self.relu(self.bn1(self.conv1(x)))
44 | out = self.relu(self.bn2(self.conv2(out)))
45 | out = self.avgpool(out)
46 | out = self.bn3(self.conv3(out))
47 |
48 | if self.downsample is not None:
49 | identity = self.downsample(x)
50 |
51 | out += identity
52 | out = self.relu(out)
53 | return out
54 |
55 |
56 | class AttentionPool2d(nn.Module):
57 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
58 | super().__init__()
59 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
60 | self.k_proj = nn.Linear(embed_dim, embed_dim)
61 | self.q_proj = nn.Linear(embed_dim, embed_dim)
62 | self.v_proj = nn.Linear(embed_dim, embed_dim)
63 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
64 | self.num_heads = num_heads
65 |
66 | def forward(self, x):
67 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
68 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
69 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
70 | x, _ = F.multi_head_attention_forward(
71 | query=x, key=x, value=x,
72 | embed_dim_to_check=x.shape[-1],
73 | num_heads=self.num_heads,
74 | q_proj_weight=self.q_proj.weight,
75 | k_proj_weight=self.k_proj.weight,
76 | v_proj_weight=self.v_proj.weight,
77 | in_proj_weight=None,
78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
79 | bias_k=None,
80 | bias_v=None,
81 | add_zero_attn=False,
82 | dropout_p=0,
83 | out_proj_weight=self.c_proj.weight,
84 | out_proj_bias=self.c_proj.bias,
85 | use_separate_proj_weight=True,
86 | training=self.training,
87 | need_weights=False
88 | )
89 |
90 | return x[0]
91 |
92 |
93 | class ModifiedResNet(nn.Module):
94 | """
95 | A ResNet class that is similar to torchvision's but contains the following changes:
96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
98 | - The final pooling layer is a QKV attention instead of an average pool
99 | """
100 |
101 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
102 | super().__init__()
103 | self.output_dim = output_dim
104 | self.input_resolution = input_resolution
105 |
106 | # the 3-layer stem
107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
108 | self.bn1 = nn.BatchNorm2d(width // 2)
109 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
110 | self.bn2 = nn.BatchNorm2d(width // 2)
111 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
112 | self.bn3 = nn.BatchNorm2d(width)
113 | self.avgpool = nn.AvgPool2d(2)
114 | self.relu = nn.ReLU(inplace=True)
115 |
116 | # residual layers
117 | self._inplanes = width # this is a *mutable* variable used during construction
118 | self.layer1 = self._make_layer(width, layers[0])
119 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
120 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
121 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
122 |
123 | embed_dim = width * 32 # the ResNet feature dimension
124 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
125 |
126 | def _make_layer(self, planes, blocks, stride=1):
127 | layers = [Bottleneck(self._inplanes, planes, stride)]
128 |
129 | self._inplanes = planes * Bottleneck.expansion
130 | for _ in range(1, blocks):
131 | layers.append(Bottleneck(self._inplanes, planes))
132 |
133 | return nn.Sequential(*layers)
134 |
135 | def forward(self, x):
136 | def stem(x):
137 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
138 | x = self.relu(bn(conv(x)))
139 | x = self.avgpool(x)
140 | return x
141 |
142 | x = x.type(self.conv1.weight.dtype)
143 | x = stem(x)
144 | x = self.layer1(x)
145 | x = self.layer2(x)
146 | x = self.layer3(x)
147 | x = self.layer4(x)
148 | x = self.attnpool(x)
149 |
150 | return x
151 |
152 |
153 | class LayerNorm(nn.LayerNorm):
154 | """Subclass torch's LayerNorm to handle fp16."""
155 |
156 | def forward(self, x: torch.Tensor):
157 | orig_type = x.dtype
158 | ret = super().forward(x.type(torch.float32))
159 | return ret.type(orig_type)
160 |
161 |
162 | class QuickGELU(nn.Module):
163 | def forward(self, x: torch.Tensor):
164 | return x * torch.sigmoid(1.702 * x)
165 |
166 |
167 | class ResidualAttentionBlock(nn.Module):
168 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
169 | super().__init__()
170 |
171 | self.attn = nn.MultiheadAttention(d_model, n_head)
172 | self.ln_1 = LayerNorm(d_model)
173 | self.mlp = nn.Sequential(OrderedDict([
174 | ("c_fc", nn.Linear(d_model, d_model * 4)),
175 | ("gelu", QuickGELU()),
176 | ("c_proj", nn.Linear(d_model * 4, d_model))
177 | ]))
178 | self.ln_2 = LayerNorm(d_model)
179 | self.attn_mask = attn_mask
180 |
181 | def attention(self, x: torch.Tensor):
182 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
183 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
184 |
185 | def forward(self, x: torch.Tensor):
186 | x = x + self.attention(self.ln_1(x))
187 | x = x + self.mlp(self.ln_2(x))
188 | return x
189 |
190 |
191 | class Transformer(nn.Module):
192 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
193 | super().__init__()
194 | self.width = width
195 | self.layers = layers
196 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
197 |
198 | def forward(self, x: torch.Tensor):
199 | return self.resblocks(x)
200 |
201 |
202 | class VisionTransformer(nn.Module):
203 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
204 | super().__init__()
205 | self.input_resolution = input_resolution
206 | self.output_dim = output_dim
207 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
208 |
209 | scale = width ** -0.5
210 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
211 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
212 | self.ln_pre = LayerNorm(width)
213 |
214 | self.transformer = Transformer(width, layers, heads)
215 |
216 | self.ln_post = LayerNorm(width)
217 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
218 |
219 | def forward(self, x: torch.Tensor):
220 | x = self.conv1(x) # shape = [*, width, grid, grid]
221 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
222 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
223 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
224 | x = x + self.positional_embedding.to(x.dtype)
225 | x = self.ln_pre(x)
226 |
227 | x = x.permute(1, 0, 2) # NLD -> LND
228 | x = self.transformer(x)
229 | x = x.permute(1, 0, 2) # LND -> NLD
230 |
231 | x = self.ln_post(x[:, 0, :])
232 |
233 | if self.proj is not None:
234 | x = x @ self.proj
235 |
236 | return x
237 |
238 |
239 | class CLIP(nn.Module):
240 | def __init__(self,
241 | embed_dim: int,
242 | # vision
243 | image_resolution: int,
244 | vision_layers: Union[Tuple[int, int, int, int], int],
245 | vision_width: int,
246 | vision_patch_size: int,
247 | # text
248 | context_length: int,
249 | vocab_size: int,
250 | transformer_width: int,
251 | transformer_heads: int,
252 | transformer_layers: int
253 | ):
254 | super().__init__()
255 |
256 | self.context_length = context_length
257 |
258 | if isinstance(vision_layers, (tuple, list)):
259 | vision_heads = vision_width * 32 // 64
260 | self.visual = ModifiedResNet(
261 | layers=vision_layers,
262 | output_dim=embed_dim,
263 | heads=vision_heads,
264 | input_resolution=image_resolution,
265 | width=vision_width
266 | )
267 | else:
268 | vision_heads = vision_width // 64
269 | self.visual = VisionTransformer(
270 | input_resolution=image_resolution,
271 | patch_size=vision_patch_size,
272 | width=vision_width,
273 | layers=vision_layers,
274 | heads=vision_heads,
275 | output_dim=embed_dim
276 | )
277 |
278 | self.transformer = Transformer(
279 | width=transformer_width,
280 | layers=transformer_layers,
281 | heads=transformer_heads,
282 | attn_mask=self.build_attention_mask()
283 | )
284 |
285 | self.vocab_size = vocab_size
286 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
287 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
288 | self.ln_final = LayerNorm(transformer_width)
289 |
290 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
291 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
292 |
293 | self.initialize_parameters()
294 |
295 | def initialize_parameters(self):
296 | nn.init.normal_(self.token_embedding.weight, std=0.02)
297 | nn.init.normal_(self.positional_embedding, std=0.01)
298 |
299 | if isinstance(self.visual, ModifiedResNet):
300 | if self.visual.attnpool is not None:
301 | std = self.visual.attnpool.c_proj.in_features ** -0.5
302 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
303 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
304 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
305 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
306 |
307 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
308 | for name, param in resnet_block.named_parameters():
309 | if name.endswith("bn3.weight"):
310 | nn.init.zeros_(param)
311 |
312 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
313 | attn_std = self.transformer.width ** -0.5
314 | fc_std = (2 * self.transformer.width) ** -0.5
315 | for block in self.transformer.resblocks:
316 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
317 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
318 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
319 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
320 |
321 | if self.text_projection is not None:
322 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
323 |
324 | def build_attention_mask(self):
325 | # lazily create causal attention mask, with full attention between the vision tokens
326 | # pytorch uses additive attention mask; fill with -inf
327 | mask = torch.empty(self.context_length, self.context_length)
328 | mask.fill_(float("-inf"))
329 | mask.triu_(1) # zero out the lower diagonal
330 | return mask
331 |
332 | @property
333 | def dtype(self):
334 | return self.visual.conv1.weight.dtype
335 |
336 | def encode_image(self, image):
337 | return self.visual(image.type(self.dtype))
338 |
339 | def encode_text(self, text):
340 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
341 |
342 | x = x + self.positional_embedding.type(self.dtype)
343 | x = x.permute(1, 0, 2) # NLD -> LND
344 | x = self.transformer(x)
345 | x = x.permute(1, 0, 2) # LND -> NLD
346 | x = self.ln_final(x).type(self.dtype)
347 |
348 | # x.shape = [batch_size, n_ctx, transformer.width]
349 | # take features from the eot embedding (eot_token is the highest number in each sequence)
350 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
351 |
352 | return x
353 |
354 | def forward(self, image, text):
355 | image_features = self.encode_image(image)
356 | text_features = self.encode_text(text)
357 |
358 | # normalized features
359 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
360 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
361 |
362 | # cosine similarity as logits
363 | logit_scale = self.logit_scale.exp()
364 | logits_per_image = logit_scale * image_features @ text_features.t()
365 | logits_per_text = logit_scale * text_features @ image_features.t()
366 |
367 | # shape = [global_batch_size, global_batch_size]
368 | return logits_per_image, logits_per_text
369 |
370 |
371 | def convert_weights(model: nn.Module):
372 | """Convert applicable model parameters to fp16"""
373 |
374 | def _convert_weights_to_fp16(l):
375 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
376 | l.weight.data = l.weight.data.half()
377 | if l.bias is not None:
378 | l.bias.data = l.bias.data.half()
379 |
380 | if isinstance(l, nn.MultiheadAttention):
381 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
382 | tensor = getattr(l, attr)
383 | if tensor is not None:
384 | tensor.data = tensor.data.half()
385 |
386 | for name in ["text_projection", "proj"]:
387 | if hasattr(l, name):
388 | attr = getattr(l, name)
389 | if attr is not None:
390 | attr.data = attr.data.half()
391 |
392 | model.apply(_convert_weights_to_fp16)
393 |
394 |
395 | def build_model(state_dict: dict):
396 | vit = "visual.proj" in state_dict
397 |
398 | if vit:
399 | vision_width = state_dict["visual.conv1.weight"].shape[0]
400 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
401 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
402 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
403 | image_resolution = vision_patch_size * grid_size
404 | else:
405 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
406 | vision_layers = tuple(counts)
407 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
408 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
409 | vision_patch_size = None
410 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
411 | image_resolution = output_width * 32
412 |
413 | embed_dim = state_dict["text_projection"].shape[1]
414 | context_length = state_dict["positional_embedding"].shape[0]
415 | vocab_size = state_dict["token_embedding.weight"].shape[0]
416 | transformer_width = state_dict["ln_final.weight"].shape[0]
417 | transformer_heads = transformer_width // 64
418 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
419 |
420 | model = CLIP(
421 | embed_dim,
422 | image_resolution, vision_layers, vision_width, vision_patch_size,
423 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
424 | )
425 |
426 | for key in ["input_resolution", "context_length", "vocab_size"]:
427 | if key in state_dict:
428 | del state_dict[key]
429 |
430 | convert_weights(model)
431 | model.load_state_dict(state_dict)
432 | return model.eval()
433 |
--------------------------------------------------------------------------------
/trainers/utils.py:
--------------------------------------------------------------------------------
1 | from tkinter.tix import Tree
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import os
5 | import json
6 | import torch
7 | from sklearn.metrics import precision_recall_curve
8 |
9 |
10 | def plotLogitsMap(outputs, label, save_path, fig_title, max_lines=1000):
11 | fig, ax = plt.subplots(figsize=(5, 200))
12 | Softmax = torch.nn.Softmax(dim=1)
13 | output_m = Softmax(outputs)
14 | output_m = outputs.cpu().detach().numpy()
15 |
16 | pred = outputs.max(1)[1]
17 | matches = pred.eq(label).float()
18 | output_m = np.sort(output_m)
19 | output_m = output_m[:,::-1] # 从大到小排序
20 | output_m = output_m[:,:5] # 取前五个
21 | output_m_index = output_m[:,0].argsort()
22 | output_m = output_m[output_m_index]
23 | output_m = output_m[::-1,:] # 按第一列从大到小排序
24 | matches = matches[output_m_index]
25 | matches = torch.flip(matches, dims=[0])
26 | matches = matches.cpu().detach().numpy()
27 |
28 |
29 | if len(matches) > max_lines:
30 | gap = int(len(matches) / 1000)
31 | index = np.arange(0, gap*1000, gap, int)
32 | output_m = output_m[index]
33 | matches = matches[index]
34 | print(save_path)
35 | matches = matches.tolist()
36 |
37 |
38 | im = ax.imshow(output_m, aspect='auto')
39 | ax.set_yticks(np.arange(output_m.shape[0]), labels=matches)
40 | for i, label in enumerate(ax.get_yticklabels()):
41 | if (int(matches[i])==0):
42 | label.set_color('red')
43 | elif (int(matches[i])==1):
44 | label.set_color('green')
45 |
46 | for i in range(output_m.shape[0]):
47 | for j in range(output_m.shape[1]):
48 | text = ax.text(j, i, str(round(output_m[i, j],2)),
49 | ha="center", va="center", color="w")
50 | plt.title(fig_title)
51 | plt.savefig(save_path)
52 | plt.close()
53 |
54 | def plotPRMap(outputs, label, save_path, fig_title):
55 | plt.figure(figsize=(15,15))
56 | plt.title('{} Precision/Recall Curve'.format(fig_title))
57 | plt.xlabel('Recall')
58 | plt.ylabel('Precision')
59 | output_m = outputs.cpu().detach().numpy()
60 | pred = outputs.max(1)[1]
61 | matches = pred.eq(label).float()
62 | output_m = np.sort(output_m)
63 | output_m = output_m[:,::-1] # 从大到小排序
64 | output_m = output_m[:,:5] # 取前五个
65 | output_m_index = output_m[:,0].argsort()
66 | output_m = output_m[output_m_index]
67 | output_m = output_m[::-1,:] # 按第一列从大到小排序
68 | matches = matches[output_m_index]
69 | matches = torch.flip(matches, dims=[0])
70 | matches = matches.cpu().detach().numpy()
71 | # print(output_m[:,0].shape, matches.shape)
72 | precision, recall, thresholds = precision_recall_curve(matches, output_m[:,0])
73 | # print(precision)
74 | # print(recall)
75 | # print(thresholds, len(thresholds))
76 | plt.plot(recall, precision)
77 |
78 | step = 0
79 | for a, b, text in zip(recall, precision, thresholds):
80 | # if float(text) % 0.05 == 0:
81 | if step % 40 == 0:
82 | plt.text(a, b, text, ha='center', va='bottom', fontsize=10, color='blue')
83 | plt.plot(a, b, marker='o', color='red')
84 | step += 1
85 | plt.grid(ls='--')
86 | plt.savefig(save_path)
87 | plt.close()
88 |
89 | def select_top_k_similarity_per_class(outputs, img_paths, K=1, image_features=None, is_softmax=True):
90 | # print(outputs.shape)
91 | if is_softmax:
92 | outputs = torch.nn.Softmax(dim=1)(outputs)
93 | output_m = outputs.cpu().detach().numpy()
94 | output_ori = outputs.cpu().detach()
95 | output_m_max = output_m.max(axis=1)
96 | output_m_max_id = np.argsort(-output_m_max)
97 | output_m = output_m[output_m_max_id]
98 | img_paths = img_paths[output_m_max_id]
99 | output_m_max = output_m_max[output_m_max_id]
100 | output_ori = output_ori[output_m_max_id]
101 | ids = (-output_m).argsort()[:, 0] # 获得每行的类别标签
102 |
103 | if image_features is not None:
104 | image_features = image_features.cpu().detach()
105 | image_features = image_features[output_m_max_id]
106 |
107 | predict_label_dict = {}
108 | predict_conf_dict = {}
109 | from tqdm import tqdm
110 | for id in tqdm(list(set(ids.tolist()))): # 标签去重
111 | index = np.where(ids==id)
112 | conf_class = output_m_max[index] # 置信度
113 | output_class = output_ori[index]
114 | img_paths_class = img_paths[index] # 每个类别的路径
115 |
116 | if image_features is not None:
117 | img_features = image_features[index]
118 | if K >= 0:
119 | for img_path, img_feature, conf, logit in zip(img_paths_class[:K], img_features[:K], conf_class[:K], output_class):
120 | if '/data/' in img_path:
121 | img_path = './data/' + img_path.split('/data/')[1]
122 | predict_label_dict[img_path] = [id, img_feature, conf, logit]
123 | else:
124 | for img_path, img_feature, conf, logit in zip(img_paths_class, img_features, conf_class, output_class):
125 | if '/data/' in img_path:
126 | img_path = './data/' + img_path.split('/data/')[1]
127 | predict_label_dict[img_path] = [id, img_feature, conf, logit]
128 | else:
129 | if K >= 0:
130 | for img_path, conf in zip(img_paths_class[:K], conf_class):
131 | if '/data/' in img_path:
132 | img_path = './data/' + img_path.split('/data/')[1]
133 | predict_label_dict[img_path] = id
134 | predict_conf_dict[img_path] = conf
135 | else:
136 | for img_path, conf in zip(img_paths_class, conf_class):
137 | if '/data/' in img_path:
138 | img_path = './data/' + img_path.split('/data/')[1]
139 | predict_label_dict[img_path] = id
140 | predict_conf_dict[img_path] = conf
141 | return predict_label_dict, predict_conf_dict
142 |
143 | def select_by_conf(outputs, img_paths, K=1, conf_threshold=None, is_softmax=True):
144 | # print(outputs.shape)
145 | if is_softmax:
146 | outputs = torch.nn.Softmax(dim=1)(outputs)
147 | output_m = outputs.cpu().detach().numpy()
148 | output_ori = outputs.cpu().detach()
149 | output_m_max = output_m.max(axis=1)
150 | output_m_max_id = np.argsort(-output_m_max)
151 | output_m = output_m[output_m_max_id]
152 | img_paths = img_paths[output_m_max_id]
153 | output_m_max = output_m_max[output_m_max_id]
154 | output_ori = output_ori[output_m_max_id]
155 | ids = (-output_m).argsort()[:, 0] # 获得每行的类别标签
156 |
157 |
158 | predict_label_dict = {}
159 | predict_conf_dict = {}
160 | from tqdm import tqdm
161 | for id in tqdm(list(set(ids.tolist()))): # 标签去重
162 | index = np.where(ids==id)
163 | conf_class = output_m_max[index] # 置信度
164 | output_class = output_ori[index]
165 | img_paths_class = img_paths[index] # 每个类别的路径
166 |
167 | for img_path, conf in zip(img_paths_class, conf_class):
168 | if conf > conf_threshold:
169 | predict_label_dict[img_path] = id
170 | predict_conf_dict[img_path] = conf
171 | return predict_label_dict, predict_conf_dict
172 |
173 | def select_top_k_similarity(outputs, img_paths, K=1, image_features=None, repeat=False):
174 | # print(outputs.shape)
175 | outputs = torch.nn.Softmax(dim=1)(outputs)
176 | output_m = outputs.cpu().detach().numpy()
177 | output_ori = outputs.cpu().detach()
178 | output_m_max = output_m.max(axis=1)
179 | output_m_max_id = np.argsort(-output_m_max)
180 | output_m = output_m[output_m_max_id]
181 | img_paths = img_paths[output_m_max_id]
182 | # output_m_max = output_m_max[output_m_max_id]
183 | output_ori = output_ori[output_m_max_id]
184 | conf_class = output_m_max[output_m_max_id] # 置信度
185 | ids = (-output_m).argsort()[:, 0] # 获得每行的类别标签
186 |
187 | if image_features is not None:
188 | image_features = image_features.cpu().detach()
189 | image_features = image_features[output_m_max_id]
190 |
191 | predict_label_dict = {}
192 | predict_conf_dict = {}
193 | if image_features is not None:
194 | img_features = image_features
195 | if K >= 0:
196 | for img_path, img_feature, conf, logit in zip(img_paths_class[:K], img_features[:K], conf_class[:K], output_class):
197 | predict_label_dict[img_path] = [id, img_feature, conf, logit]
198 | else:
199 | for img_path, img_feature, conf, logit in zip(img_paths_class, img_features, conf_class, output_class):
200 | predict_label_dict[img_path] = [id, img_feature, conf, logit]
201 | else:
202 | if K >= 0:
203 | for img_path, conf, id in zip(img_paths[:K], conf_class[:K], ids[:K]):
204 | predict_label_dict[img_path] = id
205 | predict_conf_dict[img_path] = conf
206 | else:
207 | for img_path, conf, id in zip(img_paths, conf_class, ids):
208 | predict_label_dict[img_path] = id
209 | predict_conf_dict[img_path] = conf
210 | return predict_label_dict, predict_conf_dict
211 |
212 |
213 | def select_top_by_value(outputs, img_paths, conf_threshold=0.95, image_features=None, repeat=False):
214 | outputs = torch.nn.Softmax(dim=1)(outputs)
215 | output_m = outputs.cpu().detach().numpy()
216 | output_ori = outputs.cpu().detach()
217 | output_m_max = output_m.max(axis=1)
218 | output_m_max_id = np.argsort(-output_m_max)
219 | output_m = output_m[output_m_max_id]
220 | img_paths = img_paths[output_m_max_id]
221 | # output_m_max = output_m_max[output_m_max_id]
222 | output_ori = output_ori[output_m_max_id]
223 | conf_class = output_m_max[output_m_max_id] # 置信度
224 | ids = (-output_m).argsort()[:, 0] # 获得每行的类别标签
225 |
226 | if image_features is not None:
227 | image_features = image_features.cpu().detach()
228 | image_features = image_features[output_m_max_id]
229 |
230 | predict_label_dict = {}
231 | predict_conf_dict = {}
232 | if image_features is not None:
233 | img_features = image_features
234 | for img_path, img_feature, conf, logit in zip(img_paths_class, img_features, conf_class, output_class):
235 | if conf > conf_threshold:
236 | predict_label_dict[img_path] = [id, img_feature, conf, logit]
237 | else:
238 | for img_path, id, conf in zip(img_paths, ids, conf_class):
239 | if conf > conf_threshold:
240 | predict_label_dict[img_path] = id
241 | predict_conf_dict[img_path] = conf
242 | return predict_label_dict, predict_conf_dict
243 |
244 |
245 | def caculate_noise_rate(predict_label_dict, train_loader, trainer, sample_level=False):
246 | gt_label_dict = {}
247 | for batch_idx, batch in enumerate(train_loader):
248 | input, label, impath = trainer.parse_batch_test_with_impath(batch)
249 | for l, ip in zip(label, impath):
250 | if '/data/' in ip:
251 | ip = './data/' + ip.split('/data/')[1]
252 | gt_label_dict[ip] = l
253 |
254 | # print('gt_label_dict', len(gt_label_dict))
255 | # print('gt_label_dict', gt_label_dict)
256 | total = 0
257 | correct = 0
258 | for item in predict_label_dict:
259 | if '/data/' in item:
260 | item = './data/' + item.split('/data/')[1]
261 | if gt_label_dict[item] == predict_label_dict[item]:
262 | correct += 1
263 | total += 1
264 | print('Acc Rate {:.4f}'.format(correct/total))
265 |
266 |
267 | def caculate_noise_rate_analyze(predict_label_dict, train_loader, trainer, sample_level=False):
268 | gt_label_dict = {}
269 | for batch_idx, batch in enumerate(train_loader):
270 | input, label, impath = trainer.parse_batch_test_with_impath(batch)
271 | for l, ip in zip(label, impath):
272 | ip = './data/' + ip.split('/data/')[1]
273 | gt_label_dict[ip] = l
274 | total = 0
275 | correct = 0
276 | for item in predict_label_dict:
277 | if gt_label_dict[item] == predict_label_dict[item][0]:
278 | correct += 1
279 | if sample_level is True:
280 | print(gt_label_dict[item], 1)
281 | total += 1
282 | print('Acc Rate {:.4f}'.format(correct/total))
283 |
284 |
285 | def save_outputs(train_loader, trainer, predict_label_dict, dataset_name, text_features, backbone_name=None):
286 | backbone_name = backbone_name.replace('/', '-')
287 | gt_pred_label_dict = {}
288 | for batch_idx, batch in enumerate(train_loader):
289 | input, label, impath = trainer.parse_batch_test_with_impath(batch)
290 | for l, ip in zip(label, impath):
291 | l = l.item()
292 | ip = './data/' + ip.split('/data/')[1]
293 | if l not in gt_pred_label_dict:
294 | gt_pred_label_dict[l] = []
295 | pred_label = predict_label_dict[ip][0]
296 | pred_v_feature = predict_label_dict[ip][1]
297 |
298 | conf = predict_label_dict[ip][2]
299 | logits = predict_label_dict[ip][3]
300 | gt_pred_label_dict[l].append([ip, pred_label, pred_v_feature, conf, logits])
301 | else:
302 | pred_label = predict_label_dict[ip][0]
303 | pred_v_feature = predict_label_dict[ip][1]
304 | conf = predict_label_dict[ip][2]
305 | logits = predict_label_dict[ip][3]
306 | gt_pred_label_dict[l].append([ip, pred_label, pred_v_feature, conf, logits])
307 |
308 | idx = 0
309 | v_distance_dict = {}
310 | v_features = []
311 | logits_list = []
312 | for label in gt_pred_label_dict:
313 | avg_feature = None
314 | for item in gt_pred_label_dict[label]:
315 | impath, pred_label, pred_v_feature = item[0], item[1], item[2],
316 | if avg_feature is None:
317 | avg_feature = pred_v_feature.clone()
318 | else:
319 | avg_feature += pred_v_feature.clone()
320 | avg_feature /= len(gt_pred_label_dict[label]) # class center
321 | v_distance_dict_per_class = {}
322 | for item in gt_pred_label_dict[label]:
323 | impath, pred_label, pred_v_feature, conf, logits = item[0], item[1], item[2], item[3], item[4]
324 | v_features.append(pred_v_feature)
325 | logits_list.append(logits)
326 | v_dis = torch.dist(avg_feature, pred_v_feature, p=2)
327 | v_distance_dict_per_class[impath] = [idx, v_dis.item(), conf.item(), pred_label] # id, visual distance, confidence, predicted label
328 | idx += 1
329 | v_distance_dict[label] = v_distance_dict_per_class
330 |
331 | v_features = torch.vstack(v_features)
332 | logits_tensor = torch.vstack(logits_list)
333 |
334 | if not os.path.exists('./analyze_results/{}/'.format(backbone_name)):
335 | os.makedirs('./analyze_results/{}/'.format(backbone_name))
336 |
337 | torch.save(v_features, './analyze_results/{}/{}_v_feature.pt'.format(backbone_name, dataset_name))
338 | torch.save(text_features, './analyze_results/{}/{}_l_feature.pt'.format(backbone_name, dataset_name))
339 | torch.save(logits_tensor, './analyze_results/{}/{}_logits.pt'.format(backbone_name, dataset_name))
340 |
341 |
342 | with open("./analyze_results/{}/{}.json".format(backbone_name, dataset_name), "w") as outfile:
343 | json.dump(v_distance_dict, outfile)
344 |
345 |
346 | def select_top_k_similarity_per_class_with_high_conf(outputs, img_paths, K=1, image_features=None, repeat=False):
347 | # print(outputs.shape)
348 | outputs = torch.nn.Softmax(dim=1)(outputs)
349 | output_m = outputs.cpu().detach().numpy()
350 | output_ori = outputs.cpu().detach()
351 | output_m_max = output_m.max(axis=1)
352 | output_m_max_id = np.argsort(-output_m_max)
353 | output_m = output_m[output_m_max_id]
354 | img_paths = img_paths[output_m_max_id]
355 | output_m_max = output_m_max[output_m_max_id]
356 | output_ori = output_ori[output_m_max_id]
357 | ids = (-output_m).argsort()[:, 0] # 获得每行的类别标签
358 |
359 |
360 | # 计算类别平均置信度
361 | class_avg_conf = {}
362 | for id in list(set(ids.tolist())): # 标签去重
363 | index = np.where(ids==id)
364 | conf_class = output_m_max[index] # 置信度
365 | class_avg_conf[id] = conf_class.sum() / conf_class.size
366 | # print(class_avg_conf[id])
367 |
368 | selected_ids = sorted(class_avg_conf.items(), key = lambda kv:(kv[1], kv[0]), reverse=True)[:int(0.8*len(class_avg_conf))]
369 | remain_ids = sorted(class_avg_conf.items(), key = lambda kv:(kv[1], kv[0]), reverse=True)[int(0.8*len(class_avg_conf)):]
370 |
371 | selected_ids = [id[0] for id in selected_ids]
372 | remain_ids = [id[0] for id in remain_ids]
373 |
374 | if image_features is not None:
375 | image_features = image_features.cpu().detach()
376 | image_features = image_features[output_m_max_id]
377 |
378 | predict_label_dict = {}
379 | predict_conf_dict = {}
380 |
381 | # selected_ids.append(0)
382 |
383 | for id in selected_ids: # 标签去重
384 | index = np.where(ids==id)
385 | conf_class = output_m_max[index] # 置信度
386 | output_class = output_ori[index]
387 | img_paths_class = img_paths[index] # 每个类别的路径
388 |
389 | if image_features is not None:
390 | img_features = image_features[index]
391 | if K >= 0:
392 | for img_path, img_feature, conf, logit in zip(img_paths_class[:K], img_features[:K], conf_class[:K], output_class[:K]):
393 | img_path = './data/' + img_path.split('/data/')[1]
394 | predict_label_dict[img_path] = [id, img_feature, conf, logit]
395 | else:
396 | for img_path, img_feature, conf, logit in zip(img_paths_class, img_features, conf_class, output_class):
397 | img_path = './data/' + img_path.split('/data/')[1]
398 | predict_label_dict[img_path] = [id, img_feature, conf, logit]
399 | else:
400 | if K >= 0:
401 | for img_path, conf in zip(img_paths_class[:K], conf_class):
402 | img_path = './data/' + img_path.split('/data/')[1]
403 | predict_label_dict[img_path] = id
404 | predict_conf_dict[img_path] = conf
405 | else:
406 | for img_path, conf in zip(img_paths_class, conf_class):
407 | img_path = './data/' + img_path.split('/data/')[1]
408 | predict_label_dict[img_path] = id
409 | predict_conf_dict[img_path] = conf
410 | return predict_label_dict, predict_conf_dict, remain_ids, selected_ids
411 |
412 |
413 | def select_top_k_similarity_per_class_with_low_conf(outputs, img_paths, conf_threshold, remain_ids, selected_ids, K=2):
414 | # print(outputs.shape)
415 | outputs = torch.nn.Softmax(dim=1)(outputs)
416 | remain_ids_list = remain_ids
417 | remain_ids = np.sort(np.array(remain_ids).astype(np.int))
418 | remain_logits = -100*torch.ones(outputs.shape).half().cuda()
419 | remain_logits[:, remain_ids] = outputs[:, remain_ids] * 5
420 | remain_logits = torch.nn.Softmax(dim=1)(remain_logits.float())
421 | outputs = remain_logits
422 |
423 |
424 | output_m = outputs.cpu().detach().numpy()
425 | output_ori = outputs.cpu().detach()
426 | output_m_max = output_m.max(axis=1)
427 | output_m_max_id = np.argsort(-output_m_max)
428 | output_m = output_m[output_m_max_id]
429 | img_paths = img_paths[output_m_max_id]
430 | output_m_max = output_m_max[output_m_max_id]
431 | output_ori = output_ori[output_m_max_id]
432 | ids = (-output_m).argsort()[:, 0] # 获得每行的类别标签
433 |
434 |
435 | predict_label_dict = {}
436 | predict_conf_dict = {}
437 | no_sample_ids = []
438 |
439 | for id in remain_ids_list: # 标签去重
440 | # print(id)
441 | is_id_have_sample = False
442 | index = np.where(ids==id)
443 | conf_class = output_m_max[index] # 置信度
444 | output_class = output_ori[index]
445 | img_paths_class = img_paths[index] # 每个类别的路径
446 |
447 | if K >= 0:
448 | for img_path, conf in zip(img_paths_class[:K], conf_class[:K]):
449 | print(conf)
450 | if conf > 0.4:
451 | predict_label_dict[img_path] = id
452 | predict_conf_dict[img_path] = conf
453 | is_id_have_sample = True
454 | else:
455 | for img_path, conf in zip(img_paths_class, conf_class):
456 | predict_label_dict[img_path] = id
457 | predict_conf_dict[img_path] = conf
458 | # print(is_id_have_sample)
459 | if is_id_have_sample is False:
460 | no_sample_ids.append(id)
461 |
462 | print(no_sample_ids)
463 | return predict_label_dict, predict_conf_dict, no_sample_ids
464 |
465 | def select_top_k_similarity_per_class_no_smaple(outputs, img_paths, no_sample_ids, K=16):
466 | # print(outputs.shape)
467 | outputs = torch.nn.Softmax(dim=1)(outputs)
468 | output_m = outputs.cpu().detach().numpy()
469 | output_ori = outputs.cpu().detach()
470 | output_m_max = output_m.max(axis=1)
471 | output_m_max_id = np.argsort(-output_m_max)
472 | output_m = output_m[output_m_max_id]
473 | img_paths = img_paths[output_m_max_id]
474 | output_m_max = output_m_max[output_m_max_id]
475 | output_ori = output_ori[output_m_max_id]
476 | ids = (-output_m).argsort()[:, 0] # 获得每行的类别标签
477 |
478 |
479 | predict_label_dict = {}
480 | predict_conf_dict = {}
481 |
482 | for id in no_sample_ids: # 标签去重
483 | print(id)
484 | index = np.where(ids==id)
485 | conf_class = output_m_max[index] # 置信度
486 | output_class = output_ori[index]
487 | img_paths_class = img_paths[index] # 每个类别的路径
488 |
489 | if K >= 0:
490 | for img_path, conf in zip(img_paths_class[:K], conf_class[:K]):
491 | predict_label_dict[img_path] = id
492 | predict_conf_dict[img_path] = conf
493 | else:
494 | for img_path, conf in zip(img_paths_class, conf_class):
495 | predict_label_dict[img_path] = id
496 | predict_conf_dict[img_path] = conf
497 | return predict_label_dict, predict_conf_dict
498 |
499 |
500 |
--------------------------------------------------------------------------------