├── classifier ├── __init__.py ├── .DS_Store ├── __pycache__ │ ├── l2p.cpython-310.pyc │ ├── l2p.cpython-38.pyc │ ├── l2p.cpython-39.pyc │ ├── maple.cpython-38.pyc │ ├── maple.cpython-39.pyc │ ├── utils.cpython-310.pyc │ ├── utils.cpython-38.pyc │ ├── utils.cpython-39.pyc │ ├── vcop_4.cpython-38.pyc │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── coop_base.cpython-38.pyc │ ├── coop_base.cpython-39.pyc │ ├── evaluator.cpython-38.pyc │ ├── evaluator.cpython-39.pyc │ ├── vcop_eva.cpython-38.pyc │ ├── vcop_ima.cpython-38.pyc │ ├── vt_clip.cpython-310.pyc │ ├── vt_clip.cpython-38.pyc │ ├── zero_shot.cpython-38.pyc │ ├── zero_shot.cpython-39.pyc │ ├── attri_clip.cpython-310.pyc │ ├── attri_clip.cpython-38.pyc │ ├── attri_clip.cpython-39.pyc │ ├── coop_base.cpython-310.pyc │ ├── evaluator.cpython-310.pyc │ ├── prompt_vit.cpython-310.pyc │ ├── prompt_vit.cpython-38.pyc │ ├── prompt_vit.cpython-39.pyc │ ├── zero_shot.cpython-310.pyc │ ├── attri_clip_var.cpython-38.pyc │ ├── attri_clip_var.cpython-39.pyc │ ├── clip_adapter.cpython-310.pyc │ ├── clip_adapter.cpython-38.pyc │ ├── clip_adapter.cpython-39.pyc │ ├── coop_adapter.cpython-310.pyc │ ├── coop_adapter.cpython-38.pyc │ ├── coop_adapter.cpython-39.pyc │ ├── attri_clip_sum.cpython-310.pyc │ ├── attri_clip_var.cpython-310.pyc │ ├── coop_var_tuning.cpython-310.pyc │ ├── coop_var_tuning.cpython-38.pyc │ ├── coop_variational.cpython-38.pyc │ ├── coop_variational.cpython-39.pyc │ ├── ttest_evaluator.cpython-310.pyc │ ├── ttest_evaluator.cpython-38.pyc │ ├── ttest_evaluator.cpython-39.pyc │ ├── coop_base_backup.cpython-310.pyc │ ├── coop_deterministic.cpython-310.pyc │ ├── coop_deterministic.cpython-38.pyc │ ├── coop_deterministic.cpython-39.pyc │ ├── coop_variational.cpython-310.pyc │ ├── coop_variational1.cpython-310.pyc │ ├── coop_variational1.cpython-38.pyc │ ├── maple_variational.cpython-38.pyc │ ├── maple_variational.cpython-39.pyc │ ├── vt_clip_expandable.cpython-310.pyc │ ├── vt_clip_expandable.cpython-38.pyc │ ├── coop_variational_backup.cpython-310.pyc │ ├── coop_variational_deter.cpython-310.pyc │ ├── coop_variational_deter.cpython-38.pyc │ ├── continual_clip_var_joint.cpython-310.pyc │ ├── continual_clip_variational.cpython-38.pyc │ ├── continual_clip_deterministic.cpython-310.pyc │ ├── continual_clip_deterministic.cpython-38.pyc │ ├── continual_clip_deterministic.cpython-39.pyc │ ├── continual_clip_variational.cpython-310.pyc │ ├── coop_variational_expandable.cpython-310.pyc │ ├── coop_variational_expandable.cpython-38.pyc │ ├── continual_clip_variational_copy.cpython-38.pyc │ ├── continual_clip_variational_copy.cpython-39.pyc │ ├── coop_img_variational_expandable.cpython-38.pyc │ ├── continual_clip_variational_copy.cpython-310.pyc │ └── coop_img_variational_expandable.cpython-310.pyc ├── zero_shot.py ├── clip_adapter.py ├── utils.py └── evaluator.py ├── utils ├── progress │ ├── MANIFEST.in │ ├── .gitignore │ ├── demo.gif │ ├── LICENSE │ ├── setup.py │ ├── progress │ │ ├── spinner.py │ │ ├── counter.py │ │ ├── bar.py │ │ ├── helpers.py │ │ └── __init__.py │ ├── test_progress.py │ └── README.rst ├── .DS_Store ├── figs │ ├── front.png │ ├── results.png │ └── results3.png ├── __pycache__ │ ├── eval.cpython-38.pyc │ ├── eval.cpython-39.pyc │ ├── misc.cpython-38.pyc │ ├── misc.cpython-39.pyc │ ├── cutout.cpython-38.pyc │ ├── cutout.cpython-39.pyc │ ├── eval.cpython-310.pyc │ ├── logger.cpython-38.pyc │ ├── logger.cpython-39.pyc │ ├── misc.cpython-310.pyc │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── cutout.cpython-310.pyc │ ├── logger.cpython-310.pyc │ ├── toolkit.cpython-310.pyc │ ├── toolkit.cpython-38.pyc │ ├── toolkit.cpython-39.pyc │ ├── visualize.cpython-38.pyc │ ├── visualize.cpython-39.pyc │ ├── visualize.cpython-310.pyc │ ├── display_results.cpython-310.pyc │ ├── display_results.cpython-38.pyc │ ├── display_results.cpython-39.pyc │ ├── rotation_angle_matrix.cpython-38.pyc │ ├── rotation_angle_matrix.cpython-39.pyc │ └── rotation_angle_matrix.cpython-310.pyc ├── __init__.py ├── toolkit.py ├── CELS.py ├── eval.py ├── cutout.py ├── misc.py ├── get_num_of_params.py ├── parser_results.py ├── inc_net.py ├── rotation_angle_matrix.py ├── visualize.py ├── logger.py ├── visualize_ood_detection_trend.py └── display_results.py ├── dataset ├── .DS_Store ├── __pycache__ │ ├── cub.cpython-310.pyc │ ├── cub.cpython-38.pyc │ ├── cub.cpython-39.pyc │ ├── vtab.cpython-38.pyc │ ├── vtab.cpython-39.pyc │ ├── cifar.cpython-310.pyc │ ├── cifar.cpython-38.pyc │ ├── cifar.cpython-39.pyc │ ├── vtab.cpython-310.pyc │ ├── imagenet.cpython-310.pyc │ ├── imagenet.cpython-38.pyc │ ├── imagenet.cpython-39.pyc │ ├── imagenetr.cpython-38.pyc │ ├── imagenetr.cpython-39.pyc │ ├── imagenet100.cpython-310.pyc │ ├── imagenet100.cpython-38.pyc │ ├── imagenet100.cpython-39.pyc │ ├── imagenetr.cpython-310.pyc │ ├── exemplars_selection.cpython-38.pyc │ ├── exemplars_selection.cpython-39.pyc │ ├── evaluation_dataloader.cpython-38.pyc │ ├── exemplars_selection.cpython-310.pyc │ ├── incremental_dataloader.cpython-38.pyc │ ├── incremental_dataloader.cpython-39.pyc │ └── incremental_dataloader.cpython-310.pyc ├── class_order │ ├── cifar100_order1.yaml │ ├── cifar100_order2.yaml │ ├── cifar100_order3.yaml │ └── imagenet100_order1.yaml ├── exemplars_dataset.py ├── build_dataset.py ├── templates.json ├── cifar.py └── imagenetr.py ├── images ├── Slide13-1.png └── Screenshot 2024-05-24 162355.png ├── clip ├── bpe_simple_vocab_16e6.txt.gz ├── __pycache__ │ ├── clip.cpython-310.pyc │ ├── clip.cpython-38.pyc │ ├── clip.cpython-39.pyc │ ├── model.cpython-38.pyc │ ├── model.cpython-39.pyc │ ├── clip_2.cpython-310.pyc │ ├── clip_2.cpython-38.pyc │ ├── model.cpython-310.pyc │ ├── model_2.cpython-38.pyc │ ├── model_2.cpython-310.pyc │ ├── simple_tokenizer.cpython-310.pyc │ ├── simple_tokenizer.cpython-38.pyc │ ├── simple_tokenizer.cpython-39.pyc │ ├── vision_transformer_l2p.cpython-38.pyc │ ├── vision_transformer_l2p.cpython-39.pyc │ └── vision_transformer_l2p.cpython-310.pyc ├── simple_tokenizer.py ├── clip_2.py └── clip.py ├── requirements.txt ├── scripts ├── runner_zero_shot_all_datasets.sh ├── runner_clip_adapter_all_datasets.sh └── runner_clip_var_all_datasets.sh └── README.md /classifier/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/progress/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.rst LICENSE 2 | -------------------------------------------------------------------------------- /utils/progress/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.egg-info 3 | build/ 4 | dist/ 5 | -------------------------------------------------------------------------------- /dataset/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/dataset/.DS_Store -------------------------------------------------------------------------------- /utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/.DS_Store -------------------------------------------------------------------------------- /classifier/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/.DS_Store -------------------------------------------------------------------------------- /images/Slide13-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/images/Slide13-1.png -------------------------------------------------------------------------------- /utils/figs/front.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/figs/front.png -------------------------------------------------------------------------------- /utils/figs/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/figs/results.png -------------------------------------------------------------------------------- /utils/figs/results3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/figs/results3.png -------------------------------------------------------------------------------- /utils/progress/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/progress/demo.gif -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/__pycache__/clip.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/clip/__pycache__/clip.cpython-310.pyc -------------------------------------------------------------------------------- /clip/__pycache__/clip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/clip/__pycache__/clip.cpython-38.pyc -------------------------------------------------------------------------------- /clip/__pycache__/clip.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/clip/__pycache__/clip.cpython-39.pyc -------------------------------------------------------------------------------- /clip/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/clip/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /clip/__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/clip/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/eval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/__pycache__/eval.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/eval.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/__pycache__/eval.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/misc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/__pycache__/misc.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/misc.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/__pycache__/misc.cpython-39.pyc -------------------------------------------------------------------------------- /clip/__pycache__/clip_2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/clip/__pycache__/clip_2.cpython-310.pyc -------------------------------------------------------------------------------- /clip/__pycache__/clip_2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/clip/__pycache__/clip_2.cpython-38.pyc -------------------------------------------------------------------------------- /clip/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/clip/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /clip/__pycache__/model_2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/clip/__pycache__/model_2.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/cub.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/dataset/__pycache__/cub.cpython-310.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/cub.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/dataset/__pycache__/cub.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/cub.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/dataset/__pycache__/cub.cpython-39.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/vtab.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/dataset/__pycache__/vtab.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/vtab.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/dataset/__pycache__/vtab.cpython-39.pyc -------------------------------------------------------------------------------- /images/Screenshot 2024-05-24 162355.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/images/Screenshot 2024-05-24 162355.png -------------------------------------------------------------------------------- /utils/__pycache__/cutout.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/__pycache__/cutout.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/cutout.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/__pycache__/cutout.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/eval.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/__pycache__/eval.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/__pycache__/logger.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/misc.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/__pycache__/misc.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/l2p.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/l2p.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/l2p.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/l2p.cpython-38.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/l2p.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/l2p.cpython-39.pyc -------------------------------------------------------------------------------- /clip/__pycache__/model_2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/clip/__pycache__/model_2.cpython-310.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/cifar.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/dataset/__pycache__/cifar.cpython-310.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/cifar.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/dataset/__pycache__/cifar.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/cifar.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/dataset/__pycache__/cifar.cpython-39.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/vtab.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/dataset/__pycache__/vtab.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/cutout.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/__pycache__/cutout.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/__pycache__/logger.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/toolkit.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/__pycache__/toolkit.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/toolkit.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/__pycache__/toolkit.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/toolkit.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/__pycache__/toolkit.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/visualize.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/__pycache__/visualize.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/visualize.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/__pycache__/visualize.cpython-39.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/maple.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/maple.cpython-38.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/maple.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/maple.cpython-39.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/vcop_4.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/vcop_4.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/imagenet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/dataset/__pycache__/imagenet.cpython-310.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/imagenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/dataset/__pycache__/imagenet.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/imagenet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/dataset/__pycache__/imagenet.cpython-39.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/imagenetr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/dataset/__pycache__/imagenetr.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/imagenetr.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/dataset/__pycache__/imagenetr.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/visualize.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/__pycache__/visualize.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/coop_base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/coop_base.cpython-38.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/coop_base.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/coop_base.cpython-39.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/evaluator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/evaluator.cpython-38.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/evaluator.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/evaluator.cpython-39.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/vcop_eva.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/vcop_eva.cpython-38.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/vcop_ima.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/vcop_ima.cpython-38.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/vt_clip.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/vt_clip.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/vt_clip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/vt_clip.cpython-38.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/zero_shot.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/zero_shot.cpython-38.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/zero_shot.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/zero_shot.cpython-39.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/imagenet100.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/dataset/__pycache__/imagenet100.cpython-310.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/imagenet100.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/dataset/__pycache__/imagenet100.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/imagenet100.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/dataset/__pycache__/imagenet100.cpython-39.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/imagenetr.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/dataset/__pycache__/imagenetr.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/attri_clip.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/attri_clip.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/attri_clip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/attri_clip.cpython-38.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/attri_clip.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/attri_clip.cpython-39.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/coop_base.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/coop_base.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/evaluator.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/evaluator.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/prompt_vit.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/prompt_vit.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/prompt_vit.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/prompt_vit.cpython-38.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/prompt_vit.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/prompt_vit.cpython-39.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/zero_shot.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/zero_shot.cpython-310.pyc -------------------------------------------------------------------------------- /clip/__pycache__/simple_tokenizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/clip/__pycache__/simple_tokenizer.cpython-310.pyc -------------------------------------------------------------------------------- /clip/__pycache__/simple_tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/clip/__pycache__/simple_tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /clip/__pycache__/simple_tokenizer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/clip/__pycache__/simple_tokenizer.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/display_results.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/__pycache__/display_results.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/display_results.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/__pycache__/display_results.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/display_results.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/__pycache__/display_results.cpython-39.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/attri_clip_var.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/attri_clip_var.cpython-38.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/attri_clip_var.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/attri_clip_var.cpython-39.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/clip_adapter.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/clip_adapter.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/clip_adapter.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/clip_adapter.cpython-38.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/clip_adapter.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/clip_adapter.cpython-39.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/coop_adapter.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/coop_adapter.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/coop_adapter.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/coop_adapter.cpython-38.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/coop_adapter.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/coop_adapter.cpython-39.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/attri_clip_sum.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/attri_clip_sum.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/attri_clip_var.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/attri_clip_var.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/coop_var_tuning.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/coop_var_tuning.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/coop_var_tuning.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/coop_var_tuning.cpython-38.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/coop_variational.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/coop_variational.cpython-38.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/coop_variational.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/coop_variational.cpython-39.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/ttest_evaluator.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/ttest_evaluator.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/ttest_evaluator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/ttest_evaluator.cpython-38.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/ttest_evaluator.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/ttest_evaluator.cpython-39.pyc -------------------------------------------------------------------------------- /clip/__pycache__/vision_transformer_l2p.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/clip/__pycache__/vision_transformer_l2p.cpython-38.pyc -------------------------------------------------------------------------------- /clip/__pycache__/vision_transformer_l2p.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/clip/__pycache__/vision_transformer_l2p.cpython-39.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/exemplars_selection.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/dataset/__pycache__/exemplars_selection.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/exemplars_selection.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/dataset/__pycache__/exemplars_selection.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/rotation_angle_matrix.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/__pycache__/rotation_angle_matrix.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/rotation_angle_matrix.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/__pycache__/rotation_angle_matrix.cpython-39.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/coop_base_backup.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/coop_base_backup.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/coop_deterministic.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/coop_deterministic.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/coop_deterministic.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/coop_deterministic.cpython-38.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/coop_deterministic.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/coop_deterministic.cpython-39.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/coop_variational.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/coop_variational.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/coop_variational1.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/coop_variational1.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/coop_variational1.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/coop_variational1.cpython-38.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/maple_variational.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/maple_variational.cpython-38.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/maple_variational.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/maple_variational.cpython-39.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/vt_clip_expandable.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/vt_clip_expandable.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/vt_clip_expandable.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/vt_clip_expandable.cpython-38.pyc -------------------------------------------------------------------------------- /clip/__pycache__/vision_transformer_l2p.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/clip/__pycache__/vision_transformer_l2p.cpython-310.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/evaluation_dataloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/dataset/__pycache__/evaluation_dataloader.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/exemplars_selection.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/dataset/__pycache__/exemplars_selection.cpython-310.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/incremental_dataloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/dataset/__pycache__/incremental_dataloader.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/incremental_dataloader.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/dataset/__pycache__/incremental_dataloader.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/rotation_angle_matrix.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/utils/__pycache__/rotation_angle_matrix.cpython-310.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/incremental_dataloader.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/dataset/__pycache__/incremental_dataloader.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/coop_variational_backup.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/coop_variational_backup.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/coop_variational_deter.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/coop_variational_deter.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/coop_variational_deter.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/coop_variational_deter.cpython-38.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/continual_clip_var_joint.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/continual_clip_var_joint.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/continual_clip_variational.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/continual_clip_variational.cpython-38.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/continual_clip_deterministic.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/continual_clip_deterministic.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/continual_clip_deterministic.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/continual_clip_deterministic.cpython-38.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/continual_clip_deterministic.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/continual_clip_deterministic.cpython-39.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/continual_clip_variational.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/continual_clip_variational.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/coop_variational_expandable.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/coop_variational_expandable.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/coop_variational_expandable.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/coop_variational_expandable.cpython-38.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/continual_clip_variational_copy.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/continual_clip_variational_copy.cpython-38.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/continual_clip_variational_copy.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/continual_clip_variational_copy.cpython-39.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/coop_img_variational_expandable.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/coop_img_variational_expandable.cpython-38.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/continual_clip_variational_copy.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/continual_clip_variational_copy.cpython-310.pyc -------------------------------------------------------------------------------- /classifier/__pycache__/coop_img_variational_expandable.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/clap4clip/HEAD/classifier/__pycache__/coop_img_variational_expandable.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | from .misc import * 4 | from .logger import * 5 | from .visualize import * 6 | from .eval import * 7 | 8 | # progress bar 9 | import os, sys 10 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 11 | from progress.bar import Bar as Bar -------------------------------------------------------------------------------- /utils/toolkit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | def split_images_labels(imgs): 5 | # split trainset.imgs in ImageFolder 6 | images = [] 7 | labels = [] 8 | for item in imgs: 9 | images.append(item[0]) 10 | labels.append(item[1]) 11 | 12 | return np.array(images), labels -------------------------------------------------------------------------------- /dataset/class_order/cifar100_order1.yaml: -------------------------------------------------------------------------------- 1 | class_order: [87, 0, 52, 58, 44, 91, 68, 97, 51, 15, 94, 92, 10, 72, 49, 78, 61, 14, 8, 86, 84, 96, 18, 24, 32, 45, 88, 11, 4, 67, 69, 66, 77, 47, 79, 93, 29, 50, 57, 83, 17, 81, 41, 12, 37, 59, 25, 20, 80, 73, 1, 28, 6, 46, 62, 82, 53, 9, 31, 75, 38, 63, 33, 74, 27, 22, 36, 3, 16, 21, 60, 19, 70, 90, 89, 43, 5, 42, 65, 76, 40, 30, 23, 85, 2, 95, 56, 48, 71, 64, 98, 13, 99, 7, 34, 55, 54, 26, 35, 39] -------------------------------------------------------------------------------- /dataset/class_order/cifar100_order2.yaml: -------------------------------------------------------------------------------- 1 | class_order: [58, 30, 93, 69, 21, 77, 3, 78, 12, 71, 65, 40, 16, 49, 89, 46, 24, 66, 19, 41, 5, 29, 15, 73, 11, 70, 90, 63, 67, 25, 59, 72, 80, 94, 54, 33, 18, 96, 2, 10, 43, 9, 57, 81, 76, 50, 32, 6, 37, 7, 68, 91, 88, 95, 85, 4, 60, 36, 22, 27, 39, 42, 34, 51, 55, 28, 53, 48, 38, 17, 83, 86, 56, 35, 45, 79, 99, 84, 97, 82, 98, 26, 47, 44, 62, 13, 31, 0, 75, 14, 52, 74, 8, 20, 1, 92, 87, 23, 64, 61] -------------------------------------------------------------------------------- /dataset/class_order/cifar100_order3.yaml: -------------------------------------------------------------------------------- 1 | class_order: [71, 54, 45, 32, 4, 8, 48, 66, 1, 91, 28, 82, 29, 22, 80, 27, 86, 23, 37, 47, 55, 9, 14, 68, 25, 96, 36, 90, 58, 21, 57, 81, 12, 26, 16, 89, 79, 49, 31, 38, 46, 20, 92, 88, 40, 39, 98, 94, 19, 95, 72, 24, 64, 18, 60, 50, 63, 61, 83, 76, 69, 35, 0, 52, 7, 65, 42, 73, 74, 30, 41, 3, 6, 53, 13, 56, 70, 77, 34, 97, 75, 2, 17, 93, 33, 84, 99, 51, 62, 87, 5, 15, 10, 78, 67, 44, 59, 85, 43, 11] -------------------------------------------------------------------------------- /dataset/class_order/imagenet100_order1.yaml: -------------------------------------------------------------------------------- 1 | class_order: [68, 56, 78, 8, 23, 84, 90, 65, 74, 76, 40, 89, 3, 92, 55, 9, 26, 80, 43, 38, 58, 70, 77, 1, 85, 19, 17, 50, 2 | 28, 53, 13, 81, 45, 82, 6, 59, 83, 16, 15, 44, 91, 41, 72, 60, 79, 52, 20, 10, 31, 54, 37, 95, 14, 71, 96, 3 | 98, 97, 2, 64, 66, 42, 22, 35, 86, 24, 34, 87, 21, 99, 0, 88, 27, 18, 94, 11, 12, 47, 25, 30, 46, 62, 69, 4 | 36, 61, 7, 63, 75, 5, 32, 4, 51, 48, 73, 93, 39, 67, 29, 49, 57, 33] -------------------------------------------------------------------------------- /utils/CELS.py: -------------------------------------------------------------------------------- 1 | import math 2 | import copy 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class CrossEntropyLabelSmooth(nn.Module): 8 | 9 | def __init__(self, num_classes, epsilon): 10 | super(CrossEntropyLabelSmooth, self).__init__() 11 | self.num_classes = num_classes 12 | self.epsilon = epsilon 13 | self.logsoftmax = nn.LogSoftmax(dim=1) 14 | 15 | def forward(self, inputs, targets): 16 | log_probs = self.logsoftmax(inputs) 17 | targets = torch.zeros_like(log_probs).scatter_(1, torch.unsqueeze(targets, 1), 1) 18 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 19 | loss = torch.mean(torch.sum(-targets * log_probs, 1)) 20 | return loss 21 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib @ file:///croot/matplotlib-suite_1693812469450/work 2 | matplotlib-inline @ file:///opt/conda/conda-bld/matplotlib-inline_1662014470464/work 3 | numpy @ file:///croot/numpy_and_numpy_base_1682520569166/work 4 | numpydoc @ file:///croot/numpydoc_1668085905352/work 5 | pandas @ file:///croot/pandas_1692289311655/work 6 | protobuf==4.25.2 7 | scikit-image @ file:///croot/scikit-image_1669241743693/work 8 | scikit-learn @ file:///croot/scikit-learn_1694788527225/work 9 | scikit-learn-intelex==20230228.214413 10 | scipy @ file:///tmp/build/80754af9/scipy_1630606796110/work 11 | seaborn @ file:///croot/seaborn_1673479180098/work 12 | statsmodels @ file:///croot/statsmodels_1689937266057/work 13 | tb-nightly==2.14.0a20230808 14 | torch==2.1.2 15 | torchmetrics==1.2.1 16 | torchvision==0.16.2 17 | -------------------------------------------------------------------------------- /utils/progress/LICENSE: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | -------------------------------------------------------------------------------- /scripts/runner_zero_shot_all_datasets.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -q gpuvolta 3 | #PBS -l storage=gdata/kf26 4 | #PBS -l walltime=30:00:00 5 | #PBS -l ngpus=2 6 | #PBS -l ncpus=24 7 | #PBS -P kf26 8 | #PBS -l mem=40GB 9 | #PBS -l jobfs=40GB 10 | #PBS -l wd 11 | #module load intel-mkl/2020.3.304 12 | # module load python3/3.9.2 13 | # module load cudnn/8.2.2-cuda11.4 14 | # module load cuda/11.4.1 15 | for ARCH in ViT-B-16 #ViT-L-14 16 | do 17 | for RUN in 0 #1 2 3 4 5 6 7 8 9 18 | do 19 | for EPOCH in 1 #10 #5 10 #5 10 15 #5 10 #15 20 | do 21 | for DATASET in cifar100 imagenet-r 22 | do 23 | python3 main_incremental_submit.py --db_name $DATASET --num-run $RUN --compute-ece --compute-bwt --train_batch 32 --root ../mammoth_datasets/ --multi-gpu --gpus 0,1 --default-gpu 0 --model clclip --epochs $EPOCH --forward-times 20 --arch $ARCH 24 | done 25 | done 26 | done 27 | done 28 | -------------------------------------------------------------------------------- /scripts/runner_clip_adapter_all_datasets.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -q gpuvolta 3 | #PBS -l storage=gdata/kf26 4 | #PBS -l walltime=30:00:00 5 | #PBS -l ngpus=2 6 | #PBS -l ncpus=24 7 | #PBS -P kf26 8 | #PBS -l mem=40GB 9 | #PBS -l jobfs=40GB 10 | #PBS -l wd 11 | #module load intel-mkl/2020.3.304 12 | # module load python3/3.9.2 13 | # module load cudnn/8.2.2-cuda11.4 14 | # module load cuda/11.4.1 15 | for ARCH in ViT-B-16 #ViT-L-14 16 | do 17 | for RUN in 0 #1 2 3 4 5 6 7 8 9 18 | do 19 | for EPOCH in 5 20 | do 21 | for DATASET in cifar100 imagenet-r 22 | do 23 | python3 main_incremental_submit.py --db_name $DATASET --finetuning --finetune-epochs 2 --num-run $RUN --compute-ece --compute-bwt --train_batch 32 --exemplar-selector random --root ../mammoth_datasets/ --multi-gpu --gpus 0,1 --default-gpu 0 --model clip_adapter --epochs $EPOCH --arch $ARCH --method er 24 | done 25 | done 26 | done 27 | done 28 | -------------------------------------------------------------------------------- /utils/progress/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup 4 | 5 | import progress 6 | 7 | 8 | setup( 9 | name='progress', 10 | version=progress.__version__, 11 | description='Easy to use progress bars', 12 | long_description=open('README.rst').read(), 13 | author='Giorgos Verigakis', 14 | author_email='verigak@gmail.com', 15 | url='http://github.com/verigak/progress/', 16 | license='ISC', 17 | packages=['progress'], 18 | classifiers=[ 19 | 'Environment :: Console', 20 | 'Intended Audience :: Developers', 21 | 'License :: OSI Approved :: ISC License (ISCL)', 22 | 'Programming Language :: Python :: 2.6', 23 | 'Programming Language :: Python :: 2.7', 24 | 'Programming Language :: Python :: 3.3', 25 | 'Programming Language :: Python :: 3.4', 26 | 'Programming Language :: Python :: 3.5', 27 | 'Programming Language :: Python :: 3.6', 28 | ] 29 | ) 30 | -------------------------------------------------------------------------------- /scripts/runner_clip_var_all_datasets.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #PBS -q gpuvolta 3 | #PBS -l storage=gdata/kf26 4 | #PBS -l walltime=30:00:00 5 | #PBS -l ngpus=2 6 | #PBS -l ncpus=24 7 | #PBS -P kf26 8 | #PBS -l mem=40GB 9 | #PBS -l jobfs=40GB 10 | #PBS -l wd 11 | #module load intel-mkl/2020.3.304 12 | # module load python3/3.9.2 13 | # module load cudnn/8.2.2-cuda11.4 14 | # module load cuda/11.4.1 15 | for ARCH in ViT-B-16 #ViT-L-14 16 | do 17 | for RUN in 0 #1 2 3 4 5 6 7 8 9 18 | do 19 | for EPOCH in 1 #10 #5 10 #5 10 15 #5 10 #15 20 | do 21 | for ALPHA in 10. # 2. 5. 10. 15. 20. 22 | do 23 | for DATASET in cifar100 imagenet-r 24 | do 25 | for BETA in 15 26 | do 27 | python3 main_incremental_submit.py --lasp --beta $BETA --db_name $DATASET --use-vga --expandable-adapter --finetuning --finetune-epochs 1 --num-run $RUN --compute-ece --compute-bwt --train_batch 32 --exemplar-selector random --root ../mammoth_datasets/ --multi-gpu --gpus 0,1 --default-gpu 0 --model clclip_var --epochs $EPOCH --forward-times 20 --arch $ARCH --method er --variational 28 | done 29 | done 30 | done 31 | done 32 | done 33 | done 34 | -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import torch 3 | import pdb 4 | __all__ = ['accuracy'] 5 | 6 | # def accuracy(output, target, topk=(1,)): 7 | # """Computes the precision@k for the specified values of k""" 8 | # maxk = max(topk) 9 | # batch_size = target.size(0) 10 | # _, pred = output.topk(maxk, 1, True, True) 11 | # pred = pred.t() 12 | # correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | # res = [] 15 | # for k in topk: 16 | # correct_k = correct[:k].view(-1).float().sum(0) 17 | # res.append(correct_k.mul_(100.0 / batch_size)) 18 | # return res 19 | 20 | 21 | def accuracy(output, target, topk=(1,)): 22 | """Computes the accuracy over the k top predictions for the specified values of k""" 23 | 24 | maxk = max(topk) 25 | batch_size = target.size(0) 26 | _, pred = output.topk(maxk, 1, True, True) 27 | pred = pred.t() 28 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 29 | 30 | 31 | res = [] 32 | for k in topk: 33 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 34 | res.append(correct_k.mul_(100.0 / batch_size)) 35 | return res 36 | 37 | 38 | -------------------------------------------------------------------------------- /utils/cutout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class Cutout(object): 6 | """Randomly mask out one or more patches from an image. 7 | 8 | Args: 9 | n_holes (int): Number of patches to cut out of each image. 10 | length (int): The length (in pixels) of each square patch. 11 | """ 12 | def __init__(self, n_holes, length): 13 | self.n_holes = n_holes 14 | self.length = length 15 | 16 | def __call__(self, img): 17 | """ 18 | Args: 19 | img (Tensor): Tensor image of size (C, H, W). 20 | Returns: 21 | Tensor: Image with n_holes of dimension length x length cut out of it. 22 | """ 23 | 24 | h = 32 25 | w = 32 26 | 27 | mask = np.ones((h, w), np.float32) 28 | 29 | for n in range(self.n_holes): 30 | y = np.random.randint(h) 31 | x = np.random.randint(w) 32 | 33 | y1 = np.clip(y - self.length // 2, 0, h) 34 | y2 = np.clip(y + self.length // 2, 0, h) 35 | x1 = np.clip(x - self.length // 2, 0, w) 36 | x2 = np.clip(x + self.length // 2, 0, w) 37 | 38 | mask[y1: y2, x1: x2] = 0. 39 | 40 | mask = torch.from_numpy(mask) 41 | mask = mask.expand_as(img) 42 | img = img * mask 43 | 44 | return img 45 | -------------------------------------------------------------------------------- /utils/progress/progress/spinner.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Infinite 19 | from .helpers import WriteMixin 20 | 21 | 22 | class Spinner(WriteMixin, Infinite): 23 | message = '' 24 | phases = ('-', '\\', '|', '/') 25 | hide_cursor = True 26 | 27 | def update(self): 28 | i = self.index % len(self.phases) 29 | self.write(self.phases[i]) 30 | 31 | 32 | class PieSpinner(Spinner): 33 | phases = ['◷', '◶', '◵', '◴'] 34 | 35 | 36 | class MoonSpinner(Spinner): 37 | phases = ['◑', '◒', '◐', '◓'] 38 | 39 | 40 | class LineSpinner(Spinner): 41 | phases = ['⎺', '⎻', '⎼', '⎽', '⎼', '⎻'] 42 | 43 | class PixelSpinner(Spinner): 44 | phases = ['⣾','⣷', '⣯', '⣟', '⡿', '⢿', '⣻', '⣽'] 45 | -------------------------------------------------------------------------------- /utils/progress/test_progress.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from __future__ import print_function 4 | 5 | import random 6 | import time 7 | 8 | from progress.bar import (Bar, ChargingBar, FillingSquaresBar, 9 | FillingCirclesBar, IncrementalBar, PixelBar, 10 | ShadyBar) 11 | from progress.spinner import (Spinner, PieSpinner, MoonSpinner, LineSpinner, 12 | PixelSpinner) 13 | from progress.counter import Counter, Countdown, Stack, Pie 14 | 15 | 16 | def sleep(): 17 | t = 0.01 18 | t += t * random.uniform(-0.1, 0.1) # Add some variance 19 | time.sleep(t) 20 | 21 | 22 | for bar_cls in (Bar, ChargingBar, FillingSquaresBar, FillingCirclesBar): 23 | suffix = '%(index)d/%(max)d [%(elapsed)d / %(eta)d / %(eta_td)s]' 24 | bar = bar_cls(bar_cls.__name__, suffix=suffix) 25 | for i in bar.iter(range(200)): 26 | sleep() 27 | 28 | for bar_cls in (IncrementalBar, PixelBar, ShadyBar): 29 | suffix = '%(percent)d%% [%(elapsed_td)s / %(eta)d / %(eta_td)s]' 30 | bar = bar_cls(bar_cls.__name__, suffix=suffix) 31 | for i in bar.iter(range(200)): 32 | sleep() 33 | 34 | for spin in (Spinner, PieSpinner, MoonSpinner, LineSpinner, PixelSpinner): 35 | for i in spin(spin.__name__ + ' ').iter(range(100)): 36 | sleep() 37 | print() 38 | 39 | for singleton in (Counter, Countdown, Stack, Pie): 40 | for i in singleton(singleton.__name__ + ' ').iter(range(100)): 41 | sleep() 42 | print() 43 | 44 | bar = IncrementalBar('Random', suffix='%(index)d') 45 | for i in range(100): 46 | bar.goto(random.randint(0, 100)) 47 | sleep() 48 | bar.finish() 49 | -------------------------------------------------------------------------------- /utils/progress/progress/counter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Infinite, Progress 19 | from .helpers import WriteMixin 20 | 21 | 22 | class Counter(WriteMixin, Infinite): 23 | message = '' 24 | hide_cursor = True 25 | 26 | def update(self): 27 | self.write(str(self.index)) 28 | 29 | 30 | class Countdown(WriteMixin, Progress): 31 | hide_cursor = True 32 | 33 | def update(self): 34 | self.write(str(self.remaining)) 35 | 36 | 37 | class Stack(WriteMixin, Progress): 38 | phases = (' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█') 39 | hide_cursor = True 40 | 41 | def update(self): 42 | nphases = len(self.phases) 43 | i = min(nphases - 1, int(self.progress * nphases)) 44 | self.write(self.phases[i]) 45 | 46 | 47 | class Pie(Stack): 48 | phases = ('○', '◔', '◑', '◕', '●') 49 | -------------------------------------------------------------------------------- /dataset/exemplars_dataset.py: -------------------------------------------------------------------------------- 1 | from .incremental_dataloader import IncrementalDataset 2 | 3 | class ExemplarsDataset(IncrementalDataset): 4 | """Exemplar storage for approaches with an interface of Dataset""" 5 | 6 | def __init__(self, transform, class_indices, 7 | num_exemplars=0, num_exemplars_per_class=0, exemplar_selection='random'): 8 | super().__init__({'x': [], 'y': []}, transform, class_indices=class_indices) 9 | self.max_num_exemplars_per_class = num_exemplars_per_class 10 | self.max_num_exemplars = num_exemplars 11 | assert (num_exemplars_per_class == 0) or (num_exemplars == 0), 'Cannot use both limits at once!' 12 | cls_name = "{}ExemplarsSelector".format(exemplar_selection.capitalize()) 13 | selector_cls = getattr(importlib.import_module(name='datasets.exemplars_selection'), cls_name) 14 | self.exemplars_selector = selector_cls(self) 15 | 16 | # Returns a parser containing the approach specific parameters 17 | @staticmethod 18 | def extra_parser(args): 19 | parser = ArgumentParser("Exemplars Management Parameters") 20 | _group = parser.add_mutually_exclusive_group() 21 | _group.add_argument('--num-exemplars', default=0, type=int, required=False, 22 | help='Fixed memory, total number of exemplars (default=%(default)s)') 23 | _group.add_argument('--num-exemplars-per-class', default=0, type=int, required=False, 24 | help='Growing memory, number of exemplars per class (default=%(default)s)') 25 | parser.add_argument('--exemplar-selection', default='random', type=str, 26 | choices=['herding', 'random', 'entropy', 'distance'], 27 | required=False, help='Exemplar selection strategy (default=%(default)s)') 28 | return parser.parse_known_args(args) 29 | 30 | def _is_active(self): 31 | return self.max_num_exemplars_per_class > 0 or self.max_num_exemplars > 0 32 | 33 | def collect_exemplars(self, model, trn_loader, selection_transform): 34 | if self._is_active(): 35 | self.images, self.labels = self.exemplars_selector(model, trn_loader, selection_transform) -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import errno 7 | import os 8 | import sys 9 | import time 10 | import math 11 | 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | from torch.autograd import Variable 15 | 16 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter'] 17 | 18 | 19 | def get_mean_and_std(dataset): 20 | '''Compute the mean and std value of dataset.''' 21 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 22 | 23 | mean = torch.zeros(3) 24 | std = torch.zeros(3) 25 | print('==> Computing mean and std..') 26 | for inputs, targets in dataloader: 27 | for i in range(3): 28 | mean[i] += inputs[:,i,:,:].mean() 29 | std[i] += inputs[:,i,:,:].std() 30 | mean.div_(len(dataset)) 31 | std.div_(len(dataset)) 32 | return mean, std 33 | 34 | def init_params(net): 35 | '''Init layer parameters.''' 36 | for m in net.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | init.kaiming_normal(m.weight, mode='fan_out') 39 | if m.bias: 40 | init.constant(m.bias, 0) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.constant(m.weight, 1) 43 | init.constant(m.bias, 0) 44 | elif isinstance(m, nn.Linear): 45 | init.normal(m.weight, std=1e-3) 46 | if m.bias: 47 | init.constant(m.bias, 0) 48 | 49 | def mkdir_p(path): 50 | '''make dir if not exist''' 51 | try: 52 | os.makedirs(path) 53 | except OSError as exc: # Python >2.5 54 | if exc.errno == errno.EEXIST and os.path.isdir(path): 55 | pass 56 | else: 57 | raise 58 | 59 | class AverageMeter(object): 60 | """Computes and stores the average and current value 61 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 62 | """ 63 | def __init__(self): 64 | self.reset() 65 | 66 | def reset(self): 67 | self.val = 0 68 | self.avg = 0 69 | self.sum = 0 70 | self.count = 0 71 | 72 | def update(self, val, n=1): 73 | self.val = val 74 | self.sum += val * n 75 | self.count += n 76 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /utils/get_num_of_params.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import seaborn as sns 3 | import matplotlib.pyplot as plt 4 | 5 | sns.color_palette("Spectral") 6 | 7 | plt.rcParams.update({'font.size': 28,}) 8 | 9 | min_coord = 1 10 | max_coord = 480 11 | 12 | colors = { 13 | 'iCaRL': 'blue', 14 | 'Continual-CLIP': 'deepskyblue', 15 | 'CoOp': 'gray', 16 | 'CLIP-Adapter': 'lime', 17 | 'AttriCLIP': 'orange', 18 | 'DualPrompt': 'fuchsia', 19 | 'L2P': 'brown', 20 | 'PROOF': 'green', 21 | 'TAw-UB': 'black', 22 | 'Ours': 'red', 23 | "Ours + CoOp": 'black', 24 | "Ours + AttriCLIP": 'indigo' 25 | } 26 | 27 | def reverse_scale(coord): 28 | return max_coord - coord 29 | 30 | def get_max_param_num(zsclip_percentage): 31 | max_percent = 149.6 / zsclip_percentage 32 | return max_percent 33 | 34 | def get_num_of_params(coord_dict): 35 | coord_dict = {method: reverse_scale(c) for method, c in coord_dict.items()} 36 | coord_dict_percentage = {method: (value - min_coord) / (max_coord - min_coord) for method, value in coord_dict.items()} 37 | max_param = get_max_param_num(coord_dict_percentage['Continual-CLIP']) 38 | param_nums = {method: value * max_param for method, value in coord_dict_percentage.items()} 39 | print(param_nums) 40 | return param_nums 41 | 42 | def draw_plot(data): 43 | # Set the style for the plot 44 | # sns.set(style="whitegrid") 45 | 46 | # Create a figure with the specified size 47 | plt.figure(figsize=(10, 6)) 48 | 49 | hues = [colors[key_] for key_ in list(data.keys())] 50 | # Create the bar plot with a unique color for each bar 51 | barplot = sns.barplot(x=list(data.keys()), y=list(data.values()), palette=hues, hue=list(data.keys()), hue_order=list(data.keys()), dodge=False) 52 | for i in range(len(data)): 53 | barplot.bar_label(barplot.containers[i], fmt='%.1f', rotation=20, label_type='center', fontsize=22) 54 | 55 | # Set the y-axis range to start from 50 56 | barplot.set(ylim=(0, max(data.values()) + 50)) 57 | barplot.set_xticklabels(list(data.keys()), rotation = 25) 58 | 59 | # Add a legend across 3 columns at the top 60 | barplot.legend(loc="upper center", bbox_to_anchor=(0.55, 1.15), ncol=2, prop={'size': 25}, frameon=False) 61 | # Set labels and title 62 | # plt.xlabel('Methods') 63 | plt.ylabel(f'Parameters (in millions)') 64 | barplot.spines[['right', 'top']].set_visible(False) 65 | barplot.set_xticklabels([]) 66 | # RED 67 | barplot.set_xticks([]) 68 | plt.tight_layout() 69 | plt.savefig("parameter_comparison.pdf") 70 | # Show the plot 71 | plt.show() 72 | 73 | param_nums = {'iCaRL': 299.2} 74 | param_nums['CLIP-Adapter'] = 149.8 75 | param_nums['Continual-CLIP'] = 149.6 76 | coord_list = {'L2P': 295, 'DualPrompt': 295, 'PROOF': 314, 'Continual-CLIP': 320, 'CoOp': 320,} #, 'iCaRL': 161} 77 | param_nums.update(get_num_of_params(coord_list)) 78 | param_nums['AttriCLIP'] = 149.7 79 | # param_nums['PROOF'] = 153.1 80 | param_nums['Ours'] = 159.5 81 | # param_nums['iCaRL'] = 299.2 82 | 83 | draw_plot(param_nums) -------------------------------------------------------------------------------- /utils/progress/progress/bar.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Progress 19 | from .helpers import WritelnMixin 20 | 21 | 22 | class Bar(WritelnMixin, Progress): 23 | width = 32 24 | message = '' 25 | suffix = '%(index)d/%(max)d' 26 | bar_prefix = ' |' 27 | bar_suffix = '| ' 28 | empty_fill = ' ' 29 | fill = '#' 30 | hide_cursor = True 31 | 32 | def update(self): 33 | filled_length = int(self.width * self.progress) 34 | empty_length = self.width - filled_length 35 | 36 | message = self.message % self 37 | bar = self.fill * filled_length 38 | empty = self.empty_fill * empty_length 39 | suffix = self.suffix % self 40 | line = ''.join([message, self.bar_prefix, bar, empty, self.bar_suffix, 41 | suffix]) 42 | self.writeln(line) 43 | 44 | 45 | class ChargingBar(Bar): 46 | suffix = '%(percent)d%%' 47 | bar_prefix = ' ' 48 | bar_suffix = ' ' 49 | empty_fill = '∙' 50 | fill = '█' 51 | 52 | 53 | class FillingSquaresBar(ChargingBar): 54 | empty_fill = '▢' 55 | fill = '▣' 56 | 57 | 58 | class FillingCirclesBar(ChargingBar): 59 | empty_fill = '◯' 60 | fill = '◉' 61 | 62 | 63 | class IncrementalBar(Bar): 64 | phases = (' ', '▏', '▎', '▍', '▌', '▋', '▊', '▉', '█') 65 | 66 | def update(self): 67 | nphases = len(self.phases) 68 | filled_len = self.width * self.progress 69 | nfull = int(filled_len) # Number of full chars 70 | phase = int((filled_len - nfull) * nphases) # Phase of last char 71 | nempty = self.width - nfull # Number of empty chars 72 | 73 | message = self.message % self 74 | bar = self.phases[-1] * nfull 75 | current = self.phases[phase] if phase > 0 else '' 76 | empty = self.empty_fill * max(0, nempty - len(current)) 77 | suffix = self.suffix % self 78 | line = ''.join([message, self.bar_prefix, bar, current, empty, 79 | self.bar_suffix, suffix]) 80 | self.writeln(line) 81 | 82 | 83 | class PixelBar(IncrementalBar): 84 | phases = ('⡀', '⡄', '⡆', '⡇', '⣇', '⣧', '⣷', '⣿') 85 | 86 | 87 | class ShadyBar(IncrementalBar): 88 | phases = (' ', '░', '▒', '▓', '█') 89 | -------------------------------------------------------------------------------- /utils/progress/progress/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | 15 | from __future__ import print_function 16 | 17 | 18 | HIDE_CURSOR = '\x1b[?25l' 19 | SHOW_CURSOR = '\x1b[?25h' 20 | 21 | 22 | class WriteMixin(object): 23 | hide_cursor = False 24 | 25 | def __init__(self, message=None, **kwargs): 26 | super(WriteMixin, self).__init__(**kwargs) 27 | self._width = 0 28 | if message: 29 | self.message = message 30 | 31 | if self.file.isatty(): 32 | if self.hide_cursor: 33 | print(HIDE_CURSOR, end='', file=self.file) 34 | print(self.message, end='', file=self.file) 35 | self.file.flush() 36 | 37 | def write(self, s): 38 | if self.file.isatty(): 39 | b = '\b' * self._width 40 | c = s.ljust(self._width) 41 | print(b + c, end='', file=self.file) 42 | self._width = max(self._width, len(s)) 43 | self.file.flush() 44 | 45 | def finish(self): 46 | if self.file.isatty() and self.hide_cursor: 47 | print(SHOW_CURSOR, end='', file=self.file) 48 | 49 | 50 | class WritelnMixin(object): 51 | hide_cursor = False 52 | 53 | def __init__(self, message=None, **kwargs): 54 | super(WritelnMixin, self).__init__(**kwargs) 55 | if message: 56 | self.message = message 57 | 58 | if self.file.isatty() and self.hide_cursor: 59 | print(HIDE_CURSOR, end='', file=self.file) 60 | 61 | def clearln(self): 62 | if self.file.isatty(): 63 | print('\r\x1b[K', end='', file=self.file) 64 | 65 | def writeln(self, line): 66 | if self.file.isatty(): 67 | self.clearln() 68 | print(line, end='', file=self.file) 69 | self.file.flush() 70 | 71 | def finish(self): 72 | if self.file.isatty(): 73 | print(file=self.file) 74 | if self.hide_cursor: 75 | print(SHOW_CURSOR, end='', file=self.file) 76 | 77 | 78 | from signal import signal, SIGINT 79 | from sys import exit 80 | 81 | 82 | class SigIntMixin(object): 83 | """Registers a signal handler that calls finish on SIGINT""" 84 | 85 | def __init__(self, *args, **kwargs): 86 | super(SigIntMixin, self).__init__(*args, **kwargs) 87 | signal(SIGINT, self._sigint_handler) 88 | 89 | def _sigint_handler(self, signum, frame): 90 | self.finish() 91 | exit(0) 92 | -------------------------------------------------------------------------------- /utils/parser_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from pathlib import Path 4 | import numpy as np 5 | 6 | def get_average_results(fname, method="method1"): 7 | with open(Path(fname), "r") as f: 8 | all_lines = f.readlines() 9 | avg_accs, last_accs, calibration_errors, joint_accuracies = [], [], [], [] 10 | for line in all_lines: 11 | if line.startswith("Acc avg"): 12 | result = re.findall(r'(? 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | 15 | from __future__ import division 16 | 17 | from collections import deque 18 | from datetime import timedelta 19 | from math import ceil 20 | from sys import stderr 21 | from time import time 22 | 23 | 24 | __version__ = '1.3' 25 | 26 | 27 | class Infinite(object): 28 | file = stderr 29 | sma_window = 10 # Simple Moving Average window 30 | 31 | def __init__(self, *args, **kwargs): 32 | self.index = 0 33 | self.start_ts = time() 34 | self.avg = 0 35 | self._ts = self.start_ts 36 | self._xput = deque(maxlen=self.sma_window) 37 | for key, val in kwargs.items(): 38 | setattr(self, key, val) 39 | 40 | def __getitem__(self, key): 41 | if key.startswith('_'): 42 | return None 43 | return getattr(self, key, None) 44 | 45 | @property 46 | def elapsed(self): 47 | return int(time() - self.start_ts) 48 | 49 | @property 50 | def elapsed_td(self): 51 | return timedelta(seconds=self.elapsed) 52 | 53 | def update_avg(self, n, dt): 54 | if n > 0: 55 | self._xput.append(dt / n) 56 | self.avg = sum(self._xput) / len(self._xput) 57 | 58 | def update(self): 59 | pass 60 | 61 | def start(self): 62 | pass 63 | 64 | def finish(self): 65 | pass 66 | 67 | def next(self, n=1): 68 | now = time() 69 | dt = now - self._ts 70 | self.update_avg(n, dt) 71 | self._ts = now 72 | self.index = self.index + n 73 | self.update() 74 | 75 | def iter(self, it): 76 | try: 77 | for x in it: 78 | yield x 79 | self.next() 80 | finally: 81 | self.finish() 82 | 83 | 84 | class Progress(Infinite): 85 | def __init__(self, *args, **kwargs): 86 | super(Progress, self).__init__(*args, **kwargs) 87 | self.max = kwargs.get('max', 100) 88 | 89 | @property 90 | def eta(self): 91 | return int(ceil(self.avg * self.remaining)) 92 | 93 | @property 94 | def eta_td(self): 95 | return timedelta(seconds=self.eta) 96 | 97 | @property 98 | def percent(self): 99 | return self.progress * 100 100 | 101 | @property 102 | def progress(self): 103 | return min(1, self.index / self.max) 104 | 105 | @property 106 | def remaining(self): 107 | return max(self.max - self.index, 0) 108 | 109 | def start(self): 110 | self.update() 111 | 112 | def goto(self, index): 113 | incr = index - self.index 114 | self.next(incr) 115 | 116 | def iter(self, it): 117 | try: 118 | self.max = len(it) 119 | except TypeError: 120 | pass 121 | 122 | try: 123 | for x in it: 124 | yield x 125 | self.next() 126 | finally: 127 | self.finish() 128 | -------------------------------------------------------------------------------- /utils/rotation_angle_matrix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | from tqdm import tqdm 6 | from copy import deepcopy 7 | import numpy as np 8 | 9 | 10 | 11 | class RotationAngleMatrix(): 12 | def __init__(self, args) -> None: 13 | self.args = args 14 | self.task_to_sample_to_visual_feats = {} 15 | self.task_to_sample_to_textual_feats = {} 16 | self.task_to_sample_to_indices = {} # for verification, check if indices are same at each task 17 | self.label_range_to_select = np.arange(10) 18 | self.arccos_by_tasks_visual = {} 19 | self.arccos_by_tasks_textual = {} 20 | 21 | def get_relevant_logits(self, labels, logits_visual, logits_textual, sample_indices): 22 | relevant_logits_visual = [] 23 | relevant_logits_textual = [] 24 | relevant_indices = [] # for verification 25 | labels = np.array(labels) 26 | for curr_cls in self.label_range_to_select: 27 | cls_ind = np.where(labels == curr_cls)[0] 28 | relevant_logits_visual.append(logits_visual[cls_ind]) 29 | relevant_logits_textual.append(logits_textual[cls_ind]) 30 | relevant_indices.extend(sample_indices[cls_ind]) 31 | 32 | relevant_indices = np.array(relevant_indices) 33 | sorting_indices = np.argsort(relevant_indices) 34 | relevant_logits_visual = np.array(torch.cat(relevant_logits_visual))[sorting_indices] 35 | relevant_logits_textual = np.array(torch.cat(relevant_logits_textual))[sorting_indices] 36 | relevant_indices = relevant_indices[sorting_indices] 37 | return relevant_logits_visual, relevant_logits_textual, relevant_indices 38 | 39 | 40 | def store_relevant_logits(self, cur_task, labels, logits_visual, logits_textual, sample_indices): 41 | print(f"Storing logits ..") 42 | relevant_logits_visual, relevant_logits_textual, relevant_indices = self.get_relevant_logits(labels, logits_visual, logits_textual, sample_indices) 43 | self.task_to_sample_to_visual_feats[cur_task] = dict(zip(relevant_indices, relevant_logits_visual)) 44 | self.task_to_sample_to_textual_feats[cur_task] = dict(zip(relevant_indices, relevant_logits_textual)) 45 | 46 | def compute_arccos(self, task_a, task_b, mode='visual'): 47 | assert list(self.task_to_sample_to_visual_feats[task_a].keys()) == list(self.task_to_sample_to_visual_feats[task_b].keys()), \ 48 | f"Test indices mismatch: {list(self.task_to_sample_to_visual_feats[task_a].keys())[:20]} vs {list(self.task_to_sample_to_visual_feats[task_b].keys())[:20]}!" 49 | 50 | if mode == 'visual': 51 | dot_prod = np.array(list(self.task_to_sample_to_visual_feats[task_a].values())) @ np.array(list(self.task_to_sample_to_visual_feats[task_b].values())).T 52 | dot_prod = np.clip(dot_prod, -1, 1) 53 | arccos = np.rad2deg(np.arccos(dot_prod)).mean() 54 | self.arccos_by_tasks_visual[(task_a, task_b)] = arccos 55 | elif mode == 'textual': 56 | dot_prod = np.array(list(self.task_to_sample_to_textual_feats[task_a].values())) @ np.array(list(self.task_to_sample_to_textual_feats[task_b].values())).T 57 | dot_prod = np.clip(dot_prod, -1, 1) 58 | arccos = np.rad2deg(np.arccos(dot_prod)).mean() 59 | self.arccos_by_tasks_textual[(task_a, task_b)] = arccos 60 | else: 61 | raise NotImplementedError 62 | 63 | def compute_rotation_angle_matrix(self, cur_task, labels, logits_visual, logits_textual, sample_indices): 64 | self.store_relevant_logits(cur_task, labels, logits_visual, logits_textual, sample_indices) 65 | if cur_task > 0: 66 | self.compute_arccos(cur_task, 0, mode='visual') 67 | self.compute_arccos(cur_task, 0, mode='textual') 68 | print(f"RAM across visual: {self.arccos_by_tasks_visual}, textual: {self.arccos_by_tasks_textual}") 69 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | from .misc import * 8 | 9 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single'] 10 | 11 | # functions to show an image 12 | def make_image(img, mean=(0,0,0), std=(1,1,1)): 13 | for i in range(0, 3): 14 | img[i] = img[i] * std[i] + mean[i] # unnormalize 15 | npimg = img.numpy() 16 | return np.transpose(npimg, (1, 2, 0)) 17 | 18 | def gauss(x,a,b,c): 19 | return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a) 20 | 21 | def colorize(x): 22 | ''' Converts a one-channel grayscale image to a color heatmap image ''' 23 | if x.dim() == 2: 24 | torch.unsqueeze(x, 0, out=x) 25 | if x.dim() == 3: 26 | cl = torch.zeros([3, x.size(1), x.size(2)]) 27 | cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 28 | cl[1] = gauss(x,1,.5,.3) 29 | cl[2] = gauss(x,1,.2,.3) 30 | cl[cl.gt(1)] = 1 31 | elif x.dim() == 4: 32 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)]) 33 | cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 34 | cl[:,1,:,:] = gauss(x,1,.5,.3) 35 | cl[:,2,:,:] = gauss(x,1,.2,.3) 36 | return cl 37 | 38 | def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 39 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 40 | plt.imshow(images) 41 | plt.show() 42 | 43 | 44 | def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 45 | im_size = images.size(2) 46 | 47 | # save for adding mask 48 | im_data = images.clone() 49 | for i in range(0, 3): 50 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 51 | 52 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 53 | plt.subplot(2, 1, 1) 54 | plt.imshow(images) 55 | plt.axis('off') 56 | 57 | # for b in range(mask.size(0)): 58 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 59 | mask_size = mask.size(2) 60 | # print('Max %f Min %f' % (mask.max(), mask.min())) 61 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 62 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 63 | # for c in range(3): 64 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 65 | 66 | # print(mask.size()) 67 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 68 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 69 | plt.subplot(2, 1, 2) 70 | plt.imshow(mask) 71 | plt.axis('off') 72 | 73 | def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 74 | im_size = images.size(2) 75 | 76 | # save for adding mask 77 | im_data = images.clone() 78 | for i in range(0, 3): 79 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 80 | 81 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 82 | plt.subplot(1+len(masklist), 1, 1) 83 | plt.imshow(images) 84 | plt.axis('off') 85 | 86 | for i in range(len(masklist)): 87 | mask = masklist[i].data.cpu() 88 | # for b in range(mask.size(0)): 89 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 90 | mask_size = mask.size(2) 91 | # print('Max %f Min %f' % (mask.max(), mask.min())) 92 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 93 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 94 | # for c in range(3): 95 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 96 | 97 | # print(mask.size()) 98 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 99 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 100 | plt.subplot(1+len(masklist), 1, i+2) 101 | plt.imshow(mask) 102 | plt.axis('off') 103 | 104 | 105 | 106 | # x = torch.zeros(1, 3, 3) 107 | # out = colorize(x) 108 | # out_im = make_image(out) 109 | # plt.imshow(out_im) 110 | # plt.show() -------------------------------------------------------------------------------- /classifier/zero_shot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | from clip.clip import load, tokenize 7 | from .evaluator import Evaluator 8 | 9 | # import open_clip 10 | 11 | 12 | class ZeroshotCLIP(Evaluator): 13 | def __init__(self, args): 14 | super().__init__(args) 15 | self.args = args 16 | self.clip_model, _ = load(args.ckpt_path, device=f"cuda:{args.default_gpu}") 17 | self.clip_model = self.clip_model.eval() 18 | self.current_class_names = [] 19 | 20 | @torch.no_grad() 21 | def fit(self, data): 22 | self.current_class_names += data['class_names'] 23 | print(f"Class names: {self.current_class_names}") 24 | self.n_class = len(self.current_class_names) 25 | prompts = [[temp.format(c.replace("_", " ")) for temp in data['prompt_templates'] ] for c in self.current_class_names] 26 | self.text_features = [] 27 | with torch.no_grad(): 28 | for per_cls_prompts in prompts: 29 | per_cls_prompt_embs = tokenize(per_cls_prompts).cuda(device=self.args.default_gpu) 30 | text_features = self.clip_model.encode_text(per_cls_prompt_embs) 31 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 32 | text_features = text_features.mean(dim=0) 33 | text_features = text_features / text_features.norm() 34 | self.text_features.append(text_features) 35 | self.text_features = torch.stack(self.text_features, dim=0) 36 | # if self.args.sess == 2: 37 | # self.tsne_plot_text_features() 38 | # print(self.text_features.shape) 39 | 40 | def tsne_plot_text_features(self): 41 | from sklearn.manifold import TSNE 42 | import matplotlib.pyplot as plt 43 | import seaborn as sns 44 | plt.rcParams.update({'font.size': 18}) 45 | 46 | taskwise_means = self.text_features.view(10, -1, 512).mean(0) 47 | to_plot = taskwise_means @ taskwise_means.t() 48 | ax = sns.displot(to_plot.detach().cpu().numpy(), kind="kde", bw_adjust=.25, aspect=1.7, linewidth=3, fill=True, common_norm=True, palette=['red', 'deepskyblue', 'orange'], legend=False) 49 | ax.set(xticklabels=[], yticklabels=[]) 50 | ax.set(xlabel=None, ylabel=None) 51 | ax.tick_params(bottom=False, left=False) # remove the ticks 52 | plt.legend(title='Task', labels=['1', '2', 't']) 53 | plt.tight_layout() 54 | # plt.axis('off') 55 | plt.savefig("distributions1.png") 56 | plt.show() 57 | pass 58 | 59 | @torch.no_grad() 60 | def inference(self,image, label, num_test=None, test_class=None): 61 | with torch.no_grad(): 62 | image_features = self.clip_model.encode_image(image) 63 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 64 | logit_scale = self.clip_model.logit_scale.exp() 65 | logits = image_features @ self.text_features.t() * logit_scale 66 | if self.args.compute_ram: 67 | samplewise_text_feats = self.text_features[label] 68 | return logits.float().softmax(dim=-1), (image_features.detach().cpu(), samplewise_text_feats.detach().cpu()) 69 | return logits.float(), (None, None) 70 | 71 | # def accuracy(self, loader, num_test=None, test_class=None, mean_per_class=False): 72 | # total_count=0 73 | # acc_count=0 74 | 75 | # if mean_per_class: 76 | # n_class = self.text_features.shape[0] 77 | # acc_per_class = [0 for _ in range(n_class)] 78 | # count_per_class = [0 for _ in range(n_class)] 79 | 80 | # for i, (x, y, _) in tqdm(enumerate(loader), total=len(loader), desc = 'Running zero-shot inference..'): 81 | # pred_y = self.inference(x.cuda()) 82 | # _, top_labels = pred_y.topk(1, dim=-1) 83 | 84 | # if not mean_per_class: 85 | # acc_count += (top_labels.view(-1)==y.cuda()).sum().cpu().numpy() 86 | # total_count += y.shape[0] 87 | # else: 88 | # for c in range(n_class): 89 | # acc_per_class[c] += ((top_labels.view(-1) == y.cuda()) * (y.cuda()== c)).sum().item() 90 | # count_per_class[c]+=(y.cuda()==c).sum().item() 91 | 92 | # if not mean_per_class: 93 | # acc = acc_count*1.0/total_count 94 | # acc = acc.item() 95 | # else: 96 | # acc = [a*1.0/c for (a, c) in zip(acc_per_class, count_per_class)] 97 | # acc = np.array(acc).mean() 98 | 99 | # return acc -------------------------------------------------------------------------------- /dataset/build_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from PIL import Image 4 | import numpy as np 5 | import torchvision.transforms as transforms 6 | from torch.utils.data import Dataset 7 | 8 | try: 9 | from torchvision.transforms import InterpolationMode 10 | BICUBIC = InterpolationMode.BICUBIC 11 | except ImportError: 12 | BICUBIC = Image.BICUBIC 13 | 14 | from .cifar import cifar10, cifar100 15 | from .imagenet import imagenet 16 | from .mnist import mnist 17 | import pdb 18 | 19 | def get_transforms(n_px=224, transform_mode='origin'): 20 | 21 | if transform_mode == 'origin': 22 | transform = transforms.Compose([ 23 | transforms.Resize(n_px, interpolation=BICUBIC), 24 | # transforms.Resize((n_px,n_px), interpolation=BICUBIC), 25 | transforms.CenterCrop(n_px), 26 | # transforms.RandomApply([transforms.GaussianBlur(kernel_size=(5, 5))], p=0.3), 27 | lambda image: image.convert("RGB"), 28 | transforms.ToTensor(), 29 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073),(0.26862954,0.26130258, 0.27577711)), 30 | ]) 31 | 32 | elif transform_mode == 'flip': 33 | transform = transforms.Compose([ 34 | transforms.Resize(n_px, interpolation=BICUBIC), 35 | transforms.CenterCrop(n_px), 36 | # transforms.RandomApply([transforms.GaussianBlur(kernel_size=(5, 5))], p=0.3), #modifed 37 | lambda image: image.convert("RGB"), 38 | transforms.RandomHorizontalFlip(), 39 | # transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5), 40 | #modifed 41 | transforms.ToTensor(), 42 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073),(0.26862954,0.26130258, 0.27577711)), 43 | ]) 44 | return transform 45 | 46 | class FewShotDatasetWrapper(Dataset): 47 | def __init__(self, db, select_labels): 48 | self.db=db 49 | self.select_labels = select_labels 50 | 51 | def __getitem__(self, index): 52 | db_idx = self.select_labels[index] 53 | return self.db.__getitem__(db_idx) 54 | 55 | def __len__(self): 56 | return len(self.select_labels) 57 | 58 | def prompts(self, mode='single'): 59 | return self.db.prompts(mode) 60 | 61 | def get_labels(self): 62 | return self.db.get_labels() 63 | 64 | def get_classes(self): 65 | return self.db.get_classes() 66 | 67 | def get_split_index(labels, n_shot, n_val=0, seed=None): 68 | if seed is not None: 69 | np.random.seed(seed) 70 | random.seed(seed) 71 | all_label_list = np.unique(labels) 72 | train_idx_list =[] 73 | val_idx_list =[] 74 | # pdb.set_trace() 75 | for label in all_label_list: 76 | label_collection = np.where(labels == label)[0] 77 | random.shuffle(label_collection) 78 | selected_idx = label_collection[:n_shot+n_val] 79 | train_idx_list.extend(selected_idx[:n_shot]) 80 | val_idx_list.extend(selected_idx[n_shot:]) 81 | # pdb.set_trace() 82 | return train_idx_list, val_idx_list 83 | 84 | def build_dataset(db_name, root, n_shot=-1, n_val=0, transform_mode='origin'): 85 | root = os.path.join(root, db_name) 86 | transform = get_transforms(transform_mode=transform_mode) 87 | test_transform = get_transforms(transform_mode='origin') 88 | 89 | db_func={ 90 | 'cifar10':cifar10, 91 | 'cifar100':cifar100, 92 | 'imagenet': imagenet, 93 | 'mnist':mnist, 94 | } 95 | 96 | train_db = db_func[db_name](root, transform, train=True) 97 | test_db = db_func[db_name](root, test_transform, train=False) 98 | 99 | return train_db, test_db 100 | 101 | def build_dataset_fs(db_name, root, n_shot=1, n_val=0, transform_mode='origin'): 102 | root = os.path.join(root,db_name) 103 | transform = get_transforms(transform_mode=transform_mode) 104 | test_transform = get_transforms(transform_mode='origin') 105 | 106 | db_func ={ 107 | 'cifar10':cifar10, 108 | 'cifar100':cifar100, 109 | 'imagenet': imagenet, 110 | 'mnist':mnist, 111 | } 112 | 113 | train_db = db_func[db_name](root, transform, train=True) 114 | val_db = db_func[db_name](root, test_transform, train=True) 115 | test_db = db_func[db_name](root, test_transform, train=False) 116 | 117 | if n_shot >0: 118 | labels = train_db.get_labels() 119 | train_index, val_index = get_split_index(labels, n_shot, n_val) 120 | train_db = FewShotDatasetWrapper(train_db, train_index) 121 | val_db = FewShotDatasetWrapper(val_db, val_index) 122 | return train_db, val_db, test_db 123 | else: 124 | return train_db, test_db -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | import matplotlib.pyplot as plt 5 | import os 6 | import sys 7 | import numpy as np 8 | 9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 10 | 11 | def savefig(fname, dpi=None): 12 | dpi = 150 if dpi == None else dpi 13 | plt.savefig(fname, dpi=dpi) 14 | 15 | def plot_overlap(logger, names=None): 16 | names = logger.names if names == None else names 17 | numbers = logger.numbers 18 | for _, name in enumerate(names): 19 | x = np.arange(len(numbers[name])) 20 | plt.plot(x, np.asarray(numbers[name])) 21 | return [logger.title + '(' + name + ')' for name in names] 22 | 23 | class Logger(object): 24 | '''Save training process to log file with simple plot function.''' 25 | def __init__(self, fpath, title=None, resume=False): 26 | self.file = None 27 | self.resume = resume 28 | self.title = '' if title == None else title 29 | if fpath is not None: 30 | if resume: 31 | self.file = open(fpath, 'r') 32 | name = self.file.readline() 33 | self.names = name.rstrip().split('\t') 34 | self.numbers = {} 35 | for _, name in enumerate(self.names): 36 | self.numbers[name] = [] 37 | 38 | for numbers in self.file: 39 | numbers = numbers.rstrip().split('\t') 40 | for i in range(0, len(numbers)): 41 | self.numbers[self.names[i]].append(numbers[i]) 42 | self.file.close() 43 | self.file = open(fpath, 'a') 44 | else: 45 | self.file = open(fpath, 'w') 46 | 47 | def set_names(self, names): 48 | if self.resume: 49 | pass 50 | # initialize numbers as empty list 51 | self.numbers = {} 52 | self.names = names 53 | for _, name in enumerate(self.names): 54 | self.file.write(name) 55 | self.file.write('\t') 56 | self.numbers[name] = [] 57 | self.file.write('\n') 58 | self.file.flush() 59 | 60 | 61 | def append(self, numbers): 62 | assert len(self.names) == len(numbers), 'Numbers do not match names' 63 | for index, num in enumerate(numbers): 64 | self.file.write("{0:.6f}".format(num)) 65 | self.file.write('\t') 66 | self.numbers[self.names[index]].append(num) 67 | self.file.write('\n') 68 | self.file.flush() 69 | 70 | def plot(self, names=None): 71 | names = self.names if names == None else names 72 | numbers = self.numbers 73 | for _, name in enumerate(names): 74 | x = np.arange(len(numbers[name])) 75 | plt.plot(x, np.asarray(numbers[name])) 76 | plt.legend([self.title + '(' + name + ')' for name in names]) 77 | plt.grid(True) 78 | 79 | def close(self): 80 | if self.file is not None: 81 | self.file.close() 82 | 83 | class LoggerMonitor(object): 84 | '''Load and visualize multiple logs.''' 85 | def __init__ (self, paths): 86 | '''paths is a distionary with {name:filepath} pair''' 87 | self.loggers = [] 88 | for title, path in paths.items(): 89 | logger = Logger(path, title=title, resume=True) 90 | self.loggers.append(logger) 91 | 92 | def plot(self, names=None): 93 | plt.figure() 94 | plt.subplot(121) 95 | legend_text = [] 96 | for logger in self.loggers: 97 | legend_text += plot_overlap(logger, names) 98 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 99 | plt.grid(True) 100 | 101 | if __name__ == '__main__': 102 | # # Example 103 | # logger = Logger('test.txt') 104 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 105 | 106 | # length = 100 107 | # t = np.arange(length) 108 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 109 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 110 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 111 | 112 | # for i in range(0, length): 113 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 114 | # logger.plot() 115 | 116 | # Example: logger monitor 117 | paths = { 118 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 119 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 120 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 121 | } 122 | 123 | field = ['Valid Acc.'] 124 | 125 | monitor = LoggerMonitor(paths) 126 | monitor.plot(names=field) 127 | savefig('test.eps') -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-s byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on.... 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord(";"), ord("-")+1))+list(range(ord("@"), ord("y")+1)) 27 | cs= bs[:] 28 | n=0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n+=1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char,char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ',text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | 63 | class SimpleTokenizer(object): 64 | def __init__(self,bpe_path: str = default_bpe()): 65 | self.byte_encoder = bytes_to_unicode() 66 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 67 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 68 | merges = merges[1:49152-256-2+1] 69 | merges = [tuple(merge.split()) for merge in merges] 70 | vocab = list(bytes_to_unicode().values()) 71 | vocab = vocab + [v+'' for v in vocab] 72 | for merge in merges: 73 | vocab.append(''.join(merge)) 74 | vocab.extend(['<|startoftext|>','<|endoftext|>']) 75 | self.encoder = dict(zip(vocab,range(len(vocab)))) 76 | self.decoder = {v: k for k, v in self.encoder.items()} 77 | self.bpe_ranks = dict(zip(merges,range(len(merges)))) 78 | self.cache = {'<|startoftext|>':'<|startoftext|>','<|endoftext|>':'<|endoftext|>'} 79 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",re.IGNORECASE) 80 | 81 | def bpe(self,token): 82 | if token in self.cache: 83 | return self.cache[token] 84 | word = tuple(token[:-1]) +( token[-1] + '',) 85 | pairs = get_pairs(word) 86 | 87 | if not pairs: 88 | return token+'' 89 | 90 | while True: 91 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 92 | if bigram not in self.bpe_ranks: 93 | break 94 | first, second = bigram 95 | new_word =[] 96 | i=0 97 | while i < len(word): 98 | try: 99 | j = word.index(first, i) 100 | new_word.extend(word[i:j]) 101 | i = j 102 | except: 103 | new_word.extend(word[i:]) 104 | break 105 | 106 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 107 | new_word.append(first+second) 108 | i += 2 109 | else: 110 | new_word.append(word[i]) 111 | i+=1 112 | new_word = tuple(new_word) 113 | word = new_word 114 | if len(word) == 1: 115 | break 116 | else: 117 | pairs = get_pairs(word) 118 | word =' '.join(word) 119 | self.cache[token] = word 120 | return word 121 | 122 | def encode(self, text): 123 | bpe_tokens = [] 124 | text = whitespace_clean(basic_clean(text)).lower() 125 | for token in re.findall(self.pat, text): 126 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 127 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 128 | return bpe_tokens 129 | 130 | def decode(self, tokens): 131 | text = ''.join([self.decoder[token] for token in tokens]) 132 | text = bytearray([self.byte_decoder[c] for c in text]).decode( 'utf-8', errors="replace").replace('',' ') 133 | return text 134 | -------------------------------------------------------------------------------- /dataset/templates.json: -------------------------------------------------------------------------------- 1 | { 2 | "objectnet": [ 3 | "itap of a {}.", 4 | "a bad photo of the {}.", 5 | "a origami {}.", 6 | "a photo of the large {}.", 7 | "a {} in a video game.", 8 | "art of the {}.", 9 | "a photo of the small {}." 10 | ], 11 | "imagenetr": [ 12 | "itap of a {}.", 13 | "a bad photo of the {}.", 14 | "a origami {}.", 15 | "a photo of the large {}.", 16 | "a {} in a video game.", 17 | "art of the {}.", 18 | "a photo of the small {}." 19 | ], 20 | "food101": [ 21 | "a photo of {}, a type of food." 22 | ], 23 | "cifar224": [ 24 | "a photo of a {}.", 25 | "a blurry photo of a {}.", 26 | "a black and white photo of a {}.", 27 | "a low contrast photo of a {}.", 28 | "a high contrast photo of a {}.", 29 | "a bad photo of a {}.", 30 | "a good photo of a {}.", 31 | "a photo of a small {}.", 32 | "a photo of a big {}.", 33 | "a photo of the {}.", 34 | "a blurry photo of the {}.", 35 | "a black and white photo of the {}.", 36 | "a low contrast photo of the {}.", 37 | "a high contrast photo of the {}.", 38 | "a bad photo of the {}.", 39 | "a good photo of the {}.", 40 | "a photo of the small {}.", 41 | "a photo of the big {}." 42 | ], 43 | 44 | "cub": [ 45 | "a photo of a {}, a type of bird." 46 | ], 47 | "imagenet": [ 48 | "itap of a {}.", 49 | "a bad photo of the {}.", 50 | "a origami {}.", 51 | "a photo of the large {}.", 52 | "a {} in a video game.", 53 | "art of the {}.", 54 | "a photo of the small {}." 55 | ], 56 | "ucf101": [ 57 | "a photo of a person {}.", 58 | "a video of a person {}.", 59 | "a example of a person {}.", 60 | "a demonstration of a person {}.", 61 | "a photo of the person {}.", 62 | "a video of the person {}.", 63 | "a example of the person {}.", 64 | "a demonstration of the person {}.", 65 | "a photo of a person using {}.", 66 | "a video of a person using {}.", 67 | "a example of a person using {}.", 68 | "a demonstration of a person using {}.", 69 | "a photo of the person using {}.", 70 | "a video of the person using {}.", 71 | "a example of the person using {}.", 72 | "a demonstration of the person using {}.", 73 | "a photo of a person doing {}.", 74 | "a video of a person doing {}.", 75 | "a example of a person doing {}.", 76 | "a demonstration of a person doing {}.", 77 | "a photo of the person doing {}.", 78 | "a video of the person doing {}.", 79 | "a example of the person doing {}.", 80 | "a demonstration of the person doing {}.", 81 | "a photo of a person during {}.", 82 | "a video of a person during {}.", 83 | "a example of a person during {}.", 84 | "a demonstration of a person during {}.", 85 | "a photo of the person during {}.", 86 | "a video of the person during {}.", 87 | "a example of the person during {}.", 88 | "a demonstration of the person during {}.", 89 | "a photo of a person performing {}.", 90 | "a video of a person performing {}.", 91 | "a example of a person performing {}.", 92 | "a demonstration of a person performing {}.", 93 | "a photo of the person performing {}.", 94 | "a video of the person performing {}.", 95 | "a example of the person performing {}.", 96 | "a demonstration of the person performing {}.", 97 | "a photo of a person practicing {}.", 98 | "a video of a person practicing {}.", 99 | "a example of a person practicing {}.", 100 | "a demonstration of a person practicing {}.", 101 | "a photo of the person practicing {}.", 102 | "a video of the person practicing {}.", 103 | "a example of the person practicing {}.", 104 | "a demonstration of the person practicing {}." 105 | ], 106 | 107 | "flowers": [ 108 | "a photo of a {}, a type of flower." 109 | ], 110 | "caltech101": [ 111 | "a photo of a {}.", 112 | "a painting of a {}.", 113 | "a plastic {}.", 114 | "a sculpture of a {}.", 115 | "a sketch of a {}.", 116 | "a tattoo of a {}.", 117 | "a toy {}.", 118 | "a rendition of a {}.", 119 | "a embroidered {}.", 120 | "a cartoon {}.", 121 | "a {} in a video game.", 122 | "a plushie {}.", 123 | "a origami {}.", 124 | "art of a {}.", 125 | "graffiti of a {}.", 126 | "a drawing of a {}.", 127 | "a doodle of a {}.", 128 | "a photo of the {}.", 129 | "a painting of the {}.", 130 | "the plastic {}.", 131 | "a sculpture of the {}.", 132 | "a sketch of the {}.", 133 | "a tattoo of the {}.", 134 | "the toy {}.", 135 | "a rendition of the {}.", 136 | "the embroidered {}.", 137 | "the cartoon {}.", 138 | "the {} in a video game.", 139 | "the plushie {}.", 140 | "the origami {}.", 141 | "art of the {}.", 142 | "graffiti of the {}.", 143 | "a drawing of the {}.", 144 | "a doodle of the {}." 145 | ], 146 | 147 | "aircraft": [ 148 | "a photo of a {}, a type of aircraft.", 149 | "a photo of the {}, a type of aircraft." 150 | ], 151 | "cars": [ 152 | "a photo of a {}.", 153 | "a photo of the {}.", 154 | "a photo of my {}.", 155 | "i love my {}!", 156 | "a photo of my dirty {}.", 157 | "a photo of my clean {}.", 158 | "a photo of my new {}.", 159 | "a photo of my old {}." 160 | ], 161 | "sun": [ 162 | "a photo of a {}.", 163 | "a photo of the {}." 164 | ] 165 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # CLAP4CLIP: Continual Learning with Probabilistic Finetuning for Vision-Language Models [[Paper]](https://arxiv.org/pdf/2403.19137v2) 4 | 5 | ## ✨ ***Now accepted to NeurIPS 2024!*** ✨ 6 | 7 |

