├── continual_clip
├── __init__.py
├── utils.py
├── datasets.py
└── models.py
├── loraclip
├── loralib
│ ├── __init__.py
│ └── utils.py
├── __init__.py
├── lora_clip.py
└── model.py
├── class_orders
├── cifar100.yaml
├── imagenet100.yaml
└── imagenet_R.yaml
├── run_experiment_img100.sh
├── run_experiment_cifar.sh
├── run_experiment.sh
├── requirements.txt
├── setup_environment.sh
├── dataset_reqs
├── oxford_pet_classes.txt
├── vtab_classes.txt
├── caltech101_classes.txt
├── food101_classes.txt
├── imagenet100_classes.txt
├── cub200_classes.txt
├── imagenet_R_classes.txt
├── tinyimagenet_classes.txt
├── imagenet_c_classes.txt
└── imagenet1000_classes.txt
├── configs
└── class
│ ├── cifar100_10-10.yaml
│ ├── imagenet100_10-10.yaml
│ └── imagenet_r_20-20.yaml
├── .gitignore
├── README.md
├── epoch.py
└── main.py
/continual_clip/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/loraclip/loralib/__init__.py:
--------------------------------------------------------------------------------
1 | name = "lora"
2 |
3 | from .layers import *
4 | from .utils import *
--------------------------------------------------------------------------------
/class_orders/cifar100.yaml:
--------------------------------------------------------------------------------
1 | class_order: [87, 0, 52, 58, 44, 91, 68, 97, 51, 15, 94, 92, 10, 72, 49, 78, 61, 14, 8, 86, 84, 96, 18, 24, 32, 45, 88, 11, 4, 67, 69, 66, 77, 47, 79, 93, 29, 50, 57, 83, 17, 81, 41, 12, 37, 59, 25, 20, 80, 73, 1, 28, 6, 46, 62, 82, 53, 9, 31, 75, 38, 63, 33, 74, 27, 22, 36, 3, 16, 21, 60, 19, 70, 90, 89, 43, 5, 42, 65, 76, 40, 30, 23, 85, 2, 95, 56, 48, 71, 64, 98, 13, 99, 7, 34, 55, 54, 26, 35, 39]
2 |
--------------------------------------------------------------------------------
/run_experiment_img100.sh:
--------------------------------------------------------------------------------
1 | #!bin/bash
2 |
3 |
4 | python epoch.py \
5 | --config-path configs/class \
6 | --config-name imagenet100_10-10.yaml \
7 | dataset_root="path" \
8 | class_order="class_orders/imagenet100.yaml"
9 |
10 | python main.py \
11 | --config-path configs/class \
12 | --config-name imagenet100_10-10.yaml \
13 | dataset_root="path" \
14 | class_order="class_orders/imagenet100.yaml"
--------------------------------------------------------------------------------
/run_experiment_cifar.sh:
--------------------------------------------------------------------------------
1 | #!bin/bash
2 |
3 |
4 |
5 | python epoch.py \
6 | --config-path configs/class \
7 | --config-name cifar100_10-10.yaml \
8 | dataset_root="path" \
9 | class_order="class_orders/cifar100.yaml"
10 |
11 | python main.py \
12 | --config-path configs/class \
13 | --config-name cifar100_10-10.yaml \
14 | dataset_root="path" \
15 | class_order="class_orders/cifar100.yaml"
16 |
17 |
18 |
--------------------------------------------------------------------------------
/run_experiment.sh:
--------------------------------------------------------------------------------
1 | #!bin/bash
2 |
3 |
4 |
5 | python epoch.py \
6 | --config-path configs/class \
7 | --config-name imagenet_r_20-20.yaml \
8 | dataset_root="path" \
9 | class_order="class_orders/imagenet_R.yaml"
10 |
11 | python main.py \
12 | --config-path configs/class \
13 | --config-name imagenet_r_20-20.yaml \
14 | dataset_root="path" \
15 | class_order="class_orders/imagenet_R.yaml"
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | continuum
2 | hydra-core==1.2.0
3 | numpy
4 | oauthlib==3.2.1
5 | omegaconf==2.2.3
6 | open-clip-torch==1.3.0
7 | pandas==1.4.3
8 | Pillow==9.2.0
9 | pipreqs==0.4.11
10 | scikit-image==0.19.3
11 | scikit-learn==1.1.1
12 | scipy==1.8.1
13 | tensorboard==2.10.0
14 | timm @ git+https://github.com/Arnav0400/pytorch-image-models.git@ceea7127c1ef608179ba06eaeddc22ad3ef22de0
15 | tokenizers==0.12.1
16 | tqdm==4.64.0
17 | transformers==4.21.1
18 | ftfy
19 | regex
20 |
21 |
--------------------------------------------------------------------------------
/class_orders/imagenet100.yaml:
--------------------------------------------------------------------------------
1 | class_order: [
2 | 68, 56, 78, 8, 23, 84, 90, 65, 74, 76, 40,
3 | 89, 3, 92, 55, 9, 26, 80, 43, 38, 58, 70,
4 | 77, 1, 85, 19, 17, 50, 28, 53, 13, 81, 45,
5 | 82, 6, 59, 83, 16, 15, 44, 91, 41, 72, 60,
6 | 79, 52, 20, 10, 31, 54, 37, 95, 14, 71, 96,
7 | 98, 97, 2, 64, 66, 42, 22, 35, 86, 24, 34,
8 | 87, 21, 99, 0, 88, 27, 18, 94, 11, 12, 47,
9 | 25, 30, 46, 62, 69, 36, 61, 7, 63, 75, 5, 32,
10 | 4, 51, 48, 73, 93, 39, 67, 29, 49, 57, 33
11 | ]
12 |
--------------------------------------------------------------------------------
/setup_environment.sh:
--------------------------------------------------------------------------------
1 | #!bin/bash
2 |
3 | # create enviroment using Miniconda (or Anaconda)
4 | conda create -n continual_clip python=3.8
5 | conda activate continual_clip
6 |
7 | # install pytorch
8 | pip install torch==1.8.2 torchvision==0.9.2 torchaudio==0.8.2 \
9 | --extra-index-url https://download.pytorch.org/whl/lts/1.8/cu111
10 |
11 | # install other dependencies
12 | pip install -r requirements.txt
13 |
14 | # install CLIP
15 | pip install git+https://github.com/openai/CLIP.git
16 |
17 |
--------------------------------------------------------------------------------
/dataset_reqs/oxford_pet_classes.txt:
--------------------------------------------------------------------------------
1 | 1 abyssinian
2 | 2 american_bulldog
3 | 3 american_pit_bull_terrier
4 | 4 basset_hound
5 | 5 beagle
6 | 6 bengal
7 | 7 birman
8 | 8 bombay
9 | 9 boxer
10 | 10 british_shorthair
11 | 11 chihuahua
12 | 12 egyptian_mau
13 | 13 english_cocker_spaniel
14 | 14 english_setter
15 | 15 german_shorthaired
16 | 16 great_pyrenees
17 | 17 havanese
18 | 18 japanese_chin
19 | 19 keeshond
20 | 20 leonberger
21 | 21 maine_coon
22 | 22 miniature_pinscher
23 | 23 newfoundland
24 | 24 persian
25 | 25 pomeranian
26 | 26 pug
27 | 27 ragdoll
28 | 28 russian_blue
29 | 29 saint_bernard
30 | 30 samoyed
31 | 31 scottish_terrier
32 | 32 shiba_inu
33 | 33 siamese
34 | 34 sphynx
35 | 35 staffordshire_bull_terrier
36 | 36 wheaten_terrier
37 | 37 yorkshire_terrier
38 |
--------------------------------------------------------------------------------
/dataset_reqs/vtab_classes.txt:
--------------------------------------------------------------------------------
1 | airplane
2 | airport
3 | baseball diamond
4 | basketball court
5 | beach
6 | bridge
7 | chaparral
8 | church
9 | circular farmland
10 | cloud
11 | banded
12 | blotchy
13 | braided
14 | bubbly
15 | bumpy
16 | chequered
17 | cobwebbed
18 | cracked
19 | crosshatched
20 | crystalline
21 | Abyssinian Cat
22 | American Bulldog
23 | American Pit Bull Terrier
24 | Basset Hound
25 | Beagle
26 | Bengal Cat
27 | Birman
28 | Bombay Cat
29 | Boxer Dog
30 | British Shorthair Cat
31 | AnnualCrop
32 | Forest
33 | HerbaceousVegetation
34 | Highway
35 | Industrial
36 | Pasture
37 | PermanentCrop
38 | Residential
39 | River
40 | SeaLake
41 | petunia
42 | wild pansy
43 | primula
44 | sunflower
45 | pelargonium
46 | bishop of llandaff
47 | gaura
48 | geranium
49 | orange dahlia
50 | pink-yellow dahlia
--------------------------------------------------------------------------------
/class_orders/imagenet_R.yaml:
--------------------------------------------------------------------------------
1 | class_order: [150, 119, 193, 2, 75, 179, 129, 68, 145, 53, 171, 133, 31, 103, 146, 174, 149, 40, 142, 30, 183, 156, 167, 166, 124, 56, 109, 198, 185, 170, 51, 55, 77, 32, 45, 23, 115, 39, 134, 177, 12, 126, 130, 37, 19, 34, 0, 135, 161, 108, 128, 18, 141, 24, 154, 160, 78, 85, 89, 73, 152, 112, 114, 104, 44, 93, 138, 107, 100, 163, 3, 148, 147, 35, 199, 50, 25, 96, 79, 172, 98, 140, 22, 162, 6, 27, 60, 144, 197, 67, 165, 196, 102, 71, 65, 47, 9, 122, 42, 63, 61, 180, 97, 84, 157, 168, 69, 41, 83, 92, 190, 191, 59, 72, 111, 155, 105, 57, 17, 116, 13, 28, 43, 136, 5, 14, 169, 7, 62, 110, 189, 95, 8, 4, 137, 194, 90, 74, 70, 113, 186, 132, 192, 33, 36, 158, 188, 94, 121, 164, 58, 11, 178, 151, 184, 1, 10, 64, 15, 106, 52, 176, 26, 81, 82, 88, 125, 46, 123, 91, 187, 86, 20, 76, 80, 66, 120, 131, 16, 159, 101, 87, 99, 195, 143, 21, 49, 38, 181, 29, 153, 54, 173, 127, 117, 182, 139, 118, 48, 175]
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/configs/class/cifar100_10-10.yaml:
--------------------------------------------------------------------------------
1 | hydra:
2 | run:
3 | dir: ./experiments/${scenario}/${dataset}_${initial_increment}-${increment}
4 | job:
5 | chdir: true
6 |
7 | job_logging:
8 | version: 1
9 | formatters:
10 | simple:
11 | format: '%(message)s'
12 |
13 | class_order: ""
14 | dataset_root: ""
15 | workdir: ""
16 | log_path: "metric.json"
17 | model_name: "ViT-B/16"
18 | prompt_template: "a good photo of a {}."
19 | #128
20 | batch_size: 128
21 | increment: 10
22 | initial_increment: 10
23 | scenario: "class"
24 | dataset: "cifar100"
25 | task_num: 10
26 | seed: 42
27 |
28 | epochs: 1
29 | train_batch_size: 128
30 | num_workers: 8
31 | lora_rank: 8
32 | lora_mode: "vision+only_kv+text"
33 | lr: 0.001
34 | reset: False
35 | only_reset_B: False
36 | freeze_A: False
37 | all_test: False
38 |
39 |
40 | visual_clsf: true
41 | visual_clsf_lr: 0.0005
42 | visual_clsf_batch_size: 64
43 | visual_clsf_epochs: 3
44 |
45 |
46 |
47 | real_replay: false
48 | balance_ft: False
49 | balance_epochs: 0
50 |
--------------------------------------------------------------------------------
/loraclip/__init__.py:
--------------------------------------------------------------------------------
1 | from .lora_clip import *
2 |
3 |
4 | def print_trainable_parameters(model):
5 | """
6 | Prints the number of trainable parameters in the model.
7 | """
8 | trainable_params = 0
9 | all_param = 0
10 | trainable_params_names = []
11 |
12 | for name, param in model.named_parameters():
13 | all_param += param.numel()
14 | if param.requires_grad:
15 | trainable_params += param.numel()
16 | trainable_params_names.append(name)
17 | # if "lora_B" in name:
18 | # print(param.sum())
19 |
20 | print(f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}")
21 |
22 | print("Trainable Parameters:")
23 | for name in trainable_params_names:
24 | print(name)
25 | # check = True
26 | # for name in trainable_params_names:
27 | # if "lora" not in name:
28 | # check = False
29 |
30 | # if check:
31 | # print("Are LoRA parameters correctly present? Yes")
32 | # else:
33 | # print("Are LORA parameters correctly present? No")
--------------------------------------------------------------------------------
/configs/class/imagenet100_10-10.yaml:
--------------------------------------------------------------------------------
1 | hydra:
2 | run:
3 | dir: ./experiments/${scenario}/${dataset}_${initial_increment}-${increment}
4 | job:
5 | chdir: true
6 |
7 | job_logging:
8 | version: 1
9 | formatters:
10 | simple:
11 | format: '%(message)s'
12 |
13 | class_order: ""
14 | dataset_root: ""
15 | workdir: ""
16 | log_path: "metrics.json"
17 | model_name: "ViT-B/16"
18 | prompt_template: "a good photo of a {}."
19 |
20 | batch_size: 128
21 | initial_increment: 10
22 | increment: ${initial_increment}
23 | scenario: "class"
24 | dataset: "imagenet100"
25 | task_num: 10
26 |
27 | epochs: 2
28 | train_batch_size: 128
29 | num_workers: 8
30 | lora_rank: 8
31 | lora_mode: "vision+only_kv+text"
32 | lr: 0.0005
33 | reset: False
34 | only_reset_B: False
35 | freeze_A: False
36 | all_test: False
37 | weight_decay: 1e-4
38 | momentum: 0.9
39 | seed: 0
40 |
41 |
42 | visual_clsf: true
43 | visual_clsf_lr: 0.001
44 | visual_clsf_batch_size: 128
45 | visual_clsf_epochs: 3
46 |
47 | real_replay: false
48 |
49 | balance_ft: false
50 | balance_epochs: 1
51 |
--------------------------------------------------------------------------------
/configs/class/imagenet_r_20-20.yaml:
--------------------------------------------------------------------------------
1 | hydra:
2 | run:
3 | dir: ./experiments/${scenario}/${dataset}_${initial_increment}-${increment}
4 | job:
5 | chdir: true
6 |
7 | job_logging:
8 | version: 1
9 | formatters:
10 | simple:
11 | format: '%(message)s'
12 |
13 | class_order: ""
14 | dataset_root: ""
15 | workdir: ""
16 | log_path: "metric.json"
17 | model_name: "ViT-B/16"
18 | prompt_template: "a good photo of a {}."
19 |
20 | batch_size: 32
21 | initial_increment: 20
22 | increment: ${initial_increment}
23 | scenario: "class"
24 | dataset: "imagenet_R"
25 | task_num: 10
26 |
27 | epochs: 2
28 | # 64
29 | train_batch_size: 64
30 | num_workers: 8
31 | lora_rank: 8
32 | lora_mode: "vision+only_kv+text"
33 | lr: 0.001
34 | reset: False
35 | only_reset_B: False
36 | freeze_A: False
37 | all_test: False
38 | weight_decay: 1e-4
39 | momentum: 0.9
40 | seed: 0
41 |
42 |
43 | visual_clsf: true
44 | visual_clsf_lr: 0.0005
45 | visual_clsf_batch_size: 32
46 | visual_clsf_epochs: 3
47 |
48 | real_replay: false
49 |
50 |
51 | balance_ft: False
52 | balance_epochs: 1
53 |
54 |
--------------------------------------------------------------------------------
/continual_clip/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import json
4 | import yaml
5 |
6 | from omegaconf import DictConfig, OmegaConf
7 | import pdb
8 |
9 |
10 |
11 | def get_class_order(file_name: str) -> list:
12 | r"""TO BE DOCUMENTED"""
13 | with open(file_name, "r+") as f:
14 | data = yaml.safe_load(f)
15 | return data["class_order"]
16 |
17 |
18 | def get_class_ids_per_task(args):
19 | yield args.class_order[:args.initial_increment]
20 | for i in range(args.initial_increment, len(args.class_order), args.increment):
21 | yield args.class_order[i:i + args.increment]
22 |
23 | def get_class_names(classes_names, class_ids_per_task):
24 | return [classes_names[class_id] for class_id in class_ids_per_task]
25 |
26 |
27 | def get_dataset_class_names(workdir, dataset_name, long=False):
28 | with open(os.path.join(workdir, "dataset_reqs", f"{dataset_name}_classes.txt"), "r") as f:
29 | lines = f.read().splitlines()
30 | return [line.split("\t")[-1] for line in lines]
31 |
32 |
33 | def save_config(config: DictConfig) -> None:
34 | OmegaConf.save(config, "config.yaml")
35 |
36 |
37 | def get_workdir(path):
38 | split_path = path.split("/")
39 | workdir_idx = split_path.index("MindtheGap")
40 | return "/".join(split_path[:workdir_idx+1])
41 |
42 |
43 |
--------------------------------------------------------------------------------
/dataset_reqs/caltech101_classes.txt:
--------------------------------------------------------------------------------
1 | 1 accordion
2 | 2 airplanes
3 | 3 anchor
4 | 4 ant
5 | 5 barrel
6 | 6 bass
7 | 7 beaver
8 | 8 binocular
9 | 9 bonsai
10 | 10 brain
11 | 11 brontosaurus
12 | 12 buddha
13 | 13 butterfly
14 | 14 camera
15 | 15 cannon
16 | 16 car_side
17 | 17 ceiling_fan
18 | 18 cellphone
19 | 19 chair
20 | 20 chandelier
21 | 21 cougar_body
22 | 22 cougar_face
23 | 23 crab
24 | 24 crayfish
25 | 25 crocodile
26 | 26 crocodile_head
27 | 27 cup
28 | 28 dalmatian
29 | 29 dollar_bill
30 | 30 dolphin
31 | 31 dragonfly
32 | 32 electric_guitar
33 | 33 elephant
34 | 34 emu
35 | 35 euphonium
36 | 36 ewer
37 | 37 faces
38 | 38 faces_easy
39 | 39 ferry
40 | 40 flamingo
41 | 41 flamingo_head
42 | 42 garfield
43 | 43 gerenuk
44 | 44 gramophone
45 | 45 grand_piano
46 | 46 hawksbill
47 | 47 headphone
48 | 48 hedgehog
49 | 49 helicopter
50 | 50 ibis
51 | 51 inline_skate
52 | 52 joshua_tree
53 | 53 kangaroo
54 | 54 ketch
55 | 55 lamp
56 | 56 laptop
57 | 57 leopards
58 | 58 llama
59 | 59 lobster
60 | 60 lotus
61 | 61 mandolin
62 | 62 mayfly
63 | 63 menorah
64 | 64 metronome
65 | 65 minaret
66 | 66 motorbikes
67 | 67 nautilus
68 | 68 octopus
69 | 69 okapi
70 | 70 pagoda
71 | 71 panda
72 | 72 pigeon
73 | 73 pizza
74 | 74 platypus
75 | 75 pyramid
76 | 76 revolver
77 | 77 rhino
78 | 78 rooster
79 | 79 saxophone
80 | 80 schooner
81 | 81 scissors
82 | 82 scorpion
83 | 83 sea_horse
84 | 84 snoopy
85 | 85 soccer_ball
86 | 86 stapler
87 | 87 starfish
88 | 88 stegosaurus
89 | 89 stop_sign
90 | 90 strawberry
91 | 91 sunflower
92 | 92 tick
93 | 93 trilobite
94 | 94 umbrella
95 | 95 watch
96 | 96 water_lilly
97 | 97 wheelchair
98 | 98 wild_cat
99 | 99 windsor_chair
100 | 100 wrench
101 | 101 yin_yang
102 |
--------------------------------------------------------------------------------
/dataset_reqs/food101_classes.txt:
--------------------------------------------------------------------------------
1 | apple_pie
2 | baby_back_ribs
3 | baklava
4 | beef_carpaccio
5 | beef_tartare
6 | beet_salad
7 | beignets
8 | bibimbap
9 | bread_pudding
10 | breakfast_burrito
11 | bruschetta
12 | caesar_salad
13 | cannoli
14 | caprese_salad
15 | carrot_cake
16 | ceviche
17 | cheesecake
18 | cheese_plate
19 | chicken_curry
20 | chicken_quesadilla
21 | chicken_wings
22 | chocolate_cake
23 | chocolate_mousse
24 | churros
25 | clam_chowder
26 | club_sandwich
27 | crab_cakes
28 | creme_brulee
29 | croque_madame
30 | cup_cakes
31 | deviled_eggs
32 | donuts
33 | dumplings
34 | edamame
35 | eggs_benedict
36 | escargots
37 | falafel
38 | filet_mignon
39 | fish_and_chips
40 | foie_gras
41 | french_fries
42 | french_onion_soup
43 | french_toast
44 | fried_calamari
45 | fried_rice
46 | frozen_yogurt
47 | garlic_bread
48 | gnocchi
49 | greek_salad
50 | grilled_cheese_sandwich
51 | grilled_salmon
52 | guacamole
53 | gyoza
54 | hamburger
55 | hot_and_sour_soup
56 | hot_dog
57 | huevos_rancheros
58 | hummus
59 | ice_cream
60 | lasagna
61 | lobster_bisque
62 | lobster_roll_sandwich
63 | macaroni_and_cheese
64 | macarons
65 | miso_soup
66 | mussels
67 | nachos
68 | omelette
69 | onion_rings
70 | oysters
71 | pad_thai
72 | paella
73 | pancakes
74 | panna_cotta
75 | peking_duck
76 | pho
77 | pizza
78 | pork_chop
79 | poutine
80 | prime_rib
81 | pulled_pork_sandwich
82 | ramen
83 | ravioli
84 | red_velvet_cake
85 | risotto
86 | samosa
87 | sashimi
88 | scallops
89 | seaweed_salad
90 | shrimp_and_grits
91 | spaghetti_bolognese
92 | spaghetti_carbonara
93 | spring_rolls
94 | steak
95 | strawberry_shortcake
96 | sushi
97 | tacos
98 | takoyaki
99 | tiramisu
100 | tuna_tartare
101 | waffles
102 |
--------------------------------------------------------------------------------
/loraclip/loralib/utils.py:
--------------------------------------------------------------------------------
1 | # Source: https://github.com/SivanDoveh/TSVLC/blob/main/src/open_clip/loralib/layers.py
2 | # ------------------------------------------------------------------------------------------
3 | # Copyright (c) Microsoft Corporation. All rights reserved.
4 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
5 | # ------------------------------------------------------------------------------------------
6 | #
7 | import torch
8 | import torch.nn as nn
9 |
10 | from typing import Dict
11 |
12 | from .layers import LoRALayer
13 |
14 |
15 | def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None:
16 | for n, p in model.named_parameters():
17 | if 'lora_' not in n:
18 | p.requires_grad = False
19 | # print(f'Freezing {n}')
20 | # else:
21 | # print(f'Optimizing {n}')
22 | if bias == 'none':
23 | return
24 | elif bias == 'all':
25 | for n, p in model.named_parameters():
26 | if 'bias' in n:
27 | p.requires_grad = True
28 | elif bias == 'lora_only':
29 | for m in model.modules():
30 | if isinstance(m, LoRALayer) and \
31 | hasattr(m, 'bias') and \
32 | m.bias is not None:
33 | m.bias.requires_grad = True
34 | else:
35 | raise NotImplementedError
36 |
37 |
38 | def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]:
39 | my_state_dict = model.state_dict()
40 | if bias == 'none':
41 | return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k}
42 | elif bias == 'all':
43 | return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k}
44 | elif bias == 'lora_only':
45 | to_return = {}
46 | for k in my_state_dict:
47 | if 'lora_' in k:
48 | to_return[k] = my_state_dict[k]
49 | bias_name = k.split('lora_')[0]+'bias'
50 | if bias_name in my_state_dict:
51 | to_return[bias_name] = my_state_dict[bias_name]
52 | return to_return
53 | else:
54 | raise NotImplementedError
55 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | # lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # custom
132 | experiments/*
133 | run_cifar100.sh
134 | run_imagenet100.sh
135 | run_imagenet1000.sh
136 | run_tinyimagenet.sh
137 | vit_nms*
138 | continual_clip_nms*
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [ICCV 2025 Highlight] Mind the Gap: Preserving and Compensating for the Modality Gap in CLIP-Based Continual Learning
2 | This is the official code for our paper
3 | :
4 |
5 | ## Getting Started
6 |
7 | ## Environment
8 | The environment is the same as that of our [RAPF](https://github.com/linlany/RAPF).
9 |
10 | create enviroment using Miniconda (or Anaconda)
11 | ```
12 | conda create -n continual_clip python=3.8
13 | conda activate continual_clip
14 | ```
15 | install dependencies:
16 | ```bash
17 | bash setup_environment.sh
18 | ```
19 | ### Running scripts
20 |
21 | We provide the scripts for imagenet100. Please run:
22 |
23 | ```
24 | python main.py \
25 | --config-path configs/class \
26 | --config-name imagenet100_10-10.yaml \
27 | dataset_root="[imagenet1k_path]" \
28 | class_order="class_orders/imagenet100.yaml"
29 | ```
30 | **Note:** To obtain the epoch parameter from the first task described in Eq. (3), please run the epoch.py file.
31 |
32 |
33 | The dataset_root folder should contain the train and val folders.
34 | ```
35 | imagenet1k_path
36 | ├── train
37 | │ ├── n01440764
38 | │ └── ···
39 | ├── val
40 | │ ├── n01440764
41 | │ └── ···
42 |
43 | imagenet-r_path
44 | ├── train
45 | │ ├── n01443537
46 | │ └── ···
47 | ├── val
48 | │ ├── n01443537
49 | │ └── ···
50 |
51 | ```
52 |
53 | The command to run the other two datasets is similar, in run_experiment.sh
54 |
55 | ### datasets
56 | Cifar100 will download automatically.
57 | Imagenet-R is randomly splited. You can also use our splited list in RAPF/imgr_split/imgr_train_test_split.txt.
58 |
59 | The format of imgr_train_test_split.txt:
60 | ```
61 | train
62 | n02051845/art_0.jpg
63 | ...
64 | test
65 | n02051845/tattoo_4.jpg
66 | ...
67 | ```
68 |
69 | ## Acknowledgement
70 | Our method implementation is based on the [Continual-CLIP](https://github.com/vgthengane/Continual-CLIP).
71 |
72 | ## Citation
73 |
74 | If you find our repo useful for your research, please consider citing our paper:
75 |
76 | ```bibtex
77 | @inproceedings{huang2025mind,
78 | title={Mind the gap: Preserving and compensating for the modality gap in clip-based continual learning},
79 | author={Huang, Linlan and Cao, Xusheng and Lu, Haori and Meng, Yifan and Yang, Fei and Liu, Xialei},
80 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
81 | pages={3777--3786},
82 | year={2025}
83 | }
84 | ```
85 |
86 | ## License
87 | This code is licensed under the [Creative Commons Attribution-NonCommercial 4.0 International](https://creativecommons.org/licenses/by-nc/4.0/) for non-commercial use only.
88 | Please note that any commercial use of this code requires formal permission prior to use.
89 |
90 | ## Contact
91 |
92 | For technical questions, please contact huanglinlan@mail.nankai.edu.cn
93 |
--------------------------------------------------------------------------------
/dataset_reqs/imagenet100_classes.txt:
--------------------------------------------------------------------------------
1 | 0 n01440764 tench
2 | 1 n01443537 goldfish
3 | 2 n01484850 great white shark
4 | 3 n01491361 tiger shark
5 | 4 n01494475 hammerhead shark
6 | 5 n01496331 electric ray
7 | 6 n01498041 stingray
8 | 7 n01514668 rooster
9 | 8 n01514859 hen
10 | 9 n01518878 ostrich
11 | 10 n01530575 brambling
12 | 11 n01531178 goldfinch
13 | 12 n01532829 house finch
14 | 13 n01534433 junco
15 | 14 n01537544 indigo bunting
16 | 15 n01558993 American robin
17 | 16 n01560419 bulbul
18 | 17 n01580077 jay
19 | 18 n01582220 magpie
20 | 19 n01592084 chickadee
21 | 20 n01601694 American dipper
22 | 21 n01608432 kite (bird of prey)
23 | 22 n01614925 bald eagle
24 | 23 n01616318 vulture
25 | 24 n01622779 great grey owl
26 | 25 n01629819 fire salamander
27 | 26 n01630670 smooth newt
28 | 27 n01631663 newt
29 | 28 n01632458 spotted salamander
30 | 29 n01632777 axolotl
31 | 30 n01641577 American bullfrog
32 | 31 n01644373 tree frog
33 | 32 n01644900 tailed frog
34 | 33 n01664065 loggerhead sea turtle
35 | 34 n01665541 leatherback sea turtle
36 | 35 n01667114 mud turtle
37 | 36 n01667778 terrapin
38 | 37 n01669191 box turtle
39 | 38 n01675722 banded gecko
40 | 39 n01677366 green iguana
41 | 40 n01682714 Carolina anole
42 | 41 n01685808 desert grassland whiptail lizard
43 | 42 n01687978 agama
44 | 43 n01688243 frilled-necked lizard
45 | 44 n01689811 alligator lizard
46 | 45 n01692333 Gila monster
47 | 46 n01693334 European green lizard
48 | 47 n01694178 chameleon
49 | 48 n01695060 Komodo dragon
50 | 49 n01697457 Nile crocodile
51 | 50 n01698640 American alligator
52 | 51 n01704323 triceratops
53 | 52 n01728572 worm snake
54 | 53 n01728920 ring-necked snake
55 | 54 n01729322 eastern hog-nosed snake
56 | 55 n01729977 smooth green snake
57 | 56 n01734418 kingsnake
58 | 57 n01735189 garter snake
59 | 58 n01737021 water snake
60 | 59 n01739381 vine snake
61 | 60 n01740131 night snake
62 | 61 n01742172 boa constrictor
63 | 62 n01744401 African rock python
64 | 63 n01748264 Indian cobra
65 | 64 n01749939 green mamba
66 | 65 n01751748 sea snake
67 | 66 n01753488 Saharan horned viper
68 | 67 n01755581 eastern diamondback rattlesnake
69 | 68 n01756291 sidewinder rattlesnake
70 | 69 n01768244 trilobite
71 | 70 n01770081 harvestman
72 | 71 n01770393 scorpion
73 | 72 n01773157 yellow garden spider
74 | 73 n01773549 barn spider
75 | 74 n01773797 European garden spider
76 | 75 n01774384 southern black widow
77 | 76 n01774750 tarantula
78 | 77 n01775062 wolf spider
79 | 78 n01776313 tick
80 | 79 n01784675 centipede
81 | 80 n01795545 black grouse
82 | 81 n01796340 ptarmigan
83 | 82 n01797886 ruffed grouse
84 | 83 n01798484 prairie grouse
85 | 84 n01806143 peafowl
86 | 85 n01806567 quail
87 | 86 n01807496 partridge
88 | 87 n01817953 african grey parrot
89 | 88 n01818515 macaw
90 | 89 n01819313 sulphur-crested cockatoo
91 | 90 n01820546 lorikeet
92 | 91 n01824575 coucal
93 | 92 n01828970 bee eater
94 | 93 n01829413 hornbill
95 | 94 n01833805 hummingbird
96 | 95 n01843065 jacamar
97 | 96 n01843383 toucan
98 | 97 n01847000 duck
99 | 98 n01855032 red-breasted merganser
100 | 99 n01855672 goose
101 |
--------------------------------------------------------------------------------
/dataset_reqs/cub200_classes.txt:
--------------------------------------------------------------------------------
1 | 1 Black_footed_Albatross
2 | 2 Laysan_Albatross
3 | 3 Sooty_Albatross
4 | 4 Groove_billed_Ani
5 | 5 Crested_Auklet
6 | 6 Least_Auklet
7 | 7 Parakeet_Auklet
8 | 8 Rhinoceros_Auklet
9 | 9 Brewer_Blackbird
10 | 10 Red_winged_Blackbird
11 | 11 Rusty_Blackbird
12 | 12 Yellow_headed_Blackbird
13 | 13 Bobolink
14 | 14 Indigo_Bunting
15 | 15 Lazuli_Bunting
16 | 16 Painted_Bunting
17 | 17 Cardinal
18 | 18 Spotted_Catbird
19 | 19 Gray_Catbird
20 | 20 Yellow_breasted_Chat
21 | 21 Eastern_Towhee
22 | 22 Chuck_will_Widow
23 | 23 Brandt_Cormorant
24 | 24 Red_faced_Cormorant
25 | 25 Pelagic_Cormorant
26 | 26 Bronzed_Cowbird
27 | 27 Shiny_Cowbird
28 | 28 Brown_Creeper
29 | 29 American_Crow
30 | 30 Fish_Crow
31 | 31 Black_billed_Cuckoo
32 | 32 Mangrove_Cuckoo
33 | 33 Yellow_billed_Cuckoo
34 | 34 Gray_crowned_Rosy_Finch
35 | 35 Purple_Finch
36 | 36 Northern_Flicker
37 | 37 Acadian_Flycatcher
38 | 38 Great_Crested_Flycatcher
39 | 39 Least_Flycatcher
40 | 40 Olive_sided_Flycatcher
41 | 41 Scissor_tailed_Flycatcher
42 | 42 Vermilion_Flycatcher
43 | 43 Yellow_bellied_Flycatcher
44 | 44 Frigatebird
45 | 45 Northern_Fulmar
46 | 46 Gadwall
47 | 47 American_Goldfinch
48 | 48 European_Goldfinch
49 | 49 Boat_tailed_Grackle
50 | 50 Eared_Grebe
51 | 51 Horned_Grebe
52 | 52 Pied_billed_Grebe
53 | 53 Western_Grebe
54 | 54 Blue_Grosbeak
55 | 55 Evening_Grosbeak
56 | 56 Pine_Grosbeak
57 | 57 Rose_breasted_Grosbeak
58 | 58 Pigeon_Guillemot
59 | 59 California_Gull
60 | 60 Glaucous_winged_Gull
61 | 61 Heermann_Gull
62 | 62 Herring_Gull
63 | 63 Ivory_Gull
64 | 64 Ring_billed_Gull
65 | 65 Slaty_backed_Gull
66 | 66 Western_Gull
67 | 67 Anna_Hummingbird
68 | 68 Ruby_throated_Hummingbird
69 | 69 Rufous_Hummingbird
70 | 70 Green_Violetear
71 | 71 Long_tailed_Jaeger
72 | 72 Pomarine_Jaeger
73 | 73 Blue_Jay
74 | 74 Florida_Jay
75 | 75 Green_Jay
76 | 76 Dark_eyed_Junco
77 | 77 Tropical_Kingbird
78 | 78 Gray_Kingbird
79 | 79 Belted_Kingfisher
80 | 80 Green_Kingfisher
81 | 81 Pied_Kingfisher
82 | 82 Ringed_Kingfisher
83 | 83 White_breasted_Kingfisher
84 | 84 Red_legged_Kittiwake
85 | 85 Horned_Lark
86 | 86 Pacific_Loon
87 | 87 Mallard
88 | 88 Western_Meadowlark
89 | 89 Hooded_Merganser
90 | 90 Red_breasted_Merganser
91 | 91 Mockingbird
92 | 92 Nighthawk
93 | 93 Clark_Nutcracker
94 | 94 White_breasted_Nuthatch
95 | 95 Baltimore_Oriole
96 | 96 Hooded_Oriole
97 | 97 Orchard_Oriole
98 | 98 Scott_Oriole
99 | 99 Ovenbird
100 | 100 Brown_Pelican
101 | 101 White_Pelican
102 | 102 Western_Wood_Pewee
103 | 103 Sayornis
104 | 104 American_Pipit
105 | 105 Whip_poor_Will
106 | 106 Horned_Puffin
107 | 107 Common_Raven
108 | 108 White_necked_Raven
109 | 109 American_Redstart
110 | 110 Geococcyx
111 | 111 Loggerhead_Shrike
112 | 112 Great_Grey_Shrike
113 | 113 Baird_Sparrow
114 | 114 Black_throated_Sparrow
115 | 115 Brewer_Sparrow
116 | 116 Chipping_Sparrow
117 | 117 Clay_colored_Sparrow
118 | 118 House_Sparrow
119 | 119 Field_Sparrow
120 | 120 Fox_Sparrow
121 | 121 Grasshopper_Sparrow
122 | 122 Harris_Sparrow
123 | 123 Henslow_Sparrow
124 | 124 Le_Conte_Sparrow
125 | 125 Lincoln_Sparrow
126 | 126 Nelson_Sharp_tailed_Sparrow
127 | 127 Savannah_Sparrow
128 | 128 Seaside_Sparrow
129 | 129 Song_Sparrow
130 | 130 Tree_Sparrow
131 | 131 Vesper_Sparrow
132 | 132 White_crowned_Sparrow
133 | 133 White_throated_Sparrow
134 | 134 Cape_Glossy_Starling
135 | 135 Bank_Swallow
136 | 136 Barn_Swallow
137 | 137 Cliff_Swallow
138 | 138 Tree_Swallow
139 | 139 Scarlet_Tanager
140 | 140 Summer_Tanager
141 | 141 Artic_Tern
142 | 142 Black_Tern
143 | 143 Caspian_Tern
144 | 144 Common_Tern
145 | 145 Elegant_Tern
146 | 146 Forsters_Tern
147 | 147 Least_Tern
148 | 148 Green_tailed_Towhee
149 | 149 Brown_Thrasher
150 | 150 Sage_Thrasher
151 | 151 Black_capped_Vireo
152 | 152 Blue_headed_Vireo
153 | 153 Philadelphia_Vireo
154 | 154 Red_eyed_Vireo
155 | 155 Warbling_Vireo
156 | 156 White_eyed_Vireo
157 | 157 Yellow_throated_Vireo
158 | 158 Bay_breasted_Warbler
159 | 159 Black_and_white_Warbler
160 | 160 Black_throated_Blue_Warbler
161 | 161 Blue_winged_Warbler
162 | 162 Canada_Warbler
163 | 163 Cape_May_Warbler
164 | 164 Cerulean_Warbler
165 | 165 Chestnut_sided_Warbler
166 | 166 Golden_winged_Warbler
167 | 167 Hooded_Warbler
168 | 168 Kentucky_Warbler
169 | 169 Magnolia_Warbler
170 | 170 Mourning_Warbler
171 | 171 Myrtle_Warbler
172 | 172 Nashville_Warbler
173 | 173 Orange_crowned_Warbler
174 | 174 Palm_Warbler
175 | 175 Pine_Warbler
176 | 176 Prairie_Warbler
177 | 177 Prothonotary_Warbler
178 | 178 Swainson_Warbler
179 | 179 Tennessee_Warbler
180 | 180 Wilson_Warbler
181 | 181 Worm_eating_Warbler
182 | 182 Yellow_Warbler
183 | 183 Northern_Waterthrush
184 | 184 Louisiana_Waterthrush
185 | 185 Bohemian_Waxwing
186 | 186 Cedar_Waxwing
187 | 187 American_Three_toed_Woodpecker
188 | 188 Pileated_Woodpecker
189 | 189 Red_bellied_Woodpecker
190 | 190 Red_cockaded_Woodpecker
191 | 191 Red_headed_Woodpecker
192 | 192 Downy_Woodpecker
193 | 193 Bewick_Wren
194 | 194 Cactus_Wren
195 | 195 Carolina_Wren
196 | 196 House_Wren
197 | 197 Marsh_Wren
198 | 198 Rock_Wren
199 | 199 Winter_Wren
200 | 200 Common_Yellowthroat
201 |
--------------------------------------------------------------------------------
/dataset_reqs/imagenet_R_classes.txt:
--------------------------------------------------------------------------------
1 | 0 n01443537 goldfish
2 | 1 n01484850 great_white_shark
3 | 2 n01494475 hammerhead
4 | 3 n01498041 stingray
5 | 4 n01514859 hen
6 | 5 n01518878 ostrich
7 | 6 n01531178 goldfinch
8 | 7 n01534433 junco
9 | 8 n01614925 bald_eagle
10 | 9 n01616318 vulture
11 | 10 n01630670 common_newt
12 | 11 n01632777 axolotl
13 | 12 n01644373 tree_frog
14 | 13 n01677366 common_iguana
15 | 14 n01694178 African_chameleon
16 | 15 n01748264 Indian_cobra
17 | 16 n01770393 scorpion
18 | 17 n01774750 tarantula
19 | 18 n01784675 centipede
20 | 19 n01806143 peacock
21 | 20 n01820546 lorikeet
22 | 21 n01833805 hummingbird
23 | 22 n01843383 toucan
24 | 23 n01847000 drake
25 | 24 n01855672 goose
26 | 25 n01860187 black_swan
27 | 26 n01882714 koala
28 | 27 n01910747 jellyfish
29 | 28 n01944390 snail
30 | 29 n01983481 American_lobster
31 | 30 n01986214 hermit_crab
32 | 31 n02007558 flamingo
33 | 32 n02009912 American_egret
34 | 33 n02051845 pelican
35 | 34 n02056570 king_penguin
36 | 35 n02066245 grey_whale
37 | 36 n02071294 killer_whale
38 | 37 n02077923 sea_lion
39 | 38 n02085620 Chihuahua
40 | 39 n02086240 Shih-Tzu
41 | 40 n02088094 Afghan_hound
42 | 41 n02088238 basset
43 | 42 n02088364 beagle
44 | 43 n02088466 bloodhound
45 | 44 n02091032 Italian_greyhound
46 | 45 n02091134 whippet
47 | 46 n02092339 Weimaraner
48 | 47 n02094433 Yorkshire_terrier
49 | 48 n02096585 Boston_bull
50 | 49 n02097298 Scotch_terrier
51 | 50 n02098286 West_Highland_white_terrier
52 | 51 n02099601 golden_retriever
53 | 52 n02099712 Labrador_retriever
54 | 53 n02102318 cocker_spaniel
55 | 54 n02106030 collie
56 | 55 n02106166 Border_collie
57 | 56 n02106550 Rottweiler
58 | 57 n02106662 German_shepherd
59 | 58 n02108089 boxer
60 | 59 n02108915 French_bulldog
61 | 60 n02109525 Saint_Bernard
62 | 61 n02110185 Siberian_husky
63 | 62 n02110341 dalmatian
64 | 63 n02110958 pug
65 | 64 n02112018 Pomeranian
66 | 65 n02112137 chow
67 | 66 n02113023 Pembroke
68 | 67 n02113624 toy_poodle
69 | 68 n02113799 standard_poodle
70 | 69 n02114367 timber_wolf
71 | 70 n02117135 hyena
72 | 71 n02119022 red_fox
73 | 72 n02123045 tabby
74 | 73 n02128385 leopard
75 | 74 n02128757 snow_leopard
76 | 75 n02129165 lion
77 | 76 n02129604 tiger
78 | 77 n02130308 cheetah
79 | 78 n02134084 ice_bear
80 | 79 n02138441 meerkat
81 | 80 n02165456 ladybug
82 | 81 n02190166 fly
83 | 82 n02206856 bee
84 | 83 n02219486 ant
85 | 84 n02226429 grasshopper
86 | 85 n02233338 cockroach
87 | 86 n02236044 mantis
88 | 87 n02268443 dragonfly
89 | 88 n02279972 monarch
90 | 89 n02317335 starfish
91 | 90 n02325366 wood_rabbit
92 | 91 n02346627 porcupine
93 | 92 n02356798 fox_squirrel
94 | 93 n02363005 beaver
95 | 94 n02364673 guinea_pig
96 | 95 n02391049 zebra
97 | 96 n02395406 hog
98 | 97 n02398521 hippopotamus
99 | 98 n02410509 bison
100 | 99 n02423022 gazelle
101 | 100 n02437616 llama
102 | 101 n02445715 skunk
103 | 102 n02447366 badger
104 | 103 n02480495 orangutan
105 | 104 n02480855 gorilla
106 | 105 n02481823 chimpanzee
107 | 106 n02483362 gibbon
108 | 107 n02486410 baboon
109 | 108 n02510455 giant_panda
110 | 109 n02526121 eel
111 | 110 n02607072 anemone_fish
112 | 111 n02655020 puffer
113 | 112 n02672831 accordion
114 | 113 n02701002 ambulance
115 | 114 n02749479 assault_rifle
116 | 115 n02769748 backpack
117 | 116 n02793495 barn
118 | 117 n02797295 barrow
119 | 118 n02802426 basketball
120 | 119 n02808440 bathtub
121 | 120 n02814860 beacon
122 | 121 n02823750 beer_glass
123 | 122 n02841315 binoculars
124 | 123 n02843684 birdhouse
125 | 124 n02883205 bow_tie
126 | 125 n02906734 broom
127 | 126 n02909870 bucket
128 | 127 n02939185 caldron
129 | 128 n02948072 candle
130 | 129 n02950826 cannon
131 | 130 n02951358 canoe
132 | 131 n02966193 carousel
133 | 132 n02980441 castle
134 | 133 n02992529 cellular_telephone
135 | 134 n03124170 cowboy_hat
136 | 135 n03272010 electric_guitar
137 | 136 n03345487 fire_engine
138 | 137 n03372029 flute
139 | 138 n03424325 gasmask
140 | 139 n03452741 grand_piano
141 | 140 n03467068 guillotine
142 | 141 n03481172 hammer
143 | 142 n03494278 harmonica
144 | 143 n03495258 harp
145 | 144 n03498962 hatchet
146 | 145 n03594945 jeep
147 | 146 n03602883 joystick
148 | 147 n03630383 lab_coat
149 | 148 n03649909 lawn_mower
150 | 149 n03676483 lipstick
151 | 150 n03710193 mailbox
152 | 151 n03773504 missile
153 | 152 n03775071 mitten
154 | 153 n03888257 parachute
155 | 154 n03930630 pickup
156 | 155 n03947888 pirate
157 | 156 n04086273 revolver
158 | 157 n04118538 rugby_ball
159 | 158 n04133789 sandal
160 | 159 n04141076 sax
161 | 160 n04146614 school_bus
162 | 161 n04147183 schooner
163 | 162 n04192698 shield
164 | 163 n04254680 soccer_ball
165 | 164 n04266014 space_shuttle
166 | 165 n04275548 spider_web
167 | 166 n04310018 steam_locomotive
168 | 167 n04325704 stole
169 | 168 n04347754 submarine
170 | 169 n04389033 tank
171 | 170 n04409515 tennis_ball
172 | 171 n04465501 tractor
173 | 172 n04487394 trombone
174 | 173 n04522168 vase
175 | 174 n04536866 violin
176 | 175 n04552348 warplane
177 | 176 n04591713 wine_bottle
178 | 177 n07614500 ice_cream
179 | 178 n07693725 bagel
180 | 179 n07695742 pretzel
181 | 180 n07697313 cheeseburger
182 | 181 n07697537 hotdog
183 | 182 n07714571 head_cabbage
184 | 183 n07714990 broccoli
185 | 184 n07718472 cucumber
186 | 185 n07720875 bell_pepper
187 | 186 n07734744 mushroom
188 | 187 n07742313 Granny_Smith
189 | 188 n07745940 strawberry
190 | 189 n07749582 lemon
191 | 190 n07753275 pineapple
192 | 191 n07753592 banana
193 | 192 n07768694 pomegranate
194 | 193 n07873807 pizza
195 | 194 n07880968 burrito
196 | 195 n07920052 espresso
197 | 196 n09472597 volcano
198 | 197 n09835506 ballplayer
199 | 198 n10565667 scuba_diver
200 | 199 n12267677 acorn
201 |
--------------------------------------------------------------------------------
/dataset_reqs/tinyimagenet_classes.txt:
--------------------------------------------------------------------------------
1 | 0 n02124075 Egyptian Mau
2 | 1 n04067472 fishing casting reel
3 | 2 n04540053 volleyball
4 | 3 n04099969 rocking chair
5 | 4 n07749582 lemon
6 | 5 n01641577 American bullfrog
7 | 6 n02802426 basketball
8 | 7 n09246464 cliff
9 | 8 n07920052 espresso
10 | 9 n03970156 plunger
11 | 10 n03891332 parking meter
12 | 11 n02106662 German Shepherd Dog
13 | 12 n03201208 dining table
14 | 13 n02279972 monarch butterfly
15 | 14 n02132136 brown bear
16 | 15 n04146614 school bus
17 | 16 n07873807 pizza
18 | 17 n02364673 guinea pig
19 | 18 n04507155 umbrella
20 | 19 n03854065 pipe organ
21 | 20 n03838899 oboe
22 | 21 n03733131 maypole
23 | 22 n01443537 goldfish
24 | 23 n07875152 pot pie
25 | 24 n03544143 hourglass
26 | 25 n09428293 beach
27 | 26 n03085013 computer keyboard
28 | 27 n02437312 arabian camel
29 | 28 n07614500 ice cream
30 | 29 n03804744 metal nail
31 | 30 n04265275 space heater
32 | 31 n02963159 cardigan
33 | 32 n02486410 baboon
34 | 33 n01944390 snail
35 | 34 n09256479 coral reef
36 | 35 n02058221 albatross
37 | 36 n04275548 spider web
38 | 37 n02321529 sea cucumber
39 | 38 n02769748 backpack
40 | 39 n02099712 Labrador Retriever
41 | 40 n07695742 pretzel
42 | 41 n02056570 king penguin
43 | 42 n02281406 sulphur butterfly
44 | 43 n01774750 tarantula
45 | 44 n02509815 red panda
46 | 45 n03983396 soda bottle
47 | 46 n07753592 banana
48 | 47 n04254777 sock
49 | 48 n02233338 cockroach
50 | 49 n04008634 missile
51 | 50 n02823428 beer bottle
52 | 51 n02236044 praying mantis
53 | 52 n03393912 freight car
54 | 53 n07583066 guacamole
55 | 54 n04074963 remote control
56 | 55 n01629819 fire salamander
57 | 56 n09332890 lakeshore
58 | 57 n02481823 chimpanzee
59 | 58 n03902125 payphone
60 | 59 n03404251 fur coat
61 | 60 n09193705 mountain
62 | 61 n03637318 lampshade
63 | 62 n04456115 torch
64 | 63 n02666196 abacus
65 | 64 n03796401 moving van
66 | 65 n02795169 barrel
67 | 66 n02123045 tabby cat
68 | 67 n01855672 goose
69 | 68 n01882714 koala
70 | 69 n02917067 high-speed train
71 | 70 n02988304 CD player
72 | 71 n04398044 teapot
73 | 72 n02843684 birdhouse
74 | 73 n02423022 gazelle
75 | 74 n02669723 academic gown
76 | 75 n04465501 tractor
77 | 76 n02165456 ladybug
78 | 77 n03770439 miniskirt
79 | 78 n02099601 Golden Retriever
80 | 79 n04486054 triumphal arch
81 | 80 n02950826 cannon
82 | 81 n03814639 neck brace
83 | 82 n04259630 sombrero
84 | 83 n03424325 gas mask or respirator
85 | 84 n02948072 candle
86 | 85 n03179701 desk
87 | 86 n03400231 frying pan
88 | 87 n02206856 bee
89 | 88 n03160309 dam
90 | 89 n01984695 spiny lobster
91 | 90 n03977966 police van
92 | 91 n03584254 iPod
93 | 92 n04023962 punching bag
94 | 93 n02814860 lighthouse
95 | 94 n01910747 jellyfish
96 | 95 n04596742 wok
97 | 96 n03992509 potter's wheel
98 | 97 n04133789 sandal
99 | 98 n03937543 pill bottle
100 | 99 n02927161 butcher shop
101 | 100 n01945685 slug
102 | 101 n02395406 pig
103 | 102 n02125311 cougar
104 | 103 n03126707 construction crane
105 | 104 n04532106 vestment
106 | 105 n02268443 dragonfly
107 | 106 n02977058 automated teller machine
108 | 107 n07734744 mushroom
109 | 108 n03599486 rickshaw
110 | 109 n04562935 water tower
111 | 110 n03014705 storage chest
112 | 111 n04251144 snorkel
113 | 112 n04356056 sunglasses
114 | 113 n02190166 fly
115 | 114 n03670208 limousine
116 | 115 n02002724 black stork
117 | 116 n02074367 dugong
118 | 117 n04285008 sports car
119 | 118 n04560804 water jug
120 | 119 n04366367 suspension bridge
121 | 120 n02403003 ox
122 | 121 n07615774 popsicle
123 | 122 n04501370 turnstile
124 | 123 n03026506 Christmas stocking
125 | 124 n02906734 broom
126 | 125 n01770393 scorpion
127 | 126 n04597913 wooden spoon
128 | 127 n03930313 picket fence
129 | 128 n04118538 rugby ball
130 | 129 n04179913 sewing machine
131 | 130 n04311004 through arch bridge
132 | 131 n02123394 Persian cat
133 | 132 n04070727 refrigerator
134 | 133 n02793495 barn
135 | 134 n02730930 apron
136 | 135 n02094433 Yorkshire Terrier
137 | 136 n04371430 swim trunks / shorts
138 | 137 n04328186 stopwatch
139 | 138 n03649909 lawn mower
140 | 139 n04417672 thatched roof
141 | 140 n03388043 fountain
142 | 141 n01774384 southern black widow
143 | 142 n02837789 bikini
144 | 143 n07579787 plate
145 | 144 n04399382 teddy bear
146 | 145 n02791270 barbershop
147 | 146 n03089624 candy store
148 | 147 n02814533 station wagon
149 | 148 n04149813 scoreboard
150 | 149 n07747607 orange
151 | 150 n03355925 flagpole
152 | 151 n01983481 American lobster
153 | 152 n04487081 trolleybus
154 | 153 n03250847 drumstick
155 | 154 n03255030 dumbbell
156 | 155 n02892201 brass memorial plaque
157 | 156 n02883205 bow tie
158 | 157 n03100240 convertible
159 | 158 n02415577 bighorn sheep
160 | 159 n02480495 orangutan
161 | 160 n01698640 American alligator
162 | 161 n01784675 centipede
163 | 162 n04376876 syringe
164 | 163 n03444034 go-kart
165 | 164 n01917289 brain coral
166 | 165 n01950731 sea slug
167 | 166 n03042490 cliff dwelling
168 | 167 n07711569 mashed potatoes
169 | 168 n04532670 viaduct
170 | 169 n03763968 military uniform
171 | 170 n07768694 pomegranate
172 | 171 n02999410 chain
173 | 172 n03617480 kimono
174 | 173 n06596364 comic book
175 | 174 n01768244 trilobite
176 | 175 n02410509 bison
177 | 176 n03976657 pole
178 | 177 n01742172 boa constrictor
179 | 178 n03980874 poncho
180 | 179 n02808440 bathtub
181 | 180 n02226429 grasshopper
182 | 181 n02231487 stick insect
183 | 182 n02085620 Chihuahua
184 | 183 n01644900 tailed frog
185 | 184 n02129165 lion
186 | 185 n02699494 altar
187 | 186 n03837869 obelisk
188 | 187 n02815834 beaker
189 | 188 n07720875 bell pepper
190 | 189 n02788148 baluster / handrail
191 | 190 n02909870 bucket
192 | 191 n03706229 magnetic compass
193 | 192 n07871810 meatloaf
194 | 193 n03447447 gondola
195 | 194 n02113799 Standard Poodle
196 | 195 n12267677 acorn
197 | 196 n03662601 lifeboat
198 | 197 n02841315 binoculars
199 | 198 n07715103 cauliflower
200 | 199 n02504458 African bush elephant
201 |
--------------------------------------------------------------------------------
/continual_clip/datasets.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import os
4 | import torch.nn as nn
5 |
6 | from continuum import ClassIncremental, InstanceIncremental
7 | from continuum.datasets import (
8 | CIFAR100, ImageNet100, TinyImageNet200, ImageFolderDataset, Core50, CUB200, Food101,OxfordPet,Caltech101
9 | )
10 | from .utils import get_dataset_class_names, get_workdir
11 | import pdb
12 |
13 |
14 | class ImageNet_C(ImageFolderDataset):
15 | """Continuum dataset for datasets with tree-like structure.
16 | :param train_folder: The folder of the train data.
17 | :param test_folder: The folder of the test data.
18 | :param download: Dummy parameter.
19 | """
20 |
21 | def __init__(
22 | self,
23 | data_path: str,
24 | train: bool = True,
25 | download: bool = False,
26 | ):
27 | super().__init__(data_path=data_path, train=train, download=download)
28 |
29 | def get_data(self):
30 | self.data_path = self.data_path
31 | return super().get_data()
32 |
33 |
34 | class ImageNet1000(ImageFolderDataset):
35 | """Continuum dataset for datasets with tree-like structure.
36 | :param train_folder: The folder of the train data.
37 | :param test_folder: The folder of the test data.
38 | :param download: Dummy parameter.
39 | """
40 |
41 | def __init__(
42 | self,
43 | data_path: str,
44 | train: bool = True,
45 | download: bool = False,
46 | ):
47 | super().__init__(data_path=data_path, train=train, download=download)
48 |
49 | def get_data(self):
50 | if self.train:
51 | self.data_path = os.path.join(self.data_path, "train")
52 | else:
53 | self.data_path = os.path.join(self.data_path, "val")
54 | return super().get_data()
55 |
56 |
57 | class ImageNet_R(ImageFolderDataset):
58 | """Continuum dataset for datasets with tree-like structure.
59 | :param train_folder: The folder of the train data.
60 | :param test_folder: The folder of the test data.
61 | :param download: Dummy parameter.
62 | """
63 |
64 | def __init__(
65 | self,
66 | data_path: str,
67 | train: bool = True,
68 | download: bool = False,
69 | ):
70 | super().__init__(data_path=data_path, train=train, download=download)
71 | @property
72 | def transformations(self):
73 | """Default transformations if nothing is provided to the scenario."""
74 | return [
75 | transforms.Resize((224, 224)),
76 | transforms.ToTensor(),
77 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
78 | ]
79 |
80 | def get_data(self):
81 | if self.train:
82 | self.data_path = os.path.join(self.data_path, "train")
83 | else:
84 | self.data_path = os.path.join(self.data_path, "test")
85 | return super().get_data()
86 |
87 | class VTAB(ImageFolderDataset):
88 | """Continuum dataset for datasets with tree-like structure.
89 | :param train_folder: The folder of the train data.
90 | :param test_folder: The folder of the test data.
91 | :param download: Dummy parameter.
92 | """
93 |
94 | def __init__(
95 | self,
96 | data_path: str,
97 | train: bool = True,
98 | download: bool = False,
99 | ):
100 | super().__init__(data_path=data_path, train=train, download=download)
101 | @property
102 | def transformations(self):
103 | """Default transformations if nothing is provided to the scenario."""
104 | return [
105 | transforms.Resize((224, 224)),
106 | transforms.ToTensor(),
107 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
108 | ]
109 |
110 | def get_data(self):
111 | if self.train:
112 | self.data_path = os.path.join(self.data_path, "train")
113 | else:
114 | self.data_path = os.path.join(self.data_path, "test")
115 | return super().get_data()
116 |
117 |
118 | def get_dataset(cfg, is_train, transforms=None):
119 | if cfg.dataset == "cifar100":
120 | data_path = cfg.dataset_root
121 | dataset = CIFAR100(
122 | data_path=data_path,
123 | download=True,
124 | train=is_train,
125 | # transforms=transforms
126 | )
127 | classes_names = dataset.dataset.classes
128 | elif cfg.dataset == "imagenet_R":
129 | data_path = cfg.dataset_root
130 | dataset = ImageNet_R(
131 | data_path,
132 | train=is_train
133 | )
134 | classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset)
135 | elif cfg.dataset == "cub200":
136 | data_path = cfg.dataset_root
137 | dataset = CUB200(
138 | data_path,
139 | train=is_train
140 | )
141 | classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset)
142 | elif cfg.dataset == "food101":
143 | data_path = cfg.dataset_root
144 | dataset = Food101(
145 | data_path,
146 | train=is_train
147 | )
148 | classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset)
149 | elif cfg.dataset == "oxford_pet":
150 | data_path = cfg.dataset_root
151 | dataset = OxfordPet(
152 | data_path,
153 | train=is_train
154 | )
155 | classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset)
156 | elif cfg.dataset == "caltech101":
157 | data_path = cfg.dataset_root
158 | dataset = Caltech101(
159 | data_path,
160 | train=is_train
161 | )
162 | classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset)
163 |
164 | elif cfg.dataset == "vtab":
165 | data_path = cfg.dataset_root
166 | dataset = VTAB(
167 | data_path,
168 | train=is_train
169 | )
170 | classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset)
171 | elif cfg.dataset == "imagenet_c":
172 | data_path = cfg.dataset_root
173 | dataset = ImageNet_C(
174 | data_path,
175 | train=is_train
176 | )
177 | classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset)
178 |
179 | elif cfg.dataset == "tinyimagenet":
180 | data_path = os.path.join(cfg.dataset_root, cfg.dataset)
181 | dataset = TinyImageNet200(
182 | data_path,
183 | train=is_train,
184 | download=True
185 | )
186 | classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset)
187 |
188 | elif cfg.dataset == "imagenet100":
189 | data_path = cfg.dataset_root
190 | dataset = ImageNet100(
191 | data_path,
192 | train=is_train,
193 | data_subset=os.path.join(get_workdir(os.getcwd()), "class_orders/train_100.txt" if is_train else "class_orders/val_100.txt")
194 | )
195 | classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset)
196 |
197 | elif cfg.dataset == "imagenet1000":
198 | data_path = cfg.dataset_root
199 | dataset = ImageNet1000(
200 | data_path,
201 | train=is_train
202 | )
203 | classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset)
204 |
205 | elif cfg.dataset == "core50":
206 | data_path = os.path.join(cfg.dataset_root, cfg.dataset)
207 | dataset = dataset = Core50(
208 | data_path,
209 | scenario="domains",
210 | classification="category",
211 | train=is_train
212 | )
213 | classes_names = [
214 | "plug adapters", "mobile phones", "scissors", "light bulbs", "cans",
215 | "glasses", "balls", "markers", "cups", "remote controls"
216 | ]
217 |
218 | else:
219 | ValueError(f"'{cfg.dataset}' is a invalid dataset.")
220 |
221 | return dataset, classes_names
222 |
223 |
224 | def build_cl_scenarios(cfg, is_train, transforms) -> nn.Module:
225 |
226 | dataset, classes_names = get_dataset(cfg, is_train)
227 |
228 | if cfg.scenario == "class":
229 | scenario = ClassIncremental(
230 | dataset,
231 | initial_increment=cfg.initial_increment,
232 | increment=cfg.increment,
233 | transformations=transforms.transforms, # Convert Compose into list
234 | class_order=cfg.class_order,
235 | )
236 |
237 | elif cfg.scenario == "domain":
238 | scenario = InstanceIncremental(
239 | dataset,
240 | transformations=transforms.transforms,
241 | )
242 |
243 | elif cfg.scenario == "task-agnostic":
244 | NotImplementedError("Method has not been implemented. Soon be added.")
245 |
246 | else:
247 | ValueError(f"You have entered `{cfg.scenario}` which is not a defined scenario, "
248 | "please choose from {{'class', 'domain', 'task-agnostic'}}.")
249 |
250 | return scenario, classes_names
--------------------------------------------------------------------------------
/continual_clip/models.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import pdb
4 | from omegaconf import DictConfig
5 |
6 | import clip
7 | import torch
8 | import torch.nn as nn
9 | import types
10 | from loraclip import lora_clip
11 | from clip.model import VisionTransformer as CLIPVisionTransformer
12 | from torch.nn import functional as F
13 | from .utils import get_class_ids_per_task, get_class_names
14 | import random
15 | import numpy as np
16 |
17 |
18 |
19 |
20 |
21 | def forward_clip(self, image, text, return_feature=False):
22 | image_features = self.encode_image(image)
23 | text_features = self.encode_text(text)
24 |
25 | # normalized features
26 | image_features = image_features / image_features.norm(dim=1, keepdim=True)
27 | text_features = text_features / text_features.norm(dim=1, keepdim=True)
28 |
29 | # cosine similarity as logits
30 | logit_scale = self.logit_scale.exp()
31 | logits_per_image = logit_scale * image_features @ text_features.t()
32 | logits_per_text = logits_per_image.t()
33 |
34 | if return_feature:
35 | return logits_per_image, logits_per_text, image_features, text_features
36 |
37 |
38 | # shape = [global_batch_size, global_batch_size]
39 | return logits_per_image, logits_per_text
40 |
41 |
42 |
43 | class VisionClassifier(nn.Module):
44 | def __init__(self, in_features, num_classes, weight_init=None, activation=None):
45 | super().__init__()
46 | self.fc = nn.Linear(in_features, num_classes, bias=False)
47 | self.fc = nn.Parameter(self.fc.weight.data)
48 | if weight_init is not None:
49 | self.fc.data = weight_init
50 | if activation is not None:
51 | self.activation = activation
52 | else:
53 | self.activation = nn.Identity()
54 |
55 | def add_weight(self, weight):
56 | self.fc = nn.Parameter(torch.cat([self.fc, weight], dim=0))
57 |
58 | def set_weight(self, weight):
59 | self.fc = nn.Parameter(weight)
60 |
61 |
62 | def forward(self, x):
63 | # normalize the weights
64 | x = F.normalize(x, p=2, dim=-1)
65 | weight = F.normalize(self.fc, p=2, dim=-1)
66 | x = F.linear(x, weight)
67 | x = self.activation(x)
68 | return x
69 |
70 |
71 |
72 | class ClassIncrementalCLIP(nn.Module):
73 | def __init__(self, cfg, device, jit=False):
74 | super().__init__()
75 | self.cfg = cfg
76 | self.prompt_template = cfg.prompt_template
77 | self.device = device
78 | self.classes_names = None
79 | # self.model, self.transforms = clip.load(cfg.model_name, device=device, jit=jit)
80 |
81 |
82 | #lora_clip
83 | self.model, self.transforms = lora_clip.load(cfg.model_name, device=device, jit=jit, r=cfg.lora_rank, lora_mode=cfg.lora_mode)
84 | # for name, param in self.model.named_parameters():
85 | # if 'adapter_mlp' in name:
86 | # param.requires_grad = True
87 | # for name, param in self.model.named_parameters():
88 | # if param.requires_grad:
89 | # print(f"Trainable: {name}")
90 | self.model.forward = types.MethodType(forward_clip, self.model)
91 | ori_state = self.model.state_dict()
92 |
93 | self.class_ids_per_task = list(get_class_ids_per_task(cfg))
94 | self.current_class_names = []
95 | self.text_tokens = None
96 | self.current_task = -1
97 | self.only_reset_B = cfg.only_reset_B
98 | self.freeze_A = cfg.freeze_A
99 |
100 |
101 | def cur_text_features(self):
102 | f = self.model.encode_text(self.text_tokens)
103 | f = f / f.norm(dim=1, keepdim=True)
104 | return f
105 |
106 | def inference(self, image, text_tokens):
107 | text_features = self.model.encode_text(text_tokens)
108 | image_features = self.model.visual(image.type(self.model.dtype), all_tokens=False, adapt=self.attention_adapter)
109 | # pdb.set_trace()
110 |
111 | # image_features = self.attention_adapter(image_features.type(torch.float32))[:, 0, :]
112 |
113 | image_features = image_features / image_features.norm(dim=1, keepdim=True)
114 | text_features = text_features / text_features.norm(dim=1, keepdim=True)
115 |
116 | logit_scale = self.model.logit_scale.exp()
117 | logits_per_image = logit_scale * image_features @ text_features.t()
118 | return logits_per_image
119 |
120 | def forward(self, image, test=False, all_test=False, return_feature=False,replay=None):
121 | if test:
122 | # pdb.set_trace()
123 | with torch.no_grad():
124 | if all_test:
125 | if return_feature:
126 | logits_per_image, _, image_features, __ = self.model(image, self.all_text_tokens, return_feature=return_feature)
127 | else:
128 | logits_per_image, _ = self.model(image, self.all_text_tokens)
129 | # logits_per_image = self.inference(image, self.all_text_tokens)
130 | else:
131 | if return_feature:
132 | logits_per_image, _, image_features, __ = self.model(image, self.text_tokens, return_feature=return_feature)
133 | else:
134 | logits_per_image, _ = self.model(image, self.text_tokens)
135 | # pdb.set_trace()
136 | probs = logits_per_image.softmax(dim=-1)
137 | else:
138 |
139 | if return_feature:
140 | __, _, image_features, text_features = self.model(image, self.text_tokens, return_feature=return_feature)
141 | return image_features, text_features
142 | if replay is not None:
143 | logits_per_image, _ = self.model(image, self.text_tokens)
144 | # text_features_for_replay = self.model.encode_text(self.text_tokens[:-self.cfg.increment])
145 | text_features_for_replay = self.model.encode_text(self.text_tokens)
146 | text_features_for_replay = text_features_for_replay / text_features_for_replay.norm(dim=1, keepdim=True)
147 | replay_features = replay / replay.norm(dim=1, keepdim=True)
148 | replay_logits = replay_features @ text_features_for_replay.t() * 100
149 | else:
150 | logits_per_image, _ = self.model(image, self.text_tokens)
151 | probs = logits_per_image
152 |
153 | if return_feature:
154 | text_features = self.model.encode_text(self.all_text_tokens)
155 | return probs, image_features, text_features
156 |
157 | if replay is not None:
158 | return probs, replay_logits
159 | return probs
160 |
161 | def adaptation(self, task_id, reset=False):
162 | self.current_task +=1
163 | if reset and self.current_task>0:
164 | ori_state = torch.load('ori_state.pth')
165 | if self.only_reset_B:
166 | now_state = self.model.state_dict()
167 | lora_params = {k: v for k, v in ori_state.items() if 'lora_B' in k}
168 | now_state.update(lora_params)
169 | else:
170 | now_state = ori_state
171 | self.model.load_state_dict(now_state)
172 | if self.freeze_A and self.current_task>0:
173 | for name, param in self.model.named_parameters():
174 | if 'lora_A' in name:
175 | param.requires_grad = False
176 |
177 | self.current_task_class_names = get_class_names(self.classes_names, self.class_ids_per_task[task_id])
178 | self.current_class_names += self.current_task_class_names
179 | self.text_tokens = clip.tokenize(
180 | [self.prompt_template.format(c) for c in self.current_class_names]
181 | ).to(self.device)
182 | self.current_task_text_tokens = clip.tokenize(
183 | [self.prompt_template.format(c) for c in self.current_task_class_names]
184 | ).to(self.device)
185 | if self.current_task == 0:
186 | class_names = []
187 | for i in range(self.cfg.task_num):
188 | class_names += get_class_names(self.classes_names, self.class_ids_per_task[i])
189 | self.all_class_names = class_names
190 | self.all_text_tokens = clip.tokenize(
191 | [self.prompt_template.format(c) for c in self.all_class_names]
192 | ).to(self.device)
193 |
194 |
195 |
196 |
197 |
198 | def load_model(cfg: DictConfig, device: torch.device) -> nn.Module:
199 | r"""Load a CLIP model in different continual scenarios.
200 |
201 | Arguments:
202 | cfg (DictConfig): Experiment configurations.
203 | device (torch.device): Device to train (or) evaluate the model on.
204 |
205 | Returns:
206 | nn.Module: Return scenario specific CLIP model.
207 | """
208 | if cfg.scenario == "class":
209 | return ClassIncrementalCLIP(cfg, device)
210 | else:
211 | raise ValueError(f"""
212 | `{cfg.scenarios}` is not a valid scenario,
213 | Please choose from ['class', "domain', 'task-agnostic']
214 | """)
215 |
216 |
--------------------------------------------------------------------------------
/loraclip/lora_clip.py:
--------------------------------------------------------------------------------
1 | # adapted from https://github.com/jaisidhsingh/LoRA-CLIP
2 | import hashlib
3 | import os
4 | import urllib
5 | import warnings
6 | from typing import Any, Union, List
7 | from pkg_resources import packaging
8 |
9 | import torch
10 | from PIL import Image
11 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
12 | from tqdm import tqdm
13 |
14 | import clip
15 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer # use default clip's tokenization
16 |
17 | from .model import build_LoRA_model
18 |
19 | try:
20 | from torchvision.transforms import InterpolationMode
21 | BICUBIC = InterpolationMode.BICUBIC
22 | except ImportError:
23 | BICUBIC = Image.BICUBIC
24 |
25 |
26 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
27 | warnings.warn("PyTorch version 1.7.1 or higher is recommended")
28 |
29 |
30 | __all__ = ["available_models", "load", "tokenize"]
31 | _tokenizer = _Tokenizer()
32 |
33 | _MODELS = {
34 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
35 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
36 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
37 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
38 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
39 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
40 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
41 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
42 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
43 | }
44 |
45 |
46 | def _download(url: str, root: str):
47 | os.makedirs(root, exist_ok=True)
48 | filename = os.path.basename(url)
49 |
50 | expected_sha256 = url.split("/")[-2]
51 | download_target = os.path.join(root, filename)
52 |
53 | if os.path.exists(download_target) and not os.path.isfile(download_target):
54 | raise RuntimeError(f"{download_target} exists and is not a regular file")
55 |
56 | if os.path.isfile(download_target):
57 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
58 | return download_target
59 | else:
60 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
61 |
62 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
63 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
64 | while True:
65 | buffer = source.read(8192)
66 | if not buffer:
67 | break
68 |
69 | output.write(buffer)
70 | loop.update(len(buffer))
71 |
72 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
73 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
74 |
75 | return download_target
76 |
77 |
78 | def _convert_image_to_rgb(image):
79 | return image.convert("RGB")
80 |
81 |
82 | def _transform(n_px):
83 | return Compose([
84 | Resize(n_px, interpolation=BICUBIC),
85 | CenterCrop(n_px),
86 | _convert_image_to_rgb,
87 | ToTensor(),
88 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
89 | ])
90 |
91 |
92 | def available_models() -> List[str]:
93 | """Returns the names of available CLIP models"""
94 | return list(_MODELS.keys())
95 |
96 |
97 | def load(name: str,
98 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
99 | jit: bool = False,
100 | download_root: str = None,
101 | r: int = 4,
102 | lora_mode: str = "vision+text"
103 | ):
104 | """Load a CLIP model
105 |
106 | Parameters
107 | ----------
108 | name : str
109 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
110 |
111 | device : Union[str, torch.device]
112 | The device to put the loaded model
113 |
114 | jit : bool
115 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
116 |
117 | download_root: str
118 | path to download the model files; by default, it uses "~/.cache/clip"
119 |
120 | r : int
121 | Rank of the LoRA matrices
122 |
123 | Returns
124 | -------
125 | model : torch.nn.Module
126 | The CLIP model
127 |
128 | preprocess : Callable[[PIL.Image], torch.Tensor]
129 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
130 | """
131 | if name in _MODELS:
132 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
133 | elif os.path.isfile(name):
134 | model_path = name
135 | else:
136 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
137 |
138 | with open(model_path, 'rb') as opened_file:
139 | try:
140 | # loading JIT archive
141 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
142 | state_dict = None
143 | except RuntimeError:
144 | # loading saved state dict
145 | if jit:
146 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
147 | jit = False
148 | state_dict = torch.load(opened_file, map_location="cpu")
149 |
150 | if not jit:
151 | model = build_LoRA_model(state_dict or model.state_dict(), r, lora_mode).to(device)
152 | if str(device) == "cpu":
153 | model.float()
154 | return model, _transform(model.visual.input_resolution)
155 |
156 | # patch the device names
157 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
158 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
159 |
160 | def _node_get(node: torch._C.Node, key: str):
161 | """Gets attributes of a node which is polymorphic over return type.
162 |
163 | From https://github.com/pytorch/pytorch/pull/82628
164 | """
165 | sel = node.kindOf(key)
166 | return getattr(node, sel)(key)
167 |
168 | def patch_device(module):
169 | try:
170 | graphs = [module.graph] if hasattr(module, "graph") else []
171 | except RuntimeError:
172 | graphs = []
173 |
174 | if hasattr(module, "forward1"):
175 | graphs.append(module.forward1.graph)
176 |
177 | for graph in graphs:
178 | for node in graph.findAllNodes("prim::Constant"):
179 | if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
180 | node.copyAttributes(device_node)
181 |
182 | model.apply(patch_device)
183 | patch_device(model.encode_image)
184 | patch_device(model.encode_text)
185 |
186 | # patch dtype to float32 on CPU
187 | if str(device) == "cpu":
188 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
189 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
190 | float_node = float_input.node()
191 |
192 | def patch_float(module):
193 | try:
194 | graphs = [module.graph] if hasattr(module, "graph") else []
195 | except RuntimeError:
196 | graphs = []
197 |
198 | if hasattr(module, "forward1"):
199 | graphs.append(module.forward1.graph)
200 |
201 | for graph in graphs:
202 | for node in graph.findAllNodes("aten::to"):
203 | inputs = list(node.inputs())
204 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
205 | if _node_get(inputs[i].node(), "value") == 5:
206 | inputs[i].node().copyAttributes(float_node)
207 |
208 | model.apply(patch_float)
209 | patch_float(model.encode_image)
210 | patch_float(model.encode_text)
211 |
212 | model.float()
213 |
214 | return model, _transform(model.input_resolution.item())
215 |
216 | def tokenize(texts, context_length=77, truncate=False):
217 | return clip.tokenize(texts, truncate=truncate)
--------------------------------------------------------------------------------
/epoch.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import json
4 | import pdb
5 | import hydra
6 | import logging
7 | from omegaconf import DictConfig
8 |
9 | import torch
10 | import statistics
11 | from torch.utils.data import DataLoader
12 | import torch.nn.functional as F
13 | from continuum.metrics import Logger
14 | import random
15 | import numpy as np
16 | from collections import defaultdict
17 |
18 | from tqdm import tqdm
19 | from continual_clip import utils
20 | from continual_clip.models import load_model, VisionClassifier
21 | from continual_clip.datasets import build_cl_scenarios
22 | from sklearn.cluster import KMeans
23 | from continuum import rehearsal
24 | import copy
25 | from torchvision import transforms
26 | try:
27 | from torchvision.transforms import InterpolationMode
28 | BICUBIC = InterpolationMode.BICUBIC
29 | except ImportError:
30 | BICUBIC = Image.BICUBIC
31 |
32 | def intra_cls(logits, y, classes):
33 | y = y - classes
34 | logits1 = logits[:, classes:]
35 | return F.cross_entropy(logits1, y, reduction='none')
36 |
37 | def get_finetuning_dataset(dataset, memory, finetuning='balanced', oversample_old=1, task_id=0):
38 | if finetuning == 'balanced':
39 | x, y, t = memory.get()
40 |
41 | if oversample_old > 1:
42 | old_indexes = np.where(t < task_id)[0]
43 | assert len(old_indexes) > 0
44 | new_indexes = np.where(t >= task_id)[0]
45 |
46 | indexes = np.concatenate([
47 | np.repeat(old_indexes, oversample_old),
48 | new_indexes
49 | ])
50 | x, y, t = x[indexes], y[indexes], t[indexes]
51 |
52 | new_dataset = copy.deepcopy(dataset)
53 | new_dataset._x = x
54 | new_dataset._y = y
55 | new_dataset._t = t
56 | return new_dataset
57 |
58 |
59 | def seed_everything(seed=0):
60 | """Fix all random seeds"""
61 | random.seed(seed)
62 | np.random.seed(seed)
63 | torch.manual_seed(seed)
64 | torch.cuda.manual_seed_all(seed)
65 | torch.backends.cudnn.deterministic = True
66 | os.environ['PYTHONHASHSEED'] = str(seed)
67 |
68 | def activation(x):
69 | return torch.exp(-10*(1-x))
70 |
71 |
72 |
73 | def run_class_incremental(cfg, device):
74 |
75 | cfg.class_order = utils.get_class_order(os.path.join(cfg.workdir, cfg.class_order))
76 | model = load_model(cfg, device)
77 |
78 | eval_dataset, classes_names = build_cl_scenarios(
79 | cfg, is_train=False, transforms=model.transforms
80 | )
81 | train_dataset, _ = build_cl_scenarios(
82 | cfg, is_train=True, transforms=model.transforms
83 | )
84 | # pdb.set_trace()
85 | model.classes_names = classes_names
86 | if cfg.visual_clsf:
87 | if cfg.model_name == "ViT-L/14":
88 | vision_clsf = VisionClassifier(768, cfg.increment, activation=None)
89 | else:
90 | vision_clsf = VisionClassifier(512, cfg.increment, activation=None)
91 |
92 |
93 | acc_list = []
94 | metric_logger = Logger(list_subsets=["test"])
95 |
96 | p1 = 0
97 | p2 = 0
98 | negative_records = 0
99 | trainable_params = {k: v for k, v in model.named_parameters() if v.requires_grad}
100 | # pdb.set_trace()
101 | torch.save(trainable_params, f'ori_params.pth')
102 |
103 | if cfg.real_replay:
104 | memory = rehearsal.RehearsalMemory(
105 | memory_size=2000,
106 | herding_method="random"
107 | )
108 | for task_id, _ in enumerate(eval_dataset):
109 |
110 | # negative_records = 0
111 |
112 | torch.cuda.empty_cache()
113 | if task_id == 0:
114 | targets_bais = 0
115 | else:
116 | targets_bais = cfg.initial_increment + (task_id - 1) * cfg.increment
117 |
118 | logging.info(f"Evaluation for task {task_id} has started.")
119 | model.adaptation(task_id, reset=cfg.reset)
120 |
121 | # 将model的参数保存
122 | trainable_params = {k: v for k, v in model.named_parameters() if v.requires_grad}
123 | torch.save(trainable_params, f'trainable_params.pth')
124 |
125 | trainable_params = torch.load(f'ori_params.pth')
126 | model.load_state_dict(trainable_params, strict=False)
127 |
128 | # 计算未经训练时正类别和负类别的输出平均值
129 | model.eval() # 切换到评估模式
130 | positive_outputs = []
131 | negative_outputs = []
132 |
133 | val_gap_loader = DataLoader(train_dataset[task_id], batch_size=cfg.train_batch_size, shuffle=True, num_workers=cfg.num_workers)
134 |
135 | with torch.no_grad():
136 | for inputs, targets, t in val_gap_loader:
137 | inputs, targets = inputs.to(device), targets.to(device)
138 | outputs = model(inputs)
139 |
140 | one_hot_targets = torch.nn.functional.one_hot(targets, outputs.shape[1]).float()
141 | positive_outputs.append((outputs * one_hot_targets).sum(dim=1).mean())
142 | mask = 1 - one_hot_targets
143 | negative_outputs.append(((outputs * mask).sum(dim=1) / mask.sum(dim=1)).mean())
144 | positive_mean = sum(positive_outputs) / len(positive_outputs)
145 | negative_mean = sum(negative_outputs) / len(negative_outputs)
146 | # if task_id == 0:
147 | negative_records = negative_mean
148 | # if task_id == 0:
149 | logit_size = cfg.increment if task_id>0 else cfg.initial_increment
150 | bias_logit = torch.full((logit_size,), negative_mean, device=device)
151 | bias_logit[0] = positive_mean
152 | # pdb.set_trace()
153 | # pdb.set_trace()
154 | logging.info(f"positive_records: {positive_mean}")
155 | logging.info(f"negative_records: {negative_mean}")
156 | # pdb.set_trace()
157 | trainable_params = torch.load(f'trainable_params.pth')
158 | model.load_state_dict(trainable_params, strict=False)
159 |
160 | model.train()
161 | if task_id > 0 and cfg.real_replay:
162 | mem_x, mem_y, mem_t = memory.get()
163 | t_data = train_dataset[task_id]
164 | t_data.add_samples(mem_x, mem_y, mem_t)
165 | else:
166 | t_data = train_dataset[task_id]
167 | train_loader = DataLoader(t_data, batch_size=cfg.train_batch_size, shuffle=True, num_workers=cfg.num_workers)
168 |
169 | epochs = 10
170 |
171 | if epochs>0:
172 | # filter out the parameters that require grad
173 | params = filter(lambda p: p.requires_grad, model.parameters())
174 | optimizer = torch.optim.Adam(params, lr=cfg.lr)
175 | # optimizer = torch.optim.SGD(params, lr=cfg.lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay)
176 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, eta_min=cfg.lr*0.01)
177 | # for name, param in model.named_parameters():
178 | # if param.requires_grad:
179 | # print(name)
180 | torch.cuda.empty_cache()
181 | for i_epoch in range(epochs):
182 |
183 | for bach_i, (inputs, targets, t) in enumerate(train_loader):
184 | loss_c = torch.tensor(0.0).to(device)
185 | loss = torch.tensor(0.0).to(device)
186 |
187 | replay_loss = torch.tensor(0.0).to(device)
188 | torch.cuda.empty_cache()
189 |
190 |
191 | # targets = targets - targets_bais
192 | inputs, targets = inputs.to(device), targets.to(device)
193 |
194 | outputs = model(inputs)
195 | # image_f, text_f = model(inputs, return_feature=True)
196 | if task_id >0:
197 |
198 | if cfg.real_replay:
199 | mask_replay = (targets < targets_bais)
200 | old_targets = targets[mask_replay].clone()
201 | old_outputs = outputs[mask_replay].clone()
202 | targets = targets[~mask_replay]
203 | outputs = outputs[~mask_replay]
204 | replay_loss = intra_cls(old_outputs, old_targets, 0).mean()*0.1
205 | loss_c = intra_cls(outputs,targets,targets_bais).mean() + replay_loss
206 | pass
207 | else:
208 |
209 | loss_c = torch.nn.functional.cross_entropy(outputs, targets)
210 | loss += loss_c
211 | optimizer.zero_grad()
212 | loss.backward()
213 | optimizer.step()
214 | if bach_i % 10 == 0:
215 | logging.info(f"Epoch {i_epoch + 1}/{epochs} | Batch {bach_i + 1}/{len(train_loader)} | Loss: {loss.item()} | Loss_c: {loss_c.item()}")
216 | scheduler.step()
217 |
218 |
219 |
220 | torch.cuda.empty_cache()
221 | positive_outputs = []
222 | negative_outputs = []
223 | with torch.no_grad():
224 | model.eval()
225 | for inputs, targets, t in val_gap_loader:
226 | inputs, targets = inputs.to(device), targets.to(device)
227 | outputs = model(inputs)
228 | # pdb.set_trace()
229 | one_hot_targets = torch.nn.functional.one_hot(targets, outputs.shape[1]).float()
230 | positive_outputs.append((outputs * one_hot_targets).sum(dim=1).mean())
231 | mask = 1 - one_hot_targets
232 | negative_outputs.append(((outputs * mask).sum(dim=1) / mask.sum(dim=1)).mean())
233 | model.train()
234 | positive_mean = sum(positive_outputs) / len(positive_outputs)
235 | negative_mean = sum(negative_outputs) / len(negative_outputs)
236 | all_mean = (sum(positive_outputs)+ sum(positive_outputs))/ (len(positive_outputs)+len(negative_outputs))
237 |
238 | logging.info(f"positive_mean: {positive_mean}")
239 | logging.info(f"negative_mean: {negative_mean}")
240 | torch.cuda.empty_cache()
241 | if (abs(negative_records - negative_mean)/negative_records)>0.1:
242 | if i_epoch>0:
243 | logging.info(f"Negative records changed too much, epoch {i_epoch}")
244 | else:
245 | logging.info(f"Negative records changed too much, epoch 1")
246 | exit(0)
247 |
248 |
249 |
250 |
251 |
252 | def run_domain_incremental(cfg, device):
253 |
254 | model = model = load_model(cfg, device)
255 | eval_dataset, classes_names = build_cl_scenarios(
256 | cfg, is_train=False, transforms=model.transforms
257 | )
258 | model.tokenize(classes_names)
259 |
260 | with open(cfg.log_path, 'w+') as f:
261 | pass
262 |
263 | logger = Logger(list_subsets=["test"])
264 | logging.info(f">>> Evaluation scenario length is {len(eval_dataset)}")
265 | for task_id, _ in enumerate(eval_dataset):
266 |
267 | dataset_val = eval_dataset[:task_id + 1]
268 | eval_loader = DataLoader(dataset_val, batch_size=cfg.batch_size)
269 | for input, target, task_ids in tqdm(eval_loader):
270 | input, target = input.to(device), target.to(device)
271 | output = torch.from_numpy(model(input))
272 | logger.add([output.cpu().argmax(dim=1), target.cpu(), task_ids], subset='test')
273 |
274 | with open(cfg.log_path, 'a+') as f:
275 | f.write(json.dumps({
276 | 'task': task_id,
277 | 'acc': round(100 * logger.accuracy, 2),
278 | }) + '\n')
279 |
280 | logger.end_task()
281 |
282 | def run_task_agnostic():
283 | pass
284 |
285 |
286 |
287 | @hydra.main(config_path=None, config_name=None, version_base="1.1")
288 | def continual_clip(cfg: DictConfig) -> None:
289 | seed_everything(cfg.seed)
290 | cfg.workdir = utils.get_workdir(path=os.getcwd())
291 | cfg.dataset_root = os.path.join(cfg.workdir, cfg.dataset_root)
292 |
293 | utils.save_config(cfg)
294 | with open(cfg.log_path, 'w+') as f:
295 | pass
296 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
297 |
298 | if cfg.scenario == "class":
299 | run_class_incremental(cfg, device)
300 |
301 | elif cfg.scenario == "domain":
302 | run_domain_incremental(cfg, device)
303 |
304 | elif cfg.scenario == "task-agnostic":
305 | NotImplementedError("Method has not been implemented. Soon be added.")
306 |
307 | else:
308 | ValueError(f"You have entered `{cfg.scenario}` which is not a defined scenario, "
309 | "please choose from {{'class', 'domain', 'task-agnostic'}}.")
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 | if __name__ == "__main__":
333 | continual_clip()
334 |
--------------------------------------------------------------------------------
/dataset_reqs/imagenet_c_classes.txt:
--------------------------------------------------------------------------------
1 | 0 tench
2 | 1 goldfish
3 | 2 great white shark
4 | 3 tiger shark
5 | 4 hammerhead shark
6 | 5 electric ray
7 | 6 stingray
8 | 7 rooster
9 | 8 hen
10 | 9 ostrich
11 | 10 brambling
12 | 11 goldfinch
13 | 12 house finch
14 | 13 junco
15 | 14 indigo bunting
16 | 15 American robin
17 | 16 bulbul
18 | 17 jay
19 | 18 magpie
20 | 19 chickadee
21 | 20 American dipper
22 | 21 kite (bird of prey)
23 | 22 bald eagle
24 | 23 vulture
25 | 24 great grey owl
26 | 25 fire salamander
27 | 26 smooth newt
28 | 27 newt
29 | 28 spotted salamander
30 | 29 axolotl
31 | 30 American bullfrog
32 | 31 tree frog
33 | 32 tailed frog
34 | 33 loggerhead sea turtle
35 | 34 leatherback sea turtle
36 | 35 mud turtle
37 | 36 terrapin
38 | 37 box turtle
39 | 38 banded gecko
40 | 39 green iguana
41 | 40 Carolina anole
42 | 41 desert grassland whiptail lizard
43 | 42 agama
44 | 43 frilled-necked lizard
45 | 44 alligator lizard
46 | 45 Gila monster
47 | 46 European green lizard
48 | 47 chameleon
49 | 48 Komodo dragon
50 | 49 Nile crocodile
51 | 50 American alligator
52 | 51 triceratops
53 | 52 worm snake
54 | 53 ring-necked snake
55 | 54 eastern hog-nosed snake
56 | 55 smooth green snake
57 | 56 kingsnake
58 | 57 garter snake
59 | 58 water snake
60 | 59 vine snake
61 | 60 night snake
62 | 61 boa constrictor
63 | 62 African rock python
64 | 63 Indian cobra
65 | 64 green mamba
66 | 65 sea snake
67 | 66 Saharan horned viper
68 | 67 eastern diamondback rattlesnake
69 | 68 sidewinder rattlesnake
70 | 69 trilobite
71 | 70 harvestman
72 | 71 scorpion
73 | 72 yellow garden spider
74 | 73 barn spider
75 | 74 European garden spider
76 | 75 southern black widow
77 | 76 tarantula
78 | 77 wolf spider
79 | 78 tick
80 | 79 centipede
81 | 80 black grouse
82 | 81 ptarmigan
83 | 82 ruffed grouse
84 | 83 prairie grouse
85 | 84 peafowl
86 | 85 quail
87 | 86 partridge
88 | 87 african grey parrot
89 | 88 macaw
90 | 89 sulphur-crested cockatoo
91 | 90 lorikeet
92 | 91 coucal
93 | 92 bee eater
94 | 93 hornbill
95 | 94 hummingbird
96 | 95 jacamar
97 | 96 toucan
98 | 97 duck
99 | 98 red-breasted merganser
100 | 99 goose
101 | 100 black swan
102 | 101 tusker
103 | 102 echidna
104 | 103 platypus
105 | 104 wallaby
106 | 105 koala
107 | 106 wombat
108 | 107 jellyfish
109 | 108 sea anemone
110 | 109 brain coral
111 | 110 flatworm
112 | 111 nematode
113 | 112 conch
114 | 113 snail
115 | 114 slug
116 | 115 sea slug
117 | 116 chiton
118 | 117 chambered nautilus
119 | 118 Dungeness crab
120 | 119 rock crab
121 | 120 fiddler crab
122 | 121 red king crab
123 | 122 American lobster
124 | 123 spiny lobster
125 | 124 crayfish
126 | 125 hermit crab
127 | 126 isopod
128 | 127 white stork
129 | 128 black stork
130 | 129 spoonbill
131 | 130 flamingo
132 | 131 little blue heron
133 | 132 great egret
134 | 133 bittern bird
135 | 134 crane bird
136 | 135 limpkin
137 | 136 common gallinule
138 | 137 American coot
139 | 138 bustard
140 | 139 ruddy turnstone
141 | 140 dunlin
142 | 141 common redshank
143 | 142 dowitcher
144 | 143 oystercatcher
145 | 144 pelican
146 | 145 king penguin
147 | 146 albatross
148 | 147 grey whale
149 | 148 killer whale
150 | 149 dugong
151 | 150 sea lion
152 | 151 Chihuahua
153 | 152 Japanese Chin
154 | 153 Maltese
155 | 154 Pekingese
156 | 155 Shih Tzu
157 | 156 King Charles Spaniel
158 | 157 Papillon
159 | 158 toy terrier
160 | 159 Rhodesian Ridgeback
161 | 160 Afghan Hound
162 | 161 Basset Hound
163 | 162 Beagle
164 | 163 Bloodhound
165 | 164 Bluetick Coonhound
166 | 165 Black and Tan Coonhound
167 | 166 Treeing Walker Coonhound
168 | 167 English foxhound
169 | 168 Redbone Coonhound
170 | 169 borzoi
171 | 170 Irish Wolfhound
172 | 171 Italian Greyhound
173 | 172 Whippet
174 | 173 Ibizan Hound
175 | 174 Norwegian Elkhound
176 | 175 Otterhound
177 | 176 Saluki
178 | 177 Scottish Deerhound
179 | 178 Weimaraner
180 | 179 Staffordshire Bull Terrier
181 | 180 American Staffordshire Terrier
182 | 181 Bedlington Terrier
183 | 182 Border Terrier
184 | 183 Kerry Blue Terrier
185 | 184 Irish Terrier
186 | 185 Norfolk Terrier
187 | 186 Norwich Terrier
188 | 187 Yorkshire Terrier
189 | 188 Wire Fox Terrier
190 | 189 Lakeland Terrier
191 | 190 Sealyham Terrier
192 | 191 Airedale Terrier
193 | 192 Cairn Terrier
194 | 193 Australian Terrier
195 | 194 Dandie Dinmont Terrier
196 | 195 Boston Terrier
197 | 196 Miniature Schnauzer
198 | 197 Giant Schnauzer
199 | 198 Standard Schnauzer
200 | 199 Scottish Terrier
201 | 200 Tibetan Terrier
202 | 201 Australian Silky Terrier
203 | 202 Soft-coated Wheaten Terrier
204 | 203 West Highland White Terrier
205 | 204 Lhasa Apso
206 | 205 Flat-Coated Retriever
207 | 206 Curly-coated Retriever
208 | 207 Golden Retriever
209 | 208 Labrador Retriever
210 | 209 Chesapeake Bay Retriever
211 | 210 German Shorthaired Pointer
212 | 211 Vizsla
213 | 212 English Setter
214 | 213 Irish Setter
215 | 214 Gordon Setter
216 | 215 Brittany dog
217 | 216 Clumber Spaniel
218 | 217 English Springer Spaniel
219 | 218 Welsh Springer Spaniel
220 | 219 Cocker Spaniel
221 | 220 Sussex Spaniel
222 | 221 Irish Water Spaniel
223 | 222 Kuvasz
224 | 223 Schipperke
225 | 224 Groenendael dog
226 | 225 Malinois
227 | 226 Briard
228 | 227 Australian Kelpie
229 | 228 Komondor
230 | 229 Old English Sheepdog
231 | 230 Shetland Sheepdog
232 | 231 collie
233 | 232 Border Collie
234 | 233 Bouvier des Flandres dog
235 | 234 Rottweiler
236 | 235 German Shepherd Dog
237 | 236 Dobermann
238 | 237 Miniature Pinscher
239 | 238 Greater Swiss Mountain Dog
240 | 239 Bernese Mountain Dog
241 | 240 Appenzeller Sennenhund
242 | 241 Entlebucher Sennenhund
243 | 242 Boxer
244 | 243 Bullmastiff
245 | 244 Tibetan Mastiff
246 | 245 French Bulldog
247 | 246 Great Dane
248 | 247 St. Bernard
249 | 248 husky
250 | 249 Alaskan Malamute
251 | 250 Siberian Husky
252 | 251 Dalmatian
253 | 252 Affenpinscher
254 | 253 Basenji
255 | 254 pug
256 | 255 Leonberger
257 | 256 Newfoundland dog
258 | 257 Great Pyrenees dog
259 | 258 Samoyed
260 | 259 Pomeranian
261 | 260 Chow Chow
262 | 261 Keeshond
263 | 262 brussels griffon
264 | 263 Pembroke Welsh Corgi
265 | 264 Cardigan Welsh Corgi
266 | 265 Toy Poodle
267 | 266 Miniature Poodle
268 | 267 Standard Poodle
269 | 268 Mexican hairless dog (xoloitzcuintli)
270 | 269 grey wolf
271 | 270 Alaskan tundra wolf
272 | 271 red wolf or maned wolf
273 | 272 coyote
274 | 273 dingo
275 | 274 dhole
276 | 275 African wild dog
277 | 276 hyena
278 | 277 red fox
279 | 278 kit fox
280 | 279 Arctic fox
281 | 280 grey fox
282 | 281 tabby cat
283 | 282 tiger cat
284 | 283 Persian cat
285 | 284 Siamese cat
286 | 285 Egyptian Mau
287 | 286 cougar
288 | 287 lynx
289 | 288 leopard
290 | 289 snow leopard
291 | 290 jaguar
292 | 291 lion
293 | 292 tiger
294 | 293 cheetah
295 | 294 brown bear
296 | 295 American black bear
297 | 296 polar bear
298 | 297 sloth bear
299 | 298 mongoose
300 | 299 meerkat
301 | 300 tiger beetle
302 | 301 ladybug
303 | 302 ground beetle
304 | 303 longhorn beetle
305 | 304 leaf beetle
306 | 305 dung beetle
307 | 306 rhinoceros beetle
308 | 307 weevil
309 | 308 fly
310 | 309 bee
311 | 310 ant
312 | 311 grasshopper
313 | 312 cricket insect
314 | 313 stick insect
315 | 314 cockroach
316 | 315 praying mantis
317 | 316 cicada
318 | 317 leafhopper
319 | 318 lacewing
320 | 319 dragonfly
321 | 320 damselfly
322 | 321 red admiral butterfly
323 | 322 ringlet butterfly
324 | 323 monarch butterfly
325 | 324 small white butterfly
326 | 325 sulphur butterfly
327 | 326 gossamer-winged butterfly
328 | 327 starfish
329 | 328 sea urchin
330 | 329 sea cucumber
331 | 330 cottontail rabbit
332 | 331 hare
333 | 332 Angora rabbit
334 | 333 hamster
335 | 334 porcupine
336 | 335 fox squirrel
337 | 336 marmot
338 | 337 beaver
339 | 338 guinea pig
340 | 339 common sorrel horse
341 | 340 zebra
342 | 341 pig
343 | 342 wild boar
344 | 343 warthog
345 | 344 hippopotamus
346 | 345 ox
347 | 346 water buffalo
348 | 347 bison
349 | 348 ram (adult male sheep)
350 | 349 bighorn sheep
351 | 350 Alpine ibex
352 | 351 hartebeest
353 | 352 impala (antelope)
354 | 353 gazelle
355 | 354 arabian camel
356 | 355 llama
357 | 356 weasel
358 | 357 mink
359 | 358 European polecat
360 | 359 black-footed ferret
361 | 360 otter
362 | 361 skunk
363 | 362 badger
364 | 363 armadillo
365 | 364 three-toed sloth
366 | 365 orangutan
367 | 366 gorilla
368 | 367 chimpanzee
369 | 368 gibbon
370 | 369 siamang
371 | 370 guenon
372 | 371 patas monkey
373 | 372 baboon
374 | 373 macaque
375 | 374 langur
376 | 375 black-and-white colobus
377 | 376 proboscis monkey
378 | 377 marmoset
379 | 378 white-headed capuchin
380 | 379 howler monkey
381 | 380 titi monkey
382 | 381 Geoffroy's spider monkey
383 | 382 common squirrel monkey
384 | 383 ring-tailed lemur
385 | 384 indri
386 | 385 Asian elephant
387 | 386 African bush elephant
388 | 387 red panda
389 | 388 giant panda
390 | 389 snoek fish
391 | 390 eel
392 | 391 silver salmon
393 | 392 rock beauty fish
394 | 393 clownfish
395 | 394 sturgeon
396 | 395 gar fish
397 | 396 lionfish
398 | 397 pufferfish
399 | 398 abacus
400 | 399 abaya
401 | 400 academic gown
402 | 401 accordion
403 | 402 acoustic guitar
404 | 403 aircraft carrier
405 | 404 airliner
406 | 405 airship
407 | 406 altar
408 | 407 ambulance
409 | 408 amphibious vehicle
410 | 409 analog clock
411 | 410 apiary
412 | 411 apron
413 | 412 trash can
414 | 413 assault rifle
415 | 414 backpack
416 | 415 bakery
417 | 416 balance beam
418 | 417 balloon
419 | 418 ballpoint pen
420 | 419 Band-Aid
421 | 420 banjo
422 | 421 baluster / handrail
423 | 422 barbell
424 | 423 barber chair
425 | 424 barbershop
426 | 425 barn
427 | 426 barometer
428 | 427 barrel
429 | 428 wheelbarrow
430 | 429 baseball
431 | 430 basketball
432 | 431 bassinet
433 | 432 bassoon
434 | 433 swimming cap
435 | 434 bath towel
436 | 435 bathtub
437 | 436 station wagon
438 | 437 lighthouse
439 | 438 beaker
440 | 439 military hat (bearskin or shako)
441 | 440 beer bottle
442 | 441 beer glass
443 | 442 bell tower
444 | 443 baby bib
445 | 444 tandem bicycle
446 | 445 bikini
447 | 446 ring binder
448 | 447 binoculars
449 | 448 birdhouse
450 | 449 boathouse
451 | 450 bobsleigh
452 | 451 bolo tie
453 | 452 poke bonnet
454 | 453 bookcase
455 | 454 bookstore
456 | 455 bottle cap
457 | 456 hunting bow
458 | 457 bow tie
459 | 458 brass memorial plaque
460 | 459 bra
461 | 460 breakwater
462 | 461 breastplate
463 | 462 broom
464 | 463 bucket
465 | 464 buckle
466 | 465 bulletproof vest
467 | 466 high-speed train
468 | 467 butcher shop
469 | 468 taxicab
470 | 469 cauldron
471 | 470 candle
472 | 471 cannon
473 | 472 canoe
474 | 473 can opener
475 | 474 cardigan
476 | 475 car mirror
477 | 476 carousel
478 | 477 tool kit
479 | 478 cardboard box / carton
480 | 479 car wheel
481 | 480 automated teller machine
482 | 481 cassette
483 | 482 cassette player
484 | 483 castle
485 | 484 catamaran
486 | 485 CD player
487 | 486 cello
488 | 487 mobile phone
489 | 488 chain
490 | 489 chain-link fence
491 | 490 chain mail
492 | 491 chainsaw
493 | 492 storage chest
494 | 493 chiffonier
495 | 494 bell or wind chime
496 | 495 china cabinet
497 | 496 Christmas stocking
498 | 497 church
499 | 498 movie theater
500 | 499 cleaver
501 | 500 cliff dwelling
502 | 501 cloak
503 | 502 clogs
504 | 503 cocktail shaker
505 | 504 coffee mug
506 | 505 coffeemaker
507 | 506 spiral or coil
508 | 507 combination lock
509 | 508 computer keyboard
510 | 509 candy store
511 | 510 container ship
512 | 511 convertible
513 | 512 corkscrew
514 | 513 cornet
515 | 514 cowboy boot
516 | 515 cowboy hat
517 | 516 cradle
518 | 517 construction crane
519 | 518 crash helmet
520 | 519 crate
521 | 520 infant bed
522 | 521 Crock Pot
523 | 522 croquet ball
524 | 523 crutch
525 | 524 cuirass
526 | 525 dam
527 | 526 desk
528 | 527 desktop computer
529 | 528 rotary dial telephone
530 | 529 diaper
531 | 530 digital clock
532 | 531 digital watch
533 | 532 dining table
534 | 533 dishcloth
535 | 534 dishwasher
536 | 535 disc brake
537 | 536 dock
538 | 537 dog sled
539 | 538 dome
540 | 539 doormat
541 | 540 drilling rig
542 | 541 drum
543 | 542 drumstick
544 | 543 dumbbell
545 | 544 Dutch oven
546 | 545 electric fan
547 | 546 electric guitar
548 | 547 electric locomotive
549 | 548 entertainment center
550 | 549 envelope
551 | 550 espresso machine
552 | 551 face powder
553 | 552 feather boa
554 | 553 filing cabinet
555 | 554 fireboat
556 | 555 fire truck
557 | 556 fire screen
558 | 557 flagpole
559 | 558 flute
560 | 559 folding chair
561 | 560 football helmet
562 | 561 forklift
563 | 562 fountain
564 | 563 fountain pen
565 | 564 four-poster bed
566 | 565 freight car
567 | 566 French horn
568 | 567 frying pan
569 | 568 fur coat
570 | 569 garbage truck
571 | 570 gas mask or respirator
572 | 571 gas pump
573 | 572 goblet
574 | 573 go-kart
575 | 574 golf ball
576 | 575 golf cart
577 | 576 gondola
578 | 577 gong
579 | 578 gown
580 | 579 grand piano
581 | 580 greenhouse
582 | 581 radiator grille
583 | 582 grocery store
584 | 583 guillotine
585 | 584 hair clip
586 | 585 hair spray
587 | 586 half-track
588 | 587 hammer
589 | 588 hamper
590 | 589 hair dryer
591 | 590 hand-held computer
592 | 591 handkerchief
593 | 592 hard disk drive
594 | 593 harmonica
595 | 594 harp
596 | 595 combine harvester
597 | 596 hatchet
598 | 597 holster
599 | 598 home theater
600 | 599 honeycomb
601 | 600 hook
602 | 601 hoop skirt
603 | 602 gymnastic horizontal bar
604 | 603 horse-drawn vehicle
605 | 604 hourglass
606 | 605 iPod
607 | 606 clothes iron
608 | 607 carved pumpkin
609 | 608 jeans
610 | 609 jeep
611 | 610 T-shirt
612 | 611 jigsaw puzzle
613 | 612 rickshaw
614 | 613 joystick
615 | 614 kimono
616 | 615 knee pad
617 | 616 knot
618 | 617 lab coat
619 | 618 ladle
620 | 619 lampshade
621 | 620 laptop computer
622 | 621 lawn mower
623 | 622 lens cap
624 | 623 letter opener
625 | 624 library
626 | 625 lifeboat
627 | 626 lighter
628 | 627 limousine
629 | 628 ocean liner
630 | 629 lipstick
631 | 630 slip-on shoe
632 | 631 lotion
633 | 632 music speaker
634 | 633 loupe magnifying glass
635 | 634 sawmill
636 | 635 magnetic compass
637 | 636 messenger bag
638 | 637 mailbox
639 | 638 tights
640 | 639 one-piece bathing suit
641 | 640 manhole cover
642 | 641 maraca
643 | 642 marimba
644 | 643 mask
645 | 644 matchstick
646 | 645 maypole
647 | 646 maze
648 | 647 measuring cup
649 | 648 medicine cabinet
650 | 649 megalith
651 | 650 microphone
652 | 651 microwave oven
653 | 652 military uniform
654 | 653 milk can
655 | 654 minibus
656 | 655 miniskirt
657 | 656 minivan
658 | 657 missile
659 | 658 mitten
660 | 659 mixing bowl
661 | 660 mobile home
662 | 661 ford model t
663 | 662 modem
664 | 663 monastery
665 | 664 monitor
666 | 665 moped
667 | 666 mortar and pestle
668 | 667 graduation cap
669 | 668 mosque
670 | 669 mosquito net
671 | 670 vespa
672 | 671 mountain bike
673 | 672 tent
674 | 673 computer mouse
675 | 674 mousetrap
676 | 675 moving van
677 | 676 muzzle
678 | 677 metal nail
679 | 678 neck brace
680 | 679 necklace
681 | 680 baby pacifier
682 | 681 notebook computer
683 | 682 obelisk
684 | 683 oboe
685 | 684 ocarina
686 | 685 odometer
687 | 686 oil filter
688 | 687 pipe organ
689 | 688 oscilloscope
690 | 689 overskirt
691 | 690 bullock cart
692 | 691 oxygen mask
693 | 692 product packet / packaging
694 | 693 paddle
695 | 694 paddle wheel
696 | 695 padlock
697 | 696 paintbrush
698 | 697 pajamas
699 | 698 palace
700 | 699 pan flute
701 | 700 paper towel
702 | 701 parachute
703 | 702 parallel bars
704 | 703 park bench
705 | 704 parking meter
706 | 705 railroad car
707 | 706 patio
708 | 707 payphone
709 | 708 pedestal
710 | 709 pencil case
711 | 710 pencil sharpener
712 | 711 perfume
713 | 712 Petri dish
714 | 713 photocopier
715 | 714 plectrum
716 | 715 Pickelhaube
717 | 716 picket fence
718 | 717 pickup truck
719 | 718 pier
720 | 719 piggy bank
721 | 720 pill bottle
722 | 721 pillow
723 | 722 ping-pong ball
724 | 723 pinwheel
725 | 724 pirate ship
726 | 725 drink pitcher
727 | 726 block plane
728 | 727 planetarium
729 | 728 plastic bag
730 | 729 plate rack
731 | 730 farm plow
732 | 731 plunger
733 | 732 Polaroid camera
734 | 733 pole
735 | 734 police van
736 | 735 poncho
737 | 736 pool table
738 | 737 soda bottle
739 | 738 plant pot
740 | 739 potter's wheel
741 | 740 power drill
742 | 741 prayer rug
743 | 742 printer
744 | 743 prison
745 | 744 missile
746 | 745 projector
747 | 746 hockey puck
748 | 747 punching bag
749 | 748 purse
750 | 749 quill
751 | 750 quilt
752 | 751 race car
753 | 752 racket
754 | 753 radiator
755 | 754 radio
756 | 755 radio telescope
757 | 756 rain barrel
758 | 757 recreational vehicle
759 | 758 fishing casting reel
760 | 759 reflex camera
761 | 760 refrigerator
762 | 761 remote control
763 | 762 restaurant
764 | 763 revolver
765 | 764 rifle
766 | 765 rocking chair
767 | 766 rotisserie
768 | 767 eraser
769 | 768 rugby ball
770 | 769 ruler measuring stick
771 | 770 sneaker
772 | 771 safe
773 | 772 safety pin
774 | 773 salt shaker
775 | 774 sandal
776 | 775 sarong
777 | 776 saxophone
778 | 777 scabbard
779 | 778 weighing scale
780 | 779 school bus
781 | 780 schooner
782 | 781 scoreboard
783 | 782 CRT monitor
784 | 783 screw
785 | 784 screwdriver
786 | 785 seat belt
787 | 786 sewing machine
788 | 787 shield
789 | 788 shoe store
790 | 789 shoji screen / room divider
791 | 790 shopping basket
792 | 791 shopping cart
793 | 792 shovel
794 | 793 shower cap
795 | 794 shower curtain
796 | 795 ski
797 | 796 balaclava ski mask
798 | 797 sleeping bag
799 | 798 slide rule
800 | 799 sliding door
801 | 800 slot machine
802 | 801 snorkel
803 | 802 snowmobile
804 | 803 snowplow
805 | 804 soap dispenser
806 | 805 soccer ball
807 | 806 sock
808 | 807 solar thermal collector
809 | 808 sombrero
810 | 809 soup bowl
811 | 810 keyboard space bar
812 | 811 space heater
813 | 812 space shuttle
814 | 813 spatula
815 | 814 motorboat
816 | 815 spider web
817 | 816 spindle
818 | 817 sports car
819 | 818 spotlight
820 | 819 stage
821 | 820 steam locomotive
822 | 821 through arch bridge
823 | 822 steel drum
824 | 823 stethoscope
825 | 824 scarf
826 | 825 stone wall
827 | 826 stopwatch
828 | 827 stove
829 | 828 strainer
830 | 829 tram
831 | 830 stretcher
832 | 831 couch
833 | 832 stupa
834 | 833 submarine
835 | 834 suit
836 | 835 sundial
837 | 836 sunglasses
838 | 837 sunglasses
839 | 838 sunscreen
840 | 839 suspension bridge
841 | 840 mop
842 | 841 sweatshirt
843 | 842 swim trunks / shorts
844 | 843 swing
845 | 844 electrical switch
846 | 845 syringe
847 | 846 table lamp
848 | 847 tank
849 | 848 tape player
850 | 849 teapot
851 | 850 teddy bear
852 | 851 television
853 | 852 tennis ball
854 | 853 thatched roof
855 | 854 front curtain
856 | 855 thimble
857 | 856 threshing machine
858 | 857 throne
859 | 858 tile roof
860 | 859 toaster
861 | 860 tobacco shop
862 | 861 toilet seat
863 | 862 torch
864 | 863 totem pole
865 | 864 tow truck
866 | 865 toy store
867 | 866 tractor
868 | 867 semi-trailer truck
869 | 868 tray
870 | 869 trench coat
871 | 870 tricycle
872 | 871 trimaran
873 | 872 tripod
874 | 873 triumphal arch
875 | 874 trolleybus
876 | 875 trombone
877 | 876 hot tub
878 | 877 turnstile
879 | 878 typewriter keyboard
880 | 879 umbrella
881 | 880 unicycle
882 | 881 upright piano
883 | 882 vacuum cleaner
884 | 883 vase
885 | 884 vaulted or arched ceiling
886 | 885 velvet fabric
887 | 886 vending machine
888 | 887 vestment
889 | 888 viaduct
890 | 889 violin
891 | 890 volleyball
892 | 891 waffle iron
893 | 892 wall clock
894 | 893 wallet
895 | 894 wardrobe
896 | 895 military aircraft
897 | 896 sink
898 | 897 washing machine
899 | 898 water bottle
900 | 899 water jug
901 | 900 water tower
902 | 901 whiskey jug
903 | 902 whistle
904 | 903 hair wig
905 | 904 window screen
906 | 905 window shade
907 | 906 Windsor tie
908 | 907 wine bottle
909 | 908 airplane wing
910 | 909 wok
911 | 910 wooden spoon
912 | 911 wool
913 | 912 split-rail fence
914 | 913 shipwreck
915 | 914 sailboat
916 | 915 yurt
917 | 916 website
918 | 917 comic book
919 | 918 crossword
920 | 919 traffic or street sign
921 | 920 traffic light
922 | 921 dust jacket
923 | 922 menu
924 | 923 plate
925 | 924 guacamole
926 | 925 consomme
927 | 926 hot pot
928 | 927 trifle
929 | 928 ice cream
930 | 929 popsicle
931 | 930 baguette
932 | 931 bagel
933 | 932 pretzel
934 | 933 cheeseburger
935 | 934 hot dog
936 | 935 mashed potatoes
937 | 936 cabbage
938 | 937 broccoli
939 | 938 cauliflower
940 | 939 zucchini
941 | 940 spaghetti squash
942 | 941 acorn squash
943 | 942 butternut squash
944 | 943 cucumber
945 | 944 artichoke
946 | 945 bell pepper
947 | 946 cardoon
948 | 947 mushroom
949 | 948 Granny Smith apple
950 | 949 strawberry
951 | 950 orange
952 | 951 lemon
953 | 952 fig
954 | 953 pineapple
955 | 954 banana
956 | 955 jackfruit
957 | 956 cherimoya (custard apple)
958 | 957 pomegranate
959 | 958 hay
960 | 959 carbonara
961 | 960 chocolate syrup
962 | 961 dough
963 | 962 meatloaf
964 | 963 pizza
965 | 964 pot pie
966 | 965 burrito
967 | 966 red wine
968 | 967 espresso
969 | 968 tea cup
970 | 969 eggnog
971 | 970 mountain
972 | 971 bubble
973 | 972 cliff
974 | 973 coral reef
975 | 974 geyser
976 | 975 lakeshore
977 | 976 promontory
978 | 977 sandbar
979 | 978 beach
980 | 979 valley
981 | 980 volcano
982 | 981 baseball player
983 | 982 bridegroom
984 | 983 scuba diver
985 | 984 rapeseed
986 | 985 daisy
987 | 986 yellow lady's slipper
988 | 987 corn
989 | 988 acorn
990 | 989 rose hip
991 | 990 horse chestnut seed
992 | 991 coral fungus
993 | 992 agaric
994 | 993 gyromitra
995 | 994 stinkhorn mushroom
996 | 995 earth star fungus
997 | 996 hen of the woods mushroom
998 | 997 bolete
999 | 998 corn cob
1000 | 999 toilet paper
1001 |
--------------------------------------------------------------------------------
/dataset_reqs/imagenet1000_classes.txt:
--------------------------------------------------------------------------------
1 | 0 tench
2 | 1 goldfish
3 | 2 great white shark
4 | 3 tiger shark
5 | 4 hammerhead shark
6 | 5 electric ray
7 | 6 stingray
8 | 7 rooster
9 | 8 hen
10 | 9 ostrich
11 | 10 brambling
12 | 11 goldfinch
13 | 12 house finch
14 | 13 junco
15 | 14 indigo bunting
16 | 15 American robin
17 | 16 bulbul
18 | 17 jay
19 | 18 magpie
20 | 19 chickadee
21 | 20 American dipper
22 | 21 kite (bird of prey)
23 | 22 bald eagle
24 | 23 vulture
25 | 24 great grey owl
26 | 25 fire salamander
27 | 26 smooth newt
28 | 27 newt
29 | 28 spotted salamander
30 | 29 axolotl
31 | 30 American bullfrog
32 | 31 tree frog
33 | 32 tailed frog
34 | 33 loggerhead sea turtle
35 | 34 leatherback sea turtle
36 | 35 mud turtle
37 | 36 terrapin
38 | 37 box turtle
39 | 38 banded gecko
40 | 39 green iguana
41 | 40 Carolina anole
42 | 41 desert grassland whiptail lizard
43 | 42 agama
44 | 43 frilled-necked lizard
45 | 44 alligator lizard
46 | 45 Gila monster
47 | 46 European green lizard
48 | 47 chameleon
49 | 48 Komodo dragon
50 | 49 Nile crocodile
51 | 50 American alligator
52 | 51 triceratops
53 | 52 worm snake
54 | 53 ring-necked snake
55 | 54 eastern hog-nosed snake
56 | 55 smooth green snake
57 | 56 kingsnake
58 | 57 garter snake
59 | 58 water snake
60 | 59 vine snake
61 | 60 night snake
62 | 61 boa constrictor
63 | 62 African rock python
64 | 63 Indian cobra
65 | 64 green mamba
66 | 65 sea snake
67 | 66 Saharan horned viper
68 | 67 eastern diamondback rattlesnake
69 | 68 sidewinder rattlesnake
70 | 69 trilobite
71 | 70 harvestman
72 | 71 scorpion
73 | 72 yellow garden spider
74 | 73 barn spider
75 | 74 European garden spider
76 | 75 southern black widow
77 | 76 tarantula
78 | 77 wolf spider
79 | 78 tick
80 | 79 centipede
81 | 80 black grouse
82 | 81 ptarmigan
83 | 82 ruffed grouse
84 | 83 prairie grouse
85 | 84 peafowl
86 | 85 quail
87 | 86 partridge
88 | 87 african grey parrot
89 | 88 macaw
90 | 89 sulphur-crested cockatoo
91 | 90 lorikeet
92 | 91 coucal
93 | 92 bee eater
94 | 93 hornbill
95 | 94 hummingbird
96 | 95 jacamar
97 | 96 toucan
98 | 97 duck
99 | 98 red-breasted merganser
100 | 99 goose
101 | 100 black swan
102 | 101 tusker
103 | 102 echidna
104 | 103 platypus
105 | 104 wallaby
106 | 105 koala
107 | 106 wombat
108 | 107 jellyfish
109 | 108 sea anemone
110 | 109 brain coral
111 | 110 flatworm
112 | 111 nematode
113 | 112 conch
114 | 113 snail
115 | 114 slug
116 | 115 sea slug
117 | 116 chiton
118 | 117 chambered nautilus
119 | 118 Dungeness crab
120 | 119 rock crab
121 | 120 fiddler crab
122 | 121 red king crab
123 | 122 American lobster
124 | 123 spiny lobster
125 | 124 crayfish
126 | 125 hermit crab
127 | 126 isopod
128 | 127 white stork
129 | 128 black stork
130 | 129 spoonbill
131 | 130 flamingo
132 | 131 little blue heron
133 | 132 great egret
134 | 133 bittern bird
135 | 134 crane bird
136 | 135 limpkin
137 | 136 common gallinule
138 | 137 American coot
139 | 138 bustard
140 | 139 ruddy turnstone
141 | 140 dunlin
142 | 141 common redshank
143 | 142 dowitcher
144 | 143 oystercatcher
145 | 144 pelican
146 | 145 king penguin
147 | 146 albatross
148 | 147 grey whale
149 | 148 killer whale
150 | 149 dugong
151 | 150 sea lion
152 | 151 Chihuahua
153 | 152 Japanese Chin
154 | 153 Maltese
155 | 154 Pekingese
156 | 155 Shih Tzu
157 | 156 King Charles Spaniel
158 | 157 Papillon
159 | 158 toy terrier
160 | 159 Rhodesian Ridgeback
161 | 160 Afghan Hound
162 | 161 Basset Hound
163 | 162 Beagle
164 | 163 Bloodhound
165 | 164 Bluetick Coonhound
166 | 165 Black and Tan Coonhound
167 | 166 Treeing Walker Coonhound
168 | 167 English foxhound
169 | 168 Redbone Coonhound
170 | 169 borzoi
171 | 170 Irish Wolfhound
172 | 171 Italian Greyhound
173 | 172 Whippet
174 | 173 Ibizan Hound
175 | 174 Norwegian Elkhound
176 | 175 Otterhound
177 | 176 Saluki
178 | 177 Scottish Deerhound
179 | 178 Weimaraner
180 | 179 Staffordshire Bull Terrier
181 | 180 American Staffordshire Terrier
182 | 181 Bedlington Terrier
183 | 182 Border Terrier
184 | 183 Kerry Blue Terrier
185 | 184 Irish Terrier
186 | 185 Norfolk Terrier
187 | 186 Norwich Terrier
188 | 187 Yorkshire Terrier
189 | 188 Wire Fox Terrier
190 | 189 Lakeland Terrier
191 | 190 Sealyham Terrier
192 | 191 Airedale Terrier
193 | 192 Cairn Terrier
194 | 193 Australian Terrier
195 | 194 Dandie Dinmont Terrier
196 | 195 Boston Terrier
197 | 196 Miniature Schnauzer
198 | 197 Giant Schnauzer
199 | 198 Standard Schnauzer
200 | 199 Scottish Terrier
201 | 200 Tibetan Terrier
202 | 201 Australian Silky Terrier
203 | 202 Soft-coated Wheaten Terrier
204 | 203 West Highland White Terrier
205 | 204 Lhasa Apso
206 | 205 Flat-Coated Retriever
207 | 206 Curly-coated Retriever
208 | 207 Golden Retriever
209 | 208 Labrador Retriever
210 | 209 Chesapeake Bay Retriever
211 | 210 German Shorthaired Pointer
212 | 211 Vizsla
213 | 212 English Setter
214 | 213 Irish Setter
215 | 214 Gordon Setter
216 | 215 Brittany dog
217 | 216 Clumber Spaniel
218 | 217 English Springer Spaniel
219 | 218 Welsh Springer Spaniel
220 | 219 Cocker Spaniel
221 | 220 Sussex Spaniel
222 | 221 Irish Water Spaniel
223 | 222 Kuvasz
224 | 223 Schipperke
225 | 224 Groenendael dog
226 | 225 Malinois
227 | 226 Briard
228 | 227 Australian Kelpie
229 | 228 Komondor
230 | 229 Old English Sheepdog
231 | 230 Shetland Sheepdog
232 | 231 collie
233 | 232 Border Collie
234 | 233 Bouvier des Flandres dog
235 | 234 Rottweiler
236 | 235 German Shepherd Dog
237 | 236 Dobermann
238 | 237 Miniature Pinscher
239 | 238 Greater Swiss Mountain Dog
240 | 239 Bernese Mountain Dog
241 | 240 Appenzeller Sennenhund
242 | 241 Entlebucher Sennenhund
243 | 242 Boxer
244 | 243 Bullmastiff
245 | 244 Tibetan Mastiff
246 | 245 French Bulldog
247 | 246 Great Dane
248 | 247 St. Bernard
249 | 248 husky
250 | 249 Alaskan Malamute
251 | 250 Siberian Husky
252 | 251 Dalmatian
253 | 252 Affenpinscher
254 | 253 Basenji
255 | 254 pug
256 | 255 Leonberger
257 | 256 Newfoundland dog
258 | 257 Great Pyrenees dog
259 | 258 Samoyed
260 | 259 Pomeranian
261 | 260 Chow Chow
262 | 261 Keeshond
263 | 262 brussels griffon
264 | 263 Pembroke Welsh Corgi
265 | 264 Cardigan Welsh Corgi
266 | 265 Toy Poodle
267 | 266 Miniature Poodle
268 | 267 Standard Poodle
269 | 268 Mexican hairless dog (xoloitzcuintli)
270 | 269 grey wolf
271 | 270 Alaskan tundra wolf
272 | 271 red wolf or maned wolf
273 | 272 coyote
274 | 273 dingo
275 | 274 dhole
276 | 275 African wild dog
277 | 276 hyena
278 | 277 red fox
279 | 278 kit fox
280 | 279 Arctic fox
281 | 280 grey fox
282 | 281 tabby cat
283 | 282 tiger cat
284 | 283 Persian cat
285 | 284 Siamese cat
286 | 285 Egyptian Mau
287 | 286 cougar
288 | 287 lynx
289 | 288 leopard
290 | 289 snow leopard
291 | 290 jaguar
292 | 291 lion
293 | 292 tiger
294 | 293 cheetah
295 | 294 brown bear
296 | 295 American black bear
297 | 296 polar bear
298 | 297 sloth bear
299 | 298 mongoose
300 | 299 meerkat
301 | 300 tiger beetle
302 | 301 ladybug
303 | 302 ground beetle
304 | 303 longhorn beetle
305 | 304 leaf beetle
306 | 305 dung beetle
307 | 306 rhinoceros beetle
308 | 307 weevil
309 | 308 fly
310 | 309 bee
311 | 310 ant
312 | 311 grasshopper
313 | 312 cricket insect
314 | 313 stick insect
315 | 314 cockroach
316 | 315 praying mantis
317 | 316 cicada
318 | 317 leafhopper
319 | 318 lacewing
320 | 319 dragonfly
321 | 320 damselfly
322 | 321 red admiral butterfly
323 | 322 ringlet butterfly
324 | 323 monarch butterfly
325 | 324 small white butterfly
326 | 325 sulphur butterfly
327 | 326 gossamer-winged butterfly
328 | 327 starfish
329 | 328 sea urchin
330 | 329 sea cucumber
331 | 330 cottontail rabbit
332 | 331 hare
333 | 332 Angora rabbit
334 | 333 hamster
335 | 334 porcupine
336 | 335 fox squirrel
337 | 336 marmot
338 | 337 beaver
339 | 338 guinea pig
340 | 339 common sorrel horse
341 | 340 zebra
342 | 341 pig
343 | 342 wild boar
344 | 343 warthog
345 | 344 hippopotamus
346 | 345 ox
347 | 346 water buffalo
348 | 347 bison
349 | 348 ram (adult male sheep)
350 | 349 bighorn sheep
351 | 350 Alpine ibex
352 | 351 hartebeest
353 | 352 impala (antelope)
354 | 353 gazelle
355 | 354 arabian camel
356 | 355 llama
357 | 356 weasel
358 | 357 mink
359 | 358 European polecat
360 | 359 black-footed ferret
361 | 360 otter
362 | 361 skunk
363 | 362 badger
364 | 363 armadillo
365 | 364 three-toed sloth
366 | 365 orangutan
367 | 366 gorilla
368 | 367 chimpanzee
369 | 368 gibbon
370 | 369 siamang
371 | 370 guenon
372 | 371 patas monkey
373 | 372 baboon
374 | 373 macaque
375 | 374 langur
376 | 375 black-and-white colobus
377 | 376 proboscis monkey
378 | 377 marmoset
379 | 378 white-headed capuchin
380 | 379 howler monkey
381 | 380 titi monkey
382 | 381 Geoffroy's spider monkey
383 | 382 common squirrel monkey
384 | 383 ring-tailed lemur
385 | 384 indri
386 | 385 Asian elephant
387 | 386 African bush elephant
388 | 387 red panda
389 | 388 giant panda
390 | 389 snoek fish
391 | 390 eel
392 | 391 silver salmon
393 | 392 rock beauty fish
394 | 393 clownfish
395 | 394 sturgeon
396 | 395 gar fish
397 | 396 lionfish
398 | 397 pufferfish
399 | 398 abacus
400 | 399 abaya
401 | 400 academic gown
402 | 401 accordion
403 | 402 acoustic guitar
404 | 403 aircraft carrier
405 | 404 airliner
406 | 405 airship
407 | 406 altar
408 | 407 ambulance
409 | 408 amphibious vehicle
410 | 409 analog clock
411 | 410 apiary
412 | 411 apron
413 | 412 trash can
414 | 413 assault rifle
415 | 414 backpack
416 | 415 bakery
417 | 416 balance beam
418 | 417 balloon
419 | 418 ballpoint pen
420 | 419 Band-Aid
421 | 420 banjo
422 | 421 baluster / handrail
423 | 422 barbell
424 | 423 barber chair
425 | 424 barbershop
426 | 425 barn
427 | 426 barometer
428 | 427 barrel
429 | 428 wheelbarrow
430 | 429 baseball
431 | 430 basketball
432 | 431 bassinet
433 | 432 bassoon
434 | 433 swimming cap
435 | 434 bath towel
436 | 435 bathtub
437 | 436 station wagon
438 | 437 lighthouse
439 | 438 beaker
440 | 439 military hat (bearskin or shako)
441 | 440 beer bottle
442 | 441 beer glass
443 | 442 bell tower
444 | 443 baby bib
445 | 444 tandem bicycle
446 | 445 bikini
447 | 446 ring binder
448 | 447 binoculars
449 | 448 birdhouse
450 | 449 boathouse
451 | 450 bobsleigh
452 | 451 bolo tie
453 | 452 poke bonnet
454 | 453 bookcase
455 | 454 bookstore
456 | 455 bottle cap
457 | 456 hunting bow
458 | 457 bow tie
459 | 458 brass memorial plaque
460 | 459 bra
461 | 460 breakwater
462 | 461 breastplate
463 | 462 broom
464 | 463 bucket
465 | 464 buckle
466 | 465 bulletproof vest
467 | 466 high-speed train
468 | 467 butcher shop
469 | 468 taxicab
470 | 469 cauldron
471 | 470 candle
472 | 471 cannon
473 | 472 canoe
474 | 473 can opener
475 | 474 cardigan
476 | 475 car mirror
477 | 476 carousel
478 | 477 tool kit
479 | 478 cardboard box / carton
480 | 479 car wheel
481 | 480 automated teller machine
482 | 481 cassette
483 | 482 cassette player
484 | 483 castle
485 | 484 catamaran
486 | 485 CD player
487 | 486 cello
488 | 487 mobile phone
489 | 488 chain
490 | 489 chain-link fence
491 | 490 chain mail
492 | 491 chainsaw
493 | 492 storage chest
494 | 493 chiffonier
495 | 494 bell or wind chime
496 | 495 china cabinet
497 | 496 Christmas stocking
498 | 497 church
499 | 498 movie theater
500 | 499 cleaver
501 | 500 cliff dwelling
502 | 501 cloak
503 | 502 clogs
504 | 503 cocktail shaker
505 | 504 coffee mug
506 | 505 coffeemaker
507 | 506 spiral or coil
508 | 507 combination lock
509 | 508 computer keyboard
510 | 509 candy store
511 | 510 container ship
512 | 511 convertible
513 | 512 corkscrew
514 | 513 cornet
515 | 514 cowboy boot
516 | 515 cowboy hat
517 | 516 cradle
518 | 517 construction crane
519 | 518 crash helmet
520 | 519 crate
521 | 520 infant bed
522 | 521 Crock Pot
523 | 522 croquet ball
524 | 523 crutch
525 | 524 cuirass
526 | 525 dam
527 | 526 desk
528 | 527 desktop computer
529 | 528 rotary dial telephone
530 | 529 diaper
531 | 530 digital clock
532 | 531 digital watch
533 | 532 dining table
534 | 533 dishcloth
535 | 534 dishwasher
536 | 535 disc brake
537 | 536 dock
538 | 537 dog sled
539 | 538 dome
540 | 539 doormat
541 | 540 drilling rig
542 | 541 drum
543 | 542 drumstick
544 | 543 dumbbell
545 | 544 Dutch oven
546 | 545 electric fan
547 | 546 electric guitar
548 | 547 electric locomotive
549 | 548 entertainment center
550 | 549 envelope
551 | 550 espresso machine
552 | 551 face powder
553 | 552 feather boa
554 | 553 filing cabinet
555 | 554 fireboat
556 | 555 fire truck
557 | 556 fire screen
558 | 557 flagpole
559 | 558 flute
560 | 559 folding chair
561 | 560 football helmet
562 | 561 forklift
563 | 562 fountain
564 | 563 fountain pen
565 | 564 four-poster bed
566 | 565 freight car
567 | 566 French horn
568 | 567 frying pan
569 | 568 fur coat
570 | 569 garbage truck
571 | 570 gas mask or respirator
572 | 571 gas pump
573 | 572 goblet
574 | 573 go-kart
575 | 574 golf ball
576 | 575 golf cart
577 | 576 gondola
578 | 577 gong
579 | 578 gown
580 | 579 grand piano
581 | 580 greenhouse
582 | 581 radiator grille
583 | 582 grocery store
584 | 583 guillotine
585 | 584 hair clip
586 | 585 hair spray
587 | 586 half-track
588 | 587 hammer
589 | 588 hamper
590 | 589 hair dryer
591 | 590 hand-held computer
592 | 591 handkerchief
593 | 592 hard disk drive
594 | 593 harmonica
595 | 594 harp
596 | 595 combine harvester
597 | 596 hatchet
598 | 597 holster
599 | 598 home theater
600 | 599 honeycomb
601 | 600 hook
602 | 601 hoop skirt
603 | 602 gymnastic horizontal bar
604 | 603 horse-drawn vehicle
605 | 604 hourglass
606 | 605 iPod
607 | 606 clothes iron
608 | 607 carved pumpkin
609 | 608 jeans
610 | 609 jeep
611 | 610 T-shirt
612 | 611 jigsaw puzzle
613 | 612 rickshaw
614 | 613 joystick
615 | 614 kimono
616 | 615 knee pad
617 | 616 knot
618 | 617 lab coat
619 | 618 ladle
620 | 619 lampshade
621 | 620 laptop computer
622 | 621 lawn mower
623 | 622 lens cap
624 | 623 letter opener
625 | 624 library
626 | 625 lifeboat
627 | 626 lighter
628 | 627 limousine
629 | 628 ocean liner
630 | 629 lipstick
631 | 630 slip-on shoe
632 | 631 lotion
633 | 632 music speaker
634 | 633 loupe magnifying glass
635 | 634 sawmill
636 | 635 magnetic compass
637 | 636 messenger bag
638 | 637 mailbox
639 | 638 tights
640 | 639 one-piece bathing suit
641 | 640 manhole cover
642 | 641 maraca
643 | 642 marimba
644 | 643 mask
645 | 644 matchstick
646 | 645 maypole
647 | 646 maze
648 | 647 measuring cup
649 | 648 medicine cabinet
650 | 649 megalith
651 | 650 microphone
652 | 651 microwave oven
653 | 652 military uniform
654 | 653 milk can
655 | 654 minibus
656 | 655 miniskirt
657 | 656 minivan
658 | 657 missile
659 | 658 mitten
660 | 659 mixing bowl
661 | 660 mobile home
662 | 661 ford model t
663 | 662 modem
664 | 663 monastery
665 | 664 monitor
666 | 665 moped
667 | 666 mortar and pestle
668 | 667 graduation cap
669 | 668 mosque
670 | 669 mosquito net
671 | 670 vespa
672 | 671 mountain bike
673 | 672 tent
674 | 673 computer mouse
675 | 674 mousetrap
676 | 675 moving van
677 | 676 muzzle
678 | 677 metal nail
679 | 678 neck brace
680 | 679 necklace
681 | 680 baby pacifier
682 | 681 notebook computer
683 | 682 obelisk
684 | 683 oboe
685 | 684 ocarina
686 | 685 odometer
687 | 686 oil filter
688 | 687 pipe organ
689 | 688 oscilloscope
690 | 689 overskirt
691 | 690 bullock cart
692 | 691 oxygen mask
693 | 692 product packet / packaging
694 | 693 paddle
695 | 694 paddle wheel
696 | 695 padlock
697 | 696 paintbrush
698 | 697 pajamas
699 | 698 palace
700 | 699 pan flute
701 | 700 paper towel
702 | 701 parachute
703 | 702 parallel bars
704 | 703 park bench
705 | 704 parking meter
706 | 705 railroad car
707 | 706 patio
708 | 707 payphone
709 | 708 pedestal
710 | 709 pencil case
711 | 710 pencil sharpener
712 | 711 perfume
713 | 712 Petri dish
714 | 713 photocopier
715 | 714 plectrum
716 | 715 Pickelhaube
717 | 716 picket fence
718 | 717 pickup truck
719 | 718 pier
720 | 719 piggy bank
721 | 720 pill bottle
722 | 721 pillow
723 | 722 ping-pong ball
724 | 723 pinwheel
725 | 724 pirate ship
726 | 725 drink pitcher
727 | 726 block plane
728 | 727 planetarium
729 | 728 plastic bag
730 | 729 plate rack
731 | 730 farm plow
732 | 731 plunger
733 | 732 Polaroid camera
734 | 733 pole
735 | 734 police van
736 | 735 poncho
737 | 736 pool table
738 | 737 soda bottle
739 | 738 plant pot
740 | 739 potter's wheel
741 | 740 power drill
742 | 741 prayer rug
743 | 742 printer
744 | 743 prison
745 | 744 missile
746 | 745 projector
747 | 746 hockey puck
748 | 747 punching bag
749 | 748 purse
750 | 749 quill
751 | 750 quilt
752 | 751 race car
753 | 752 racket
754 | 753 radiator
755 | 754 radio
756 | 755 radio telescope
757 | 756 rain barrel
758 | 757 recreational vehicle
759 | 758 fishing casting reel
760 | 759 reflex camera
761 | 760 refrigerator
762 | 761 remote control
763 | 762 restaurant
764 | 763 revolver
765 | 764 rifle
766 | 765 rocking chair
767 | 766 rotisserie
768 | 767 eraser
769 | 768 rugby ball
770 | 769 ruler measuring stick
771 | 770 sneaker
772 | 771 safe
773 | 772 safety pin
774 | 773 salt shaker
775 | 774 sandal
776 | 775 sarong
777 | 776 saxophone
778 | 777 scabbard
779 | 778 weighing scale
780 | 779 school bus
781 | 780 schooner
782 | 781 scoreboard
783 | 782 CRT monitor
784 | 783 screw
785 | 784 screwdriver
786 | 785 seat belt
787 | 786 sewing machine
788 | 787 shield
789 | 788 shoe store
790 | 789 shoji screen / room divider
791 | 790 shopping basket
792 | 791 shopping cart
793 | 792 shovel
794 | 793 shower cap
795 | 794 shower curtain
796 | 795 ski
797 | 796 balaclava ski mask
798 | 797 sleeping bag
799 | 798 slide rule
800 | 799 sliding door
801 | 800 slot machine
802 | 801 snorkel
803 | 802 snowmobile
804 | 803 snowplow
805 | 804 soap dispenser
806 | 805 soccer ball
807 | 806 sock
808 | 807 solar thermal collector
809 | 808 sombrero
810 | 809 soup bowl
811 | 810 keyboard space bar
812 | 811 space heater
813 | 812 space shuttle
814 | 813 spatula
815 | 814 motorboat
816 | 815 spider web
817 | 816 spindle
818 | 817 sports car
819 | 818 spotlight
820 | 819 stage
821 | 820 steam locomotive
822 | 821 through arch bridge
823 | 822 steel drum
824 | 823 stethoscope
825 | 824 scarf
826 | 825 stone wall
827 | 826 stopwatch
828 | 827 stove
829 | 828 strainer
830 | 829 tram
831 | 830 stretcher
832 | 831 couch
833 | 832 stupa
834 | 833 submarine
835 | 834 suit
836 | 835 sundial
837 | 836 sunglasses
838 | 837 sunglasses
839 | 838 sunscreen
840 | 839 suspension bridge
841 | 840 mop
842 | 841 sweatshirt
843 | 842 swim trunks / shorts
844 | 843 swing
845 | 844 electrical switch
846 | 845 syringe
847 | 846 table lamp
848 | 847 tank
849 | 848 tape player
850 | 849 teapot
851 | 850 teddy bear
852 | 851 television
853 | 852 tennis ball
854 | 853 thatched roof
855 | 854 front curtain
856 | 855 thimble
857 | 856 threshing machine
858 | 857 throne
859 | 858 tile roof
860 | 859 toaster
861 | 860 tobacco shop
862 | 861 toilet seat
863 | 862 torch
864 | 863 totem pole
865 | 864 tow truck
866 | 865 toy store
867 | 866 tractor
868 | 867 semi-trailer truck
869 | 868 tray
870 | 869 trench coat
871 | 870 tricycle
872 | 871 trimaran
873 | 872 tripod
874 | 873 triumphal arch
875 | 874 trolleybus
876 | 875 trombone
877 | 876 hot tub
878 | 877 turnstile
879 | 878 typewriter keyboard
880 | 879 umbrella
881 | 880 unicycle
882 | 881 upright piano
883 | 882 vacuum cleaner
884 | 883 vase
885 | 884 vaulted or arched ceiling
886 | 885 velvet fabric
887 | 886 vending machine
888 | 887 vestment
889 | 888 viaduct
890 | 889 violin
891 | 890 volleyball
892 | 891 waffle iron
893 | 892 wall clock
894 | 893 wallet
895 | 894 wardrobe
896 | 895 military aircraft
897 | 896 sink
898 | 897 washing machine
899 | 898 water bottle
900 | 899 water jug
901 | 900 water tower
902 | 901 whiskey jug
903 | 902 whistle
904 | 903 hair wig
905 | 904 window screen
906 | 905 window shade
907 | 906 Windsor tie
908 | 907 wine bottle
909 | 908 airplane wing
910 | 909 wok
911 | 910 wooden spoon
912 | 911 wool
913 | 912 split-rail fence
914 | 913 shipwreck
915 | 914 sailboat
916 | 915 yurt
917 | 916 website
918 | 917 comic book
919 | 918 crossword
920 | 919 traffic or street sign
921 | 920 traffic light
922 | 921 dust jacket
923 | 922 menu
924 | 923 plate
925 | 924 guacamole
926 | 925 consomme
927 | 926 hot pot
928 | 927 trifle
929 | 928 ice cream
930 | 929 popsicle
931 | 930 baguette
932 | 931 bagel
933 | 932 pretzel
934 | 933 cheeseburger
935 | 934 hot dog
936 | 935 mashed potatoes
937 | 936 cabbage
938 | 937 broccoli
939 | 938 cauliflower
940 | 939 zucchini
941 | 940 spaghetti squash
942 | 941 acorn squash
943 | 942 butternut squash
944 | 943 cucumber
945 | 944 artichoke
946 | 945 bell pepper
947 | 946 cardoon
948 | 947 mushroom
949 | 948 Granny Smith apple
950 | 949 strawberry
951 | 950 orange
952 | 951 lemon
953 | 952 fig
954 | 953 pineapple
955 | 954 banana
956 | 955 jackfruit
957 | 956 cherimoya (custard apple)
958 | 957 pomegranate
959 | 958 hay
960 | 959 carbonara
961 | 960 chocolate syrup
962 | 961 dough
963 | 962 meatloaf
964 | 963 pizza
965 | 964 pot pie
966 | 965 burrito
967 | 966 red wine
968 | 967 espresso
969 | 968 tea cup
970 | 969 eggnog
971 | 970 mountain
972 | 971 bubble
973 | 972 cliff
974 | 973 coral reef
975 | 974 geyser
976 | 975 lakeshore
977 | 976 promontory
978 | 977 sandbar
979 | 978 beach
980 | 979 valley
981 | 980 volcano
982 | 981 baseball player
983 | 982 bridegroom
984 | 983 scuba diver
985 | 984 rapeseed
986 | 985 daisy
987 | 986 yellow lady's slipper
988 | 987 corn
989 | 988 acorn
990 | 989 rose hip
991 | 990 horse chestnut seed
992 | 991 coral fungus
993 | 992 agaric
994 | 993 gyromitra
995 | 994 stinkhorn mushroom
996 | 995 earth star fungus
997 | 996 hen of the woods mushroom
998 | 997 bolete
999 | 998 corn cob
1000 | 999 toilet paper
1001 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import json
4 | import pdb
5 | import hydra
6 | import logging
7 | from omegaconf import DictConfig
8 |
9 | import torch
10 | import statistics
11 | from torch.utils.data import DataLoader
12 | import torch.nn.functional as F
13 | from continuum.metrics import Logger
14 | import random
15 | import numpy as np
16 | from collections import defaultdict
17 |
18 | from tqdm import tqdm
19 | from continual_clip import utils
20 | from continual_clip.models import load_model, VisionClassifier
21 | from continual_clip.datasets import build_cl_scenarios
22 | from sklearn.cluster import KMeans
23 | from continuum import rehearsal
24 | import copy
25 | from torchvision import transforms
26 | try:
27 | from torchvision.transforms import InterpolationMode
28 | BICUBIC = InterpolationMode.BICUBIC
29 | except ImportError:
30 | BICUBIC = Image.BICUBIC
31 |
32 | def intra_cls(logits, y, classes):
33 | y = y - classes
34 | logits1 = logits[:, classes:]
35 | return F.cross_entropy(logits1, y, reduction='none')
36 |
37 | def get_finetuning_dataset(dataset, memory, finetuning='balanced', oversample_old=1, task_id=0):
38 | if finetuning == 'balanced':
39 | x, y, t = memory.get()
40 |
41 | if oversample_old > 1:
42 | old_indexes = np.where(t < task_id)[0]
43 | assert len(old_indexes) > 0
44 | new_indexes = np.where(t >= task_id)[0]
45 |
46 | indexes = np.concatenate([
47 | np.repeat(old_indexes, oversample_old),
48 | new_indexes
49 | ])
50 | x, y, t = x[indexes], y[indexes], t[indexes]
51 |
52 | new_dataset = copy.deepcopy(dataset)
53 | new_dataset._x = x
54 | new_dataset._y = y
55 | new_dataset._t = t
56 | return new_dataset
57 |
58 |
59 | def seed_everything(seed=0):
60 | """Fix all random seeds"""
61 | random.seed(seed)
62 | np.random.seed(seed)
63 | torch.manual_seed(seed)
64 | torch.cuda.manual_seed_all(seed)
65 | torch.backends.cudnn.deterministic = True
66 | os.environ['PYTHONHASHSEED'] = str(seed)
67 |
68 | def activation(x):
69 | return torch.exp(-10*(1-x))
70 |
71 |
72 |
73 | def run_class_incremental(cfg, device):
74 |
75 | cfg.class_order = utils.get_class_order(os.path.join(cfg.workdir, cfg.class_order))
76 | model = load_model(cfg, device)
77 |
78 | eval_dataset, classes_names = build_cl_scenarios(
79 | cfg, is_train=False, transforms=model.transforms
80 | )
81 | train_dataset, _ = build_cl_scenarios(
82 | cfg, is_train=True, transforms=model.transforms
83 | )
84 | # pdb.set_trace()
85 | model.classes_names = classes_names
86 | if cfg.visual_clsf:
87 | if cfg.model_name == "ViT-L/14":
88 | vision_clsf = VisionClassifier(768, cfg.increment, activation=None)
89 | else:
90 | vision_clsf = VisionClassifier(512, cfg.increment, activation=None)
91 |
92 |
93 | acc_list = []
94 | metric_logger = Logger(list_subsets=["test"])
95 |
96 | p1 = 0
97 | p2 = 0
98 | negative_records = 0
99 | trainable_params = {k: v for k, v in model.named_parameters() if v.requires_grad}
100 | # pdb.set_trace()
101 | torch.save(trainable_params, f'ori_params.pth')
102 |
103 | if cfg.real_replay:
104 | memory = rehearsal.RehearsalMemory(
105 | memory_size=2000,
106 | herding_method="random"
107 | )
108 | for task_id, _ in enumerate(eval_dataset):
109 |
110 | # negative_records = 0
111 |
112 | torch.cuda.empty_cache()
113 | if task_id == 0:
114 | targets_bais = 0
115 | else:
116 | targets_bais = cfg.initial_increment + (task_id - 1) * cfg.increment
117 |
118 | logging.info(f"Evaluation for task {task_id} has started.")
119 | model.adaptation(task_id, reset=cfg.reset)
120 |
121 | # 将model的参数保存
122 | trainable_params = {k: v for k, v in model.named_parameters() if v.requires_grad}
123 | torch.save(trainable_params, f'trainable_params.pth')
124 |
125 | trainable_params = torch.load(f'ori_params.pth')
126 | model.load_state_dict(trainable_params, strict=False)
127 |
128 | # 计算未经训练时正类别和负类别的输出平均值
129 | model.eval() # 切换到评估模式
130 | positive_outputs = []
131 | negative_outputs = []
132 |
133 | val_gap_loader = DataLoader(train_dataset[task_id], batch_size=cfg.train_batch_size, shuffle=True, num_workers=cfg.num_workers)
134 |
135 | with torch.no_grad():
136 | for inputs, targets, t in val_gap_loader:
137 | inputs, targets = inputs.to(device), targets.to(device)
138 | outputs = model(inputs)
139 |
140 | one_hot_targets = torch.nn.functional.one_hot(targets, outputs.shape[1]).float()
141 | positive_outputs.append((outputs * one_hot_targets).sum(dim=1).mean())
142 | mask = 1 - one_hot_targets
143 | negative_outputs.append(((outputs * mask).sum(dim=1) / mask.sum(dim=1)).mean())
144 | positive_mean = sum(positive_outputs) / len(positive_outputs)
145 | negative_mean = sum(negative_outputs) / len(negative_outputs)
146 | # if task_id == 0:
147 | negative_records = negative_mean
148 | # if task_id == 0:
149 | logit_size = cfg.increment if task_id>0 else cfg.initial_increment
150 | bias_logit = torch.full((logit_size,), negative_mean, device=device)
151 | bias_logit[0] = positive_mean
152 | # pdb.set_trace()
153 | # pdb.set_trace()
154 | logging.info(f"positive_records: {positive_mean}")
155 | logging.info(f"negative_records: {negative_mean}")
156 | # pdb.set_trace()
157 | trainable_params = torch.load(f'trainable_params.pth')
158 | model.load_state_dict(trainable_params, strict=False)
159 |
160 | model.train()
161 | if task_id > 0 and cfg.real_replay:
162 | mem_x, mem_y, mem_t = memory.get()
163 | t_data = train_dataset[task_id]
164 | t_data.add_samples(mem_x, mem_y, mem_t)
165 | else:
166 | t_data = train_dataset[task_id]
167 | # train_loader = DataLoader(train_dataset[:task_id+1], batch_size=cfg.train_batch_size, shuffle=True, num_workers=cfg.num_workers)
168 | train_loader = DataLoader(t_data, batch_size=cfg.train_batch_size, shuffle=True, num_workers=cfg.num_workers)
169 |
170 | epochs = cfg.epochs
171 |
172 | if epochs>0:
173 | # filter out the parameters that require grad
174 | params = filter(lambda p: p.requires_grad, model.parameters())
175 | optimizer = torch.optim.Adam(params, lr=cfg.lr)
176 | # optimizer = torch.optim.SGD(params, lr=cfg.lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay)
177 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, eta_min=cfg.lr*0.01)
178 | # for name, param in model.named_parameters():
179 | # if param.requires_grad:
180 | # print(name)
181 | torch.cuda.empty_cache()
182 | for i_epoch in range(epochs):
183 |
184 | for bach_i, (inputs, targets, t) in enumerate(train_loader):
185 | loss_c = torch.tensor(0.0).to(device)
186 | loss = torch.tensor(0.0).to(device)
187 |
188 | replay_loss = torch.tensor(0.0).to(device)
189 | torch.cuda.empty_cache()
190 |
191 |
192 | # targets = targets - targets_bais
193 | inputs, targets = inputs.to(device), targets.to(device)
194 |
195 | outputs = model(inputs)
196 | # image_f, text_f = model(inputs, return_feature=True)
197 | if task_id >0:
198 | if cfg.real_replay:
199 | mask_replay = (targets < targets_bais)
200 | old_targets = targets[mask_replay].clone()
201 | old_outputs = outputs[mask_replay].clone()
202 | targets = targets[~mask_replay]
203 | outputs = outputs[~mask_replay]
204 | replay_loss = intra_cls(old_outputs, old_targets, 0).mean()*0.1
205 | loss_c = intra_cls(outputs,targets,targets_bais).mean() + replay_loss
206 | pass
207 | else:
208 | loss_c = torch.nn.functional.cross_entropy(outputs, targets)
209 | loss += loss_c
210 | optimizer.zero_grad()
211 | loss.backward()
212 | optimizer.step()
213 | if bach_i % 10 == 0:
214 | logging.info(f"Epoch {i_epoch + 1}/{epochs} | Batch {bach_i + 1}/{len(train_loader)} | Loss: {loss.item()} | Loss_c: {loss_c.item()}")
215 | scheduler.step()
216 |
217 |
218 |
219 | # torch.cuda.empty_cache()
220 | # positive_outputs = []
221 | # negative_outputs = []
222 | # with torch.no_grad():
223 | # model.eval()
224 | # for inputs, targets, t in val_gap_loader:
225 | # inputs, targets = inputs.to(device), targets.to(device)
226 | # outputs = model(inputs)
227 | # # pdb.set_trace()
228 | # one_hot_targets = torch.nn.functional.one_hot(targets, outputs.shape[1]).float()
229 | # positive_outputs.append((outputs * one_hot_targets).sum(dim=1).mean())
230 | # mask = 1 - one_hot_targets
231 | # negative_outputs.append(((outputs * mask).sum(dim=1) / mask.sum(dim=1)).mean())
232 | # model.train()
233 | # positive_mean = sum(positive_outputs) / len(positive_outputs)
234 | # negative_mean = sum(negative_outputs) / len(negative_outputs)
235 | # all_mean = (sum(positive_outputs)+ sum(positive_outputs))/ (len(positive_outputs)+len(negative_outputs))
236 |
237 | # logging.info(f"positive_mean: {positive_mean}")
238 | # logging.info(f"negative_mean: {negative_mean}")
239 | # torch.cuda.empty_cache()
240 |
241 | if cfg.real_replay:
242 | memory.add(*train_dataset[task_id].get_raw_samples(), None)
243 |
244 | if cfg.balance_ft and cfg.real_replay and task_id > 0:
245 | balance_data = get_finetuning_dataset(t_data, memory, 'balanced')
246 | balance_loader = DataLoader(balance_data, batch_size=cfg.train_batch_size, shuffle=True, num_workers=cfg.num_workers)
247 | epochs = cfg.balance_epochs
248 |
249 | params = filter(lambda p: p.requires_grad, model.parameters())
250 | optimizer = torch.optim.Adam(params, lr=cfg.lr*0.01)
251 | # optimizer = torch.optim.SGD(params, lr=cfg.lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay)
252 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, eta_min=cfg.lr*0.001)
253 | for i_epoch in range(epochs):
254 | for bach_i, (inputs, targets, t) in enumerate(balance_loader):
255 | loss_c = torch.tensor(0.0).to(device)
256 | loss = torch.tensor(0.0).to(device)
257 |
258 | replay_loss = torch.tensor(0.0).to(device)
259 | torch.cuda.empty_cache()
260 |
261 | inputs, targets = inputs.to(device), targets.to(device)
262 | outputs = model(inputs)
263 | # image_f, text_f = model(inputs, return_feature=True)
264 | loss_c = torch.nn.functional.cross_entropy(outputs, targets)
265 | loss += loss_c
266 | optimizer.zero_grad()
267 | loss.backward()
268 | optimizer.step()
269 | if bach_i % 10 == 0:
270 | logging.info(f"Epoch {i_epoch + 1}/{epochs} | Batch {bach_i + 1}/{len(balance_loader)} | Loss: {loss.item()} | Loss_c: {loss_c.item()}")
271 | # break
272 | scheduler.step()
273 |
274 | # if task_id > 0:
275 | # alpha = 0.2 # EMA
276 | # print("EMA")
277 | # with torch.no_grad():
278 | # for name, param in model.named_parameters():
279 | # if param.requires_grad and name in initial_params:
280 | # param.copy_(alpha * initial_params[name] + (1 - alpha) * param)
281 | if cfg.visual_clsf:
282 | # pdb.set_trace()
283 | torch.cuda.empty_cache()
284 | model.eval()
285 | e_num = cfg.visual_clsf_epochs
286 | vision_clsf_loader = DataLoader(train_dataset[task_id], batch_size=cfg.visual_clsf_batch_size, shuffle=True, num_workers=cfg.num_workers)
287 | features_dict = {}
288 | with torch.no_grad():
289 | for inputs, targets, t in vision_clsf_loader:
290 | inputs, targets = inputs.to(device), targets.to(device)
291 | _, features, __ = model(inputs, test=True, return_feature=True)
292 | for feature, target in zip(features, targets):
293 | target = target.item()
294 | if target not in features_dict:
295 | features_dict[target] = []
296 | features_dict[target].append(feature.cpu())
297 | mean_features = []
298 | for target in sorted(features_dict.keys()):
299 | features = torch.stack(features_dict[target])
300 | mean_feature = features.mean(dim=0)
301 | mean_features.append(mean_feature.unsqueeze(0))
302 | mean_features = torch.cat(mean_features).to(device)
303 | if task_id > 0:
304 | vision_clsf.add_weight(mean_features)
305 | pass
306 | else:
307 | vision_clsf.set_weight(mean_features)
308 | pass
309 | optimizer = torch.optim.Adam(vision_clsf.parameters(), lr=cfg.visual_clsf_lr)
310 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, e_num*len(vision_clsf_loader), eta_min=cfg.visual_clsf_lr*0.01)
311 | for e in range(e_num):
312 | bach_i = -1
313 | for inputs, targets, t in vision_clsf_loader:
314 | inputs, targets = inputs.to(device), targets.to(device)
315 | # pdb.set_trace()
316 | with torch.no_grad():
317 | outputs, _ = model(inputs, return_feature=True)
318 | # pdb.set_trace()
319 | outputs = vision_clsf(outputs)
320 | # pdb.set_trace()
321 | loss = intra_cls(outputs,targets,targets_bais).mean()
322 | # loss = F.cross_entropy(outputs, targets)
323 | optimizer.zero_grad()
324 | loss.backward()
325 | optimizer.step()
326 | bach_i+=1
327 | if bach_i % 10 == 0:
328 | logging.info(f"Epoch {e + 1}/{e_num} | Batch {bach_i + 1}/{len(vision_clsf_loader)} | Loss: {loss.item()}")
329 | scheduler.step()
330 |
331 | if cfg.balance_ft and cfg.real_replay and task_id > 0:
332 | balance_data = get_finetuning_dataset(t_data, memory, 'balanced')
333 | balance_loader = DataLoader(balance_data, batch_size=cfg.train_batch_size, shuffle=True, num_workers=cfg.num_workers)
334 | epochs = cfg.balance_epochs
335 |
336 | optimizer = torch.optim.Adam(vision_clsf.parameters(), lr=cfg.visual_clsf_lr*0.1)
337 | # optimizer = torch.optim.SGD(params, lr=cfg.lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay)
338 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs*len(balance_loader), eta_min=cfg.lr*0.01)
339 | for i_epoch in range(epochs):
340 | for bach_i, (inputs, targets, t) in enumerate(balance_loader):
341 | loss_c = torch.tensor(0.0).to(device)
342 | loss = torch.tensor(0.0).to(device)
343 |
344 | replay_loss = torch.tensor(0.0).to(device)
345 | torch.cuda.empty_cache()
346 |
347 | inputs, targets = inputs.to(device), targets.to(device)
348 | with torch.no_grad():
349 | outputs, _ = model(inputs, return_feature=True)
350 | # pdb.set_trace()
351 | outputs = vision_clsf(outputs)
352 | loss_c = torch.nn.functional.cross_entropy(outputs, targets)
353 | loss += loss_c
354 | optimizer.zero_grad()
355 | loss.backward()
356 | optimizer.step()
357 | if bach_i % 10 == 0:
358 | logging.info(f"Epoch {i_epoch + 1}/{epochs} | Batch {bach_i + 1}/{len(balance_loader)} | Loss: {loss.item()} | Loss_c: {loss_c.item()}")
359 | # break
360 | scheduler.step()
361 |
362 |
363 | if cfg.all_test:
364 | eval_loader = DataLoader(eval_dataset[:cfg.task_num], batch_size=cfg.batch_size)
365 | else:
366 | eval_loader = DataLoader(eval_dataset[:task_id + 1], batch_size=cfg.batch_size)
367 | # eval_loader = DataLoader(eval_dataset[:10], batch_size=cfg.batch_size)
368 | image_feature_list = []
369 | targets_list = []
370 | model.eval()
371 | text_feature_list = []
372 | torch.cuda.empty_cache()
373 | correct_per_class = defaultdict(int)
374 | total_per_class = defaultdict(int)
375 | for inputs, targets, task_ids in eval_loader:
376 | inputs, targets = inputs.to(device), targets.to(device)
377 |
378 | with torch.no_grad():
379 | if cfg.visual_clsf:
380 | a = 1
381 | b = 4
382 |
383 | outputs, image_feature, text_feature = model(inputs, test=True, all_test=cfg.all_test, return_feature=True)
384 | vision_outputs = vision_clsf(image_feature)
385 |
386 | outputs_softmax = F.softmax(outputs, dim=1)
387 | vision_outputs_softmax = F.softmax(vision_outputs, dim=1)
388 |
389 | combined_outputs = (a*outputs_softmax + b*vision_outputs_softmax) / (a + b)
390 |
391 | metric_logger.add([combined_outputs.cpu().argmax(dim=1), targets.cpu(), task_ids], subset="test")
392 | preds = combined_outputs.cpu().argmax(dim=1)
393 | for l,p in zip(targets.cpu(), preds):
394 | label = l.item()
395 | total_per_class[label] += 1
396 | if l == p:
397 | correct_per_class[label] += 1
398 | else:
399 | outputs = model(inputs, test=True, all_test=cfg.all_test)
400 | metric_logger.add([outputs.cpu().argmax(dim=1), targets.cpu(), task_ids], subset="test")
401 | class_acc = {}
402 | for clas in total_per_class:
403 | acc = correct_per_class[clas] / total_per_class[clas]
404 | class_acc[clas] = acc
405 | avg_acc = np.mean(list(class_acc.values()))
406 |
407 |
408 |
409 | acc_list.append(100 * metric_logger.accuracy)
410 | with open(cfg.log_path, 'a+') as f:
411 | f.write(json.dumps({
412 | 'task': task_id,
413 | 'acc': round(100 * metric_logger.accuracy, 2),
414 | 'avg_acc': round(100 * metric_logger.average_incremental_accuracy, 2),
415 | 'forgetting': round(100 * metric_logger.forgetting, 6),
416 | 'acc_per_task': [round(100 * acc_t, 2) for acc_t in metric_logger.accuracy_per_task],
417 | 'bwt': round(100 * metric_logger.backward_transfer, 2),
418 | 'fwt': round(100 * metric_logger.forward_transfer, 2),
419 | }) + '\n')
420 | metric_logger.end_task()
421 | torch.save(model.state_dict(), f'final_model.pth')
422 | with open(cfg.log_path, 'a+') as f:
423 | f.write(json.dumps({
424 | 'last': round(acc_list[-1], 2),
425 | 'avg': round(statistics.mean(acc_list), 2)
426 | }) + '\n')
427 |
428 |
429 |
430 | def run_domain_incremental(cfg, device):
431 |
432 | model = model = load_model(cfg, device)
433 | eval_dataset, classes_names = build_cl_scenarios(
434 | cfg, is_train=False, transforms=model.transforms
435 | )
436 | model.tokenize(classes_names)
437 |
438 | with open(cfg.log_path, 'w+') as f:
439 | pass
440 |
441 | logger = Logger(list_subsets=["test"])
442 | logging.info(f">>> Evaluation scenario length is {len(eval_dataset)}")
443 | for task_id, _ in enumerate(eval_dataset):
444 |
445 | dataset_val = eval_dataset[:task_id + 1]
446 | eval_loader = DataLoader(dataset_val, batch_size=cfg.batch_size)
447 | for input, target, task_ids in tqdm(eval_loader):
448 | input, target = input.to(device), target.to(device)
449 | output = torch.from_numpy(model(input))
450 | logger.add([output.cpu().argmax(dim=1), target.cpu(), task_ids], subset='test')
451 |
452 | with open(cfg.log_path, 'a+') as f:
453 | f.write(json.dumps({
454 | 'task': task_id,
455 | 'acc': round(100 * logger.accuracy, 2),
456 | }) + '\n')
457 |
458 | logger.end_task()
459 |
460 | def run_task_agnostic():
461 | pass
462 |
463 |
464 |
465 | @hydra.main(config_path=None, config_name=None, version_base="1.1")
466 | def continual_clip(cfg: DictConfig) -> None:
467 | seed_everything(cfg.seed)
468 | cfg.workdir = utils.get_workdir(path=os.getcwd())
469 | cfg.dataset_root = os.path.join(cfg.workdir, cfg.dataset_root)
470 |
471 | utils.save_config(cfg)
472 | with open(cfg.log_path, 'w+') as f:
473 | pass
474 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
475 |
476 | if cfg.scenario == "class":
477 | run_class_incremental(cfg, device)
478 |
479 | elif cfg.scenario == "domain":
480 | run_domain_incremental(cfg, device)
481 |
482 | elif cfg.scenario == "task-agnostic":
483 | NotImplementedError("Method has not been implemented. Soon be added.")
484 |
485 | else:
486 | ValueError(f"You have entered `{cfg.scenario}` which is not a defined scenario, "
487 | "please choose from {{'class', 'domain', 'task-agnostic'}}.")
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 | if __name__ == "__main__":
511 | continual_clip()
512 |
--------------------------------------------------------------------------------
/loraclip/model.py:
--------------------------------------------------------------------------------
1 | # content adapted from CLIP's official github: openai/CLIP
2 |
3 | import pdb
4 | from .loralib import layers as lora
5 | from .loralib import utils as lora_utils
6 |
7 | from collections import OrderedDict
8 | from typing import Tuple, Union
9 |
10 | import numpy as np
11 | import torch
12 | import torch.nn.functional as F
13 | from torch import nn
14 |
15 |
16 | class Bottleneck(nn.Module):
17 | expansion = 4
18 |
19 | def __init__(self, inplanes, planes, stride=1):
20 | super().__init__()
21 |
22 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
23 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
24 | self.bn1 = nn.BatchNorm2d(planes)
25 | self.relu1 = nn.ReLU(inplace=True)
26 |
27 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
28 | self.bn2 = nn.BatchNorm2d(planes)
29 | self.relu2 = nn.ReLU(inplace=True)
30 |
31 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
32 |
33 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
34 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
35 | self.relu3 = nn.ReLU(inplace=True)
36 |
37 | self.downsample = None
38 | self.stride = stride
39 |
40 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
41 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
42 | self.downsample = nn.Sequential(OrderedDict([
43 | ("-1", nn.AvgPool2d(stride)),
44 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
45 | ("1", nn.BatchNorm2d(planes * self.expansion))
46 | ]))
47 |
48 | def forward(self, x: torch.Tensor):
49 | identity = x
50 |
51 | out = self.relu1(self.bn1(self.conv1(x)))
52 | out = self.relu2(self.bn2(self.conv2(out)))
53 | out = self.avgpool(out)
54 | out = self.bn3(self.conv3(out))
55 |
56 | if self.downsample is not None:
57 | identity = self.downsample(x)
58 |
59 | out += identity
60 | out = self.relu3(out)
61 | return out
62 |
63 |
64 | class AttentionPool2d(nn.Module):
65 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
66 | super().__init__()
67 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
68 | self.k_proj = nn.Linear(embed_dim, embed_dim)
69 | self.q_proj = nn.Linear(embed_dim, embed_dim)
70 | self.v_proj = nn.Linear(embed_dim, embed_dim)
71 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
72 | self.num_heads = num_heads
73 |
74 | def forward(self, x):
75 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
76 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
77 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
78 | x, _ = F.multi_head_attention_forward(
79 | query=x[:1], key=x, value=x,
80 | embed_dim_to_check=x.shape[-1],
81 | num_heads=self.num_heads,
82 | q_proj_weight=self.q_proj.weight,
83 | k_proj_weight=self.k_proj.weight,
84 | v_proj_weight=self.v_proj.weight,
85 | in_proj_weight=None,
86 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
87 | bias_k=None,
88 | bias_v=None,
89 | add_zero_attn=False,
90 | dropout_p=0,
91 | out_proj_weight=self.c_proj.weight,
92 | out_proj_bias=self.c_proj.bias,
93 | use_separate_proj_weight=True,
94 | training=self.training,
95 | need_weights=False
96 | )
97 | return x.squeeze(0)
98 |
99 |
100 | class ModifiedResNet(nn.Module):
101 | """
102 | A ResNet class that is similar to torchvision's but contains the following changes:
103 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
104 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
105 | - The final pooling layer is a QKV attention instead of an average pool
106 | """
107 |
108 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
109 | super().__init__()
110 | self.output_dim = output_dim
111 | self.input_resolution = input_resolution
112 |
113 | # the 3-layer stem
114 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
115 | self.bn1 = nn.BatchNorm2d(width // 2)
116 | self.relu1 = nn.ReLU(inplace=True)
117 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
118 | self.bn2 = nn.BatchNorm2d(width // 2)
119 | self.relu2 = nn.ReLU(inplace=True)
120 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
121 | self.bn3 = nn.BatchNorm2d(width)
122 | self.relu3 = nn.ReLU(inplace=True)
123 | self.avgpool = nn.AvgPool2d(2)
124 |
125 | # residual layers
126 | self._inplanes = width # this is a *mutable* variable used during construction
127 | self.layer1 = self._make_layer(width, layers[0])
128 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
129 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
130 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
131 |
132 | embed_dim = width * 32 # the ResNet feature dimension
133 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
134 |
135 | def _make_layer(self, planes, blocks, stride=1):
136 | layers = [Bottleneck(self._inplanes, planes, stride)]
137 |
138 | self._inplanes = planes * Bottleneck.expansion
139 | for _ in range(1, blocks):
140 | layers.append(Bottleneck(self._inplanes, planes))
141 |
142 | return nn.Sequential(*layers)
143 |
144 | def forward(self, x):
145 | def stem(x):
146 | x = self.relu1(self.bn1(self.conv1(x)))
147 | x = self.relu2(self.bn2(self.conv2(x)))
148 | x = self.relu3(self.bn3(self.conv3(x)))
149 | x = self.avgpool(x)
150 | return x
151 |
152 | x = x.type(self.conv1.weight.dtype)
153 | x = stem(x)
154 | x = self.layer1(x)
155 | x = self.layer2(x)
156 | x = self.layer3(x)
157 | x = self.layer4(x)
158 | x = self.attnpool(x)
159 |
160 | return x
161 |
162 |
163 | class LayerNorm(nn.LayerNorm):
164 | """Subclass torch's LayerNorm to handle fp16."""
165 |
166 | def forward(self, x: torch.Tensor):
167 | orig_type = x.dtype
168 | ret = super().forward(x.type(torch.float32))
169 | return ret.type(orig_type)
170 |
171 |
172 | class QuickGELU(nn.Module):
173 | def forward(self, x: torch.Tensor):
174 | return x * torch.sigmoid(1.702 * x)
175 |
176 |
177 | class ResidualAttentionBlock(nn.Module):
178 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
179 | super().__init__()
180 |
181 | self.attn = nn.MultiheadAttention(d_model, n_head)
182 | self.ln_1 = LayerNorm(d_model)
183 | self.mlp = nn.Sequential(OrderedDict([
184 | ("c_fc", nn.Linear(d_model, d_model * 4)),
185 | ("gelu", QuickGELU()),
186 | ("c_proj", nn.Linear(d_model * 4, d_model))
187 | ]))
188 | self.ln_2 = LayerNorm(d_model)
189 | self.attn_mask = attn_mask
190 |
191 | def attention(self, x: torch.Tensor):
192 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
193 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
194 |
195 | def forward(self, x: torch.Tensor):
196 | x = x + self.attention(self.ln_1(x))
197 | x = x + self.mlp(self.ln_2(x))
198 | return x
199 |
200 |
201 | # LoRA implementation of ResidualAttentionBlock:
202 | class LoRAResidualAttentionBlock(nn.Module):
203 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, r=4, only_kv=False,mlp=False):
204 | super().__init__()
205 |
206 | self.attn = lora.MultiheadAttention(d_model, n_head, r=r, only_kv=only_kv, mlp=mlp) # LoRA rank set as 4
207 | self.ln_1 = LayerNorm(d_model)
208 | if only_kv:
209 | self.mlp = nn.Sequential(OrderedDict([
210 | ("c_fc", nn.Linear(d_model, d_model * 4)),
211 | ("gelu", QuickGELU()),
212 | ("c_proj", nn.Linear(d_model * 4, d_model))
213 | ]))
214 | else:
215 | self.mlp = nn.Sequential(OrderedDict([
216 | ("c_fc", lora.Linear(d_model, d_model * 4, r=r)),
217 | ("gelu", QuickGELU()),
218 | ("c_proj", lora.Linear(d_model * 4, d_model, r=r))
219 | ]))
220 | if mlp:
221 | self.adapter_mlp = nn.Sequential(OrderedDict([
222 | ("c_fc", nn.Linear(d_model, r, bias=False)),
223 | ("gelu", QuickGELU()),
224 | ("c_proj", nn.Linear(r, d_model, bias=False))
225 | ]))
226 | nn.init.zeros_(self.adapter_mlp.c_proj.weight)
227 | self.mlp_flag = mlp
228 | self.ln_2 = LayerNorm(d_model)
229 | self.attn_mask = attn_mask
230 |
231 | def attention(self, x: torch.Tensor):
232 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
233 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
234 |
235 | def forward(self, x: torch.Tensor):
236 | x = x + self.attention(self.ln_1(x))
237 | if self.mlp_flag:
238 | x = x + self.adapter_mlp(self.ln_2(x)) + self.mlp(self.ln_2(x))
239 | else:
240 | x = x + self.mlp(self.ln_2(x))
241 | return x
242 |
243 |
244 | class Transformer(nn.Module):
245 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
246 | super().__init__()
247 | self.width = width
248 | self.layers = layers
249 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
250 |
251 | def forward(self, x: torch.Tensor):
252 | return self.resblocks(x)
253 |
254 |
255 | # LoRA implementation of Transformer:
256 | class LoRATransformer(nn.Module):
257 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, r = 4, only_kv=False,mlp=False):
258 | super().__init__()
259 | self.width = width
260 | self.layers = layers
261 | self.resblocks = nn.Sequential(*[LoRAResidualAttentionBlock(width, heads, attn_mask, r=r, only_kv=only_kv, mlp=mlp) for _ in range(layers)])
262 |
263 | def forward(self, x: torch.Tensor):
264 | return self.resblocks(x)
265 |
266 |
267 | class VisionTransformer(nn.Module):
268 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
269 | super().__init__()
270 | self.input_resolution = input_resolution
271 | self.output_dim = output_dim
272 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
273 |
274 | scale = width ** -0.5
275 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
276 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
277 | self.ln_pre = LayerNorm(width)
278 |
279 | self.transformer = Transformer(width, layers, heads)
280 |
281 | self.ln_post = LayerNorm(width)
282 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
283 |
284 | def forward(self, x: torch.Tensor):
285 | x = self.conv1(x) # shape = [*, width, grid, grid]
286 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
287 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
288 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
289 | x = x + self.positional_embedding.to(x.dtype)
290 | x = self.ln_pre(x)
291 |
292 | x = x.permute(1, 0, 2) # NLD -> LND
293 | x = self.transformer(x)
294 | x = x.permute(1, 0, 2) # LND -> NLD
295 |
296 | x = self.ln_post(x[:, 0, :])
297 |
298 | if self.proj is not None:
299 | x = x @ self.proj
300 |
301 | return x
302 |
303 |
304 | class LoRAVisionTransformer(nn.Module):
305 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, r: int, only_kv=False, mlp=False):
306 | super().__init__()
307 | self.input_resolution = input_resolution
308 | self.output_dim = output_dim
309 | if only_kv:
310 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
311 | else:
312 | self.conv1 = lora.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
313 |
314 | scale = width ** -0.5
315 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
316 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
317 | self.ln_pre = LayerNorm(width)
318 |
319 | self.transformer = LoRATransformer(width, layers, heads, only_kv=only_kv, r=r, mlp=mlp)
320 | # self.transformer = Transformer(width, layers, heads)
321 |
322 | self.ln_post = LayerNorm(width)
323 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
324 |
325 | def forward(self, x: torch.Tensor):
326 | x = self.conv1(x) # shape = [*, width, grid, grid]
327 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
328 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
329 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
330 | x = x + self.positional_embedding.to(x.dtype)
331 | x = self.ln_pre(x)
332 |
333 | x = x.permute(1, 0, 2) # NLD -> LND
334 | x = self.transformer(x)
335 | x = x.permute(1, 0, 2) # LND -> NLD
336 |
337 | x = self.ln_post(x[:, 0, :])
338 |
339 | if self.proj is not None:
340 | x = x @ self.proj
341 |
342 | return x
343 |
344 |
345 | class CLIP(nn.Module):
346 | def __init__(self,
347 | embed_dim: int,
348 | # vision
349 | image_resolution: int,
350 | vision_layers: Union[Tuple[int, int, int, int], int],
351 | vision_width: int,
352 | vision_patch_size: int,
353 | # text
354 | context_length: int,
355 | vocab_size: int,
356 | transformer_width: int,
357 | transformer_heads: int,
358 | transformer_layers: int
359 | ):
360 | super().__init__()
361 |
362 | self.context_length = context_length
363 |
364 | if isinstance(vision_layers, (tuple, list)):
365 | vision_heads = vision_width * 32 // 64
366 | self.visual = ModifiedResNet(
367 | layers=vision_layers,
368 | output_dim=embed_dim,
369 | heads=vision_heads,
370 | input_resolution=image_resolution,
371 | width=vision_width
372 | )
373 | else:
374 | vision_heads = vision_width // 64
375 | self.visual = VisionTransformer(
376 | input_resolution=image_resolution,
377 | patch_size=vision_patch_size,
378 | width=vision_width,
379 | layers=vision_layers,
380 | heads=vision_heads,
381 | output_dim=embed_dim
382 | )
383 |
384 | self.transformer = Transformer(
385 | width=transformer_width,
386 | layers=transformer_layers,
387 | heads=transformer_heads,
388 | attn_mask=self.build_attention_mask()
389 | )
390 |
391 | self.vocab_size = vocab_size
392 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
393 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
394 | self.ln_final = LayerNorm(transformer_width)
395 |
396 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
397 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
398 |
399 | self.initialize_parameters()
400 |
401 | def initialize_parameters(self):
402 | nn.init.normal_(self.token_embedding.weight, std=0.02)
403 | nn.init.normal_(self.positional_embedding, std=0.01)
404 |
405 | if isinstance(self.visual, ModifiedResNet):
406 | if self.visual.attnpool is not None:
407 | std = self.visual.attnpool.c_proj.in_features ** -0.5
408 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
409 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
410 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
411 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
412 |
413 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
414 | for name, param in resnet_block.named_parameters():
415 | if name.endswith("bn3.weight"):
416 | nn.init.zeros_(param)
417 |
418 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
419 | attn_std = self.transformer.width ** -0.5
420 | fc_std = (2 * self.transformer.width) ** -0.5
421 | for block in self.transformer.resblocks:
422 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
423 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
424 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
425 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
426 |
427 | if self.text_projection is not None:
428 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
429 |
430 | def build_attention_mask(self):
431 | # lazily create causal attention mask, with full attention between the vision tokens
432 | # pytorch uses additive attention mask; fill with -inf
433 | mask = torch.empty(self.context_length, self.context_length)
434 | mask.fill_(float("-inf"))
435 | mask.triu_(1) # zero out the lower diagonal
436 | return mask
437 |
438 | @property
439 | def dtype(self):
440 | return self.visual.conv1.weight.dtype
441 |
442 | def encode_image(self, image):
443 | return self.visual(image.type(self.dtype))
444 |
445 | def encode_text(self, text):
446 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
447 |
448 | x = x + self.positional_embedding.type(self.dtype)
449 | x = x.permute(1, 0, 2) # NLD -> LND
450 | x = self.transformer(x)
451 | x = x.permute(1, 0, 2) # LND -> NLD
452 | x = self.ln_final(x).type(self.dtype)
453 |
454 | # x.shape = [batch_size, n_ctx, transformer.width]
455 | # take features from the eot embedding (eot_token is the highest number in each sequence)
456 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
457 |
458 | return x
459 |
460 | def forward(self, image, text):
461 | image_features = self.encode_image(image)
462 | text_features = self.encode_text(text)
463 |
464 | # normalized features
465 | image_features = image_features / image_features.norm(dim=1, keepdim=True)
466 | text_features = text_features / text_features.norm(dim=1, keepdim=True)
467 |
468 | # cosine similarity as logits
469 | logit_scale = self.logit_scale.exp()
470 | logits_per_image = logit_scale * image_features @ text_features.t()
471 | logits_per_text = logits_per_image.t()
472 |
473 | # shape = [global_batch_size, global_batch_size]
474 | return logits_per_image, logits_per_text
475 |
476 |
477 | class LoRACLIP(nn.Module):
478 | def __init__(self,
479 | embed_dim: int,
480 | # vision
481 | image_resolution: int,
482 | vision_layers: Union[Tuple[int, int, int, int], int],
483 | vision_width: int,
484 | vision_patch_size: int,
485 | # text
486 | context_length: int,
487 | vocab_size: int,
488 | transformer_width: int,
489 | transformer_heads: int,
490 | transformer_layers: int,
491 | r: int,
492 | lora_mode: str
493 | ):
494 | super().__init__()
495 |
496 | self.context_length = context_length
497 |
498 | if isinstance(vision_layers, (tuple, list)):
499 | vision_heads = vision_width * 32 // 64
500 | self.visual = ModifiedResNet(
501 | layers=vision_layers,
502 | output_dim=embed_dim,
503 | heads=vision_heads,
504 | input_resolution=image_resolution,
505 | width=vision_width
506 | )
507 | else:
508 | vision_heads = vision_width // 64
509 |
510 | if "vision" in lora_mode:
511 | self.visual = LoRAVisionTransformer(
512 | input_resolution=image_resolution,
513 | patch_size=vision_patch_size,
514 | width=vision_width,
515 | layers=vision_layers,
516 | heads=vision_heads,
517 | output_dim=embed_dim,
518 | r=r,
519 | only_kv=("only_kv" in lora_mode),
520 | mlp="mlp" in lora_mode,
521 | )
522 | else:
523 | self.visual = VisionTransformer(
524 | input_resolution=image_resolution,
525 | patch_size=vision_patch_size,
526 | width=vision_width,
527 | layers=vision_layers,
528 | heads=vision_heads,
529 | output_dim=embed_dim
530 | )
531 |
532 | if "text" in lora_mode:
533 | self.transformer = LoRATransformer(
534 | width=transformer_width,
535 | layers=transformer_layers,
536 | heads=transformer_heads,
537 | attn_mask=self.build_attention_mask(),
538 | r = r,
539 | only_kv=("only_kv" in lora_mode),
540 | mlp="mlp" in lora_mode,
541 | )
542 |
543 | else:
544 | self.transformer = Transformer(
545 | width=transformer_width,
546 | layers=transformer_layers,
547 | heads=transformer_heads,
548 | attn_mask=self.build_attention_mask()
549 | )
550 |
551 | self.vocab_size = vocab_size
552 |
553 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
554 | self.ln_final = LayerNorm(transformer_width)
555 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
556 |
557 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
558 |
559 | if "text" in lora_mode:
560 | if "only_kv" in lora_mode:
561 | self.lora_text_projection = nn.Linear(transformer_width, embed_dim, bias=False)
562 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
563 | else:
564 | self.lora_text_projection = lora.Linear(transformer_width, embed_dim, r=r, bias=False)
565 | self.token_embedding = lora.Embedding(vocab_size, transformer_width, r=r)
566 |
567 | else:
568 | self.lora_text_projection = nn.Linear(transformer_width, embed_dim, bias=False)
569 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
570 |
571 | self.initialize_parameters()
572 |
573 | def initialize_parameters(self):
574 | nn.init.normal_(self.token_embedding.weight, std=0.02)
575 | nn.init.normal_(self.positional_embedding, std=0.01)
576 |
577 | if isinstance(self.visual, ModifiedResNet):
578 | if self.visual.attnpool is not None:
579 | std = self.visual.attnpool.c_proj.in_features ** -0.5
580 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
581 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
582 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
583 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
584 |
585 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
586 | for name, param in resnet_block.named_parameters():
587 | if name.endswith("bn3.weight"):
588 | nn.init.zeros_(param)
589 |
590 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
591 | attn_std = self.transformer.width ** -0.5
592 | fc_std = (2 * self.transformer.width) ** -0.5
593 | for block in self.transformer.resblocks:
594 | pass
595 | # nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
596 | # nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
597 | # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
598 | # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
599 |
600 | if self.text_projection is not None:
601 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
602 |
603 | def build_attention_mask(self):
604 | # lazily create causal attention mask, with full attention between the vision tokens
605 | # pytorch uses additive attention mask; fill with -inf
606 | mask = torch.empty(self.context_length, self.context_length)
607 | mask.fill_(float("-inf"))
608 | mask.triu_(1) # zero out the lower diagonal
609 | return mask
610 |
611 | @property
612 | def dtype(self):
613 | return self.visual.conv1.weight.dtype
614 |
615 | def encode_image(self, image):
616 | return self.visual(image.type(self.dtype))
617 |
618 | def encode_text(self, text):
619 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
620 |
621 | x = x + self.positional_embedding.type(self.dtype)
622 | x = x.permute(1, 0, 2) # NLD -> LND
623 | x = self.transformer(x)
624 | x = x.permute(1, 0, 2) # LND -> NLD
625 | x = self.ln_final(x).type(self.dtype)
626 |
627 | # x.shape = [batch_size, n_ctx, transformer.width]
628 | # take features from the eot embedding (eot_token is the highest number in each sequence)
629 | # x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
630 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
631 | x = self.lora_text_projection(x)
632 | return x
633 |
634 | def forward(self, image, text):
635 | image_features = self.encode_image(image)
636 | text_features = self.encode_text(text)
637 |
638 | # normalized features
639 | image_features = image_features / image_features.norm(dim=1, keepdim=True)
640 | text_features = text_features / text_features.norm(dim=1, keepdim=True)
641 |
642 | # cosine similarity as logits
643 | logit_scale = self.logit_scale.exp()
644 | logits_per_image = logit_scale * image_features @ text_features.t()
645 | logits_per_text = logits_per_image.t()
646 |
647 | # shape = [global_batch_size, global_batch_size]
648 | return logits_per_image, logits_per_text
649 |
650 | def convert_weights(model: nn.Module):
651 | """Convert applicable model parameters to fp16"""
652 |
653 | def _convert_weights_to_fp16(l):
654 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
655 | l.weight.data = l.weight.data.half()
656 | if l.bias is not None:
657 | l.bias.data = l.bias.data.half()
658 |
659 | if isinstance(l, nn.MultiheadAttention):
660 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
661 | tensor = getattr(l, attr)
662 | if tensor is not None:
663 | tensor.data = tensor.data.half()
664 |
665 | for name in ["text_projection", "proj"]:
666 | if hasattr(l, name):
667 | attr = getattr(l, name)
668 | if attr is not None:
669 | attr.data = attr.data.half()
670 |
671 | model.apply(_convert_weights_to_fp16)
672 |
673 | def convert_weights_lora_model(model: nn.Module):
674 | """Convert applicable model parameters to fp16"""
675 |
676 | def _convert_weights_to_fp16(l):
677 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
678 | l.weight.data = l.weight.data.half()
679 | if l.bias is not None:
680 | l.bias.data = l.bias.data.half()
681 |
682 | if isinstance(l, lora.MultiheadAttention):
683 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v", *[f"{s}_proj_weight_lora_A" for s in ["in", "q", "k", "v"]], *[f"{s}_proj_weight_lora_B" for s in ["in", "q", "k", "v"]]]:
684 | tensor = getattr(l, attr)
685 | if tensor is not None:
686 | tensor.data = tensor.data.half()
687 |
688 | for name in ["text_projection", "proj"]:
689 | if hasattr(l, name):
690 | attr = getattr(l, name)
691 | if attr is not None:
692 | attr.data = attr.data.half()
693 |
694 | model.apply(_convert_weights_to_fp16)
695 |
696 |
697 | def build_model(state_dict: dict):
698 | vit = "visual.proj" in state_dict
699 |
700 | if vit:
701 | vision_width = state_dict["visual.conv1.weight"].shape[0]
702 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
703 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
704 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
705 | image_resolution = vision_patch_size * grid_size
706 | else:
707 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
708 | vision_layers = tuple(counts)
709 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
710 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
711 | vision_patch_size = None
712 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
713 | image_resolution = output_width * 32
714 |
715 | embed_dim = state_dict["text_projection"].shape[1]
716 | context_length = state_dict["positional_embedding"].shape[0]
717 | vocab_size = state_dict["token_embedding.weight"].shape[0]
718 | transformer_width = state_dict["ln_final.weight"].shape[0]
719 | transformer_heads = transformer_width // 64
720 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
721 |
722 | model = CLIP(
723 | embed_dim,
724 | image_resolution, vision_layers, vision_width, vision_patch_size,
725 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
726 | )
727 |
728 | for key in ["input_resolution", "context_length", "vocab_size"]:
729 | if key in state_dict:
730 | del state_dict[key]
731 |
732 | convert_weights(model)
733 | model.load_state_dict(state_dict)
734 | return model.eval()
735 |
736 | def build_LoRA_model(state_dict: dict, r: int, lora_mode: str):
737 | vit = "visual.proj" in state_dict
738 |
739 | if vit:
740 | vision_width = state_dict["visual.conv1.weight"].shape[0]
741 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
742 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
743 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
744 | image_resolution = vision_patch_size * grid_size
745 | else:
746 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
747 | vision_layers = tuple(counts)
748 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
749 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
750 | vision_patch_size = None
751 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
752 | image_resolution = output_width * 32
753 | if "only_kv" in lora_mode:
754 | for name, param in list(state_dict.items()):
755 | if "transformer.resblocks." in name and 'in_proj_weight' in name:
756 | # pdb.set_trace()
757 | shape_chunk = param.shape[0]//3
758 | state_dict[name.replace("in_proj_weight", "q_proj_weight")] = param[:1*shape_chunk].clone()
759 | state_dict[name.replace("in_proj_weight", "k_proj_weight")] = param[1*shape_chunk:2*shape_chunk].clone()
760 | state_dict[name.replace("in_proj_weight", "v_proj_weight")] = param[2*shape_chunk:3*shape_chunk].clone()
761 | del state_dict[name]
762 | # if "transformer.resblocks" in name and 'in_proj_weight' in name:
763 | # shape_chunk = param.shape[0]//3
764 | # state_dict[name.replace("in_proj_weight", "q_proj_weight")] = param[:1*shape_chunk].clone()
765 | # state_dict[name.replace("in_proj_weight", "k_proj_weight")] = param[1*shape_chunk:2*shape_chunk].clone()
766 | # state_dict[name.replace("in_proj_weight", "v_proj_weight")] = param[2*shape_chunk:3*shape_chunk].clone()
767 | # del state_dict[name]
768 | embed_dim = state_dict["text_projection"].shape[1]
769 | context_length = state_dict["positional_embedding"].shape[0]
770 | vocab_size = state_dict["token_embedding.weight"].shape[0]
771 | transformer_width = state_dict["ln_final.weight"].shape[0]
772 | transformer_heads = transformer_width // 64
773 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
774 |
775 | model = LoRACLIP(
776 | embed_dim,
777 | image_resolution, vision_layers, vision_width, vision_patch_size,
778 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers,
779 | r, lora_mode
780 | )
781 |
782 | for key in ["input_resolution", "context_length", "vocab_size"]:
783 | if key in state_dict:
784 | del state_dict[key]
785 |
786 | new_state_dict = state_dict
787 | new_state_dict["lora_text_projection.weight"] = state_dict["text_projection"].T
788 | # pdb.set_trace()
789 |
790 | res = model.load_state_dict(new_state_dict, strict=False)
791 | missing_keys = res.missing_keys
792 | unexpected_keys = res.unexpected_keys
793 | missing_keys = [x for x in missing_keys if 'lora_' not in x] # ignore LoRA extra weights
794 |
795 | print("Model loaded")
796 | if len(missing_keys) != 0:
797 | print(f"Missing keys: {missing_keys}")
798 |
799 | if len(unexpected_keys) != 0:
800 | print(f"Unexpected keys: {unexpected_keys}")
801 |
802 | print(" ")
803 |
804 | # here we mark only lora parameters as trainable
805 | # for name, param in model.named_parameters():
806 | # if param.requires_grad:
807 | # print(name)
808 | lora_utils.mark_only_lora_as_trainable(model)
809 | # 冻结text_proj 以及位置编码等
810 | ##---------------------------------------------------------
811 | ##
812 |
813 | for name, param in model.named_parameters():
814 | # if param.requires_grad:
815 | # print(name)
816 | # pdb.set_trace()
817 | if 'lora_text_projection' in name:
818 | param.requires_grad = False
819 | # some caveats for loading a model for fine-tuning:
820 | # if "text" in lora_mode:
821 | # for name, param in model.named_parameters():
822 | # if "positional_embedding" in name:
823 | # param.requires_grad = True
824 | # if "text_projection" in name:
825 | # param.requies_grad = True
826 | # if "logit_scale" in name:
827 | # param.requires_grad = True
828 |
829 | # if "vision" in lora_mode:
830 | # for name, param in model.named_parameters():
831 | # if "visual.proj" in name:
832 | # param.requires_grad = True
833 | # if "visual.class_embedding" in name:
834 | # param.requires_grad = True
835 | # if "visual.positional_embedding" in name:
836 | # param.requires_grad = True
837 | ##
838 | ## ---------------------------------------------------------
839 | # convert_weights_lora_model(model)
840 | return model.eval()
--------------------------------------------------------------------------------