├── requirements.txt ├── src ├── models │ ├── utils.py │ └── clip.py ├── utils │ ├── __init__.py │ ├── wandb.py │ ├── metrics.py │ ├── config.py │ └── dist_utils.py ├── datasets │ ├── __init__.py │ ├── core_dataset.py │ ├── transform.py │ ├── base.py │ └── utils.py └── trainer │ ├── __init__.py │ ├── utils.py │ ├── snd_trainer.py │ └── base_trainer.py ├── configs ├── inference_config.yaml ├── snd_config_1_gpu.yaml └── snd_config_4_gpus.yaml ├── .pre-commit-config.yaml ├── main ├── evaluate.py └── train.py ├── scripts ├── inference.py ├── continually_train.py ├── train_and_eval.py └── utils.py ├── LICENSE ├── .gitignore ├── visualization └── plot.py └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.10.0 2 | pandas==2.2.3 3 | numpy==1.26.3 4 | pillow==10.3.0 5 | timm==0.9.12 6 | wandb==0.16.3 7 | omegaconf==2.3.0 8 | open_clip_torch==2.24.0 -------------------------------------------------------------------------------- /src/models/utils.py: -------------------------------------------------------------------------------- 1 | def disabled_train(self, mode=True): 2 | """Overwrite model.train with this function to make sure train/eval mode 3 | does not change anymore.""" 4 | return self 5 | -------------------------------------------------------------------------------- /configs/inference_config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | vit_base: ViT-B-16 3 | pretrained: openai 4 | use_pure_clip: True 5 | 6 | data: 7 | name: fgvc-aircraft 8 | root: /work/chu980802/data/classification 9 | split: 10 | test: 11 | split_name: test 12 | batch_size: 256 13 | shuffle: False 14 | drop_last: False 15 | 16 | task: 17 | seed: 1102 18 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import dump_config, flatten_config, get_config 2 | from .dist_utils import ( 3 | get_job_id, 4 | get_rank, 5 | get_world_size, 6 | init_distributed_mode, 7 | main_process, 8 | setup_seeds, 9 | is_main_process, 10 | ) 11 | from .metrics import AccuracyMeter 12 | from .wandb import local_logger, wandb_logger 13 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .core_dataset import ( 2 | DTD, 3 | UCF101, 4 | Caltech101, 5 | EuroSAT, 6 | FGVCAircraft, 7 | Flowers102, 8 | Food101, 9 | ImageNet, 10 | OxfordPets, 11 | StanfordCars, 12 | ) 13 | 14 | DATASET_MAPPING = { 15 | "dtd": DTD, 16 | "ucf-101": UCF101, 17 | "caltech-101": Caltech101, 18 | "eurosat": EuroSAT, 19 | "fgvc-aircraft": FGVCAircraft, 20 | "flowers-102": Flowers102, 21 | "food-101": Food101, 22 | "imagenet": ImageNet, 23 | "oxford-pets": OxfordPets, 24 | "stanford-cars": StanfordCars, 25 | } 26 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.0.1 6 | hooks: 7 | - id: check-yaml 8 | - id: end-of-file-fixer 9 | - id: trailing-whitespace 10 | - repo: https://github.com/PyCQA/isort 11 | rev: 5.11.5 12 | hooks: 13 | - id: isort 14 | args: ["--profile", "black"] 15 | - repo: https://github.com/ambv/black 16 | rev: 22.3.0 17 | hooks: 18 | - id: black 19 | entry: bash -c 'python -m black "$@"; git add -u' -- 20 | - repo: https://github.com/pycqa/flake8 21 | rev: 3.9.2 22 | hooks: 23 | - id: flake8 24 | entry: bash -c 'python -m flake8 "$@"; git add -u' -- 25 | -------------------------------------------------------------------------------- /main/evaluate.py: -------------------------------------------------------------------------------- 1 | from src.datasets.utils import get_dataloaders_from_config, load_class_name_list 2 | from src.models.clip import get_model 3 | from src.trainer import BaseTrainer as Trainer 4 | from src.utils import get_config, setup_seeds 5 | 6 | 7 | def main(config): 8 | setup_seeds(config.task.seed) 9 | 10 | class_name_list, num_classes_accumulation_dict = load_class_name_list(config) 11 | 12 | model = get_model( 13 | config, class_name_list, device="cuda", freeze=True, pretrained=False 14 | ) 15 | 16 | dataloaders = get_dataloaders_from_config(config, num_classes_accumulation_dict) 17 | 18 | trainer = Trainer(model, dataloaders, config) 19 | 20 | trainer.logging( 21 | local_desc="zero shot", 22 | test_acc=trainer.evaluate(trainer.test_loader), 23 | use_wandb=False, 24 | ) 25 | trainer.dump_results(print_result=True) 26 | 27 | 28 | if __name__ == "__main__": 29 | config = get_config(mode="evaluate") 30 | main(config) 31 | -------------------------------------------------------------------------------- /src/utils/wandb.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | 3 | import wandb 4 | from src.utils import is_main_process 5 | from src.utils.config import flatten_config 6 | 7 | 8 | def print_text_in_center_with_border(text, symbol="="): 9 | terminal_width, _ = shutil.get_terminal_size() 10 | padding = (terminal_width - len(text)) // 2 11 | 12 | print(symbol * padding + text + symbol * padding) 13 | 14 | 15 | def wandb_logger(func): 16 | def wrap(config): 17 | if is_main_process(): 18 | wandb.init( 19 | project=config.data.name, 20 | name=config.data.name, 21 | config=flatten_config(config), 22 | ) 23 | func(config) 24 | wandb.finish() 25 | else: 26 | func(config) 27 | 28 | return wrap 29 | 30 | 31 | def local_logger(func): 32 | def wrap(config): 33 | print() 34 | print_text_in_center_with_border(f" Dataset: {config.data.name} ") 35 | print() 36 | func(config) 37 | print() 38 | print_text_in_center_with_border(" Done! ") 39 | 40 | return wrap 41 | -------------------------------------------------------------------------------- /scripts/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | from scripts.utils import ( 5 | DEFAULT_DATASET_SEQ, 6 | eval_on_multiple_datasets_script, 7 | ) 8 | 9 | 10 | def main(args): 11 | args.dataset_seq = ( 12 | DEFAULT_DATASET_SEQ if args.dataset_seq is None else args.dataset_seq.split(",") 13 | ) 14 | 15 | eval_on_multiple_datasets_script( 16 | datasets=args.dataset_seq, 17 | pretrained_model_path=args.model_path, 18 | dump_result_path=args.model_path.parent / "eval_results.json", 19 | ) 20 | 21 | 22 | if __name__ == "__main__": 23 | p = argparse.ArgumentParser() 24 | p.add_argument( 25 | "--dataset_seq", 26 | type=str, 27 | default=None, 28 | help="the sequence of evaluation datasets, splitted by comma. Do not set it if you want to evaluate on all of our default datasets", 29 | ) 30 | p.add_argument( 31 | "--model_path", 32 | type=Path, 33 | required=True, 34 | help="specify the path of the model to evaluate", 35 | ) 36 | args = p.parse_args() 37 | 38 | main(args) 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yu-Chu Yu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from src.datasets.utils import get_dataloader, load_transform 2 | 3 | from .base_trainer import BaseTrainer, BaseKDTrainer 4 | from .snd_trainer import SnDTrainer 5 | 6 | TRAINER_MAPPING = { 7 | "snd": SnDTrainer, 8 | } 9 | 10 | 11 | def get_kd_trainer(model, dataloaders, config, teacher_models, job_id=None): 12 | if "ref_dataset" in config.method: 13 | train_transform, _ = load_transform(config) 14 | dataset_name, dataloader_config = ( 15 | config.method.ref_dataset, 16 | config.method.ref_dataset_config, 17 | ) 18 | 19 | dataloaders["ref"] = get_dataloader( 20 | dataset_name=dataset_name, 21 | root=config.data.root, 22 | mode=dataloader_config.split_name, 23 | transform=train_transform, 24 | seed=config.task.seed, 25 | distributed=config.task.get("distributed", False), 26 | **dataloader_config, 27 | ) 28 | 29 | meta_trainer_class = TRAINER_MAPPING.get(config.method.name, BaseKDTrainer) 30 | 31 | return meta_trainer_class(model, dataloaders, config, teacher_models, job_id) 32 | -------------------------------------------------------------------------------- /configs/snd_config_1_gpu.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | vit_base: ViT-B-16 3 | pretrained: openai 4 | use_pure_clip: True 5 | freeze_classification_head: True 6 | 7 | method: 8 | name: snd 9 | params: 10 | ratio_prev: 9 11 | ratio_pretrained: 0.5 12 | threshold: 0.2 13 | scale: 6 14 | label_smoothing: 0.0 15 | ref_dataset: imagenet 16 | ref_dataset_config: 17 | split_name: train 18 | batch_size: 64 19 | shuffle: True 20 | drop_last: True 21 | sample_num: 100000 22 | 23 | data: 24 | name: fgvc-aircraft 25 | root: /work/chu980802/data/classification 26 | split: 27 | train: 28 | split_name: train 29 | batch_size: 32 30 | shuffle: True 31 | drop_last: True 32 | val: 33 | split_name: val 34 | batch_size: 256 35 | shuffle: False 36 | drop_last: False 37 | test: 38 | split_name: test 39 | batch_size: 256 40 | shuffle: False 41 | drop_last: False 42 | 43 | task: 44 | # fine-tuning arguments 45 | init_lrs: 0.00001 46 | weight_decay: 0.0005 47 | seed: 1102 48 | max_epoch: 10 49 | max_iterations: 1000 50 | warmup_length: 0 51 | 52 | output_dir: "outputs" 53 | log_interval: 10 54 | 55 | world_size: 1 56 | dist_url: "env://" 57 | distributed: False 58 | use_dist_eval_sampler: False 59 | -------------------------------------------------------------------------------- /configs/snd_config_4_gpus.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | vit_base: ViT-B-16 3 | pretrained: openai 4 | use_pure_clip: True 5 | freeze_classification_head: True 6 | 7 | method: 8 | name: snd 9 | params: 10 | ratio_prev: 9 11 | ratio_pretrained: 0.5 12 | threshold: 0.2 13 | scale: 6 14 | label_smoothing: 0.0 15 | ref_dataset: imagenet 16 | ref_dataset_config: 17 | split_name: train 18 | batch_size: 16 19 | shuffle: True 20 | drop_last: True 21 | sample_num: 100000 22 | 23 | data: 24 | name: fgvc-aircraft 25 | root: /work/chu980802/data/classification 26 | split: 27 | train: 28 | split_name: train 29 | batch_size: 8 30 | shuffle: True 31 | drop_last: True 32 | val: 33 | split_name: val 34 | batch_size: 256 35 | shuffle: False 36 | drop_last: False 37 | test: 38 | split_name: test 39 | batch_size: 256 40 | shuffle: False 41 | drop_last: False 42 | 43 | task: 44 | # fine-tuning arguments 45 | init_lrs: 0.00001 46 | weight_decay: 0.0005 47 | seed: 1102 48 | max_epoch: 10 49 | max_iterations: 1000 50 | warmup_length: 0 51 | 52 | output_dir: "outputs" 53 | log_interval: 10 54 | 55 | world_size: 1 56 | dist_url: "env://" 57 | distributed: True 58 | use_dist_eval_sampler: False 59 | -------------------------------------------------------------------------------- /src/utils/metrics.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import numpy 4 | import torch 5 | 6 | 7 | @dataclass 8 | class AccuracyMeter: 9 | num_correct: int = 0 10 | num_total: int = 0 11 | 12 | def _astype(self, value): 13 | match value: 14 | case int() | float(): 15 | return AccuracyMeter(value, 1) 16 | case numpy.ndarray(): 17 | return AccuracyMeter(value.sum(), value.shape[0]) 18 | case torch.Tensor(): 19 | return AccuracyMeter(value.sum().item(), value.shape[0]) 20 | case AccuracyMeter(): 21 | return value 22 | case _: 23 | raise TypeError(f"Unsupported type: {type(value)}") 24 | 25 | def __add__(self, other): 26 | other = self._astype(other) 27 | return AccuracyMeter( 28 | self.num_correct + other.num_correct, self.num_total + other.num_total 29 | ) 30 | 31 | def __radd__(self, other): 32 | return self.__add__(other) 33 | 34 | def acc(self): 35 | return self.num_correct / self.num_total 36 | 37 | 38 | if __name__ == "__main__": 39 | import numpy as np 40 | import torch 41 | 42 | bool_list = [True, False, True, False] 43 | 44 | scores = AccuracyMeter() 45 | scores += np.array(bool_list) 46 | scores += torch.tensor(bool_list) 47 | scores += 1 48 | scores += 0.0 49 | 50 | # expected value: 5 / 10 = 0.5 51 | print(scores.acc()) 52 | -------------------------------------------------------------------------------- /main/train.py: -------------------------------------------------------------------------------- 1 | from src.datasets.utils import get_dataloaders_from_config, load_class_name_list 2 | from src.models.clip import get_model 3 | from src.trainer import get_kd_trainer 4 | from src.utils import ( 5 | get_config, 6 | get_job_id, 7 | init_distributed_mode, 8 | is_main_process, 9 | setup_seeds, 10 | wandb_logger, 11 | ) 12 | 13 | 14 | @wandb_logger 15 | def main(config): 16 | job_id = get_job_id() if is_main_process() else None 17 | setup_seeds(config.task.seed) 18 | 19 | class_name_list, num_classes_accumulation_dict = load_class_name_list(config) 20 | 21 | model = get_model( 22 | config, class_name_list, device="cuda", freeze=False, pretrained=False 23 | ) 24 | 25 | dataloaders = get_dataloaders_from_config(config, num_classes_accumulation_dict) 26 | 27 | teachers = dict() 28 | teachers["pretrained"] = get_model( 29 | config, class_name_list, device="cuda", freeze=True, pretrained=True 30 | ) 31 | 32 | if config.method.name == "snd": 33 | teachers["prev"] = get_model( 34 | config, class_name_list, device="cuda", freeze=True, pretrained=False 35 | ) 36 | 37 | trainer = get_kd_trainer(model, dataloaders, config, teachers, job_id) 38 | 39 | trainer.train(set_validation=False) 40 | 41 | trainer.logging( 42 | local_desc="fine-tuned", test_acc=trainer.evaluate(trainer.test_loader) 43 | ) 44 | trainer.dump_results() 45 | 46 | 47 | if __name__ == "__main__": 48 | config = get_config(mode="train") 49 | init_distributed_mode(config.task) 50 | main(config) 51 | -------------------------------------------------------------------------------- /src/utils/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | from omegaconf import OmegaConf 5 | 6 | 7 | class Config: 8 | def __init__(self, args, mode="train"): 9 | self.config = OmegaConf.merge( 10 | OmegaConf.load(args.cfg_path), 11 | self._build_user_config(args.options), 12 | {"mode": mode}, 13 | ) 14 | 15 | def _build_user_config(self, opts): 16 | return OmegaConf.from_dotlist([] if opts is None else opts) 17 | 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | "--cfg-path", 23 | default=f"configs/inference_config.yaml", 24 | help="path to configuration file.", 25 | ) 26 | parser.add_argument( 27 | "--options", 28 | nargs="+", 29 | help="override some settings in the used config, the key-value pair " 30 | "in xxx=yyy format will be merged into config file (deprecate), " 31 | "change to --cfg-options instead.", 32 | ) 33 | 34 | args = parser.parse_args() 35 | 36 | return args 37 | 38 | 39 | def flatten_config(d, parent_key="", sep="."): 40 | items = [] 41 | for k, v in d.items(): 42 | new_key = f"{parent_key}{sep}{k}" if parent_key else k 43 | if isinstance(v, dict): 44 | items.extend(flatten_config(v, new_key, sep=sep).items()) 45 | else: 46 | items.append((new_key, v)) 47 | return dict(items) 48 | 49 | 50 | def dump_config(config, path, flatten=False): 51 | dict_config = OmegaConf.to_container(config) 52 | if flatten: 53 | dict_config = flatten_config(dict_config) 54 | with open(path, "w") as f: 55 | json.dump(dict_config, f, indent=4) 56 | 57 | 58 | def get_config(mode="train"): 59 | return Config(parse_args(), mode=mode).config 60 | -------------------------------------------------------------------------------- /src/datasets/core_dataset.py: -------------------------------------------------------------------------------- 1 | from src.datasets.base import BaseClassificationDataset, BaseUnlabeledDataset 2 | 3 | 4 | class ImageNet(BaseClassificationDataset): 5 | dataset_name = "imagenet" 6 | annotation_filename = "imagenet_annotations.json" 7 | 8 | 9 | class Caltech101(BaseClassificationDataset): 10 | dataset_name = "caltech-101" 11 | annotation_filename = "caltech101_annotations.json" 12 | 13 | 14 | class OxfordPets(BaseClassificationDataset): 15 | dataset_name = "oxford-pets" 16 | annotation_filename = "oxfordpets_annotations.json" 17 | 18 | 19 | class StanfordCars(BaseClassificationDataset): 20 | dataset_name = "stanford-cars" 21 | annotation_filename = "stanfordcars_annotations.json" 22 | 23 | 24 | class Flowers102(BaseClassificationDataset): 25 | dataset_name = "flowers-102" 26 | annotation_filename = "flowers102_annotations.json" 27 | 28 | 29 | class Food101(BaseClassificationDataset): 30 | dataset_name = "food-101" 31 | annotation_filename = "food101_annotations.json" 32 | 33 | 34 | class FGVCAircraft(BaseClassificationDataset): 35 | dataset_name = "fgvc-aircraft" 36 | annotation_filename = "aircraft_annotations.json" 37 | 38 | 39 | class EuroSAT(BaseClassificationDataset): 40 | dataset_name = "eurosat" 41 | annotation_filename = "eurosat_annotations.json" 42 | 43 | 44 | class UCF101(BaseClassificationDataset): 45 | dataset_name = "ucf-101" 46 | annotation_filename = "ucf101_annotations.json" 47 | 48 | 49 | class DTD(BaseClassificationDataset): 50 | dataset_name = "dtd" 51 | annotation_filename = "dtd_annotations.json" 52 | 53 | 54 | if __name__ == "__main__": 55 | dataset = DTD(root="/work/chu980802/data/classification", mode="train") 56 | print(dataset[0]) 57 | print(len(dataset)) 58 | print(len(dataset.class_name_list)) 59 | -------------------------------------------------------------------------------- /scripts/continually_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import deque 3 | from copy import deepcopy 4 | from pathlib import Path 5 | 6 | from scripts.utils import DEFAULT_DATASET_SEQ, ContinualTrainer 7 | 8 | 9 | def parse_dataset_seq(args): 10 | dataset_seq = deque(deepcopy(DEFAULT_DATASET_SEQ)) 11 | dataset_seq.rotate(len(DEFAULT_DATASET_SEQ) - args.order) 12 | sub_output_dir = f"order_{args.order}" 13 | 14 | return dataset_seq, sub_output_dir 15 | 16 | 17 | def main(args): 18 | dataset_seq, sub_output_dir = parse_dataset_seq(args) 19 | 20 | continual_trainer = ContinualTrainer( 21 | config_path=args.config_path, 22 | training_dataset_seq=dataset_seq, 23 | sub_output_dir=sub_output_dir, 24 | output_root=args.output_root, 25 | max_iterations=args.max_iterations, 26 | distributed=args.distributed, 27 | nproc_per_node=args.nproc_per_node, 28 | method_config=args.method_config, 29 | ) 30 | 31 | continual_trainer.train_and_eval() 32 | 33 | 34 | if __name__ == "__main__": 35 | p = argparse.ArgumentParser() 36 | p.add_argument( 37 | "--config_path", type=str, default="configs/snd_config_4_gpus.yaml" 38 | ) 39 | 40 | p.add_argument("--max_iterations", type=int, default=1000) 41 | 42 | p.add_argument("--order", type=int, default=0) 43 | p.add_argument("--output_root", type=Path, default=Path("outputs")) 44 | 45 | p.add_argument("--distributed", action="store_true") 46 | p.add_argument("--nproc_per_node", type=int, default=1) 47 | 48 | p.add_argument("--method_config", nargs="+") 49 | args = p.parse_args() 50 | 51 | if args.method_config is not None: 52 | args.method_config = { 53 | k.split("=")[0]: k.split("=")[1] for k in args.method_config 54 | } 55 | 56 | main(args) 57 | -------------------------------------------------------------------------------- /scripts/train_and_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from scripts.utils import train_and_eval_script 4 | 5 | 6 | def main(args): 7 | train_and_eval_script( 8 | config_path=args.config_path, 9 | training_dataset=args.dataset, 10 | pretrained_dataset=args.pretrained_dataset, 11 | max_iterations=args.max_iterations, 12 | sub_output_dir=args.sub_output_dir, 13 | timestamp=args.timestamp, 14 | distributed=args.distributed, 15 | nproc_per_node=args.nproc_per_node, 16 | ) 17 | 18 | 19 | if __name__ == "__main__": 20 | p = argparse.ArgumentParser() 21 | p.add_argument( 22 | "--config_path", type=str, default="configs/snd_config_4_gpus.yaml", help="select the config file" 23 | ) 24 | p.add_argument( 25 | "--pretrained_dataset", 26 | type=str, 27 | default=None, 28 | help="the latest training dataset, this is specified to choose the corresponding model. Do not set it if you want to select the original pre-trained CLIP model.", 29 | ) 30 | p.add_argument( 31 | "--dataset", type=str, default="fgvc-aircraft", help="the dataset to train" 32 | ) 33 | p.add_argument( 34 | "--sub_output_dir", 35 | type=str, 36 | default="default", 37 | help="the sub-directory to save the training results, choose any name you want", 38 | ) 39 | 40 | p.add_argument("--max_iterations", type=int, default=1000, help="the maximum number of iterations to train.") 41 | p.add_argument( 42 | "--timestamp", 43 | type=str, 44 | default="latest", 45 | help="select the timestamp folder. This is used to select the model you fine-tuned previously. Do not set it if you want to select the latest model.", 46 | ) 47 | 48 | p.add_argument("--distributed", action="store_true", help="use distributed training") 49 | p.add_argument("--nproc_per_node", type=int, default=1, help="number of GPUs per node") 50 | 51 | args = p.parse_args() 52 | 53 | main(args) 54 | -------------------------------------------------------------------------------- /src/datasets/transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from open_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 3 | from open_clip.transform import PreprocessCfg, image_transform_v2 4 | from torchvision.transforms import ( 5 | CenterCrop, 6 | Compose, 7 | ConvertImageDtype, 8 | InterpolationMode, 9 | Normalize, 10 | PILToTensor, 11 | RandomResizedCrop, 12 | Resize, 13 | ToTensor, 14 | ) 15 | 16 | DEFAULT_PREPROCESS_CONFIG = { 17 | "size": (224, 224), 18 | "mode": "RGB", 19 | "mean": OPENAI_DATASET_MEAN, 20 | "std": OPENAI_DATASET_STD, 21 | "interpolation": "bicubic", 22 | "resize_mode": "shortest", 23 | "fill_color": 0, 24 | } 25 | 26 | 27 | RAW_TRANSFORM = Compose( 28 | [ 29 | CenterCrop(224), 30 | PILToTensor(), 31 | ConvertImageDtype(torch.float), 32 | ] 33 | ) 34 | 35 | 36 | def _convert_to_rgb(image): 37 | return image.convert("RGB") 38 | 39 | 40 | def original_clip_transform(n_px: int = 224, is_train: bool = False): 41 | normalize = Normalize( 42 | (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711) 43 | ) 44 | if is_train: 45 | return Compose( 46 | [ 47 | RandomResizedCrop( 48 | n_px, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC 49 | ), 50 | _convert_to_rgb, 51 | ToTensor(), 52 | normalize, 53 | ] 54 | ) 55 | else: 56 | return Compose( 57 | [ 58 | Resize(n_px, interpolation=InterpolationMode.BICUBIC), 59 | CenterCrop(n_px), 60 | _convert_to_rgb, 61 | ToTensor(), 62 | normalize, 63 | ] 64 | ) 65 | 66 | 67 | def load_transform(config, model_preprocess_config=None): 68 | if model_preprocess_config is None: 69 | model_preprocess_config = DEFAULT_PREPROCESS_CONFIG 70 | 71 | if config.data.get("use_original_clip_transform", False): 72 | train_transform = original_clip_transform(is_train=True) 73 | eval_transform = original_clip_transform(is_train=False) 74 | else: 75 | pp_cfg = PreprocessCfg(**model_preprocess_config) 76 | 77 | train_transform = image_transform_v2( 78 | pp_cfg, 79 | is_train=True, 80 | ) 81 | 82 | eval_transform = image_transform_v2( 83 | pp_cfg, 84 | is_train=False, 85 | ) 86 | 87 | return train_transform, eval_transform 88 | -------------------------------------------------------------------------------- /src/trainer/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class CosineSimilarityLoss(nn.CosineEmbeddingLoss): 9 | def forward(self, x, y): 10 | return super().forward(x, y, torch.ones(x.shape[0]).to(x.device)) 11 | 12 | 13 | class L2Loss(nn.Module): 14 | def __init__(self, reduce=None, square=False): 15 | super().__init__() 16 | self.reduce = reduce 17 | self.square = square 18 | 19 | def forward(self, x, y): 20 | loss = torch.pow(torch.norm(x - y, dim=-1), 2) 21 | if self.square: 22 | loss = loss**2 23 | if self.reduce == "mean": 24 | return loss.mean() 25 | return loss 26 | 27 | 28 | def get_optimizer(model, task_config): 29 | optim_params = model.get_params() 30 | 31 | num_parameters = 0 32 | for param_group in optim_params: 33 | for p in param_group["params"]: 34 | num_parameters += p.data.nelement() 35 | logging.info(f"number of trainable parameters: {num_parameters}") 36 | 37 | return torch.optim.AdamW( 38 | optim_params, 39 | weight_decay=float(task_config.weight_decay), 40 | ) 41 | 42 | 43 | class CosineLRScheduler(object): 44 | def __init__(self, optimizer, task_config, num_steps): 45 | self.current_step = 0 46 | self.optimizer = optimizer 47 | 48 | init_lrs = task_config.init_lrs 49 | self.init_lrs = ( 50 | init_lrs 51 | if isinstance(init_lrs, list) 52 | else [init_lrs for _ in optimizer.param_groups] 53 | ) 54 | 55 | self.warmup_length = task_config.warmup_length 56 | self.num_steps = num_steps 57 | self._current_lr = self.init_lrs[0] 58 | 59 | def step(self): 60 | for param_group, init_lr in zip(self.optimizer.param_groups, self.init_lrs): 61 | if self.current_step < self.warmup_length: 62 | param_group["lr"] = ( 63 | init_lr * (self.current_step + 1) / self.warmup_length 64 | ) 65 | else: 66 | e = self.current_step - self.warmup_length 67 | es = self.num_steps - self.warmup_length 68 | param_group["lr"] = 0.5 * (1 + np.cos(np.pi * e / es)) * init_lr 69 | 70 | self.current_step += 1 71 | self._current_lr = self.optimizer.param_groups[0]["lr"] 72 | 73 | def refresh(self): 74 | self.current_step = 0 75 | 76 | @property 77 | def current_lr(self): 78 | return self._current_lr 79 | -------------------------------------------------------------------------------- /src/trainer/snd_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from src.trainer.base_trainer import BaseKDTrainer 5 | from src.trainer.utils import L2Loss 6 | 7 | 8 | class SnDTrainer(BaseKDTrainer): 9 | def __init__(self, *args, **kwargs): 10 | super().__init__(*args, **kwargs) 11 | self.prev_teacher_model.eval() 12 | self.feature_criterion = L2Loss(reduce=None, square=False) 13 | self.num_valid_prev_data = 0 14 | 15 | @property 16 | def prev_teacher_model(self): 17 | return self._teachers["prev"] 18 | 19 | def snd_loss( 20 | self, 21 | images, 22 | labels, 23 | ratio_prev=9, 24 | ratio_pretrained=0.5, 25 | threshold=0.2, 26 | scale=6, 27 | label_smoothing=0.0, 28 | ): 29 | ref_images, _ = self.get_ref_data(self.ref_loader) 30 | base_loss, loss_dict = self.base_loss( 31 | images, labels, label_smoothing=label_smoothing 32 | ) 33 | 34 | student_ref_image_embedding = self.unwrapped_model(self.train_model).encode( 35 | images=ref_images 36 | ) 37 | 38 | with torch.no_grad(): 39 | ( 40 | pretrained_teacher_ref_image_embedding, 41 | _, 42 | _, 43 | ) = self.pretrained_teacher_model(ref_images, get_features=True) 44 | 45 | ( 46 | prev_teacher_ref_image_embedding, 47 | _, 48 | _, 49 | ) = self.prev_teacher_model(ref_images, get_features=True) 50 | 51 | pre_scores = torch.norm( 52 | pretrained_teacher_ref_image_embedding - prev_teacher_ref_image_embedding, 53 | dim=-1, 54 | ) 55 | 56 | self.num_valid_prev_data += (pre_scores > threshold).float().sum().item() 57 | 58 | scaled_scores = scale * (pre_scores - threshold) 59 | 60 | scores = nn.functional.sigmoid(scaled_scores).reshape(-1, 1) 61 | 62 | pretrained_kd_loss = self._get_kd_loss( 63 | student_ref_image_embedding, 64 | pretrained_teacher_ref_image_embedding, 65 | feature_criterion=self.feature_criterion, 66 | ) 67 | prev_kd_loss = self._get_kd_loss( 68 | student_ref_image_embedding, 69 | prev_teacher_ref_image_embedding, 70 | feature_criterion=self.feature_criterion, 71 | ) 72 | 73 | prev_kd_loss = (scores * prev_kd_loss).mean() 74 | 75 | pretrained_kd_loss = ((1 - scores) * pretrained_kd_loss).mean() 76 | 77 | return ( 78 | base_loss 79 | + ratio_prev * prev_kd_loss 80 | + ratio_pretrained * pretrained_kd_loss, 81 | { 82 | **loss_dict, 83 | "prev_kd_loss": prev_kd_loss.item(), 84 | "pretrained_kd_loss": pretrained_kd_loss.item(), 85 | "num_valid_prev_data": self.num_valid_prev_data, 86 | }, 87 | ) 88 | -------------------------------------------------------------------------------- /src/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import functools 3 | import os 4 | import random 5 | 6 | import numpy as np 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import torch.distributed as dist 10 | 11 | 12 | def setup_for_distributed(is_master): 13 | """ 14 | This function disables printing when not in master process 15 | """ 16 | import builtins as __builtin__ 17 | 18 | builtin_print = __builtin__.print 19 | 20 | def print(*args, **kwargs): 21 | force = kwargs.pop("force", False) 22 | if is_master or force: 23 | builtin_print(*args, **kwargs) 24 | 25 | __builtin__.print = print 26 | 27 | 28 | def is_dist_avail_and_initialized(): 29 | if not dist.is_available(): 30 | return False 31 | if not dist.is_initialized(): 32 | return False 33 | return True 34 | 35 | 36 | def get_world_size(): 37 | if not is_dist_avail_and_initialized(): 38 | return 1 39 | return dist.get_world_size() 40 | 41 | 42 | def get_rank(): 43 | if not is_dist_avail_and_initialized(): 44 | return 0 45 | return dist.get_rank() 46 | 47 | 48 | def is_main_process(): 49 | return get_rank() == 0 50 | 51 | 52 | def init_distributed_mode(args): 53 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: 54 | args.rank = int(os.environ["RANK"]) 55 | args.world_size = int(os.environ["WORLD_SIZE"]) 56 | args.gpu = int(os.environ["LOCAL_RANK"]) 57 | elif "SLURM_PROCID" in os.environ: 58 | args.rank = int(os.environ["SLURM_PROCID"]) 59 | args.gpu = args.rank % torch.cuda.device_count() 60 | else: 61 | print("Not using distributed mode") 62 | args.distributed = False 63 | return 64 | 65 | args.distributed = True 66 | 67 | torch.cuda.set_device(args.gpu) 68 | args.dist_backend = "nccl" 69 | print( 70 | "| distributed init (rank {}, world {}): {}".format( 71 | args.rank, args.world_size, args.dist_url 72 | ), 73 | flush=True, 74 | ) 75 | torch.distributed.init_process_group( 76 | backend=args.dist_backend, 77 | init_method=args.dist_url, 78 | world_size=args.world_size, 79 | rank=args.rank, 80 | timeout=datetime.timedelta( 81 | days=365 82 | ), # allow auto-downloading and de-compressing 83 | ) 84 | torch.distributed.barrier() 85 | setup_for_distributed(args.rank == 0) 86 | 87 | 88 | def get_dist_info(): 89 | if torch.__version__ < "1.0": 90 | initialized = dist._initialized 91 | else: 92 | initialized = dist.is_initialized() 93 | if initialized: 94 | rank = dist.get_rank() 95 | world_size = dist.get_world_size() 96 | else: # non-distributed training 97 | rank = 0 98 | world_size = 1 99 | return rank, world_size 100 | 101 | 102 | def main_process(func): 103 | @functools.wraps(func) 104 | def wrapper(*args, **kwargs): 105 | rank, _ = get_dist_info() 106 | if rank == 0: 107 | return func(*args, **kwargs) 108 | 109 | return wrapper 110 | 111 | 112 | def get_job_id(): 113 | return datetime.datetime.now().strftime("%Y%m%d%H%M%S") 114 | 115 | 116 | def get_rank(): 117 | if not is_dist_avail_and_initialized(): 118 | return 0 119 | return dist.get_rank() 120 | 121 | 122 | def setup_seeds(seed): 123 | seed = seed + get_rank() 124 | 125 | random.seed(seed) 126 | np.random.seed(seed) 127 | torch.manual_seed(seed) 128 | 129 | cudnn.benchmark = False 130 | cudnn.deterministic = True 131 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | run.py 2 | prev_outputs 3 | *.png 4 | .vscode 5 | wandb 6 | *.sh 7 | *.pt 8 | outputs* 9 | *.pdf 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | bash_scripts 15 | # C extensions 16 | *.so 17 | scripts/metrics.py 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | cover/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | .pybuilder/ 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | # For a library or package, you might want to ignore these files since the code is 97 | # intended to run in multiple environments; otherwise, check them in: 98 | # .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # poetry 108 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 109 | # This is especially recommended for binary packages to ensure reproducibility, and is more 110 | # commonly ignored for libraries. 111 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 112 | #poetry.lock 113 | 114 | # pdm 115 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 116 | #pdm.lock 117 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 118 | # in version control. 119 | # https://pdm.fming.dev/#use-with-ide 120 | .pdm.toml 121 | 122 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 123 | __pypackages__/ 124 | 125 | # Celery stuff 126 | celerybeat-schedule 127 | celerybeat.pid 128 | 129 | # SageMath parsed files 130 | *.sage.py 131 | 132 | # Environments 133 | .env 134 | .venv 135 | env/ 136 | venv/ 137 | ENV/ 138 | env.bak/ 139 | venv.bak/ 140 | 141 | # Spyder project settings 142 | .spyderproject 143 | .spyproject 144 | 145 | # Rope project settings 146 | .ropeproject 147 | 148 | # mkdocs documentation 149 | /site 150 | 151 | # mypy 152 | .mypy_cache/ 153 | .dmypy.json 154 | dmypy.json 155 | 156 | # Pyre type checker 157 | .pyre/ 158 | 159 | # pytype static type analyzer 160 | .pytype/ 161 | 162 | # Cython debug symbols 163 | cython_debug/ 164 | 165 | # PyCharm 166 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 167 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 168 | # and can be added to the global gitignore or merged into this file. For a more nuclear 169 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 170 | #.idea/ 171 | -------------------------------------------------------------------------------- /src/models/clip.py: -------------------------------------------------------------------------------- 1 | import open_clip 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from src.models.utils import disabled_train 7 | 8 | 9 | SIMPLE_TEMPLATE = lambda c: f"a photo of a {c}." 10 | 11 | # In PureClip model, the text-encoder is involved in the training progress. 12 | # FIXME: Directly remove logit scale before loading previous model might result in an error. 13 | class PureClip(nn.Module): 14 | def __init__( 15 | self, 16 | model_name, 17 | class_name_list, 18 | freeze_classification_head=False, 19 | device="cuda", 20 | ): 21 | super().__init__() 22 | self.model = open_clip.create_model_from_pretrained( 23 | model_name, 24 | pretrained="openai", 25 | return_transform=False, 26 | ).to(device) 27 | self.tokenizer = open_clip.get_tokenizer(model_name) 28 | self.template = SIMPLE_TEMPLATE 29 | self.device = device 30 | self.class_tokens = self.tokenize(class_name_list) 31 | 32 | self.freeze_classification_head = freeze_classification_head 33 | 34 | if self.freeze_classification_head: 35 | 36 | for name, p in self.model.named_parameters(): 37 | if "visual" not in name: 38 | p.requires_grad = False 39 | self.model.transformer.eval() 40 | self.model.transformer.train = disabled_train 41 | 42 | @property 43 | def preprocess_config(self): 44 | return self.model.visual.preprocess_cfg 45 | 46 | def tokenize(self, texts, device="cuda"): 47 | return self.tokenizer([self.template(t) for t in texts]).to(device) 48 | 49 | @torch.no_grad() 50 | def get_class_embedding(self, class_name_list, device="cuda"): 51 | tokens = self.tokenizer([self.template(t) for t in class_name_list]).to(device) 52 | text_embedding = self.model.encode_text(tokens) 53 | return F.normalize(text_embedding) 54 | 55 | def encode(self, images=None, text=None, normalize=True): 56 | if images is None: 57 | text_embeddings = self.model.encode_text(text) 58 | return F.normalize(text_embeddings) if normalize else text_embeddings 59 | if text is None: 60 | image_embeddings = self.model.encode_image(images) 61 | return F.normalize(image_embeddings) if normalize else image_embeddings 62 | 63 | # to fit the format of clip-classifier, we send a list of data to pure-clip if text is neeeded. 64 | def forward(self, images, text=None, normalize=True, get_features=False): 65 | if text is None: 66 | text = self.class_tokens 67 | 68 | image_embeddings = self.encode(images=images, normalize=normalize) 69 | text_embeddings = self.encode(text=text, normalize=normalize) 70 | 71 | if get_features: 72 | return image_embeddings, text_embeddings, self.model.logit_scale.exp() 73 | 74 | res = image_embeddings @ text_embeddings.t() 75 | 76 | res *= self.model.logit_scale.exp() 77 | 78 | return res 79 | 80 | def get_params(self): 81 | exclude_param = "logit_scale" 82 | return [ 83 | { 84 | "params": [ 85 | p 86 | for k, p in self.model.named_parameters() 87 | if p.requires_grad and exclude_param not in k 88 | ] 89 | } 90 | ] 91 | 92 | def get_state_dict(self): 93 | return self.model.state_dict() 94 | 95 | def load_state_dict(self, state_dict): 96 | self.model.load_state_dict(state_dict) 97 | 98 | 99 | def get_model( 100 | config, 101 | class_name_list, 102 | pretrained=False, 103 | freeze=False, 104 | device="cuda", 105 | ): 106 | 107 | model_config = config.model 108 | 109 | model = PureClip( 110 | model_config.vit_base, 111 | class_name_list, 112 | freeze_classification_head=model_config.get( 113 | "freeze_classification_head", False 114 | ), 115 | device=device, 116 | ) 117 | 118 | # then load from a checkpoint if not pre-trained 119 | if model_config.pretrained != "openai" and not pretrained: 120 | model.load_state_dict(torch.load(model_config.pretrained)["model"]) 121 | 122 | model = model.to(device) 123 | 124 | if freeze: 125 | for _, v in model.named_parameters(): 126 | v.requires_grad = False 127 | model.eval() 128 | 129 | if config.task.get("distributed", False) and not freeze: 130 | model = nn.parallel.DistributedDataParallel(model) 131 | 132 | return model 133 | -------------------------------------------------------------------------------- /visualization/plot.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import matplotlib.pyplot as plt 5 | import pandas as pd 6 | 7 | from scripts.metrics import DEFAULT_ZERO_SHOT_PERFORMANCE 8 | from scripts.utils import DEFAULT_DATASET_SEQ 9 | 10 | plt.rcParams["font.family"] = "Times New Roman" 11 | 12 | METHOD_MAP = { 13 | "Continual-FT": "base", 14 | "LwF": "lwf", 15 | "iCaRL": "icarl", 16 | "ZSCL": "zscl", 17 | } 18 | 19 | VISUALIZED_DATASET_NAME_MAP = { 20 | "fgvc-aircraft": "Aircraft", 21 | "dtd": "DTD", 22 | "eurosat": "EuroSAT", 23 | "flowers-102": "Flowers", 24 | "food-101": "Food", 25 | "oxford-pets": "Pets", 26 | "stanford-cars": "Cars", 27 | "ucf-101": "UCF101", 28 | } 29 | 30 | 31 | def parse_results(method="split_teacher_pure_clip", is_mdcil=False): 32 | config_name = f"{method}_config" 33 | 34 | res_list = [] 35 | for order in range(8): 36 | res_path = ( 37 | Path("/work/chu980802/mix-teacher") 38 | / method 39 | / "outputs" 40 | / f"order_{order}" 41 | / config_name 42 | / "final_results.json" 43 | ) 44 | with res_path.open("r") as f: 45 | res = json.load(f) 46 | res_list.append(pd.DataFrame(res).T) 47 | return res_list 48 | 49 | 50 | def plot_figure( 51 | data_dict, zero_shot, title, legend=None, save_path="comparison_plot.pdf" 52 | ): 53 | plt.figure(figsize=(10, 10)) 54 | markers = ["o", "s", "D", "^", "P"] # Different markers for each series 55 | # colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k'] 56 | colors = ["#92CD00", "#A3D1F2", "#F4B6C2", "#FED1BD", "#957DAD", "#88D8C0"] 57 | 58 | plt.scatter( 59 | 0, zero_shot, marker="*", s=800, color="#88D8C0", label="Zero-shot", zorder=3 60 | ) 61 | 62 | for (label, data), marker, color in zip(data_dict.items(), markers, colors): 63 | plt.plot( 64 | range(len(data) + 1), 65 | [zero_shot] + data, 66 | label=label, 67 | marker=marker, 68 | lw=3, 69 | markersize=10, 70 | color=color, 71 | ) 72 | plt.title(title, fontsize=40) 73 | plt.xlabel("Task sequence", fontsize=40) 74 | plt.ylabel("Accuracy", fontsize=40) 75 | plt.tick_params(labelsize=30) 76 | plt.xticks(range(9), range(0, 9)) # Example x-axis labels 77 | if legend is not None: 78 | plt.legend(fontsize=26, loc=legend) 79 | plt.grid(linestyle=":") 80 | plt.tight_layout() 81 | # Save the figure as a PDF 82 | plt.savefig(save_path) 83 | 84 | 85 | res_list_dict = {k: parse_results(v) for k, v in METHOD_MAP.items()} 86 | 87 | res_list = [ 88 | {method: res[order] for method, res in res_list_dict.items()} for order in range(8) 89 | ] 90 | 91 | # Catastrophic forgetting 92 | for order in range(8): 93 | dataset_name = {res.index[0] for res in res_list[order].values()} 94 | assert len(dataset_name) == 1 95 | dataset_name = dataset_name.pop() 96 | display_order = DEFAULT_DATASET_SEQ.index(dataset_name) 97 | legend = "lower left" if order == 0 else None 98 | plot_figure( 99 | { 100 | method: res.loc[:, dataset_name].values.tolist() 101 | for method, res in res_list[order].items() 102 | }, 103 | zero_shot=DEFAULT_ZERO_SHOT_PERFORMANCE[dataset_name], 104 | title="Acc. of the 1st task in $\mathcal{S}^$ ()".replace( 105 | "", str(display_order + 1) 106 | ).replace("", VISUALIZED_DATASET_NAME_MAP[dataset_name]), 107 | legend=legend, 108 | save_path=f"visualization/{dataset_name}_forgetting_order_{display_order+1}.pdf", 109 | ) 110 | 111 | # Zero-shot degradation 112 | for order in range(8): 113 | dataset_name = {res.index[-1] for res in res_list[order].values()} 114 | assert len(dataset_name) == 1 115 | dataset_name = dataset_name.pop() 116 | display_order = ( 117 | DEFAULT_DATASET_SEQ.index(dataset_name) 118 | + 1 119 | - 8 * (DEFAULT_DATASET_SEQ.index(dataset_name) == 7) 120 | ) 121 | legend = "upper left" if order == 0 else None 122 | plot_figure( 123 | { 124 | method: res.loc[:, dataset_name].values.tolist() 125 | for method, res in res_list[order].items() 126 | }, 127 | zero_shot=DEFAULT_ZERO_SHOT_PERFORMANCE[dataset_name], 128 | title="Acc. of the 8th task in $\mathcal{S}^$ ()".replace( 129 | "", str(display_order + 1) 130 | ).replace("", VISUALIZED_DATASET_NAME_MAP[dataset_name]), 131 | legend=legend, 132 | save_path=f"visualization/{dataset_name}_degradation_order_{display_order+1}.pdf", 133 | ) 134 | -------------------------------------------------------------------------------- /src/datasets/base.py: -------------------------------------------------------------------------------- 1 | import json 2 | from ast import literal_eval 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | from torchvision.datasets import ImageFolder 10 | from torchvision.transforms import Compose, ConvertImageDtype, PILToTensor 11 | 12 | 13 | def pil_loader(path: str): 14 | with open(path, "rb") as f: 15 | img = Image.open(f) 16 | return img.convert("RGB") 17 | 18 | 19 | class BaseClassificationDataset(Dataset): 20 | def __init__( 21 | self, 22 | root, 23 | mode="train", 24 | transform=None, 25 | sample_num=-1, 26 | seed=1102, 27 | label_shift=0, 28 | ): 29 | self.root = Path(root) / self.dataset_name 30 | self.mode = mode 31 | self._data_list, self._class_name_list = self.make_dataset() 32 | self.transform = transform 33 | self.rng = np.random.default_rng(seed) 34 | self.label_shift = label_shift 35 | 36 | if sample_num != -1: 37 | sample_idx = self.rng.choice( 38 | len(self._data_list), sample_num, replace=False 39 | ) 40 | self._data_list = [self._data_list[i] for i in sample_idx] 41 | 42 | @property 43 | def class_name_list(self): 44 | return self._class_name_list 45 | 46 | def make_dataset(self): 47 | """ 48 | data annotation format: 49 | { 50 | "data": { 51 | "train":[ 52 | [image_path, label], 53 | ... 54 | ], 55 | "val": [ 56 | [image_path, label], 57 | ... 58 | ], 59 | "test": [ 60 | [image_path, label], 61 | ... 62 | ] 63 | }, 64 | "class_names": [ 65 | class_0_name, 66 | class_1_name, 67 | ... 68 | ] 69 | } 70 | """ 71 | with (self.root / self.annotation_filename).open("r") as f: 72 | data = json.load(f) 73 | 74 | data_list = [] 75 | for d in data["data"][self.mode]: 76 | data_list.append(((self.root / "images" / d[0]).as_posix(), d[1])) 77 | 78 | return data_list, data["class_names"] 79 | 80 | def get_class_name(self, class_idx): 81 | return self._class_name_list[class_idx] 82 | 83 | def __len__(self): 84 | return len(self._data_list) 85 | 86 | def __getitem__(self, index): 87 | path, label = self._data_list[index] 88 | image = pil_loader(path) 89 | 90 | if self.transform: 91 | image = self.transform(image) 92 | 93 | return image, label + self.label_shift, index 94 | 95 | 96 | class BaseUnlabeledDataset(BaseClassificationDataset): 97 | @property 98 | def class_name_list(self): 99 | return None 100 | 101 | def make_dataset(self): 102 | with (self.root / self.annotation_filename).open("r") as f: 103 | data = json.load(f) 104 | 105 | data_list = [] 106 | for d in data["data"][self.mode]: 107 | data_list.append((self.root / "images" / d).as_posix()) 108 | 109 | return data_list, None 110 | 111 | def get_class_name(self, _): 112 | return None 113 | 114 | def __getitem__(self, index): 115 | path = self._data_list[index] 116 | # image = pil_loader(path) 117 | try: 118 | image = pil_loader(path) 119 | except: 120 | with open("error.log", "a") as f: 121 | f.write(path + "\n") 122 | image = pil_loader(self._data_list[0]) 123 | 124 | if self.transform: 125 | image = self.transform(image) 126 | 127 | return image, -1, index 128 | 129 | 130 | class ImageListDataset(BaseClassificationDataset): 131 | def __init__(self, image_list_path, transform=None, seed=1102, sample_num=-1): 132 | if not isinstance(image_list_path, Path): 133 | image_list_path = Path(image_list_path) 134 | 135 | self._data_list = [ 136 | literal_eval(line) 137 | for line in image_list_path.read_text().strip().split("\n") 138 | ] 139 | self.label_shift = 0 140 | self.transform = transform 141 | self.rng = np.random.default_rng(seed) 142 | 143 | if sample_num != -1: 144 | sample_idx = self.rng.choice( 145 | len(self._data_list), sample_num, replace=False 146 | ) 147 | self._data_list = [self._data_list[i] for i in sample_idx] 148 | 149 | @property 150 | def class_name_list(self): 151 | return None 152 | 153 | 154 | class NoisyImageListDataset(ImageListDataset): 155 | def __init__(self, noise_path, *args, **kwargs): 156 | super().__init__(*args, **kwargs) 157 | self.noise = torch.load(noise_path) 158 | 159 | def __getitem__(self, idx): 160 | image, label, _ = super().__getitem__(idx) 161 | noise = self.noise[idx] 162 | 163 | return image, noise, label, idx 164 | -------------------------------------------------------------------------------- /src/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import List, Union 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from torch.utils.data import DataLoader, DistributedSampler 8 | 9 | from src.datasets import DATASET_MAPPING 10 | from src.datasets.transform import load_transform 11 | from src.utils import get_rank, get_world_size 12 | 13 | 14 | class DataIterativeLoader: 15 | def __init__(self, dataloader, device="cuda"): 16 | self.len = len(dataloader) 17 | self.dataloader = dataloader 18 | self.iterator = None 19 | self.device = device 20 | 21 | def set_epoch(self, epoch): 22 | if hasattr(self.dataloader.sampler, "set_epoch"): 23 | self.dataloader.sampler.set_epoch(epoch) 24 | 25 | def init(self): 26 | self.iterator = iter(self.dataloader) 27 | 28 | def __next__(self): 29 | data = next(self.iterator) 30 | if isinstance(data, list): 31 | data = [d.to(self.device) for d in data] 32 | return data 33 | else: 34 | data = data.to(self.device) 35 | return data 36 | 37 | def __iter__(self): 38 | return self 39 | 40 | def __len__(self): 41 | return self.len 42 | 43 | 44 | def build_dataloader( 45 | dataset, 46 | batch_size=8, 47 | num_workers=4, 48 | pin_memory=True, 49 | shuffle=False, 50 | drop_last=False, 51 | distributed=False, 52 | ): 53 | if distributed: 54 | sampler = DistributedSampler( 55 | dataset, 56 | shuffle=shuffle, 57 | num_replicas=get_world_size(), 58 | rank=get_rank(), 59 | ) 60 | else: 61 | sampler = None 62 | return DataLoader( 63 | dataset, 64 | batch_size=batch_size, 65 | num_workers=num_workers, 66 | pin_memory=pin_memory, 67 | sampler=sampler, 68 | shuffle=shuffle and not distributed, 69 | drop_last=drop_last, 70 | ) 71 | 72 | 73 | def build_iter_dataloader( 74 | dataset, 75 | batch_size=8, 76 | num_workers=4, 77 | pin_memory=True, 78 | shuffle=False, 79 | drop_last=False, 80 | device="cuda", 81 | distributed=False, 82 | **kwargs, 83 | ): 84 | dataloader = build_dataloader( 85 | dataset, 86 | batch_size=batch_size, 87 | num_workers=num_workers, 88 | pin_memory=pin_memory, 89 | shuffle=shuffle, 90 | drop_last=drop_last, 91 | distributed=distributed, 92 | ) 93 | 94 | return DataIterativeLoader(dataloader, device=device) 95 | 96 | 97 | def get_dataloader( 98 | dataset_name, 99 | root, 100 | mode, 101 | transform, 102 | sample_num=-1, 103 | device="cuda", 104 | seed=1102, 105 | distributed=False, 106 | label_shift=0, 107 | **dataloader_config, 108 | ): 109 | dataset_class = DATASET_MAPPING[dataset_name] 110 | 111 | dataset = dataset_class( 112 | root, 113 | mode=mode, 114 | transform=transform, 115 | sample_num=sample_num, 116 | seed=seed, 117 | label_shift=label_shift, 118 | ) 119 | 120 | distributed = distributed and mode == "train" 121 | return build_iter_dataloader( 122 | dataset, **dataloader_config, device=device, distributed=distributed 123 | ) 124 | 125 | 126 | def get_dataloaders_from_config(config, num_classes_accumulation_dict, device="cuda"): 127 | dataloaders = {} 128 | train_transform, eval_transform = load_transform(config) 129 | 130 | for dataloader_type, dataloader_config in config.data.split.items(): 131 | label_shift = num_classes_accumulation_dict[config.data.name] 132 | 133 | dataloaders[dataloader_type] = get_dataloader( 134 | dataset_name=config.data.name, 135 | root=config.data.root, 136 | mode=dataloader_config.split_name, 137 | transform=train_transform if dataloader_type == "train" else eval_transform, 138 | sample_num=config.data.get("sample_num", -1), 139 | device=device, 140 | distributed=config.task.get("distributed", False), 141 | label_shift=label_shift, 142 | **dataloader_config, 143 | ) 144 | 145 | return dataloaders 146 | 147 | 148 | def load_single_class_name_list(dataset_name: str, data_root: str): 149 | dataset_class = DATASET_MAPPING[dataset_name] 150 | name, annotation_filename = ( 151 | dataset_class.dataset_name, 152 | dataset_class.annotation_filename, 153 | ) 154 | 155 | with (Path(data_root) / name / annotation_filename).open("r") as f: 156 | data = json.load(f) 157 | 158 | return data["class_names"] 159 | 160 | 161 | def load_class_name_list(config): 162 | dataset_list = config.data.get("inference_dataset_list", [config.data.name]) 163 | class_name_list = [] 164 | num_classes_accumulation = [] 165 | for dataset_name in dataset_list: 166 | class_names = load_single_class_name_list(dataset_name, config.data.root) 167 | class_name_list += class_names 168 | num_classes_accumulation.append(len(class_names)) 169 | num_classes_accumulation = [0] + np.cumsum(num_classes_accumulation).tolist()[:-1] 170 | 171 | return class_name_list, dict(zip(dataset_list, num_classes_accumulation)) 172 | 173 | 174 | def get_conceptual_captions( 175 | config, filename="Validation_GCC-1.1.0-Validation.tsv", size=100 176 | ): 177 | path = Path(config.data.root) / "conceptual_captions" / filename 178 | 179 | df = pd.read_csv(path, sep="\t") 180 | 181 | rng = np.random.default_rng(config.task.seed) 182 | random_index = rng.choice( 183 | df.index, size=size if size else len(df.index), replace=False 184 | ) 185 | 186 | return df.iloc[random_index, 0].tolist() 187 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Select and Distill 4 | 5 | This is an official implementation of our work, Select and Distill: Selective Dual-Teacher Knowledge Transfer for Continual Learning on Vision-Language Models, accepted to ECCV'24. 6 | 7 | [![PWC](https://img.shields.io/badge/arXiv-2403.09296-b31b1b)](https://arxiv.org/abs/2403.09296) 8 | [![PWC](https://img.shields.io/badge/ECCV%202024-PDF-FACE27)](https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/03759.pdf) 9 | [![PWC](https://img.shields.io/badge/ECCV%202024-Supp-7DCBFF)](https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/03759-supp.pdf) 10 | [![PWC](https://img.shields.io/badge/ECCV%202024-Bibtex-CB8CEA)](#citation) 11 | 12 | https://yuchuyu.org/research/snd 13 | 14 | 15 |
16 | 17 | 18 | ## Table of Contents 19 | - [Announcement](#announcement) 20 | - [Installation](#install) 21 | - [Data Preparation](#data) 22 | - [Model Checkpoints](#checkpoints) 23 | - [Running the model](#run) 24 | - [Citation](#citation) 25 | 26 | 27 | ## Annoucement 28 | 29 | **[2025/01/19]** The model checkpoints have also been uploaded! Check [here](#checkpoints) for more details. 30 | 31 | **[2025/01/19]** The instruction page is ready! We plan to release our original checkpoints soon. 32 | 33 | **[2024/12/31]** Our full codebase has been released! Introduction and installation method (include packages) would be updated soon. 34 | 35 | 36 | ## Installation 37 | 38 | Create a new Conda environment with Python 3.10.14: 39 | 40 | ``` 41 | conda create -n snd python==3.10.14 42 | ``` 43 | 44 | Activate the environment and install PyTorch with the specified version and CUDA support: 45 | 46 | ``` 47 | conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=12.1 -c pytorch -c nvidia 48 | ``` 49 | 50 | Install additional dependencies using the provided `requirements.txt` file: 51 | 52 | ``` 53 | pip install -r requirements.txt 54 | ``` 55 | 56 | 57 | ## Dataset Preparation 58 | 59 | To reproduce our experiments, download the following datasets from the guidance provided [here](https://github.com/KaiyangZhou/CoOp/blob/main/DATASETS.md). 60 | 61 | - FGVCAircraft 62 | - DTD 63 | - EuroSAT 64 | - Flowers102 65 | - Food101 66 | - OxfordPets 67 | - StanfordCars 68 | - UCF101 69 | - ImageNet 70 | 71 | 72 | Organize each dataset in the following directory structure: 73 | 74 | ``` 75 | / 76 | ├── images/ 77 | │ ├── image data / folders 78 | ├── _annotations.json 79 | ``` 80 | 81 | The `_annotations.json` file contains the training, validation, and test splits, along with class names. The files we used for all datasets are provided [here](https://drive.google.com/drive/folders/144OIxusHyB8tRtlnvVGttCx0UE0ab_bv?usp=sharing). Download these files and place them in the appropriate paths as described above. 82 | 83 | 84 | ## Model Checkpoints 85 | 86 | We provide our original model checkpoints for public use. Due to limited storage space, only the last checkpoints for each training sequence are released. 87 | 88 | Unfortunately, while reproducing our experiments, we observed a slight performance drop (0.08% in mean scores). This discrepancy may be attributed to differences in hardware or package versions. Despite this minor variation, our method still achieve state-of-the-art performance compared to previous works. 89 | 90 | You can access the model checkpoints and the reproduced average accuracy scores [here](https://drive.google.com/drive/folders/1V4rubgQsq-e9ydHbiEs5BtwnySJwghOG?usp=sharing). 91 | 92 | 93 | ## Running with the Scripts 94 | 95 | We provide several scripts to help you easily reproduce our experiments. Our experiments were conducted using 4x V100 GPUs in distributed parallel mode. Note that we have not tested our method outside of distributed mode. If you have only one GPU, run the code in distributed mode by specifying `--nproc_per_node` to 1. 96 | 97 | --- 98 | 99 | ### Prerequisite 100 | 101 | Before running the scripts, ensure that the root paths to your dataset folders are correctly configured in all files within the `configs/` directory. 102 | 103 | Specifically, update the `data.root` attribute to point to your dataset's root directory. 104 | 105 | Other configuration attributes do not need modification, as our scripts will automatically adjust them during runtime. However, you may modify these attributes if you wish to experiment with different hyper-parameters. 106 | 107 | --- 108 | 109 | ### Train and Eval 110 | 111 | The following script allows training on **a single dataset** (e.g., fgvc-aircraft) and evaluating on **all datasets** using 4 GPUs. 112 | 113 | Run the command below to execute the script: 114 | 115 | ```sh 116 | python -m scripts.train_and_eval --config_path configs/snd_config_4_gpus.yaml --dataset fgvc-aircraft --distributed --nproc_per_node 4 117 | ``` 118 | 119 | #### Using a Single GPU 120 | 121 | If you are using only one GPU, modify the command as follows: 122 | 123 | ```sh 124 | python -m scripts.train_and_eval --config_path configs/snd_config_1_gpu.yaml --dataset fgvc-aircraft --distributed --nproc_per_node 1 125 | ``` 126 | 127 | #### Continual Training 128 | 129 | To load a model trained on a specific dataset and **continue training** on another dataset, include the `--pretrained_dataset` argument: 130 | 131 | ```sh 132 | python -m scripts.train_and_eval --config_path configs/snd_config_4_gpus.yaml --pretrained_dataset fgvc-aircraft --dataset dtd --distributed --nproc_per_node 4 133 | ``` 134 | 135 | #### Note 136 | 137 | - Our code has only been verified with 1 or 4 GPUs. 138 | - Using more than 4 GPUs is not recommended, as we observed that the performance drops a bit. 139 | - When training with 1–4 GPUs, ensure that the batch size for training and reference data is correctly adjusted to match the number of GPUs. 140 | 141 | --- 142 | 143 | ### Continual Training on the whole training sequence 144 | 145 | We also provide a script to continually train and evaluate across an entire sequence of datasets (i.e., reproduce our Multi-Domain Task Incremental Learning setting): 146 | 147 | ```sh 148 | python -m scripts.continually_train --config_path configs/snd_config_4_gpus.yaml --order 0 --distributed --nproc_per_node 4 149 | ``` 150 | 151 | #### Note 152 | 153 | - The `--order` argument specifies an offset to shift the pre-defined dataset sequence. 154 | - For detailed task orders of each training sequence, refer to the supplementary materials. 155 | - The whole process of training and evaluation on a single training sequence using 4x V100 GPUs takes approximately **150 minutes** on our devices. 156 | --- 157 | 158 | ### Inference 159 | 160 | We also provide a script for performing inference on all datasets used in our experiments. 161 | 162 | Run the following command to execute the inference script using the model stored in `outputs/order_0/checkpoint_latest.pth`: 163 | 164 | ```sh 165 | python -m scripts.inference --model_path outputs/order_0/checkpoint_latest.pth 166 | ``` 167 | 168 | 169 | ## Citation 170 | 171 | If you find our work useful, please cite it using the following BibTeX entry: 172 | 173 | ```bibtex 174 | @inproceedings{yu2025select, 175 | title={Select and distill: Selective dual-teacher knowledge transfer for continual learning on vision-language models}, 176 | author={Yu, Yu-Chu and Huang, Chi-Pin and Chen, Jr-Jen and Chang, Kai-Po and Lai, Yung-Hsuan and Yang, Fu-En and Wang, Yu-Chiang Frank}, 177 | booktitle={European Conference on Computer Vision}, 178 | pages={219--236}, 179 | year={2025}, 180 | organization={Springer} 181 | } 182 | ``` 183 | -------------------------------------------------------------------------------- /src/trainer/base_trainer.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | from collections import defaultdict 4 | from pathlib import Path 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from tqdm import tqdm 10 | 11 | import wandb 12 | from src.trainer.utils import CosineLRScheduler, get_optimizer 13 | from src.utils import AccuracyMeter, dump_config, is_main_process, main_process 14 | 15 | 16 | class BaseTrainer: 17 | def __init__(self, model, dataloaders, config, job_id=None): 18 | self._model = model 19 | self.dataloaders = dataloaders 20 | self.config = config 21 | self._current_num_iterations = 0 22 | 23 | if self.training_mode and job_id: 24 | self.output_dir = ( 25 | Path(self.config.task.output_dir) / self.config.data.name / job_id 26 | ) 27 | self.output_dir.mkdir(parents=True, exist_ok=True) 28 | dump_config(self.config, self.output_dir / "config.json") 29 | 30 | self.lastest_dir = self.output_dir.parent / "latest" 31 | 32 | if self.lastest_dir.exists(): 33 | # unlink it since it's a symbolic link 34 | self.lastest_dir.unlink() 35 | 36 | self.lastest_dir.symlink_to(self.output_dir.name) 37 | 38 | self.local_log = defaultdict(dict) 39 | 40 | if self.training_mode: 41 | self.optimizer = get_optimizer( 42 | self.unwrapped_model(self.train_model), self.config.task 43 | ) 44 | self.lr_scheduler = CosineLRScheduler( 45 | self.optimizer, self.config.task, self.num_total_train_steps 46 | ) 47 | 48 | @main_process 49 | def save(self, epoch=None): 50 | # TODO: check if freeze classification head or not 51 | 52 | unwrapped_eval_model = self.unwrapped_model(self.eval_model) 53 | 54 | state_dict = unwrapped_eval_model.get_state_dict() 55 | 56 | save_obj = {"model": state_dict} 57 | 58 | if not epoch: 59 | epoch = "latest" 60 | 61 | save_path = self.output_dir / f"checkpoint_{epoch}.pth" 62 | 63 | print(f"Saving checkpoint at epoch {epoch} to {save_path}.") 64 | torch.save(save_obj, save_path) 65 | 66 | @property 67 | def distributed(self): 68 | return self.config.task.get("distributed", False) 69 | 70 | @property 71 | def eval_model(self): 72 | return self._model 73 | 74 | @property 75 | def train_model(self): 76 | return self._model 77 | 78 | @property 79 | def method_config(self): 80 | return self.config.method 81 | 82 | @property 83 | def training_mode(self): 84 | return self.config.mode == "train" 85 | 86 | @property 87 | def current_num_iterations(self): 88 | return self._current_num_iterations 89 | 90 | @current_num_iterations.setter 91 | def current_num_iterations(self, value): 92 | self._current_num_iterations = value 93 | 94 | @property 95 | def num_total_train_steps(self): 96 | minimum_iterations = max(2 * len(self.train_loader), self.max_iterations) 97 | return min(self.max_epoch * len(self.train_loader), minimum_iterations) 98 | 99 | @property 100 | def max_epoch(self): 101 | return self.config.task.max_epoch 102 | 103 | @property 104 | def max_iterations(self): 105 | return self.config.task.max_iterations 106 | 107 | @property 108 | def lr(self): 109 | return self.lr_scheduler.current_lr 110 | 111 | @property 112 | def log_interval(self): 113 | return self.config.task.log_interval 114 | 115 | @property 116 | def train_loader(self): 117 | return self.dataloaders.get("train", None) 118 | 119 | @property 120 | def val_loader(self): 121 | return self.dataloaders.get("val", None) 122 | 123 | @property 124 | def test_loader(self): 125 | return self.dataloaders.get("test", None) 126 | 127 | def get_current_training_step(self, epoch, local_step): 128 | return len(self.train_loader) * (epoch - 1) + local_step 129 | 130 | @main_process 131 | def logging(self, local_desc=None, use_wandb=True, **message_dict): 132 | if use_wandb: 133 | wandb.log(message_dict) 134 | if local_desc is not None: 135 | self.local_log[local_desc].update(message_dict) 136 | 137 | @main_process 138 | def dump_results(self, filename="results.json", print_result=False): 139 | if self.training_mode: 140 | with open(self.output_dir / filename, "w") as f: 141 | json.dump(self.local_log, f, indent=4) 142 | 143 | if print_result: 144 | print(json.dumps(self.local_log)) 145 | 146 | def unwrapped_model(self, model): 147 | if self.distributed and hasattr(model, "module"): 148 | return model.module 149 | else: 150 | return model 151 | 152 | def base_loss(self, images, labels, label_smoothing=0.2, **_): 153 | outputs = self.train_model(images) 154 | loss = F.cross_entropy(outputs, labels, label_smoothing=label_smoothing) 155 | return loss, {"loss": loss.item()} 156 | 157 | def train_step(self, images, labels): 158 | self.current_num_iterations += 1 159 | # need to step lr_scheduler first since in this repo I didn't explictly set a learning rate in the optimizer. 160 | self.lr_scheduler.step() 161 | 162 | loss_fn = getattr(self, f"{self.method_config.name}_loss") 163 | loss, loss_dict = loss_fn(images, labels, **self.method_config.params) 164 | 165 | loss.backward() 166 | self.optimizer.step() 167 | self.optimizer.zero_grad() 168 | 169 | loss_dict.update({"total_loss": loss.item()}) 170 | return loss_dict 171 | 172 | def evaluate(self, dataloader=None): 173 | if dataloader is None: 174 | dataloader = self.test_loader 175 | 176 | self.eval_model.eval() 177 | 178 | scores = AccuracyMeter() 179 | 180 | dataloader.init() 181 | with torch.no_grad(), tqdm(total=len(dataloader)) as pbar: 182 | for images, labels, _ in dataloader: 183 | preds = self.eval_model(images).argmax(dim=1) 184 | scores += preds == labels 185 | pbar.set_postfix_str(f"acc: {100 * scores.acc():.2f}%") 186 | pbar.update(1) 187 | 188 | return scores.acc() 189 | 190 | def train(self, set_validation=False): 191 | # test zero-shot validation performance 192 | if self.val_loader and set_validation: 193 | self.logging(val_acc=self.evaluate(self.val_loader)) 194 | 195 | with tqdm(total=self.num_total_train_steps) as pbar: 196 | 197 | # TODO: make this double for-loop a single for-loop 198 | for epoch in range(1, self.max_epoch + 1): 199 | pbar.set_description(f"Epoch {epoch}/{self.max_epoch}: ") 200 | 201 | self.train_model.train() 202 | self.train_loader.init() 203 | 204 | for i, (images, labels, _) in enumerate(self.train_loader): 205 | loss_dict = self.train_step(images, labels) 206 | 207 | pbar.set_postfix_str( 208 | f"lr: {self.lr:.2e}, loss: {loss_dict['total_loss']:.2e}" 209 | ) 210 | pbar.update(1) 211 | 212 | if i % self.log_interval == 0: 213 | self.logging(lr=self.lr, **loss_dict) 214 | 215 | if self.current_num_iterations >= self.num_total_train_steps: 216 | break 217 | 218 | if self.val_loader and set_validation and is_main_process(): 219 | self.logging(val_acc=self.evaluate(self.val_loader)) 220 | 221 | if self.current_num_iterations >= self.num_total_train_steps: 222 | self.save(epoch=None) 223 | break 224 | 225 | if self.distributed: 226 | self.train_loader.set_epoch(epoch) 227 | 228 | # self.save(epoch) 229 | 230 | 231 | class BaseKDTrainer(BaseTrainer): 232 | def __init__(self, model, dataloaders, config, teachers, job_id=None): 233 | super().__init__(model, dataloaders, config, job_id=job_id) 234 | self.epoch_counter = 0 235 | self._teachers = teachers 236 | self.pretrained_teacher_model.eval() 237 | 238 | @property 239 | def pretrained_teacher_model(self): 240 | return self._teachers["pretrained"] 241 | 242 | @property 243 | def ref_loader(self): 244 | return self.dataloaders["ref"] 245 | 246 | def _get_kd_loss(self, student_logits, teacher_logits, feature_criterion=None, T=2): 247 | if feature_criterion: 248 | return feature_criterion(student_logits, teacher_logits) 249 | 250 | soft_labels = nn.functional.softmax(teacher_logits / T, dim=-1) 251 | return nn.functional.cross_entropy( 252 | student_logits / T, soft_labels, reduction="mean" 253 | ) * (T**2) 254 | 255 | def get_ref_data(self, loader, has_noise=False): 256 | try: 257 | ref_data = next(loader) 258 | except StopIteration: 259 | self.epoch_counter += 1 260 | loader.init() 261 | ref_data = next(loader) 262 | 263 | if self.distributed: 264 | self.ref_loader.set_epoch(self.epoch_counter) 265 | 266 | data, index = ref_data[0], ref_data[-1] 267 | if has_noise: 268 | data += ref_data[1] 269 | 270 | return data, index 271 | 272 | def train(self, *args, **kwargs): 273 | self.dataloaders["ref"].init() 274 | super().train(*args, **kwargs) 275 | -------------------------------------------------------------------------------- /scripts/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import subprocess 3 | from ast import literal_eval 4 | from pathlib import Path 5 | from typing import List, Union 6 | 7 | DEFAULT_OUTPUT_ROOT = Path("outputs") 8 | DEFAULT_STORAGE_ROOT = Path("/work/chu980802/mix-teacher") 9 | 10 | DEFAULT_DATASET_SEQ = [ 11 | "fgvc-aircraft", 12 | "dtd", 13 | "eurosat", 14 | "flowers-102", 15 | "food-101", 16 | "oxford-pets", 17 | "stanford-cars", 18 | "ucf-101", 19 | ] 20 | 21 | 22 | class ContinualTrainer: 23 | def __init__( 24 | self, 25 | config_path: str = "configs/mix_teacher_config.yaml", 26 | module: str = "main.train", 27 | training_dataset_seq: List[str] = DEFAULT_DATASET_SEQ, 28 | eval_dataset_seq: List[str] = None, 29 | output_root: Path = DEFAULT_OUTPUT_ROOT, 30 | sub_output_dir: str = "default", 31 | method_config=None, 32 | max_epoch: int = 10, 33 | max_iterations: int = 1000, 34 | distributed: bool = False, 35 | nnodes: int = 1, 36 | nproc_per_node: int = 1, 37 | ): 38 | self.config_path = config_path 39 | self.module = module 40 | self.training_dataset_seq = training_dataset_seq 41 | self.eval_dataset_seq = ( 42 | training_dataset_seq if eval_dataset_seq is None else eval_dataset_seq 43 | ) 44 | self.train_eval_config = { 45 | "distributed": distributed, 46 | "nnodes": nnodes, 47 | "nproc_per_node": nproc_per_node, 48 | "max_epoch": max_epoch, 49 | "max_iterations": max_iterations, 50 | } 51 | 52 | self.output_root = output_root 53 | self.sub_output_dir = sub_output_dir 54 | 55 | self.method_config = method_config if method_config is not None else {} 56 | 57 | self.output_dir = ( 58 | self.output_root / self.sub_output_dir / Path(self.config_path).stem 59 | ) 60 | self.output_dir.mkdir(parents=True, exist_ok=True) 61 | 62 | @classmethod 63 | def aggregate_results(self, training_dataset_seq, ouptut_root): 64 | results_dict = dict() 65 | for dataset in training_dataset_seq: 66 | eval_result_path = ( 67 | get_output_dataset_dir(dataset, output_root=ouptut_root) 68 | / "eval_results.json" 69 | ) 70 | 71 | with eval_result_path.open("r") as f: 72 | results = json.load(f) 73 | 74 | results_dict[dataset] = results 75 | 76 | return results_dict 77 | 78 | @classmethod 79 | def format_results( 80 | self, 81 | res_dict, 82 | training_dataset_seq, 83 | eval_dataset_seq, 84 | pad=4, 85 | decimal=2, 86 | ): 87 | longest_training_dataset_name_len = max([len(k) for k in training_dataset_seq]) 88 | lines = [] 89 | lines.append( 90 | (" " * pad).join( 91 | [" " * longest_training_dataset_name_len] 92 | + [ 93 | f"%{max(len(dataset), 5)}s" % (dataset) 94 | for dataset in eval_dataset_seq 95 | ] 96 | ) 97 | ) 98 | 99 | for training_dataset in training_dataset_seq: 100 | line = [f"%{longest_training_dataset_name_len}s" % (training_dataset)] 101 | line += [ 102 | f"%{len(eval_dataset)}s" 103 | % (f"{100*res_dict[training_dataset][eval_dataset]:.{decimal}f}") 104 | for eval_dataset in eval_dataset_seq 105 | ] 106 | lines.append((" " * pad).join(line)) 107 | 108 | return "\n".join(lines) + "\n" 109 | 110 | def train_and_eval(self, format=True): 111 | pretrained_dataset = None 112 | for training_dataset in self.training_dataset_seq: 113 | train_and_eval_script( 114 | config_path=self.config_path, 115 | training_module=self.module, 116 | training_dataset=training_dataset, 117 | pretrained_dataset=pretrained_dataset, 118 | eval_dataset_seq=self.eval_dataset_seq, 119 | output_root=self.output_root, 120 | sub_output_dir=self.sub_output_dir, 121 | **self.train_eval_config, 122 | **self.method_config, 123 | ) 124 | pretrained_dataset = training_dataset 125 | 126 | res = self.aggregate_results( 127 | training_dataset_seq=self.training_dataset_seq, 128 | ouptut_root=self.output_dir.parent, 129 | ) 130 | 131 | with (self.output_dir / "final_results.json").open("w") as f: 132 | json.dump(res, f, indent=4) 133 | 134 | if format: 135 | formatted_results = self.format_results( 136 | res, self.training_dataset_seq, self.eval_dataset_seq 137 | ) 138 | with (self.output_dir / "formatted_results.txt").open("w") as f: 139 | f.write(formatted_results) 140 | print(formatted_results) 141 | 142 | return res 143 | 144 | 145 | def get_output_dataset_dir( 146 | dataset=None, output_root=DEFAULT_OUTPUT_ROOT, timestamp="latest" 147 | ): 148 | if dataset is None: 149 | dataset = "openai" 150 | return output_root / dataset / timestamp 151 | 152 | 153 | def get_model_path( 154 | dataset=None, output_root=DEFAULT_OUTPUT_ROOT, timestamp="latest", epoch="latest" 155 | ): 156 | if dataset is None: 157 | return "openai" 158 | model_dir = get_output_dataset_dir(dataset, output_root, timestamp) 159 | return model_dir / f"checkpoint_{epoch}.pth" 160 | 161 | 162 | def start_subprocess(command, print_command=False): 163 | if isinstance(command, list): 164 | command = " ".join(command) 165 | if print_command: 166 | print(command + "\n") 167 | output = subprocess.check_output(command, shell=True) 168 | 169 | return output.decode("utf-8") 170 | 171 | 172 | def train_and_eval_script( 173 | config_path: str = "configs/mix_teacher_config.yaml", 174 | training_module: str = "main.train", 175 | training_dataset: str = "fgvc-aircraft", 176 | pretrained_dataset: str = None, 177 | eval_dataset_seq: List[str] = DEFAULT_DATASET_SEQ, 178 | sample_num: int = -1, 179 | max_epoch: int = 10, 180 | max_iterations: int = 1000, 181 | output_root: Path = DEFAULT_OUTPUT_ROOT, 182 | sub_output_dir: str = "default", 183 | eval_epoch: Union[int, str] = "latest", 184 | timestamp="latest", 185 | distributed=False, 186 | nnodes=1, 187 | nproc_per_node=1, 188 | **method_config, 189 | ): 190 | output_dir = output_root / sub_output_dir 191 | pretrained_model_path = get_model_path( 192 | pretrained_dataset, output_root=output_dir, timestamp=timestamp 193 | ) 194 | 195 | training_script( 196 | config_path=config_path, 197 | training_module=training_module, 198 | dataset=training_dataset, 199 | pretrained_model_path=pretrained_model_path, 200 | sample_num=sample_num, 201 | max_epoch=max_epoch, 202 | max_iterations=max_iterations, 203 | output_root=output_root, 204 | sub_output_dir=sub_output_dir, 205 | distributed=distributed, 206 | nnodes=nnodes, 207 | nproc_per_node=nproc_per_node, 208 | **method_config, 209 | ) 210 | 211 | model_path = get_model_path( 212 | training_dataset, output_root=output_dir, epoch=eval_epoch 213 | ) 214 | eval_results_path = ( 215 | get_output_dataset_dir(training_dataset, output_root=output_dir) 216 | / "eval_results.json" 217 | ) 218 | 219 | eval_on_multiple_datasets_script( 220 | datasets=eval_dataset_seq, 221 | pretrained_model_path=model_path, 222 | dump_result_path=eval_results_path, 223 | ) 224 | 225 | 226 | def training_script( 227 | config_path, 228 | training_module="main.train", 229 | dataset="fgvc-aircraft", 230 | pretrained_model_path="openai", 231 | sample_num=-1, 232 | max_epoch=10, 233 | max_iterations=1000, 234 | output_root=DEFAULT_OUTPUT_ROOT, 235 | sub_output_dir="default", 236 | distributed=False, 237 | nnodes=1, 238 | nproc_per_node=1, 239 | **method_config, 240 | ): 241 | runner = ( 242 | "python" 243 | if not distributed 244 | else f"torchrun --nnodes={nnodes} --nproc_per_node={nproc_per_node}" 245 | ) 246 | command = [ 247 | runner, 248 | "-m", 249 | training_module, 250 | "--cfg-path", 251 | config_path, 252 | "--options", 253 | f"data.name={dataset}", 254 | f"model.pretrained={pretrained_model_path}", 255 | f"data.sample_num={sample_num}", 256 | f"task.max_epoch={max_epoch}", 257 | f"task.max_iterations={max_iterations}", 258 | f"task.output_dir={output_root}/{sub_output_dir}", 259 | f"task.distributed={distributed}", 260 | ] 261 | 262 | if len(method_config) > 0: 263 | command += [f"method.{k}={v}" for k, v in method_config.items()] 264 | 265 | start_subprocess(command, print_command=True) 266 | 267 | 268 | def eval_on_multiple_datasets_script( 269 | config_path="configs/inference_config.yaml", 270 | eval_module="main.evaluate", 271 | datasets=DEFAULT_DATASET_SEQ, 272 | pretrained_model_path="openai", 273 | sample_num=-1, 274 | dump_result_path=None, 275 | ): 276 | eval_results = {} 277 | for eval_dataset in datasets: 278 | command = [ 279 | "python", 280 | "-m", 281 | eval_module, 282 | "--cfg-path", 283 | config_path, 284 | "--options", 285 | f"model.pretrained={pretrained_model_path}", 286 | f"data.name={eval_dataset}", 287 | f"data.sample_num={sample_num}", 288 | ] 289 | 290 | res = start_subprocess(command, print_command=True) 291 | 292 | eval_results[eval_dataset] = float(literal_eval(res)["zero shot"]["test_acc"]) 293 | 294 | if dump_result_path: 295 | with open(dump_result_path, "w") as f: 296 | json.dump(eval_results, f, indent=4) 297 | 298 | return eval_results 299 | --------------------------------------------------------------------------------