8 | What is CLAP4CLIP? • 9 | Get going • 10 | What is in this repo? • 11 | Language-aware knowledge • 12 | Uncertainty-related ablations • 13 | Cite 14 |

15 |
16 | 17 | --- 18 | 19 | ## What is CLAP4CLIP? 20 | 21 | ![alt text](https://github.com/srvCodes/clap4clip/blob/main/images/Slide13-1.png "Logo Title Text 1") 22 | 23 | [CLAP4CLIP](https://arxiv.org/pdf/2403.19137v2) is a general probabilistic finetuning framework for the pre-trained CLIP model on downstream class-incremental learning tasks. 24 | 25 | The framework is general because (as depicted below) it supports a diverse range of prompt styles including hand-crafted prompts like [Continual-CLIP](https://arxiv.org/abs/2210.03114), task-conditioned prompts like [CoOp](https://arxiv.org/abs/2109.01134), instance-conditioned prompts like [AttriCLIP](https://arxiv.org/abs/2305.11488), and multi-modal prompts like [MaPLe](https://arxiv.org/abs/2210.03117): 26 | 27 | ![alt text](https://github.com/srvCodes/clap4clip/blob/main/images/Screenshot%202024-05-24%20162355.png) 28 | 29 | 30 | ## Get going 31 | 32 | Clone this github repository: 33 | ``` 34 | git clone https://github.com/srvCodes/clap4clip.git 35 | cd clap4clip 36 | mkdir ckpt/ 37 | ``` 38 | - Download models: Download the pretrained ViT-B-16.pt and ViT-L-14.pt checkpoints to `ckpt/` directory. 39 | 40 | - Download datasets: We suggest following the [mammoth](https://github.com/aimagelab/mammoth) library to download all the datasets into the repo `datasets/`. Instructions for ImageNet-R can be found [here](https://github.com/muzairkhattak/multimodal-prompt-learning/blob/main/docs/DATASETS.md). 41 | 42 | 43 | ## What is in this repo? 44 | 45 | This repo is designed with the aim of benchmarking various finetuning methods for class-incremental learning with the pre-trained CLIP model. 46 | 47 | The instructions below depict how to run the models provided with the initial release on CIFAR100 (check the repo `scripts/` and edit ): 48 | 49 | - CLAP4CLIP with hand-crafted prompts (our base CLAP model): 50 | ``` 51 | python3 main_incremental_submit.py --lasp --beta 15 --db_name cifar100 --use-vga --expandable-adapter --finetuning --finetune-epochs 2 --num-run 10 --compute-ece --compute-bwt --train_batch 32 --exemplar-selector random --root ../path_to_datasets/ --multi-gpu --gpus 0,1 --default-gpu 0 --model clclip_var --epochs 5 --forward-times 20 --arch ViT-B-16 --method er --variational 52 | ``` 53 | - Continual-CLIP (zero-shot): 54 | ``` 55 | python3 main_incremental_submit.py --db_name cifar100 --num-run 10 --compute-ece --compute-bwt --train_batch 32 --root ../path_to_datasets/ --multi-gpu --gpus 0,1 --default-gpu 0 --model clclip --arch ViT-B-16 56 | ``` 57 | - CLIP-Adapter: 58 | ``` 59 | python3 main_incremental_submit.py --db_name cifar100 --finetuning --finetune-epochs 2 --num-run 10 --compute-ece --compute-bwt --train_batch 32 --exemplar-selector random --root ../path_to_datasets/ --multi-gpu --gpus 0,1 --default-gpu 0 --model clip_adapter --epochs 5 --arch ViT-B-16 --method er 60 | ``` 61 | - CLAP with CoOp: 62 | ``` 63 | python3 main_incremental_submit.py --db_name cifar100 --num-run 10 --compute-ece --compute-bwt --train_batch 32 --root ../path_to_datasets/ --multi-gpu --gpus 0,1 --default-gpu 0 --model coop_variational --arch ViT-B-16 64 | ``` 65 | - CLAP with MaPLe: 66 | ``` 67 | python3 main_incremental_submit.py --db_name cifar100 --num-run 10 --compute-ece --compute-bwt --train_batch 32 --root ../path_to_datasets/ --multi-gpu --gpus 0,1 --default-gpu 0 --model maple_variational --arch ViT-B-16 68 | ``` 69 | - CoOp with Adapter (used in Fig. 3b in the paper): 70 | ``` 71 | python3 main_incremental_submit.py --db_name cifar100 --num-run 10 --compute-ece --compute-bwt --train_batch 32 --root ../path_to_datasets/ --multi-gpu --gpus 0,1 --default-gpu 0 --model coop_adapter --arch ViT-B-16 72 | ``` 73 | 74 | We plan to release the following models upon the acceptance of our paper: 75 | - ~~CLAP4CLIP with support for CoOp/MaPLe~~ Now released! 76 | 77 | ## Language-aware knowledge 78 | 79 | - Past-task distribution regularization (for reducing **forgetting** in general): Can be evoked by passing the argument `--lasp --beta` $\gamma$ where $\gamma$ is the loss weight used in Eq. (12) in our paper. 80 | - Weight initialization (for reducing **stability gap**): Currently, controlled by commenting/uncommenting [this line](https://github.com/srvCodes/clap4clip/blob/main/classifier/continual_clip_variational.py#L99). 81 | 82 | ## Uncertainty-related ablations 83 | 84 | In our paper, we show the out-of-the-box perks of uncertainty-aware modelling for the following two tasks: 85 | 86 | ### Post-hoc novel data detection (PhNDD) 87 | 88 | - PhNDD is a post-hoc setting proposed in our paper for evaluating the novel data detection capabilities of a finetuning algorithm within the continual learning setting. To evoke this, simply pass the argument `--eval-ood-score` in the script. 89 | 90 | ### Exemplar selection 91 | - For all but the zero-shot models, the repo implements the following exemplar selection criteria: Random, Herding, Entropy, Variance, Variance of entropy, and Energy scores. These can simply be evoked by passing the value `x` to the argument `--exemplar-selector`, where `x` can be {`random`, `icarl`, `entropy`, `variance`, `distance`, `var_entropy`, `energy`}. 92 | 93 | ## Cite 94 | 95 | If you want to cite this framework feel free to use this preprint citation: 96 | 97 | ```bibtex 98 | @inproceedings{jha2024clap4clip, 99 | title={{CLAP4CLIP}: Continual Learning with Probabilistic Finetuning for Vision-Language Models}, 100 | author={Saurav Jha and Dong Gong and Lina Yao}, 101 | booktitle={Thirty-eighth Conference on Neural Information Processing Systems}, 102 | year={2024}, 103 | url={https://arxiv.org/pdf/2403.19137} 104 | } 105 | ``` 106 | -------------------------------------------------------------------------------- /utils/visualize_ood_detection_trend.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | 5 | # plt.rcParams.update({'font.size': 16}) 6 | sns.color_palette("Spectral") 7 | # sns.set(font_scale=15) 8 | plt.rcParams.update({'font.size': 20,}) 9 | 10 | datum = { 11 | # "Softmax": { 12 | # # softmax 13 | # "FPR95" : { 14 | # "ZS-CLIP": [62.56, 71.31, 73.1, 76.37, 77.02, 80.08, 80.87, 85.75, 85.9], 15 | # "ZS-CLIP+Ours": [83.16, 80.25, 79.5, 81.48, 81.3, 86.15, 84.93, 86.45, 85.0], 16 | # # "Ours": [65.14, 81.74, 82.17, 85.33, 85.88, 88.7, 87.33, 89.5, 89.7] 17 | # }, 18 | # "AUROC" : { 19 | # "ZS-CLIP": [87.83, 84.13, 83.68, 80.86, 81.43, 79.31, 76.94, 74.54, 73.8], 20 | # "ZS-CLIP+Ours": [57.86, 57.75, 57.79, 59.28, 60.66, 57.67, 58.62, 58.73, 61.43], 21 | # # "Ours": [65.95, 59.01, 58.23, 57.25, 57.4, 56.16, 56.94, 56.3, 57.05] 22 | # }, 23 | # "AUPR" : { 24 | # "ZS-CLIP": [58.33, 64.09, 74.25, 76.64, 82.75, 85.58, 88.5, 92.2, 96.14], 25 | # "ZS-CLIP+Ours": [11.66, 22.92, 33.79, 45.22 , 56.25, 64.12, 74.01, 83.06, 92.26], 26 | # # "Ours": [14.11, 23.51, 34.06, 44, 54.19, 63.31, 73.21, 82.21, 91.4] 27 | # }}, 28 | 29 | # energy 30 | "Energy":{ 31 | "FPR95" : { 32 | "Continual-CLIP": [60.42, 69.91, 73.51, 76.35, 76.82,81.23, 84.17, 87.3, 86.3], 33 | "CoOp": [39.99, 60.02, 59.33,61.47, 64.88, 76.02, 80.03, 83.3, 76.2], 34 | "Ours": [47.32, 70.08, 62.31, 69.72, 67.08, 72.72, 74.37, 81.05, 73.8], 35 | "Ours w/o VI": [52.63, 63.32, 63.41, 68.38, 64.48, 73.08, 76.53, 80.25,77.4], 36 | "Ours + CoOp": [38.12, 56.39, 58.94, 65.22, 62.9, 67.5, 73.07, 78.05, 63.9], 37 | "Ours + CoOp (w/o VI)": [48.18, 62.44, 57.13, 63.22, 64.4, 69.95, 76.6, 78.8, 75.2] 38 | }, 39 | "AUROC" : { 40 | "Continual-CLIP": [85.29, 80.25, 78.21, 75.66, 75.41, 72.04, 67.98, 64.8, 70.48], 41 | "CoOp": [90.42, 84.15, 84.79, 81.52, 81.13, 77.43, 73.17, 70.84, 77.93], 42 | "Ours": [88.9, 84.62, 85.31, 82.1, 83.58, 79.87, 78.4, 75.17, 81.93], 43 | "Ours w/o VI": [86.16, 84.58, 85.55, 83.37, 85.46, 80.22, 77.96, 77.11, 80.18], 44 | "Ours + CoOp": [91.36, 86.76, 85.41, 83.72, 84.79, 81.67, 78.57, 76.57, 84.7], 45 | "Ours + CoOp (w/o VI)": [88.26, 85.36, 84.69, 83.18, 82.97, 80.19, 77.64, 76.72, 78.8] 46 | }, 47 | "AUPR" : { 48 | "Continual-CLIP": [40.91, 52.36, 61.82, 67.72, 74.66, 78.28, 82.02, 87.15, 95.09], 49 | "CoOp": [57.83, 60, 71.21, 73.55, 80.47, 83.42, 85.45, 90.02, 96.62], 50 | "Ours": [51.69, 65.11, 73.97, 76.32, 84.6, 85.81, 89.03, 91.89, 97.41], 51 | "Ours w/o VI": [42.7, 61.93, 74.69, 78.86, 86.75, 86.45, 89.09, 92.74, 96.67], 52 | "Ours + CoOp": [59.15, 67.08, 72.66, 78.52, 85.39, 86.92, 88.86, 92.28, 97.83], 53 | "Ours + CoOp (w/o VI)": [50.15, 63.8, 68.96, 76.22, 83.07, 85.52, 88.49, 92.42, 96.68] 54 | }} 55 | } 56 | 57 | # Set the style for the plots 58 | # sns.set(style="whitegrid") 59 | 60 | colors = { 61 | 'iCaRL': 'blue', 62 | 'Continual-CLIP': 'deepskyblue', 63 | 'CoOp': 'gray', 64 | 'CLIP-Adapter': 'purple', 65 | 'Ours + CoOp (w/o VI)': 'green', 66 | 'DualPrompt': 'fuchsia', 67 | 'L2P': 'brown', 68 | 'Ours w/o VI': 'orange', 69 | 'TAw-UB': 'black', 70 | 'Ours': 'red', 71 | "Ours + CoOp": 'black', 72 | "Ours + AttriCLIP": 'indigo' 73 | } 74 | 75 | 76 | markers = { 77 | "iCaRL": "o", 78 | "Continual-CLIP": "*", 79 | "CoOp": "D", 80 | "CLIP-Adapter": "h", 81 | "Ours + CoOp (w/o VI)": "P", 82 | "DualPrompt": "v", 83 | "L2P": "s", 84 | "Ours w/o VI": "^", 85 | "Ours": "X", 86 | "Ours + CoOp": "p", 87 | "Ours + AttriCLIP": ">" 88 | } 89 | 90 | # data = data["softmax"] 91 | def ood_detection_plot(datum, dataset_name="cifar100"): 92 | step_size = 20 if dataset_name == "imagenet-r" else 10 93 | x_labels = list(range(0, step_size*10+1, step_size))[1:-1] #if dataset_name == "imagenet-r" else list(range(0, 101, 10))[1:] 94 | for j, metric in enumerate(datum.keys()): 95 | print(f"{'== ' * 30}\n Confidence Metric: {metric}") 96 | # Create subplots with three plots 97 | fig, axes = plt.subplots(3, 1, figsize=(10, 18)) 98 | 99 | data = datum[metric] 100 | # Plot each dictionary in a subplot 101 | for i, element in enumerate(data.keys()): 102 | print(f"\nEvaluation Metric: {element}") 103 | ax = axes[i] 104 | for k, (key, values) in enumerate(data[element].items()): 105 | print(f"Method: {key}, Average: {np.mean(values)} ") 106 | sns.lineplot(x=x_labels, y=values, ax=ax, label=key, marker=markers[key], color=colors[key], linewidth=2., markersize=17) 107 | 108 | # Set labels and title for each subplot 109 | ax.set_xlabel('Number of Classes') 110 | downarrow = r'$\downarrow$' 111 | uparrow = r'$\uparrow$' 112 | ax.set_ylabel(f"{element}{downarrow if element == 'FPR95' else uparrow}") 113 | # ax.set_title(f'{element}') 114 | # ax.legend(ncol=1, fontsize=14, loc='lower right' if element != "AUROC" else "lower left", frameon=False, framealpha=0.9) 115 | ax.get_legend().remove() 116 | # ax.set_xticklabels(x_labels) 117 | ax.set_axisbelow(True) 118 | ax.spines[['top', 'right']].set_visible(False) 119 | ax.yaxis.grid(color='lightgrey', linestyle='dashed', linewidth=0.3, which='major') 120 | ax.xaxis.grid(color='lightgrey', linestyle='dashed', linewidth=0.3, which='major') 121 | 122 | ax.yaxis.grid(color='aliceblue', linestyle=':', linewidth=0.3, which='minor') 123 | ax.xaxis.grid(color='aliceblue', linestyle=':', linewidth=0.3, which='minor') 124 | 125 | ax.minorticks_on() 126 | 127 | lines_labels = [ax.get_legend_handles_labels() for i, ax in enumerate(fig.axes) if i == 0] 128 | lines, labels = [sum(lol, []) for lol in zip(*lines_labels)] 129 | fig.legend(lines, labels, loc='upper center', ncol=len(labels)//2, frameon=True, bbox_to_anchor=(0.52, 0.95), fontsize=18) 130 | # Adjust spacing and display the subplots 131 | # plt.tight_layout() 132 | plt.minorticks_on() 133 | # fig.suptitle(f'{metric}') 134 | plt.savefig(f"OOD_{metric}.pdf") 135 | plt.show() 136 | 137 | ood_detection_plot(datum) -------------------------------------------------------------------------------- /utils/display_results.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source: https://github.com/wetliu/energy_ood/blob/master/utils/display_results.py 3 | """ 4 | import numpy as np 5 | import sklearn.metrics as sk 6 | 7 | recall_level_default = 0.95 8 | 9 | 10 | def stable_cumsum(arr, rtol=1e-05, atol=1e-08): 11 | """Use high precision for cumsum and check that final value matches sum 12 | Parameters 13 | ---------- 14 | arr : array-like 15 | To be cumulatively summed as flat 16 | rtol : float 17 | Relative tolerance, see ``np.allclose`` 18 | atol : float 19 | Absolute tolerance, see ``np.allclose`` 20 | """ 21 | out = np.cumsum(arr, dtype=np.float64) 22 | expected = np.sum(arr, dtype=np.float64) 23 | if not np.allclose(out[-1], expected, rtol=rtol, atol=atol): 24 | raise RuntimeError('cumsum was found to be unstable: ' 25 | 'its last element does not correspond to sum') 26 | return out 27 | 28 | 29 | def fpr_and_fdr_at_recall(y_true, y_score, recall_level=recall_level_default, pos_label=None): 30 | classes = np.unique(y_true) 31 | if (pos_label is None and 32 | not (np.array_equal(classes, [0, 1]) or 33 | np.array_equal(classes, [-1, 1]) or 34 | np.array_equal(classes, [0]) or 35 | np.array_equal(classes, [-1]) or 36 | np.array_equal(classes, [1]))): 37 | raise ValueError("Data is not binary and pos_label is not specified") 38 | elif pos_label is None: 39 | pos_label = 1. 40 | 41 | # make y_true a boolean vector 42 | y_true = (y_true == pos_label) 43 | 44 | # sort scores and corresponding truth values 45 | desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1] 46 | y_score = y_score[desc_score_indices] 47 | y_true = y_true[desc_score_indices] 48 | 49 | # y_score typically has many tied values. Here we extract 50 | # the indices associated with the distinct values. We also 51 | # concatenate a value for the end of the curve. 52 | distinct_value_indices = np.where(np.diff(y_score))[0] 53 | threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1] 54 | 55 | # accumulate the true positives with decreasing threshold 56 | tps = stable_cumsum(y_true)[threshold_idxs] 57 | fps = 1 + threshold_idxs - tps # add one because of zero-based indexing 58 | 59 | thresholds = y_score[threshold_idxs] 60 | 61 | recall = tps / tps[-1] 62 | 63 | last_ind = tps.searchsorted(tps[-1]) 64 | sl = slice(last_ind, None, -1) # [last_ind::-1] 65 | recall, fps, tps, thresholds = np.r_[recall[sl], 1], np.r_[fps[sl], 0], np.r_[tps[sl], 0], thresholds[sl] 66 | 67 | cutoff = np.argmin(np.abs(recall - recall_level)) 68 | 69 | return fps[cutoff] / (np.sum(np.logical_not(y_true))) # , fps[cutoff]/(fps[cutoff] + tps[cutoff]) 70 | 71 | 72 | def get_measures(_pos, _neg, recall_level=recall_level_default): 73 | pos = np.array(_pos[:]).reshape((-1, 1)) 74 | neg = np.array(_neg[:]).reshape((-1, 1)) 75 | examples = np.squeeze(np.vstack((pos, neg))) 76 | labels = np.zeros(len(examples), dtype=np.int32) 77 | labels[:len(pos)] += 1 78 | 79 | auroc = sk.roc_auc_score(labels, examples) 80 | aupr = sk.average_precision_score(labels, examples) 81 | fpr = fpr_and_fdr_at_recall(labels, examples, recall_level) 82 | 83 | return auroc, aupr, fpr 84 | 85 | 86 | def show_performance(pos, neg, method_name='Ours', recall_level=recall_level_default): 87 | ''' 88 | :param pos: 1's class, class to detect, outliers, or wrongly predicted 89 | example scores 90 | :param neg: 0's class scores 91 | ''' 92 | 93 | auroc, aupr, fpr = get_measures(pos[:], neg[:], recall_level) 94 | 95 | print('\t\t\t' + method_name) 96 | print('FPR{:d}:\t\t\t{:.2f}'.format(int(100 * recall_level), 100 * fpr)) 97 | print('AUROC:\t\t\t{:.2f}'.format(100 * auroc)) 98 | print('AUPR:\t\t\t{:.2f}'.format(100 * aupr)) 99 | # print('FDR{:d}:\t\t\t{:.2f}'.format(int(100 * recall_level), 100 * fdr)) 100 | 101 | 102 | def print_measures(auroc, aupr, fpr, method_name='Ours', recall_level=recall_level_default): 103 | # print('\t\t\t\t' + method_name) 104 | # print(' FPR{:d} AUROC AUPR'.format(int(100*recall_level))) 105 | # print('& {:.2f} & {:.2f} & {:.2f}'.format(100*fpr, 100*auroc, 100*aupr)) 106 | # print('FPR{:d}:\t\t\t{:.2f}'.format(int(100 * recall_level), 100 * fpr)) 107 | # print('AUROC: \t\t\t{:.2f}'.format(100 * auroc)) 108 | # print('AUPR: \t\t\t{:.2f}'.format(100 * aupr)) 109 | print(f"Metric: {method_name} || FPR{int(100 * recall_level)}: {100*fpr:.2f} || AUROC: {100*auroc:.2f} || AUPR: {100*aupr:.2f}") 110 | 111 | 112 | def print_measures_with_std(aurocs, auprs, fprs, method_name='Ours', recall_level=recall_level_default): 113 | print('\t\t\t\t' + method_name) 114 | print(' FPR{:d} AUROC AUPR'.format(int(100*recall_level))) 115 | print('& {:.2f} & {:.2f} & {:.2f}'.format(100*np.mean(fprs), 100*np.mean(aurocs), 100*np.mean(auprs))) 116 | print('& {:.2f} & {:.2f} & {:.2f}'.format(100*np.std(fprs), 100*np.std(aurocs), 100*np.std(auprs))) 117 | #print('FPR{:d}:\t\t\t{:.2f}\t+/- {:.2f}'.format(int(100 * recall_level), 100 * np.mean(fprs), 100 * np.std(fprs))) 118 | #print('AUROC: \t\t\t{:.2f}\t+/- {:.2f}'.format(100 * np.mean(aurocs), 100 * np.std(aurocs))) 119 | #print('AUPR: \t\t\t{:.2f}\t+/- {:.2f}'.format(100 * np.mean(auprs), 100 * np.std(auprs))) 120 | 121 | 122 | def show_performance_comparison(pos_base, neg_base, pos_ours, neg_ours, baseline_name='Baseline', 123 | method_name='Ours', recall_level=recall_level_default): 124 | ''' 125 | :param pos_base: 1's class, class to detect, outliers, or wrongly predicted 126 | example scores from the baseline 127 | :param neg_base: 0's class scores generated by the baseline 128 | ''' 129 | auroc_base, aupr_base, fpr_base = get_measures(pos_base[:], neg_base[:], recall_level) 130 | auroc_ours, aupr_ours, fpr_ours = get_measures(pos_ours[:], neg_ours[:], recall_level) 131 | 132 | print('\t\t\t' + baseline_name + '\t' + method_name) 133 | print('FPR{:d}:\t\t\t{:.2f}\t\t{:.2f}'.format( 134 | int(100 * recall_level), 100 * fpr_base, 100 * fpr_ours)) 135 | print('AUROC:\t\t\t{:.2f}\t\t{:.2f}'.format( 136 | 100 * auroc_base, 100 * auroc_ours)) 137 | print('AUPR:\t\t\t{:.2f}\t\t{:.2f}'.format( 138 | 100 * aupr_base, 100 * aupr_ours)) 139 | # print('FDR{:d}:\t\t\t{:.2f}\t\t{:.2f}'.format( 140 | # int(100 * recall_level), 100 * fdr_base, 100 * fdr_ours)) -------------------------------------------------------------------------------- /dataset/cifar.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | from PIL import Image 5 | import torch 6 | from torch.utils.data import Dataset 7 | import torchvision.transforms as transforms 8 | 9 | class cifar10(Dataset): 10 | 11 | filename = "cifar-10-python.tar.gz" 12 | 13 | train_list=[ 14 | 15 | ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], 16 | ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], 17 | ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], 18 | ['data_batch_4', '634d18415352ddfa80567beed471001a'], 19 | ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], 20 | ] 21 | test_list=[ 22 | 23 | ['test_batch', '40351d587109b95175f43aff81a1287e'], 24 | ] 25 | 26 | meta ={ 27 | 'filename':'batches.meta', 28 | 'key': 'label_names', 29 | 'md5':'5ff9c542aee3614f3951f8cda6e48888', 30 | } 31 | 32 | templates=[ 33 | 'a photo of a {}.', 34 | 'a blurry photo of a {}.', 35 | 'a black and white photo of a {}.', 36 | 'a low contrast photo-of a {}.', 37 | 'a high contrast photb of a {}.', 38 | 'a bad photo of a {}.', 39 | 'a good photo of a {}.', 40 | 'a photo of a small {}.', 41 | 'a photo of a big {}.', 42 | 'a photo of the {}.', 43 | 'a blurry photo of the {}.', 44 | 'a black and white photo of the {}.', 45 | 'a low contrast photo of the {}.', 46 | 'a high contrast photo of the {}.', 47 | 'a bad photo of the {}.', 48 | 'a good photo of the {}.', 49 | 'a photo of the small {}.', 50 | 'a photo of the big {}.',] 51 | 52 | def __init__(self,root, transform=None, train=True): 53 | self.root = root 54 | self.train = train 55 | self.transform = transform 56 | self.base_folder = 'cifar-10-batches-py' 57 | if self.train: 58 | downloaded_list = self.train_list 59 | else: 60 | downloaded_list = self.test_list 61 | 62 | self.data =[] 63 | self.targets =[] 64 | for file_name, checksum in downloaded_list: 65 | file_path = os.path.join(self.root, self.base_folder, file_name) 66 | with open(file_path, 'rb') as f: 67 | entry = pickle.load(f, encoding='latin1') 68 | self.data.append(entry['data']) 69 | if 'labels' in entry: 70 | self.targets.extend(entry['labels']) 71 | else: 72 | self.targets.extend(entry['fine_labels']) 73 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) 74 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HwC 75 | 76 | self._load_meta() 77 | 78 | def _load_meta(self) -> None: 79 | path = os.path.join(self.root, self.base_folder, self.meta['filename']) 80 | with open(path, 'rb') as infile: 81 | data = pickle.load(infile, encoding='latin1') 82 | self.classes = data[self.meta['key']] 83 | self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} 84 | 85 | def __getitem__(self, index): 86 | img, target = self.data[index], self.targets[index] 87 | img = Image.fromarray(img) 88 | 89 | if self.transform is not None: 90 | img = self.transform(img) 91 | 92 | return img, target 93 | 94 | def __len__(self): 95 | return len(self.data) 96 | 97 | def prompts(self,mode='single'): 98 | if mode == 'single': 99 | prompts = [[self.templates[0].format(label)] for label in self.classes] 100 | return prompts 101 | elif mode == 'ensemble': 102 | prompts = [[template.format(label) for template in self.templates] for label in self.classes] 103 | return prompts 104 | 105 | def get_labels(self): 106 | return np.array(self.targets) 107 | 108 | def get_classes(self): 109 | return self.classes 110 | 111 | 112 | class cifar100(Dataset): 113 | # base_folder = 'cifar-100-python' 114 | train_list=[ 115 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 116 | ] 117 | 118 | test_list=[ 119 | ['test', 'foef6b0ae62326f3e7ffdfab6717acfc'], 120 | ] 121 | meta={ 122 | 'filename':'meta', 123 | 'key': 'fine_label_names', 124 | 'md5':'7973b15100ade9c7d40fb424638fde48' 125 | } 126 | 127 | templates=[ 128 | 'a photo of a {}.', 129 | 'a blurry photo of a {}.', 130 | 'a black and white photo of a {}.', 131 | 'a low contrast photo of a {}.', 132 | 'a high contrast photo of a {}.', 133 | 'a bad photo of a {}.', 134 | 'a good photo of a {}.', 135 | 'a photo of a small {}.', 136 | 'a photo of a big {}.', 137 | 'a photo of the {}.', 138 | 'a blurry photo of the {}.', 139 | 'a black and white photo of the {}.', 140 | 'a low contrast photo of the {}.', 141 | 'a high contrast photo of the {}.', 142 | 'a bad photo of the {}.', 143 | 'a good photo of the {}.', 144 | 'a photo of the small {}.', 145 | 'a photo of the big {}.', 146 | ] 147 | 148 | 149 | def __init__(self, root, transform=None, train=True): 150 | self.root = root 151 | self.train = train 152 | self.transform = transform 153 | self.base_folder = 'cifar-100-python' 154 | 155 | if self.train: 156 | downloaded_list = self.train_list 157 | else: 158 | downloaded_list = self.test_list 159 | 160 | self.data =[] 161 | self.targets =[] 162 | for file_name, checksum in downloaded_list: 163 | file_path = os.path.join(self.root, self.base_folder, file_name) 164 | with open(file_path, 'rb') as f: 165 | entry = pickle.load(f, encoding='latin1') 166 | self.data.append(entry['data']) 167 | if 'labels' in entry: 168 | self.targets.extend(entry['labels']) 169 | else: 170 | self.targets.extend(entry['fine_labels']) 171 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) 172 | self.data = self.data.transpose((0, 2, 3, 1))# convert to HWC 173 | 174 | self._load_meta() 175 | 176 | def _load_meta(self) -> None: 177 | path = os.path.join(self.root, self.base_folder, self.meta['filename']) 178 | with open(path, 'rb') as infile: 179 | data = pickle.load(infile, encoding='latin1') 180 | self.classes = data[self.meta[ 'key']] 181 | self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} 182 | 183 | def __getitem__(self, index): 184 | img, target = self.data[index], self.targets[index] 185 | img = Image.fromarray(img) 186 | 187 | if self.transform is not None: 188 | img = self.transform(img) 189 | 190 | return img,target,int(index) 191 | 192 | def __len__(self): 193 | return len(self.data) 194 | 195 | def prompts(self, mode='single'): 196 | if mode =='single': 197 | prompts = [[self.templates[0].format(label)] for label in self.classes] 198 | return prompts 199 | elif mode == 'ensemble': 200 | prompts = [[template.format(label) for template in self.templates] for label in self.classes] 201 | return prompts 202 | 203 | def get_labels(self): 204 | return np.array(self.targets) 205 | 206 | def get_classes(self): 207 | return self.classes 208 | -------------------------------------------------------------------------------- /classifier/clip_adapter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | from tqdm import tqdm 6 | from copy import deepcopy 7 | import numpy as np 8 | 9 | from clip.clip import load, tokenize 10 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 11 | _tokenizer = _Tokenizer() 12 | import dataset.incremental_dataloader 13 | 14 | from .utils import build_cosine_scheduler, freeze_parameters 15 | import pdb 16 | import time 17 | from .evaluator import Evaluator 18 | 19 | 20 | class Adapter(nn.Module): 21 | def __init__(self, c_in, reduction=4): 22 | super(Adapter, self).__init__() 23 | self.fc = nn.Sequential( 24 | nn.Linear(c_in, c_in // reduction, bias=False), 25 | nn.ReLU(inplace=True), 26 | nn.Linear(c_in // reduction, c_in, bias=False), 27 | nn.ReLU(inplace=True) 28 | ) 29 | 30 | def forward(self, x): 31 | x = self.fc(x) 32 | return x 33 | 34 | class CLIP(nn.Module): 35 | def __init__(self, args, class_names, clip_model, temp=None): 36 | super().__init__() 37 | self.n_class = len(class_names) 38 | self.args = args 39 | self.clip_model = clip_model 40 | self.current_class_names = class_names 41 | 42 | # image encoder 43 | self.image_encoder = clip_model.visual 44 | self.logit_scale = clip_model.logit_scale 45 | ctx_dim = self.clip_model.ln_final.weight.shape[0] 46 | 47 | # clip adapter 48 | self.adapter = Adapter(ctx_dim, 4).cuda(device=self.args.default_gpu).type(self.clip_model.dtype) 49 | # prompt templaates 50 | self.temp = temp if temp is not None else ["A photo of a"] 51 | self.text_features = self.get_text_features() 52 | 53 | def get_text_features(self): 54 | prompts = [[template.format(c.replace("_", " ")) for template in self.temp] for c in self.current_class_names] 55 | self.text_features = [] 56 | with torch.no_grad(): 57 | for per_cls_prompts in prompts: 58 | per_cls_prompt_embs = tokenize(per_cls_prompts).cuda(device=self.args.default_gpu) 59 | text_features = self.clip_model.encode_text(per_cls_prompt_embs) 60 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 61 | text_features = text_features.mean(dim=0) 62 | text_features = text_features / text_features.norm() 63 | self.text_features.append(text_features) 64 | text_features = torch.stack(self.text_features, dim=0) 65 | return text_features 66 | 67 | 68 | def forward(self, image, label=None, test=False): 69 | 70 | image_features = self.image_encoder(image.type(self.dtype)) 71 | x = self.adapter(image_features) 72 | ratio = 0.2 73 | image_features = ratio * x + (1 - ratio) * image_features 74 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 75 | 76 | n_class = self.n_class 77 | 78 | if test: 79 | text_features = self.text_features 80 | logit_scale = self.logit_scale.exp() 81 | logits = logit_scale * image_features @ text_features.t() 82 | if self.args.compute_ram: 83 | visual_feats = image_features 84 | textual_feats = text_features[label] 85 | return logits, (visual_feats.detach().cpu(), textual_feats.detach().cpu()) 86 | return logits, (None, None) 87 | 88 | else: 89 | # text_features_all = [] 90 | text_features = self.text_features 91 | # text_features_all.append(text_features) 92 | # text_features = torch.stack(text_features_all).sum(0) 93 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 94 | text_features = text_features.view(n_class, -1) 95 | 96 | logit_scale = self.logit_scale.exp() 97 | logits = logit_scale * image_features @ text_features.t() 98 | 99 | return logits 100 | 101 | @torch.no_grad() 102 | def set_classifier(self): 103 | pass 104 | 105 | @property #变成属性 106 | def dtype(self): 107 | return self.image_encoder.conv1.weight.dtype #return int/float 108 | 109 | 110 | class ClipAdapter(Evaluator): 111 | def __init__(self, args, use_float32=False, use_grad_checkpoint=False): 112 | super().__init__(args) 113 | self.args = args 114 | clip_model, _ = load(args.ckpt_path, device=f"cuda:{args.default_gpu}") 115 | clip_model.eval() 116 | if use_float32: 117 | clip_model.float() 118 | self.clip_model = clip_model 119 | self.use_grad_checkpoint = use_grad_checkpoint 120 | 121 | self.lr = args.lr 122 | self.wd = args.wd 123 | self.epochs = args.epochs 124 | self.train_batch = args.train_batch 125 | self.current_class_names = [] 126 | 127 | def fit(self, data): 128 | self.current_class_names += data['class_names'] 129 | print(f"Classes: {self.current_class_names}") 130 | train_loader = data['train_loader'] 131 | 132 | if len(train_loader.dataset)< self.train_batch: 133 | real_img_bsz = len(train_loader.dataset) 134 | self.lr = self.lr * real_img_bsz / self.train_batch 135 | else: 136 | real_img_bsz = self.train_batch 137 | 138 | per_epoch_steps = len(train_loader) 139 | 140 | self.init_model(class_names=self.current_class_names, per_epoch_steps=per_epoch_steps, temp=data['prompt_templates']) 141 | 142 | self.model.eval() 143 | if self.args.sess >= 0: 144 | for epoch in tqdm(range(self.epochs)): 145 | for idx, (x, y, index) in tqdm(enumerate(train_loader), total=len(train_loader), desc = 'Training'): 146 | 147 | cur_iter_idx = epoch*per_epoch_steps+idx 148 | self.cur_iter_idx = cur_iter_idx 149 | self.scheduler.step(cur_iter_idx) 150 | 151 | output = self.model(x.cuda(device=self.args.default_gpu)) 152 | # pdb.set_trace() 153 | loss = F.cross_entropy(output, y.cuda(device=self.args.default_gpu)) 154 | self.optimizer.zero_grad() 155 | loss.backward() 156 | self.optimizer.step() 157 | 158 | self.model.set_classifier() 159 | return self.model 160 | 161 | def finetuning(self, data): 162 | memory_loader = data['memory_loader'] 163 | self.cur_iter_idx = 0 164 | 165 | if len(memory_loader.dataset)< self.train_batch: 166 | real_img_bsz = len(memory_loader.dataset) 167 | self.lr = self.lr * real_img_bsz / self.train_batch 168 | else: 169 | real_img_bsz = self.train_batch 170 | 171 | per_epoch_steps = len(memory_loader) 172 | 173 | self.build_optimizer(per_epoch_steps=per_epoch_steps, lr=self.lr/10., finetune=True) 174 | 175 | self.model.eval() 176 | 177 | for epoch in tqdm(range(self.args.finetune_epochs)): 178 | for idx, (x, y, index) in tqdm(enumerate(memory_loader), total=len(memory_loader), desc = 'Finetuning'): 179 | 180 | cur_iter_idx = epoch*per_epoch_steps+idx 181 | self.cur_iter_idx = cur_iter_idx 182 | self.scheduler.step(cur_iter_idx) 183 | 184 | output = self.model(x.cuda(device=self.args.default_gpu)) 185 | # pdb.set_trace() 186 | loss = F.cross_entropy(output, y.cuda(device=self.args.default_gpu)) 187 | self.optimizer.zero_grad() 188 | loss.backward() 189 | self.optimizer.step() 190 | 191 | return self.model 192 | 193 | 194 | def post_training(self, finalize=False): 195 | self.model.set_classifier() 196 | 197 | def init_model(self, class_names, per_epoch_steps, temp=None): 198 | self.n_class = len(class_names) 199 | clip_model = deepcopy(self.clip_model) 200 | 201 | self.model = CLIP(self.args, class_names, clip_model, temp=temp) 202 | 203 | if self.use_grad_checkpoint: 204 | try: 205 | self.model.text_encoder.transformer.use_gradient_checkpoint = True 206 | except: 207 | self.model.text_encoder.module.transformer.use_gradient_checkpoint = True 208 | self.build_optimizer(per_epoch_steps, lr=self.lr, warmup=True) 209 | 210 | def build_optimizer(self, per_epoch_steps, lr, warmup=False, finetune=False): 211 | for name, param in self.model.named_parameters(): 212 | if "adapter" not in name: 213 | param.requires_grad_(False) 214 | # double check 215 | enabled = set() 216 | for name, param in self.model.named_parameters(): 217 | if param.requires_grad: 218 | enabled.add(name) 219 | 220 | print(f"\nParameters to be updated: {sorted(enabled)}\n") 221 | 222 | param_dict = [{'params': [p for p in self.model.parameters() if p.requires_grad]}] 223 | total_step=self.epochs*per_epoch_steps if not finetune else self.args.finetune_epochs*per_epoch_steps 224 | self.optimizer = torch.optim.SGD(param_dict, lr=lr, weight_decay=self.wd) 225 | self.scheduler = build_cosine_scheduler( 226 | self.optimizer, 227 | lr=lr, 228 | total_step=total_step) 229 | 230 | 231 | @torch.no_grad() 232 | def inference(self,image, label, num_test, test_class): 233 | self.model.eval() 234 | logits, feats = self.model(image, label, test=True) 235 | return logits.float(), feats 236 | 237 | 238 | -------------------------------------------------------------------------------- /clip/clip_2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hashlib 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | 7 | from PIL import Image 8 | import torch 9 | from tqdm import tqdm 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | 12 | from .model_2 import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | import pdb 15 | try: 16 | from torchvision.transforms import InterpolationMode 17 | BICUBIC = InterpolationMode.BICUBIC 18 | except ImportError: 19 | BICUBIC = Image.BICUBIC 20 | 21 | if torch.__version__.split(".") < ["1","7","1"]: 22 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 23 | 24 | __all__ = ["available_models", "load", "tokenize"] 25 | _tokenizer = _Tokenizer() 26 | 27 | _MODELS = { 28 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 29 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 30 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 31 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 32 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 33 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 34 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 35 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 36 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 37 | } 38 | 39 | def _download(url:str,root:str): 40 | os.makedirs(root,exist_ok=True) 41 | filename = os.path.basename(url) 42 | 43 | expected_sha256 = url.split("/")[-2] 44 | download_target = os.path.join(root,filename) 45 | 46 | if os.path.exists(download_target) and not os.path.isfile(download_target): 47 | raise RuntimeError(f"{download_target} exists and is not a regular file") 48 | 49 | if os.path.isfile(download_target): 50 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 51 | return download_target 52 | else: 53 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 54 | 55 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 56 | with tqdm(total=int(source.info().get("content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 57 | while True: 58 | buffer = source.read(8192) 59 | if not buffer: 60 | break 61 | 62 | output.write(buffer) 63 | loop.update(len(buffer)) 64 | 65 | if hashlib.sha256(open(download_target,"rb").read()).hexdigest()!= expected_sha256: 66 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 67 | 68 | return download_target 69 | 70 | 71 | def _transform(n_px): 72 | return Compose([ 73 | Resize(n_px,interpolation=BICUBIC), 74 | CenterCrop(n_px), 75 | lambda image: image.convert("RGB"), 76 | ToTensor(), 77 | Normalize((0.48145466, 0.4578275, 0.40821073),(0.26862954, 0.26130258, 0.27577711)), 78 | ]) 79 | 80 | 81 | def available_models() -> List[str]: 82 | """Returns the names of available CLIP models""" 83 | return list(_MODELS.keys()) 84 | 85 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 86 | """Load a CLIP model 87 | 88 | Parameters 89 | ---------- 90 | name:str 91 | A model name listed by `clip.available_models()", or the path to a model checkpoint containing the state_dict 92 | 93 | device : Union[str, torch.device] 94 | The device to put the loaded model 95 | 96 | jit: bool 97 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 98 | 99 | download_root: str 100 | path to download the model files; by default, it uses "~/.cache/clip" 101 | 102 | Returns 103 | ------- 104 | model: torch.nn.Module 105 | The CLIP model 106 | 107 | preprocess : Callable[[PIL.Image], torch.Tensor] 108 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 109 | """ 110 | # pdb.set_trace() 111 | if name in _MODELS: 112 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 113 | elif os.path.isfile(name): 114 | model_path = name 115 | else: 116 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 117 | try: 118 | # loading JIT archive 119 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 120 | state_dict = None 121 | except RuntimeError: 122 | # loading saved state dict 123 | if jit: 124 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 125 | jit = False 126 | state_dict = torch.load(model_path, map_location="cpu") 127 | 128 | if not jit: 129 | model = build_model(state_dict or model.state_dict()).to(device) 130 | if str(device) == "cpu": 131 | model.float() 132 | return model, _transform(model.visual.input_resolution) 133 | 134 | # patch the device names 135 | device_holder = torch.jit.trace(lambda:torch.ones([]).to(torch.device(device)), example_inputs=[]) 136 | device_node = [n for n in device_holder.graph.findAllNodes("prim: :Constant") if "Device" in repr(n)][-1] 137 | 138 | def patch_device(module): 139 | try: 140 | graphs = [module.graph] if hasattr(module, "graph") else [] 141 | except RuntimeError: 142 | graphs =[] 143 | 144 | if hasattr(module, "forward1"): 145 | graphs.append(module.forward1.graph) 146 | 147 | for graph in graphs: 148 | for node in graph.findAllNodes("prim::Constant"): 149 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 150 | node.copyAttributes(device_node) 151 | 152 | model.apply(patch_device) 153 | patch_device(model.encode_image) 154 | patch_device(model.encode_text) 155 | 156 | # patch dtype to float32 on CPU 157 | if str(device)=="cpu": 158 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 159 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 160 | float_node = float_input.node() 161 | 162 | 163 | def patch_float(module): 164 | try: 165 | graphs = [module.graph] if hasattr(module, "graph") else [] 166 | except RuntimeError: 167 | graphs = [] 168 | 169 | if hasattr(module, "forward1"): 170 | graphs.append(module.forward1.graph) 171 | 172 | for graph in graphs: 173 | for node in graph.findAllNodes("aten::to"): 174 | inputs = list(node.inputs()) 175 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 176 | if inputs[i].node()["value"] == 5: 177 | inputs[i].node().copyAttributes(float_node) 178 | 179 | model.apply(patch_float) 180 | patch_float(model.encode_image) 181 | patch_float(model.encode_text) 182 | 183 | model.float() 184 | 185 | return model, _transform(model.input_resolution.item()) 186 | 187 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 188 | """ 189 | Returns the tokenized representation of given input string(s) 190 | 191 | Parameters 192 | --------- 193 | texts : Union[str, List[str]] 194 | An input string or a list of input strings to tokenize 195 | 196 | context_length : int 197 | The context length to use; all CLIP models use 77 as the context length 198 | 199 | truncate:bool 200 | whether to truncate the text in case its encoding is longer than the context length 201 | 202 | Returns 203 | ------- 204 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 205 | """ 206 | 207 | if isinstance(texts, str): 208 | texts = [texts] 209 | 210 | sot_token = _tokenizer.encoder["<|startoftext|>"] 211 | eot_token = _tokenizer.encoder["<|endoftext|>"] 212 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 213 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 214 | 215 | for i, tokens in enumerate(all_tokens): 216 | if len(tokens) > context_length: 217 | if truncate: 218 | tokens = tokens[:context_length] 219 | tokens[-1] = eot_token 220 | else: 221 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 222 | result[i, :len(tokens)] = torch.tensor(tokens) 223 | 224 | return result -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hashlib 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | 7 | from PIL import Image 8 | import torch 9 | from tqdm import tqdm 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | try: 16 | from torchvision.transforms import InterpolationMode 17 | BICUBIC = InterpolationMode.BICUBIC 18 | except ImportError: 19 | BICUBIC = Image.BICUBIC 20 | 21 | if torch.__version__.split(".") < ["1","7","1"]: 22 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 23 | 24 | __all__ = ["available_models", "load", "tokenize"] 25 | _tokenizer = _Tokenizer() 26 | 27 | _MODELS = { 28 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 29 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 30 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 31 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 32 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 33 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 34 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 35 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 36 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 37 | } 38 | 39 | def _download(url:str,root:str): 40 | os.makedirs(root,exist_ok=True) 41 | filename = os.path.basename(url) 42 | 43 | expected_sha256 = url.split("/")[-2] 44 | download_target = os.path.join(root,filename) 45 | 46 | if os.path.exists(download_target) and not os.path.isfile(download_target): 47 | raise RuntimeError(f"{download_target} exists and is not a regular file") 48 | 49 | if os.path.isfile(download_target): 50 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 51 | return download_target 52 | else: 53 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 54 | 55 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 56 | with tqdm(total=int(source.info().get("content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 57 | while True: 58 | buffer = source.read(8192) 59 | if not buffer: 60 | break 61 | 62 | output.write(buffer) 63 | loop.update(len(buffer)) 64 | 65 | if hashlib.sha256(open(download_target,"rb").read()).hexdigest()!= expected_sha256: 66 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 67 | 68 | return download_target 69 | 70 | 71 | def _transform(n_px): 72 | return Compose([ 73 | Resize(n_px,interpolation=BICUBIC), 74 | CenterCrop(n_px), 75 | lambda image: image.convert("RGB"), 76 | ToTensor(), 77 | Normalize((0.48145466, 0.4578275, 0.40821073),(0.26862954, 0.26130258, 0.27577711)), 78 | ]) 79 | 80 | 81 | def available_models() -> List[str]: 82 | """Returns the names of available CLIP models""" 83 | return list(_MODELS.keys()) 84 | 85 | 86 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None, design_details={"vision_depth":0, "trainer":"", "language_depth":0}): 87 | """Load a CLIP model 88 | 89 | Parameters 90 | ---------- 91 | name:str 92 | A model name listed by `clip.available_models()", or the path to a model checkpoint containing the state_dict 93 | 94 | device : Union[str, torch.device] 95 | The device to put the loaded model 96 | 97 | jit: bool 98 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 99 | 100 | download_root: str 101 | path to download the model files; by default, it uses "~/.cache/clip" 102 | 103 | Returns 104 | ------- 105 | model: torch.nn.Module 106 | The CLIP model 107 | 108 | preprocess : Callable[[PIL.Image], torch.Tensor] 109 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 110 | """ 111 | if name in _MODELS: 112 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 113 | elif os.path.isfile(name): 114 | model_path = name 115 | else: 116 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 117 | 118 | try: 119 | # loading JIT archive 120 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 121 | state_dict = None 122 | except RuntimeError: 123 | # loading saved state dict 124 | if jit: 125 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 126 | jit = False 127 | state_dict = torch.load(model_path, map_location="cpu") 128 | 129 | if not jit: 130 | model = build_model(state_dict or model.state_dict(), design_details=design_details).to(device) 131 | if str(device) == "cpu": 132 | model.float() 133 | return model, _transform(model.visual.input_resolution) 134 | 135 | # patch the device names 136 | device_holder = torch.jit.trace(lambda:torch.ones([]).to(torch.device(device)), example_inputs=[]) 137 | device_node = [n for n in device_holder.graph.findAllNodes("prim: :Constant") if "Device" in repr(n)][-1] 138 | 139 | def patch_device(module): 140 | try: 141 | graphs = [module.graph] if hasattr(module, "graph") else [] 142 | except RuntimeError: 143 | graphs =[] 144 | 145 | if hasattr(module, "forward1"): 146 | graphs.append(module.forward1.graph) 147 | 148 | for graph in graphs: 149 | for node in graph.findAllNodes("prim::Constant"): 150 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 151 | node.copyAttributes(device_node) 152 | 153 | model.apply(patch_device) 154 | patch_device(model.encode_image) 155 | patch_device(model.encode_text) 156 | 157 | # patch dtype to float32 on CPU 158 | if str(device)=="cpu": 159 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 160 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 161 | float_node = float_input.node() 162 | 163 | 164 | def patch_float(module): 165 | try: 166 | graphs = [module.graph] if hasattr(module, "graph") else [] 167 | except RuntimeError: 168 | graphs = [] 169 | 170 | if hasattr(module, "forward1"): 171 | graphs.append(module.forward1.graph) 172 | 173 | for graph in graphs: 174 | for node in graph.findAllNodes("aten::to"): 175 | inputs = list(node.inputs()) 176 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 177 | if inputs[i].node()["value"] == 5: 178 | inputs[i].node().copyAttributes(float_node) 179 | 180 | model.apply(patch_float) 181 | patch_float(model.encode_image) 182 | patch_float(model.encode_text) 183 | 184 | model.float() 185 | 186 | return model, _transform(model.input_resolution.item()) 187 | 188 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 189 | """ 190 | Returns the tokenized representation of given input string(s) 191 | 192 | Parameters 193 | --------- 194 | texts : Union[str, List[str]] 195 | An input string or a list of input strings to tokenize 196 | 197 | context_length : int 198 | The context length to use; all CLIP models use 77 as the context length 199 | 200 | truncate:bool 201 | whether to truncate the text in case its encoding is longer than the context length 202 | 203 | Returns 204 | ------- 205 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 206 | """ 207 | 208 | if isinstance(texts, str): 209 | texts = [texts] 210 | 211 | sot_token = _tokenizer.encoder["<|startoftext|>"] 212 | eot_token = _tokenizer.encoder["<|endoftext|>"] 213 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 214 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 215 | 216 | for i, tokens in enumerate(all_tokens): 217 | if len(tokens) > context_length: 218 | if truncate: 219 | tokens = tokens[:context_length] 220 | tokens[-1] = eot_token 221 | else: 222 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 223 | result[i, :len(tokens)] = torch.tensor(tokens) 224 | 225 | return result -------------------------------------------------------------------------------- /classifier/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from sklearn import datasets 5 | from sklearn.manifold import TSNE 6 | from sklearn.cluster import KMeans 7 | import pdb 8 | import math 9 | import random 10 | from torch.utils.data import Sampler 11 | import torch.nn.functional as F 12 | 13 | class SubsetRandomSampler(Sampler): 14 | r"""Samples elements randomly from a given list of indices, without replacement. 15 | 16 | Arguments: 17 | indices (sequence): a sequence of indices 18 | """ 19 | 20 | def __init__(self, indices, shuffle): 21 | self.indices = indices 22 | self.shuffle = shuffle 23 | 24 | def __iter__(self): 25 | if(self.shuffle): 26 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 27 | else: 28 | return (self.indices[i] for i in range(len(self.indices))) 29 | 30 | def __len__(self): 31 | return len(self.indices) 32 | 33 | def ce_loss_np(logits, targets_onehot, sample_T): 34 | 35 | pred = F.softmax(logits, dim=-1) 36 | 37 | B = pred.size(1) 38 | targets_onehot_expand = targets_onehot.unsqueeze(0).expand(sample_T, -1, -1) 39 | loss =torch.sum(-targets_onehot_expand * pred.log()) 40 | 41 | return loss/(B*sample_T) 42 | 43 | def get_context_by_labels(labels, m=None): 44 | _, idx, counts = torch.unique(labels, dim=0, sorted=True, return_inverse=True, return_counts=True) 45 | _, idx_sorted = torch.sort(idx, stable=True) 46 | cum_sum = counts.cumsum(0) 47 | cum_sum = torch.cat((torch.tensor([0]).to(cum_sum.device), cum_sum[:-1])) 48 | context_indices = idx_sorted[cum_sum] 49 | if m is not None and context_indices.size(0) < m: 50 | diff = m - context_indices.size(0) 51 | context_indices_permuted = torch.randperm(labels.size(0)).to(labels.device) 52 | context_indices_permuted = context_indices_permuted[(context_indices_permuted != context_indices.view(-1, 1)).all(dim=0)] 53 | context_indices = torch.cat((context_indices, context_indices_permuted[:diff])) 54 | return context_indices 55 | 56 | def compute_uncertainty(logits, T=1.): 57 | logits = logits * T 58 | pseudo_label = torch.softmax(logits, dim=-1) 59 | if logits.dim() == 3: 60 | pseudo_label = pseudo_label.mean(0) 61 | uncertainty = torch.special.entr(pseudo_label).sum(1) 62 | return uncertainty 63 | 64 | @torch.no_grad() 65 | def get_context_indices_by_uncertainty(bs, labels, logits, task_specific_labels=None, top_k=1): 66 | unique_labels = torch.unique(labels) 67 | labels_to_indices = {label.item(): (labels == label).nonzero().flatten() for label in unique_labels} 68 | uncertainties = compute_uncertainty(logits) 69 | uncertainties_by_labels = {label.item(): uncertainties[labels_to_indices[label.item()]] for label in unique_labels} 70 | uncertainties_by_labels_sorted_indices = {label: labels_to_indices[label][torch.argsort(uncs, descending=False)] for label, uncs in uncertainties_by_labels.items()} 71 | context_indices = torch.cat([indices[:top_k] for _, indices in uncertainties_by_labels_sorted_indices.items()]) 72 | return context_indices.detach() 73 | 74 | @torch.no_grad() 75 | def get_context_indices( bs, labels, task_specific_labels=None, context_size=0.67): 76 | if task_specific_labels is None: 77 | # m = random.randint(math.ceil(0.3 * bs), math.ceil(0.8 * bs)) 78 | m = math.ceil(context_size * bs) 79 | context_indices = torch.randperm(labels.size(0)).to(labels.device)[:m] 80 | # context_indices = get_context_by_labels(labels, m) 81 | else: 82 | context_indices = [] 83 | for label in task_specific_labels: 84 | idx = (labels == label).nonzero(as_tuple=True)[0] 85 | context_indices.append(idx) 86 | context_indices = torch.cat(context_indices) 87 | if context_indices.shape[0] == labels.shape[0]: 88 | context_indices = get_context_indices(bs, labels) 89 | return context_indices 90 | 91 | def init_weights(m): 92 | if isinstance(m, torch.nn.Linear): 93 | torch.nn.init.xavier_uniform(m.weight) 94 | if m.bias is not None: 95 | m.bias.data.fill_(0.01) 96 | 97 | def freeze_parameters(m, requires_grad=False): 98 | if m is None: 99 | return 100 | 101 | if isinstance(m, torch.nn.Parameter): 102 | m.requires_grad = requires_grad 103 | else: 104 | for p in m.parameters(): 105 | p.requires_grad = requires_grad 106 | 107 | def cosine_schedule_warmup(total_step, value, final_value=0, warmup_step=0, warmup_value=0): 108 | if warmup_step > 0: 109 | warmup_schedule = np.linspace(warmup_value, value, warmup_step+2)[1:-1] 110 | else: 111 | warmup_schedule = np.array([]) 112 | steps = np.arange(total_step - warmup_step) 113 | schedule = final_value + 0.5 * (value-final_value) * (1+np.cos(np.pi * steps / len(steps))) 114 | schedule = np.concatenate((warmup_schedule, schedule)) 115 | assert len(schedule) == total_step 116 | return schedule 117 | 118 | class build_cosine_scheduler: 119 | def __init__(self, optimizer, lr, total_step, lr_warmup_step=0): 120 | init_lr = 0 121 | final_lr = lr * 1e-3 122 | self.lrs = cosine_schedule_warmup(total_step, lr, final_lr, lr_warmup_step, init_lr) 123 | self.optimizer = optimizer 124 | 125 | def step(self,idx): 126 | lr = self.lrs[idx] 127 | for i, param_group in enumerate(self.optimizer.param_groups): 128 | param_group["lr"]= lr 129 | self.lr=lr 130 | 131 | class build_bicosine_scheduler: 132 | def __init__(self, optimizer, lr, total_step, lr_warmup_step=0): 133 | lr_promt = lr[0] 134 | lr_conv = lr[1] 135 | init_lr=0 136 | final_lr_promt = lr_promt * 1e-3 137 | final_lr_conv = lr_conv * 1e-3 138 | self.lrs_prompt = cosine_schedule_warmup(total_step, lr_promt, final_lr_promt, lr_warmup_step, init_lr) 139 | self.lrs_conv = cosine_schedule_warmup(total_step, lr_conv, final_lr_conv, lr_warmup_step, init_lr) 140 | self.optimizer = optimizer 141 | 142 | def step(self,idx): 143 | lr_promt = self.lrs_prompt[idx] 144 | lr_conv = self.lrs_conv[idx] 145 | for i, param_group in enumerate(self.optimizer.param_groups): 146 | # pdb.set_trace() 147 | if i==0: 148 | param_group["lr"] = lr_conv 149 | else: 150 | param_group["lr"] = lr_promt 151 | self.lr_conv = lr_conv 152 | self.lr_prompt = lr_promt 153 | 154 | def plot_tsne(features, labels, id): 155 | """ 156 | features:(N*m)N*m大小特征,其中N代表有N个数据,每个数据m维 157 | label:(N)有N个标签 158 | """ 159 | fig_path = "/home/ma-user/work/proda/visualization/tsne_{}.png".format(id) 160 | features = features.detach().cpu().numpy() 161 | labels = labels.detach().cpu().numpy() 162 | # import pandas as pd 163 | # tsne = TSNE(n_components=2, init='pca', random_state=0) 164 | # import seaborn as sns 165 | # class_num = len(np.unique(labels))#要分类的种类个数 eg:[0,1,2,3]这个就是为4 166 | 167 | 168 | 169 | # tsne_features = tsne.fit_transform(features)#将特征使用PCA降维至2维 170 | # print('tsne_features的shape:',tsne_features.shape) 171 | # # plt.scatter(tsne_features[:, 0], tsne_features[:, 1])#将对降维的特征进行可视化 172 | # # plt.show() 173 | # plt.savefig(fig_path) 174 | 175 | # sns.set() 176 | # df = pd.DataFrame() 177 | # df["y"] = labels 178 | # df["comp-1"] = tsne_features[:,0] 179 | # df["comp-2"] = tsne_features[:,1] 180 | 181 | 182 | # fig = sns.scatterplot(x="comp-1", y="comp-2",hue=df.y.tolist(), 183 | # palette=sns.color_palette("hls", class_num), 184 | # data=df).set(title="Bearing data T-SNE projection") 185 | 186 | # scatter_fig = fig.get_figure() 187 | # scatter_fig.savefig(fig_path, dpi = 400) 188 | 189 | tSNE = TSNE() 190 | word_embeddings = tSNE.fit_transform(features) 191 | classifier = KMeans(n_clusters=len(np.unique(labels))) 192 | classifier.fit(word_embeddings) 193 | labels = classifier.labels_ 194 | min_left = min(word_embeddings[:, 0]) 195 | max_right = max(word_embeddings[:, 0]) 196 | min_bottom = min(word_embeddings[:, 1]) 197 | max_top = max(word_embeddings[:, 1]) 198 | # markers = ["bo","go",,"mo","yo","ko","bx","gx", "rx"] 199 | colors =["b","g","r","y", "k", "slategrey","slateblue","pink"] 200 | marks = ["o","o","o","o","o","o","o","o","o","o","x","x","x","x","x","x","x","x","x","x"] 201 | for i in range(len(word_embeddings)): 202 | plt.plot(word_embeddings[i][0], word_embeddings[i][1], marker=marks[labels[i]], color=colors[labels[i]]) 203 | plt.axis([min_left, max_right, min_bottom, max_top]) 204 | plt.savefig(fig_path) 205 | plt.clf() 206 | 207 | 208 | def plot_histogram(image1,image2,n): 209 | # image1 = image1.reshape(image1.shape[0],-1).cpu() 210 | # image2 = image2.reshape(image2.shape[0],-1).cpu() 211 | image1 = image1.reshape(-1).cpu() 212 | image2 = image2.reshape(-1).cpu() 213 | image3 = torch.cat((image1,image2),0).detach().numpy() 214 | image1 = image1.detach().numpy() 215 | imagez = image2.detach().numpy() 216 | # bins = np.linspace(image3.min(),image3.max(),n) 217 | bins = np.linspace(-0.045,0.045,n) 218 | # for i in range(image1.shape[0]): 219 | # pdb.set_trace() 220 | i = 0 221 | j = 8 222 | # plt.ylim((0,15000)) 223 | plt.ylim((0,400)) 224 | # plt.hist(image1[i], bins, alpha=0.5, label='x_1') 225 | # plt.hist(image1[j], bins, alpha=0.5, label='x_2') 226 | plt.hist(image1, bins, alpha=0.5, label='Image features') 227 | plt.hist(image2, bins, alpha=0.5, label='Text features') 228 | plt.legend(loc='upper right',fontsize=15) 229 | # print("image",image1[i].mean(),image1[j].mean(),image1[i].mean()-image1[j].mean()) 230 | fig_path = "/home/ma-user/work/proda/visualization/histogram_kl.png" 231 | plt.savefig(fig_path) 232 | plt.clf() 233 | # plt.ylim((0,15000)) 234 | # plt.hist(image2[i], bins, alpha=0.5, label='adv_1') 235 | # plt.hist(image2[j], bins, alpha=0.5, label='adv_2') 236 | # plt.legend(loc='upper right') 237 | # print("text",image2[i].mean(),image2[j].mean(),image2[i].mean()-image2[j].mean()) 238 | # fig_path = "/home/ma-user/work/proda/visualization/histogram_text0.png" 239 | # plt.savefig(fig_path) 240 | # plt.clf() 241 | # pdb.set_trace() 242 | 243 | def cosine_loss(q,k): 244 | # pdb.set_trace() 245 | q = q.repeat(1,k.shape[1],1) 246 | # k = k.squeeze(1) 247 | # q = q/q.norm(dim=-1) 248 | k_norm = k.norm(dim=-1,keepdim=True) 249 | # pdb.set_trace() 250 | # k_norm = k.norm(dim=-1).unsqueeze(1).repeat(1,k.shape[1]) 251 | k = k/k_norm 252 | cos = ((q*k)/(k.shape[0]*k.shape[1])).sum() 253 | return 1-cos 254 | -------------------------------------------------------------------------------- /dataset/imagenetr.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | from torchvision import datasets 6 | import torch 7 | from shutil import move, rmtree 8 | 9 | def prepare_imagenet_r(fpath="/home/srv/Documents/mammoth_datasets/imagenet-r/"): 10 | if not os.path.exists(fpath + '/train') and not os.path.exists(fpath + '/test'): 11 | dataset = datasets.ImageFolder(fpath, transform=None) 12 | 13 | train_size = int(0.8 * len(dataset)) 14 | val_size = len(dataset) - train_size 15 | 16 | train, val = torch.utils.data.random_split(dataset, [train_size, val_size]) 17 | train_idx, val_idx = train.indices, val.indices 18 | 19 | train_file_list = [dataset.imgs[i][0] for i in train_idx] 20 | test_file_list = [dataset.imgs[i][0] for i in val_idx] 21 | 22 | 23 | train_folder = fpath + '/train' 24 | test_folder = fpath + '/test' 25 | 26 | if os.path.exists(train_folder): 27 | rmtree(train_folder) 28 | if os.path.exists(test_folder): 29 | rmtree(test_folder) 30 | os.mkdir(train_folder) 31 | os.mkdir(test_folder) 32 | 33 | for c in dataset.classes: 34 | if not os.path.exists(os.path.join(train_folder, c)): 35 | os.mkdir(os.path.join(os.path.join(train_folder, c))) 36 | if not os.path.exists(os.path.join(test_folder, c)): 37 | os.mkdir(os.path.join(os.path.join(test_folder, c))) 38 | 39 | for path in train_file_list: 40 | if '\\' in path: 41 | path = path.replace('\\', '/') 42 | src = path 43 | dst = os.path.join(train_folder, '/'.join(path.split('/')[-2:])) 44 | move(src, dst) 45 | 46 | for path in test_file_list: 47 | if '\\' in path: 48 | path = path.replace('\\', '/') 49 | src = path 50 | dst = os.path.join(test_folder, '/'.join(path.split('/')[-2:])) 51 | move(src, dst) 52 | 53 | for c in dataset.classes: 54 | path = os.path.join(fpath, c) 55 | rmtree(path) 56 | 57 | class imagenetR(Dataset): 58 | 59 | templates = [ 60 | 'a photo of a {}.', 61 | 'a bad photo of a {}.', 62 | 'a photo of many {}.', 63 | 'a sculpture of a {}.', 64 | 'a photo of the hard to see {}.', 65 | 'a low resolution photo of the {}.', 66 | 'a rendering of a {}.', 67 | 'graffiti of a {}.', 68 | 'a bad photo of the {}.', 69 | 'a cropped photo of the {}.', 70 | 'a tattoo of a {}.', 71 | 'the embroidered {}.', 72 | 'a photo of a hard to see {}.', 73 | 'a bright photo of a {}.', 74 | 'a photo of a clean {}.', 75 | 'a photo of a dirty {}.', 76 | 'a dark photo of the {}.', 77 | 'a drawing of a {}.', 78 | 'a photo of my {}.', 79 | 'the plastic {}.', 80 | 'a photo of the cool {}.', 81 | 'a close-up photo of a {}.', 82 | 'a black and white photo of the {}.', 83 | 'a painting of the {}.', 84 | 'a painting of a {}.', 85 | 'a pixelated photo of the {}.', 86 | 'a sculpture of the {}.', 87 | 'a bright photo of the {}.', 88 | 'a cropped photo of a {}.', 89 | 'a plastic {}.', 90 | 'a photo of the dirty {}.', 91 | 'a jpeg corrupted photo of a {}.', 92 | 'a blurry photo of the {}.', 93 | 'a photo of the {}.', 94 | 'a good photo of the {}.', 95 | 'a rendering of the {}.', 96 | 'a {} in a video game.', 97 | 'a photo of one {}.', 98 | 'a doodle of a {}.', 99 | 'a close-up photo of the {}.', 100 | 'the origami {}.', 101 | 'the {} in a video game.', 102 | 'a sketch of a {}.', 103 | 'a doodle of the {}.', 104 | 'a origami {}.', 105 | 'a low resolution photo of a {}.', 106 | 'the toy {}.', 107 | 'a rendition of the {}.', 108 | 'a photo of the clean {}.', 109 | 'a photo of a large {}.', 110 | 'a rendition of a {}.', 111 | 'a photo of a nice {}.', 112 | 'a photo of a weird {}.', 113 | 'a blurry photo of a {}.', 114 | 'a cartoon {}.', 115 | 'art of a {}.', 116 | 'a sketch of the {}.', 117 | 'a embroidered {}.', 118 | 'a pixelated photo of a {}.', 119 | 'itap of the {}.', 120 | 'a jpeg corrupted photo of the {}.', 121 | 'a good photo of a {}.', 122 | 'a plushie {}.', 123 | 'a photo of the nice {}.', 124 | 'a photo of the small {}.', 125 | 'a photo of the weird {}.', 126 | 'the cartoon {}.', 127 | 'art of the {}.', 128 | 'a drawing of the {}.', 129 | 'a photo of the large {}.', 130 | 'a black and white photo of a {}.', 131 | 'the plushie {}.', 132 | 'a dark photo of a {}.', 133 | 'itap of a {}.', 134 | 'graffiti of the {}.', 135 | 'a toy {}.', 136 | 'itap of my {}.', 137 | 'a photo of a cool {}.', 138 | 'a photo of a small {}.', 139 | 'a tattoo of the {}.', 140 | ] 141 | 142 | new_classes = ['goldfish', 'great_white_shark', 'hammerhead', 'stingray', 'hen', 'ostrich', 'goldfinch', 143 | 'junco', 'bald_eagle', 'vulture', 'newt', 'axolotl', 'tree_frog', 'iguana', 'African_chameleon', 144 | 'cobra', 'scorpion', 'tarantula', 'centipede', 'peacock', 'lorikeet', 'hummingbird', 'toucan', 145 | 'duck', 'goose', 'black_swan', 'koala', 'jellyfish', 'snail', 'lobster', 'hermit_crab', 'flamingo', 146 | 'american_egret', 'pelican', 'king_penguin', 'grey_whale', 'killer_whale', 'sea_lion', 'chihuahua', 147 | 'shih_tzu', 'afghan_hound', 'basset_hound', 'beagle', 'bloodhound', 'italian_greyhound', 'whippet', 148 | 'weimaraner', 'yorkshire_terrier', 'boston_terrier', 'scottish_terrier', 149 | 'west_highland_white_terrier', 'golden_retriever', 'labrador_retriever', 'cocker_spaniels', 150 | 'collie', 'border_collie', 'rottweiler', 'german_shepherd_dog', 'boxer', 'french_bulldog', 151 | 'saint_bernard', 'husky', 'dalmatian', 'pug', 'pomeranian', 'chow_chow', 152 | 'pembroke_welsh_corgi', 'toy_poodle', 'standard_poodle', 'timber_wolf', 'hyena', 'red_fox', 153 | 'tabby_cat', 'leopard', 'snow_leopard', 'lion', 'tiger', 'cheetah', 'polar_bear', 'meerkat', 154 | 'ladybug', 'fly', 'bee', 'ant', 'grasshopper', 'cockroach', 'mantis', 'dragonfly', 155 | 'monarch_butterfly', 'starfish', 'wood_rabbit', 'porcupine', 'fox_squirrel', 'beaver', 'guinea_pig', 156 | 'zebra', 'pig', 'hippopotamus', 'bison', 'gazelle', 'llama', 'skunk', 'badger', 'orangutan', 157 | 'gorilla', 'chimpanzee', 'gibbon', 'baboon', 'panda', 'eel', 'clown_fish', 'puffer_fish', 158 | 'accordion', 'ambulance', 'assault_rifle', 'backpack', 'barn', 'wheelbarrow', 'basketball', 'bathtub', 159 | 'lighthouse', 'beer_glass', 'binoculars', 'birdhouse', 'bow_tie', 'broom', 'bucket', 'cauldron', 'candle', 160 | 'cannon', 'canoe', 'carousel', 'castle', 'mobile_phone', 'cowboy_hat', 'electric_guitar', 'fire_engine', 'flute', 161 | 'gasmask', 'grand_piano', 'guillotine', 'hammer', 'harmonica', 'harp', 'hatchet', 'jeep', 'joystick', 'lab_coat', 162 | 'lawn_mower', 'lipstick', 'mailbox', 'missile', 'mitten', 'parachute', 'pickup_truck', 'pirate_ship', 'revolver', 163 | 'rugby_ball', 'sandal', 'saxophone', 'school_bus', 'schooner', 'shield', 'soccer_ball', 'space_shuttle', 'spider_web', 164 | 'steam_locomotive', 'scarf', 'submarine', 'tank', 'tennis_ball', 'tractor', 'trombone', 'vase', 'violin', 165 | 'military_aircraft', 'wine_bottle', 'ice_cream', 'bagel', 'pretzel', 'cheeseburger', 'hotdog', 'cabbage', 'broccoli', 166 | 'cucumber', 'bell_pepper', 'mushroom', 'Granny_Smith', 'strawberry', 'lemon', 'pineapple', 'banana', 'pomegranate', 167 | 'pizza', 'burrito', 'espresso', 'volcano', 'baseball_player', 'scuba_diver', 'acorn'] 168 | 169 | 170 | 171 | def __init__(self, root, transform=None,train=True): 172 | split = 'train' if train else 'test' 173 | self.split = split 174 | self.classes = self.new_classes 175 | self.root = root 176 | self.datadir = os.path.join(root, f'imagenet-r/{split}') 177 | self.transform = transform 178 | # self.prepare_files_() 179 | self._load_meta() 180 | 181 | def prepare_files_(self): 182 | metadata_path = os.path.join(self.root, "imagenet-r/README.txt") 183 | self.data, self.targets = [], [] 184 | with open(metadata_path) as f: 185 | lines = [line for line in f.readlines()[13:]] 186 | f.close() 187 | dir_names = ["".join(line.split(" ")[0]) for line in lines] 188 | relative_path_and_labels = [] 189 | for class_id, dir in enumerate(dir_names): 190 | for dirpath, dnames, fnames in os.walk(f"{self.datadir}/{dir}"): 191 | for fname in fnames: 192 | relative_fpath = f"{self.split}/{dir}/{fname}" 193 | to_write = f"{relative_fpath} {class_id}\n" 194 | relative_path_and_labels.append(to_write) 195 | with open(f"imagenet_split/imagenetr_{self.split}.txt", "w") as f: 196 | f.writelines(relative_path_and_labels) 197 | 198 | 199 | def _load_meta(self): 200 | metadata_path = f"imagenet_split/imagenetr_{self.split}.txt" 201 | self.data, self.targets = [], [] 202 | with open(metadata_path) as f: 203 | for line in f: 204 | path, target = line.strip().split(" ") 205 | self.data.append(os.path.join(self.root, f'imagenet-r/{path}')) 206 | self.targets.append(int(target)) 207 | self.data = np.array(self.data) 208 | 209 | 210 | def __getitem__(self, index): 211 | img, target = self.data[index], self.targets[index] 212 | img = Image.open(img).convert("RGB") 213 | 214 | if self.transform is not None: 215 | img = self.transform(img) 216 | 217 | return img,target,int(index) 218 | 219 | def __len__(self): 220 | return len(self.data) 221 | 222 | def prompts(self,mode='single'): 223 | if mode == 'single': 224 | prompts = [[self.imagenet_templates[0].format(label)] for label in self.new_classes] 225 | return prompts 226 | elif mode == 'ensemble': 227 | prompts = [[template.format(label) for template in self.imagenet_templates] for label in self.new_classes] 228 | return prompts 229 | 230 | def get_labels(self): 231 | return np.array(self.targets) 232 | 233 | def get_classes(self): 234 | return self.new_classes 235 | 236 | # first download the imagenet-r.tar file 237 | # then uncomment this to prepare the train and test datasets 238 | # prepare_imagenet_r() -------------------------------------------------------------------------------- /classifier/evaluator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | from tqdm import tqdm 6 | from copy import deepcopy 7 | import numpy as np 8 | 9 | from clip.clip import load, tokenize 10 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 11 | _tokenizer = _Tokenizer() 12 | import dataset.incremental_dataloader 13 | 14 | from .utils import build_cosine_scheduler, freeze_parameters 15 | import pdb 16 | import time 17 | from torchmetrics.classification import MulticlassCalibrationError 18 | from utils.display_results import get_measures, print_measures 19 | 20 | class Evaluator(): 21 | def __init__(self, args): 22 | self.args = args 23 | self.task_to_original_acc = {} 24 | self.time_step_to_acc = {} 25 | self.time_step_to_taw_acc = {} 26 | self.task_to_ood_metrics = {} 27 | self.time_step_to_future_task_acc = {} 28 | self.time_step_to_test_id_to_module_id = {} 29 | 30 | 31 | def flush_task_to_accs(self): 32 | self.task_to_original_acc = {} 33 | 34 | def compute_backward_transfer(self, curr_task_accs): 35 | bwt_scores = [] 36 | for t in range(len(curr_task_accs) -1): 37 | acc = curr_task_accs[t] 38 | bwt = acc - self.task_to_original_acc[t] 39 | bwt_scores.append(bwt) 40 | print(f"Average BWT score: {np.mean(bwt_scores)}") 41 | 42 | 43 | def mask_classes(self, outputs: torch.Tensor, k: int) -> None: 44 | """ 45 | Given the output tensor, the dataset at hand and the current task, 46 | masks the former by setting the responses for the other tasks at -inf. 47 | It is used to obtain the results for the task-il setting. 48 | :param outputs: the output tensor 49 | :param dataset: the continual dataset 50 | :param k: the task index 51 | """ 52 | outputs[:, 0:k * self.args.class_per_task] = -float('inf') 53 | outputs[:, (k + 1) * self.args.class_per_task: 54 | self.args.num_task * self.args.class_per_task] = -float('inf') 55 | 56 | @staticmethod 57 | def compute_confidence_score(preds, metric="energy"): 58 | conf = None 59 | if metric == "energy": 60 | if preds.dim() == 3: 61 | preds = preds.mean(1) 62 | conf = torch.logsumexp(preds, dim=-1) 63 | elif metric == "softmax": 64 | if preds.dim() == 3: 65 | preds = preds.mean(1) 66 | conf, _ = torch.max(preds.softmax(dim=-1), dim=1) 67 | elif metric == "variance": 68 | conf = -torch.var(preds, dim=1).sum(1) 69 | elif metric == "variance_softmax": 70 | conf = -torch.var(preds.softmax(dim=-1), dim=1).sum(1) 71 | elif metric == "variance_max_prob": 72 | preds, _ = torch.max(preds.softmax(dim=-1), dim=-1) 73 | conf = -torch.var(preds, dim=1)#.sum(1) 74 | else: 75 | raise NotImplementedError(f"Confidence metric: '{metric}' is not defined!") 76 | return conf 77 | 78 | def compute_ood_scores(self, id_preds, ood_test_loader, num_test=None, test_class=None): 79 | ood_preds = [] 80 | total_count, acc_count = 0, 0 81 | for i, (x, y, idx) in tqdm(enumerate(ood_test_loader), total=len(ood_test_loader), desc=f"Running OOD inference:"): 82 | pred_y_, _ = self.inference(x.cuda(device=self.args.default_gpu), y, num_test=num_test, test_class=test_class) 83 | if pred_y_.dim() == 3: 84 | pred_y_ = pred_y_.permute(1, 0, 2) 85 | ood_preds.append(pred_y_.clone().cpu()) 86 | pred_y = pred_y_.mean(0) if pred_y_.dim() == 3 else pred_y_ 87 | pred_y = pred_y.softmax(dim=-1) 88 | _, top_labels = pred_y.topk(1, dim=-1) 89 | acc_count += (top_labels.view(-1)==y.cuda(device=self.args.default_gpu)).sum().cpu().numpy() 90 | total_count += y.shape[0] 91 | ood_preds = torch.cat(ood_preds, 0) 92 | acc = acc_count*1.0/total_count 93 | acc = acc.item() 94 | self.time_step_to_future_task_acc[self.args.sess] = acc 95 | print(f"Future tasks avg acc: {np.mean(list(self.time_step_to_future_task_acc.values()))} {acc}") 96 | print(f"Total ID examples: {id_preds.shape[0]}, Total OOD examples: {ood_preds.shape[0]}") 97 | for metric in ["energy"]:#, "softmax", "variance", "variance_softmax", "variance_max_prob"]: 98 | if "variance" in metric and id_preds.dim() == 2: 99 | continue 100 | confidence_id = self.compute_confidence_score(id_preds, metric=metric) 101 | confidence_ood = self.compute_confidence_score(ood_preds, metric=metric) 102 | measures = get_measures(confidence_id, confidence_ood) 103 | print_measures(measures[0], measures[1], measures[2], metric) 104 | if metric == "softmax": 105 | self.task_to_ood_metrics[self.args.sess] = [measures[0], measures[1], measures[2]] 106 | if self.args.sess > 0: 107 | all_vals = list(self.task_to_ood_metrics.values()) 108 | means_ = np.mean(all_vals, 0) 109 | print(f"Average {metric}: FPR95: {means_[2]} || AUROC: {means_[0]} || AUPR: {means_[1]}") 110 | 111 | def map_class_id_to_module_id(self, class_id): 112 | module_id = torch.div(class_id, self.args.class_per_task, rounding_mode='trunc') 113 | return module_id 114 | 115 | @torch.no_grad() 116 | def _accuracy(self, loaders, num_test=None, test_class=None, only_eval=False, ood_test_loader=None): 117 | visual_feats, textual_feats, indices, labels = [],[], [], [] 118 | # pdb.set_trace() 119 | accs, accs_mask_classes = [], [] 120 | calibration_errors = [] 121 | task_to_module_accuracy = {} 122 | self.calibration_evaluator = MulticlassCalibrationError(num_classes=len(self.current_class_names)) if self.args.compute_ece else None 123 | id_preds = [] 124 | inference_times = [] 125 | if self.args.sess >= 0: 126 | # return 0 127 | # else: 128 | for k, loader in enumerate(loaders): 129 | total_count=0 130 | acc_count =0 131 | correct_mask_classes = 0 132 | task_calibration_errors=[] 133 | selected_module_ids = [] 134 | for i, (x, y, idx) in tqdm(enumerate(loader), total=len(loader), desc=f"Task {k} inference:"): 135 | start_time = time.time() 136 | pred_y_, feats = self.inference(x.cuda(device=self.args.default_gpu), y, num_test=num_test, test_class=test_class) 137 | inference_times.append(time.time() - start_time) 138 | 139 | pred_y = pred_y_.mean(0) if pred_y_.dim() == 3 else pred_y_ 140 | pred_y = pred_y.softmax(dim=-1) 141 | _, top_labels = pred_y.topk(1, dim=-1) 142 | acc_count += (top_labels.view(-1)==y.cuda(device=self.args.default_gpu)).sum().cpu().numpy() 143 | total_count += y.shape[0] 144 | if self.args.viz_module_selection: 145 | selected_module_id = self.map_class_id_to_module_id(top_labels) 146 | selected_module_ids.append(selected_module_id) 147 | if self.args.compute_ece: 148 | task_calibration_errors.append(self.calibration_evaluator(pred_y, y.cuda(device=self.args.default_gpu))) 149 | if self.args.compute_ram: 150 | visual_feats.append(feats[0]) 151 | textual_feats.append(feats[1]) 152 | indices.append(deepcopy(idx)) 153 | labels.append(deepcopy(y)) 154 | del idx 155 | del y 156 | 157 | if self.args.eval_ood_score and ood_test_loader is not None: 158 | if pred_y_.dim() == 3: 159 | pred_y_ = pred_y_.permute(1, 0, 2) 160 | id_preds.append(pred_y_.clone().cpu()) 161 | 162 | pred_y_ = pred_y_.mean(0) if pred_y_.dim() == 3 else pred_y_ 163 | self.mask_classes(pred_y_, k) 164 | _, taw_pred = pred_y_.topk(1, dim=-1) 165 | correct_mask_classes += (taw_pred.view(-1)==y.cuda(device=self.args.default_gpu)).sum().cpu().numpy() 166 | 167 | acc = acc_count*1.0/total_count 168 | acc = acc.item() 169 | accs.append(acc) 170 | 171 | acc_taw = correct_mask_classes*1.0/total_count 172 | acc_taw = acc_taw.item() 173 | accs_mask_classes.append(acc_taw) 174 | 175 | if not only_eval and k == len(loaders) - 1: 176 | self.task_to_original_acc[self.args.sess] = acc 177 | 178 | if self.args.compute_ece: 179 | calibration_errors.extend(task_calibration_errors) 180 | 181 | if self.args.viz_module_selection: 182 | selected_module_ids = torch.cat(selected_module_ids) 183 | module_ids, counts = torch.unique(selected_module_ids, return_counts=True) 184 | individual_task_allocations = {j: 0 for j in range(len(loaders))} 185 | for module_id, count in zip(module_ids, counts): 186 | individual_task_allocations[module_id.item()] += count.item() 187 | task_to_module_accuracy[k] = {task_label: count / total_count * 100. for task_label, count in list(individual_task_allocations.items())} 188 | 189 | print(f"Average inference time: {np.mean(inference_times)}") 190 | if self.args.viz_module_selection: 191 | self.time_step_to_test_id_to_module_id[self.args.sess] = task_to_module_accuracy 192 | print(self.time_step_to_test_id_to_module_id) 193 | 194 | if self.args.eval_ood_score and ood_test_loader is not None: 195 | self.compute_ood_scores(torch.cat(id_preds, 0), ood_test_loader, num_test=num_test, test_class=test_class) 196 | 197 | if self.args.compute_ram: 198 | visual_feats = torch.cat(visual_feats) 199 | textual_feats = torch.cat(textual_feats) 200 | indices = torch.cat(indices) 201 | labels = torch.cat(labels) 202 | self.args.ram_computer.compute_rotation_angle_matrix(self.args.sess, labels, visual_feats, textual_feats, indices) 203 | 204 | if self.args.compute_ece: 205 | print(f"Avg. Expected Calibration Error: {torch.stack(calibration_errors).mean()}") 206 | 207 | acc = np.mean(accs) 208 | self.time_step_to_acc[self.args.sess] = acc 209 | 210 | acc_taw = np.mean(accs_mask_classes) 211 | self.time_step_to_taw_acc[self.args.sess] = acc_taw 212 | 213 | print(f"Acc avg: {np.mean(list(self.time_step_to_acc.values()))}, Acc last: {acc}") 214 | print(f"TaW Acc avg: {np.mean(list(self.time_step_to_taw_acc.values()))}, TaW Acc last: {acc_taw}") 215 | 216 | if self.args.sess > 0 and self.args.compute_bwt: 217 | self.compute_backward_transfer(accs) 218 | 219 | return acc 220 | 221 | @torch.no_grad() 222 | def _accuracy_mpc(self, loader): 223 | n_class = self.n_class 224 | acc_per_class = [0 for _ in range(n_class)] 225 | count_per_class = [0 for _ in range(n_class)] 226 | visual_feats, textual_feats, indices, labels = [],[], [], [] 227 | for i, (x, y, idx) in tqdm(enumerate(loader), total=len(loader), desc = 'running inference'): 228 | pred_y, feats = self.inference(x.cuda(device=self.args.default_gpu), y) 229 | if self.args.compute_ram: 230 | visual_feats.append(feats[0]) 231 | textual_feats.append(feats[1]) 232 | indices.extend(idx) 233 | labels.extend(y) 234 | _, top_labels = pred_y.topk(1, dim=-1) 235 | for c in range(n_class): 236 | acc_per_class[c] += ((top_labels.view(-1) == y.cuda(device=self.args.default_gpu)) * (y.cuda(device=self.args.default_gpu)== c)).sum().item() 237 | count_per_class[c] += (y.cuda(device=self.args.default_gpu) == c).sum().item() 238 | acc = [a*1.0/c for (a, c) in zip(acc_per_class, count_per_class)] 239 | acc = np.array(acc).mean() 240 | 241 | if self.args.compute_ram: 242 | visual_feats = torch.cat(visual_feats) 243 | textual_feats = torch.cat(textual_feats) 244 | self.args.ram_computer.compute_rotation_angle_matrix(self.args.sess, labels, visual_feats, textual_feats, indices) 245 | return acc 246 | 247 | @torch.no_grad() 248 | def accuracy(self, loaders, num_test=None, test_class=None, mean_per_class=False, only_eval=False, ood_test_loader=None): 249 | if mean_per_class: 250 | return self._accuracy_mpc(loaders) 251 | else: 252 | return self._accuracy(loaders, num_test, test_class, only_eval=only_eval, ood_test_loader=ood_test_loader) 253 | 254 | 255 | def post_training(self, finalize=False): 256 | pass 257 | 258 | def finetuning(self, data=None): 259 | pass 260 | --------------------------------------------------------------------------------