├── continual_clip ├── __init__.py ├── utils.py ├── datasets.py └── models.py ├── loraclip ├── loralib │ ├── __init__.py │ └── utils.py ├── __init__.py ├── lora_clip.py └── model.py ├── class_orders ├── cifar100.yaml ├── imagenet100.yaml └── imagenet_R.yaml ├── run_experiment_img100.sh ├── run_experiment_cifar.sh ├── run_experiment.sh ├── requirements.txt ├── setup_environment.sh ├── dataset_reqs ├── oxford_pet_classes.txt ├── vtab_classes.txt ├── caltech101_classes.txt ├── food101_classes.txt ├── imagenet100_classes.txt ├── cub200_classes.txt ├── imagenet_R_classes.txt ├── tinyimagenet_classes.txt ├── imagenet_c_classes.txt └── imagenet1000_classes.txt ├── configs └── class │ ├── cifar100_10-10.yaml │ ├── imagenet100_10-10.yaml │ └── imagenet_r_20-20.yaml ├── .gitignore ├── README.md ├── epoch.py └── main.py /continual_clip/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /loraclip/loralib/__init__.py: -------------------------------------------------------------------------------- 1 | name = "lora" 2 | 3 | from .layers import * 4 | from .utils import * -------------------------------------------------------------------------------- /class_orders/cifar100.yaml: -------------------------------------------------------------------------------- 1 | class_order: [87, 0, 52, 58, 44, 91, 68, 97, 51, 15, 94, 92, 10, 72, 49, 78, 61, 14, 8, 86, 84, 96, 18, 24, 32, 45, 88, 11, 4, 67, 69, 66, 77, 47, 79, 93, 29, 50, 57, 83, 17, 81, 41, 12, 37, 59, 25, 20, 80, 73, 1, 28, 6, 46, 62, 82, 53, 9, 31, 75, 38, 63, 33, 74, 27, 22, 36, 3, 16, 21, 60, 19, 70, 90, 89, 43, 5, 42, 65, 76, 40, 30, 23, 85, 2, 95, 56, 48, 71, 64, 98, 13, 99, 7, 34, 55, 54, 26, 35, 39] 2 | -------------------------------------------------------------------------------- /run_experiment_img100.sh: -------------------------------------------------------------------------------- 1 | #!bin/bash 2 | 3 | 4 | python epoch.py \ 5 | --config-path configs/class \ 6 | --config-name imagenet100_10-10.yaml \ 7 | dataset_root="path" \ 8 | class_order="class_orders/imagenet100.yaml" 9 | 10 | python main.py \ 11 | --config-path configs/class \ 12 | --config-name imagenet100_10-10.yaml \ 13 | dataset_root="path" \ 14 | class_order="class_orders/imagenet100.yaml" -------------------------------------------------------------------------------- /run_experiment_cifar.sh: -------------------------------------------------------------------------------- 1 | #!bin/bash 2 | 3 | 4 | 5 | python epoch.py \ 6 | --config-path configs/class \ 7 | --config-name cifar100_10-10.yaml \ 8 | dataset_root="path" \ 9 | class_order="class_orders/cifar100.yaml" 10 | 11 | python main.py \ 12 | --config-path configs/class \ 13 | --config-name cifar100_10-10.yaml \ 14 | dataset_root="path" \ 15 | class_order="class_orders/cifar100.yaml" 16 | 17 | 18 | -------------------------------------------------------------------------------- /run_experiment.sh: -------------------------------------------------------------------------------- 1 | #!bin/bash 2 | 3 | 4 | 5 | python epoch.py \ 6 | --config-path configs/class \ 7 | --config-name imagenet_r_20-20.yaml \ 8 | dataset_root="path" \ 9 | class_order="class_orders/imagenet_R.yaml" 10 | 11 | python main.py \ 12 | --config-path configs/class \ 13 | --config-name imagenet_r_20-20.yaml \ 14 | dataset_root="path" \ 15 | class_order="class_orders/imagenet_R.yaml" 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | continuum 2 | hydra-core==1.2.0 3 | numpy 4 | oauthlib==3.2.1 5 | omegaconf==2.2.3 6 | open-clip-torch==1.3.0 7 | pandas==1.4.3 8 | Pillow==9.2.0 9 | pipreqs==0.4.11 10 | scikit-image==0.19.3 11 | scikit-learn==1.1.1 12 | scipy==1.8.1 13 | tensorboard==2.10.0 14 | timm @ git+https://github.com/Arnav0400/pytorch-image-models.git@ceea7127c1ef608179ba06eaeddc22ad3ef22de0 15 | tokenizers==0.12.1 16 | tqdm==4.64.0 17 | transformers==4.21.1 18 | ftfy 19 | regex 20 | 21 | -------------------------------------------------------------------------------- /class_orders/imagenet100.yaml: -------------------------------------------------------------------------------- 1 | class_order: [ 2 | 68, 56, 78, 8, 23, 84, 90, 65, 74, 76, 40, 3 | 89, 3, 92, 55, 9, 26, 80, 43, 38, 58, 70, 4 | 77, 1, 85, 19, 17, 50, 28, 53, 13, 81, 45, 5 | 82, 6, 59, 83, 16, 15, 44, 91, 41, 72, 60, 6 | 79, 52, 20, 10, 31, 54, 37, 95, 14, 71, 96, 7 | 98, 97, 2, 64, 66, 42, 22, 35, 86, 24, 34, 8 | 87, 21, 99, 0, 88, 27, 18, 94, 11, 12, 47, 9 | 25, 30, 46, 62, 69, 36, 61, 7, 63, 75, 5, 32, 10 | 4, 51, 48, 73, 93, 39, 67, 29, 49, 57, 33 11 | ] 12 | -------------------------------------------------------------------------------- /setup_environment.sh: -------------------------------------------------------------------------------- 1 | #!bin/bash 2 | 3 | # create enviroment using Miniconda (or Anaconda) 4 | conda create -n continual_clip python=3.8 5 | conda activate continual_clip 6 | 7 | # install pytorch 8 | pip install torch==1.8.2 torchvision==0.9.2 torchaudio==0.8.2 \ 9 | --extra-index-url https://download.pytorch.org/whl/lts/1.8/cu111 10 | 11 | # install other dependencies 12 | pip install -r requirements.txt 13 | 14 | # install CLIP 15 | pip install git+https://github.com/openai/CLIP.git 16 | 17 | -------------------------------------------------------------------------------- /dataset_reqs/oxford_pet_classes.txt: -------------------------------------------------------------------------------- 1 | 1 abyssinian 2 | 2 american_bulldog 3 | 3 american_pit_bull_terrier 4 | 4 basset_hound 5 | 5 beagle 6 | 6 bengal 7 | 7 birman 8 | 8 bombay 9 | 9 boxer 10 | 10 british_shorthair 11 | 11 chihuahua 12 | 12 egyptian_mau 13 | 13 english_cocker_spaniel 14 | 14 english_setter 15 | 15 german_shorthaired 16 | 16 great_pyrenees 17 | 17 havanese 18 | 18 japanese_chin 19 | 19 keeshond 20 | 20 leonberger 21 | 21 maine_coon 22 | 22 miniature_pinscher 23 | 23 newfoundland 24 | 24 persian 25 | 25 pomeranian 26 | 26 pug 27 | 27 ragdoll 28 | 28 russian_blue 29 | 29 saint_bernard 30 | 30 samoyed 31 | 31 scottish_terrier 32 | 32 shiba_inu 33 | 33 siamese 34 | 34 sphynx 35 | 35 staffordshire_bull_terrier 36 | 36 wheaten_terrier 37 | 37 yorkshire_terrier 38 | -------------------------------------------------------------------------------- /dataset_reqs/vtab_classes.txt: -------------------------------------------------------------------------------- 1 | airplane 2 | airport 3 | baseball diamond 4 | basketball court 5 | beach 6 | bridge 7 | chaparral 8 | church 9 | circular farmland 10 | cloud 11 | banded 12 | blotchy 13 | braided 14 | bubbly 15 | bumpy 16 | chequered 17 | cobwebbed 18 | cracked 19 | crosshatched 20 | crystalline 21 | Abyssinian Cat 22 | American Bulldog 23 | American Pit Bull Terrier 24 | Basset Hound 25 | Beagle 26 | Bengal Cat 27 | Birman 28 | Bombay Cat 29 | Boxer Dog 30 | British Shorthair Cat 31 | AnnualCrop 32 | Forest 33 | HerbaceousVegetation 34 | Highway 35 | Industrial 36 | Pasture 37 | PermanentCrop 38 | Residential 39 | River 40 | SeaLake 41 | petunia 42 | wild pansy 43 | primula 44 | sunflower 45 | pelargonium 46 | bishop of llandaff 47 | gaura 48 | geranium 49 | orange dahlia 50 | pink-yellow dahlia -------------------------------------------------------------------------------- /class_orders/imagenet_R.yaml: -------------------------------------------------------------------------------- 1 | class_order: [150, 119, 193, 2, 75, 179, 129, 68, 145, 53, 171, 133, 31, 103, 146, 174, 149, 40, 142, 30, 183, 156, 167, 166, 124, 56, 109, 198, 185, 170, 51, 55, 77, 32, 45, 23, 115, 39, 134, 177, 12, 126, 130, 37, 19, 34, 0, 135, 161, 108, 128, 18, 141, 24, 154, 160, 78, 85, 89, 73, 152, 112, 114, 104, 44, 93, 138, 107, 100, 163, 3, 148, 147, 35, 199, 50, 25, 96, 79, 172, 98, 140, 22, 162, 6, 27, 60, 144, 197, 67, 165, 196, 102, 71, 65, 47, 9, 122, 42, 63, 61, 180, 97, 84, 157, 168, 69, 41, 83, 92, 190, 191, 59, 72, 111, 155, 105, 57, 17, 116, 13, 28, 43, 136, 5, 14, 169, 7, 62, 110, 189, 95, 8, 4, 137, 194, 90, 74, 70, 113, 186, 132, 192, 33, 36, 158, 188, 94, 121, 164, 58, 11, 178, 151, 184, 1, 10, 64, 15, 106, 52, 176, 26, 81, 82, 88, 125, 46, 123, 91, 187, 86, 20, 76, 80, 66, 120, 131, 16, 159, 101, 87, 99, 195, 143, 21, 49, 38, 181, 29, 153, 54, 173, 127, 117, 182, 139, 118, 48, 175] 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /configs/class/cifar100_10-10.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: ./experiments/${scenario}/${dataset}_${initial_increment}-${increment} 4 | job: 5 | chdir: true 6 | 7 | job_logging: 8 | version: 1 9 | formatters: 10 | simple: 11 | format: '%(message)s' 12 | 13 | class_order: "" 14 | dataset_root: "" 15 | workdir: "" 16 | log_path: "metric.json" 17 | model_name: "ViT-B/16" 18 | prompt_template: "a good photo of a {}." 19 | #128 20 | batch_size: 128 21 | increment: 10 22 | initial_increment: 10 23 | scenario: "class" 24 | dataset: "cifar100" 25 | task_num: 10 26 | seed: 42 27 | 28 | epochs: 1 29 | train_batch_size: 128 30 | num_workers: 8 31 | lora_rank: 8 32 | lora_mode: "vision+only_kv+text" 33 | lr: 0.001 34 | reset: False 35 | only_reset_B: False 36 | freeze_A: False 37 | all_test: False 38 | 39 | 40 | visual_clsf: true 41 | visual_clsf_lr: 0.0005 42 | visual_clsf_batch_size: 64 43 | visual_clsf_epochs: 3 44 | 45 | 46 | 47 | real_replay: false 48 | balance_ft: False 49 | balance_epochs: 0 50 | -------------------------------------------------------------------------------- /loraclip/__init__.py: -------------------------------------------------------------------------------- 1 | from .lora_clip import * 2 | 3 | 4 | def print_trainable_parameters(model): 5 | """ 6 | Prints the number of trainable parameters in the model. 7 | """ 8 | trainable_params = 0 9 | all_param = 0 10 | trainable_params_names = [] 11 | 12 | for name, param in model.named_parameters(): 13 | all_param += param.numel() 14 | if param.requires_grad: 15 | trainable_params += param.numel() 16 | trainable_params_names.append(name) 17 | # if "lora_B" in name: 18 | # print(param.sum()) 19 | 20 | print(f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}") 21 | 22 | print("Trainable Parameters:") 23 | for name in trainable_params_names: 24 | print(name) 25 | # check = True 26 | # for name in trainable_params_names: 27 | # if "lora" not in name: 28 | # check = False 29 | 30 | # if check: 31 | # print("Are LoRA parameters correctly present? Yes") 32 | # else: 33 | # print("Are LORA parameters correctly present? No") -------------------------------------------------------------------------------- /configs/class/imagenet100_10-10.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: ./experiments/${scenario}/${dataset}_${initial_increment}-${increment} 4 | job: 5 | chdir: true 6 | 7 | job_logging: 8 | version: 1 9 | formatters: 10 | simple: 11 | format: '%(message)s' 12 | 13 | class_order: "" 14 | dataset_root: "" 15 | workdir: "" 16 | log_path: "metrics.json" 17 | model_name: "ViT-B/16" 18 | prompt_template: "a good photo of a {}." 19 | 20 | batch_size: 128 21 | initial_increment: 10 22 | increment: ${initial_increment} 23 | scenario: "class" 24 | dataset: "imagenet100" 25 | task_num: 10 26 | 27 | epochs: 2 28 | train_batch_size: 128 29 | num_workers: 8 30 | lora_rank: 8 31 | lora_mode: "vision+only_kv+text" 32 | lr: 0.0005 33 | reset: False 34 | only_reset_B: False 35 | freeze_A: False 36 | all_test: False 37 | weight_decay: 1e-4 38 | momentum: 0.9 39 | seed: 0 40 | 41 | 42 | visual_clsf: true 43 | visual_clsf_lr: 0.001 44 | visual_clsf_batch_size: 128 45 | visual_clsf_epochs: 3 46 | 47 | real_replay: false 48 | 49 | balance_ft: false 50 | balance_epochs: 1 51 | -------------------------------------------------------------------------------- /configs/class/imagenet_r_20-20.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: ./experiments/${scenario}/${dataset}_${initial_increment}-${increment} 4 | job: 5 | chdir: true 6 | 7 | job_logging: 8 | version: 1 9 | formatters: 10 | simple: 11 | format: '%(message)s' 12 | 13 | class_order: "" 14 | dataset_root: "" 15 | workdir: "" 16 | log_path: "metric.json" 17 | model_name: "ViT-B/16" 18 | prompt_template: "a good photo of a {}." 19 | 20 | batch_size: 32 21 | initial_increment: 20 22 | increment: ${initial_increment} 23 | scenario: "class" 24 | dataset: "imagenet_R" 25 | task_num: 10 26 | 27 | epochs: 2 28 | # 64 29 | train_batch_size: 64 30 | num_workers: 8 31 | lora_rank: 8 32 | lora_mode: "vision+only_kv+text" 33 | lr: 0.001 34 | reset: False 35 | only_reset_B: False 36 | freeze_A: False 37 | all_test: False 38 | weight_decay: 1e-4 39 | momentum: 0.9 40 | seed: 0 41 | 42 | 43 | visual_clsf: true 44 | visual_clsf_lr: 0.0005 45 | visual_clsf_batch_size: 32 46 | visual_clsf_epochs: 3 47 | 48 | real_replay: false 49 | 50 | 51 | balance_ft: False 52 | balance_epochs: 1 53 | 54 | -------------------------------------------------------------------------------- /continual_clip/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import json 4 | import yaml 5 | 6 | from omegaconf import DictConfig, OmegaConf 7 | import pdb 8 | 9 | 10 | 11 | def get_class_order(file_name: str) -> list: 12 | r"""TO BE DOCUMENTED""" 13 | with open(file_name, "r+") as f: 14 | data = yaml.safe_load(f) 15 | return data["class_order"] 16 | 17 | 18 | def get_class_ids_per_task(args): 19 | yield args.class_order[:args.initial_increment] 20 | for i in range(args.initial_increment, len(args.class_order), args.increment): 21 | yield args.class_order[i:i + args.increment] 22 | 23 | def get_class_names(classes_names, class_ids_per_task): 24 | return [classes_names[class_id] for class_id in class_ids_per_task] 25 | 26 | 27 | def get_dataset_class_names(workdir, dataset_name, long=False): 28 | with open(os.path.join(workdir, "dataset_reqs", f"{dataset_name}_classes.txt"), "r") as f: 29 | lines = f.read().splitlines() 30 | return [line.split("\t")[-1] for line in lines] 31 | 32 | 33 | def save_config(config: DictConfig) -> None: 34 | OmegaConf.save(config, "config.yaml") 35 | 36 | 37 | def get_workdir(path): 38 | split_path = path.split("/") 39 | workdir_idx = split_path.index("MindtheGap") 40 | return "/".join(split_path[:workdir_idx+1]) 41 | 42 | 43 | -------------------------------------------------------------------------------- /dataset_reqs/caltech101_classes.txt: -------------------------------------------------------------------------------- 1 | 1 accordion 2 | 2 airplanes 3 | 3 anchor 4 | 4 ant 5 | 5 barrel 6 | 6 bass 7 | 7 beaver 8 | 8 binocular 9 | 9 bonsai 10 | 10 brain 11 | 11 brontosaurus 12 | 12 buddha 13 | 13 butterfly 14 | 14 camera 15 | 15 cannon 16 | 16 car_side 17 | 17 ceiling_fan 18 | 18 cellphone 19 | 19 chair 20 | 20 chandelier 21 | 21 cougar_body 22 | 22 cougar_face 23 | 23 crab 24 | 24 crayfish 25 | 25 crocodile 26 | 26 crocodile_head 27 | 27 cup 28 | 28 dalmatian 29 | 29 dollar_bill 30 | 30 dolphin 31 | 31 dragonfly 32 | 32 electric_guitar 33 | 33 elephant 34 | 34 emu 35 | 35 euphonium 36 | 36 ewer 37 | 37 faces 38 | 38 faces_easy 39 | 39 ferry 40 | 40 flamingo 41 | 41 flamingo_head 42 | 42 garfield 43 | 43 gerenuk 44 | 44 gramophone 45 | 45 grand_piano 46 | 46 hawksbill 47 | 47 headphone 48 | 48 hedgehog 49 | 49 helicopter 50 | 50 ibis 51 | 51 inline_skate 52 | 52 joshua_tree 53 | 53 kangaroo 54 | 54 ketch 55 | 55 lamp 56 | 56 laptop 57 | 57 leopards 58 | 58 llama 59 | 59 lobster 60 | 60 lotus 61 | 61 mandolin 62 | 62 mayfly 63 | 63 menorah 64 | 64 metronome 65 | 65 minaret 66 | 66 motorbikes 67 | 67 nautilus 68 | 68 octopus 69 | 69 okapi 70 | 70 pagoda 71 | 71 panda 72 | 72 pigeon 73 | 73 pizza 74 | 74 platypus 75 | 75 pyramid 76 | 76 revolver 77 | 77 rhino 78 | 78 rooster 79 | 79 saxophone 80 | 80 schooner 81 | 81 scissors 82 | 82 scorpion 83 | 83 sea_horse 84 | 84 snoopy 85 | 85 soccer_ball 86 | 86 stapler 87 | 87 starfish 88 | 88 stegosaurus 89 | 89 stop_sign 90 | 90 strawberry 91 | 91 sunflower 92 | 92 tick 93 | 93 trilobite 94 | 94 umbrella 95 | 95 watch 96 | 96 water_lilly 97 | 97 wheelchair 98 | 98 wild_cat 99 | 99 windsor_chair 100 | 100 wrench 101 | 101 yin_yang 102 | -------------------------------------------------------------------------------- /dataset_reqs/food101_classes.txt: -------------------------------------------------------------------------------- 1 | apple_pie 2 | baby_back_ribs 3 | baklava 4 | beef_carpaccio 5 | beef_tartare 6 | beet_salad 7 | beignets 8 | bibimbap 9 | bread_pudding 10 | breakfast_burrito 11 | bruschetta 12 | caesar_salad 13 | cannoli 14 | caprese_salad 15 | carrot_cake 16 | ceviche 17 | cheesecake 18 | cheese_plate 19 | chicken_curry 20 | chicken_quesadilla 21 | chicken_wings 22 | chocolate_cake 23 | chocolate_mousse 24 | churros 25 | clam_chowder 26 | club_sandwich 27 | crab_cakes 28 | creme_brulee 29 | croque_madame 30 | cup_cakes 31 | deviled_eggs 32 | donuts 33 | dumplings 34 | edamame 35 | eggs_benedict 36 | escargots 37 | falafel 38 | filet_mignon 39 | fish_and_chips 40 | foie_gras 41 | french_fries 42 | french_onion_soup 43 | french_toast 44 | fried_calamari 45 | fried_rice 46 | frozen_yogurt 47 | garlic_bread 48 | gnocchi 49 | greek_salad 50 | grilled_cheese_sandwich 51 | grilled_salmon 52 | guacamole 53 | gyoza 54 | hamburger 55 | hot_and_sour_soup 56 | hot_dog 57 | huevos_rancheros 58 | hummus 59 | ice_cream 60 | lasagna 61 | lobster_bisque 62 | lobster_roll_sandwich 63 | macaroni_and_cheese 64 | macarons 65 | miso_soup 66 | mussels 67 | nachos 68 | omelette 69 | onion_rings 70 | oysters 71 | pad_thai 72 | paella 73 | pancakes 74 | panna_cotta 75 | peking_duck 76 | pho 77 | pizza 78 | pork_chop 79 | poutine 80 | prime_rib 81 | pulled_pork_sandwich 82 | ramen 83 | ravioli 84 | red_velvet_cake 85 | risotto 86 | samosa 87 | sashimi 88 | scallops 89 | seaweed_salad 90 | shrimp_and_grits 91 | spaghetti_bolognese 92 | spaghetti_carbonara 93 | spring_rolls 94 | steak 95 | strawberry_shortcake 96 | sushi 97 | tacos 98 | takoyaki 99 | tiramisu 100 | tuna_tartare 101 | waffles 102 | -------------------------------------------------------------------------------- /loraclip/loralib/utils.py: -------------------------------------------------------------------------------- 1 | # Source: https://github.com/SivanDoveh/TSVLC/blob/main/src/open_clip/loralib/layers.py 2 | # ------------------------------------------------------------------------------------------ 3 | # Copyright (c) Microsoft Corporation. All rights reserved. 4 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 5 | # ------------------------------------------------------------------------------------------ 6 | # 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Dict 11 | 12 | from .layers import LoRALayer 13 | 14 | 15 | def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None: 16 | for n, p in model.named_parameters(): 17 | if 'lora_' not in n: 18 | p.requires_grad = False 19 | # print(f'Freezing {n}') 20 | # else: 21 | # print(f'Optimizing {n}') 22 | if bias == 'none': 23 | return 24 | elif bias == 'all': 25 | for n, p in model.named_parameters(): 26 | if 'bias' in n: 27 | p.requires_grad = True 28 | elif bias == 'lora_only': 29 | for m in model.modules(): 30 | if isinstance(m, LoRALayer) and \ 31 | hasattr(m, 'bias') and \ 32 | m.bias is not None: 33 | m.bias.requires_grad = True 34 | else: 35 | raise NotImplementedError 36 | 37 | 38 | def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]: 39 | my_state_dict = model.state_dict() 40 | if bias == 'none': 41 | return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k} 42 | elif bias == 'all': 43 | return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k} 44 | elif bias == 'lora_only': 45 | to_return = {} 46 | for k in my_state_dict: 47 | if 'lora_' in k: 48 | to_return[k] = my_state_dict[k] 49 | bias_name = k.split('lora_')[0]+'bias' 50 | if bias_name in my_state_dict: 51 | to_return[bias_name] = my_state_dict[bias_name] 52 | return to_return 53 | else: 54 | raise NotImplementedError 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | # lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # custom 132 | experiments/* 133 | run_cifar100.sh 134 | run_imagenet100.sh 135 | run_imagenet1000.sh 136 | run_tinyimagenet.sh 137 | vit_nms* 138 | continual_clip_nms* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [ICCV 2025 Highlight] Mind the Gap: Preserving and Compensating for the Modality Gap in CLIP-Based Continual Learning 2 | This is the official code for our paper 3 | : 4 | 5 | ## Getting Started 6 | 7 | ## Environment 8 | The environment is the same as that of our [RAPF](https://github.com/linlany/RAPF). 9 | 10 | create enviroment using Miniconda (or Anaconda) 11 | ``` 12 | conda create -n continual_clip python=3.8 13 | conda activate continual_clip 14 | ``` 15 | install dependencies: 16 | ```bash 17 | bash setup_environment.sh 18 | ``` 19 | ### Running scripts 20 | 21 | We provide the scripts for imagenet100. Please run: 22 | 23 | ``` 24 | python main.py \ 25 | --config-path configs/class \ 26 | --config-name imagenet100_10-10.yaml \ 27 | dataset_root="[imagenet1k_path]" \ 28 | class_order="class_orders/imagenet100.yaml" 29 | ``` 30 | **Note:** To obtain the epoch parameter from the first task described in Eq. (3), please run the epoch.py file. 31 | 32 | 33 | The dataset_root folder should contain the train and val folders. 34 | ``` 35 | imagenet1k_path 36 | ├── train 37 | │   ├── n01440764 38 | │   └── ··· 39 | ├── val 40 | │   ├── n01440764 41 | │   └── ··· 42 | 43 | imagenet-r_path 44 | ├── train 45 | │   ├── n01443537 46 | │   └── ··· 47 | ├── val 48 | │   ├── n01443537 49 | │   └── ··· 50 | 51 | ``` 52 | 53 | The command to run the other two datasets is similar, in run_experiment.sh 54 | 55 | ### datasets 56 | Cifar100 will download automatically. 57 | Imagenet-R is randomly splited. You can also use our splited list in RAPF/imgr_split/imgr_train_test_split.txt. 58 | 59 | The format of imgr_train_test_split.txt: 60 | ``` 61 | train 62 | n02051845/art_0.jpg 63 | ... 64 | test 65 | n02051845/tattoo_4.jpg 66 | ... 67 | ``` 68 | 69 | ## Acknowledgement 70 | Our method implementation is based on the [Continual-CLIP](https://github.com/vgthengane/Continual-CLIP). 71 | 72 | ## Citation 73 | 74 | If you find our repo useful for your research, please consider citing our paper: 75 | 76 | ```bibtex 77 | @inproceedings{huang2025mind, 78 | title={Mind the gap: Preserving and compensating for the modality gap in clip-based continual learning}, 79 | author={Huang, Linlan and Cao, Xusheng and Lu, Haori and Meng, Yifan and Yang, Fei and Liu, Xialei}, 80 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 81 | pages={3777--3786}, 82 | year={2025} 83 | } 84 | ``` 85 | 86 | ## License 87 | This code is licensed under the [Creative Commons Attribution-NonCommercial 4.0 International](https://creativecommons.org/licenses/by-nc/4.0/) for non-commercial use only. 88 | Please note that any commercial use of this code requires formal permission prior to use. 89 | 90 | ## Contact 91 | 92 | For technical questions, please contact huanglinlan@mail.nankai.edu.cn 93 | -------------------------------------------------------------------------------- /dataset_reqs/imagenet100_classes.txt: -------------------------------------------------------------------------------- 1 | 0 n01440764 tench 2 | 1 n01443537 goldfish 3 | 2 n01484850 great white shark 4 | 3 n01491361 tiger shark 5 | 4 n01494475 hammerhead shark 6 | 5 n01496331 electric ray 7 | 6 n01498041 stingray 8 | 7 n01514668 rooster 9 | 8 n01514859 hen 10 | 9 n01518878 ostrich 11 | 10 n01530575 brambling 12 | 11 n01531178 goldfinch 13 | 12 n01532829 house finch 14 | 13 n01534433 junco 15 | 14 n01537544 indigo bunting 16 | 15 n01558993 American robin 17 | 16 n01560419 bulbul 18 | 17 n01580077 jay 19 | 18 n01582220 magpie 20 | 19 n01592084 chickadee 21 | 20 n01601694 American dipper 22 | 21 n01608432 kite (bird of prey) 23 | 22 n01614925 bald eagle 24 | 23 n01616318 vulture 25 | 24 n01622779 great grey owl 26 | 25 n01629819 fire salamander 27 | 26 n01630670 smooth newt 28 | 27 n01631663 newt 29 | 28 n01632458 spotted salamander 30 | 29 n01632777 axolotl 31 | 30 n01641577 American bullfrog 32 | 31 n01644373 tree frog 33 | 32 n01644900 tailed frog 34 | 33 n01664065 loggerhead sea turtle 35 | 34 n01665541 leatherback sea turtle 36 | 35 n01667114 mud turtle 37 | 36 n01667778 terrapin 38 | 37 n01669191 box turtle 39 | 38 n01675722 banded gecko 40 | 39 n01677366 green iguana 41 | 40 n01682714 Carolina anole 42 | 41 n01685808 desert grassland whiptail lizard 43 | 42 n01687978 agama 44 | 43 n01688243 frilled-necked lizard 45 | 44 n01689811 alligator lizard 46 | 45 n01692333 Gila monster 47 | 46 n01693334 European green lizard 48 | 47 n01694178 chameleon 49 | 48 n01695060 Komodo dragon 50 | 49 n01697457 Nile crocodile 51 | 50 n01698640 American alligator 52 | 51 n01704323 triceratops 53 | 52 n01728572 worm snake 54 | 53 n01728920 ring-necked snake 55 | 54 n01729322 eastern hog-nosed snake 56 | 55 n01729977 smooth green snake 57 | 56 n01734418 kingsnake 58 | 57 n01735189 garter snake 59 | 58 n01737021 water snake 60 | 59 n01739381 vine snake 61 | 60 n01740131 night snake 62 | 61 n01742172 boa constrictor 63 | 62 n01744401 African rock python 64 | 63 n01748264 Indian cobra 65 | 64 n01749939 green mamba 66 | 65 n01751748 sea snake 67 | 66 n01753488 Saharan horned viper 68 | 67 n01755581 eastern diamondback rattlesnake 69 | 68 n01756291 sidewinder rattlesnake 70 | 69 n01768244 trilobite 71 | 70 n01770081 harvestman 72 | 71 n01770393 scorpion 73 | 72 n01773157 yellow garden spider 74 | 73 n01773549 barn spider 75 | 74 n01773797 European garden spider 76 | 75 n01774384 southern black widow 77 | 76 n01774750 tarantula 78 | 77 n01775062 wolf spider 79 | 78 n01776313 tick 80 | 79 n01784675 centipede 81 | 80 n01795545 black grouse 82 | 81 n01796340 ptarmigan 83 | 82 n01797886 ruffed grouse 84 | 83 n01798484 prairie grouse 85 | 84 n01806143 peafowl 86 | 85 n01806567 quail 87 | 86 n01807496 partridge 88 | 87 n01817953 african grey parrot 89 | 88 n01818515 macaw 90 | 89 n01819313 sulphur-crested cockatoo 91 | 90 n01820546 lorikeet 92 | 91 n01824575 coucal 93 | 92 n01828970 bee eater 94 | 93 n01829413 hornbill 95 | 94 n01833805 hummingbird 96 | 95 n01843065 jacamar 97 | 96 n01843383 toucan 98 | 97 n01847000 duck 99 | 98 n01855032 red-breasted merganser 100 | 99 n01855672 goose 101 | -------------------------------------------------------------------------------- /dataset_reqs/cub200_classes.txt: -------------------------------------------------------------------------------- 1 | 1 Black_footed_Albatross 2 | 2 Laysan_Albatross 3 | 3 Sooty_Albatross 4 | 4 Groove_billed_Ani 5 | 5 Crested_Auklet 6 | 6 Least_Auklet 7 | 7 Parakeet_Auklet 8 | 8 Rhinoceros_Auklet 9 | 9 Brewer_Blackbird 10 | 10 Red_winged_Blackbird 11 | 11 Rusty_Blackbird 12 | 12 Yellow_headed_Blackbird 13 | 13 Bobolink 14 | 14 Indigo_Bunting 15 | 15 Lazuli_Bunting 16 | 16 Painted_Bunting 17 | 17 Cardinal 18 | 18 Spotted_Catbird 19 | 19 Gray_Catbird 20 | 20 Yellow_breasted_Chat 21 | 21 Eastern_Towhee 22 | 22 Chuck_will_Widow 23 | 23 Brandt_Cormorant 24 | 24 Red_faced_Cormorant 25 | 25 Pelagic_Cormorant 26 | 26 Bronzed_Cowbird 27 | 27 Shiny_Cowbird 28 | 28 Brown_Creeper 29 | 29 American_Crow 30 | 30 Fish_Crow 31 | 31 Black_billed_Cuckoo 32 | 32 Mangrove_Cuckoo 33 | 33 Yellow_billed_Cuckoo 34 | 34 Gray_crowned_Rosy_Finch 35 | 35 Purple_Finch 36 | 36 Northern_Flicker 37 | 37 Acadian_Flycatcher 38 | 38 Great_Crested_Flycatcher 39 | 39 Least_Flycatcher 40 | 40 Olive_sided_Flycatcher 41 | 41 Scissor_tailed_Flycatcher 42 | 42 Vermilion_Flycatcher 43 | 43 Yellow_bellied_Flycatcher 44 | 44 Frigatebird 45 | 45 Northern_Fulmar 46 | 46 Gadwall 47 | 47 American_Goldfinch 48 | 48 European_Goldfinch 49 | 49 Boat_tailed_Grackle 50 | 50 Eared_Grebe 51 | 51 Horned_Grebe 52 | 52 Pied_billed_Grebe 53 | 53 Western_Grebe 54 | 54 Blue_Grosbeak 55 | 55 Evening_Grosbeak 56 | 56 Pine_Grosbeak 57 | 57 Rose_breasted_Grosbeak 58 | 58 Pigeon_Guillemot 59 | 59 California_Gull 60 | 60 Glaucous_winged_Gull 61 | 61 Heermann_Gull 62 | 62 Herring_Gull 63 | 63 Ivory_Gull 64 | 64 Ring_billed_Gull 65 | 65 Slaty_backed_Gull 66 | 66 Western_Gull 67 | 67 Anna_Hummingbird 68 | 68 Ruby_throated_Hummingbird 69 | 69 Rufous_Hummingbird 70 | 70 Green_Violetear 71 | 71 Long_tailed_Jaeger 72 | 72 Pomarine_Jaeger 73 | 73 Blue_Jay 74 | 74 Florida_Jay 75 | 75 Green_Jay 76 | 76 Dark_eyed_Junco 77 | 77 Tropical_Kingbird 78 | 78 Gray_Kingbird 79 | 79 Belted_Kingfisher 80 | 80 Green_Kingfisher 81 | 81 Pied_Kingfisher 82 | 82 Ringed_Kingfisher 83 | 83 White_breasted_Kingfisher 84 | 84 Red_legged_Kittiwake 85 | 85 Horned_Lark 86 | 86 Pacific_Loon 87 | 87 Mallard 88 | 88 Western_Meadowlark 89 | 89 Hooded_Merganser 90 | 90 Red_breasted_Merganser 91 | 91 Mockingbird 92 | 92 Nighthawk 93 | 93 Clark_Nutcracker 94 | 94 White_breasted_Nuthatch 95 | 95 Baltimore_Oriole 96 | 96 Hooded_Oriole 97 | 97 Orchard_Oriole 98 | 98 Scott_Oriole 99 | 99 Ovenbird 100 | 100 Brown_Pelican 101 | 101 White_Pelican 102 | 102 Western_Wood_Pewee 103 | 103 Sayornis 104 | 104 American_Pipit 105 | 105 Whip_poor_Will 106 | 106 Horned_Puffin 107 | 107 Common_Raven 108 | 108 White_necked_Raven 109 | 109 American_Redstart 110 | 110 Geococcyx 111 | 111 Loggerhead_Shrike 112 | 112 Great_Grey_Shrike 113 | 113 Baird_Sparrow 114 | 114 Black_throated_Sparrow 115 | 115 Brewer_Sparrow 116 | 116 Chipping_Sparrow 117 | 117 Clay_colored_Sparrow 118 | 118 House_Sparrow 119 | 119 Field_Sparrow 120 | 120 Fox_Sparrow 121 | 121 Grasshopper_Sparrow 122 | 122 Harris_Sparrow 123 | 123 Henslow_Sparrow 124 | 124 Le_Conte_Sparrow 125 | 125 Lincoln_Sparrow 126 | 126 Nelson_Sharp_tailed_Sparrow 127 | 127 Savannah_Sparrow 128 | 128 Seaside_Sparrow 129 | 129 Song_Sparrow 130 | 130 Tree_Sparrow 131 | 131 Vesper_Sparrow 132 | 132 White_crowned_Sparrow 133 | 133 White_throated_Sparrow 134 | 134 Cape_Glossy_Starling 135 | 135 Bank_Swallow 136 | 136 Barn_Swallow 137 | 137 Cliff_Swallow 138 | 138 Tree_Swallow 139 | 139 Scarlet_Tanager 140 | 140 Summer_Tanager 141 | 141 Artic_Tern 142 | 142 Black_Tern 143 | 143 Caspian_Tern 144 | 144 Common_Tern 145 | 145 Elegant_Tern 146 | 146 Forsters_Tern 147 | 147 Least_Tern 148 | 148 Green_tailed_Towhee 149 | 149 Brown_Thrasher 150 | 150 Sage_Thrasher 151 | 151 Black_capped_Vireo 152 | 152 Blue_headed_Vireo 153 | 153 Philadelphia_Vireo 154 | 154 Red_eyed_Vireo 155 | 155 Warbling_Vireo 156 | 156 White_eyed_Vireo 157 | 157 Yellow_throated_Vireo 158 | 158 Bay_breasted_Warbler 159 | 159 Black_and_white_Warbler 160 | 160 Black_throated_Blue_Warbler 161 | 161 Blue_winged_Warbler 162 | 162 Canada_Warbler 163 | 163 Cape_May_Warbler 164 | 164 Cerulean_Warbler 165 | 165 Chestnut_sided_Warbler 166 | 166 Golden_winged_Warbler 167 | 167 Hooded_Warbler 168 | 168 Kentucky_Warbler 169 | 169 Magnolia_Warbler 170 | 170 Mourning_Warbler 171 | 171 Myrtle_Warbler 172 | 172 Nashville_Warbler 173 | 173 Orange_crowned_Warbler 174 | 174 Palm_Warbler 175 | 175 Pine_Warbler 176 | 176 Prairie_Warbler 177 | 177 Prothonotary_Warbler 178 | 178 Swainson_Warbler 179 | 179 Tennessee_Warbler 180 | 180 Wilson_Warbler 181 | 181 Worm_eating_Warbler 182 | 182 Yellow_Warbler 183 | 183 Northern_Waterthrush 184 | 184 Louisiana_Waterthrush 185 | 185 Bohemian_Waxwing 186 | 186 Cedar_Waxwing 187 | 187 American_Three_toed_Woodpecker 188 | 188 Pileated_Woodpecker 189 | 189 Red_bellied_Woodpecker 190 | 190 Red_cockaded_Woodpecker 191 | 191 Red_headed_Woodpecker 192 | 192 Downy_Woodpecker 193 | 193 Bewick_Wren 194 | 194 Cactus_Wren 195 | 195 Carolina_Wren 196 | 196 House_Wren 197 | 197 Marsh_Wren 198 | 198 Rock_Wren 199 | 199 Winter_Wren 200 | 200 Common_Yellowthroat 201 | -------------------------------------------------------------------------------- /dataset_reqs/imagenet_R_classes.txt: -------------------------------------------------------------------------------- 1 | 0 n01443537 goldfish 2 | 1 n01484850 great_white_shark 3 | 2 n01494475 hammerhead 4 | 3 n01498041 stingray 5 | 4 n01514859 hen 6 | 5 n01518878 ostrich 7 | 6 n01531178 goldfinch 8 | 7 n01534433 junco 9 | 8 n01614925 bald_eagle 10 | 9 n01616318 vulture 11 | 10 n01630670 common_newt 12 | 11 n01632777 axolotl 13 | 12 n01644373 tree_frog 14 | 13 n01677366 common_iguana 15 | 14 n01694178 African_chameleon 16 | 15 n01748264 Indian_cobra 17 | 16 n01770393 scorpion 18 | 17 n01774750 tarantula 19 | 18 n01784675 centipede 20 | 19 n01806143 peacock 21 | 20 n01820546 lorikeet 22 | 21 n01833805 hummingbird 23 | 22 n01843383 toucan 24 | 23 n01847000 drake 25 | 24 n01855672 goose 26 | 25 n01860187 black_swan 27 | 26 n01882714 koala 28 | 27 n01910747 jellyfish 29 | 28 n01944390 snail 30 | 29 n01983481 American_lobster 31 | 30 n01986214 hermit_crab 32 | 31 n02007558 flamingo 33 | 32 n02009912 American_egret 34 | 33 n02051845 pelican 35 | 34 n02056570 king_penguin 36 | 35 n02066245 grey_whale 37 | 36 n02071294 killer_whale 38 | 37 n02077923 sea_lion 39 | 38 n02085620 Chihuahua 40 | 39 n02086240 Shih-Tzu 41 | 40 n02088094 Afghan_hound 42 | 41 n02088238 basset 43 | 42 n02088364 beagle 44 | 43 n02088466 bloodhound 45 | 44 n02091032 Italian_greyhound 46 | 45 n02091134 whippet 47 | 46 n02092339 Weimaraner 48 | 47 n02094433 Yorkshire_terrier 49 | 48 n02096585 Boston_bull 50 | 49 n02097298 Scotch_terrier 51 | 50 n02098286 West_Highland_white_terrier 52 | 51 n02099601 golden_retriever 53 | 52 n02099712 Labrador_retriever 54 | 53 n02102318 cocker_spaniel 55 | 54 n02106030 collie 56 | 55 n02106166 Border_collie 57 | 56 n02106550 Rottweiler 58 | 57 n02106662 German_shepherd 59 | 58 n02108089 boxer 60 | 59 n02108915 French_bulldog 61 | 60 n02109525 Saint_Bernard 62 | 61 n02110185 Siberian_husky 63 | 62 n02110341 dalmatian 64 | 63 n02110958 pug 65 | 64 n02112018 Pomeranian 66 | 65 n02112137 chow 67 | 66 n02113023 Pembroke 68 | 67 n02113624 toy_poodle 69 | 68 n02113799 standard_poodle 70 | 69 n02114367 timber_wolf 71 | 70 n02117135 hyena 72 | 71 n02119022 red_fox 73 | 72 n02123045 tabby 74 | 73 n02128385 leopard 75 | 74 n02128757 snow_leopard 76 | 75 n02129165 lion 77 | 76 n02129604 tiger 78 | 77 n02130308 cheetah 79 | 78 n02134084 ice_bear 80 | 79 n02138441 meerkat 81 | 80 n02165456 ladybug 82 | 81 n02190166 fly 83 | 82 n02206856 bee 84 | 83 n02219486 ant 85 | 84 n02226429 grasshopper 86 | 85 n02233338 cockroach 87 | 86 n02236044 mantis 88 | 87 n02268443 dragonfly 89 | 88 n02279972 monarch 90 | 89 n02317335 starfish 91 | 90 n02325366 wood_rabbit 92 | 91 n02346627 porcupine 93 | 92 n02356798 fox_squirrel 94 | 93 n02363005 beaver 95 | 94 n02364673 guinea_pig 96 | 95 n02391049 zebra 97 | 96 n02395406 hog 98 | 97 n02398521 hippopotamus 99 | 98 n02410509 bison 100 | 99 n02423022 gazelle 101 | 100 n02437616 llama 102 | 101 n02445715 skunk 103 | 102 n02447366 badger 104 | 103 n02480495 orangutan 105 | 104 n02480855 gorilla 106 | 105 n02481823 chimpanzee 107 | 106 n02483362 gibbon 108 | 107 n02486410 baboon 109 | 108 n02510455 giant_panda 110 | 109 n02526121 eel 111 | 110 n02607072 anemone_fish 112 | 111 n02655020 puffer 113 | 112 n02672831 accordion 114 | 113 n02701002 ambulance 115 | 114 n02749479 assault_rifle 116 | 115 n02769748 backpack 117 | 116 n02793495 barn 118 | 117 n02797295 barrow 119 | 118 n02802426 basketball 120 | 119 n02808440 bathtub 121 | 120 n02814860 beacon 122 | 121 n02823750 beer_glass 123 | 122 n02841315 binoculars 124 | 123 n02843684 birdhouse 125 | 124 n02883205 bow_tie 126 | 125 n02906734 broom 127 | 126 n02909870 bucket 128 | 127 n02939185 caldron 129 | 128 n02948072 candle 130 | 129 n02950826 cannon 131 | 130 n02951358 canoe 132 | 131 n02966193 carousel 133 | 132 n02980441 castle 134 | 133 n02992529 cellular_telephone 135 | 134 n03124170 cowboy_hat 136 | 135 n03272010 electric_guitar 137 | 136 n03345487 fire_engine 138 | 137 n03372029 flute 139 | 138 n03424325 gasmask 140 | 139 n03452741 grand_piano 141 | 140 n03467068 guillotine 142 | 141 n03481172 hammer 143 | 142 n03494278 harmonica 144 | 143 n03495258 harp 145 | 144 n03498962 hatchet 146 | 145 n03594945 jeep 147 | 146 n03602883 joystick 148 | 147 n03630383 lab_coat 149 | 148 n03649909 lawn_mower 150 | 149 n03676483 lipstick 151 | 150 n03710193 mailbox 152 | 151 n03773504 missile 153 | 152 n03775071 mitten 154 | 153 n03888257 parachute 155 | 154 n03930630 pickup 156 | 155 n03947888 pirate 157 | 156 n04086273 revolver 158 | 157 n04118538 rugby_ball 159 | 158 n04133789 sandal 160 | 159 n04141076 sax 161 | 160 n04146614 school_bus 162 | 161 n04147183 schooner 163 | 162 n04192698 shield 164 | 163 n04254680 soccer_ball 165 | 164 n04266014 space_shuttle 166 | 165 n04275548 spider_web 167 | 166 n04310018 steam_locomotive 168 | 167 n04325704 stole 169 | 168 n04347754 submarine 170 | 169 n04389033 tank 171 | 170 n04409515 tennis_ball 172 | 171 n04465501 tractor 173 | 172 n04487394 trombone 174 | 173 n04522168 vase 175 | 174 n04536866 violin 176 | 175 n04552348 warplane 177 | 176 n04591713 wine_bottle 178 | 177 n07614500 ice_cream 179 | 178 n07693725 bagel 180 | 179 n07695742 pretzel 181 | 180 n07697313 cheeseburger 182 | 181 n07697537 hotdog 183 | 182 n07714571 head_cabbage 184 | 183 n07714990 broccoli 185 | 184 n07718472 cucumber 186 | 185 n07720875 bell_pepper 187 | 186 n07734744 mushroom 188 | 187 n07742313 Granny_Smith 189 | 188 n07745940 strawberry 190 | 189 n07749582 lemon 191 | 190 n07753275 pineapple 192 | 191 n07753592 banana 193 | 192 n07768694 pomegranate 194 | 193 n07873807 pizza 195 | 194 n07880968 burrito 196 | 195 n07920052 espresso 197 | 196 n09472597 volcano 198 | 197 n09835506 ballplayer 199 | 198 n10565667 scuba_diver 200 | 199 n12267677 acorn 201 | -------------------------------------------------------------------------------- /dataset_reqs/tinyimagenet_classes.txt: -------------------------------------------------------------------------------- 1 | 0 n02124075 Egyptian Mau 2 | 1 n04067472 fishing casting reel 3 | 2 n04540053 volleyball 4 | 3 n04099969 rocking chair 5 | 4 n07749582 lemon 6 | 5 n01641577 American bullfrog 7 | 6 n02802426 basketball 8 | 7 n09246464 cliff 9 | 8 n07920052 espresso 10 | 9 n03970156 plunger 11 | 10 n03891332 parking meter 12 | 11 n02106662 German Shepherd Dog 13 | 12 n03201208 dining table 14 | 13 n02279972 monarch butterfly 15 | 14 n02132136 brown bear 16 | 15 n04146614 school bus 17 | 16 n07873807 pizza 18 | 17 n02364673 guinea pig 19 | 18 n04507155 umbrella 20 | 19 n03854065 pipe organ 21 | 20 n03838899 oboe 22 | 21 n03733131 maypole 23 | 22 n01443537 goldfish 24 | 23 n07875152 pot pie 25 | 24 n03544143 hourglass 26 | 25 n09428293 beach 27 | 26 n03085013 computer keyboard 28 | 27 n02437312 arabian camel 29 | 28 n07614500 ice cream 30 | 29 n03804744 metal nail 31 | 30 n04265275 space heater 32 | 31 n02963159 cardigan 33 | 32 n02486410 baboon 34 | 33 n01944390 snail 35 | 34 n09256479 coral reef 36 | 35 n02058221 albatross 37 | 36 n04275548 spider web 38 | 37 n02321529 sea cucumber 39 | 38 n02769748 backpack 40 | 39 n02099712 Labrador Retriever 41 | 40 n07695742 pretzel 42 | 41 n02056570 king penguin 43 | 42 n02281406 sulphur butterfly 44 | 43 n01774750 tarantula 45 | 44 n02509815 red panda 46 | 45 n03983396 soda bottle 47 | 46 n07753592 banana 48 | 47 n04254777 sock 49 | 48 n02233338 cockroach 50 | 49 n04008634 missile 51 | 50 n02823428 beer bottle 52 | 51 n02236044 praying mantis 53 | 52 n03393912 freight car 54 | 53 n07583066 guacamole 55 | 54 n04074963 remote control 56 | 55 n01629819 fire salamander 57 | 56 n09332890 lakeshore 58 | 57 n02481823 chimpanzee 59 | 58 n03902125 payphone 60 | 59 n03404251 fur coat 61 | 60 n09193705 mountain 62 | 61 n03637318 lampshade 63 | 62 n04456115 torch 64 | 63 n02666196 abacus 65 | 64 n03796401 moving van 66 | 65 n02795169 barrel 67 | 66 n02123045 tabby cat 68 | 67 n01855672 goose 69 | 68 n01882714 koala 70 | 69 n02917067 high-speed train 71 | 70 n02988304 CD player 72 | 71 n04398044 teapot 73 | 72 n02843684 birdhouse 74 | 73 n02423022 gazelle 75 | 74 n02669723 academic gown 76 | 75 n04465501 tractor 77 | 76 n02165456 ladybug 78 | 77 n03770439 miniskirt 79 | 78 n02099601 Golden Retriever 80 | 79 n04486054 triumphal arch 81 | 80 n02950826 cannon 82 | 81 n03814639 neck brace 83 | 82 n04259630 sombrero 84 | 83 n03424325 gas mask or respirator 85 | 84 n02948072 candle 86 | 85 n03179701 desk 87 | 86 n03400231 frying pan 88 | 87 n02206856 bee 89 | 88 n03160309 dam 90 | 89 n01984695 spiny lobster 91 | 90 n03977966 police van 92 | 91 n03584254 iPod 93 | 92 n04023962 punching bag 94 | 93 n02814860 lighthouse 95 | 94 n01910747 jellyfish 96 | 95 n04596742 wok 97 | 96 n03992509 potter's wheel 98 | 97 n04133789 sandal 99 | 98 n03937543 pill bottle 100 | 99 n02927161 butcher shop 101 | 100 n01945685 slug 102 | 101 n02395406 pig 103 | 102 n02125311 cougar 104 | 103 n03126707 construction crane 105 | 104 n04532106 vestment 106 | 105 n02268443 dragonfly 107 | 106 n02977058 automated teller machine 108 | 107 n07734744 mushroom 109 | 108 n03599486 rickshaw 110 | 109 n04562935 water tower 111 | 110 n03014705 storage chest 112 | 111 n04251144 snorkel 113 | 112 n04356056 sunglasses 114 | 113 n02190166 fly 115 | 114 n03670208 limousine 116 | 115 n02002724 black stork 117 | 116 n02074367 dugong 118 | 117 n04285008 sports car 119 | 118 n04560804 water jug 120 | 119 n04366367 suspension bridge 121 | 120 n02403003 ox 122 | 121 n07615774 popsicle 123 | 122 n04501370 turnstile 124 | 123 n03026506 Christmas stocking 125 | 124 n02906734 broom 126 | 125 n01770393 scorpion 127 | 126 n04597913 wooden spoon 128 | 127 n03930313 picket fence 129 | 128 n04118538 rugby ball 130 | 129 n04179913 sewing machine 131 | 130 n04311004 through arch bridge 132 | 131 n02123394 Persian cat 133 | 132 n04070727 refrigerator 134 | 133 n02793495 barn 135 | 134 n02730930 apron 136 | 135 n02094433 Yorkshire Terrier 137 | 136 n04371430 swim trunks / shorts 138 | 137 n04328186 stopwatch 139 | 138 n03649909 lawn mower 140 | 139 n04417672 thatched roof 141 | 140 n03388043 fountain 142 | 141 n01774384 southern black widow 143 | 142 n02837789 bikini 144 | 143 n07579787 plate 145 | 144 n04399382 teddy bear 146 | 145 n02791270 barbershop 147 | 146 n03089624 candy store 148 | 147 n02814533 station wagon 149 | 148 n04149813 scoreboard 150 | 149 n07747607 orange 151 | 150 n03355925 flagpole 152 | 151 n01983481 American lobster 153 | 152 n04487081 trolleybus 154 | 153 n03250847 drumstick 155 | 154 n03255030 dumbbell 156 | 155 n02892201 brass memorial plaque 157 | 156 n02883205 bow tie 158 | 157 n03100240 convertible 159 | 158 n02415577 bighorn sheep 160 | 159 n02480495 orangutan 161 | 160 n01698640 American alligator 162 | 161 n01784675 centipede 163 | 162 n04376876 syringe 164 | 163 n03444034 go-kart 165 | 164 n01917289 brain coral 166 | 165 n01950731 sea slug 167 | 166 n03042490 cliff dwelling 168 | 167 n07711569 mashed potatoes 169 | 168 n04532670 viaduct 170 | 169 n03763968 military uniform 171 | 170 n07768694 pomegranate 172 | 171 n02999410 chain 173 | 172 n03617480 kimono 174 | 173 n06596364 comic book 175 | 174 n01768244 trilobite 176 | 175 n02410509 bison 177 | 176 n03976657 pole 178 | 177 n01742172 boa constrictor 179 | 178 n03980874 poncho 180 | 179 n02808440 bathtub 181 | 180 n02226429 grasshopper 182 | 181 n02231487 stick insect 183 | 182 n02085620 Chihuahua 184 | 183 n01644900 tailed frog 185 | 184 n02129165 lion 186 | 185 n02699494 altar 187 | 186 n03837869 obelisk 188 | 187 n02815834 beaker 189 | 188 n07720875 bell pepper 190 | 189 n02788148 baluster / handrail 191 | 190 n02909870 bucket 192 | 191 n03706229 magnetic compass 193 | 192 n07871810 meatloaf 194 | 193 n03447447 gondola 195 | 194 n02113799 Standard Poodle 196 | 195 n12267677 acorn 197 | 196 n03662601 lifeboat 198 | 197 n02841315 binoculars 199 | 198 n07715103 cauliflower 200 | 199 n02504458 African bush elephant 201 | -------------------------------------------------------------------------------- /continual_clip/datasets.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | import torch.nn as nn 5 | 6 | from continuum import ClassIncremental, InstanceIncremental 7 | from continuum.datasets import ( 8 | CIFAR100, ImageNet100, TinyImageNet200, ImageFolderDataset, Core50, CUB200, Food101,OxfordPet,Caltech101 9 | ) 10 | from .utils import get_dataset_class_names, get_workdir 11 | import pdb 12 | 13 | 14 | class ImageNet_C(ImageFolderDataset): 15 | """Continuum dataset for datasets with tree-like structure. 16 | :param train_folder: The folder of the train data. 17 | :param test_folder: The folder of the test data. 18 | :param download: Dummy parameter. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | data_path: str, 24 | train: bool = True, 25 | download: bool = False, 26 | ): 27 | super().__init__(data_path=data_path, train=train, download=download) 28 | 29 | def get_data(self): 30 | self.data_path = self.data_path 31 | return super().get_data() 32 | 33 | 34 | class ImageNet1000(ImageFolderDataset): 35 | """Continuum dataset for datasets with tree-like structure. 36 | :param train_folder: The folder of the train data. 37 | :param test_folder: The folder of the test data. 38 | :param download: Dummy parameter. 39 | """ 40 | 41 | def __init__( 42 | self, 43 | data_path: str, 44 | train: bool = True, 45 | download: bool = False, 46 | ): 47 | super().__init__(data_path=data_path, train=train, download=download) 48 | 49 | def get_data(self): 50 | if self.train: 51 | self.data_path = os.path.join(self.data_path, "train") 52 | else: 53 | self.data_path = os.path.join(self.data_path, "val") 54 | return super().get_data() 55 | 56 | 57 | class ImageNet_R(ImageFolderDataset): 58 | """Continuum dataset for datasets with tree-like structure. 59 | :param train_folder: The folder of the train data. 60 | :param test_folder: The folder of the test data. 61 | :param download: Dummy parameter. 62 | """ 63 | 64 | def __init__( 65 | self, 66 | data_path: str, 67 | train: bool = True, 68 | download: bool = False, 69 | ): 70 | super().__init__(data_path=data_path, train=train, download=download) 71 | @property 72 | def transformations(self): 73 | """Default transformations if nothing is provided to the scenario.""" 74 | return [ 75 | transforms.Resize((224, 224)), 76 | transforms.ToTensor(), 77 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 78 | ] 79 | 80 | def get_data(self): 81 | if self.train: 82 | self.data_path = os.path.join(self.data_path, "train") 83 | else: 84 | self.data_path = os.path.join(self.data_path, "test") 85 | return super().get_data() 86 | 87 | class VTAB(ImageFolderDataset): 88 | """Continuum dataset for datasets with tree-like structure. 89 | :param train_folder: The folder of the train data. 90 | :param test_folder: The folder of the test data. 91 | :param download: Dummy parameter. 92 | """ 93 | 94 | def __init__( 95 | self, 96 | data_path: str, 97 | train: bool = True, 98 | download: bool = False, 99 | ): 100 | super().__init__(data_path=data_path, train=train, download=download) 101 | @property 102 | def transformations(self): 103 | """Default transformations if nothing is provided to the scenario.""" 104 | return [ 105 | transforms.Resize((224, 224)), 106 | transforms.ToTensor(), 107 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 108 | ] 109 | 110 | def get_data(self): 111 | if self.train: 112 | self.data_path = os.path.join(self.data_path, "train") 113 | else: 114 | self.data_path = os.path.join(self.data_path, "test") 115 | return super().get_data() 116 | 117 | 118 | def get_dataset(cfg, is_train, transforms=None): 119 | if cfg.dataset == "cifar100": 120 | data_path = cfg.dataset_root 121 | dataset = CIFAR100( 122 | data_path=data_path, 123 | download=True, 124 | train=is_train, 125 | # transforms=transforms 126 | ) 127 | classes_names = dataset.dataset.classes 128 | elif cfg.dataset == "imagenet_R": 129 | data_path = cfg.dataset_root 130 | dataset = ImageNet_R( 131 | data_path, 132 | train=is_train 133 | ) 134 | classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset) 135 | elif cfg.dataset == "cub200": 136 | data_path = cfg.dataset_root 137 | dataset = CUB200( 138 | data_path, 139 | train=is_train 140 | ) 141 | classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset) 142 | elif cfg.dataset == "food101": 143 | data_path = cfg.dataset_root 144 | dataset = Food101( 145 | data_path, 146 | train=is_train 147 | ) 148 | classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset) 149 | elif cfg.dataset == "oxford_pet": 150 | data_path = cfg.dataset_root 151 | dataset = OxfordPet( 152 | data_path, 153 | train=is_train 154 | ) 155 | classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset) 156 | elif cfg.dataset == "caltech101": 157 | data_path = cfg.dataset_root 158 | dataset = Caltech101( 159 | data_path, 160 | train=is_train 161 | ) 162 | classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset) 163 | 164 | elif cfg.dataset == "vtab": 165 | data_path = cfg.dataset_root 166 | dataset = VTAB( 167 | data_path, 168 | train=is_train 169 | ) 170 | classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset) 171 | elif cfg.dataset == "imagenet_c": 172 | data_path = cfg.dataset_root 173 | dataset = ImageNet_C( 174 | data_path, 175 | train=is_train 176 | ) 177 | classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset) 178 | 179 | elif cfg.dataset == "tinyimagenet": 180 | data_path = os.path.join(cfg.dataset_root, cfg.dataset) 181 | dataset = TinyImageNet200( 182 | data_path, 183 | train=is_train, 184 | download=True 185 | ) 186 | classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset) 187 | 188 | elif cfg.dataset == "imagenet100": 189 | data_path = cfg.dataset_root 190 | dataset = ImageNet100( 191 | data_path, 192 | train=is_train, 193 | data_subset=os.path.join(get_workdir(os.getcwd()), "class_orders/train_100.txt" if is_train else "class_orders/val_100.txt") 194 | ) 195 | classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset) 196 | 197 | elif cfg.dataset == "imagenet1000": 198 | data_path = cfg.dataset_root 199 | dataset = ImageNet1000( 200 | data_path, 201 | train=is_train 202 | ) 203 | classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset) 204 | 205 | elif cfg.dataset == "core50": 206 | data_path = os.path.join(cfg.dataset_root, cfg.dataset) 207 | dataset = dataset = Core50( 208 | data_path, 209 | scenario="domains", 210 | classification="category", 211 | train=is_train 212 | ) 213 | classes_names = [ 214 | "plug adapters", "mobile phones", "scissors", "light bulbs", "cans", 215 | "glasses", "balls", "markers", "cups", "remote controls" 216 | ] 217 | 218 | else: 219 | ValueError(f"'{cfg.dataset}' is a invalid dataset.") 220 | 221 | return dataset, classes_names 222 | 223 | 224 | def build_cl_scenarios(cfg, is_train, transforms) -> nn.Module: 225 | 226 | dataset, classes_names = get_dataset(cfg, is_train) 227 | 228 | if cfg.scenario == "class": 229 | scenario = ClassIncremental( 230 | dataset, 231 | initial_increment=cfg.initial_increment, 232 | increment=cfg.increment, 233 | transformations=transforms.transforms, # Convert Compose into list 234 | class_order=cfg.class_order, 235 | ) 236 | 237 | elif cfg.scenario == "domain": 238 | scenario = InstanceIncremental( 239 | dataset, 240 | transformations=transforms.transforms, 241 | ) 242 | 243 | elif cfg.scenario == "task-agnostic": 244 | NotImplementedError("Method has not been implemented. Soon be added.") 245 | 246 | else: 247 | ValueError(f"You have entered `{cfg.scenario}` which is not a defined scenario, " 248 | "please choose from {{'class', 'domain', 'task-agnostic'}}.") 249 | 250 | return scenario, classes_names -------------------------------------------------------------------------------- /continual_clip/models.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import pdb 4 | from omegaconf import DictConfig 5 | 6 | import clip 7 | import torch 8 | import torch.nn as nn 9 | import types 10 | from loraclip import lora_clip 11 | from clip.model import VisionTransformer as CLIPVisionTransformer 12 | from torch.nn import functional as F 13 | from .utils import get_class_ids_per_task, get_class_names 14 | import random 15 | import numpy as np 16 | 17 | 18 | 19 | 20 | 21 | def forward_clip(self, image, text, return_feature=False): 22 | image_features = self.encode_image(image) 23 | text_features = self.encode_text(text) 24 | 25 | # normalized features 26 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 27 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 28 | 29 | # cosine similarity as logits 30 | logit_scale = self.logit_scale.exp() 31 | logits_per_image = logit_scale * image_features @ text_features.t() 32 | logits_per_text = logits_per_image.t() 33 | 34 | if return_feature: 35 | return logits_per_image, logits_per_text, image_features, text_features 36 | 37 | 38 | # shape = [global_batch_size, global_batch_size] 39 | return logits_per_image, logits_per_text 40 | 41 | 42 | 43 | class VisionClassifier(nn.Module): 44 | def __init__(self, in_features, num_classes, weight_init=None, activation=None): 45 | super().__init__() 46 | self.fc = nn.Linear(in_features, num_classes, bias=False) 47 | self.fc = nn.Parameter(self.fc.weight.data) 48 | if weight_init is not None: 49 | self.fc.data = weight_init 50 | if activation is not None: 51 | self.activation = activation 52 | else: 53 | self.activation = nn.Identity() 54 | 55 | def add_weight(self, weight): 56 | self.fc = nn.Parameter(torch.cat([self.fc, weight], dim=0)) 57 | 58 | def set_weight(self, weight): 59 | self.fc = nn.Parameter(weight) 60 | 61 | 62 | def forward(self, x): 63 | # normalize the weights 64 | x = F.normalize(x, p=2, dim=-1) 65 | weight = F.normalize(self.fc, p=2, dim=-1) 66 | x = F.linear(x, weight) 67 | x = self.activation(x) 68 | return x 69 | 70 | 71 | 72 | class ClassIncrementalCLIP(nn.Module): 73 | def __init__(self, cfg, device, jit=False): 74 | super().__init__() 75 | self.cfg = cfg 76 | self.prompt_template = cfg.prompt_template 77 | self.device = device 78 | self.classes_names = None 79 | # self.model, self.transforms = clip.load(cfg.model_name, device=device, jit=jit) 80 | 81 | 82 | #lora_clip 83 | self.model, self.transforms = lora_clip.load(cfg.model_name, device=device, jit=jit, r=cfg.lora_rank, lora_mode=cfg.lora_mode) 84 | # for name, param in self.model.named_parameters(): 85 | # if 'adapter_mlp' in name: 86 | # param.requires_grad = True 87 | # for name, param in self.model.named_parameters(): 88 | # if param.requires_grad: 89 | # print(f"Trainable: {name}") 90 | self.model.forward = types.MethodType(forward_clip, self.model) 91 | ori_state = self.model.state_dict() 92 | 93 | self.class_ids_per_task = list(get_class_ids_per_task(cfg)) 94 | self.current_class_names = [] 95 | self.text_tokens = None 96 | self.current_task = -1 97 | self.only_reset_B = cfg.only_reset_B 98 | self.freeze_A = cfg.freeze_A 99 | 100 | 101 | def cur_text_features(self): 102 | f = self.model.encode_text(self.text_tokens) 103 | f = f / f.norm(dim=1, keepdim=True) 104 | return f 105 | 106 | def inference(self, image, text_tokens): 107 | text_features = self.model.encode_text(text_tokens) 108 | image_features = self.model.visual(image.type(self.model.dtype), all_tokens=False, adapt=self.attention_adapter) 109 | # pdb.set_trace() 110 | 111 | # image_features = self.attention_adapter(image_features.type(torch.float32))[:, 0, :] 112 | 113 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 114 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 115 | 116 | logit_scale = self.model.logit_scale.exp() 117 | logits_per_image = logit_scale * image_features @ text_features.t() 118 | return logits_per_image 119 | 120 | def forward(self, image, test=False, all_test=False, return_feature=False,replay=None): 121 | if test: 122 | # pdb.set_trace() 123 | with torch.no_grad(): 124 | if all_test: 125 | if return_feature: 126 | logits_per_image, _, image_features, __ = self.model(image, self.all_text_tokens, return_feature=return_feature) 127 | else: 128 | logits_per_image, _ = self.model(image, self.all_text_tokens) 129 | # logits_per_image = self.inference(image, self.all_text_tokens) 130 | else: 131 | if return_feature: 132 | logits_per_image, _, image_features, __ = self.model(image, self.text_tokens, return_feature=return_feature) 133 | else: 134 | logits_per_image, _ = self.model(image, self.text_tokens) 135 | # pdb.set_trace() 136 | probs = logits_per_image.softmax(dim=-1) 137 | else: 138 | 139 | if return_feature: 140 | __, _, image_features, text_features = self.model(image, self.text_tokens, return_feature=return_feature) 141 | return image_features, text_features 142 | if replay is not None: 143 | logits_per_image, _ = self.model(image, self.text_tokens) 144 | # text_features_for_replay = self.model.encode_text(self.text_tokens[:-self.cfg.increment]) 145 | text_features_for_replay = self.model.encode_text(self.text_tokens) 146 | text_features_for_replay = text_features_for_replay / text_features_for_replay.norm(dim=1, keepdim=True) 147 | replay_features = replay / replay.norm(dim=1, keepdim=True) 148 | replay_logits = replay_features @ text_features_for_replay.t() * 100 149 | else: 150 | logits_per_image, _ = self.model(image, self.text_tokens) 151 | probs = logits_per_image 152 | 153 | if return_feature: 154 | text_features = self.model.encode_text(self.all_text_tokens) 155 | return probs, image_features, text_features 156 | 157 | if replay is not None: 158 | return probs, replay_logits 159 | return probs 160 | 161 | def adaptation(self, task_id, reset=False): 162 | self.current_task +=1 163 | if reset and self.current_task>0: 164 | ori_state = torch.load('ori_state.pth') 165 | if self.only_reset_B: 166 | now_state = self.model.state_dict() 167 | lora_params = {k: v for k, v in ori_state.items() if 'lora_B' in k} 168 | now_state.update(lora_params) 169 | else: 170 | now_state = ori_state 171 | self.model.load_state_dict(now_state) 172 | if self.freeze_A and self.current_task>0: 173 | for name, param in self.model.named_parameters(): 174 | if 'lora_A' in name: 175 | param.requires_grad = False 176 | 177 | self.current_task_class_names = get_class_names(self.classes_names, self.class_ids_per_task[task_id]) 178 | self.current_class_names += self.current_task_class_names 179 | self.text_tokens = clip.tokenize( 180 | [self.prompt_template.format(c) for c in self.current_class_names] 181 | ).to(self.device) 182 | self.current_task_text_tokens = clip.tokenize( 183 | [self.prompt_template.format(c) for c in self.current_task_class_names] 184 | ).to(self.device) 185 | if self.current_task == 0: 186 | class_names = [] 187 | for i in range(self.cfg.task_num): 188 | class_names += get_class_names(self.classes_names, self.class_ids_per_task[i]) 189 | self.all_class_names = class_names 190 | self.all_text_tokens = clip.tokenize( 191 | [self.prompt_template.format(c) for c in self.all_class_names] 192 | ).to(self.device) 193 | 194 | 195 | 196 | 197 | 198 | def load_model(cfg: DictConfig, device: torch.device) -> nn.Module: 199 | r"""Load a CLIP model in different continual scenarios. 200 | 201 | Arguments: 202 | cfg (DictConfig): Experiment configurations. 203 | device (torch.device): Device to train (or) evaluate the model on. 204 | 205 | Returns: 206 | nn.Module: Return scenario specific CLIP model. 207 | """ 208 | if cfg.scenario == "class": 209 | return ClassIncrementalCLIP(cfg, device) 210 | else: 211 | raise ValueError(f""" 212 | `{cfg.scenarios}` is not a valid scenario, 213 | Please choose from ['class', "domain', 'task-agnostic'] 214 | """) 215 | 216 | -------------------------------------------------------------------------------- /loraclip/lora_clip.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/jaisidhsingh/LoRA-CLIP 2 | import hashlib 3 | import os 4 | import urllib 5 | import warnings 6 | from typing import Any, Union, List 7 | from pkg_resources import packaging 8 | 9 | import torch 10 | from PIL import Image 11 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 12 | from tqdm import tqdm 13 | 14 | import clip 15 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer # use default clip's tokenization 16 | 17 | from .model import build_LoRA_model 18 | 19 | try: 20 | from torchvision.transforms import InterpolationMode 21 | BICUBIC = InterpolationMode.BICUBIC 22 | except ImportError: 23 | BICUBIC = Image.BICUBIC 24 | 25 | 26 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 27 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 28 | 29 | 30 | __all__ = ["available_models", "load", "tokenize"] 31 | _tokenizer = _Tokenizer() 32 | 33 | _MODELS = { 34 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 35 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 36 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 37 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 38 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 39 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 40 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 41 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 42 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 43 | } 44 | 45 | 46 | def _download(url: str, root: str): 47 | os.makedirs(root, exist_ok=True) 48 | filename = os.path.basename(url) 49 | 50 | expected_sha256 = url.split("/")[-2] 51 | download_target = os.path.join(root, filename) 52 | 53 | if os.path.exists(download_target) and not os.path.isfile(download_target): 54 | raise RuntimeError(f"{download_target} exists and is not a regular file") 55 | 56 | if os.path.isfile(download_target): 57 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 58 | return download_target 59 | else: 60 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 61 | 62 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 63 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 64 | while True: 65 | buffer = source.read(8192) 66 | if not buffer: 67 | break 68 | 69 | output.write(buffer) 70 | loop.update(len(buffer)) 71 | 72 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 73 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") 74 | 75 | return download_target 76 | 77 | 78 | def _convert_image_to_rgb(image): 79 | return image.convert("RGB") 80 | 81 | 82 | def _transform(n_px): 83 | return Compose([ 84 | Resize(n_px, interpolation=BICUBIC), 85 | CenterCrop(n_px), 86 | _convert_image_to_rgb, 87 | ToTensor(), 88 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 89 | ]) 90 | 91 | 92 | def available_models() -> List[str]: 93 | """Returns the names of available CLIP models""" 94 | return list(_MODELS.keys()) 95 | 96 | 97 | def load(name: str, 98 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 99 | jit: bool = False, 100 | download_root: str = None, 101 | r: int = 4, 102 | lora_mode: str = "vision+text" 103 | ): 104 | """Load a CLIP model 105 | 106 | Parameters 107 | ---------- 108 | name : str 109 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 110 | 111 | device : Union[str, torch.device] 112 | The device to put the loaded model 113 | 114 | jit : bool 115 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 116 | 117 | download_root: str 118 | path to download the model files; by default, it uses "~/.cache/clip" 119 | 120 | r : int 121 | Rank of the LoRA matrices 122 | 123 | Returns 124 | ------- 125 | model : torch.nn.Module 126 | The CLIP model 127 | 128 | preprocess : Callable[[PIL.Image], torch.Tensor] 129 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 130 | """ 131 | if name in _MODELS: 132 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 133 | elif os.path.isfile(name): 134 | model_path = name 135 | else: 136 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 137 | 138 | with open(model_path, 'rb') as opened_file: 139 | try: 140 | # loading JIT archive 141 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 142 | state_dict = None 143 | except RuntimeError: 144 | # loading saved state dict 145 | if jit: 146 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 147 | jit = False 148 | state_dict = torch.load(opened_file, map_location="cpu") 149 | 150 | if not jit: 151 | model = build_LoRA_model(state_dict or model.state_dict(), r, lora_mode).to(device) 152 | if str(device) == "cpu": 153 | model.float() 154 | return model, _transform(model.visual.input_resolution) 155 | 156 | # patch the device names 157 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 158 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 159 | 160 | def _node_get(node: torch._C.Node, key: str): 161 | """Gets attributes of a node which is polymorphic over return type. 162 | 163 | From https://github.com/pytorch/pytorch/pull/82628 164 | """ 165 | sel = node.kindOf(key) 166 | return getattr(node, sel)(key) 167 | 168 | def patch_device(module): 169 | try: 170 | graphs = [module.graph] if hasattr(module, "graph") else [] 171 | except RuntimeError: 172 | graphs = [] 173 | 174 | if hasattr(module, "forward1"): 175 | graphs.append(module.forward1.graph) 176 | 177 | for graph in graphs: 178 | for node in graph.findAllNodes("prim::Constant"): 179 | if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"): 180 | node.copyAttributes(device_node) 181 | 182 | model.apply(patch_device) 183 | patch_device(model.encode_image) 184 | patch_device(model.encode_text) 185 | 186 | # patch dtype to float32 on CPU 187 | if str(device) == "cpu": 188 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 189 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 190 | float_node = float_input.node() 191 | 192 | def patch_float(module): 193 | try: 194 | graphs = [module.graph] if hasattr(module, "graph") else [] 195 | except RuntimeError: 196 | graphs = [] 197 | 198 | if hasattr(module, "forward1"): 199 | graphs.append(module.forward1.graph) 200 | 201 | for graph in graphs: 202 | for node in graph.findAllNodes("aten::to"): 203 | inputs = list(node.inputs()) 204 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 205 | if _node_get(inputs[i].node(), "value") == 5: 206 | inputs[i].node().copyAttributes(float_node) 207 | 208 | model.apply(patch_float) 209 | patch_float(model.encode_image) 210 | patch_float(model.encode_text) 211 | 212 | model.float() 213 | 214 | return model, _transform(model.input_resolution.item()) 215 | 216 | def tokenize(texts, context_length=77, truncate=False): 217 | return clip.tokenize(texts, truncate=truncate) -------------------------------------------------------------------------------- /epoch.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import json 4 | import pdb 5 | import hydra 6 | import logging 7 | from omegaconf import DictConfig 8 | 9 | import torch 10 | import statistics 11 | from torch.utils.data import DataLoader 12 | import torch.nn.functional as F 13 | from continuum.metrics import Logger 14 | import random 15 | import numpy as np 16 | from collections import defaultdict 17 | 18 | from tqdm import tqdm 19 | from continual_clip import utils 20 | from continual_clip.models import load_model, VisionClassifier 21 | from continual_clip.datasets import build_cl_scenarios 22 | from sklearn.cluster import KMeans 23 | from continuum import rehearsal 24 | import copy 25 | from torchvision import transforms 26 | try: 27 | from torchvision.transforms import InterpolationMode 28 | BICUBIC = InterpolationMode.BICUBIC 29 | except ImportError: 30 | BICUBIC = Image.BICUBIC 31 | 32 | def intra_cls(logits, y, classes): 33 | y = y - classes 34 | logits1 = logits[:, classes:] 35 | return F.cross_entropy(logits1, y, reduction='none') 36 | 37 | def get_finetuning_dataset(dataset, memory, finetuning='balanced', oversample_old=1, task_id=0): 38 | if finetuning == 'balanced': 39 | x, y, t = memory.get() 40 | 41 | if oversample_old > 1: 42 | old_indexes = np.where(t < task_id)[0] 43 | assert len(old_indexes) > 0 44 | new_indexes = np.where(t >= task_id)[0] 45 | 46 | indexes = np.concatenate([ 47 | np.repeat(old_indexes, oversample_old), 48 | new_indexes 49 | ]) 50 | x, y, t = x[indexes], y[indexes], t[indexes] 51 | 52 | new_dataset = copy.deepcopy(dataset) 53 | new_dataset._x = x 54 | new_dataset._y = y 55 | new_dataset._t = t 56 | return new_dataset 57 | 58 | 59 | def seed_everything(seed=0): 60 | """Fix all random seeds""" 61 | random.seed(seed) 62 | np.random.seed(seed) 63 | torch.manual_seed(seed) 64 | torch.cuda.manual_seed_all(seed) 65 | torch.backends.cudnn.deterministic = True 66 | os.environ['PYTHONHASHSEED'] = str(seed) 67 | 68 | def activation(x): 69 | return torch.exp(-10*(1-x)) 70 | 71 | 72 | 73 | def run_class_incremental(cfg, device): 74 | 75 | cfg.class_order = utils.get_class_order(os.path.join(cfg.workdir, cfg.class_order)) 76 | model = load_model(cfg, device) 77 | 78 | eval_dataset, classes_names = build_cl_scenarios( 79 | cfg, is_train=False, transforms=model.transforms 80 | ) 81 | train_dataset, _ = build_cl_scenarios( 82 | cfg, is_train=True, transforms=model.transforms 83 | ) 84 | # pdb.set_trace() 85 | model.classes_names = classes_names 86 | if cfg.visual_clsf: 87 | if cfg.model_name == "ViT-L/14": 88 | vision_clsf = VisionClassifier(768, cfg.increment, activation=None) 89 | else: 90 | vision_clsf = VisionClassifier(512, cfg.increment, activation=None) 91 | 92 | 93 | acc_list = [] 94 | metric_logger = Logger(list_subsets=["test"]) 95 | 96 | p1 = 0 97 | p2 = 0 98 | negative_records = 0 99 | trainable_params = {k: v for k, v in model.named_parameters() if v.requires_grad} 100 | # pdb.set_trace() 101 | torch.save(trainable_params, f'ori_params.pth') 102 | 103 | if cfg.real_replay: 104 | memory = rehearsal.RehearsalMemory( 105 | memory_size=2000, 106 | herding_method="random" 107 | ) 108 | for task_id, _ in enumerate(eval_dataset): 109 | 110 | # negative_records = 0 111 | 112 | torch.cuda.empty_cache() 113 | if task_id == 0: 114 | targets_bais = 0 115 | else: 116 | targets_bais = cfg.initial_increment + (task_id - 1) * cfg.increment 117 | 118 | logging.info(f"Evaluation for task {task_id} has started.") 119 | model.adaptation(task_id, reset=cfg.reset) 120 | 121 | # 将model的参数保存 122 | trainable_params = {k: v for k, v in model.named_parameters() if v.requires_grad} 123 | torch.save(trainable_params, f'trainable_params.pth') 124 | 125 | trainable_params = torch.load(f'ori_params.pth') 126 | model.load_state_dict(trainable_params, strict=False) 127 | 128 | # 计算未经训练时正类别和负类别的输出平均值 129 | model.eval() # 切换到评估模式 130 | positive_outputs = [] 131 | negative_outputs = [] 132 | 133 | val_gap_loader = DataLoader(train_dataset[task_id], batch_size=cfg.train_batch_size, shuffle=True, num_workers=cfg.num_workers) 134 | 135 | with torch.no_grad(): 136 | for inputs, targets, t in val_gap_loader: 137 | inputs, targets = inputs.to(device), targets.to(device) 138 | outputs = model(inputs) 139 | 140 | one_hot_targets = torch.nn.functional.one_hot(targets, outputs.shape[1]).float() 141 | positive_outputs.append((outputs * one_hot_targets).sum(dim=1).mean()) 142 | mask = 1 - one_hot_targets 143 | negative_outputs.append(((outputs * mask).sum(dim=1) / mask.sum(dim=1)).mean()) 144 | positive_mean = sum(positive_outputs) / len(positive_outputs) 145 | negative_mean = sum(negative_outputs) / len(negative_outputs) 146 | # if task_id == 0: 147 | negative_records = negative_mean 148 | # if task_id == 0: 149 | logit_size = cfg.increment if task_id>0 else cfg.initial_increment 150 | bias_logit = torch.full((logit_size,), negative_mean, device=device) 151 | bias_logit[0] = positive_mean 152 | # pdb.set_trace() 153 | # pdb.set_trace() 154 | logging.info(f"positive_records: {positive_mean}") 155 | logging.info(f"negative_records: {negative_mean}") 156 | # pdb.set_trace() 157 | trainable_params = torch.load(f'trainable_params.pth') 158 | model.load_state_dict(trainable_params, strict=False) 159 | 160 | model.train() 161 | if task_id > 0 and cfg.real_replay: 162 | mem_x, mem_y, mem_t = memory.get() 163 | t_data = train_dataset[task_id] 164 | t_data.add_samples(mem_x, mem_y, mem_t) 165 | else: 166 | t_data = train_dataset[task_id] 167 | train_loader = DataLoader(t_data, batch_size=cfg.train_batch_size, shuffle=True, num_workers=cfg.num_workers) 168 | 169 | epochs = 10 170 | 171 | if epochs>0: 172 | # filter out the parameters that require grad 173 | params = filter(lambda p: p.requires_grad, model.parameters()) 174 | optimizer = torch.optim.Adam(params, lr=cfg.lr) 175 | # optimizer = torch.optim.SGD(params, lr=cfg.lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay) 176 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, eta_min=cfg.lr*0.01) 177 | # for name, param in model.named_parameters(): 178 | # if param.requires_grad: 179 | # print(name) 180 | torch.cuda.empty_cache() 181 | for i_epoch in range(epochs): 182 | 183 | for bach_i, (inputs, targets, t) in enumerate(train_loader): 184 | loss_c = torch.tensor(0.0).to(device) 185 | loss = torch.tensor(0.0).to(device) 186 | 187 | replay_loss = torch.tensor(0.0).to(device) 188 | torch.cuda.empty_cache() 189 | 190 | 191 | # targets = targets - targets_bais 192 | inputs, targets = inputs.to(device), targets.to(device) 193 | 194 | outputs = model(inputs) 195 | # image_f, text_f = model(inputs, return_feature=True) 196 | if task_id >0: 197 | 198 | if cfg.real_replay: 199 | mask_replay = (targets < targets_bais) 200 | old_targets = targets[mask_replay].clone() 201 | old_outputs = outputs[mask_replay].clone() 202 | targets = targets[~mask_replay] 203 | outputs = outputs[~mask_replay] 204 | replay_loss = intra_cls(old_outputs, old_targets, 0).mean()*0.1 205 | loss_c = intra_cls(outputs,targets,targets_bais).mean() + replay_loss 206 | pass 207 | else: 208 | 209 | loss_c = torch.nn.functional.cross_entropy(outputs, targets) 210 | loss += loss_c 211 | optimizer.zero_grad() 212 | loss.backward() 213 | optimizer.step() 214 | if bach_i % 10 == 0: 215 | logging.info(f"Epoch {i_epoch + 1}/{epochs} | Batch {bach_i + 1}/{len(train_loader)} | Loss: {loss.item()} | Loss_c: {loss_c.item()}") 216 | scheduler.step() 217 | 218 | 219 | 220 | torch.cuda.empty_cache() 221 | positive_outputs = [] 222 | negative_outputs = [] 223 | with torch.no_grad(): 224 | model.eval() 225 | for inputs, targets, t in val_gap_loader: 226 | inputs, targets = inputs.to(device), targets.to(device) 227 | outputs = model(inputs) 228 | # pdb.set_trace() 229 | one_hot_targets = torch.nn.functional.one_hot(targets, outputs.shape[1]).float() 230 | positive_outputs.append((outputs * one_hot_targets).sum(dim=1).mean()) 231 | mask = 1 - one_hot_targets 232 | negative_outputs.append(((outputs * mask).sum(dim=1) / mask.sum(dim=1)).mean()) 233 | model.train() 234 | positive_mean = sum(positive_outputs) / len(positive_outputs) 235 | negative_mean = sum(negative_outputs) / len(negative_outputs) 236 | all_mean = (sum(positive_outputs)+ sum(positive_outputs))/ (len(positive_outputs)+len(negative_outputs)) 237 | 238 | logging.info(f"positive_mean: {positive_mean}") 239 | logging.info(f"negative_mean: {negative_mean}") 240 | torch.cuda.empty_cache() 241 | if (abs(negative_records - negative_mean)/negative_records)>0.1: 242 | if i_epoch>0: 243 | logging.info(f"Negative records changed too much, epoch {i_epoch}") 244 | else: 245 | logging.info(f"Negative records changed too much, epoch 1") 246 | exit(0) 247 | 248 | 249 | 250 | 251 | 252 | def run_domain_incremental(cfg, device): 253 | 254 | model = model = load_model(cfg, device) 255 | eval_dataset, classes_names = build_cl_scenarios( 256 | cfg, is_train=False, transforms=model.transforms 257 | ) 258 | model.tokenize(classes_names) 259 | 260 | with open(cfg.log_path, 'w+') as f: 261 | pass 262 | 263 | logger = Logger(list_subsets=["test"]) 264 | logging.info(f">>> Evaluation scenario length is {len(eval_dataset)}") 265 | for task_id, _ in enumerate(eval_dataset): 266 | 267 | dataset_val = eval_dataset[:task_id + 1] 268 | eval_loader = DataLoader(dataset_val, batch_size=cfg.batch_size) 269 | for input, target, task_ids in tqdm(eval_loader): 270 | input, target = input.to(device), target.to(device) 271 | output = torch.from_numpy(model(input)) 272 | logger.add([output.cpu().argmax(dim=1), target.cpu(), task_ids], subset='test') 273 | 274 | with open(cfg.log_path, 'a+') as f: 275 | f.write(json.dumps({ 276 | 'task': task_id, 277 | 'acc': round(100 * logger.accuracy, 2), 278 | }) + '\n') 279 | 280 | logger.end_task() 281 | 282 | def run_task_agnostic(): 283 | pass 284 | 285 | 286 | 287 | @hydra.main(config_path=None, config_name=None, version_base="1.1") 288 | def continual_clip(cfg: DictConfig) -> None: 289 | seed_everything(cfg.seed) 290 | cfg.workdir = utils.get_workdir(path=os.getcwd()) 291 | cfg.dataset_root = os.path.join(cfg.workdir, cfg.dataset_root) 292 | 293 | utils.save_config(cfg) 294 | with open(cfg.log_path, 'w+') as f: 295 | pass 296 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 297 | 298 | if cfg.scenario == "class": 299 | run_class_incremental(cfg, device) 300 | 301 | elif cfg.scenario == "domain": 302 | run_domain_incremental(cfg, device) 303 | 304 | elif cfg.scenario == "task-agnostic": 305 | NotImplementedError("Method has not been implemented. Soon be added.") 306 | 307 | else: 308 | ValueError(f"You have entered `{cfg.scenario}` which is not a defined scenario, " 309 | "please choose from {{'class', 'domain', 'task-agnostic'}}.") 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | if __name__ == "__main__": 333 | continual_clip() 334 | -------------------------------------------------------------------------------- /dataset_reqs/imagenet_c_classes.txt: -------------------------------------------------------------------------------- 1 | 0 tench 2 | 1 goldfish 3 | 2 great white shark 4 | 3 tiger shark 5 | 4 hammerhead shark 6 | 5 electric ray 7 | 6 stingray 8 | 7 rooster 9 | 8 hen 10 | 9 ostrich 11 | 10 brambling 12 | 11 goldfinch 13 | 12 house finch 14 | 13 junco 15 | 14 indigo bunting 16 | 15 American robin 17 | 16 bulbul 18 | 17 jay 19 | 18 magpie 20 | 19 chickadee 21 | 20 American dipper 22 | 21 kite (bird of prey) 23 | 22 bald eagle 24 | 23 vulture 25 | 24 great grey owl 26 | 25 fire salamander 27 | 26 smooth newt 28 | 27 newt 29 | 28 spotted salamander 30 | 29 axolotl 31 | 30 American bullfrog 32 | 31 tree frog 33 | 32 tailed frog 34 | 33 loggerhead sea turtle 35 | 34 leatherback sea turtle 36 | 35 mud turtle 37 | 36 terrapin 38 | 37 box turtle 39 | 38 banded gecko 40 | 39 green iguana 41 | 40 Carolina anole 42 | 41 desert grassland whiptail lizard 43 | 42 agama 44 | 43 frilled-necked lizard 45 | 44 alligator lizard 46 | 45 Gila monster 47 | 46 European green lizard 48 | 47 chameleon 49 | 48 Komodo dragon 50 | 49 Nile crocodile 51 | 50 American alligator 52 | 51 triceratops 53 | 52 worm snake 54 | 53 ring-necked snake 55 | 54 eastern hog-nosed snake 56 | 55 smooth green snake 57 | 56 kingsnake 58 | 57 garter snake 59 | 58 water snake 60 | 59 vine snake 61 | 60 night snake 62 | 61 boa constrictor 63 | 62 African rock python 64 | 63 Indian cobra 65 | 64 green mamba 66 | 65 sea snake 67 | 66 Saharan horned viper 68 | 67 eastern diamondback rattlesnake 69 | 68 sidewinder rattlesnake 70 | 69 trilobite 71 | 70 harvestman 72 | 71 scorpion 73 | 72 yellow garden spider 74 | 73 barn spider 75 | 74 European garden spider 76 | 75 southern black widow 77 | 76 tarantula 78 | 77 wolf spider 79 | 78 tick 80 | 79 centipede 81 | 80 black grouse 82 | 81 ptarmigan 83 | 82 ruffed grouse 84 | 83 prairie grouse 85 | 84 peafowl 86 | 85 quail 87 | 86 partridge 88 | 87 african grey parrot 89 | 88 macaw 90 | 89 sulphur-crested cockatoo 91 | 90 lorikeet 92 | 91 coucal 93 | 92 bee eater 94 | 93 hornbill 95 | 94 hummingbird 96 | 95 jacamar 97 | 96 toucan 98 | 97 duck 99 | 98 red-breasted merganser 100 | 99 goose 101 | 100 black swan 102 | 101 tusker 103 | 102 echidna 104 | 103 platypus 105 | 104 wallaby 106 | 105 koala 107 | 106 wombat 108 | 107 jellyfish 109 | 108 sea anemone 110 | 109 brain coral 111 | 110 flatworm 112 | 111 nematode 113 | 112 conch 114 | 113 snail 115 | 114 slug 116 | 115 sea slug 117 | 116 chiton 118 | 117 chambered nautilus 119 | 118 Dungeness crab 120 | 119 rock crab 121 | 120 fiddler crab 122 | 121 red king crab 123 | 122 American lobster 124 | 123 spiny lobster 125 | 124 crayfish 126 | 125 hermit crab 127 | 126 isopod 128 | 127 white stork 129 | 128 black stork 130 | 129 spoonbill 131 | 130 flamingo 132 | 131 little blue heron 133 | 132 great egret 134 | 133 bittern bird 135 | 134 crane bird 136 | 135 limpkin 137 | 136 common gallinule 138 | 137 American coot 139 | 138 bustard 140 | 139 ruddy turnstone 141 | 140 dunlin 142 | 141 common redshank 143 | 142 dowitcher 144 | 143 oystercatcher 145 | 144 pelican 146 | 145 king penguin 147 | 146 albatross 148 | 147 grey whale 149 | 148 killer whale 150 | 149 dugong 151 | 150 sea lion 152 | 151 Chihuahua 153 | 152 Japanese Chin 154 | 153 Maltese 155 | 154 Pekingese 156 | 155 Shih Tzu 157 | 156 King Charles Spaniel 158 | 157 Papillon 159 | 158 toy terrier 160 | 159 Rhodesian Ridgeback 161 | 160 Afghan Hound 162 | 161 Basset Hound 163 | 162 Beagle 164 | 163 Bloodhound 165 | 164 Bluetick Coonhound 166 | 165 Black and Tan Coonhound 167 | 166 Treeing Walker Coonhound 168 | 167 English foxhound 169 | 168 Redbone Coonhound 170 | 169 borzoi 171 | 170 Irish Wolfhound 172 | 171 Italian Greyhound 173 | 172 Whippet 174 | 173 Ibizan Hound 175 | 174 Norwegian Elkhound 176 | 175 Otterhound 177 | 176 Saluki 178 | 177 Scottish Deerhound 179 | 178 Weimaraner 180 | 179 Staffordshire Bull Terrier 181 | 180 American Staffordshire Terrier 182 | 181 Bedlington Terrier 183 | 182 Border Terrier 184 | 183 Kerry Blue Terrier 185 | 184 Irish Terrier 186 | 185 Norfolk Terrier 187 | 186 Norwich Terrier 188 | 187 Yorkshire Terrier 189 | 188 Wire Fox Terrier 190 | 189 Lakeland Terrier 191 | 190 Sealyham Terrier 192 | 191 Airedale Terrier 193 | 192 Cairn Terrier 194 | 193 Australian Terrier 195 | 194 Dandie Dinmont Terrier 196 | 195 Boston Terrier 197 | 196 Miniature Schnauzer 198 | 197 Giant Schnauzer 199 | 198 Standard Schnauzer 200 | 199 Scottish Terrier 201 | 200 Tibetan Terrier 202 | 201 Australian Silky Terrier 203 | 202 Soft-coated Wheaten Terrier 204 | 203 West Highland White Terrier 205 | 204 Lhasa Apso 206 | 205 Flat-Coated Retriever 207 | 206 Curly-coated Retriever 208 | 207 Golden Retriever 209 | 208 Labrador Retriever 210 | 209 Chesapeake Bay Retriever 211 | 210 German Shorthaired Pointer 212 | 211 Vizsla 213 | 212 English Setter 214 | 213 Irish Setter 215 | 214 Gordon Setter 216 | 215 Brittany dog 217 | 216 Clumber Spaniel 218 | 217 English Springer Spaniel 219 | 218 Welsh Springer Spaniel 220 | 219 Cocker Spaniel 221 | 220 Sussex Spaniel 222 | 221 Irish Water Spaniel 223 | 222 Kuvasz 224 | 223 Schipperke 225 | 224 Groenendael dog 226 | 225 Malinois 227 | 226 Briard 228 | 227 Australian Kelpie 229 | 228 Komondor 230 | 229 Old English Sheepdog 231 | 230 Shetland Sheepdog 232 | 231 collie 233 | 232 Border Collie 234 | 233 Bouvier des Flandres dog 235 | 234 Rottweiler 236 | 235 German Shepherd Dog 237 | 236 Dobermann 238 | 237 Miniature Pinscher 239 | 238 Greater Swiss Mountain Dog 240 | 239 Bernese Mountain Dog 241 | 240 Appenzeller Sennenhund 242 | 241 Entlebucher Sennenhund 243 | 242 Boxer 244 | 243 Bullmastiff 245 | 244 Tibetan Mastiff 246 | 245 French Bulldog 247 | 246 Great Dane 248 | 247 St. Bernard 249 | 248 husky 250 | 249 Alaskan Malamute 251 | 250 Siberian Husky 252 | 251 Dalmatian 253 | 252 Affenpinscher 254 | 253 Basenji 255 | 254 pug 256 | 255 Leonberger 257 | 256 Newfoundland dog 258 | 257 Great Pyrenees dog 259 | 258 Samoyed 260 | 259 Pomeranian 261 | 260 Chow Chow 262 | 261 Keeshond 263 | 262 brussels griffon 264 | 263 Pembroke Welsh Corgi 265 | 264 Cardigan Welsh Corgi 266 | 265 Toy Poodle 267 | 266 Miniature Poodle 268 | 267 Standard Poodle 269 | 268 Mexican hairless dog (xoloitzcuintli) 270 | 269 grey wolf 271 | 270 Alaskan tundra wolf 272 | 271 red wolf or maned wolf 273 | 272 coyote 274 | 273 dingo 275 | 274 dhole 276 | 275 African wild dog 277 | 276 hyena 278 | 277 red fox 279 | 278 kit fox 280 | 279 Arctic fox 281 | 280 grey fox 282 | 281 tabby cat 283 | 282 tiger cat 284 | 283 Persian cat 285 | 284 Siamese cat 286 | 285 Egyptian Mau 287 | 286 cougar 288 | 287 lynx 289 | 288 leopard 290 | 289 snow leopard 291 | 290 jaguar 292 | 291 lion 293 | 292 tiger 294 | 293 cheetah 295 | 294 brown bear 296 | 295 American black bear 297 | 296 polar bear 298 | 297 sloth bear 299 | 298 mongoose 300 | 299 meerkat 301 | 300 tiger beetle 302 | 301 ladybug 303 | 302 ground beetle 304 | 303 longhorn beetle 305 | 304 leaf beetle 306 | 305 dung beetle 307 | 306 rhinoceros beetle 308 | 307 weevil 309 | 308 fly 310 | 309 bee 311 | 310 ant 312 | 311 grasshopper 313 | 312 cricket insect 314 | 313 stick insect 315 | 314 cockroach 316 | 315 praying mantis 317 | 316 cicada 318 | 317 leafhopper 319 | 318 lacewing 320 | 319 dragonfly 321 | 320 damselfly 322 | 321 red admiral butterfly 323 | 322 ringlet butterfly 324 | 323 monarch butterfly 325 | 324 small white butterfly 326 | 325 sulphur butterfly 327 | 326 gossamer-winged butterfly 328 | 327 starfish 329 | 328 sea urchin 330 | 329 sea cucumber 331 | 330 cottontail rabbit 332 | 331 hare 333 | 332 Angora rabbit 334 | 333 hamster 335 | 334 porcupine 336 | 335 fox squirrel 337 | 336 marmot 338 | 337 beaver 339 | 338 guinea pig 340 | 339 common sorrel horse 341 | 340 zebra 342 | 341 pig 343 | 342 wild boar 344 | 343 warthog 345 | 344 hippopotamus 346 | 345 ox 347 | 346 water buffalo 348 | 347 bison 349 | 348 ram (adult male sheep) 350 | 349 bighorn sheep 351 | 350 Alpine ibex 352 | 351 hartebeest 353 | 352 impala (antelope) 354 | 353 gazelle 355 | 354 arabian camel 356 | 355 llama 357 | 356 weasel 358 | 357 mink 359 | 358 European polecat 360 | 359 black-footed ferret 361 | 360 otter 362 | 361 skunk 363 | 362 badger 364 | 363 armadillo 365 | 364 three-toed sloth 366 | 365 orangutan 367 | 366 gorilla 368 | 367 chimpanzee 369 | 368 gibbon 370 | 369 siamang 371 | 370 guenon 372 | 371 patas monkey 373 | 372 baboon 374 | 373 macaque 375 | 374 langur 376 | 375 black-and-white colobus 377 | 376 proboscis monkey 378 | 377 marmoset 379 | 378 white-headed capuchin 380 | 379 howler monkey 381 | 380 titi monkey 382 | 381 Geoffroy's spider monkey 383 | 382 common squirrel monkey 384 | 383 ring-tailed lemur 385 | 384 indri 386 | 385 Asian elephant 387 | 386 African bush elephant 388 | 387 red panda 389 | 388 giant panda 390 | 389 snoek fish 391 | 390 eel 392 | 391 silver salmon 393 | 392 rock beauty fish 394 | 393 clownfish 395 | 394 sturgeon 396 | 395 gar fish 397 | 396 lionfish 398 | 397 pufferfish 399 | 398 abacus 400 | 399 abaya 401 | 400 academic gown 402 | 401 accordion 403 | 402 acoustic guitar 404 | 403 aircraft carrier 405 | 404 airliner 406 | 405 airship 407 | 406 altar 408 | 407 ambulance 409 | 408 amphibious vehicle 410 | 409 analog clock 411 | 410 apiary 412 | 411 apron 413 | 412 trash can 414 | 413 assault rifle 415 | 414 backpack 416 | 415 bakery 417 | 416 balance beam 418 | 417 balloon 419 | 418 ballpoint pen 420 | 419 Band-Aid 421 | 420 banjo 422 | 421 baluster / handrail 423 | 422 barbell 424 | 423 barber chair 425 | 424 barbershop 426 | 425 barn 427 | 426 barometer 428 | 427 barrel 429 | 428 wheelbarrow 430 | 429 baseball 431 | 430 basketball 432 | 431 bassinet 433 | 432 bassoon 434 | 433 swimming cap 435 | 434 bath towel 436 | 435 bathtub 437 | 436 station wagon 438 | 437 lighthouse 439 | 438 beaker 440 | 439 military hat (bearskin or shako) 441 | 440 beer bottle 442 | 441 beer glass 443 | 442 bell tower 444 | 443 baby bib 445 | 444 tandem bicycle 446 | 445 bikini 447 | 446 ring binder 448 | 447 binoculars 449 | 448 birdhouse 450 | 449 boathouse 451 | 450 bobsleigh 452 | 451 bolo tie 453 | 452 poke bonnet 454 | 453 bookcase 455 | 454 bookstore 456 | 455 bottle cap 457 | 456 hunting bow 458 | 457 bow tie 459 | 458 brass memorial plaque 460 | 459 bra 461 | 460 breakwater 462 | 461 breastplate 463 | 462 broom 464 | 463 bucket 465 | 464 buckle 466 | 465 bulletproof vest 467 | 466 high-speed train 468 | 467 butcher shop 469 | 468 taxicab 470 | 469 cauldron 471 | 470 candle 472 | 471 cannon 473 | 472 canoe 474 | 473 can opener 475 | 474 cardigan 476 | 475 car mirror 477 | 476 carousel 478 | 477 tool kit 479 | 478 cardboard box / carton 480 | 479 car wheel 481 | 480 automated teller machine 482 | 481 cassette 483 | 482 cassette player 484 | 483 castle 485 | 484 catamaran 486 | 485 CD player 487 | 486 cello 488 | 487 mobile phone 489 | 488 chain 490 | 489 chain-link fence 491 | 490 chain mail 492 | 491 chainsaw 493 | 492 storage chest 494 | 493 chiffonier 495 | 494 bell or wind chime 496 | 495 china cabinet 497 | 496 Christmas stocking 498 | 497 church 499 | 498 movie theater 500 | 499 cleaver 501 | 500 cliff dwelling 502 | 501 cloak 503 | 502 clogs 504 | 503 cocktail shaker 505 | 504 coffee mug 506 | 505 coffeemaker 507 | 506 spiral or coil 508 | 507 combination lock 509 | 508 computer keyboard 510 | 509 candy store 511 | 510 container ship 512 | 511 convertible 513 | 512 corkscrew 514 | 513 cornet 515 | 514 cowboy boot 516 | 515 cowboy hat 517 | 516 cradle 518 | 517 construction crane 519 | 518 crash helmet 520 | 519 crate 521 | 520 infant bed 522 | 521 Crock Pot 523 | 522 croquet ball 524 | 523 crutch 525 | 524 cuirass 526 | 525 dam 527 | 526 desk 528 | 527 desktop computer 529 | 528 rotary dial telephone 530 | 529 diaper 531 | 530 digital clock 532 | 531 digital watch 533 | 532 dining table 534 | 533 dishcloth 535 | 534 dishwasher 536 | 535 disc brake 537 | 536 dock 538 | 537 dog sled 539 | 538 dome 540 | 539 doormat 541 | 540 drilling rig 542 | 541 drum 543 | 542 drumstick 544 | 543 dumbbell 545 | 544 Dutch oven 546 | 545 electric fan 547 | 546 electric guitar 548 | 547 electric locomotive 549 | 548 entertainment center 550 | 549 envelope 551 | 550 espresso machine 552 | 551 face powder 553 | 552 feather boa 554 | 553 filing cabinet 555 | 554 fireboat 556 | 555 fire truck 557 | 556 fire screen 558 | 557 flagpole 559 | 558 flute 560 | 559 folding chair 561 | 560 football helmet 562 | 561 forklift 563 | 562 fountain 564 | 563 fountain pen 565 | 564 four-poster bed 566 | 565 freight car 567 | 566 French horn 568 | 567 frying pan 569 | 568 fur coat 570 | 569 garbage truck 571 | 570 gas mask or respirator 572 | 571 gas pump 573 | 572 goblet 574 | 573 go-kart 575 | 574 golf ball 576 | 575 golf cart 577 | 576 gondola 578 | 577 gong 579 | 578 gown 580 | 579 grand piano 581 | 580 greenhouse 582 | 581 radiator grille 583 | 582 grocery store 584 | 583 guillotine 585 | 584 hair clip 586 | 585 hair spray 587 | 586 half-track 588 | 587 hammer 589 | 588 hamper 590 | 589 hair dryer 591 | 590 hand-held computer 592 | 591 handkerchief 593 | 592 hard disk drive 594 | 593 harmonica 595 | 594 harp 596 | 595 combine harvester 597 | 596 hatchet 598 | 597 holster 599 | 598 home theater 600 | 599 honeycomb 601 | 600 hook 602 | 601 hoop skirt 603 | 602 gymnastic horizontal bar 604 | 603 horse-drawn vehicle 605 | 604 hourglass 606 | 605 iPod 607 | 606 clothes iron 608 | 607 carved pumpkin 609 | 608 jeans 610 | 609 jeep 611 | 610 T-shirt 612 | 611 jigsaw puzzle 613 | 612 rickshaw 614 | 613 joystick 615 | 614 kimono 616 | 615 knee pad 617 | 616 knot 618 | 617 lab coat 619 | 618 ladle 620 | 619 lampshade 621 | 620 laptop computer 622 | 621 lawn mower 623 | 622 lens cap 624 | 623 letter opener 625 | 624 library 626 | 625 lifeboat 627 | 626 lighter 628 | 627 limousine 629 | 628 ocean liner 630 | 629 lipstick 631 | 630 slip-on shoe 632 | 631 lotion 633 | 632 music speaker 634 | 633 loupe magnifying glass 635 | 634 sawmill 636 | 635 magnetic compass 637 | 636 messenger bag 638 | 637 mailbox 639 | 638 tights 640 | 639 one-piece bathing suit 641 | 640 manhole cover 642 | 641 maraca 643 | 642 marimba 644 | 643 mask 645 | 644 matchstick 646 | 645 maypole 647 | 646 maze 648 | 647 measuring cup 649 | 648 medicine cabinet 650 | 649 megalith 651 | 650 microphone 652 | 651 microwave oven 653 | 652 military uniform 654 | 653 milk can 655 | 654 minibus 656 | 655 miniskirt 657 | 656 minivan 658 | 657 missile 659 | 658 mitten 660 | 659 mixing bowl 661 | 660 mobile home 662 | 661 ford model t 663 | 662 modem 664 | 663 monastery 665 | 664 monitor 666 | 665 moped 667 | 666 mortar and pestle 668 | 667 graduation cap 669 | 668 mosque 670 | 669 mosquito net 671 | 670 vespa 672 | 671 mountain bike 673 | 672 tent 674 | 673 computer mouse 675 | 674 mousetrap 676 | 675 moving van 677 | 676 muzzle 678 | 677 metal nail 679 | 678 neck brace 680 | 679 necklace 681 | 680 baby pacifier 682 | 681 notebook computer 683 | 682 obelisk 684 | 683 oboe 685 | 684 ocarina 686 | 685 odometer 687 | 686 oil filter 688 | 687 pipe organ 689 | 688 oscilloscope 690 | 689 overskirt 691 | 690 bullock cart 692 | 691 oxygen mask 693 | 692 product packet / packaging 694 | 693 paddle 695 | 694 paddle wheel 696 | 695 padlock 697 | 696 paintbrush 698 | 697 pajamas 699 | 698 palace 700 | 699 pan flute 701 | 700 paper towel 702 | 701 parachute 703 | 702 parallel bars 704 | 703 park bench 705 | 704 parking meter 706 | 705 railroad car 707 | 706 patio 708 | 707 payphone 709 | 708 pedestal 710 | 709 pencil case 711 | 710 pencil sharpener 712 | 711 perfume 713 | 712 Petri dish 714 | 713 photocopier 715 | 714 plectrum 716 | 715 Pickelhaube 717 | 716 picket fence 718 | 717 pickup truck 719 | 718 pier 720 | 719 piggy bank 721 | 720 pill bottle 722 | 721 pillow 723 | 722 ping-pong ball 724 | 723 pinwheel 725 | 724 pirate ship 726 | 725 drink pitcher 727 | 726 block plane 728 | 727 planetarium 729 | 728 plastic bag 730 | 729 plate rack 731 | 730 farm plow 732 | 731 plunger 733 | 732 Polaroid camera 734 | 733 pole 735 | 734 police van 736 | 735 poncho 737 | 736 pool table 738 | 737 soda bottle 739 | 738 plant pot 740 | 739 potter's wheel 741 | 740 power drill 742 | 741 prayer rug 743 | 742 printer 744 | 743 prison 745 | 744 missile 746 | 745 projector 747 | 746 hockey puck 748 | 747 punching bag 749 | 748 purse 750 | 749 quill 751 | 750 quilt 752 | 751 race car 753 | 752 racket 754 | 753 radiator 755 | 754 radio 756 | 755 radio telescope 757 | 756 rain barrel 758 | 757 recreational vehicle 759 | 758 fishing casting reel 760 | 759 reflex camera 761 | 760 refrigerator 762 | 761 remote control 763 | 762 restaurant 764 | 763 revolver 765 | 764 rifle 766 | 765 rocking chair 767 | 766 rotisserie 768 | 767 eraser 769 | 768 rugby ball 770 | 769 ruler measuring stick 771 | 770 sneaker 772 | 771 safe 773 | 772 safety pin 774 | 773 salt shaker 775 | 774 sandal 776 | 775 sarong 777 | 776 saxophone 778 | 777 scabbard 779 | 778 weighing scale 780 | 779 school bus 781 | 780 schooner 782 | 781 scoreboard 783 | 782 CRT monitor 784 | 783 screw 785 | 784 screwdriver 786 | 785 seat belt 787 | 786 sewing machine 788 | 787 shield 789 | 788 shoe store 790 | 789 shoji screen / room divider 791 | 790 shopping basket 792 | 791 shopping cart 793 | 792 shovel 794 | 793 shower cap 795 | 794 shower curtain 796 | 795 ski 797 | 796 balaclava ski mask 798 | 797 sleeping bag 799 | 798 slide rule 800 | 799 sliding door 801 | 800 slot machine 802 | 801 snorkel 803 | 802 snowmobile 804 | 803 snowplow 805 | 804 soap dispenser 806 | 805 soccer ball 807 | 806 sock 808 | 807 solar thermal collector 809 | 808 sombrero 810 | 809 soup bowl 811 | 810 keyboard space bar 812 | 811 space heater 813 | 812 space shuttle 814 | 813 spatula 815 | 814 motorboat 816 | 815 spider web 817 | 816 spindle 818 | 817 sports car 819 | 818 spotlight 820 | 819 stage 821 | 820 steam locomotive 822 | 821 through arch bridge 823 | 822 steel drum 824 | 823 stethoscope 825 | 824 scarf 826 | 825 stone wall 827 | 826 stopwatch 828 | 827 stove 829 | 828 strainer 830 | 829 tram 831 | 830 stretcher 832 | 831 couch 833 | 832 stupa 834 | 833 submarine 835 | 834 suit 836 | 835 sundial 837 | 836 sunglasses 838 | 837 sunglasses 839 | 838 sunscreen 840 | 839 suspension bridge 841 | 840 mop 842 | 841 sweatshirt 843 | 842 swim trunks / shorts 844 | 843 swing 845 | 844 electrical switch 846 | 845 syringe 847 | 846 table lamp 848 | 847 tank 849 | 848 tape player 850 | 849 teapot 851 | 850 teddy bear 852 | 851 television 853 | 852 tennis ball 854 | 853 thatched roof 855 | 854 front curtain 856 | 855 thimble 857 | 856 threshing machine 858 | 857 throne 859 | 858 tile roof 860 | 859 toaster 861 | 860 tobacco shop 862 | 861 toilet seat 863 | 862 torch 864 | 863 totem pole 865 | 864 tow truck 866 | 865 toy store 867 | 866 tractor 868 | 867 semi-trailer truck 869 | 868 tray 870 | 869 trench coat 871 | 870 tricycle 872 | 871 trimaran 873 | 872 tripod 874 | 873 triumphal arch 875 | 874 trolleybus 876 | 875 trombone 877 | 876 hot tub 878 | 877 turnstile 879 | 878 typewriter keyboard 880 | 879 umbrella 881 | 880 unicycle 882 | 881 upright piano 883 | 882 vacuum cleaner 884 | 883 vase 885 | 884 vaulted or arched ceiling 886 | 885 velvet fabric 887 | 886 vending machine 888 | 887 vestment 889 | 888 viaduct 890 | 889 violin 891 | 890 volleyball 892 | 891 waffle iron 893 | 892 wall clock 894 | 893 wallet 895 | 894 wardrobe 896 | 895 military aircraft 897 | 896 sink 898 | 897 washing machine 899 | 898 water bottle 900 | 899 water jug 901 | 900 water tower 902 | 901 whiskey jug 903 | 902 whistle 904 | 903 hair wig 905 | 904 window screen 906 | 905 window shade 907 | 906 Windsor tie 908 | 907 wine bottle 909 | 908 airplane wing 910 | 909 wok 911 | 910 wooden spoon 912 | 911 wool 913 | 912 split-rail fence 914 | 913 shipwreck 915 | 914 sailboat 916 | 915 yurt 917 | 916 website 918 | 917 comic book 919 | 918 crossword 920 | 919 traffic or street sign 921 | 920 traffic light 922 | 921 dust jacket 923 | 922 menu 924 | 923 plate 925 | 924 guacamole 926 | 925 consomme 927 | 926 hot pot 928 | 927 trifle 929 | 928 ice cream 930 | 929 popsicle 931 | 930 baguette 932 | 931 bagel 933 | 932 pretzel 934 | 933 cheeseburger 935 | 934 hot dog 936 | 935 mashed potatoes 937 | 936 cabbage 938 | 937 broccoli 939 | 938 cauliflower 940 | 939 zucchini 941 | 940 spaghetti squash 942 | 941 acorn squash 943 | 942 butternut squash 944 | 943 cucumber 945 | 944 artichoke 946 | 945 bell pepper 947 | 946 cardoon 948 | 947 mushroom 949 | 948 Granny Smith apple 950 | 949 strawberry 951 | 950 orange 952 | 951 lemon 953 | 952 fig 954 | 953 pineapple 955 | 954 banana 956 | 955 jackfruit 957 | 956 cherimoya (custard apple) 958 | 957 pomegranate 959 | 958 hay 960 | 959 carbonara 961 | 960 chocolate syrup 962 | 961 dough 963 | 962 meatloaf 964 | 963 pizza 965 | 964 pot pie 966 | 965 burrito 967 | 966 red wine 968 | 967 espresso 969 | 968 tea cup 970 | 969 eggnog 971 | 970 mountain 972 | 971 bubble 973 | 972 cliff 974 | 973 coral reef 975 | 974 geyser 976 | 975 lakeshore 977 | 976 promontory 978 | 977 sandbar 979 | 978 beach 980 | 979 valley 981 | 980 volcano 982 | 981 baseball player 983 | 982 bridegroom 984 | 983 scuba diver 985 | 984 rapeseed 986 | 985 daisy 987 | 986 yellow lady's slipper 988 | 987 corn 989 | 988 acorn 990 | 989 rose hip 991 | 990 horse chestnut seed 992 | 991 coral fungus 993 | 992 agaric 994 | 993 gyromitra 995 | 994 stinkhorn mushroom 996 | 995 earth star fungus 997 | 996 hen of the woods mushroom 998 | 997 bolete 999 | 998 corn cob 1000 | 999 toilet paper 1001 | -------------------------------------------------------------------------------- /dataset_reqs/imagenet1000_classes.txt: -------------------------------------------------------------------------------- 1 | 0 tench 2 | 1 goldfish 3 | 2 great white shark 4 | 3 tiger shark 5 | 4 hammerhead shark 6 | 5 electric ray 7 | 6 stingray 8 | 7 rooster 9 | 8 hen 10 | 9 ostrich 11 | 10 brambling 12 | 11 goldfinch 13 | 12 house finch 14 | 13 junco 15 | 14 indigo bunting 16 | 15 American robin 17 | 16 bulbul 18 | 17 jay 19 | 18 magpie 20 | 19 chickadee 21 | 20 American dipper 22 | 21 kite (bird of prey) 23 | 22 bald eagle 24 | 23 vulture 25 | 24 great grey owl 26 | 25 fire salamander 27 | 26 smooth newt 28 | 27 newt 29 | 28 spotted salamander 30 | 29 axolotl 31 | 30 American bullfrog 32 | 31 tree frog 33 | 32 tailed frog 34 | 33 loggerhead sea turtle 35 | 34 leatherback sea turtle 36 | 35 mud turtle 37 | 36 terrapin 38 | 37 box turtle 39 | 38 banded gecko 40 | 39 green iguana 41 | 40 Carolina anole 42 | 41 desert grassland whiptail lizard 43 | 42 agama 44 | 43 frilled-necked lizard 45 | 44 alligator lizard 46 | 45 Gila monster 47 | 46 European green lizard 48 | 47 chameleon 49 | 48 Komodo dragon 50 | 49 Nile crocodile 51 | 50 American alligator 52 | 51 triceratops 53 | 52 worm snake 54 | 53 ring-necked snake 55 | 54 eastern hog-nosed snake 56 | 55 smooth green snake 57 | 56 kingsnake 58 | 57 garter snake 59 | 58 water snake 60 | 59 vine snake 61 | 60 night snake 62 | 61 boa constrictor 63 | 62 African rock python 64 | 63 Indian cobra 65 | 64 green mamba 66 | 65 sea snake 67 | 66 Saharan horned viper 68 | 67 eastern diamondback rattlesnake 69 | 68 sidewinder rattlesnake 70 | 69 trilobite 71 | 70 harvestman 72 | 71 scorpion 73 | 72 yellow garden spider 74 | 73 barn spider 75 | 74 European garden spider 76 | 75 southern black widow 77 | 76 tarantula 78 | 77 wolf spider 79 | 78 tick 80 | 79 centipede 81 | 80 black grouse 82 | 81 ptarmigan 83 | 82 ruffed grouse 84 | 83 prairie grouse 85 | 84 peafowl 86 | 85 quail 87 | 86 partridge 88 | 87 african grey parrot 89 | 88 macaw 90 | 89 sulphur-crested cockatoo 91 | 90 lorikeet 92 | 91 coucal 93 | 92 bee eater 94 | 93 hornbill 95 | 94 hummingbird 96 | 95 jacamar 97 | 96 toucan 98 | 97 duck 99 | 98 red-breasted merganser 100 | 99 goose 101 | 100 black swan 102 | 101 tusker 103 | 102 echidna 104 | 103 platypus 105 | 104 wallaby 106 | 105 koala 107 | 106 wombat 108 | 107 jellyfish 109 | 108 sea anemone 110 | 109 brain coral 111 | 110 flatworm 112 | 111 nematode 113 | 112 conch 114 | 113 snail 115 | 114 slug 116 | 115 sea slug 117 | 116 chiton 118 | 117 chambered nautilus 119 | 118 Dungeness crab 120 | 119 rock crab 121 | 120 fiddler crab 122 | 121 red king crab 123 | 122 American lobster 124 | 123 spiny lobster 125 | 124 crayfish 126 | 125 hermit crab 127 | 126 isopod 128 | 127 white stork 129 | 128 black stork 130 | 129 spoonbill 131 | 130 flamingo 132 | 131 little blue heron 133 | 132 great egret 134 | 133 bittern bird 135 | 134 crane bird 136 | 135 limpkin 137 | 136 common gallinule 138 | 137 American coot 139 | 138 bustard 140 | 139 ruddy turnstone 141 | 140 dunlin 142 | 141 common redshank 143 | 142 dowitcher 144 | 143 oystercatcher 145 | 144 pelican 146 | 145 king penguin 147 | 146 albatross 148 | 147 grey whale 149 | 148 killer whale 150 | 149 dugong 151 | 150 sea lion 152 | 151 Chihuahua 153 | 152 Japanese Chin 154 | 153 Maltese 155 | 154 Pekingese 156 | 155 Shih Tzu 157 | 156 King Charles Spaniel 158 | 157 Papillon 159 | 158 toy terrier 160 | 159 Rhodesian Ridgeback 161 | 160 Afghan Hound 162 | 161 Basset Hound 163 | 162 Beagle 164 | 163 Bloodhound 165 | 164 Bluetick Coonhound 166 | 165 Black and Tan Coonhound 167 | 166 Treeing Walker Coonhound 168 | 167 English foxhound 169 | 168 Redbone Coonhound 170 | 169 borzoi 171 | 170 Irish Wolfhound 172 | 171 Italian Greyhound 173 | 172 Whippet 174 | 173 Ibizan Hound 175 | 174 Norwegian Elkhound 176 | 175 Otterhound 177 | 176 Saluki 178 | 177 Scottish Deerhound 179 | 178 Weimaraner 180 | 179 Staffordshire Bull Terrier 181 | 180 American Staffordshire Terrier 182 | 181 Bedlington Terrier 183 | 182 Border Terrier 184 | 183 Kerry Blue Terrier 185 | 184 Irish Terrier 186 | 185 Norfolk Terrier 187 | 186 Norwich Terrier 188 | 187 Yorkshire Terrier 189 | 188 Wire Fox Terrier 190 | 189 Lakeland Terrier 191 | 190 Sealyham Terrier 192 | 191 Airedale Terrier 193 | 192 Cairn Terrier 194 | 193 Australian Terrier 195 | 194 Dandie Dinmont Terrier 196 | 195 Boston Terrier 197 | 196 Miniature Schnauzer 198 | 197 Giant Schnauzer 199 | 198 Standard Schnauzer 200 | 199 Scottish Terrier 201 | 200 Tibetan Terrier 202 | 201 Australian Silky Terrier 203 | 202 Soft-coated Wheaten Terrier 204 | 203 West Highland White Terrier 205 | 204 Lhasa Apso 206 | 205 Flat-Coated Retriever 207 | 206 Curly-coated Retriever 208 | 207 Golden Retriever 209 | 208 Labrador Retriever 210 | 209 Chesapeake Bay Retriever 211 | 210 German Shorthaired Pointer 212 | 211 Vizsla 213 | 212 English Setter 214 | 213 Irish Setter 215 | 214 Gordon Setter 216 | 215 Brittany dog 217 | 216 Clumber Spaniel 218 | 217 English Springer Spaniel 219 | 218 Welsh Springer Spaniel 220 | 219 Cocker Spaniel 221 | 220 Sussex Spaniel 222 | 221 Irish Water Spaniel 223 | 222 Kuvasz 224 | 223 Schipperke 225 | 224 Groenendael dog 226 | 225 Malinois 227 | 226 Briard 228 | 227 Australian Kelpie 229 | 228 Komondor 230 | 229 Old English Sheepdog 231 | 230 Shetland Sheepdog 232 | 231 collie 233 | 232 Border Collie 234 | 233 Bouvier des Flandres dog 235 | 234 Rottweiler 236 | 235 German Shepherd Dog 237 | 236 Dobermann 238 | 237 Miniature Pinscher 239 | 238 Greater Swiss Mountain Dog 240 | 239 Bernese Mountain Dog 241 | 240 Appenzeller Sennenhund 242 | 241 Entlebucher Sennenhund 243 | 242 Boxer 244 | 243 Bullmastiff 245 | 244 Tibetan Mastiff 246 | 245 French Bulldog 247 | 246 Great Dane 248 | 247 St. Bernard 249 | 248 husky 250 | 249 Alaskan Malamute 251 | 250 Siberian Husky 252 | 251 Dalmatian 253 | 252 Affenpinscher 254 | 253 Basenji 255 | 254 pug 256 | 255 Leonberger 257 | 256 Newfoundland dog 258 | 257 Great Pyrenees dog 259 | 258 Samoyed 260 | 259 Pomeranian 261 | 260 Chow Chow 262 | 261 Keeshond 263 | 262 brussels griffon 264 | 263 Pembroke Welsh Corgi 265 | 264 Cardigan Welsh Corgi 266 | 265 Toy Poodle 267 | 266 Miniature Poodle 268 | 267 Standard Poodle 269 | 268 Mexican hairless dog (xoloitzcuintli) 270 | 269 grey wolf 271 | 270 Alaskan tundra wolf 272 | 271 red wolf or maned wolf 273 | 272 coyote 274 | 273 dingo 275 | 274 dhole 276 | 275 African wild dog 277 | 276 hyena 278 | 277 red fox 279 | 278 kit fox 280 | 279 Arctic fox 281 | 280 grey fox 282 | 281 tabby cat 283 | 282 tiger cat 284 | 283 Persian cat 285 | 284 Siamese cat 286 | 285 Egyptian Mau 287 | 286 cougar 288 | 287 lynx 289 | 288 leopard 290 | 289 snow leopard 291 | 290 jaguar 292 | 291 lion 293 | 292 tiger 294 | 293 cheetah 295 | 294 brown bear 296 | 295 American black bear 297 | 296 polar bear 298 | 297 sloth bear 299 | 298 mongoose 300 | 299 meerkat 301 | 300 tiger beetle 302 | 301 ladybug 303 | 302 ground beetle 304 | 303 longhorn beetle 305 | 304 leaf beetle 306 | 305 dung beetle 307 | 306 rhinoceros beetle 308 | 307 weevil 309 | 308 fly 310 | 309 bee 311 | 310 ant 312 | 311 grasshopper 313 | 312 cricket insect 314 | 313 stick insect 315 | 314 cockroach 316 | 315 praying mantis 317 | 316 cicada 318 | 317 leafhopper 319 | 318 lacewing 320 | 319 dragonfly 321 | 320 damselfly 322 | 321 red admiral butterfly 323 | 322 ringlet butterfly 324 | 323 monarch butterfly 325 | 324 small white butterfly 326 | 325 sulphur butterfly 327 | 326 gossamer-winged butterfly 328 | 327 starfish 329 | 328 sea urchin 330 | 329 sea cucumber 331 | 330 cottontail rabbit 332 | 331 hare 333 | 332 Angora rabbit 334 | 333 hamster 335 | 334 porcupine 336 | 335 fox squirrel 337 | 336 marmot 338 | 337 beaver 339 | 338 guinea pig 340 | 339 common sorrel horse 341 | 340 zebra 342 | 341 pig 343 | 342 wild boar 344 | 343 warthog 345 | 344 hippopotamus 346 | 345 ox 347 | 346 water buffalo 348 | 347 bison 349 | 348 ram (adult male sheep) 350 | 349 bighorn sheep 351 | 350 Alpine ibex 352 | 351 hartebeest 353 | 352 impala (antelope) 354 | 353 gazelle 355 | 354 arabian camel 356 | 355 llama 357 | 356 weasel 358 | 357 mink 359 | 358 European polecat 360 | 359 black-footed ferret 361 | 360 otter 362 | 361 skunk 363 | 362 badger 364 | 363 armadillo 365 | 364 three-toed sloth 366 | 365 orangutan 367 | 366 gorilla 368 | 367 chimpanzee 369 | 368 gibbon 370 | 369 siamang 371 | 370 guenon 372 | 371 patas monkey 373 | 372 baboon 374 | 373 macaque 375 | 374 langur 376 | 375 black-and-white colobus 377 | 376 proboscis monkey 378 | 377 marmoset 379 | 378 white-headed capuchin 380 | 379 howler monkey 381 | 380 titi monkey 382 | 381 Geoffroy's spider monkey 383 | 382 common squirrel monkey 384 | 383 ring-tailed lemur 385 | 384 indri 386 | 385 Asian elephant 387 | 386 African bush elephant 388 | 387 red panda 389 | 388 giant panda 390 | 389 snoek fish 391 | 390 eel 392 | 391 silver salmon 393 | 392 rock beauty fish 394 | 393 clownfish 395 | 394 sturgeon 396 | 395 gar fish 397 | 396 lionfish 398 | 397 pufferfish 399 | 398 abacus 400 | 399 abaya 401 | 400 academic gown 402 | 401 accordion 403 | 402 acoustic guitar 404 | 403 aircraft carrier 405 | 404 airliner 406 | 405 airship 407 | 406 altar 408 | 407 ambulance 409 | 408 amphibious vehicle 410 | 409 analog clock 411 | 410 apiary 412 | 411 apron 413 | 412 trash can 414 | 413 assault rifle 415 | 414 backpack 416 | 415 bakery 417 | 416 balance beam 418 | 417 balloon 419 | 418 ballpoint pen 420 | 419 Band-Aid 421 | 420 banjo 422 | 421 baluster / handrail 423 | 422 barbell 424 | 423 barber chair 425 | 424 barbershop 426 | 425 barn 427 | 426 barometer 428 | 427 barrel 429 | 428 wheelbarrow 430 | 429 baseball 431 | 430 basketball 432 | 431 bassinet 433 | 432 bassoon 434 | 433 swimming cap 435 | 434 bath towel 436 | 435 bathtub 437 | 436 station wagon 438 | 437 lighthouse 439 | 438 beaker 440 | 439 military hat (bearskin or shako) 441 | 440 beer bottle 442 | 441 beer glass 443 | 442 bell tower 444 | 443 baby bib 445 | 444 tandem bicycle 446 | 445 bikini 447 | 446 ring binder 448 | 447 binoculars 449 | 448 birdhouse 450 | 449 boathouse 451 | 450 bobsleigh 452 | 451 bolo tie 453 | 452 poke bonnet 454 | 453 bookcase 455 | 454 bookstore 456 | 455 bottle cap 457 | 456 hunting bow 458 | 457 bow tie 459 | 458 brass memorial plaque 460 | 459 bra 461 | 460 breakwater 462 | 461 breastplate 463 | 462 broom 464 | 463 bucket 465 | 464 buckle 466 | 465 bulletproof vest 467 | 466 high-speed train 468 | 467 butcher shop 469 | 468 taxicab 470 | 469 cauldron 471 | 470 candle 472 | 471 cannon 473 | 472 canoe 474 | 473 can opener 475 | 474 cardigan 476 | 475 car mirror 477 | 476 carousel 478 | 477 tool kit 479 | 478 cardboard box / carton 480 | 479 car wheel 481 | 480 automated teller machine 482 | 481 cassette 483 | 482 cassette player 484 | 483 castle 485 | 484 catamaran 486 | 485 CD player 487 | 486 cello 488 | 487 mobile phone 489 | 488 chain 490 | 489 chain-link fence 491 | 490 chain mail 492 | 491 chainsaw 493 | 492 storage chest 494 | 493 chiffonier 495 | 494 bell or wind chime 496 | 495 china cabinet 497 | 496 Christmas stocking 498 | 497 church 499 | 498 movie theater 500 | 499 cleaver 501 | 500 cliff dwelling 502 | 501 cloak 503 | 502 clogs 504 | 503 cocktail shaker 505 | 504 coffee mug 506 | 505 coffeemaker 507 | 506 spiral or coil 508 | 507 combination lock 509 | 508 computer keyboard 510 | 509 candy store 511 | 510 container ship 512 | 511 convertible 513 | 512 corkscrew 514 | 513 cornet 515 | 514 cowboy boot 516 | 515 cowboy hat 517 | 516 cradle 518 | 517 construction crane 519 | 518 crash helmet 520 | 519 crate 521 | 520 infant bed 522 | 521 Crock Pot 523 | 522 croquet ball 524 | 523 crutch 525 | 524 cuirass 526 | 525 dam 527 | 526 desk 528 | 527 desktop computer 529 | 528 rotary dial telephone 530 | 529 diaper 531 | 530 digital clock 532 | 531 digital watch 533 | 532 dining table 534 | 533 dishcloth 535 | 534 dishwasher 536 | 535 disc brake 537 | 536 dock 538 | 537 dog sled 539 | 538 dome 540 | 539 doormat 541 | 540 drilling rig 542 | 541 drum 543 | 542 drumstick 544 | 543 dumbbell 545 | 544 Dutch oven 546 | 545 electric fan 547 | 546 electric guitar 548 | 547 electric locomotive 549 | 548 entertainment center 550 | 549 envelope 551 | 550 espresso machine 552 | 551 face powder 553 | 552 feather boa 554 | 553 filing cabinet 555 | 554 fireboat 556 | 555 fire truck 557 | 556 fire screen 558 | 557 flagpole 559 | 558 flute 560 | 559 folding chair 561 | 560 football helmet 562 | 561 forklift 563 | 562 fountain 564 | 563 fountain pen 565 | 564 four-poster bed 566 | 565 freight car 567 | 566 French horn 568 | 567 frying pan 569 | 568 fur coat 570 | 569 garbage truck 571 | 570 gas mask or respirator 572 | 571 gas pump 573 | 572 goblet 574 | 573 go-kart 575 | 574 golf ball 576 | 575 golf cart 577 | 576 gondola 578 | 577 gong 579 | 578 gown 580 | 579 grand piano 581 | 580 greenhouse 582 | 581 radiator grille 583 | 582 grocery store 584 | 583 guillotine 585 | 584 hair clip 586 | 585 hair spray 587 | 586 half-track 588 | 587 hammer 589 | 588 hamper 590 | 589 hair dryer 591 | 590 hand-held computer 592 | 591 handkerchief 593 | 592 hard disk drive 594 | 593 harmonica 595 | 594 harp 596 | 595 combine harvester 597 | 596 hatchet 598 | 597 holster 599 | 598 home theater 600 | 599 honeycomb 601 | 600 hook 602 | 601 hoop skirt 603 | 602 gymnastic horizontal bar 604 | 603 horse-drawn vehicle 605 | 604 hourglass 606 | 605 iPod 607 | 606 clothes iron 608 | 607 carved pumpkin 609 | 608 jeans 610 | 609 jeep 611 | 610 T-shirt 612 | 611 jigsaw puzzle 613 | 612 rickshaw 614 | 613 joystick 615 | 614 kimono 616 | 615 knee pad 617 | 616 knot 618 | 617 lab coat 619 | 618 ladle 620 | 619 lampshade 621 | 620 laptop computer 622 | 621 lawn mower 623 | 622 lens cap 624 | 623 letter opener 625 | 624 library 626 | 625 lifeboat 627 | 626 lighter 628 | 627 limousine 629 | 628 ocean liner 630 | 629 lipstick 631 | 630 slip-on shoe 632 | 631 lotion 633 | 632 music speaker 634 | 633 loupe magnifying glass 635 | 634 sawmill 636 | 635 magnetic compass 637 | 636 messenger bag 638 | 637 mailbox 639 | 638 tights 640 | 639 one-piece bathing suit 641 | 640 manhole cover 642 | 641 maraca 643 | 642 marimba 644 | 643 mask 645 | 644 matchstick 646 | 645 maypole 647 | 646 maze 648 | 647 measuring cup 649 | 648 medicine cabinet 650 | 649 megalith 651 | 650 microphone 652 | 651 microwave oven 653 | 652 military uniform 654 | 653 milk can 655 | 654 minibus 656 | 655 miniskirt 657 | 656 minivan 658 | 657 missile 659 | 658 mitten 660 | 659 mixing bowl 661 | 660 mobile home 662 | 661 ford model t 663 | 662 modem 664 | 663 monastery 665 | 664 monitor 666 | 665 moped 667 | 666 mortar and pestle 668 | 667 graduation cap 669 | 668 mosque 670 | 669 mosquito net 671 | 670 vespa 672 | 671 mountain bike 673 | 672 tent 674 | 673 computer mouse 675 | 674 mousetrap 676 | 675 moving van 677 | 676 muzzle 678 | 677 metal nail 679 | 678 neck brace 680 | 679 necklace 681 | 680 baby pacifier 682 | 681 notebook computer 683 | 682 obelisk 684 | 683 oboe 685 | 684 ocarina 686 | 685 odometer 687 | 686 oil filter 688 | 687 pipe organ 689 | 688 oscilloscope 690 | 689 overskirt 691 | 690 bullock cart 692 | 691 oxygen mask 693 | 692 product packet / packaging 694 | 693 paddle 695 | 694 paddle wheel 696 | 695 padlock 697 | 696 paintbrush 698 | 697 pajamas 699 | 698 palace 700 | 699 pan flute 701 | 700 paper towel 702 | 701 parachute 703 | 702 parallel bars 704 | 703 park bench 705 | 704 parking meter 706 | 705 railroad car 707 | 706 patio 708 | 707 payphone 709 | 708 pedestal 710 | 709 pencil case 711 | 710 pencil sharpener 712 | 711 perfume 713 | 712 Petri dish 714 | 713 photocopier 715 | 714 plectrum 716 | 715 Pickelhaube 717 | 716 picket fence 718 | 717 pickup truck 719 | 718 pier 720 | 719 piggy bank 721 | 720 pill bottle 722 | 721 pillow 723 | 722 ping-pong ball 724 | 723 pinwheel 725 | 724 pirate ship 726 | 725 drink pitcher 727 | 726 block plane 728 | 727 planetarium 729 | 728 plastic bag 730 | 729 plate rack 731 | 730 farm plow 732 | 731 plunger 733 | 732 Polaroid camera 734 | 733 pole 735 | 734 police van 736 | 735 poncho 737 | 736 pool table 738 | 737 soda bottle 739 | 738 plant pot 740 | 739 potter's wheel 741 | 740 power drill 742 | 741 prayer rug 743 | 742 printer 744 | 743 prison 745 | 744 missile 746 | 745 projector 747 | 746 hockey puck 748 | 747 punching bag 749 | 748 purse 750 | 749 quill 751 | 750 quilt 752 | 751 race car 753 | 752 racket 754 | 753 radiator 755 | 754 radio 756 | 755 radio telescope 757 | 756 rain barrel 758 | 757 recreational vehicle 759 | 758 fishing casting reel 760 | 759 reflex camera 761 | 760 refrigerator 762 | 761 remote control 763 | 762 restaurant 764 | 763 revolver 765 | 764 rifle 766 | 765 rocking chair 767 | 766 rotisserie 768 | 767 eraser 769 | 768 rugby ball 770 | 769 ruler measuring stick 771 | 770 sneaker 772 | 771 safe 773 | 772 safety pin 774 | 773 salt shaker 775 | 774 sandal 776 | 775 sarong 777 | 776 saxophone 778 | 777 scabbard 779 | 778 weighing scale 780 | 779 school bus 781 | 780 schooner 782 | 781 scoreboard 783 | 782 CRT monitor 784 | 783 screw 785 | 784 screwdriver 786 | 785 seat belt 787 | 786 sewing machine 788 | 787 shield 789 | 788 shoe store 790 | 789 shoji screen / room divider 791 | 790 shopping basket 792 | 791 shopping cart 793 | 792 shovel 794 | 793 shower cap 795 | 794 shower curtain 796 | 795 ski 797 | 796 balaclava ski mask 798 | 797 sleeping bag 799 | 798 slide rule 800 | 799 sliding door 801 | 800 slot machine 802 | 801 snorkel 803 | 802 snowmobile 804 | 803 snowplow 805 | 804 soap dispenser 806 | 805 soccer ball 807 | 806 sock 808 | 807 solar thermal collector 809 | 808 sombrero 810 | 809 soup bowl 811 | 810 keyboard space bar 812 | 811 space heater 813 | 812 space shuttle 814 | 813 spatula 815 | 814 motorboat 816 | 815 spider web 817 | 816 spindle 818 | 817 sports car 819 | 818 spotlight 820 | 819 stage 821 | 820 steam locomotive 822 | 821 through arch bridge 823 | 822 steel drum 824 | 823 stethoscope 825 | 824 scarf 826 | 825 stone wall 827 | 826 stopwatch 828 | 827 stove 829 | 828 strainer 830 | 829 tram 831 | 830 stretcher 832 | 831 couch 833 | 832 stupa 834 | 833 submarine 835 | 834 suit 836 | 835 sundial 837 | 836 sunglasses 838 | 837 sunglasses 839 | 838 sunscreen 840 | 839 suspension bridge 841 | 840 mop 842 | 841 sweatshirt 843 | 842 swim trunks / shorts 844 | 843 swing 845 | 844 electrical switch 846 | 845 syringe 847 | 846 table lamp 848 | 847 tank 849 | 848 tape player 850 | 849 teapot 851 | 850 teddy bear 852 | 851 television 853 | 852 tennis ball 854 | 853 thatched roof 855 | 854 front curtain 856 | 855 thimble 857 | 856 threshing machine 858 | 857 throne 859 | 858 tile roof 860 | 859 toaster 861 | 860 tobacco shop 862 | 861 toilet seat 863 | 862 torch 864 | 863 totem pole 865 | 864 tow truck 866 | 865 toy store 867 | 866 tractor 868 | 867 semi-trailer truck 869 | 868 tray 870 | 869 trench coat 871 | 870 tricycle 872 | 871 trimaran 873 | 872 tripod 874 | 873 triumphal arch 875 | 874 trolleybus 876 | 875 trombone 877 | 876 hot tub 878 | 877 turnstile 879 | 878 typewriter keyboard 880 | 879 umbrella 881 | 880 unicycle 882 | 881 upright piano 883 | 882 vacuum cleaner 884 | 883 vase 885 | 884 vaulted or arched ceiling 886 | 885 velvet fabric 887 | 886 vending machine 888 | 887 vestment 889 | 888 viaduct 890 | 889 violin 891 | 890 volleyball 892 | 891 waffle iron 893 | 892 wall clock 894 | 893 wallet 895 | 894 wardrobe 896 | 895 military aircraft 897 | 896 sink 898 | 897 washing machine 899 | 898 water bottle 900 | 899 water jug 901 | 900 water tower 902 | 901 whiskey jug 903 | 902 whistle 904 | 903 hair wig 905 | 904 window screen 906 | 905 window shade 907 | 906 Windsor tie 908 | 907 wine bottle 909 | 908 airplane wing 910 | 909 wok 911 | 910 wooden spoon 912 | 911 wool 913 | 912 split-rail fence 914 | 913 shipwreck 915 | 914 sailboat 916 | 915 yurt 917 | 916 website 918 | 917 comic book 919 | 918 crossword 920 | 919 traffic or street sign 921 | 920 traffic light 922 | 921 dust jacket 923 | 922 menu 924 | 923 plate 925 | 924 guacamole 926 | 925 consomme 927 | 926 hot pot 928 | 927 trifle 929 | 928 ice cream 930 | 929 popsicle 931 | 930 baguette 932 | 931 bagel 933 | 932 pretzel 934 | 933 cheeseburger 935 | 934 hot dog 936 | 935 mashed potatoes 937 | 936 cabbage 938 | 937 broccoli 939 | 938 cauliflower 940 | 939 zucchini 941 | 940 spaghetti squash 942 | 941 acorn squash 943 | 942 butternut squash 944 | 943 cucumber 945 | 944 artichoke 946 | 945 bell pepper 947 | 946 cardoon 948 | 947 mushroom 949 | 948 Granny Smith apple 950 | 949 strawberry 951 | 950 orange 952 | 951 lemon 953 | 952 fig 954 | 953 pineapple 955 | 954 banana 956 | 955 jackfruit 957 | 956 cherimoya (custard apple) 958 | 957 pomegranate 959 | 958 hay 960 | 959 carbonara 961 | 960 chocolate syrup 962 | 961 dough 963 | 962 meatloaf 964 | 963 pizza 965 | 964 pot pie 966 | 965 burrito 967 | 966 red wine 968 | 967 espresso 969 | 968 tea cup 970 | 969 eggnog 971 | 970 mountain 972 | 971 bubble 973 | 972 cliff 974 | 973 coral reef 975 | 974 geyser 976 | 975 lakeshore 977 | 976 promontory 978 | 977 sandbar 979 | 978 beach 980 | 979 valley 981 | 980 volcano 982 | 981 baseball player 983 | 982 bridegroom 984 | 983 scuba diver 985 | 984 rapeseed 986 | 985 daisy 987 | 986 yellow lady's slipper 988 | 987 corn 989 | 988 acorn 990 | 989 rose hip 991 | 990 horse chestnut seed 992 | 991 coral fungus 993 | 992 agaric 994 | 993 gyromitra 995 | 994 stinkhorn mushroom 996 | 995 earth star fungus 997 | 996 hen of the woods mushroom 998 | 997 bolete 999 | 998 corn cob 1000 | 999 toilet paper 1001 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import json 4 | import pdb 5 | import hydra 6 | import logging 7 | from omegaconf import DictConfig 8 | 9 | import torch 10 | import statistics 11 | from torch.utils.data import DataLoader 12 | import torch.nn.functional as F 13 | from continuum.metrics import Logger 14 | import random 15 | import numpy as np 16 | from collections import defaultdict 17 | 18 | from tqdm import tqdm 19 | from continual_clip import utils 20 | from continual_clip.models import load_model, VisionClassifier 21 | from continual_clip.datasets import build_cl_scenarios 22 | from sklearn.cluster import KMeans 23 | from continuum import rehearsal 24 | import copy 25 | from torchvision import transforms 26 | try: 27 | from torchvision.transforms import InterpolationMode 28 | BICUBIC = InterpolationMode.BICUBIC 29 | except ImportError: 30 | BICUBIC = Image.BICUBIC 31 | 32 | def intra_cls(logits, y, classes): 33 | y = y - classes 34 | logits1 = logits[:, classes:] 35 | return F.cross_entropy(logits1, y, reduction='none') 36 | 37 | def get_finetuning_dataset(dataset, memory, finetuning='balanced', oversample_old=1, task_id=0): 38 | if finetuning == 'balanced': 39 | x, y, t = memory.get() 40 | 41 | if oversample_old > 1: 42 | old_indexes = np.where(t < task_id)[0] 43 | assert len(old_indexes) > 0 44 | new_indexes = np.where(t >= task_id)[0] 45 | 46 | indexes = np.concatenate([ 47 | np.repeat(old_indexes, oversample_old), 48 | new_indexes 49 | ]) 50 | x, y, t = x[indexes], y[indexes], t[indexes] 51 | 52 | new_dataset = copy.deepcopy(dataset) 53 | new_dataset._x = x 54 | new_dataset._y = y 55 | new_dataset._t = t 56 | return new_dataset 57 | 58 | 59 | def seed_everything(seed=0): 60 | """Fix all random seeds""" 61 | random.seed(seed) 62 | np.random.seed(seed) 63 | torch.manual_seed(seed) 64 | torch.cuda.manual_seed_all(seed) 65 | torch.backends.cudnn.deterministic = True 66 | os.environ['PYTHONHASHSEED'] = str(seed) 67 | 68 | def activation(x): 69 | return torch.exp(-10*(1-x)) 70 | 71 | 72 | 73 | def run_class_incremental(cfg, device): 74 | 75 | cfg.class_order = utils.get_class_order(os.path.join(cfg.workdir, cfg.class_order)) 76 | model = load_model(cfg, device) 77 | 78 | eval_dataset, classes_names = build_cl_scenarios( 79 | cfg, is_train=False, transforms=model.transforms 80 | ) 81 | train_dataset, _ = build_cl_scenarios( 82 | cfg, is_train=True, transforms=model.transforms 83 | ) 84 | # pdb.set_trace() 85 | model.classes_names = classes_names 86 | if cfg.visual_clsf: 87 | if cfg.model_name == "ViT-L/14": 88 | vision_clsf = VisionClassifier(768, cfg.increment, activation=None) 89 | else: 90 | vision_clsf = VisionClassifier(512, cfg.increment, activation=None) 91 | 92 | 93 | acc_list = [] 94 | metric_logger = Logger(list_subsets=["test"]) 95 | 96 | p1 = 0 97 | p2 = 0 98 | negative_records = 0 99 | trainable_params = {k: v for k, v in model.named_parameters() if v.requires_grad} 100 | # pdb.set_trace() 101 | torch.save(trainable_params, f'ori_params.pth') 102 | 103 | if cfg.real_replay: 104 | memory = rehearsal.RehearsalMemory( 105 | memory_size=2000, 106 | herding_method="random" 107 | ) 108 | for task_id, _ in enumerate(eval_dataset): 109 | 110 | # negative_records = 0 111 | 112 | torch.cuda.empty_cache() 113 | if task_id == 0: 114 | targets_bais = 0 115 | else: 116 | targets_bais = cfg.initial_increment + (task_id - 1) * cfg.increment 117 | 118 | logging.info(f"Evaluation for task {task_id} has started.") 119 | model.adaptation(task_id, reset=cfg.reset) 120 | 121 | # 将model的参数保存 122 | trainable_params = {k: v for k, v in model.named_parameters() if v.requires_grad} 123 | torch.save(trainable_params, f'trainable_params.pth') 124 | 125 | trainable_params = torch.load(f'ori_params.pth') 126 | model.load_state_dict(trainable_params, strict=False) 127 | 128 | # 计算未经训练时正类别和负类别的输出平均值 129 | model.eval() # 切换到评估模式 130 | positive_outputs = [] 131 | negative_outputs = [] 132 | 133 | val_gap_loader = DataLoader(train_dataset[task_id], batch_size=cfg.train_batch_size, shuffle=True, num_workers=cfg.num_workers) 134 | 135 | with torch.no_grad(): 136 | for inputs, targets, t in val_gap_loader: 137 | inputs, targets = inputs.to(device), targets.to(device) 138 | outputs = model(inputs) 139 | 140 | one_hot_targets = torch.nn.functional.one_hot(targets, outputs.shape[1]).float() 141 | positive_outputs.append((outputs * one_hot_targets).sum(dim=1).mean()) 142 | mask = 1 - one_hot_targets 143 | negative_outputs.append(((outputs * mask).sum(dim=1) / mask.sum(dim=1)).mean()) 144 | positive_mean = sum(positive_outputs) / len(positive_outputs) 145 | negative_mean = sum(negative_outputs) / len(negative_outputs) 146 | # if task_id == 0: 147 | negative_records = negative_mean 148 | # if task_id == 0: 149 | logit_size = cfg.increment if task_id>0 else cfg.initial_increment 150 | bias_logit = torch.full((logit_size,), negative_mean, device=device) 151 | bias_logit[0] = positive_mean 152 | # pdb.set_trace() 153 | # pdb.set_trace() 154 | logging.info(f"positive_records: {positive_mean}") 155 | logging.info(f"negative_records: {negative_mean}") 156 | # pdb.set_trace() 157 | trainable_params = torch.load(f'trainable_params.pth') 158 | model.load_state_dict(trainable_params, strict=False) 159 | 160 | model.train() 161 | if task_id > 0 and cfg.real_replay: 162 | mem_x, mem_y, mem_t = memory.get() 163 | t_data = train_dataset[task_id] 164 | t_data.add_samples(mem_x, mem_y, mem_t) 165 | else: 166 | t_data = train_dataset[task_id] 167 | # train_loader = DataLoader(train_dataset[:task_id+1], batch_size=cfg.train_batch_size, shuffle=True, num_workers=cfg.num_workers) 168 | train_loader = DataLoader(t_data, batch_size=cfg.train_batch_size, shuffle=True, num_workers=cfg.num_workers) 169 | 170 | epochs = cfg.epochs 171 | 172 | if epochs>0: 173 | # filter out the parameters that require grad 174 | params = filter(lambda p: p.requires_grad, model.parameters()) 175 | optimizer = torch.optim.Adam(params, lr=cfg.lr) 176 | # optimizer = torch.optim.SGD(params, lr=cfg.lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay) 177 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, eta_min=cfg.lr*0.01) 178 | # for name, param in model.named_parameters(): 179 | # if param.requires_grad: 180 | # print(name) 181 | torch.cuda.empty_cache() 182 | for i_epoch in range(epochs): 183 | 184 | for bach_i, (inputs, targets, t) in enumerate(train_loader): 185 | loss_c = torch.tensor(0.0).to(device) 186 | loss = torch.tensor(0.0).to(device) 187 | 188 | replay_loss = torch.tensor(0.0).to(device) 189 | torch.cuda.empty_cache() 190 | 191 | 192 | # targets = targets - targets_bais 193 | inputs, targets = inputs.to(device), targets.to(device) 194 | 195 | outputs = model(inputs) 196 | # image_f, text_f = model(inputs, return_feature=True) 197 | if task_id >0: 198 | if cfg.real_replay: 199 | mask_replay = (targets < targets_bais) 200 | old_targets = targets[mask_replay].clone() 201 | old_outputs = outputs[mask_replay].clone() 202 | targets = targets[~mask_replay] 203 | outputs = outputs[~mask_replay] 204 | replay_loss = intra_cls(old_outputs, old_targets, 0).mean()*0.1 205 | loss_c = intra_cls(outputs,targets,targets_bais).mean() + replay_loss 206 | pass 207 | else: 208 | loss_c = torch.nn.functional.cross_entropy(outputs, targets) 209 | loss += loss_c 210 | optimizer.zero_grad() 211 | loss.backward() 212 | optimizer.step() 213 | if bach_i % 10 == 0: 214 | logging.info(f"Epoch {i_epoch + 1}/{epochs} | Batch {bach_i + 1}/{len(train_loader)} | Loss: {loss.item()} | Loss_c: {loss_c.item()}") 215 | scheduler.step() 216 | 217 | 218 | 219 | # torch.cuda.empty_cache() 220 | # positive_outputs = [] 221 | # negative_outputs = [] 222 | # with torch.no_grad(): 223 | # model.eval() 224 | # for inputs, targets, t in val_gap_loader: 225 | # inputs, targets = inputs.to(device), targets.to(device) 226 | # outputs = model(inputs) 227 | # # pdb.set_trace() 228 | # one_hot_targets = torch.nn.functional.one_hot(targets, outputs.shape[1]).float() 229 | # positive_outputs.append((outputs * one_hot_targets).sum(dim=1).mean()) 230 | # mask = 1 - one_hot_targets 231 | # negative_outputs.append(((outputs * mask).sum(dim=1) / mask.sum(dim=1)).mean()) 232 | # model.train() 233 | # positive_mean = sum(positive_outputs) / len(positive_outputs) 234 | # negative_mean = sum(negative_outputs) / len(negative_outputs) 235 | # all_mean = (sum(positive_outputs)+ sum(positive_outputs))/ (len(positive_outputs)+len(negative_outputs)) 236 | 237 | # logging.info(f"positive_mean: {positive_mean}") 238 | # logging.info(f"negative_mean: {negative_mean}") 239 | # torch.cuda.empty_cache() 240 | 241 | if cfg.real_replay: 242 | memory.add(*train_dataset[task_id].get_raw_samples(), None) 243 | 244 | if cfg.balance_ft and cfg.real_replay and task_id > 0: 245 | balance_data = get_finetuning_dataset(t_data, memory, 'balanced') 246 | balance_loader = DataLoader(balance_data, batch_size=cfg.train_batch_size, shuffle=True, num_workers=cfg.num_workers) 247 | epochs = cfg.balance_epochs 248 | 249 | params = filter(lambda p: p.requires_grad, model.parameters()) 250 | optimizer = torch.optim.Adam(params, lr=cfg.lr*0.01) 251 | # optimizer = torch.optim.SGD(params, lr=cfg.lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay) 252 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, eta_min=cfg.lr*0.001) 253 | for i_epoch in range(epochs): 254 | for bach_i, (inputs, targets, t) in enumerate(balance_loader): 255 | loss_c = torch.tensor(0.0).to(device) 256 | loss = torch.tensor(0.0).to(device) 257 | 258 | replay_loss = torch.tensor(0.0).to(device) 259 | torch.cuda.empty_cache() 260 | 261 | inputs, targets = inputs.to(device), targets.to(device) 262 | outputs = model(inputs) 263 | # image_f, text_f = model(inputs, return_feature=True) 264 | loss_c = torch.nn.functional.cross_entropy(outputs, targets) 265 | loss += loss_c 266 | optimizer.zero_grad() 267 | loss.backward() 268 | optimizer.step() 269 | if bach_i % 10 == 0: 270 | logging.info(f"Epoch {i_epoch + 1}/{epochs} | Batch {bach_i + 1}/{len(balance_loader)} | Loss: {loss.item()} | Loss_c: {loss_c.item()}") 271 | # break 272 | scheduler.step() 273 | 274 | # if task_id > 0: 275 | # alpha = 0.2 # EMA 276 | # print("EMA") 277 | # with torch.no_grad(): 278 | # for name, param in model.named_parameters(): 279 | # if param.requires_grad and name in initial_params: 280 | # param.copy_(alpha * initial_params[name] + (1 - alpha) * param) 281 | if cfg.visual_clsf: 282 | # pdb.set_trace() 283 | torch.cuda.empty_cache() 284 | model.eval() 285 | e_num = cfg.visual_clsf_epochs 286 | vision_clsf_loader = DataLoader(train_dataset[task_id], batch_size=cfg.visual_clsf_batch_size, shuffle=True, num_workers=cfg.num_workers) 287 | features_dict = {} 288 | with torch.no_grad(): 289 | for inputs, targets, t in vision_clsf_loader: 290 | inputs, targets = inputs.to(device), targets.to(device) 291 | _, features, __ = model(inputs, test=True, return_feature=True) 292 | for feature, target in zip(features, targets): 293 | target = target.item() 294 | if target not in features_dict: 295 | features_dict[target] = [] 296 | features_dict[target].append(feature.cpu()) 297 | mean_features = [] 298 | for target in sorted(features_dict.keys()): 299 | features = torch.stack(features_dict[target]) 300 | mean_feature = features.mean(dim=0) 301 | mean_features.append(mean_feature.unsqueeze(0)) 302 | mean_features = torch.cat(mean_features).to(device) 303 | if task_id > 0: 304 | vision_clsf.add_weight(mean_features) 305 | pass 306 | else: 307 | vision_clsf.set_weight(mean_features) 308 | pass 309 | optimizer = torch.optim.Adam(vision_clsf.parameters(), lr=cfg.visual_clsf_lr) 310 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, e_num*len(vision_clsf_loader), eta_min=cfg.visual_clsf_lr*0.01) 311 | for e in range(e_num): 312 | bach_i = -1 313 | for inputs, targets, t in vision_clsf_loader: 314 | inputs, targets = inputs.to(device), targets.to(device) 315 | # pdb.set_trace() 316 | with torch.no_grad(): 317 | outputs, _ = model(inputs, return_feature=True) 318 | # pdb.set_trace() 319 | outputs = vision_clsf(outputs) 320 | # pdb.set_trace() 321 | loss = intra_cls(outputs,targets,targets_bais).mean() 322 | # loss = F.cross_entropy(outputs, targets) 323 | optimizer.zero_grad() 324 | loss.backward() 325 | optimizer.step() 326 | bach_i+=1 327 | if bach_i % 10 == 0: 328 | logging.info(f"Epoch {e + 1}/{e_num} | Batch {bach_i + 1}/{len(vision_clsf_loader)} | Loss: {loss.item()}") 329 | scheduler.step() 330 | 331 | if cfg.balance_ft and cfg.real_replay and task_id > 0: 332 | balance_data = get_finetuning_dataset(t_data, memory, 'balanced') 333 | balance_loader = DataLoader(balance_data, batch_size=cfg.train_batch_size, shuffle=True, num_workers=cfg.num_workers) 334 | epochs = cfg.balance_epochs 335 | 336 | optimizer = torch.optim.Adam(vision_clsf.parameters(), lr=cfg.visual_clsf_lr*0.1) 337 | # optimizer = torch.optim.SGD(params, lr=cfg.lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay) 338 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs*len(balance_loader), eta_min=cfg.lr*0.01) 339 | for i_epoch in range(epochs): 340 | for bach_i, (inputs, targets, t) in enumerate(balance_loader): 341 | loss_c = torch.tensor(0.0).to(device) 342 | loss = torch.tensor(0.0).to(device) 343 | 344 | replay_loss = torch.tensor(0.0).to(device) 345 | torch.cuda.empty_cache() 346 | 347 | inputs, targets = inputs.to(device), targets.to(device) 348 | with torch.no_grad(): 349 | outputs, _ = model(inputs, return_feature=True) 350 | # pdb.set_trace() 351 | outputs = vision_clsf(outputs) 352 | loss_c = torch.nn.functional.cross_entropy(outputs, targets) 353 | loss += loss_c 354 | optimizer.zero_grad() 355 | loss.backward() 356 | optimizer.step() 357 | if bach_i % 10 == 0: 358 | logging.info(f"Epoch {i_epoch + 1}/{epochs} | Batch {bach_i + 1}/{len(balance_loader)} | Loss: {loss.item()} | Loss_c: {loss_c.item()}") 359 | # break 360 | scheduler.step() 361 | 362 | 363 | if cfg.all_test: 364 | eval_loader = DataLoader(eval_dataset[:cfg.task_num], batch_size=cfg.batch_size) 365 | else: 366 | eval_loader = DataLoader(eval_dataset[:task_id + 1], batch_size=cfg.batch_size) 367 | # eval_loader = DataLoader(eval_dataset[:10], batch_size=cfg.batch_size) 368 | image_feature_list = [] 369 | targets_list = [] 370 | model.eval() 371 | text_feature_list = [] 372 | torch.cuda.empty_cache() 373 | correct_per_class = defaultdict(int) 374 | total_per_class = defaultdict(int) 375 | for inputs, targets, task_ids in eval_loader: 376 | inputs, targets = inputs.to(device), targets.to(device) 377 | 378 | with torch.no_grad(): 379 | if cfg.visual_clsf: 380 | a = 1 381 | b = 4 382 | 383 | outputs, image_feature, text_feature = model(inputs, test=True, all_test=cfg.all_test, return_feature=True) 384 | vision_outputs = vision_clsf(image_feature) 385 | 386 | outputs_softmax = F.softmax(outputs, dim=1) 387 | vision_outputs_softmax = F.softmax(vision_outputs, dim=1) 388 | 389 | combined_outputs = (a*outputs_softmax + b*vision_outputs_softmax) / (a + b) 390 | 391 | metric_logger.add([combined_outputs.cpu().argmax(dim=1), targets.cpu(), task_ids], subset="test") 392 | preds = combined_outputs.cpu().argmax(dim=1) 393 | for l,p in zip(targets.cpu(), preds): 394 | label = l.item() 395 | total_per_class[label] += 1 396 | if l == p: 397 | correct_per_class[label] += 1 398 | else: 399 | outputs = model(inputs, test=True, all_test=cfg.all_test) 400 | metric_logger.add([outputs.cpu().argmax(dim=1), targets.cpu(), task_ids], subset="test") 401 | class_acc = {} 402 | for clas in total_per_class: 403 | acc = correct_per_class[clas] / total_per_class[clas] 404 | class_acc[clas] = acc 405 | avg_acc = np.mean(list(class_acc.values())) 406 | 407 | 408 | 409 | acc_list.append(100 * metric_logger.accuracy) 410 | with open(cfg.log_path, 'a+') as f: 411 | f.write(json.dumps({ 412 | 'task': task_id, 413 | 'acc': round(100 * metric_logger.accuracy, 2), 414 | 'avg_acc': round(100 * metric_logger.average_incremental_accuracy, 2), 415 | 'forgetting': round(100 * metric_logger.forgetting, 6), 416 | 'acc_per_task': [round(100 * acc_t, 2) for acc_t in metric_logger.accuracy_per_task], 417 | 'bwt': round(100 * metric_logger.backward_transfer, 2), 418 | 'fwt': round(100 * metric_logger.forward_transfer, 2), 419 | }) + '\n') 420 | metric_logger.end_task() 421 | torch.save(model.state_dict(), f'final_model.pth') 422 | with open(cfg.log_path, 'a+') as f: 423 | f.write(json.dumps({ 424 | 'last': round(acc_list[-1], 2), 425 | 'avg': round(statistics.mean(acc_list), 2) 426 | }) + '\n') 427 | 428 | 429 | 430 | def run_domain_incremental(cfg, device): 431 | 432 | model = model = load_model(cfg, device) 433 | eval_dataset, classes_names = build_cl_scenarios( 434 | cfg, is_train=False, transforms=model.transforms 435 | ) 436 | model.tokenize(classes_names) 437 | 438 | with open(cfg.log_path, 'w+') as f: 439 | pass 440 | 441 | logger = Logger(list_subsets=["test"]) 442 | logging.info(f">>> Evaluation scenario length is {len(eval_dataset)}") 443 | for task_id, _ in enumerate(eval_dataset): 444 | 445 | dataset_val = eval_dataset[:task_id + 1] 446 | eval_loader = DataLoader(dataset_val, batch_size=cfg.batch_size) 447 | for input, target, task_ids in tqdm(eval_loader): 448 | input, target = input.to(device), target.to(device) 449 | output = torch.from_numpy(model(input)) 450 | logger.add([output.cpu().argmax(dim=1), target.cpu(), task_ids], subset='test') 451 | 452 | with open(cfg.log_path, 'a+') as f: 453 | f.write(json.dumps({ 454 | 'task': task_id, 455 | 'acc': round(100 * logger.accuracy, 2), 456 | }) + '\n') 457 | 458 | logger.end_task() 459 | 460 | def run_task_agnostic(): 461 | pass 462 | 463 | 464 | 465 | @hydra.main(config_path=None, config_name=None, version_base="1.1") 466 | def continual_clip(cfg: DictConfig) -> None: 467 | seed_everything(cfg.seed) 468 | cfg.workdir = utils.get_workdir(path=os.getcwd()) 469 | cfg.dataset_root = os.path.join(cfg.workdir, cfg.dataset_root) 470 | 471 | utils.save_config(cfg) 472 | with open(cfg.log_path, 'w+') as f: 473 | pass 474 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 475 | 476 | if cfg.scenario == "class": 477 | run_class_incremental(cfg, device) 478 | 479 | elif cfg.scenario == "domain": 480 | run_domain_incremental(cfg, device) 481 | 482 | elif cfg.scenario == "task-agnostic": 483 | NotImplementedError("Method has not been implemented. Soon be added.") 484 | 485 | else: 486 | ValueError(f"You have entered `{cfg.scenario}` which is not a defined scenario, " 487 | "please choose from {{'class', 'domain', 'task-agnostic'}}.") 488 | 489 | 490 | 491 | 492 | 493 | 494 | 495 | 496 | 497 | 498 | 499 | 500 | 501 | 502 | 503 | 504 | 505 | 506 | 507 | 508 | 509 | 510 | if __name__ == "__main__": 511 | continual_clip() 512 | -------------------------------------------------------------------------------- /loraclip/model.py: -------------------------------------------------------------------------------- 1 | # content adapted from CLIP's official github: openai/CLIP 2 | 3 | import pdb 4 | from .loralib import layers as lora 5 | from .loralib import utils as lora_utils 6 | 7 | from collections import OrderedDict 8 | from typing import Tuple, Union 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import nn 14 | 15 | 16 | class Bottleneck(nn.Module): 17 | expansion = 4 18 | 19 | def __init__(self, inplanes, planes, stride=1): 20 | super().__init__() 21 | 22 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 23 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 24 | self.bn1 = nn.BatchNorm2d(planes) 25 | self.relu1 = nn.ReLU(inplace=True) 26 | 27 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | self.relu2 = nn.ReLU(inplace=True) 30 | 31 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 32 | 33 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 34 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 35 | self.relu3 = nn.ReLU(inplace=True) 36 | 37 | self.downsample = None 38 | self.stride = stride 39 | 40 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 41 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 42 | self.downsample = nn.Sequential(OrderedDict([ 43 | ("-1", nn.AvgPool2d(stride)), 44 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 45 | ("1", nn.BatchNorm2d(planes * self.expansion)) 46 | ])) 47 | 48 | def forward(self, x: torch.Tensor): 49 | identity = x 50 | 51 | out = self.relu1(self.bn1(self.conv1(x))) 52 | out = self.relu2(self.bn2(self.conv2(out))) 53 | out = self.avgpool(out) 54 | out = self.bn3(self.conv3(out)) 55 | 56 | if self.downsample is not None: 57 | identity = self.downsample(x) 58 | 59 | out += identity 60 | out = self.relu3(out) 61 | return out 62 | 63 | 64 | class AttentionPool2d(nn.Module): 65 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 66 | super().__init__() 67 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 68 | self.k_proj = nn.Linear(embed_dim, embed_dim) 69 | self.q_proj = nn.Linear(embed_dim, embed_dim) 70 | self.v_proj = nn.Linear(embed_dim, embed_dim) 71 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 72 | self.num_heads = num_heads 73 | 74 | def forward(self, x): 75 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC 76 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 77 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 78 | x, _ = F.multi_head_attention_forward( 79 | query=x[:1], key=x, value=x, 80 | embed_dim_to_check=x.shape[-1], 81 | num_heads=self.num_heads, 82 | q_proj_weight=self.q_proj.weight, 83 | k_proj_weight=self.k_proj.weight, 84 | v_proj_weight=self.v_proj.weight, 85 | in_proj_weight=None, 86 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 87 | bias_k=None, 88 | bias_v=None, 89 | add_zero_attn=False, 90 | dropout_p=0, 91 | out_proj_weight=self.c_proj.weight, 92 | out_proj_bias=self.c_proj.bias, 93 | use_separate_proj_weight=True, 94 | training=self.training, 95 | need_weights=False 96 | ) 97 | return x.squeeze(0) 98 | 99 | 100 | class ModifiedResNet(nn.Module): 101 | """ 102 | A ResNet class that is similar to torchvision's but contains the following changes: 103 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 104 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 105 | - The final pooling layer is a QKV attention instead of an average pool 106 | """ 107 | 108 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 109 | super().__init__() 110 | self.output_dim = output_dim 111 | self.input_resolution = input_resolution 112 | 113 | # the 3-layer stem 114 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 115 | self.bn1 = nn.BatchNorm2d(width // 2) 116 | self.relu1 = nn.ReLU(inplace=True) 117 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 118 | self.bn2 = nn.BatchNorm2d(width // 2) 119 | self.relu2 = nn.ReLU(inplace=True) 120 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 121 | self.bn3 = nn.BatchNorm2d(width) 122 | self.relu3 = nn.ReLU(inplace=True) 123 | self.avgpool = nn.AvgPool2d(2) 124 | 125 | # residual layers 126 | self._inplanes = width # this is a *mutable* variable used during construction 127 | self.layer1 = self._make_layer(width, layers[0]) 128 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 129 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 130 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 131 | 132 | embed_dim = width * 32 # the ResNet feature dimension 133 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 134 | 135 | def _make_layer(self, planes, blocks, stride=1): 136 | layers = [Bottleneck(self._inplanes, planes, stride)] 137 | 138 | self._inplanes = planes * Bottleneck.expansion 139 | for _ in range(1, blocks): 140 | layers.append(Bottleneck(self._inplanes, planes)) 141 | 142 | return nn.Sequential(*layers) 143 | 144 | def forward(self, x): 145 | def stem(x): 146 | x = self.relu1(self.bn1(self.conv1(x))) 147 | x = self.relu2(self.bn2(self.conv2(x))) 148 | x = self.relu3(self.bn3(self.conv3(x))) 149 | x = self.avgpool(x) 150 | return x 151 | 152 | x = x.type(self.conv1.weight.dtype) 153 | x = stem(x) 154 | x = self.layer1(x) 155 | x = self.layer2(x) 156 | x = self.layer3(x) 157 | x = self.layer4(x) 158 | x = self.attnpool(x) 159 | 160 | return x 161 | 162 | 163 | class LayerNorm(nn.LayerNorm): 164 | """Subclass torch's LayerNorm to handle fp16.""" 165 | 166 | def forward(self, x: torch.Tensor): 167 | orig_type = x.dtype 168 | ret = super().forward(x.type(torch.float32)) 169 | return ret.type(orig_type) 170 | 171 | 172 | class QuickGELU(nn.Module): 173 | def forward(self, x: torch.Tensor): 174 | return x * torch.sigmoid(1.702 * x) 175 | 176 | 177 | class ResidualAttentionBlock(nn.Module): 178 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 179 | super().__init__() 180 | 181 | self.attn = nn.MultiheadAttention(d_model, n_head) 182 | self.ln_1 = LayerNorm(d_model) 183 | self.mlp = nn.Sequential(OrderedDict([ 184 | ("c_fc", nn.Linear(d_model, d_model * 4)), 185 | ("gelu", QuickGELU()), 186 | ("c_proj", nn.Linear(d_model * 4, d_model)) 187 | ])) 188 | self.ln_2 = LayerNorm(d_model) 189 | self.attn_mask = attn_mask 190 | 191 | def attention(self, x: torch.Tensor): 192 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 193 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 194 | 195 | def forward(self, x: torch.Tensor): 196 | x = x + self.attention(self.ln_1(x)) 197 | x = x + self.mlp(self.ln_2(x)) 198 | return x 199 | 200 | 201 | # LoRA implementation of ResidualAttentionBlock: 202 | class LoRAResidualAttentionBlock(nn.Module): 203 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, r=4, only_kv=False,mlp=False): 204 | super().__init__() 205 | 206 | self.attn = lora.MultiheadAttention(d_model, n_head, r=r, only_kv=only_kv, mlp=mlp) # LoRA rank set as 4 207 | self.ln_1 = LayerNorm(d_model) 208 | if only_kv: 209 | self.mlp = nn.Sequential(OrderedDict([ 210 | ("c_fc", nn.Linear(d_model, d_model * 4)), 211 | ("gelu", QuickGELU()), 212 | ("c_proj", nn.Linear(d_model * 4, d_model)) 213 | ])) 214 | else: 215 | self.mlp = nn.Sequential(OrderedDict([ 216 | ("c_fc", lora.Linear(d_model, d_model * 4, r=r)), 217 | ("gelu", QuickGELU()), 218 | ("c_proj", lora.Linear(d_model * 4, d_model, r=r)) 219 | ])) 220 | if mlp: 221 | self.adapter_mlp = nn.Sequential(OrderedDict([ 222 | ("c_fc", nn.Linear(d_model, r, bias=False)), 223 | ("gelu", QuickGELU()), 224 | ("c_proj", nn.Linear(r, d_model, bias=False)) 225 | ])) 226 | nn.init.zeros_(self.adapter_mlp.c_proj.weight) 227 | self.mlp_flag = mlp 228 | self.ln_2 = LayerNorm(d_model) 229 | self.attn_mask = attn_mask 230 | 231 | def attention(self, x: torch.Tensor): 232 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 233 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 234 | 235 | def forward(self, x: torch.Tensor): 236 | x = x + self.attention(self.ln_1(x)) 237 | if self.mlp_flag: 238 | x = x + self.adapter_mlp(self.ln_2(x)) + self.mlp(self.ln_2(x)) 239 | else: 240 | x = x + self.mlp(self.ln_2(x)) 241 | return x 242 | 243 | 244 | class Transformer(nn.Module): 245 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 246 | super().__init__() 247 | self.width = width 248 | self.layers = layers 249 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 250 | 251 | def forward(self, x: torch.Tensor): 252 | return self.resblocks(x) 253 | 254 | 255 | # LoRA implementation of Transformer: 256 | class LoRATransformer(nn.Module): 257 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, r = 4, only_kv=False,mlp=False): 258 | super().__init__() 259 | self.width = width 260 | self.layers = layers 261 | self.resblocks = nn.Sequential(*[LoRAResidualAttentionBlock(width, heads, attn_mask, r=r, only_kv=only_kv, mlp=mlp) for _ in range(layers)]) 262 | 263 | def forward(self, x: torch.Tensor): 264 | return self.resblocks(x) 265 | 266 | 267 | class VisionTransformer(nn.Module): 268 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 269 | super().__init__() 270 | self.input_resolution = input_resolution 271 | self.output_dim = output_dim 272 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 273 | 274 | scale = width ** -0.5 275 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 276 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 277 | self.ln_pre = LayerNorm(width) 278 | 279 | self.transformer = Transformer(width, layers, heads) 280 | 281 | self.ln_post = LayerNorm(width) 282 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 283 | 284 | def forward(self, x: torch.Tensor): 285 | x = self.conv1(x) # shape = [*, width, grid, grid] 286 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 287 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 288 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 289 | x = x + self.positional_embedding.to(x.dtype) 290 | x = self.ln_pre(x) 291 | 292 | x = x.permute(1, 0, 2) # NLD -> LND 293 | x = self.transformer(x) 294 | x = x.permute(1, 0, 2) # LND -> NLD 295 | 296 | x = self.ln_post(x[:, 0, :]) 297 | 298 | if self.proj is not None: 299 | x = x @ self.proj 300 | 301 | return x 302 | 303 | 304 | class LoRAVisionTransformer(nn.Module): 305 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, r: int, only_kv=False, mlp=False): 306 | super().__init__() 307 | self.input_resolution = input_resolution 308 | self.output_dim = output_dim 309 | if only_kv: 310 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 311 | else: 312 | self.conv1 = lora.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 313 | 314 | scale = width ** -0.5 315 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 316 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 317 | self.ln_pre = LayerNorm(width) 318 | 319 | self.transformer = LoRATransformer(width, layers, heads, only_kv=only_kv, r=r, mlp=mlp) 320 | # self.transformer = Transformer(width, layers, heads) 321 | 322 | self.ln_post = LayerNorm(width) 323 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 324 | 325 | def forward(self, x: torch.Tensor): 326 | x = self.conv1(x) # shape = [*, width, grid, grid] 327 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 328 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 329 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 330 | x = x + self.positional_embedding.to(x.dtype) 331 | x = self.ln_pre(x) 332 | 333 | x = x.permute(1, 0, 2) # NLD -> LND 334 | x = self.transformer(x) 335 | x = x.permute(1, 0, 2) # LND -> NLD 336 | 337 | x = self.ln_post(x[:, 0, :]) 338 | 339 | if self.proj is not None: 340 | x = x @ self.proj 341 | 342 | return x 343 | 344 | 345 | class CLIP(nn.Module): 346 | def __init__(self, 347 | embed_dim: int, 348 | # vision 349 | image_resolution: int, 350 | vision_layers: Union[Tuple[int, int, int, int], int], 351 | vision_width: int, 352 | vision_patch_size: int, 353 | # text 354 | context_length: int, 355 | vocab_size: int, 356 | transformer_width: int, 357 | transformer_heads: int, 358 | transformer_layers: int 359 | ): 360 | super().__init__() 361 | 362 | self.context_length = context_length 363 | 364 | if isinstance(vision_layers, (tuple, list)): 365 | vision_heads = vision_width * 32 // 64 366 | self.visual = ModifiedResNet( 367 | layers=vision_layers, 368 | output_dim=embed_dim, 369 | heads=vision_heads, 370 | input_resolution=image_resolution, 371 | width=vision_width 372 | ) 373 | else: 374 | vision_heads = vision_width // 64 375 | self.visual = VisionTransformer( 376 | input_resolution=image_resolution, 377 | patch_size=vision_patch_size, 378 | width=vision_width, 379 | layers=vision_layers, 380 | heads=vision_heads, 381 | output_dim=embed_dim 382 | ) 383 | 384 | self.transformer = Transformer( 385 | width=transformer_width, 386 | layers=transformer_layers, 387 | heads=transformer_heads, 388 | attn_mask=self.build_attention_mask() 389 | ) 390 | 391 | self.vocab_size = vocab_size 392 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 393 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 394 | self.ln_final = LayerNorm(transformer_width) 395 | 396 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 397 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 398 | 399 | self.initialize_parameters() 400 | 401 | def initialize_parameters(self): 402 | nn.init.normal_(self.token_embedding.weight, std=0.02) 403 | nn.init.normal_(self.positional_embedding, std=0.01) 404 | 405 | if isinstance(self.visual, ModifiedResNet): 406 | if self.visual.attnpool is not None: 407 | std = self.visual.attnpool.c_proj.in_features ** -0.5 408 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 409 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 410 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 411 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 412 | 413 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 414 | for name, param in resnet_block.named_parameters(): 415 | if name.endswith("bn3.weight"): 416 | nn.init.zeros_(param) 417 | 418 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 419 | attn_std = self.transformer.width ** -0.5 420 | fc_std = (2 * self.transformer.width) ** -0.5 421 | for block in self.transformer.resblocks: 422 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 423 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 424 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 425 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 426 | 427 | if self.text_projection is not None: 428 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 429 | 430 | def build_attention_mask(self): 431 | # lazily create causal attention mask, with full attention between the vision tokens 432 | # pytorch uses additive attention mask; fill with -inf 433 | mask = torch.empty(self.context_length, self.context_length) 434 | mask.fill_(float("-inf")) 435 | mask.triu_(1) # zero out the lower diagonal 436 | return mask 437 | 438 | @property 439 | def dtype(self): 440 | return self.visual.conv1.weight.dtype 441 | 442 | def encode_image(self, image): 443 | return self.visual(image.type(self.dtype)) 444 | 445 | def encode_text(self, text): 446 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 447 | 448 | x = x + self.positional_embedding.type(self.dtype) 449 | x = x.permute(1, 0, 2) # NLD -> LND 450 | x = self.transformer(x) 451 | x = x.permute(1, 0, 2) # LND -> NLD 452 | x = self.ln_final(x).type(self.dtype) 453 | 454 | # x.shape = [batch_size, n_ctx, transformer.width] 455 | # take features from the eot embedding (eot_token is the highest number in each sequence) 456 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 457 | 458 | return x 459 | 460 | def forward(self, image, text): 461 | image_features = self.encode_image(image) 462 | text_features = self.encode_text(text) 463 | 464 | # normalized features 465 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 466 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 467 | 468 | # cosine similarity as logits 469 | logit_scale = self.logit_scale.exp() 470 | logits_per_image = logit_scale * image_features @ text_features.t() 471 | logits_per_text = logits_per_image.t() 472 | 473 | # shape = [global_batch_size, global_batch_size] 474 | return logits_per_image, logits_per_text 475 | 476 | 477 | class LoRACLIP(nn.Module): 478 | def __init__(self, 479 | embed_dim: int, 480 | # vision 481 | image_resolution: int, 482 | vision_layers: Union[Tuple[int, int, int, int], int], 483 | vision_width: int, 484 | vision_patch_size: int, 485 | # text 486 | context_length: int, 487 | vocab_size: int, 488 | transformer_width: int, 489 | transformer_heads: int, 490 | transformer_layers: int, 491 | r: int, 492 | lora_mode: str 493 | ): 494 | super().__init__() 495 | 496 | self.context_length = context_length 497 | 498 | if isinstance(vision_layers, (tuple, list)): 499 | vision_heads = vision_width * 32 // 64 500 | self.visual = ModifiedResNet( 501 | layers=vision_layers, 502 | output_dim=embed_dim, 503 | heads=vision_heads, 504 | input_resolution=image_resolution, 505 | width=vision_width 506 | ) 507 | else: 508 | vision_heads = vision_width // 64 509 | 510 | if "vision" in lora_mode: 511 | self.visual = LoRAVisionTransformer( 512 | input_resolution=image_resolution, 513 | patch_size=vision_patch_size, 514 | width=vision_width, 515 | layers=vision_layers, 516 | heads=vision_heads, 517 | output_dim=embed_dim, 518 | r=r, 519 | only_kv=("only_kv" in lora_mode), 520 | mlp="mlp" in lora_mode, 521 | ) 522 | else: 523 | self.visual = VisionTransformer( 524 | input_resolution=image_resolution, 525 | patch_size=vision_patch_size, 526 | width=vision_width, 527 | layers=vision_layers, 528 | heads=vision_heads, 529 | output_dim=embed_dim 530 | ) 531 | 532 | if "text" in lora_mode: 533 | self.transformer = LoRATransformer( 534 | width=transformer_width, 535 | layers=transformer_layers, 536 | heads=transformer_heads, 537 | attn_mask=self.build_attention_mask(), 538 | r = r, 539 | only_kv=("only_kv" in lora_mode), 540 | mlp="mlp" in lora_mode, 541 | ) 542 | 543 | else: 544 | self.transformer = Transformer( 545 | width=transformer_width, 546 | layers=transformer_layers, 547 | heads=transformer_heads, 548 | attn_mask=self.build_attention_mask() 549 | ) 550 | 551 | self.vocab_size = vocab_size 552 | 553 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 554 | self.ln_final = LayerNorm(transformer_width) 555 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 556 | 557 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 558 | 559 | if "text" in lora_mode: 560 | if "only_kv" in lora_mode: 561 | self.lora_text_projection = nn.Linear(transformer_width, embed_dim, bias=False) 562 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 563 | else: 564 | self.lora_text_projection = lora.Linear(transformer_width, embed_dim, r=r, bias=False) 565 | self.token_embedding = lora.Embedding(vocab_size, transformer_width, r=r) 566 | 567 | else: 568 | self.lora_text_projection = nn.Linear(transformer_width, embed_dim, bias=False) 569 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 570 | 571 | self.initialize_parameters() 572 | 573 | def initialize_parameters(self): 574 | nn.init.normal_(self.token_embedding.weight, std=0.02) 575 | nn.init.normal_(self.positional_embedding, std=0.01) 576 | 577 | if isinstance(self.visual, ModifiedResNet): 578 | if self.visual.attnpool is not None: 579 | std = self.visual.attnpool.c_proj.in_features ** -0.5 580 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 581 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 582 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 583 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 584 | 585 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 586 | for name, param in resnet_block.named_parameters(): 587 | if name.endswith("bn3.weight"): 588 | nn.init.zeros_(param) 589 | 590 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 591 | attn_std = self.transformer.width ** -0.5 592 | fc_std = (2 * self.transformer.width) ** -0.5 593 | for block in self.transformer.resblocks: 594 | pass 595 | # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 596 | # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 597 | # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 598 | # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 599 | 600 | if self.text_projection is not None: 601 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 602 | 603 | def build_attention_mask(self): 604 | # lazily create causal attention mask, with full attention between the vision tokens 605 | # pytorch uses additive attention mask; fill with -inf 606 | mask = torch.empty(self.context_length, self.context_length) 607 | mask.fill_(float("-inf")) 608 | mask.triu_(1) # zero out the lower diagonal 609 | return mask 610 | 611 | @property 612 | def dtype(self): 613 | return self.visual.conv1.weight.dtype 614 | 615 | def encode_image(self, image): 616 | return self.visual(image.type(self.dtype)) 617 | 618 | def encode_text(self, text): 619 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 620 | 621 | x = x + self.positional_embedding.type(self.dtype) 622 | x = x.permute(1, 0, 2) # NLD -> LND 623 | x = self.transformer(x) 624 | x = x.permute(1, 0, 2) # LND -> NLD 625 | x = self.ln_final(x).type(self.dtype) 626 | 627 | # x.shape = [batch_size, n_ctx, transformer.width] 628 | # take features from the eot embedding (eot_token is the highest number in each sequence) 629 | # x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 630 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] 631 | x = self.lora_text_projection(x) 632 | return x 633 | 634 | def forward(self, image, text): 635 | image_features = self.encode_image(image) 636 | text_features = self.encode_text(text) 637 | 638 | # normalized features 639 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 640 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 641 | 642 | # cosine similarity as logits 643 | logit_scale = self.logit_scale.exp() 644 | logits_per_image = logit_scale * image_features @ text_features.t() 645 | logits_per_text = logits_per_image.t() 646 | 647 | # shape = [global_batch_size, global_batch_size] 648 | return logits_per_image, logits_per_text 649 | 650 | def convert_weights(model: nn.Module): 651 | """Convert applicable model parameters to fp16""" 652 | 653 | def _convert_weights_to_fp16(l): 654 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 655 | l.weight.data = l.weight.data.half() 656 | if l.bias is not None: 657 | l.bias.data = l.bias.data.half() 658 | 659 | if isinstance(l, nn.MultiheadAttention): 660 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 661 | tensor = getattr(l, attr) 662 | if tensor is not None: 663 | tensor.data = tensor.data.half() 664 | 665 | for name in ["text_projection", "proj"]: 666 | if hasattr(l, name): 667 | attr = getattr(l, name) 668 | if attr is not None: 669 | attr.data = attr.data.half() 670 | 671 | model.apply(_convert_weights_to_fp16) 672 | 673 | def convert_weights_lora_model(model: nn.Module): 674 | """Convert applicable model parameters to fp16""" 675 | 676 | def _convert_weights_to_fp16(l): 677 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 678 | l.weight.data = l.weight.data.half() 679 | if l.bias is not None: 680 | l.bias.data = l.bias.data.half() 681 | 682 | if isinstance(l, lora.MultiheadAttention): 683 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v", *[f"{s}_proj_weight_lora_A" for s in ["in", "q", "k", "v"]], *[f"{s}_proj_weight_lora_B" for s in ["in", "q", "k", "v"]]]: 684 | tensor = getattr(l, attr) 685 | if tensor is not None: 686 | tensor.data = tensor.data.half() 687 | 688 | for name in ["text_projection", "proj"]: 689 | if hasattr(l, name): 690 | attr = getattr(l, name) 691 | if attr is not None: 692 | attr.data = attr.data.half() 693 | 694 | model.apply(_convert_weights_to_fp16) 695 | 696 | 697 | def build_model(state_dict: dict): 698 | vit = "visual.proj" in state_dict 699 | 700 | if vit: 701 | vision_width = state_dict["visual.conv1.weight"].shape[0] 702 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 703 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 704 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 705 | image_resolution = vision_patch_size * grid_size 706 | else: 707 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 708 | vision_layers = tuple(counts) 709 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 710 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 711 | vision_patch_size = None 712 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 713 | image_resolution = output_width * 32 714 | 715 | embed_dim = state_dict["text_projection"].shape[1] 716 | context_length = state_dict["positional_embedding"].shape[0] 717 | vocab_size = state_dict["token_embedding.weight"].shape[0] 718 | transformer_width = state_dict["ln_final.weight"].shape[0] 719 | transformer_heads = transformer_width // 64 720 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) 721 | 722 | model = CLIP( 723 | embed_dim, 724 | image_resolution, vision_layers, vision_width, vision_patch_size, 725 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 726 | ) 727 | 728 | for key in ["input_resolution", "context_length", "vocab_size"]: 729 | if key in state_dict: 730 | del state_dict[key] 731 | 732 | convert_weights(model) 733 | model.load_state_dict(state_dict) 734 | return model.eval() 735 | 736 | def build_LoRA_model(state_dict: dict, r: int, lora_mode: str): 737 | vit = "visual.proj" in state_dict 738 | 739 | if vit: 740 | vision_width = state_dict["visual.conv1.weight"].shape[0] 741 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 742 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 743 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 744 | image_resolution = vision_patch_size * grid_size 745 | else: 746 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 747 | vision_layers = tuple(counts) 748 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 749 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 750 | vision_patch_size = None 751 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 752 | image_resolution = output_width * 32 753 | if "only_kv" in lora_mode: 754 | for name, param in list(state_dict.items()): 755 | if "transformer.resblocks." in name and 'in_proj_weight' in name: 756 | # pdb.set_trace() 757 | shape_chunk = param.shape[0]//3 758 | state_dict[name.replace("in_proj_weight", "q_proj_weight")] = param[:1*shape_chunk].clone() 759 | state_dict[name.replace("in_proj_weight", "k_proj_weight")] = param[1*shape_chunk:2*shape_chunk].clone() 760 | state_dict[name.replace("in_proj_weight", "v_proj_weight")] = param[2*shape_chunk:3*shape_chunk].clone() 761 | del state_dict[name] 762 | # if "transformer.resblocks" in name and 'in_proj_weight' in name: 763 | # shape_chunk = param.shape[0]//3 764 | # state_dict[name.replace("in_proj_weight", "q_proj_weight")] = param[:1*shape_chunk].clone() 765 | # state_dict[name.replace("in_proj_weight", "k_proj_weight")] = param[1*shape_chunk:2*shape_chunk].clone() 766 | # state_dict[name.replace("in_proj_weight", "v_proj_weight")] = param[2*shape_chunk:3*shape_chunk].clone() 767 | # del state_dict[name] 768 | embed_dim = state_dict["text_projection"].shape[1] 769 | context_length = state_dict["positional_embedding"].shape[0] 770 | vocab_size = state_dict["token_embedding.weight"].shape[0] 771 | transformer_width = state_dict["ln_final.weight"].shape[0] 772 | transformer_heads = transformer_width // 64 773 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) 774 | 775 | model = LoRACLIP( 776 | embed_dim, 777 | image_resolution, vision_layers, vision_width, vision_patch_size, 778 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, 779 | r, lora_mode 780 | ) 781 | 782 | for key in ["input_resolution", "context_length", "vocab_size"]: 783 | if key in state_dict: 784 | del state_dict[key] 785 | 786 | new_state_dict = state_dict 787 | new_state_dict["lora_text_projection.weight"] = state_dict["text_projection"].T 788 | # pdb.set_trace() 789 | 790 | res = model.load_state_dict(new_state_dict, strict=False) 791 | missing_keys = res.missing_keys 792 | unexpected_keys = res.unexpected_keys 793 | missing_keys = [x for x in missing_keys if 'lora_' not in x] # ignore LoRA extra weights 794 | 795 | print("Model loaded") 796 | if len(missing_keys) != 0: 797 | print(f"Missing keys: {missing_keys}") 798 | 799 | if len(unexpected_keys) != 0: 800 | print(f"Unexpected keys: {unexpected_keys}") 801 | 802 | print(" ") 803 | 804 | # here we mark only lora parameters as trainable 805 | # for name, param in model.named_parameters(): 806 | # if param.requires_grad: 807 | # print(name) 808 | lora_utils.mark_only_lora_as_trainable(model) 809 | # 冻结text_proj 以及位置编码等 810 | ##--------------------------------------------------------- 811 | ## 812 | 813 | for name, param in model.named_parameters(): 814 | # if param.requires_grad: 815 | # print(name) 816 | # pdb.set_trace() 817 | if 'lora_text_projection' in name: 818 | param.requires_grad = False 819 | # some caveats for loading a model for fine-tuning: 820 | # if "text" in lora_mode: 821 | # for name, param in model.named_parameters(): 822 | # if "positional_embedding" in name: 823 | # param.requires_grad = True 824 | # if "text_projection" in name: 825 | # param.requies_grad = True 826 | # if "logit_scale" in name: 827 | # param.requires_grad = True 828 | 829 | # if "vision" in lora_mode: 830 | # for name, param in model.named_parameters(): 831 | # if "visual.proj" in name: 832 | # param.requires_grad = True 833 | # if "visual.class_embedding" in name: 834 | # param.requires_grad = True 835 | # if "visual.positional_embedding" in name: 836 | # param.requires_grad = True 837 | ## 838 | ## --------------------------------------------------------- 839 | # convert_weights_lora_model(model) 840 | return model.eval() --------------------------------------------------------------------------------