├── clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── __pycache__ │ ├── clip.cpython-39.pyc │ ├── clip.cpython-310.pyc │ ├── model.cpython-310.pyc │ ├── model.cpython-39.pyc │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-39.pyc │ ├── simple_tokenizer.cpython-310.pyc │ └── simple_tokenizer.cpython-39.pyc ├── simple_tokenizer.py ├── clip.py └── model.py ├── loralib ├── __init__.py ├── __pycache__ │ ├── layers.cpython-39.pyc │ ├── utils.cpython-310.pyc │ ├── utils.cpython-39.pyc │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-39.pyc │ └── layers.cpython-310.pyc ├── easymultiheadattention.py └── utils.py ├── __pycache__ ├── lora.cpython-39.pyc ├── lora.cpython-310.pyc ├── utils.cpython-310.pyc ├── utils.cpython-39.pyc ├── features.cpython-310.pyc ├── features.cpython-39.pyc ├── run_utils.cpython-39.pyc └── run_utils.cpython-310.pyc ├── datasets ├── __pycache__ │ ├── dtd.cpython-39.pyc │ ├── dtd.cpython-310.pyc │ ├── fgvc.cpython-310.pyc │ ├── fgvc.cpython-39.pyc │ ├── utils.cpython-39.pyc │ ├── __init__.cpython-39.pyc │ ├── cu_us.cpython-310.pyc │ ├── eurosat.cpython-310.pyc │ ├── eurosat.cpython-39.pyc │ ├── food101.cpython-310.pyc │ ├── food101.cpython-39.pyc │ ├── hicervix.cpython-39.pyc │ ├── imagenet.cpython-39.pyc │ ├── sipakmed.cpython-39.pyc │ ├── sun397.cpython-310.pyc │ ├── sun397.cpython-39.pyc │ ├── ucf101.cpython-310.pyc │ ├── ucf101.cpython-39.pyc │ ├── utils.cpython-310.pyc │ ├── __init__.cpython-310.pyc │ ├── caltech101.cpython-39.pyc │ ├── hicervix.cpython-310.pyc │ ├── imagenet.cpython-310.pyc │ ├── sipakmed.cpython-310.pyc │ ├── caltech101.cpython-310.pyc │ ├── oxford_pets.cpython-310.pyc │ ├── oxford_pets.cpython-39.pyc │ ├── stanford_cars.cpython-39.pyc │ ├── Dataset_kaggle1.cpython-39.pyc │ ├── dataset_kaggle.cpython-310.pyc │ ├── dataset_kaggle.cpython-39.pyc │ ├── oxford_flowers.cpython-310.pyc │ ├── oxford_flowers.cpython-39.pyc │ └── stanford_cars.cpython-310.pyc ├── caltech101.py ├── food101.py ├── __init__.py ├── stanford_cars.py ├── sun397.py ├── eurosat.py ├── ucf101.py ├── sipakmed.py ├── fgvc.py ├── oxford_flowers.py ├── dtd.py ├── dataset_kaggle.py ├── hicervix.py ├── oxford_pets.py ├── cu_us.py ├── utils.py └── imagenet.py ├── launch_run.sh ├── LICENSE ├── requirements.txt ├── utils.py ├── dataset_hicervix.py ├── features.py ├── run_utils.py ├── run.py ├── README.md ├── main.py └── lora.py /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /loralib/__init__.py: -------------------------------------------------------------------------------- 1 | from .layers import * 2 | from .utils import * -------------------------------------------------------------------------------- /__pycache__/lora.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/__pycache__/lora.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/lora.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/__pycache__/lora.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /__pycache__/features.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/__pycache__/features.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/features.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/__pycache__/features.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/run_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/__pycache__/run_utils.cpython-39.pyc -------------------------------------------------------------------------------- /clip/__pycache__/clip.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/clip/__pycache__/clip.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/run_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/__pycache__/run_utils.cpython-310.pyc -------------------------------------------------------------------------------- /clip/__pycache__/clip.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/clip/__pycache__/clip.cpython-310.pyc -------------------------------------------------------------------------------- /clip/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/clip/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /clip/__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/clip/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dtd.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/dtd.cpython-39.pyc -------------------------------------------------------------------------------- /clip/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/clip/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /clip/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/clip/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dtd.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/dtd.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/fgvc.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/fgvc.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/fgvc.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/fgvc.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /loralib/__pycache__/layers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/loralib/__pycache__/layers.cpython-39.pyc -------------------------------------------------------------------------------- /loralib/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/loralib/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /loralib/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/loralib/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/cu_us.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/cu_us.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/eurosat.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/eurosat.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/eurosat.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/eurosat.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/food101.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/food101.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/food101.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/food101.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/hicervix.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/hicervix.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/imagenet.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sipakmed.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/sipakmed.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sun397.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/sun397.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sun397.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/sun397.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/ucf101.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/ucf101.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/ucf101.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/ucf101.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /loralib/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/loralib/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /loralib/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/loralib/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /loralib/__pycache__/layers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/loralib/__pycache__/layers.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/caltech101.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/caltech101.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/hicervix.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/hicervix.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/imagenet.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sipakmed.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/sipakmed.cpython-310.pyc -------------------------------------------------------------------------------- /clip/__pycache__/simple_tokenizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/clip/__pycache__/simple_tokenizer.cpython-310.pyc -------------------------------------------------------------------------------- /clip/__pycache__/simple_tokenizer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/clip/__pycache__/simple_tokenizer.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/caltech101.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/caltech101.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/oxford_pets.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/oxford_pets.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/oxford_pets.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/oxford_pets.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/stanford_cars.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/stanford_cars.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/Dataset_kaggle1.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/Dataset_kaggle1.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dataset_kaggle.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/dataset_kaggle.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dataset_kaggle.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/dataset_kaggle.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/oxford_flowers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/oxford_flowers.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/oxford_flowers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/oxford_flowers.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/stanford_cars.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdausort/Cytology-fine-tuning/HEAD/datasets/__pycache__/stanford_cars.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/caltech101.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .oxford_pets import OxfordPets 3 | from .utils import Datum, DatasetBase 4 | 5 | 6 | template = ['a photo of a {}.'] 7 | 8 | 9 | class Caltech101(DatasetBase): 10 | 11 | dataset_dir = 'Caltech101' 12 | 13 | def __init__(self, root, num_shots): 14 | self.dataset_dir = os.path.join(root, self.dataset_dir) 15 | self.image_dir = os.path.join(self.dataset_dir, '101_ObjectCategories') 16 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_Caltech101.json') 17 | 18 | self.template = template 19 | 20 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 21 | n_shots_val = min(num_shots, 4) 22 | val = self.generate_fewshot_dataset(val, num_shots=n_shots_val) 23 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 24 | 25 | super().__init__(train_x=train, val=val, test=test) 26 | -------------------------------------------------------------------------------- /datasets/food101.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 4 | from .oxford_pets import OxfordPets 5 | 6 | """ 7 | template = ['a photo of {}, a type of food.'] 8 | """ 9 | template = ['a photo of a {}.'] 10 | 11 | class Food101(DatasetBase): 12 | 13 | dataset_dir = 'Food101' 14 | 15 | def __init__(self, root, num_shots): 16 | self.dataset_dir = os.path.join(root, self.dataset_dir) 17 | self.image_dir = os.path.join(self.dataset_dir, 'images') 18 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_Food101.json') 19 | 20 | self.template = template 21 | 22 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 23 | n_shots_val = min(num_shots, 4) 24 | val = self.generate_fewshot_dataset(val, num_shots=n_shots_val) 25 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 26 | 27 | super().__init__(train_x=train, val=val, test=test) -------------------------------------------------------------------------------- /launch_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --job-name=cyto_finetune 4 | # 5 | #SBATCH --cpus-per-task=32 6 | #SBATCH --ntasks=1 7 | # 8 | #SBATCH --partition=gpu 9 | #SBATCH --gres=gpu:2 10 | 11 | source PATH_TO_CHANGE/bin/activate 12 | 13 | cd PATH_TO_CHANGE 14 | 15 | # Experiment 1: Linear Classifier 16 | # python3 run.py --seed_launch "$seed" --shots_launch -1 --lr_launch "$lr" --model_launch "$model" --dataset_launch kaggle1 --task_launch classifier 17 | 18 | # Experiment 2: LoRA Few-Shot Adaptation 19 | # python3 run.py --seed_launch "$seed" --shots_launch "$shot" --lr_launch "$lr" --iterations 100 --model_launch "$model" --dataset_launch kaggle1 --task_launch lora 20 | # python3 run.py --seed_launch "$seed" --shots_launch "$shot" --lr_launch "$lr" --iterations 100 --model_launch "$model" --dataset_launch hicervix --task_launch lora --level_launch "level_3" 21 | 22 | # Experiment 3: Pushing Model Fine-Tuning Limits 23 | # python3 run.py --seed_launch "$seed" --shots_launch "$shot" --lr_launch "$lr" --iterations 100 --model_launch "$model" --dataset_launch hicervix --level_launch "level_3" --percent_launch 10 --task_launch percentage_lora 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Manon Dausort 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .ucf101 import UCF101 2 | from .sun397 import SUN397 3 | from .eurosat import EuroSAT 4 | from .food101 import Food101 5 | from .imagenet import ImageNet 6 | from .sipakmed import SipakMed 7 | from .fgvc import FGVCAircraft 8 | from .hicervix import HiCervix 9 | from .caltech101 import Caltech101 10 | from .oxford_pets import OxfordPets 11 | from .dtd import DescribableTextures 12 | from .stanford_cars import StanfordCars 13 | from .oxford_flowers import OxfordFlowers 14 | from .dataset_kaggle import MLCC, BCFC 15 | 16 | 17 | dataset_list = { 18 | "oxford_pets": OxfordPets, 19 | "eurosat": EuroSAT, 20 | "ucf101": UCF101, 21 | "sun397": SUN397, 22 | "caltech101": Caltech101, 23 | "dtd": DescribableTextures, 24 | "fgvc": FGVCAircraft, 25 | "food101": Food101, 26 | "oxford_flowers": OxfordFlowers, 27 | "stanford_cars": StanfordCars, 28 | "imagenet": ImageNet, 29 | "sipakmed": SipakMed, 30 | "mlcc": MLCC, 31 | "bcfc": BCFC, 32 | "hicervix": HiCervix 33 | } 34 | 35 | 36 | def build_dataset( 37 | dataset, root_path, shots, level="level_1", pourcentage=0.0, preprocess=None 38 | ): 39 | if dataset == "imagenet": 40 | return dataset_list[dataset](root_path, shots, preprocess) 41 | elif dataset == "hicervix": 42 | return dataset_list[dataset](root_path, shots, level, pourcentage) 43 | else: 44 | return dataset_list[dataset](root_path, shots) 45 | -------------------------------------------------------------------------------- /datasets/stanford_cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | from scipy.io import loadmat 3 | 4 | from .oxford_pets import OxfordPets 5 | from .utils import Datum, DatasetBase 6 | 7 | 8 | template = ['a photo of a {}.'] 9 | 10 | 11 | class StanfordCars(DatasetBase): 12 | 13 | dataset_dir = 'StanfordCars' 14 | 15 | def __init__(self, root, num_shots): 16 | self.dataset_dir = os.path.join(root, self.dataset_dir) 17 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_StanfordCars.json') 18 | 19 | self.template = template 20 | 21 | train, val, test = OxfordPets.read_split(self.split_path, self.dataset_dir) 22 | n_shots_val = min(num_shots, 4) 23 | val = self.generate_fewshot_dataset(val, num_shots=n_shots_val) 24 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 25 | 26 | super().__init__(train_x=train, val=val, test=test) 27 | 28 | def read_data(self, image_dir, anno_file, meta_file): 29 | anno_file = loadmat(anno_file)['annotations'][0] 30 | meta_file = loadmat(meta_file)['class_names'][0] 31 | items = [] 32 | 33 | for i in range(len(anno_file)): 34 | imname = anno_file[i]['fname'][0] 35 | impath = os.path.join(self.dataset_dir, image_dir, imname) 36 | label = anno_file[i]['class'][0, 0] 37 | label = int(label) - 1 # convert to 0-based index 38 | classname = meta_file[label][0] 39 | names = classname.split(' ') 40 | year = names.pop(-1) 41 | names.insert(0, year) 42 | classname = ' '.join(names) 43 | item = Datum( 44 | impath=impath, 45 | label=label, 46 | classname=classname 47 | ) 48 | items.append(item) 49 | 50 | return items -------------------------------------------------------------------------------- /datasets/sun397.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 4 | 5 | from .oxford_pets import OxfordPets 6 | 7 | 8 | template = ['a photo of a {}.'] 9 | 10 | 11 | class SUN397(DatasetBase): 12 | 13 | dataset_dir = 'SUN397' 14 | 15 | def __init__(self, root, num_shots): 16 | self.dataset_dir = os.path.join(root, self.dataset_dir) 17 | self.image_dir = os.path.join(self.dataset_dir, 'SUN397') 18 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_SUN397.json') 19 | 20 | self.template = template 21 | 22 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 23 | n_shots_val = min(num_shots, 4) 24 | val = self.generate_fewshot_dataset(val, num_shots=n_shots_val) 25 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 26 | 27 | super().__init__(train_x=train, val=val, test=test) 28 | 29 | def read_data(self, cname2lab, text_file): 30 | text_file = os.path.join(self.dataset_dir, text_file) 31 | items = [] 32 | 33 | with open(text_file, 'r') as f: 34 | lines = f.readlines() 35 | for line in lines: 36 | imname = line.strip()[1:] # remove / 37 | classname = os.path.dirname(imname) 38 | label = cname2lab[classname] 39 | impath = os.path.join(self.image_dir, imname) 40 | 41 | names = classname.split('/')[1:] # remove 1st letter 42 | names = names[::-1] # put words like indoor/outdoor at first 43 | classname = ' '.join(names) 44 | 45 | item = Datum( 46 | impath=impath, 47 | label=label, 48 | classname=classname 49 | ) 50 | items.append(item) 51 | 52 | return items 53 | -------------------------------------------------------------------------------- /datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 4 | from .oxford_pets import OxfordPets 5 | 6 | """ 7 | template = ['a centered satellite photo of {}.'] 8 | """ 9 | template = ['a photo of a {}.'] 10 | 11 | NEW_CNAMES = { 12 | 'AnnualCrop': 'Annual Crop Land', 13 | 'Forest': 'Forest', 14 | 'HerbaceousVegetation': 'Herbaceous Vegetation Land', 15 | 'Highway': 'Highway or Road', 16 | 'Industrial': 'Industrial Buildings', 17 | 'Pasture': 'Pasture Land', 18 | 'PermanentCrop': 'Permanent Crop Land', 19 | 'Residential': 'Residential Buildings', 20 | 'River': 'River', 21 | 'SeaLake': 'Sea or Lake' 22 | } 23 | 24 | 25 | class EuroSAT(DatasetBase): 26 | 27 | dataset_dir = 'eurosat' 28 | 29 | def __init__(self, root, num_shots): 30 | self.dataset_dir = os.path.join(root, self.dataset_dir) 31 | self.image_dir = os.path.join(self.dataset_dir, '2750') 32 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_EuroSAT.json') 33 | 34 | self.template = template 35 | 36 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 37 | n_shots_val = min(num_shots, 4) 38 | val = self.generate_fewshot_dataset(val, num_shots=n_shots_val) 39 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 40 | 41 | super().__init__(train_x=train, val=val, test=test) 42 | 43 | def update_classname(self, dataset_old): 44 | dataset_new = [] 45 | for item_old in dataset_old: 46 | cname_old = item_old.classname 47 | cname_new = NEW_CLASSNAMES[cname_old] 48 | item_new = Datum( 49 | impath=item_old.impath, 50 | label=item_old.label, 51 | classname=cname_new 52 | ) 53 | dataset_new.append(item_new) 54 | return dataset_new 55 | -------------------------------------------------------------------------------- /datasets/ucf101.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 4 | 5 | from .oxford_pets import OxfordPets 6 | 7 | """ 8 | template = ['a photo of a person doing {}.'] 9 | """ 10 | template = ['a photo of a {}.'] 11 | 12 | class UCF101(DatasetBase): 13 | 14 | dataset_dir = 'UCF101' 15 | 16 | def __init__(self, root, num_shots): 17 | self.dataset_dir = os.path.join(root, self.dataset_dir) 18 | self.image_dir = os.path.join(self.dataset_dir, 'UCF-101-midframes') 19 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_UCF101.json') 20 | 21 | self.template = template 22 | 23 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 24 | n_shots_val = min(num_shots, 4) 25 | val = self.generate_fewshot_dataset(val, num_shots=n_shots_val) 26 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 27 | 28 | super().__init__(train_x=train, val=val, test=test) 29 | 30 | def read_data(self, cname2lab, text_file): 31 | text_file = os.path.join(self.dataset_dir, text_file) 32 | items = [] 33 | 34 | with open(text_file, 'r') as f: 35 | lines = f.readlines() 36 | for line in lines: 37 | line = line.strip().split(' ')[0] # trainlist: filename, label 38 | action, filename = line.split('/') 39 | label = cname2lab[action] 40 | 41 | elements = re.findall('[A-Z][^A-Z]*', action) 42 | renamed_action = '_'.join(elements) 43 | 44 | filename = filename.replace('.avi', '.jpg') 45 | impath = os.path.join(self.image_dir, renamed_action, filename) 46 | 47 | item = Datum( 48 | impath=impath, 49 | label=label, 50 | classname=renamed_action 51 | ) 52 | items.append(item) 53 | 54 | return items 55 | -------------------------------------------------------------------------------- /datasets/sipakmed.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import glob 3 | import os 4 | 5 | from .utils import Datum, DatasetBase 6 | 7 | template = ["A pap smear slide showing a {} cervical cells."] 8 | 9 | 10 | class SipakMed(DatasetBase): 11 | 12 | dataset_dir = "sipakmed" 13 | classes = [ 14 | "Dyskeratotic", 15 | "Koilocytotic", 16 | "Metaplastic", 17 | "Parabasal", 18 | "Superficial-Intermediate", 19 | ] 20 | 21 | def __init__(self, root, num_shots): 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.image_dir = os.path.join(self.dataset_dir, "images") 24 | 25 | self.template = template 26 | 27 | train = self.create_list_of_datum("train") 28 | val = self.create_list_of_datum("val") 29 | test = self.create_list_of_datum("test") 30 | 31 | n_shots_val = min(num_shots, 4) 32 | val = self.generate_fewshot_dataset(val, num_shots=n_shots_val) 33 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 34 | 35 | super().__init__(train_x=train, val=val, test=test) 36 | 37 | def __getitem__(self, im_files, idx): 38 | image = cv2.imread(im_files[idx]) 39 | class_name = im_files[idx].split("/")[-1].split("_")[0] 40 | class_ = self.classes.index(class_name) 41 | 42 | return image, class_, class_name, im_files[idx] 43 | 44 | def create_list_of_datum(self, set): 45 | """Create a list of Datum objects, each containing the image and label.""" 46 | datum_list = [] 47 | 48 | im_files = glob.glob(os.path.join(self.image_dir, set, "*.bmp")) 49 | for i in range(len(im_files)): 50 | # Get the image and the class 51 | image, class_, class_name, impath = self.__getitem__(im_files, i) 52 | 53 | # Create a Datum object 54 | datum = Datum(impath=impath, label=class_, classname=class_name) 55 | 56 | # Append the datum to the list 57 | datum_list.append(datum) 58 | 59 | return datum_list 60 | -------------------------------------------------------------------------------- /datasets/fgvc.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 4 | 5 | """ 6 | template = ['a photo of a {}, a type of aircraft.'] 7 | """ 8 | template = ['a photo of a {}.'] 9 | class FGVCAircraft(DatasetBase): 10 | 11 | dataset_dir = 'fgvc_aircraft' 12 | 13 | def __init__(self, root, num_shots): 14 | 15 | self.dataset_dir = os.path.join(root, self.dataset_dir) 16 | self.image_dir = os.path.join(self.dataset_dir, 'images') 17 | 18 | self.template = template 19 | 20 | classnames = [] 21 | with open(os.path.join(self.dataset_dir, 'variants.txt'), 'r') as f: 22 | lines = f.readlines() 23 | for line in lines: 24 | classnames.append(line.strip()) 25 | cname2lab = {c: i for i, c in enumerate(classnames)} 26 | 27 | train = self.read_data(cname2lab, 'images_variant_train.txt') 28 | val = self.read_data(cname2lab, 'images_variant_val.txt') 29 | test = self.read_data(cname2lab, 'images_variant_test.txt') 30 | 31 | n_shots_val = min(num_shots, 4) 32 | val = self.generate_fewshot_dataset(val, num_shots=n_shots_val) 33 | 34 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 35 | 36 | super().__init__(train_x=train, val=val, test=test) 37 | 38 | def read_data(self, cname2lab, split_file): 39 | filepath = os.path.join(self.dataset_dir, split_file) 40 | items = [] 41 | 42 | with open(filepath, 'r') as f: 43 | lines = f.readlines() 44 | for line in lines: 45 | line = line.strip().split(' ') 46 | imname = line[0] + '.jpg' 47 | classname = ' '.join(line[1:]) 48 | impath = os.path.join(self.image_dir, imname) 49 | label = cname2lab[classname] 50 | item = Datum( 51 | impath=impath, 52 | label=label, 53 | classname=classname 54 | ) 55 | items.append(item) 56 | 57 | return items -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | anyio==3.6.1 2 | argon2-cffi==21.3.0 3 | argon2-cffi-bindings==21.2.0 4 | asttokens==2.0.5 5 | attrs==21.4.0 6 | Automat==20.2.0 7 | Babel==2.10.3 8 | backcall==0.2.0 9 | beautifulsoup4==4.11.1 10 | bleach==5.0.0 11 | certifi==2022.6.15 12 | cffi==1.15.0 13 | charset-normalizer==2.0.12 14 | constantly==15.1.0 15 | cycler==0.11.0 16 | debugpy==1.6.0 17 | decorator==5.1.1 18 | defusedxml==0.7.1 19 | easybuild==4.6.0 20 | easybuild-easyblocks==4.6.0 21 | easybuild-easyconfigs==4.6.0 22 | easybuild-framework==4.6.0 23 | entrypoints==0.4 24 | executing==0.8.3 25 | fastjsonschema==2.15.3 26 | fonttools==4.33.3 27 | h5py==3.10.0 28 | hyperlink==21.0.0 29 | idna==3.3 30 | importlib-metadata==4.12.0 31 | incremental==21.3.0 32 | ipykernel==6.15.0 33 | ipython==8.4.0 34 | ipython-genutils==0.2.0 35 | jedi==0.18.1 36 | Jinja2==3.1.2 37 | json5==0.9.8 38 | jsonschema==4.6.0 39 | jupyter-client==7.3.4 40 | jupyter-core==4.10.0 41 | jupyter-server==1.18.0 42 | jupyterlab==3.4.3 43 | jupyterlab-pygments==0.2.2 44 | jupyterlab-server==2.14.0 45 | kiwisolver==1.4.3 46 | MarkupSafe==2.1.1 47 | matplotlib==3.5.2 48 | matplotlib-inline==0.1.3 49 | mistune==0.8.4 50 | mysql-connector-python==8.0.29 51 | mysqlclient==2.1.1 52 | nbclassic==0.3.7 53 | nbclient==0.6.4 54 | nbconvert==6.5.0 55 | nbformat==5.4.0 56 | nest-asyncio==1.5.5 57 | notebook==6.4.12 58 | notebook-shim==0.1.0 59 | numpy==1.23.0 60 | opencv-python==4.10.0.84 61 | packaging==21.3 62 | pandas==2.2.3 63 | pandocfilters==1.5.0 64 | parso==0.8.3 65 | pexpect==4.8.0 66 | pickleshare==0.7.5 67 | Pillow==9.1.1 68 | prometheus-client==0.14.1 69 | prompt-toolkit==3.0.29 70 | protobuf==4.21.3 71 | psutil==5.9.1 72 | ptyprocess==0.7.0 73 | pure-eval==0.2.2 74 | pycparser==2.21 75 | Pygments==2.12.0 76 | pyparsing==3.0.9 77 | pyrsistent==0.18.1 78 | python-dateutil==2.8.2 79 | pytz==2022.1 80 | pyzmq==23.2.0 81 | requests==2.28.0 82 | scipy==1.8.1 83 | Send2Trash==1.8.0 84 | six==1.16.0 85 | sniffio==1.2.0 86 | soupsieve==2.3.2.post1 87 | stack-data==0.3.0 88 | terminado==0.15.0 89 | tinycss2==1.1.1 90 | tornado==6.1 91 | traitlets==5.3.0 92 | Twisted==22.4.0 93 | typing_extensions==4.2.0 94 | tzdata==2024.2 95 | urllib3==1.26.9 96 | wcwidth==0.2.5 97 | webencodings==0.5.1 98 | websocket-client==1.3.3 99 | zipp==3.8.0 100 | zope.interface==5.4.0 101 | -------------------------------------------------------------------------------- /datasets/oxford_flowers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from scipy.io import loadmat 4 | from collections import defaultdict 5 | 6 | from .oxford_pets import OxfordPets 7 | from .utils import Datum, DatasetBase, read_json 8 | 9 | """ 10 | template = ['a photo of a {}, a type of flower.'] 11 | """ 12 | template = ['a photo of a {}.'] 13 | 14 | class OxfordFlowers(DatasetBase): 15 | 16 | dataset_dir = 'Flower102' 17 | 18 | def __init__(self, root, num_shots): 19 | self.dataset_dir = os.path.join(root, self.dataset_dir) 20 | self.image_dir = os.path.join(self.dataset_dir, 'jpg') 21 | self.label_file = os.path.join(self.dataset_dir, 'imagelabels.mat') 22 | self.lab2cname_file = os.path.join(self.dataset_dir, 'cat_to_name.json') 23 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_OxfordFlowers.json') 24 | 25 | self.template = template 26 | 27 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 28 | n_shots_val = min(num_shots, 4) 29 | val = self.generate_fewshot_dataset(val, num_shots=n_shots_val) 30 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 31 | 32 | super().__init__(train_x=train, val=val, test=test) 33 | 34 | def read_data(self): 35 | tracker = defaultdict(list) 36 | label_file = loadmat(self.label_file)['labels'][0] 37 | for i, label in enumerate(label_file): 38 | imname = f'image_{str(i + 1).zfill(5)}.jpg' 39 | impath = os.path.join(self.image_dir, imname) 40 | label = int(label) 41 | tracker[label].append(impath) 42 | 43 | print('Splitting data into 50% train, 20% val, and 30% test') 44 | 45 | def _collate(ims, y, c): 46 | items = [] 47 | for im in ims: 48 | item = Datum( 49 | impath=im, 50 | label=y-1, # convert to 0-based label 51 | classname=c 52 | ) 53 | items.append(item) 54 | return items 55 | 56 | lab2cname = read_json(self.lab2cname_file) 57 | train, val, test = [], [], [] 58 | for label, impaths in tracker.items(): 59 | random.shuffle(impaths) 60 | n_total = len(impaths) 61 | n_train = round(n_total * 0.5) 62 | n_val = round(n_total * 0.2) 63 | n_test = n_total - n_train - n_val 64 | assert n_train > 0 and n_val > 0 and n_test > 0 65 | cname = lab2cname[str(label)] 66 | train.extend(_collate(impaths[:n_train], label, cname)) 67 | val.extend(_collate(impaths[n_train:n_train+n_val], label, cname)) 68 | test.extend(_collate(impaths[n_train+n_val:], label, cname)) 69 | 70 | return train, val, test -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import clip 2 | import torch 3 | from tqdm import tqdm 4 | 5 | 6 | def get_function(model_name, clip_model, tokenizer=None): 7 | MODEL_NAME = {} 8 | if model_name == "clip": 9 | MODEL_NAME = { 10 | "vision": clip_model.encode_image, 11 | "text": clip_model.encode_text, 12 | "token": clip.tokenize, 13 | } 14 | 15 | elif model_name == "quilt": 16 | MODEL_NAME = { 17 | "vision": clip_model.encode_image, 18 | "text": clip_model.encode_text, 19 | "token": clip.tokenize, 20 | } 21 | 22 | elif model_name == "biomedclip": 23 | MODEL_NAME = { 24 | "vision": clip_model.visual, 25 | "text": clip_model.text, 26 | "token": tokenizer, 27 | } 28 | 29 | elif model_name == "uni": 30 | 31 | MODEL_NAME = {"vision": clip_model, "text": None, "token": None} 32 | 33 | elif model_name == "vit_google": 34 | 35 | MODEL_NAME = {"vision": clip_model, "text": None, "token": None} 36 | 37 | return MODEL_NAME["vision"], MODEL_NAME["text"], MODEL_NAME["token"] 38 | 39 | 40 | def cls_acc(output, target, topk=1): 41 | pred = output.topk(topk, 1, True, True)[1].t() 42 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 43 | acc = float(correct[:topk].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) 44 | acc = 100 * acc / target.shape[0] 45 | 46 | return acc 47 | 48 | 49 | def clip_classifier(classnames, template, clip_model, model_name, tokenizer): 50 | 51 | vision, text, token = get_function(model_name, clip_model, tokenizer) 52 | 53 | with torch.no_grad(): 54 | 55 | clip_weights = [] 56 | 57 | for classname in classnames: 58 | # Tokenize the prompts 59 | classname = classname.replace("_", " ") 60 | texts = [t.format(classname) for t in template] 61 | texts = token(texts).cuda() 62 | class_embeddings = text(texts) 63 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) 64 | class_embedding = class_embeddings.mean(dim=0) 65 | class_embedding /= class_embedding.norm() 66 | clip_weights.append(class_embedding) 67 | 68 | clip_weights = torch.stack(clip_weights, dim=1).cuda() 69 | 70 | return clip_weights 71 | 72 | 73 | def pre_load_features(clip_model, loader, model_name, tokenizer): 74 | 75 | vision, text, token = get_function(model_name, clip_model, tokenizer) 76 | features, labels = [], [] 77 | 78 | with torch.no_grad(): 79 | for i, (images, target) in enumerate(tqdm(loader)): 80 | images, target = images.cuda(), target.cuda() 81 | image_features = vision(images) 82 | image_features /= image_features.norm(dim=-1, keepdim=True) 83 | features.append(image_features.cpu()) 84 | labels.append(target.cpu()) 85 | 86 | features, labels = torch.cat(features), torch.cat(labels) 87 | 88 | return features, labels 89 | -------------------------------------------------------------------------------- /datasets/dtd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | from .utils import Datum, DatasetBase, listdir_nohidden 5 | from .oxford_pets import OxfordPets 6 | 7 | """ 8 | template = ['{} texture.'] 9 | """ 10 | template = ['a photo of a {}.'] 11 | 12 | class DescribableTextures(DatasetBase): 13 | 14 | dataset_dir = 'DTD' 15 | 16 | def __init__(self, root, num_shots): 17 | self.dataset_dir = os.path.join(root, self.dataset_dir) 18 | self.image_dir = os.path.join(self.dataset_dir, 'images') 19 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_DescribableTextures.json') 20 | 21 | self.template = template 22 | 23 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 24 | n_shots_val = min(num_shots, 4) 25 | val = self.generate_fewshot_dataset(val, num_shots=n_shots_val) 26 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 27 | 28 | super().__init__(train_x=train, val=val, test=test) 29 | 30 | @staticmethod 31 | def read_and_split_data( 32 | image_dir, 33 | p_trn=0.5, 34 | p_val=0.2, 35 | ignored=[], 36 | new_cnames=None 37 | ): 38 | # The data are supposed to be organized into the following structure 39 | # ============= 40 | # images/ 41 | # dog/ 42 | # cat/ 43 | # horse/ 44 | # ============= 45 | categories = listdir_nohidden(image_dir) 46 | categories = [c for c in categories if c not in ignored] 47 | categories.sort() 48 | 49 | p_tst = 1 - p_trn - p_val 50 | print(f'Splitting into {p_trn:.0%} train, {p_val:.0%} val, and {p_tst:.0%} test') 51 | 52 | def _collate(ims, y, c): 53 | items = [] 54 | for im in ims: 55 | item = Datum( 56 | impath=im, 57 | label=y, # is already 0-based 58 | classname=c 59 | ) 60 | items.append(item) 61 | return items 62 | 63 | train, val, test = [], [], [] 64 | for label, category in enumerate(categories): 65 | category_dir = os.path.join(image_dir, category) 66 | images = listdir_nohidden(category_dir) 67 | images = [os.path.join(category_dir, im) for im in images] 68 | random.shuffle(images) 69 | n_total = len(images) 70 | n_train = round(n_total * p_trn) 71 | n_val = round(n_total * p_val) 72 | n_test = n_total - n_train - n_val 73 | assert n_train > 0 and n_val > 0 and n_test > 0 74 | 75 | if new_cnames is not None and category in new_cnames: 76 | category = new_cnames[category] 77 | 78 | train.extend(_collate(images[:n_train], label, category)) 79 | val.extend(_collate(images[n_train:n_train+n_val], label, category)) 80 | test.extend(_collate(images[n_train+n_val:], label, category)) 81 | 82 | return train, val, test 83 | -------------------------------------------------------------------------------- /dataset_hicervix.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import numpy as np 5 | import pandas as pd 6 | from datasets import build_dataset 7 | 8 | 9 | if __name__ == "__main__": 10 | 11 | # Parser creation 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--seed_launch", type=int, help="Seed number") 14 | parser.add_argument("--shots_launch", type=int, help="Shot number") 15 | parser.add_argument( 16 | "--level_launch", 17 | default="level_1", 18 | type=str, 19 | help="This is the level of the hierarchical tree to capture different fine-grained subtype information. Only applicable in the case of hicervix.", 20 | ) 21 | parser.add_argument( 22 | "--percent_launch", 23 | type=float, 24 | help="Percentage of the dataset considered. Used for the third experiment.", 25 | ) 26 | 27 | args = parser.parse_args() 28 | 29 | # Variables 30 | seed = args.seed_launch 31 | shot = args.shots_launch 32 | dataset = "hicervix" 33 | level = args.level_launch 34 | percent = args.percent_launch 35 | 36 | if dataset == "hicervix": 37 | 38 | df = pd.read_csv( 39 | "path_of_dataset/train.csv" 40 | ) # TO CHANGE 41 | 42 | if level == "level_3": 43 | class_list_2 = sorted(np.unique(df.loc[:, "level_2"].dropna().tolist())) 44 | class_list_3 = sorted(np.unique(df.loc[:, "level_3"].dropna().tolist())) 45 | 46 | combined_class_list = np.append(class_list_2, class_list_3) 47 | 48 | cleaned_class_list = pd.Series(combined_class_list).dropna() 49 | 50 | class_list = sorted(np.unique(cleaned_class_list).tolist()) 51 | 52 | elif level == "level_2": 53 | class_list_2 = sorted(np.unique(df.loc[:, "level_2"].dropna().tolist())) 54 | class_list_1 = sorted(np.unique(df.loc[:, "level_1"].dropna().tolist())) 55 | 56 | combined_class_list = np.append(class_list_2, class_list_1) 57 | 58 | cleaned_class_list = pd.Series(combined_class_list).dropna() 59 | 60 | class_list = sorted(np.unique(cleaned_class_list).tolist()) 61 | 62 | else: 63 | class_list = np.unique(df.loc[:, level].tolist()) 64 | 65 | num_classes = len(class_list) 66 | 67 | root_path = "path_of_dataset" # TO CHANGE 68 | 69 | else: 70 | raise RuntimeError("Wrong dataset") 71 | 72 | level_name = level.replace("_", "") 73 | 74 | if percent > 0: 75 | if not os.path.exists( 76 | f"./{dataset}_{seed}_{shot}_{level_name}_{percent}_percent.pt" 77 | ): 78 | dataset_all = build_dataset(dataset, root_path, shot, level, percent) 79 | torch.save( 80 | dataset_all, 81 | f"./{dataset}_{seed}_{shot}_{level_name}_{percent}_percent.pt", 82 | ) 83 | else: 84 | if not os.path.exists(f"./{dataset}_{seed}_{shot}_{level_name}.pt"): 85 | dataset_all = build_dataset(dataset, root_path, shot, level) 86 | torch.save(dataset_all, f"./{dataset}_{seed}_{shot}_{level_name}.pt") 87 | -------------------------------------------------------------------------------- /datasets/dataset_kaggle.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import glob 4 | 5 | from .utils import Datum, DatasetBase # type: ignore 6 | 7 | template = ["A pap smear slide showing a {} cervical cells."] 8 | 9 | 10 | class MLCC(DatasetBase): 11 | 12 | dataset_dir = "MLCC" 13 | classes = [ 14 | "HSIL", 15 | "LSIL", 16 | "NL", 17 | "SCC", 18 | ] 19 | 20 | def __init__(self, root, num_shots): 21 | self.dataset_dir = os.path.join(root, self.dataset_dir) 22 | self.image_dir = os.path.join(self.dataset_dir, "images") 23 | 24 | self.template = template 25 | 26 | train = self.create_list_of_datum("train") 27 | val = self.create_list_of_datum("val") 28 | test = self.create_list_of_datum("test") 29 | 30 | n_shots_val = min(num_shots, 4) 31 | val = self.generate_fewshot_dataset(val, num_shots=n_shots_val) 32 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 33 | 34 | super().__init__(train_x=train, val=val, test=test) 35 | 36 | def __getitem__(self, im_files, idx): 37 | image = cv2.imread(im_files[idx]) 38 | class_name = im_files[idx].split("/")[-1].split("_")[0] 39 | class_ = self.classes.index(class_name) 40 | 41 | return image, class_, class_name, im_files[idx] 42 | 43 | def create_list_of_datum(self, set): 44 | """Create a list of Datum objects, each containing the image and label.""" 45 | datum_list = [] 46 | 47 | im_files = glob.glob(os.path.join(self.image_dir, set, "*.jpg")) 48 | for i in range(len(im_files)): 49 | # Get the image and the class 50 | image, class_, class_name, impath = self.__getitem__(im_files, i) 51 | 52 | # Create a Datum object 53 | datum = Datum(impath=impath, label=class_, classname=class_name) 54 | 55 | # Append the datum to the list 56 | datum_list.append(datum) 57 | 58 | return datum_list 59 | 60 | 61 | class BCFC(DatasetBase): 62 | 63 | dataset_dir = "BCFC" 64 | classes = [ 65 | "malignant", 66 | "benign", 67 | ] 68 | 69 | def __init__(self, root, num_shots): 70 | self.dataset_dir = os.path.join(root, self.dataset_dir) 71 | self.image_dir = os.path.join(self.dataset_dir, "images") 72 | 73 | self.template = template 74 | 75 | train = self.create_list_of_datum("train") 76 | val = self.create_list_of_datum("val") 77 | test = self.create_list_of_datum("test") 78 | 79 | n_shots_val = min(num_shots, 4) 80 | val = self.generate_fewshot_dataset(val, num_shots=n_shots_val) 81 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 82 | 83 | super().__init__(train_x=train, val=val, test=test) 84 | 85 | def __getitem__(self, im_files, idx): 86 | image = cv2.imread(im_files[idx]) 87 | class_name = im_files[idx].split("/")[-1].split("_")[0] 88 | class_ = self.classes.index(class_name) 89 | 90 | return image, class_, class_name, im_files[idx] 91 | 92 | def create_list_of_datum(self, set): 93 | """Create a list of Datum objects, each containing the image and label.""" 94 | datum_list = [] 95 | 96 | im_files = glob.glob(os.path.join(self.image_dir, set, "*.png")) 97 | for i in range(len(im_files)): 98 | # Get the image and the class 99 | image, class_, class_name, impath = self.__getitem__(im_files, i) 100 | 101 | # Create a Datum object 102 | datum = Datum(impath=impath, label=class_, classname=class_name) 103 | 104 | # Append the datum to the list 105 | datum_list.append(datum) 106 | 107 | return datum_list 108 | -------------------------------------------------------------------------------- /features.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from tqdm import tqdm 5 | from utils import get_function 6 | from torch.utils.data import Dataset 7 | from transformers.modeling_outputs import ImageClassifierOutput 8 | 9 | 10 | class FeaturesDataset(Dataset): 11 | def __init__(self, features_path): 12 | self.data = np.load(features_path) 13 | self.features = self.data["features"] 14 | self.labels = self.data["labels"] 15 | 16 | def __len__(self): 17 | return len(self.features) 18 | 19 | def __getitem__(self, idx): 20 | feature = torch.tensor(self.features[idx], dtype=torch.float) 21 | label = torch.tensor(self.labels[idx], dtype=torch.long) 22 | 23 | return feature, label 24 | 25 | 26 | def features_extractor(args, model, train_loader, val_loader, test_loader): 27 | 28 | device = torch.device("cuda") 29 | model.to(device) 30 | 31 | if args.model_name == "vit_google": 32 | setattr(model, "classifier", torch.nn.Identity()) 33 | 34 | features_csv_train = os.path.join( 35 | args.root_path, args.dataset + "_" + args.model_name + "_features_train.npz" 36 | ) 37 | features_csv_val = os.path.join( 38 | args.root_path, args.dataset + "_" + args.model_name + "_features_val.npz" 39 | ) 40 | features_csv_test = os.path.join( 41 | args.root_path, args.dataset + "_" + args.model_name + "_features_test.npz" 42 | ) 43 | 44 | list_dataloader = [] 45 | features_path = [] 46 | if not os.path.exists(features_csv_train): 47 | features_path.append(features_csv_train) 48 | list_dataloader.append(train_loader) 49 | if not os.path.exists(features_csv_val): 50 | features_path.append(features_csv_val) 51 | list_dataloader.append(val_loader) 52 | if not os.path.exists(features_csv_test): 53 | features_path.append(features_csv_test) 54 | list_dataloader.append(test_loader) 55 | 56 | if len(features_path) == 0: 57 | print( 58 | f"All features have been extracted for {args.model_name} and {args.dataset}" 59 | ) 60 | else: 61 | encode_image, _, __ = get_function(args.model_name, model) 62 | 63 | with torch.no_grad(): 64 | for dataloader, path in zip(list_dataloader, features_path): 65 | 66 | features = [] 67 | labels = [] 68 | 69 | for image, label in tqdm(dataloader): 70 | 71 | image = image.to(device) 72 | 73 | img = encode_image(image) 74 | 75 | if isinstance(img, ImageClassifierOutput): 76 | img = img.logits 77 | 78 | features.append(img.cpu().numpy()) 79 | labels.append(label) 80 | 81 | features = np.concatenate(features) 82 | labels = np.concatenate(labels) 83 | np.savez(path, features=features, labels=labels) 84 | 85 | return 86 | 87 | 88 | def textual_extractor(args, dataset, model, tokenizer): 89 | 90 | textual_csv = os.path.join( 91 | args.root_path, args.dataset + "_" + args.model_name + "_textual_train.npz" 92 | ) 93 | 94 | if os.path.exists(textual_csv): 95 | print( 96 | f"All textual features have been extracted for {args.model_name} and {args.dataset}" 97 | ) 98 | else: 99 | 100 | if args.dataset in ["sipakmed", "hicervix"]: 101 | template = "A cytological slide showing a {} cell" 102 | else: 103 | template = "A cytological slide showing {} cells" 104 | 105 | texts = [ 106 | template.format(classname.replace("_", " ")) 107 | for classname in dataset.classnames 108 | ] 109 | _, text, token = get_function(args.model_name, model, tokenizer) 110 | 111 | with torch.no_grad(): 112 | with torch.amp.autocast(device_type="cuda", dtype=torch.float16): 113 | texts = token(texts).cuda() 114 | class_embeddings = text(texts) 115 | text_features = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) 116 | text_features = text_features.cpu().numpy() 117 | 118 | np.savez(textual_csv, textuals=text_features) 119 | 120 | return 121 | -------------------------------------------------------------------------------- /datasets/hicervix.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | import pandas as pd 5 | from tqdm import tqdm 6 | from .utils import Datum, DatasetBase 7 | 8 | template = ["A pap smear slide showing a {} cervical cells."] 9 | 10 | 11 | class HiCervix(DatasetBase): 12 | 13 | dataset_dir = "HiCervix" 14 | 15 | def __init__(self, root, num_shots, level, pourcentage): 16 | self.dataset_dir = os.path.join(root) 17 | self.image_dir = os.path.join(self.dataset_dir) 18 | 19 | self.level = level 20 | self.pourcentage = pourcentage 21 | 22 | self.train_csv = os.path.join(self.image_dir, "train.csv") 23 | self.test_csv = os.path.join(self.image_dir, "test.csv") 24 | self.val_csv = os.path.join(self.image_dir, "val.csv") 25 | 26 | self.template = template 27 | 28 | train = self.create_list_of_datum("train") 29 | val = self.create_list_of_datum("val") 30 | test = self.create_list_of_datum("test") 31 | 32 | n_shots_val = min(num_shots, 4) 33 | 34 | # To take the same percentage of each class (experience 3) 35 | if self.pourcentage > 0: 36 | print("Percentage of the dataset considered :", self.pourcentage) 37 | train = self.generate_pourcent_dataset(train, pourcentage=self.pourcentage) 38 | val = self.generate_pourcent_dataset(val, pourcentage=self.pourcentage) 39 | else: 40 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 41 | val = self.generate_fewshot_dataset(val, num_shots=n_shots_val) 42 | 43 | super().__init__(train_x=train, val=val, test=test) 44 | 45 | def __getitem__(self, im_files, idx, set): 46 | 47 | name = im_files[idx].split("/")[-1] 48 | 49 | if set == "train": 50 | df = pd.read_csv(self.train_csv) 51 | elif set == "test": 52 | df = pd.read_csv(self.test_csv) 53 | elif set == "val": 54 | df = pd.read_csv(self.val_csv) 55 | 56 | interm = df[df["image_name"] == name] 57 | class_name = (interm.loc[:, self.level].values)[0] 58 | 59 | # Creation of a list of cases considered according to classification level. 60 | if self.level == "level_1": 61 | class_name = (interm.loc[:, "level_1"].values)[0] 62 | class_list = sorted(np.unique(df.loc[:, "level_1"].dropna().tolist())) 63 | 64 | class_ = class_list.index(class_name) 65 | 66 | elif self.level == "level_2": 67 | class_name = (interm.loc[:, "level_2"].values)[0] 68 | class_list = sorted(np.unique(df.loc[:, "level_2"].dropna().tolist())) 69 | 70 | if pd.isna(class_name): 71 | class_ = -1 72 | else: 73 | class_ = class_list.index(class_name) 74 | 75 | elif self.level == "level_3": 76 | class_name_3 = (interm.loc[:, self.level].values)[0] 77 | class_list_2 = sorted(np.unique(df.loc[:, "level_2"].dropna().tolist())) 78 | class_list_3 = sorted(np.unique(df.loc[:, "level_3"].dropna().tolist())) 79 | 80 | combined_class_list = np.append(class_list_2, class_list_3) 81 | 82 | cleaned_class_list = pd.Series(combined_class_list).dropna() 83 | 84 | class_list = sorted(np.unique(cleaned_class_list).tolist()) 85 | 86 | if pd.isna(class_name_3): 87 | class_name = (interm.loc[:, "level_2"].values)[0] 88 | else: 89 | class_name = class_name_3 90 | 91 | if pd.isna(class_name): 92 | class_ = -1 93 | else: 94 | class_ = class_list.index(class_name) 95 | 96 | return class_, class_name, im_files[idx] 97 | 98 | def create_list_of_datum(self, set): 99 | """Create a list of Datum objects, each containing the image and label.""" 100 | datum_list = [] 101 | 102 | im_files = glob.glob(os.path.join(self.image_dir, set, "*.jpg")) 103 | 104 | for i in tqdm(range(len(im_files))): 105 | 106 | class_, class_name, impath = self.__getitem__(im_files, i, set) 107 | if class_name is None or class_name == "" or class_ == -1: 108 | continue 109 | 110 | # Create a Datum object 111 | datum = Datum(impath=impath, label=class_, classname=class_name) 112 | 113 | # Append the datum to the list 114 | datum_list.append(datum) 115 | 116 | return datum_list 117 | -------------------------------------------------------------------------------- /datasets/oxford_pets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import random 4 | from collections import defaultdict 5 | 6 | import torchvision.transforms as transforms 7 | 8 | from .utils import Datum, DatasetBase, read_json, write_json, build_data_loader 9 | 10 | """ 11 | template = ['a photo of a {}, a type of pet.'] 12 | """ 13 | template = ['a photo of a {}.'] 14 | 15 | class OxfordPets(DatasetBase): 16 | 17 | dataset_dir = 'OxfordPets' 18 | 19 | def __init__(self, root, num_shots): 20 | self.dataset_dir = os.path.join(root, self.dataset_dir) 21 | self.image_dir = os.path.join(self.dataset_dir, 'images') 22 | self.anno_dir = os.path.join(self.dataset_dir, 'annotations') 23 | self.split_path = os.path.join(self.dataset_dir, 'split_zhou_OxfordPets.json') 24 | 25 | self.template = template 26 | 27 | train, val, test = self.read_split(self.split_path, self.image_dir) 28 | n_shots_val = min(num_shots, 4) 29 | val = self.generate_fewshot_dataset(val, num_shots=n_shots_val) 30 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 31 | 32 | super().__init__(train_x=train, val=val, test=test) 33 | 34 | def read_data(self, split_file): 35 | filepath = os.path.join(self.anno_dir, split_file) 36 | items = [] 37 | 38 | with open(filepath, 'r') as f: 39 | lines = f.readlines() 40 | for line in lines: 41 | line = line.strip() 42 | imname, label, species, _ = line.split(' ') 43 | breed = imname.split('_')[:-1] 44 | breed = '_'.join(breed) 45 | breed = breed.lower() 46 | imname += '.jpg' 47 | impath = os.path.join(self.image_dir, imname) 48 | label = int(label) - 1 # convert to 0-based index 49 | item = Datum( 50 | impath=impath, 51 | label=label, 52 | classname=breed 53 | ) 54 | items.append(item) 55 | 56 | return items 57 | 58 | @staticmethod 59 | def split_trainval(trainval, p_val=0.2): 60 | p_trn = 1 - p_val 61 | print(f'Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val') 62 | tracker = defaultdict(list) 63 | for idx, item in enumerate(trainval): 64 | label = item.label 65 | tracker[label].append(idx) 66 | 67 | train, val = [], [] 68 | for label, idxs in tracker.items(): 69 | n_val = round(len(idxs) * p_val) 70 | assert n_val > 0 71 | random.shuffle(idxs) 72 | for n, idx in enumerate(idxs): 73 | item = trainval[idx] 74 | if n < n_val: 75 | val.append(item) 76 | else: 77 | train.append(item) 78 | 79 | return train, val 80 | 81 | @staticmethod 82 | def save_split(train, val, test, filepath, path_prefix): 83 | def _extract(items): 84 | out = [] 85 | for item in items: 86 | impath = item.impath 87 | label = item.label 88 | classname = item.classname 89 | impath = impath.replace(path_prefix, '') 90 | if impath.startswith('/'): 91 | impath = impath[1:] 92 | out.append((impath, label, classname)) 93 | return out 94 | 95 | train = _extract(train) 96 | val = _extract(val) 97 | test = _extract(test) 98 | 99 | split = { 100 | 'train': train, 101 | 'val': val, 102 | 'test': test 103 | } 104 | 105 | write_json(split, filepath) 106 | print(f'Saved split to {filepath}') 107 | 108 | @staticmethod 109 | def read_split(filepath, path_prefix): 110 | def _convert(items): 111 | out = [] 112 | for impath, label, classname in items: 113 | impath = os.path.join(path_prefix, impath) 114 | item = Datum( 115 | impath=impath, 116 | label=int(label), 117 | classname=classname 118 | ) 119 | out.append(item) 120 | return out 121 | 122 | print(f'Reading split from {filepath}') 123 | split = read_json(filepath) 124 | train = _convert(split['train']) 125 | val = _convert(split['val']) 126 | test = _convert(split['test']) 127 | 128 | return train, val, test -------------------------------------------------------------------------------- /run_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import argparse 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def set_random_seed(seed): 8 | random.seed(seed) 9 | np.random.seed(seed) 10 | torch.manual_seed(seed) 11 | torch.cuda.manual_seed_all(seed) 12 | 13 | 14 | def get_arguments(): 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--seed", type=int, default=1, help="Seed number") 18 | 19 | # Dataset arguments 20 | parser.add_argument( 21 | "--root_path", 22 | type=str, 23 | default="", 24 | help="Path of your root directory. We put our dataset in it.", 25 | ) 26 | parser.add_argument( 27 | "--dataset", 28 | type=str, 29 | default="mlcc", 30 | help="Name of the dataset used", 31 | choices=["mlcc", "bcfc", "sipakmed", "hicervix"], 32 | ) 33 | parser.add_argument("--shots", type=int, default=16, help="Shot number") 34 | parser.add_argument( 35 | "--percentage", 36 | type=float, 37 | default=0.0, 38 | help="Percentage of the dataset considered. Used for the third experiment.", 39 | ) 40 | parser.add_argument( 41 | "--textual", 42 | type=str, 43 | default="False", 44 | help="If True, the classifier is initialized with textual embeddings. If False, the textual information is ignored.", 45 | ) 46 | parser.add_argument( 47 | "--task", 48 | type=str, 49 | default="lora", 50 | help="Task name", 51 | choices=["classifier", "lora", "percentage_lora"], 52 | ) 53 | 54 | # Model arguments 55 | parser.add_argument( 56 | "--model_name", 57 | type=str, 58 | default="clip", 59 | help="Name of the model used", 60 | choices=["clip", "quilt", "biomedclip", "vit_google", "uni"], 61 | ) 62 | parser.add_argument( 63 | "--num_classes", 64 | type=int, 65 | default=2, 66 | help="Number of classes considered for the classification task", 67 | ) 68 | parser.add_argument( 69 | "--level", 70 | type=str, 71 | default="level_1", 72 | help="This is the level of the hierarchical tree to capture different fine-grained subtype information. Only applicable in the case of hicervix.", 73 | choices=["level_1", "level_2", "level_3", "class_name"], 74 | ) 75 | parser.add_argument( 76 | "--backbone", 77 | type=str, 78 | default="ViT-B/16", 79 | help="Configuration of the model's backbone", 80 | choices=["ViT-L/14", "ViT-B/16"], 81 | ) 82 | 83 | # Training arguments 84 | parser.add_argument("--lr", type=float, default=0.001, help="Learning rate") 85 | parser.add_argument("--n_iters", type=int, default=500, help="Number of iterations") 86 | parser.add_argument("--batch_size", type=int, default=32, help="Size of the batch") 87 | 88 | # Argument definition 89 | parser.add_argument( 90 | "--position", 91 | type=str, 92 | default="all", 93 | help="where to put the LoRA modules", 94 | choices=["bottom", "mid", "up", "half-up", "half-bottom", "all", "top3"], 95 | ) 96 | 97 | parser.add_argument( 98 | "--encoder", 99 | type=str, 100 | default="both", 101 | choices=["text", "vision", "both"], 102 | help="It is the part of the model on which we want apply LoRA, either on the visual or textual part.", 103 | ) 104 | 105 | parser.add_argument( 106 | "--params", 107 | type=str, 108 | metavar="N", 109 | nargs="+", 110 | default=["q", "k", "v"], 111 | help="list of attention matrices where putting a LoRA", 112 | ) 113 | 114 | parser.add_argument( 115 | "--r", type=int, default=2, help="the rank of the low-rank matrices" 116 | ) 117 | 118 | parser.add_argument("--alpha", default=1, type=int, help="scaling (see LoRA paper)") 119 | 120 | parser.add_argument( 121 | "--dropout_rate", 122 | default=0.25, 123 | type=float, 124 | help="dropout rate applied before the LoRA module", 125 | ) 126 | 127 | parser.add_argument( 128 | "--save_path", 129 | default=None, 130 | help="path to save the lora modules after training, not saved if None", 131 | ) 132 | parser.add_argument( 133 | "--filename", 134 | default="lora_weights", 135 | help="file name to save the lora weights (.pt extension will be added)", 136 | ) 137 | 138 | parser.add_argument( 139 | "--eval_only", 140 | default=False, 141 | action="store_true", 142 | help="only evaluate the LoRA modules (save_path should not be None)", 143 | ) 144 | 145 | args = parser.parse_args() 146 | 147 | return args 148 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import pandas as pd 5 | 6 | 7 | if __name__ == "__main__": 8 | 9 | # Parser creation 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--seed_launch", type=int, default=1, help="Seed number") 12 | parser.add_argument("--shots_launch", type=int, default=16, help="Shot number") 13 | parser.add_argument("--lr_launch", type=float, default=0.001, help="Learning rate") 14 | parser.add_argument( 15 | "--iterations", type=int, default=50, help="Number of iteration" 16 | ) 17 | parser.add_argument( 18 | "--parameters", 19 | type=str, 20 | default="q v", 21 | help="Layers of the model on which LoRA will be applied", 22 | ) 23 | parser.add_argument( 24 | "--rank_launch", 25 | type=int, 26 | default=8, 27 | help="Rank of matrices on which LoRA will be applied", 28 | ) 29 | 30 | parser.add_argument( 31 | "--model_launch", 32 | type=str, 33 | default="clip", 34 | help="Name of the model used", 35 | choices=["clip", "quilt", "biomedclip", "vit_google", "uni"], 36 | ) 37 | parser.add_argument( 38 | "--textual_launch", 39 | type=str, 40 | default="False", 41 | help="If True, the classifier is initialized with textual embeddings. If False, the textual information is ignored.", 42 | ) 43 | parser.add_argument( 44 | "--dataset_launch", 45 | type=str, 46 | default="mlcc", 47 | help="Name of the dataset used", 48 | choices=[ 49 | "mlcc", 50 | "bcfc", 51 | "sipakmed", 52 | "hicervix" 53 | ], 54 | ) 55 | parser.add_argument( 56 | "--level_launch", 57 | type=str, 58 | default="level_1", 59 | help="This is the level of the hierarchical tree to capture different fine-grained subtype information. Only applicable in the case of hicervix.", 60 | ) 61 | parser.add_argument( 62 | "--percent_launch", 63 | type=float, 64 | default=0.0, 65 | help="Percentage of the dataset considered. Used for the third experiment.", 66 | ) 67 | 68 | parser.add_argument( 69 | "--task_launch", 70 | type=str, 71 | default="lora", 72 | help="Task name", 73 | choices=["classifier", "lora", "percentage_lora"], 74 | ) 75 | args = parser.parse_args() 76 | 77 | # Variables 78 | seed = args.seed_launch 79 | shot = args.shots_launch 80 | lr = args.lr_launch 81 | n_iters = args.iterations 82 | params = args.parameters 83 | position = "all" 84 | encoder = "vision" 85 | r = args.rank_launch 86 | 87 | model_name = args.model_launch 88 | textual = args.textual_launch 89 | dataset = args.dataset_launch 90 | level = args.level_launch 91 | percent = args.percent_launch 92 | 93 | task = args.task_launch 94 | 95 | if dataset == "mlcc": 96 | num_classes = 4 97 | root_path = "path_of_dataset" # TO CHANGE 98 | 99 | elif dataset == "bcfc": 100 | num_classes = 2 101 | root_path = "path_of_dataset" # TO CHANGE 102 | 103 | elif dataset == "sipakmed": 104 | num_classes = 5 105 | root_path = "path_of_dataset" # TO CHANGE 106 | 107 | elif dataset == "hicervix": 108 | df = pd.read_csv("path_of_dataset") # TO CHANGE 109 | 110 | if level == "level_3": 111 | class_list_2 = sorted(np.unique(df.loc[:, "level_2"].dropna().tolist())) 112 | class_list_3 = sorted(np.unique(df.loc[:, "level_3"].dropna().tolist())) 113 | 114 | combined_class_list = np.append(class_list_2, class_list_3) 115 | 116 | cleaned_class_list = pd.Series(combined_class_list).dropna() 117 | 118 | class_list = sorted(np.unique(cleaned_class_list).tolist()) 119 | 120 | elif level == "level_2": 121 | class_list_2 = sorted(np.unique(df.loc[:, "level_2"].dropna().tolist())) 122 | class_list_1 = sorted(np.unique(df.loc[:, "level_1"].dropna().tolist())) 123 | 124 | combined_class_list = np.append(class_list_2, class_list_1) 125 | 126 | cleaned_class_list = pd.Series(combined_class_list).dropna() 127 | 128 | class_list = sorted(np.unique(cleaned_class_list).tolist()) 129 | 130 | else: 131 | class_list = np.unique(df.loc[:, level].tolist()) 132 | 133 | num_classes = len(class_list) 134 | root_path = "path_of_the_dataset" # TO CHANGE 135 | 136 | else: 137 | raise RuntimeError("Wrong dataset") 138 | 139 | if model_name in ["clip", "quilt", "biomedclip", "vit_google"]: 140 | backbone = "ViT-B/16" 141 | elif model_name in ["uni"]: 142 | backbone = "ViT-L/14" 143 | 144 | if task == "percentage_lora": 145 | backbone = "ViT-L/14" 146 | 147 | print(f"Run started: model {model_name}, lr {lr}, r {r}, seed {seed}") 148 | 149 | os.system( 150 | f"python3 main.py --root_path {root_path} \ 151 | --dataset {dataset} --seed {seed} --shots {shot} --lr {lr} \ 152 | --n_iters {n_iters} --position {position} --encoder {encoder} --percentage {percent}\ 153 | --params {params} --r {r} --model_name {model_name} --num_classes {num_classes}\ 154 | --level {level} --backbone {backbone} --textual {textual} --task {task}" 155 | ) 156 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Exploring Foundation Models Fine-Tuning for Cytology Tasks [Accepted to ISBI 2025] 2 | 3 | Implementation of **[Exploring Foundation Models Fine-Tuning for Cytology Tasks](https://doi.org/10.48550/arXiv.2411.14975)**. 4 | 5 | In this paper, we explore the application of existing foundation models to cytological classification tasks, focusing on low-rank adaptation (LoRA), a parameter-efficient fine-tuning method well-suited to few-shot learning scenarios. We evaluate five foundation models across four cytological classification datasets. Our results demonstrate that fine-tuning the pre-trained backbones with LoRA significantly enhances model performance compared to merely fine-tuning the classifier head, achieving state-of-the-art results on both simple and complex classification tasks while requiring fewer data samples. 6 | 7 | **Authors**: [M. Dausort](https://scholar.google.com/citations?user=hXTkITwAAAAJ&hl=en), [T. Godelaine](https://scholar.google.com/citations?user=xKcPd0oAAAAJ&hl=en&oi=ao), [M. Zanella](https://scholar.google.com/citations?user=FIoE9YIAAAAJ&hl=fr&oi=ao), [K. El Khoury](https://scholar.google.be/citations?user=UU_keGAAAAAJ&hl=fr), [I. Salmon](https://scholar.google.be/citations?user=S1dmusUAAAAJ&hl=en), [B. Macq](https://scholar.google.be/citations?user=H9pGN70AAAAJ&hl=fr) 8 | 9 | 📌 **NB:** This GitHub repository is based on the implementation of [CLIP-LoRA](https://github.com/MaxZanella/CLIP-LoRA). 10 | 11 | ## Contents 12 | 13 | - [Installation](#installation) 14 | - [Usage](#usage) 15 | - [Contact](#contact) 16 | 17 | ## Installation 18 | 19 | 📌 **NB:** The Python version used is 3.9.13. 20 | 21 | 1. Create a virtual environment 22 | ```bash 23 | python3 -m venv cyto_ft_venv 24 | source cyto_ft_venv/bin/activate 25 | ``` 26 | 27 | Clone the GitHub repository 28 | ```bash 29 | pip3 install torch==2.2.2 torchaudio==2.2.2 torchvision==0.17.2 30 | git clone https://github.com/mdausort/Cytology-fine-tuning.git 31 | ``` 32 | 33 | Install the required packages 34 | ```bash 35 | cd Cytology-fine-tuning 36 | pip3 install -r requirements.txt 37 | ``` 38 | 39 | 40 | 2. Datasets downloads: 41 | 42 | | Dataset | 🔗 Download Link | 43 | | -------------- | -------------------------------------------------------------------------------------------------------- | 44 | | BCFC | [📥 Link](https://www.kaggle.com/datasets/cmacus/body-cavity-fluid-cytology-images) | 45 | | MLCC | [📥 Link](https://www.kaggle.com/datasets/blank1508/mendeley-lbc-cervical-cancer-) | 46 | | SIPaKMeD | [📥 Link](https://www.kaggle.com/datasets/prahladmehandiratta/cervical-cancer-largest-dataset-sipakmed) | 47 | | HiCervix | [📥 Link](https://zenodo.org/records/11087263) | 48 | 49 | 📌 Each dataset must be divided into three folders: train, val and test. Images were named following this structure: *classname_number*. 50 | **Important**: All file paths in scripts are set with the placeholder "TO CHANGE". You will need to search for this placeholder in the cloned repository's files and replace it with the appropriate path ```/root/path/``` as specified for your system. In this setup, we have placed the different datasets inside a folder named `./data`. 51 | 52 | ## Usage 53 | 54 | To launch the experiments, use the provided `launch_run.sh` bash script: 55 | 56 | 1. Open the relevant script and locate the required line for configuration (e.g., line 28 for Experiment 1). Uncomment this line to enable the specific settings needed for the experiment. 57 | 2. Start the experiment by executing the `launch_run.sh` script: 58 | 59 | ```bash 60 | bash launch_run.sh 61 | ``` 62 | 63 | 3. To visualize the changes and track the experiment's progress, you must integrate your code with Weights & Biases. Add the following line to your script if it's not already included: 64 | ```python 65 | import wandb 66 | wandb.init(project='your_project_name') 67 | ``` 68 | You can view the results and metrics of your experiment on [Weights & Biases](https://wandb.ai/site). 69 | 70 | 4. The results of the experiment are also saved into a JSON file for further analysis or documentation. 71 | 72 | | Experiment | Command line | 73 | | -----------------------| --------------------------------------------------------------------------------------------------------------------------------- | 74 | | **Linear Classifier** | `python3 main.py --root_path ./data/ --dataset {dataset} --seed {seed} --shots -1 --lr {lr} --n_iters 50 --model_name {model_name} --num_classes {num_classes} --level {level} --textual False --task classifier` | 75 | | **LoRA Few-Shot** | `python3 --root_path ./data/ --dataset {dataset} --seed {seed} --shots {shots} --lr {lr} --n_iters 50 --position "all" --encoder "vision" --params "q v" --r 2 --model_name {model_name} --num_classes {num_classes} --level {level} --task lora` | 76 | | **Advanced LoRA** | `python3 run.py --root_path ./data/ --dataset hicervix --seed {seed} --shots 0 --lr 1e-3 --n_iters 100 --position "all" --encoder "vision" --pourcentage {pourcentage} --params "q k v o" --r 16 --model_name clip --level level_3 --task percentage_lora` | 77 | 78 | 79 | ## Contact 80 | 81 | If you have any questions, you can contact us by email: [manon.dausort@uclouvain.be](mailto\:manon.dausort@uclouvain.be), [tiffanie.godelaine@uclouvain.be](mailto\:tiffanie.godelaine@uclouvain.be) 82 | -------------------------------------------------------------------------------- /loralib/easymultiheadattention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | """ 6 | Source : https://github.com/KyanChen/MakeMultiHeadNaive/blob/master/main.py 7 | """ 8 | 9 | 10 | class PlainMultiHeadAttention(nn.Module): 11 | def __init__(self, existing_mha: nn.MultiheadAttention): 12 | super().__init__() 13 | 14 | self.dropout = 0 # this module is not used to retrain the main block 15 | self.embed_dim = existing_mha.embed_dim 16 | self.kdim = existing_mha.kdim 17 | self.vdim = existing_mha.vdim 18 | self._qkv_same_embed_dim = existing_mha._qkv_same_embed_dim 19 | self.num_heads = existing_mha.num_heads 20 | self.batch_first = existing_mha.batch_first 21 | self.head_dim = existing_mha.head_dim 22 | self.qkv = nn.Linear( 23 | self.embed_dim, 24 | self.embed_dim * 3, 25 | bias=existing_mha.in_proj_bias is not None, 26 | ) 27 | self.proj = nn.Linear( 28 | self.embed_dim, self.embed_dim, bias=existing_mha.out_proj.bias is not None 29 | ) 30 | 31 | # Initialize parameters 32 | with torch.no_grad(): 33 | self.qkv.weight.data.copy_(existing_mha.in_proj_weight.data) 34 | if self.qkv.bias is not None: 35 | self.qkv.bias.data.copy_(existing_mha.in_proj_bias.data) 36 | self.proj.weight.data.copy_(existing_mha.out_proj.weight.data) 37 | if self.proj.bias is not None: 38 | self.proj.bias.data.copy_(existing_mha.out_proj.bias.data) 39 | 40 | self.scaled_dot_product_attention = F.scaled_dot_product_attention 41 | 42 | def forward( 43 | self, 44 | query, 45 | key, 46 | value, 47 | key_padding_mask=None, 48 | need_weights=True, 49 | attn_mask=None, 50 | average_attn_weights=True, 51 | is_causal=False, 52 | ): 53 | 54 | if attn_mask is not None and is_causal: 55 | raise AssertionError("Only allow causal mask or attn_mask") 56 | is_batched = query.dim() == 3 57 | key_padding_mask = F._canonical_mask( 58 | mask=key_padding_mask, 59 | mask_name="key_padding_mask", 60 | other_type=F._none_or_dtype(attn_mask), 61 | other_name="attn_mask", 62 | target_type=query.dtype, 63 | ) 64 | 65 | if self.batch_first and is_batched: 66 | if key is value: 67 | if query is key: 68 | query = key = value = query.transpose(1, 0) 69 | else: 70 | query, key = [x.transpose(1, 0) for x in (query, key)] 71 | value = key 72 | else: 73 | query, key, value = [x.transpose(1, 0) for x in (query, key, value)] 74 | 75 | tgt_len, bsz, embed_dim = query.shape 76 | src_len, _, _ = key.shape 77 | 78 | E = query.size(-1) 79 | qkv = self.qkv(query) 80 | qkv = ( 81 | qkv.unflatten(-1, (3, E)) 82 | .unsqueeze(0) 83 | .transpose(0, -2) 84 | .squeeze(-2) 85 | .contiguous() 86 | ) 87 | q, k, v = qkv[0], qkv[1], qkv[2] 88 | 89 | attn_mask = F._canonical_mask( 90 | mask=attn_mask, 91 | mask_name="attn_mask", 92 | other_type=F._none_or_dtype(key_padding_mask), 93 | other_name="key_padding_mask", 94 | target_type=q.dtype, 95 | check_other=False, 96 | ) 97 | 98 | if attn_mask is not None: 99 | # ensure attn_mask's dim is 3 100 | if attn_mask.dim() == 2: 101 | correct_2d_size = (tgt_len, src_len) 102 | if attn_mask.shape != correct_2d_size: 103 | raise RuntimeError( 104 | f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}." 105 | ) 106 | attn_mask = attn_mask.unsqueeze(0) 107 | elif attn_mask.dim() == 3: 108 | correct_3d_size = (bsz * self.num_heads, tgt_len, src_len) 109 | if attn_mask.shape != correct_3d_size: 110 | raise RuntimeError( 111 | f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}." 112 | ) 113 | else: 114 | raise RuntimeError( 115 | f"attn_mask's dimension {attn_mask.dim()} is not supported" 116 | ) 117 | 118 | if attn_mask is not None: 119 | if attn_mask.size(0) == 1 and attn_mask.dim() == 3: 120 | attn_mask = attn_mask.unsqueeze(0) 121 | else: 122 | attn_mask = attn_mask.view(bsz, self.num_heads, -1, src_len) 123 | 124 | dropout_p = self.dropout if self.training else 0.0 125 | 126 | q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) 127 | k = k.view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) 128 | v = v.view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) 129 | src_len = k.size(1) 130 | q = q.view(bsz, self.num_heads, tgt_len, self.head_dim) 131 | k = k.view(bsz, self.num_heads, src_len, self.head_dim) 132 | v = v.view(bsz, self.num_heads, src_len, self.head_dim) 133 | 134 | attn_output = self.scaled_dot_product_attention( 135 | q, k, v, attn_mask, dropout_p, is_causal 136 | ) 137 | attn_output = ( 138 | attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) 139 | ) 140 | attn_output = self.proj(attn_output) 141 | attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) 142 | if self.batch_first and is_batched: 143 | return attn_output.transpose(1, 0), None 144 | return attn_output, None 145 | -------------------------------------------------------------------------------- /datasets/cu_us.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import pandas as pd 4 | from PIL import Image 5 | from .utils import Datum, DatasetBase # type: ignore 6 | 7 | 8 | class USDataset_C1(DatasetBase): 9 | 10 | dataset_dir = "Dataset/US_CU/" 11 | classes = ["C11", "C12"] 12 | 13 | def __init__(self, root, num_shots, preprocess=None): 14 | self.root = root 15 | self.dataset_dir = os.path.join(root, self.dataset_dir + "split_data/") 16 | self.image_dir = os.path.join(self.dataset_dir, "step3_c1") 17 | 18 | train = self.create_list_of_datum("train") 19 | val = self.create_list_of_datum("val") 20 | test = self.create_list_of_datum("test") 21 | 22 | n_shots_val = min(num_shots, 4) 23 | val = self.generate_fewshot_dataset(val, num_shots=n_shots_val) 24 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 25 | 26 | super().__init__(train_x=train, val=val, test=test) 27 | 28 | def __getitem__(self, im_files, idx): 29 | 30 | image = Image.open(im_files[idx]).convert("RGB") 31 | image_name = im_files[idx].split("/")[-1].split(".")[0] 32 | 33 | df = pd.read_excel(os.path.join(self.root, "Dataset/US_CU/Labels.xlsx")) 34 | self.labels_dict = dict(zip(df["Cyto"], df["C1"])) 35 | 36 | class_name = self.labels_dict[image_name] 37 | class_ = self.classes.index(class_name) 38 | return image, class_, class_name, im_files[idx] 39 | 40 | def create_list_of_datum(self, set): 41 | """ 42 | Create a list of Datum objects, each containing the image and label. 43 | """ 44 | 45 | datum_list = [] 46 | 47 | im_files = glob.glob(os.path.join(self.image_dir, set, "*.jpg")) 48 | for i in range(len(im_files)): 49 | # Get the image and the class 50 | image, class_, class_name, impath = self.__getitem__(im_files, i) 51 | 52 | # Create a Datum object 53 | datum = Datum(impath=impath, label=class_, classname=class_name) 54 | 55 | # Append the datum to the list 56 | datum_list.append(datum) 57 | 58 | return datum_list 59 | 60 | 61 | class USDataset_C2(DatasetBase): 62 | 63 | dataset_dir = "Dataset/US_CU/" 64 | classes = ["C21", "C22", "C23"] 65 | 66 | def __init__(self, root, num_shots, preprocess=None): 67 | self.root = root 68 | self.dataset_dir = os.path.join(root, self.dataset_dir + "split_data/") 69 | self.image_dir = os.path.join(self.dataset_dir, "step3_c2") 70 | 71 | train = self.create_list_of_datum("train") 72 | val = self.create_list_of_datum("val") 73 | test = self.create_list_of_datum("test") 74 | 75 | n_shots_val = min(num_shots, 4) 76 | val = self.generate_fewshot_dataset(val, num_shots=n_shots_val) 77 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 78 | 79 | super().__init__(train_x=train, val=val, test=test) 80 | 81 | def __getitem__(self, im_files, idx): 82 | 83 | image = Image.open(im_files[idx]).convert("RGB") 84 | image_name = im_files[idx].split("/")[-1].split(".")[0] 85 | 86 | df = pd.read_excel(os.path.join(self.root, "Dataset/US_CU/Labels.xlsx")) 87 | self.labels_dict = dict(zip(df["Cyto"], df["C2"])) 88 | 89 | class_name = self.labels_dict[image_name] 90 | 91 | class_ = self.classes.index(class_name) 92 | 93 | return image, class_, class_name, im_files[idx] 94 | 95 | def create_list_of_datum(self, set): 96 | """ 97 | Create a list of Datum objects, each containing the image and label. 98 | """ 99 | 100 | datum_list = [] 101 | 102 | im_files = glob.glob(os.path.join(self.image_dir, set, "*.jpg")) 103 | for i in range(len(im_files)): 104 | # Get the image and the class 105 | image, class_, class_name, impath = self.__getitem__(im_files, i) 106 | 107 | # Create a Datum object 108 | datum = Datum(impath=impath, label=class_, classname=class_name) 109 | 110 | # Append the datum to the list 111 | datum_list.append(datum) 112 | 113 | return datum_list 114 | 115 | 116 | class USDataset_C3(DatasetBase): 117 | 118 | dataset_dir = "Dataset/US_CU/" 119 | classes = ["C31", "C32", "C33", "C34", "C35"] 120 | 121 | def __init__(self, root, num_shots, preprocess=None): 122 | self.root = root 123 | self.dataset_dir = os.path.join(root, self.dataset_dir + "split_data/") 124 | self.image_dir = os.path.join(self.dataset_dir, "step3_c3") 125 | 126 | train = self.create_list_of_datum("train") 127 | val = self.create_list_of_datum("val") 128 | test = self.create_list_of_datum("test") 129 | 130 | n_shots_val = min(num_shots, 4) 131 | val = self.generate_fewshot_dataset(val, num_shots=n_shots_val) 132 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 133 | 134 | super().__init__(train_x=train, val=val, test=test) 135 | 136 | def __getitem__(self, im_files, idx): 137 | 138 | image = Image.open(im_files[idx]).convert("RGB") 139 | image_name = im_files[idx].split("/")[-1].split(".")[0] 140 | 141 | df = pd.read_excel(os.path.join(self.root, "Dataset/US_CU/Labels.xlsx")) 142 | self.labels_dict = dict(zip(df["Cyto"], df["C3"])) 143 | 144 | class_name = self.labels_dict[image_name] 145 | 146 | class_ = self.classes.index(class_name) 147 | 148 | return image, class_, class_name, im_files[idx] 149 | 150 | def create_list_of_datum(self, set): 151 | """ 152 | Create a list of Datum objects, each containing the image and label. 153 | """ 154 | 155 | datum_list = [] 156 | 157 | im_files = glob.glob(os.path.join(self.image_dir, set, "*.jpg")) 158 | for i in range(len(im_files)): 159 | # Get the image and the class 160 | image, class_, class_name, impath = self.__getitem__(im_files, i) 161 | 162 | # Create a Datum object 163 | datum = Datum(impath=impath, label=class_, classname=class_name) 164 | 165 | # Append the datum to the list 166 | datum_list.append(datum) 167 | 168 | return datum_list 169 | -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | try: 16 | from torchvision.transforms import InterpolationMode 17 | BICUBIC = InterpolationMode.BICUBIC 18 | except ImportError: 19 | BICUBIC = Image.BICUBIC 20 | 21 | 22 | if torch.__version__.split(".") < ["1", "7", "1"]: 23 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 24 | 25 | 26 | __all__ = ["available_models", "load", "tokenize"] 27 | _tokenizer = _Tokenizer() 28 | 29 | _MODELS = { 30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 34 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 35 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 36 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 37 | # "ViT-L/16": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-16.pt", 38 | } 39 | 40 | 41 | def _download(url: str, root: str): 42 | os.makedirs(root, exist_ok=True) 43 | filename = os.path.basename(url) 44 | 45 | expected_sha256 = url.split("/")[-2] 46 | download_target = os.path.join(root, filename) 47 | 48 | if os.path.exists(download_target) and not os.path.isfile(download_target): 49 | raise RuntimeError(f"{download_target} exists and is not a regular file") 50 | 51 | if os.path.isfile(download_target): 52 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 53 | return download_target 54 | else: 55 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 56 | 57 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 58 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 59 | while True: 60 | buffer = source.read(8192) 61 | if not buffer: 62 | break 63 | 64 | output.write(buffer) 65 | loop.update(len(buffer)) 66 | 67 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 68 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 69 | 70 | return download_target 71 | 72 | 73 | def _convert_image_to_rgb(image): 74 | return image.convert("RGB") 75 | 76 | 77 | def _transform(n_px): 78 | return Compose([ 79 | Resize(n_px, interpolation=BICUBIC), 80 | CenterCrop(n_px), 81 | _convert_image_to_rgb, 82 | ToTensor(), 83 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 84 | ]) 85 | 86 | 87 | def available_models() -> List[str]: 88 | """Returns the names of available CLIP models""" 89 | return list(_MODELS.keys()) 90 | 91 | 92 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 93 | """Load a CLIP model 94 | 95 | Parameters 96 | ---------- 97 | name : str 98 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 99 | 100 | device : Union[str, torch.device] 101 | The device to put the loaded model 102 | 103 | jit : bool 104 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 105 | 106 | download_root: str 107 | path to download the model files; by default, it uses "~/.cache/clip" 108 | 109 | Returns 110 | ------- 111 | model : torch.nn.Module 112 | The CLIP model 113 | 114 | preprocess : Callable[[PIL.Image], torch.Tensor] 115 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 116 | """ 117 | if name in _MODELS: 118 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 119 | elif os.path.isfile(name): 120 | model_path = name 121 | else: 122 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 123 | 124 | try: 125 | # loading JIT archive 126 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 127 | state_dict = None 128 | except RuntimeError: 129 | # loading saved state dict 130 | if jit: 131 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 132 | jit = False 133 | state_dict = torch.load(model_path, map_location="cpu") 134 | 135 | if not jit: 136 | model = build_model(state_dict or model.state_dict()).to(device) 137 | if str(device) == "cpu": 138 | model.float() 139 | return model, _transform(model.visual.input_resolution) 140 | 141 | # patch the device names 142 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 143 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 144 | 145 | def patch_device(module): 146 | try: 147 | graphs = [module.graph] if hasattr(module, "graph") else [] 148 | except RuntimeError: 149 | graphs = [] 150 | 151 | if hasattr(module, "forward1"): 152 | graphs.append(module.forward1.graph) 153 | 154 | for graph in graphs: 155 | for node in graph.findAllNodes("prim::Constant"): 156 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 157 | node.copyAttributes(device_node) 158 | 159 | model.apply(patch_device) 160 | patch_device(model.encode_image) 161 | patch_device(model.encode_text) 162 | 163 | # patch dtype to float32 on CPU 164 | if str(device) == "cpu": 165 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 166 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 167 | float_node = float_input.node() 168 | 169 | def patch_float(module): 170 | try: 171 | graphs = [module.graph] if hasattr(module, "graph") else [] 172 | except RuntimeError: 173 | graphs = [] 174 | 175 | if hasattr(module, "forward1"): 176 | graphs.append(module.forward1.graph) 177 | 178 | for graph in graphs: 179 | for node in graph.findAllNodes("aten::to"): 180 | inputs = list(node.inputs()) 181 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 182 | if inputs[i].node()["value"] == 5: 183 | inputs[i].node().copyAttributes(float_node) 184 | 185 | model.apply(patch_float) 186 | patch_float(model.encode_image) 187 | patch_float(model.encode_text) 188 | 189 | model.float() 190 | 191 | return model, _transform(model.input_resolution.item()) 192 | 193 | 194 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 195 | """ 196 | Returns the tokenized representation of given input string(s) 197 | 198 | Parameters 199 | ---------- 200 | texts : Union[str, List[str]] 201 | An input string or a list of input strings to tokenize 202 | 203 | context_length : int 204 | The context length to use; all CLIP models use 77 as the context length 205 | 206 | truncate: bool 207 | Whether to truncate the text in case its encoding is longer than the context length 208 | 209 | Returns 210 | ------- 211 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 212 | """ 213 | if isinstance(texts, str): 214 | texts = [texts] 215 | 216 | sot_token = _tokenizer.encoder["<|startoftext|>"] 217 | eot_token = _tokenizer.encoder["<|endoftext|>"] 218 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 219 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 220 | 221 | for i, tokens in enumerate(all_tokens): 222 | if len(tokens) > context_length: 223 | if truncate: 224 | tokens = tokens[:context_length] 225 | tokens[-1] = eot_token 226 | else: 227 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 228 | result[i, :len(tokens)] = torch.tensor(tokens) 229 | 230 | return result 231 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import timm 3 | import clip 4 | import torch 5 | import open_clip 6 | from datasets import build_dataset # type: ignore 7 | import torchvision.transforms as transforms 8 | from datasets.utils import build_data_loader # type: ignore 9 | from lora import run_uni, run_uni_lora, run_uni_lora_percent # type: ignore 10 | from run_utils import set_random_seed, get_arguments # type: ignore 11 | from transformers import AutoModelForImageClassification 12 | from features import ( # type: ignore 13 | features_extractor, 14 | FeaturesDataset, 15 | textual_extractor, 16 | ) 17 | 18 | 19 | def main(): 20 | 21 | args = get_arguments() 22 | 23 | set_random_seed(args.seed) 24 | 25 | # -------------------------------- Models -------------------------------- 26 | tokenizer = None 27 | 28 | if args.model_name == "clip": 29 | model_clip, _ = clip.load(args.backbone) 30 | tokenizer = clip.tokenize 31 | 32 | # Preprocess for CLIP 33 | preprocess = transforms.Compose( 34 | [ 35 | transforms.ToTensor(), 36 | transforms.Normalize( 37 | mean=(0.48145466, 0.4578275, 0.40821073), 38 | std=(0.26862954, 0.26130258, 0.27577711), 39 | ), 40 | ] 41 | ) 42 | 43 | elif args.model_name == "quilt": 44 | model_clip, _, _ = open_clip.create_model_and_transforms( 45 | "hf-hub:wisdomik/QuiltNet-B-32" 46 | ) 47 | tokenizer = open_clip.get_tokenizer("hf-hub:wisdomik/QuiltNet-B-32") 48 | 49 | # Preprocess for Quilt 50 | preprocess = transforms.Compose( 51 | [ 52 | transforms.ToTensor(), 53 | transforms.Normalize( 54 | mean=(0.48145466, 0.4578275, 0.40821073), 55 | std=(0.26862954, 0.26130258, 0.27577711), 56 | ), 57 | ] 58 | ) 59 | 60 | elif args.model_name == "uni": 61 | model_clip = timm.create_model( 62 | "hf-hub:MahmoodLab/UNI", 63 | pretrained=True, 64 | init_values=1e-5, 65 | dynamic_img_size=True, 66 | ) 67 | 68 | # Preprocess for UNI 69 | preprocess = transforms.Compose( 70 | [ 71 | transforms.ToTensor(), 72 | transforms.Normalize( 73 | mean=(0.485, 0.456, 0.406), 74 | std=(0.229, 0.224, 0.225), 75 | ), 76 | ] 77 | ) 78 | 79 | elif args.model_name == "vit_google": 80 | model_clip = AutoModelForImageClassification.from_pretrained( 81 | "google/vit-base-patch16-224" 82 | ) 83 | 84 | # Preprocess for ViT-Google 85 | preprocess = transforms.Compose([ 86 | transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BILINEAR), 87 | transforms.ToTensor(), 88 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 89 | ]) 90 | 91 | elif args.model_name == "biomedclip": 92 | model_clip, preprocess, _ = open_clip.create_model_and_transforms( 93 | "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224" 94 | ) 95 | tokenizer = open_clip.get_tokenizer( 96 | "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224" 97 | ) 98 | 99 | # Preprocess for BiomedCLIP 100 | preprocess = transforms.Compose( 101 | [ 102 | transforms.ToTensor(), 103 | transforms.Normalize( 104 | mean=(0.48145466, 0.4578275, 0.40821073), 105 | std=(0.26862954, 0.26130258, 0.27577711), 106 | ), 107 | ] 108 | ) 109 | 110 | else: 111 | raise RuntimeError( 112 | "Wrong model name used. Try clip, uni, biomedclip, vit_google or quilt." 113 | ) 114 | 115 | da_transform = transforms.Compose( 116 | [ 117 | transforms.RandomResizedCrop( 118 | size=224, 119 | scale=(0.7, 1), 120 | interpolation=transforms.InterpolationMode.BICUBIC, 121 | ), 122 | transforms.RandomHorizontalFlip(p=0.5), 123 | ] 124 | ) 125 | 126 | train_transform = transforms.Compose( 127 | [preprocess, da_transform] 128 | ) 129 | 130 | model_clip.eval() 131 | model_clip.cuda() 132 | logit_scale = 100 133 | 134 | # ---------------------------- Prepare dataset ---------------------------- 135 | print("Preparing dataset.") 136 | print(preprocess) 137 | print(train_transform) 138 | 139 | level_name = (args.level).replace("_", "") 140 | 141 | if args.task == "classifier": 142 | 143 | textual_csv_train = os.path.join( 144 | args.root_path, args.dataset + "_" + args.model_name + "_textual_train.npz" 145 | ) 146 | 147 | if not os.path.exists(textual_csv_train) and args.textual == "True": 148 | dataset = build_dataset(args.dataset, args.root_path, -1, args.level) 149 | 150 | textual_extractor(args, dataset, model_clip, tokenizer) 151 | 152 | features_csv_train = os.path.join( 153 | args.root_path, args.dataset + "_" + args.model_name + "_features_train.npz" 154 | ) 155 | features_csv_val = os.path.join( 156 | args.root_path, args.dataset + "_" + args.model_name + "_features_val.npz" 157 | ) 158 | features_csv_test = os.path.join( 159 | args.root_path, args.dataset + "_" + args.model_name + "_features_test.npz" 160 | ) 161 | 162 | if ( 163 | not os.path.exists(features_csv_train) 164 | or not os.path.exists(features_csv_val) 165 | or not os.path.exists(features_csv_test) 166 | ): 167 | dataset = build_dataset(args.dataset, args.root_path, -1, args.level) 168 | 169 | val_loader = build_data_loader( 170 | data_source=dataset.val, 171 | batch_size=256, 172 | is_train=False, 173 | tfm=preprocess, 174 | shuffle=False, 175 | num_workers=5, 176 | ) 177 | 178 | test_loader = build_data_loader( 179 | data_source=dataset.test, 180 | batch_size=256, 181 | is_train=False, 182 | tfm=preprocess, 183 | shuffle=False, 184 | num_workers=5, 185 | ) 186 | 187 | train_loader = None 188 | if not args.eval_only: 189 | 190 | train_loader = build_data_loader( 191 | data_source=dataset.train_x, 192 | batch_size=args.batch_size, 193 | tfm=train_transform, 194 | is_train=True, 195 | shuffle=True, 196 | num_workers=5, 197 | ) 198 | 199 | features_extractor(args, model_clip, train_loader, val_loader, test_loader) 200 | 201 | train_dataset = FeaturesDataset( 202 | os.path.join( 203 | args.root_path, 204 | args.dataset + "_" + args.model_name + "_features_train.npz", 205 | ) 206 | ) 207 | val_dataset = FeaturesDataset( 208 | os.path.join( 209 | args.root_path, 210 | args.dataset + "_" + args.model_name + "_features_val.npz", 211 | ) 212 | ) 213 | test_dataset = FeaturesDataset( 214 | os.path.join( 215 | args.root_path, 216 | args.dataset + "_" + args.model_name + "_features_test.npz", 217 | ) 218 | ) 219 | 220 | train_loader = torch.utils.data.DataLoader( 221 | train_dataset, 222 | batch_size=args.batch_size, 223 | num_workers=5, 224 | shuffle=True, 225 | pin_memory=True, 226 | ) 227 | 228 | val_loader = torch.utils.data.DataLoader( 229 | val_dataset, 230 | batch_size=args.batch_size, 231 | num_workers=5, 232 | shuffle=True, 233 | pin_memory=True, 234 | ) 235 | 236 | test_loader = torch.utils.data.DataLoader( 237 | test_dataset, 238 | batch_size=args.batch_size, 239 | num_workers=5, 240 | shuffle=True, 241 | pin_memory=True, 242 | ) 243 | 244 | elif args.task == "lora": 245 | 246 | if args.dataset == "hicervix": 247 | pt_path = ( 248 | "./" 249 | + str(args.dataset) 250 | + "_" 251 | + str(args.seed) 252 | + "_" 253 | + str(args.shots) 254 | + "_" 255 | + str(level_name) 256 | + ".pt" 257 | ) 258 | 259 | if not os.path.exists(pt_path): 260 | # Doing this to save time. 261 | os.system( 262 | f"python3 dataset_hicervix.py --seed_launch {args.seed} --shots_launch {args.shots} --level_launch {args.level}" 263 | ) 264 | 265 | dataset = torch.load(pt_path, weights_only=False) 266 | else: 267 | dataset = build_dataset(args.dataset, args.root_path, args.shots) 268 | 269 | val_loader = build_data_loader( 270 | data_source=dataset.val, 271 | batch_size=256, 272 | is_train=False, 273 | tfm=preprocess, 274 | shuffle=False, 275 | num_workers=5, 276 | ) 277 | 278 | test_loader = build_data_loader( 279 | data_source=dataset.test, 280 | batch_size=256, 281 | is_train=False, 282 | tfm=preprocess, 283 | shuffle=False, 284 | num_workers=5, 285 | ) 286 | 287 | train_loader = build_data_loader( 288 | data_source=dataset.train_x, 289 | batch_size=args.batch_size, 290 | tfm=train_transform, 291 | is_train=True, 292 | shuffle=True, 293 | num_workers=5, 294 | ) 295 | 296 | elif args.task == "percentage_lora": 297 | 298 | assert args.percentage > 0, "The percentage should be greater than zero." 299 | 300 | if args.dataset == "hicervix": 301 | pt_path = ( 302 | "./" 303 | + str(args.dataset) 304 | + "_" 305 | + str(args.seed) 306 | + "_" 307 | + str(args.shots) 308 | + "_" 309 | + str(level_name) 310 | + "_" 311 | + str(args.percentage) 312 | + "_percent.pt" 313 | ) 314 | 315 | if not os.path.exists(pt_path): 316 | # Doing this to save time. 317 | os.system( 318 | f"python3 dataset_hicervix.py --seed_launch {args.seed} --shots_launch {args.shots} --level_launch {args.level} --percent_launch {args.percentage}" 319 | ) 320 | 321 | dataset = torch.load(pt_path, weights_only=False) 322 | else: 323 | print("Percentage experiment was not implemented for the other datasets.") 324 | 325 | val_loader = build_data_loader( 326 | data_source=dataset.val, 327 | batch_size=256, 328 | is_train=False, 329 | tfm=preprocess, 330 | shuffle=False, 331 | num_workers=5, 332 | ) 333 | 334 | test_loader = build_data_loader( 335 | data_source=dataset.test, 336 | batch_size=256, 337 | is_train=False, 338 | tfm=preprocess, 339 | shuffle=False, 340 | num_workers=5, 341 | ) 342 | 343 | train_loader = build_data_loader( 344 | data_source=dataset.train_x, 345 | batch_size=args.batch_size, 346 | tfm=train_transform, 347 | is_train=True, 348 | shuffle=True, 349 | num_workers=5, 350 | ) 351 | 352 | else: 353 | print("We are in the wrong situation") 354 | 355 | # Classifier experiment 356 | if args.task == "classifier": 357 | run_uni(args, model_clip, logit_scale, train_loader, val_loader, test_loader) 358 | 359 | # LoRA experiment 360 | elif args.task == "lora": 361 | run_uni_lora( 362 | args, model_clip, logit_scale, train_loader, val_loader, test_loader 363 | ) 364 | 365 | # Percentage - LoRA experiment 366 | elif args.task == "percentage_lora": 367 | run_uni_lora_percent( 368 | args, model_clip, logit_scale, train_loader, val_loader, test_loader 369 | ) 370 | 371 | else: 372 | print("Wrong task name") 373 | 374 | 375 | if __name__ == "__main__": 376 | main() 377 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import gdown 4 | import torch 5 | import random 6 | import tarfile 7 | import zipfile 8 | import os.path as osp 9 | from PIL import Image 10 | import torchvision.transforms as T 11 | from collections import defaultdict 12 | from torch.utils.data import Dataset as TorchDataset 13 | 14 | 15 | def read_json(fpath): 16 | """Read json file from a path.""" 17 | with open(fpath, "r") as f: 18 | obj = json.load(f) 19 | return obj 20 | 21 | 22 | def write_json(obj, fpath): 23 | """Writes to a json file.""" 24 | if not osp.exists(osp.dirname(fpath)): 25 | os.makedirs(osp.dirname(fpath)) 26 | with open(fpath, "w") as f: 27 | json.dump(obj, f, indent=4, separators=(",", ": ")) 28 | 29 | 30 | def read_image(path): 31 | """Read image from path using ``PIL.Image``. 32 | 33 | Args: 34 | path (str): path to an image. 35 | 36 | Returns: 37 | PIL image 38 | """ 39 | if not osp.exists(path): 40 | raise IOError("No file exists at {}".format(path)) 41 | 42 | while True: 43 | try: 44 | img = Image.open(path).convert("RGB") 45 | return img 46 | except IOError: 47 | print( 48 | "Cannot read image from {}, " 49 | "probably due to heavy IO. Will re-try".format(path) 50 | ) 51 | 52 | 53 | def listdir_nohidden(path, sort=False): 54 | """List non-hidden items in a directory. 55 | 56 | Args: 57 | path (str): directory path. 58 | sort (bool): sort the items. 59 | """ 60 | items = [f for f in os.listdir(path) if not f.startswith(".") and "sh" not in f] 61 | if sort: 62 | items.sort() 63 | return items 64 | 65 | 66 | class Datum: 67 | """Data instance which defines the basic attributes. 68 | 69 | Args: 70 | impath (str): image path. 71 | label (int): class label. 72 | domain (int): domain label. 73 | classname (str): class name. 74 | """ 75 | 76 | def __init__(self, impath="", label=0, domain=-1, classname=""): 77 | assert isinstance(impath, str) 78 | assert isinstance(label, int) 79 | assert isinstance(domain, int) 80 | assert isinstance(classname, str) 81 | 82 | self._impath = impath 83 | self._label = label 84 | self._domain = domain 85 | self._classname = classname 86 | 87 | @property 88 | def impath(self): 89 | return self._impath 90 | 91 | @property 92 | def label(self): 93 | return self._label 94 | 95 | @property 96 | def domain(self): 97 | return self._domain 98 | 99 | @property 100 | def classname(self): 101 | return self._classname 102 | 103 | 104 | class DatasetBase: 105 | """A unified dataset class for 106 | 1) domain adaptation 107 | 2) domain generalization 108 | 3) semi-supervised learning 109 | """ 110 | 111 | dataset_dir = "" # the directory where the dataset is stored 112 | domains = [] # string names of all domains 113 | 114 | def __init__(self, train_x=None, train_u=None, val=None, test=None): 115 | self._train_x = train_x # labeled training data 116 | self._train_u = train_u # unlabeled training data (optional) 117 | self._val = val # validation data (optional) 118 | self._test = test # test data 119 | 120 | self._num_classes = self.get_num_classes(train_x) 121 | self._lab2cname, self._classnames = self.get_lab2cname(train_x) 122 | 123 | @property 124 | def train_x(self): 125 | return self._train_x 126 | 127 | @property 128 | def train_u(self): 129 | return self._train_u 130 | 131 | @property 132 | def val(self): 133 | return self._val 134 | 135 | @property 136 | def test(self): 137 | return self._test 138 | 139 | @property 140 | def lab2cname(self): 141 | return self._lab2cname 142 | 143 | @property 144 | def classnames(self): 145 | return self._classnames 146 | 147 | @property 148 | def num_classes(self): 149 | return self._num_classes 150 | 151 | def get_num_classes(self, data_source): 152 | """Count number of classes. 153 | 154 | Args: 155 | data_source (list): a list of Datum objects. 156 | """ 157 | label_set = set() 158 | 159 | for item in data_source: 160 | label_set.add(item.label) 161 | 162 | return max(label_set) + 1 163 | 164 | def get_lab2cname(self, data_source): 165 | """Get a label-to-classname mapping (dict). 166 | 167 | Args: 168 | data_source (list): a list of Datum objects. 169 | """ 170 | container = set() 171 | for item in data_source: 172 | container.add((item.label, item.classname)) 173 | mapping = {label: classname for label, classname in container} 174 | labels = list(mapping.keys()) 175 | labels.sort() 176 | classnames = [mapping[label] for label in labels] 177 | return mapping, classnames 178 | 179 | def check_input_domains(self, source_domains, target_domains): 180 | self.is_input_domain_valid(source_domains) 181 | self.is_input_domain_valid(target_domains) 182 | 183 | def is_input_domain_valid(self, input_domains): 184 | for domain in input_domains: 185 | if domain not in self.domains: 186 | raise ValueError( 187 | "Input domain must belong to {}, " 188 | "but got [{}]".format(self.domains, domain) 189 | ) 190 | 191 | def download_data(self, url, dst, from_gdrive=True): 192 | if not osp.exists(osp.dirname(dst)): 193 | os.makedirs(osp.dirname(dst)) 194 | 195 | if from_gdrive: 196 | gdown.download(url, dst, quiet=False) 197 | else: 198 | raise NotImplementedError 199 | 200 | print("Extracting file ...") 201 | 202 | try: 203 | tar = tarfile.open(dst) 204 | tar.extractall(path=osp.dirname(dst)) 205 | tar.close() 206 | except: # noqa: E722 207 | zip_ref = zipfile.ZipFile(dst, "r") 208 | zip_ref.extractall(osp.dirname(dst)) 209 | zip_ref.close() 210 | 211 | print("File extracted to {}".format(osp.dirname(dst))) 212 | 213 | def generate_pourcent_dataset(self, *data_sources, pourcentage, repeat=True): 214 | """Generate a few-shot dataset (typically for the training set). 215 | 216 | This function is useful when one wants to evaluate a model 217 | in a few-shot learning setting where each class only contains 218 | a few number of images. 219 | 220 | Args: 221 | data_sources: each individual is a list containing Datum objects. 222 | num_shots (int): number of instances per class to sample. 223 | repeat (bool): repeat images if needed. 224 | """ 225 | 226 | print("WE ARE DANS LA BONNE FONCTION") 227 | 228 | output = [] 229 | 230 | for data_source in data_sources: 231 | tracker = self.split_dataset_by_label(data_source) 232 | dataset = [] 233 | 234 | for label, items in tracker.items(): 235 | 236 | sampled_items = items 237 | 238 | random.shuffle(sampled_items) 239 | borne = int(len(sampled_items) * pourcentage / 100) 240 | data_interm = sampled_items[:borne] 241 | 242 | dataset.extend(data_interm) 243 | 244 | output.append(dataset) 245 | 246 | if len(output) == 1: 247 | return output[0] 248 | 249 | return output 250 | 251 | def generate_fewshot_dataset(self, *data_sources, num_shots=-1, repeat=True): 252 | """Generate a few-shot dataset (typically for the training set). 253 | 254 | This function is useful when one wants to evaluate a model 255 | in a few-shot learning setting where each class only contains 256 | a few number of images. 257 | 258 | Args: 259 | data_sources: each individual is a list containing Datum objects. 260 | num_shots (int): number of instances per class to sample. 261 | repeat (bool): repeat images if needed. 262 | """ 263 | if num_shots == -1: 264 | output = [] 265 | 266 | for data_source in data_sources: 267 | tracker = self.split_dataset_by_label(data_source) 268 | dataset = [] 269 | 270 | for label, items in tracker.items(): 271 | sampled_items = items 272 | dataset.extend(sampled_items) 273 | 274 | output.append(dataset) 275 | 276 | if len(output) == 1: 277 | return output[0] 278 | 279 | return output 280 | 281 | print(f"Creating a {num_shots}-shot dataset") 282 | 283 | if num_shots < 1: 284 | if len(data_sources) == 1: 285 | return data_sources[0] 286 | return data_sources 287 | 288 | print(f"Creating a {num_shots}-shot dataset") 289 | 290 | output = [] 291 | 292 | for data_source in data_sources: 293 | tracker = self.split_dataset_by_label(data_source) 294 | dataset = [] 295 | 296 | for label, items in tracker.items(): 297 | if len(items) >= num_shots: 298 | sampled_items = random.sample(items, num_shots) 299 | else: 300 | if repeat: 301 | sampled_items = random.choices(items, k=num_shots) 302 | else: 303 | sampled_items = items 304 | dataset.extend(sampled_items) 305 | 306 | output.append(dataset) 307 | 308 | if len(output) == 1: 309 | return output[0] 310 | 311 | return output 312 | 313 | def split_dataset_by_label(self, data_source): 314 | """Split a dataset, i.e. a list of Datum objects, 315 | into class-specific groups stored in a dictionary. 316 | 317 | Args: 318 | data_source (list): a list of Datum objects. 319 | """ 320 | output = defaultdict(list) 321 | 322 | for item in data_source: 323 | output[item.label].append(item) 324 | 325 | return output 326 | 327 | def split_dataset_by_domain(self, data_source): 328 | """Split a dataset, i.e. a list of Datum objects, 329 | into domain-specific groups stored in a dictionary. 330 | 331 | Args: 332 | data_source (list): a list of Datum objects. 333 | """ 334 | output = defaultdict(list) 335 | 336 | for item in data_source: 337 | output[item.domain].append(item) 338 | 339 | return output 340 | 341 | 342 | class DatasetWrapper(TorchDataset): 343 | def __init__( 344 | self, 345 | data_source, 346 | input_size, 347 | transform=None, 348 | is_train=False, 349 | return_img0=False, 350 | k_tfm=1, 351 | ): 352 | self.data_source = data_source 353 | self.transform = transform # accept list (tuple) as input 354 | self.is_train = is_train 355 | # Augmenting an image K>1 times is only allowed during training 356 | self.k_tfm = k_tfm if is_train else 1 357 | self.return_img0 = return_img0 358 | 359 | if self.k_tfm > 1 and transform is None: 360 | raise ValueError( 361 | "Cannot augment the image {} times " 362 | "because transform is None".format(self.k_tfm) 363 | ) 364 | 365 | # Build transform that doesn't apply any data augmentation 366 | interp_mode = T.InterpolationMode.BICUBIC 367 | to_tensor = [] 368 | to_tensor += [T.Resize(input_size, interpolation=interp_mode)] 369 | to_tensor += [T.ToTensor()] 370 | normalize = T.Normalize( 371 | mean=(0.48145466, 0.4578275, 0.40821073), 372 | std=(0.26862954, 0.26130258, 0.27577711), 373 | ) 374 | to_tensor += [normalize] 375 | self.to_tensor = T.Compose(to_tensor) 376 | 377 | def __len__(self): 378 | return len(self.data_source) 379 | 380 | def __getitem__(self, idx): 381 | item = self.data_source[idx] 382 | 383 | output = {"label": item.label, "domain": item.domain, "impath": item.impath} 384 | 385 | img0 = read_image(item.impath) 386 | 387 | if self.transform is not None: 388 | if isinstance(self.transform, (list, tuple)): 389 | for i, tfm in enumerate(self.transform): 390 | img = self._transform_image(tfm, img0) 391 | keyname = "img" 392 | if (i + 1) > 1: 393 | keyname += str(i + 1) 394 | output[keyname] = img 395 | else: 396 | img = self._transform_image(self.transform, img0) 397 | output["img"] = img 398 | 399 | if self.return_img0: 400 | output["img0"] = self.to_tensor(img0) 401 | 402 | return output["img"], output["label"] 403 | 404 | def _transform_image(self, tfm, img0): 405 | img_list = [] 406 | 407 | for k in range(self.k_tfm): 408 | img_list.append(tfm(img0)) 409 | 410 | img = img_list 411 | if len(img) == 1: 412 | img = img[0] 413 | 414 | return img 415 | 416 | 417 | def build_data_loader( 418 | data_source=None, 419 | batch_size=64, 420 | input_size=224, 421 | tfm=None, 422 | is_train=True, 423 | shuffle=False, 424 | dataset_wrapper=None, 425 | num_workers=8, 426 | ): 427 | 428 | if dataset_wrapper is None: 429 | dataset_wrapper = DatasetWrapper 430 | 431 | # Build data loader 432 | data_loader = torch.utils.data.DataLoader( 433 | dataset_wrapper( 434 | data_source, input_size=input_size, transform=tfm, is_train=is_train 435 | ), 436 | batch_size=batch_size, 437 | num_workers=num_workers, 438 | shuffle=shuffle, 439 | drop_last=False, 440 | pin_memory=(torch.cuda.is_available()), 441 | ) 442 | assert len(data_loader) > 0 443 | 444 | return data_loader 445 | -------------------------------------------------------------------------------- /loralib/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from typing import Dict 7 | 8 | from .layers import ( 9 | LoRALayer, 10 | PlainMultiheadAttentionLoRA, 11 | SimpleAttentionLoRA, 12 | LinearLoRA, 13 | ) 14 | 15 | INDEX_POSITIONS_TEXT = { 16 | "top1": [11], 17 | "top2": [10, 11], 18 | "top3": [9, 10, 11], 19 | "bottom": [0, 1, 2, 3], 20 | "mid": [4, 5, 6, 7], 21 | "up": [8, 9, 10, 11], 22 | "half-up": [6, 7, 8, 9, 10, 11], 23 | "half-bottom": [0, 1, 2, 3, 4, 5], 24 | "all": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], 25 | } 26 | 27 | 28 | INDEX_POSITIONS_VISION = { 29 | "ViT-B/16": { 30 | "top": [11], 31 | "top3": [9, 10, 11], 32 | "bottom": [0, 1, 2, 3], 33 | "mid": [4, 5, 6, 7], 34 | "up": [8, 9, 10, 11], 35 | "half-up": [6, 7, 8, 9, 10, 11], 36 | "half-bottom": [0, 1, 2, 3, 4, 5], 37 | "all": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], 38 | }, 39 | "ViT-B/32": { 40 | "bottom": [0, 1, 2, 3], 41 | "mid": [4, 5, 6, 7], 42 | "up": [8, 9, 10, 11], 43 | "half-up": [6, 7, 8, 9, 10, 11], 44 | "half-bottom": [0, 1, 2, 3, 4, 5], 45 | "all": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], 46 | }, 47 | "ViT-L/14": { 48 | "half-up": [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23], 49 | "half-bottom": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], 50 | "all": [ 51 | 0, 52 | 1, 53 | 2, 54 | 3, 55 | 4, 56 | 5, 57 | 6, 58 | 7, 59 | 8, 60 | 9, 61 | 10, 62 | 11, 63 | 12, 64 | 13, 65 | 14, 66 | 15, 67 | 16, 68 | 17, 69 | 18, 70 | 19, 71 | 20, 72 | 21, 73 | 22, 74 | 23, 75 | ], 76 | }, 77 | } 78 | 79 | 80 | def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None: 81 | for n, p in model.named_parameters(): 82 | if "lora_" not in n: 83 | p.requires_grad = False 84 | if bias == "none": 85 | return 86 | elif bias == "all": 87 | for n, p in model.named_parameters(): 88 | if "bias" in n: 89 | p.requires_grad = True 90 | elif bias == "lora_only": 91 | for m in model.modules(): 92 | if isinstance(m, LoRALayer) and hasattr(m, "bias") and m.bias is not None: 93 | m.bias.requires_grad = True 94 | else: 95 | raise NotImplementedError 96 | 97 | 98 | def lora_state_dict(model: nn.Module, bias: str = "none") -> Dict[str, torch.Tensor]: 99 | my_state_dict = model.state_dict() 100 | if bias == "none": 101 | return {k: my_state_dict[k] for k in my_state_dict if "lora_" in k} 102 | elif bias == "all": 103 | return { 104 | k: my_state_dict[k] for k in my_state_dict if "lora_" in k or "bias" in k 105 | } 106 | elif bias == "lora_only": 107 | to_return = {} 108 | for k in my_state_dict: 109 | if "lora_" in k: 110 | to_return[k] = my_state_dict[k] 111 | bias_name = k.split("lora_")[0] + "bias" 112 | if bias_name in my_state_dict: 113 | to_return[bias_name] = my_state_dict[bias_name] 114 | return to_return 115 | else: 116 | raise NotImplementedError 117 | 118 | 119 | def get_lora_parameters(model, bias="none"): 120 | params = [] 121 | for name, param in model.named_parameters(): 122 | if bias == "none": 123 | if "lora_" in name: 124 | params.append(param) 125 | elif bias == "all": 126 | if "lora_" in name or "bias" in name: 127 | params.append(param) 128 | elif bias == "lora_only": 129 | if "lora_" in name: 130 | params.append(param) 131 | bias_name = name.split("lora_")[0] + "bias" 132 | if bias_name in model.state_dict(): 133 | bias_param = dict(model.named_parameters())[bias_name] 134 | params.append(bias_param) 135 | else: 136 | raise NotImplementedError 137 | return params 138 | 139 | 140 | def apply_lora(args, clip_model): 141 | list_lora_layers = [] 142 | if args.encoder == "text" or args.encoder == "both": 143 | indices = INDEX_POSITIONS_TEXT[args.position] 144 | text_encoder = clip_model.transformer 145 | for i, block in enumerate(text_encoder.resblocks): 146 | # print(f"Residual Attention Block {i}: {block}") 147 | if i in indices: 148 | for name, submodule in block.named_children(): 149 | if isinstance(submodule, nn.MultiheadAttention): 150 | new_multi_head_lora = PlainMultiheadAttentionLoRA( 151 | submodule, 152 | enable_lora=args.params, 153 | r=args.r, 154 | lora_alpha=args.alpha, 155 | dropout_rate=args.dropout_rate, 156 | ) 157 | setattr(block, name, new_multi_head_lora) 158 | list_lora_layers.append(new_multi_head_lora) 159 | 160 | if args.encoder == "vision" or args.encoder == "both": 161 | indices = INDEX_POSITIONS_VISION[args.backbone][args.position] 162 | if args.model_name == "uni": 163 | vision_encoder = clip_model 164 | for i, block in enumerate(vision_encoder.blocks): 165 | # print(f"Residual Attention Block {i}: {block}") 166 | if i in indices: 167 | for name, submodule in block.named_children(): 168 | if "Attention" in str(type(submodule)): 169 | new_attention_lora = SimpleAttentionLoRA( 170 | submodule, 171 | enable_lora=args.params, 172 | r=args.r, 173 | lora_alpha=args.alpha, 174 | dropout_rate=args.dropout_rate, 175 | ) 176 | setattr(block, name, new_attention_lora) 177 | list_lora_layers.append(new_attention_lora) 178 | 179 | elif args.model_name == "biomedclip": 180 | vision_encoder = clip_model.visual.trunk 181 | for i, block in enumerate(vision_encoder.blocks): 182 | # print(f"Residual Attention Block {i}: {block}") 183 | if i in indices: 184 | for name, submodule in block.named_children(): 185 | if "Attention" in str(type(submodule)): 186 | new_attention_lora = SimpleAttentionLoRA( 187 | submodule, 188 | enable_lora=args.params, 189 | r=args.r, 190 | lora_alpha=args.alpha, 191 | dropout_rate=args.dropout_rate, 192 | ) 193 | setattr(block, name, new_attention_lora) 194 | list_lora_layers.append(new_attention_lora) 195 | 196 | elif args.model_name == "conch": 197 | vision_encoder = clip_model.visual.trunk 198 | for i, block in enumerate(vision_encoder.blocks): 199 | # print(f"Residual Attention Block {i}: {block}") 200 | if i in indices: 201 | for name, submodule in block.named_children(): 202 | if "Attention" in str(type(submodule)): 203 | new_attention_lora = SimpleAttentionLoRA( 204 | submodule, 205 | enable_lora=args.params, 206 | r=args.r, 207 | lora_alpha=args.alpha, 208 | dropout_rate=args.dropout_rate, 209 | ) 210 | setattr(block, name, new_attention_lora) 211 | list_lora_layers.append(new_attention_lora) 212 | 213 | elif args.model_name == "vit_google": 214 | vision_encoder = clip_model.vit.encoder 215 | for i, block in enumerate(vision_encoder.layer): 216 | for name, submodule in block.attention.attention.named_children(): 217 | for item in args.params: 218 | if item == "q" and name == "query": 219 | new_attention_lora = LinearLoRA( 220 | submodule, 221 | r=args.r, 222 | lora_alpha=args.alpha, 223 | dropout_rate=args.dropout_rate, 224 | ) 225 | setattr(block.attention.attention, name, new_attention_lora) 226 | list_lora_layers.append(new_attention_lora) 227 | elif item == "k" and name == "key": 228 | new_attention_lora = LinearLoRA( 229 | submodule, 230 | r=args.r, 231 | lora_alpha=args.alpha, 232 | dropout_rate=args.dropout_rate, 233 | ) 234 | setattr(block.attention.attention, name, new_attention_lora) 235 | list_lora_layers.append(new_attention_lora) 236 | elif item == "v" and name == "value": 237 | new_attention_lora = LinearLoRA( 238 | submodule, 239 | r=args.r, 240 | lora_alpha=args.alpha, 241 | dropout_rate=args.dropout_rate, 242 | ) 243 | setattr(block.attention.attention, name, new_attention_lora) 244 | list_lora_layers.append(new_attention_lora) 245 | for name, submodule in block.attention.output.named_children(): 246 | if "o" in args.params and name == "dense": 247 | new_attention_lora = LinearLoRA( 248 | submodule, 249 | r=args.r, 250 | lora_alpha=args.alpha, 251 | dropout_rate=args.dropout_rate, 252 | ) 253 | setattr(block.attention.attention, name, new_attention_lora) 254 | list_lora_layers.append(new_attention_lora) 255 | 256 | else: 257 | vision_encoder = clip_model.visual.transformer 258 | for i, block in enumerate(vision_encoder.resblocks): 259 | # print(f"Residual Attention Block {i}: {block}") 260 | if i in indices: 261 | for name, submodule in block.named_children(): 262 | if isinstance(submodule, nn.MultiheadAttention): 263 | new_multi_head_lora = PlainMultiheadAttentionLoRA( 264 | submodule, 265 | enable_lora=args.params, 266 | r=args.r, 267 | lora_alpha=args.alpha, 268 | dropout_rate=args.dropout_rate, 269 | ) 270 | setattr(block, name, new_multi_head_lora) 271 | list_lora_layers.append(new_multi_head_lora) 272 | return list_lora_layers 273 | 274 | 275 | def save_lora(args, list_lora_layers): 276 | weights = {} 277 | for i, layer in enumerate(list_lora_layers): 278 | layer_weights = {} 279 | if "q" in args.params: 280 | if args.model_name == "vit_google": 281 | layer_weights["q_proj"] = { 282 | "w_lora_A": layer.w_lora_A.data, 283 | "w_lora_B": layer.w_lora_B.data, 284 | } 285 | else: 286 | layer_weights["q_proj"] = { 287 | "w_lora_A": layer.q_proj.w_lora_A.data, 288 | "w_lora_B": layer.q_proj.w_lora_B.data, 289 | } 290 | if "k" in args.params: 291 | if args.model_name == "vit_google": 292 | layer_weights["k_proj"] = { 293 | "w_lora_A": layer.w_lora_A.data, 294 | "w_lora_B": layer.w_lora_B.data, 295 | } 296 | else: 297 | layer_weights["k_proj"] = { 298 | "w_lora_A": layer.k_proj.w_lora_A.data, 299 | "w_lora_B": layer.k_proj.w_lora_B.data, 300 | } 301 | if "v" in args.params: 302 | if args.model_name == "vit_google": 303 | layer_weights["v_proj"] = { 304 | "w_lora_A": layer.w_lora_A.data, 305 | "w_lora_B": layer.w_lora_B.data, 306 | } 307 | else: 308 | layer_weights["v_proj"] = { 309 | "w_lora_A": layer.v_proj.w_lora_A.data, 310 | "w_lora_B": layer.v_proj.w_lora_B.data, 311 | } 312 | 313 | if "o" in args.params: 314 | layer_weights["proj"] = { 315 | "w_lora_A": layer.proj.w_lora_A.data, 316 | "w_lora_B": layer.proj.w_lora_B.data, 317 | } 318 | 319 | weights[f"layer_{i}"] = layer_weights 320 | 321 | metadata = { 322 | "r": args.r, 323 | "alpha": args.alpha, 324 | "encoder": args.encoder, 325 | "params": args.params, 326 | "position": args.position, 327 | } 328 | 329 | save_data = {"weights": weights, "metadata": metadata} 330 | 331 | # to manage names like ViT-B/16 332 | backbone = args.backbone.replace("/", "").replace("-", "").lower() 333 | save_dir = ( 334 | f"{args.save_path}/{backbone}/{args.dataset}/{args.shots}shots/seed{args.seed}" 335 | ) 336 | os.makedirs(save_dir, exist_ok=True) 337 | 338 | save_path = f"{save_dir}/{args.filename}.pt" 339 | torch.save(save_data, save_path) 340 | print(f"LoRA weights saved to {save_path}") 341 | 342 | 343 | def load_lora(args, list_lora_layers): 344 | # to manage names like ViT-B/16 345 | backbone = args.backbone.replace("/", "").replace("-", "").lower() 346 | load_path = f"{args.save_path}/{backbone}/{args.dataset}/{args.shots}shots/seed{args.seed}/{args.filename}.pt" 347 | 348 | if not os.path.exists(load_path): 349 | raise FileNotFoundError(f"File {load_path} does not exist.") 350 | 351 | loaded_data = torch.load(load_path) 352 | 353 | metadata = loaded_data["metadata"] 354 | if metadata["r"] != args.r: 355 | raise ValueError(f"r mismatch: expected {args.r}, found {metadata['r']}") 356 | if metadata["alpha"] != args.alpha: 357 | raise ValueError( 358 | f"alpha mismatch: expected {args.alpha}, found {metadata['alpha']}" 359 | ) 360 | if metadata["encoder"] != args.encoder: 361 | raise ValueError( 362 | f"Encoder mismatch: expected {args.encoder}, found {metadata['encoder']}" 363 | ) 364 | if metadata["params"] != args.params: 365 | raise ValueError( 366 | f"Params mismatch: expected {args.params}, found {metadata['params']}" 367 | ) 368 | if metadata["position"] != args.position: 369 | raise ValueError( 370 | f"Position mismatch: expected {args.position}, found {metadata['position']}" 371 | ) 372 | 373 | weights = loaded_data["weights"] 374 | for i, layer in enumerate(list_lora_layers): 375 | layer_weights = weights[f"layer_{i}"] 376 | if "q" in args.params and "q_proj" in layer_weights: 377 | layer.q_proj.w_lora_A.data.copy_(layer_weights["q_proj"]["w_lora_A"]) 378 | layer.q_proj.w_lora_B.data.copy_(layer_weights["q_proj"]["w_lora_B"]) 379 | if "k" in args.params and "k_proj" in layer_weights: 380 | layer.k_proj.w_lora_A.data.copy_(layer_weights["k_proj"]["w_lora_A"]) 381 | layer.k_proj.w_lora_B.data.copy_(layer_weights["k_proj"]["w_lora_B"]) 382 | if "v" in args.params and "v_proj" in layer_weights: 383 | layer.v_proj.w_lora_A.data.copy_(layer_weights["v_proj"]["w_lora_A"]) 384 | layer.v_proj.w_lora_B.data.copy_(layer_weights["v_proj"]["w_lora_B"]) 385 | if "o" in args.params and "proj" in layer_weights: 386 | layer.proj.w_lora_A.data.copy_(layer_weights["proj"]["w_lora_A"]) 387 | layer.proj.w_lora_B.data.copy_(layer_weights["proj"]["w_lora_B"]) 388 | 389 | print(f"LoRA weights loaded from {load_path}") 390 | -------------------------------------------------------------------------------- /lora.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import numpy as np 5 | import torch.nn as nn 6 | from tqdm import tqdm 7 | from utils import cls_acc 8 | import pytorch_warmup as warmup 9 | import torch.nn.functional as F 10 | from transformers.modeling_outputs import ImageClassifierOutput 11 | from loralib.utils import ( 12 | mark_only_lora_as_trainable, 13 | apply_lora, 14 | get_lora_parameters, 15 | save_lora, 16 | ) 17 | 18 | 19 | class ClipWrapper(nn.Module): 20 | def __init__(self, clip_model, model_linear): 21 | super(ClipWrapper, self).__init__() 22 | self.clip_model = clip_model 23 | self.model_linear = model_linear 24 | 25 | def forward(self, x): 26 | # clip_model returns a tuple, we unpack it 27 | clip_output = self.clip_model(x) 28 | # Select only the first element of the tuple 29 | image_features = clip_output[0] 30 | # Apply the linear layer on the image features 31 | output = self.model_linear(image_features) 32 | return output 33 | 34 | 35 | def get_number_trainable_parameters(model_name, clip_model): 36 | n_param = np.sum([p.numel() for p in get_lora_parameters(clip_model)]) 37 | return n_param 38 | 39 | 40 | def get_feature_size(model, input_size): 41 | model.eval() 42 | 43 | # Move model to GPU if available 44 | device = next(model.parameters()).device 45 | 46 | with torch.no_grad(): 47 | # Create a sample input tensor and move it to the same device as the model 48 | sample_input = torch.randn(1, *input_size).to(device) 49 | features = model(sample_input) 50 | 51 | if isinstance(features, ImageClassifierOutput): 52 | features = features.logits 53 | 54 | return np.prod(features.size()[1:]) 55 | 56 | 57 | def evaluate_uni(args, clip_model, loader): 58 | 59 | clip_model.eval() 60 | 61 | acc = 0.0 62 | loss_epoch = 0.0 63 | tot_samples = 0 64 | 65 | with torch.no_grad(): 66 | for i, (images, target) in enumerate(loader): 67 | images, target = images.cuda(), target.cuda() 68 | image_features = clip_model(images) 69 | 70 | if isinstance(image_features, ImageClassifierOutput): 71 | image_features = image_features.logits 72 | 73 | loss = F.cross_entropy(image_features, target) 74 | loss_epoch += loss.item() * target.shape[0] 75 | acc += cls_acc(image_features, target) * target.shape[0] 76 | tot_samples += target.shape[0] 77 | 78 | acc /= tot_samples 79 | loss_epoch /= tot_samples 80 | return acc, loss_epoch 81 | 82 | 83 | def run_uni(args, clip_model, logit_scale, train_loader, val_loader, test_loader): 84 | """Classifier experiment - backbone freezed and classification layer added on the top of it""" 85 | 86 | VALIDATION = True 87 | 88 | if args.model_name in ["vit_google"]: 89 | num_features = 768 90 | elif args.model_name in ["clip", "quilt", "biomedclip"]: 91 | num_features = 512 92 | elif args.model_name in ["uni"]: 93 | num_features = get_feature_size(clip_model, (3, 224, 224)) 94 | 95 | clip_model_ = nn.Sequential( 96 | nn.Flatten(start_dim=1), nn.Linear(num_features, args.num_classes) 97 | ).cuda() 98 | 99 | if args.textual == "True": 100 | 101 | textual_features = np.load( 102 | os.path.join( 103 | args.root_path, 104 | args.dataset + "_" + args.model_name + "_textual_train.npz", 105 | ) 106 | )["textuals"] 107 | clip_model_[1].weight.data = torch.tensor(textual_features, dtype=torch.float32) 108 | 109 | clip_model_ = clip_model_.cuda() 110 | trainable_parameters_ = [] 111 | for _, param in clip_model_.named_parameters(): 112 | trainable_parameters_.append(param) 113 | 114 | optimizer = torch.optim.AdamW( 115 | trainable_parameters_, 116 | weight_decay=1e-2, 117 | betas=(0.9, 0.999), 118 | lr=args.lr, 119 | ) 120 | 121 | num_steps = args.n_iters 122 | warmup_period = 50 123 | total_iters = warmup_period + num_steps 124 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 125 | optimizer, num_steps, eta_min=1e-6 126 | ) 127 | warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period) 128 | 129 | # training LoRA 130 | scaler = torch.cuda.amp.GradScaler() 131 | count_iters = 0 132 | 133 | while count_iters < total_iters: 134 | clip_model_.train() 135 | 136 | acc_train = 0 137 | tot_samples = 0 138 | loss_epoch = 0.0 139 | 140 | for i, (images, target) in enumerate(tqdm(train_loader)): 141 | 142 | images, target = images.cuda(), target.cuda() 143 | output = clip_model_(images) 144 | if isinstance(output, ImageClassifierOutput): 145 | output = output.logits 146 | 147 | loss = F.cross_entropy(output, target) 148 | acc_train += cls_acc(output, target) * target.shape[0] 149 | loss_epoch += loss.item() * target.shape[0] 150 | tot_samples += target.shape[0] 151 | 152 | optimizer.zero_grad() 153 | scaler.scale(loss).backward() 154 | scaler.step(optimizer) 155 | scaler.update() 156 | 157 | with warmup_scheduler.dampening(): 158 | if warmup_scheduler.last_step + 1 >= warmup_period: 159 | scheduler.step() 160 | 161 | count_iters += 1 162 | 163 | if count_iters == total_iters: 164 | break 165 | 166 | acc_train /= tot_samples 167 | loss_epoch /= tot_samples 168 | 169 | current_lr = scheduler.get_last_lr()[0] 170 | for param_group in optimizer.param_groups: 171 | optimizer_lr = param_group["lr"] 172 | print( 173 | " OptLR: {:.6f}, LR: {:.6f}, Acc: {:.4f}, Loss: {:.4f}".format( 174 | optimizer_lr, current_lr, acc_train, loss_epoch 175 | ) 176 | ) 177 | 178 | # Eval 179 | if VALIDATION: 180 | acc_val, loss_val = evaluate_uni(args, clip_model_, val_loader) 181 | print("**** Val accuracy: {:.2f}. ****\n".format(acc_val)) 182 | 183 | acc_test, _ = evaluate_uni(args, clip_model_, test_loader) 184 | print("**** Final test accuracy: {:.2f}. ****\n".format(acc_test)) 185 | 186 | json_path = ( 187 | "/CECI/proj/medresyst/manon/US/results/US_CU/Clip-LoRA/classifier_" 188 | + str(args.dataset) 189 | + "_" 190 | + str(args.model_name) 191 | + "_" 192 | + str(args.seed) 193 | + "_" 194 | + str(args.shots) 195 | + "_" 196 | + str(args.lr) 197 | + "_" 198 | + str(args.textual) 199 | + "_results.json" 200 | ) 201 | 202 | with open( 203 | json_path, 204 | "w", 205 | ) as f: 206 | json.dump({"val_acc": acc_val, "test_acc": acc_test}, f) 207 | 208 | return 209 | 210 | 211 | def evaluate_lora_uni(args, clip_model, loader): 212 | 213 | clip_model.eval() 214 | 215 | acc = 0.0 216 | loss_epoch = 0.0 217 | tot_samples = 0 218 | with torch.no_grad(): 219 | for i, (images, target) in enumerate(loader): 220 | 221 | images, target = images.cuda(), target.cuda() 222 | 223 | if args.model_name in ["clip"]: 224 | with torch.amp.autocast(device_type="cuda", dtype=torch.float16): 225 | image_features = clip_model(images) 226 | else: 227 | image_features = clip_model(images) 228 | 229 | if isinstance(image_features, ImageClassifierOutput): 230 | image_features = image_features.logits 231 | 232 | loss = F.cross_entropy(image_features, target) 233 | loss_epoch += loss.item() * target.shape[0] 234 | acc += cls_acc(image_features, target) * target.shape[0] 235 | tot_samples += target.shape[0] 236 | 237 | acc /= tot_samples 238 | loss_epoch /= tot_samples 239 | 240 | return acc, loss_epoch 241 | 242 | 243 | def run_uni_lora(args, clip_model, logit_scale, train_loader, val_loader, test_loader): 244 | 245 | VALIDATION = True 246 | acc_val = 0.0 247 | 248 | if args.model_name in ["vit_google"]: 249 | num_features = 768 250 | elif args.model_name in ["clip", "quilt", "biomedclip"]: 251 | num_features = 512 252 | elif args.model_name in ["uni"]: 253 | num_features = get_feature_size(clip_model, (3, 224, 224)) 254 | 255 | model_linear = nn.Sequential( 256 | nn.Flatten(start_dim=1), nn.Linear(num_features, args.num_classes) 257 | ).cuda() 258 | 259 | list_lora_layers = apply_lora(args, clip_model) 260 | clip_model = clip_model.cuda() 261 | 262 | mark_only_lora_as_trainable(clip_model) 263 | trainable_parameters_ = get_lora_parameters(clip_model) 264 | for _, param in model_linear.named_parameters(): 265 | trainable_parameters_.append(param) 266 | 267 | if args.model_name in ["clip", "quilt", "biomedclip"]: 268 | clip_model_ = nn.Sequential(clip_model.visual, model_linear) 269 | elif args.model_name in ["uni"]: 270 | clip_model_ = nn.Sequential(clip_model, model_linear) 271 | elif args.model_name in ["vit_google"]: 272 | setattr(clip_model, "classifier", model_linear) 273 | clip_model_ = clip_model 274 | else: 275 | raise RuntimeError( 276 | "Wrong model name used. Try clip, uni, biomedclip, vit_google or quilt." 277 | ) 278 | 279 | optimizer = torch.optim.AdamW( 280 | trainable_parameters_, 281 | weight_decay=1e-2, 282 | betas=(0.9, 0.999), 283 | lr=args.lr, 284 | ) 285 | 286 | num_steps = args.n_iters * args.shots 287 | warmup_period = 50 288 | total_iters = num_steps 289 | 290 | if args.shots > 0: 291 | total_iters = warmup_period + num_steps 292 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 293 | optimizer, num_steps, eta_min=1e-6 294 | ) 295 | warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period) 296 | 297 | # training LoRA 298 | scaler = torch.cuda.amp.GradScaler() 299 | count_iters = 0 300 | 301 | while count_iters < total_iters: 302 | clip_model_.train() 303 | 304 | acc_train = 0 305 | tot_samples = 0 306 | loss_epoch = 0.0 307 | 308 | for i, (images, target) in enumerate(tqdm(train_loader)): 309 | 310 | images, target = images.cuda(), target.cuda() 311 | if args.model_name in ["clip"]: 312 | with torch.amp.autocast(device_type="cuda", dtype=torch.float16): 313 | output = clip_model_(images) 314 | else: 315 | output = clip_model_(images) 316 | 317 | if isinstance(output, ImageClassifierOutput): 318 | output = output.logits 319 | 320 | loss = F.cross_entropy(output, target) 321 | acc_train += cls_acc(output, target) * target.shape[0] 322 | loss_epoch += loss.item() * target.shape[0] 323 | tot_samples += target.shape[0] 324 | 325 | optimizer.zero_grad() 326 | scaler.scale(loss).backward() 327 | scaler.step(optimizer) 328 | scaler.update() 329 | 330 | with warmup_scheduler.dampening(): 331 | if warmup_scheduler.last_step + 1 >= warmup_period: 332 | scheduler.step() 333 | 334 | count_iters += 1 335 | 336 | if count_iters == total_iters: 337 | break 338 | 339 | acc_train /= tot_samples 340 | loss_epoch /= tot_samples 341 | 342 | current_lr = scheduler.get_last_lr()[0] 343 | for param_group in optimizer.param_groups: 344 | optimizer_lr = param_group["lr"] 345 | print( 346 | " OptLR: {:.6f}, LR: {:.6f}, Acc: {:.4f}, Loss: {:.4f}".format( 347 | optimizer_lr, current_lr, acc_train, loss_epoch 348 | ) 349 | ) 350 | 351 | # Eval 352 | if VALIDATION: 353 | acc_val, loss_val = evaluate_lora_uni(args, clip_model_, val_loader) 354 | print("**** Val accuracy: {:.2f}. ****\n".format(acc_val)) 355 | 356 | acc_test, _ = evaluate_lora_uni(args, clip_model_, test_loader) 357 | print("**** Final test accuracy: {:.2f}. ****\n".format(acc_test)) 358 | 359 | json_path = ( 360 | "/CECI/proj/medresyst/manon/US/results/US_CU/Clip-LoRA/lora_" 361 | + str(args.dataset) 362 | + "_" 363 | + str(args.model_name) 364 | + "_" 365 | + str(args.seed) 366 | + "_" 367 | + str(args.shots) 368 | + "_" 369 | + str(args.lr) 370 | + "_" 371 | + str(args.r) 372 | + "_results.json" 373 | ) 374 | 375 | with open( 376 | json_path, 377 | "w", 378 | ) as f: 379 | json.dump({"val_acc": acc_val, "test_acc": acc_test}, f) 380 | 381 | # Figure confusion matrix 382 | # plt.figure(figsize=(6, 5)) 383 | # sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", cbar=True) 384 | # plt.xlabel("Predicted Labels") 385 | # plt.ylabel("True Labels") 386 | # plt.title(f"Confusion Matrix - Test accuraccy {test_acc:.4f}") 387 | # plt.savefig(os.path.join(results_dir, f"{name_run}_confusion_matrix.png")) 388 | # plt.close() 389 | 390 | args.save_path = json_path.replace(".json", ".pt") 391 | save_lora(args, list_lora_layers) 392 | 393 | return 394 | 395 | 396 | def run_uni_lora_percent( 397 | args, clip_model, logit_scale, train_loader, val_loader, test_loader 398 | ): 399 | 400 | VALIDATION = True 401 | acc_val = 0.0 402 | 403 | if args.model_name in ["vit_google", "clip"]: 404 | num_features = 768 405 | elif args.model_name in ["quilt", "biomedclip"]: 406 | num_features = 512 407 | elif args.model_name in ["uni"]: 408 | num_features = get_feature_size(clip_model, (3, 224, 224)) 409 | 410 | model_linear = nn.Sequential( 411 | nn.Flatten(start_dim=1), nn.Linear(num_features, args.num_classes) 412 | ).cuda() 413 | 414 | list_lora_layers = apply_lora(args, clip_model) 415 | clip_model = clip_model.cuda() 416 | 417 | mark_only_lora_as_trainable(clip_model) 418 | trainable_parameters_ = get_lora_parameters(clip_model) 419 | for _, param in model_linear.named_parameters(): 420 | trainable_parameters_.append(param) 421 | 422 | if args.model_name in ["clip", "quilt", "biomedclip"]: 423 | clip_model_ = nn.Sequential(clip_model.visual, model_linear) 424 | elif args.model_name in ["uni"]: 425 | clip_model_ = nn.Sequential(clip_model, model_linear) 426 | elif args.model_name in ["vit_google"]: 427 | setattr(clip_model, "classifier", model_linear) 428 | clip_model_ = clip_model 429 | else: 430 | raise RuntimeError( 431 | "Wrong model name used. Try clip, uni, biomedclip, vit_google or quilt." 432 | ) 433 | 434 | optimizer = torch.optim.AdamW( 435 | trainable_parameters_, 436 | weight_decay=1e-2, 437 | betas=(0.9, 0.999), 438 | lr=args.lr, 439 | ) 440 | 441 | num_steps = args.n_iters * len(train_loader) 442 | warmup_period = 50 443 | total_iters = warmup_period + num_steps 444 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 445 | optimizer, num_steps, eta_min=1e-6 446 | ) 447 | warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period) 448 | 449 | # training LoRA 450 | scaler = torch.cuda.amp.GradScaler() 451 | count_iters = 0 452 | 453 | while count_iters < total_iters: 454 | clip_model_.train() 455 | 456 | acc_train = 0 457 | tot_samples = 0 458 | loss_epoch = 0.0 459 | 460 | for i, (images, target) in enumerate(tqdm(train_loader)): 461 | 462 | images, target = images.cuda(), target.cuda() 463 | if args.model_name in ["clip"]: 464 | with torch.amp.autocast(device_type="cuda", dtype=torch.float16): 465 | output = clip_model_(images) 466 | else: 467 | output = clip_model_(images) 468 | 469 | if isinstance(output, ImageClassifierOutput): 470 | output = output.logits 471 | 472 | loss = F.cross_entropy(output, target) 473 | acc_train += cls_acc(output, target) * target.shape[0] 474 | loss_epoch += loss.item() * target.shape[0] 475 | tot_samples += target.shape[0] 476 | 477 | optimizer.zero_grad() 478 | scaler.scale(loss).backward() 479 | scaler.step(optimizer) 480 | scaler.update() 481 | 482 | with warmup_scheduler.dampening(): 483 | if warmup_scheduler.last_step + 1 >= warmup_period: 484 | scheduler.step() 485 | 486 | count_iters += 1 487 | 488 | if count_iters == total_iters: 489 | break 490 | 491 | acc_train /= tot_samples 492 | loss_epoch /= tot_samples 493 | 494 | current_lr = scheduler.get_last_lr()[0] 495 | for param_group in optimizer.param_groups: 496 | optimizer_lr = param_group["lr"] 497 | print( 498 | " OptLR: {:.6f}, LR: {:.6f}, Acc: {:.4f}, Loss: {:.4f}".format( 499 | optimizer_lr, current_lr, acc_train, loss_epoch 500 | ) 501 | ) 502 | 503 | # Eval 504 | if VALIDATION: 505 | acc_val, loss_val = evaluate_lora_uni(args, clip_model_, val_loader) 506 | print("**** Val accuracy: {:.2f}. ****\n".format(acc_val)) 507 | 508 | acc_test, _ = evaluate_lora_uni(args, clip_model_, test_loader) 509 | print("**** Final test accuracy: {:.2f}. ****\n".format(acc_test)) 510 | 511 | json_path = ( 512 | "/CECI/proj/medresyst/manon/US/results/US_CU/Clip-LoRA/lora_" 513 | + str(args.dataset) 514 | + "_" 515 | + str(args.model_name) 516 | + "_" 517 | + str(args.seed) 518 | + "_" 519 | + str(args.shots) 520 | + "_" 521 | + str(args.lr) 522 | + "_" 523 | + str(args.r) 524 | + str(args.percentage) 525 | + "_percent_results.json" 526 | ) 527 | 528 | with open( 529 | json_path, 530 | "w", 531 | ) as f: 532 | json.dump({"val_acc": acc_val, "test_acc": acc_test}, f) 533 | 534 | args.save_path = json_path.replace(".json", ".pt") 535 | save_lora(args, list_lora_layers) 536 | 537 | return 538 | -------------------------------------------------------------------------------- /clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | 20 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 24 | 25 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 26 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 27 | 28 | self.relu = nn.ReLU(inplace=True) 29 | self.downsample = None 30 | self.stride = stride 31 | 32 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 33 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 34 | self.downsample = nn.Sequential(OrderedDict([ 35 | ("-1", nn.AvgPool2d(stride)), 36 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 37 | ("1", nn.BatchNorm2d(planes * self.expansion)) 38 | ])) 39 | 40 | def forward(self, x: torch.Tensor): 41 | identity = x 42 | 43 | out = self.relu(self.bn1(self.conv1(x))) 44 | out = self.relu(self.bn2(self.conv2(out))) 45 | out = self.avgpool(out) 46 | out = self.bn3(self.conv3(out)) 47 | 48 | if self.downsample is not None: 49 | identity = self.downsample(x) 50 | 51 | out += identity 52 | out = self.relu(out) 53 | return out 54 | 55 | 56 | class AttentionPool2d(nn.Module): 57 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 58 | super().__init__() 59 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 60 | self.k_proj = nn.Linear(embed_dim, embed_dim) 61 | self.q_proj = nn.Linear(embed_dim, embed_dim) 62 | self.v_proj = nn.Linear(embed_dim, embed_dim) 63 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 64 | self.num_heads = num_heads 65 | 66 | def forward(self, x): 67 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 68 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 69 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 70 | x, _ = F.multi_head_attention_forward( 71 | query=x, key=x, value=x, 72 | embed_dim_to_check=x.shape[-1], 73 | num_heads=self.num_heads, 74 | q_proj_weight=self.q_proj.weight, 75 | k_proj_weight=self.k_proj.weight, 76 | v_proj_weight=self.v_proj.weight, 77 | in_proj_weight=None, 78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 79 | bias_k=None, 80 | bias_v=None, 81 | add_zero_attn=False, 82 | dropout_p=0, 83 | out_proj_weight=self.c_proj.weight, 84 | out_proj_bias=self.c_proj.bias, 85 | use_separate_proj_weight=True, 86 | training=self.training, 87 | need_weights=False 88 | ) 89 | 90 | return x[0] 91 | 92 | 93 | class ModifiedResNet(nn.Module): 94 | """ 95 | A ResNet class that is similar to torchvision's but contains the following changes: 96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 98 | - The final pooling layer is a QKV attention instead of an average pool 99 | """ 100 | 101 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 102 | super().__init__() 103 | self.output_dim = output_dim 104 | self.input_resolution = input_resolution 105 | 106 | # the 3-layer stem 107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 108 | self.bn1 = nn.BatchNorm2d(width // 2) 109 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 110 | self.bn2 = nn.BatchNorm2d(width // 2) 111 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 112 | self.bn3 = nn.BatchNorm2d(width) 113 | self.avgpool = nn.AvgPool2d(2) 114 | self.relu = nn.ReLU(inplace=True) 115 | 116 | # residual layers 117 | self._inplanes = width # this is a *mutable* variable used during construction 118 | self.layer1 = self._make_layer(width, layers[0]) 119 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 120 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 121 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 122 | 123 | embed_dim = width * 32 # the ResNet feature dimension 124 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 125 | 126 | def _make_layer(self, planes, blocks, stride=1): 127 | layers = [Bottleneck(self._inplanes, planes, stride)] 128 | 129 | self._inplanes = planes * Bottleneck.expansion 130 | for _ in range(1, blocks): 131 | layers.append(Bottleneck(self._inplanes, planes)) 132 | 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | def stem(x): 137 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 138 | x = self.relu(bn(conv(x))) 139 | x = self.avgpool(x) 140 | return x 141 | 142 | x = x.type(self.conv1.weight.dtype) 143 | x = stem(x) 144 | x = self.layer1(x) 145 | x = self.layer2(x) 146 | x = self.layer3(x) 147 | x = self.layer4(x) 148 | x = self.attnpool(x) 149 | 150 | return x 151 | 152 | 153 | class LayerNorm(nn.LayerNorm): 154 | """Subclass torch's LayerNorm to handle fp16.""" 155 | 156 | def forward(self, x: torch.Tensor): 157 | orig_type = x.dtype 158 | ret = super().forward(x.type(torch.float32)) 159 | return ret.type(orig_type) 160 | 161 | 162 | class QuickGELU(nn.Module): 163 | def forward(self, x: torch.Tensor): 164 | return x * torch.sigmoid(1.702 * x) 165 | 166 | 167 | class ResidualAttentionBlock(nn.Module): 168 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 169 | super().__init__() 170 | 171 | self.attn = nn.MultiheadAttention(d_model, n_head) 172 | self.ln_1 = LayerNorm(d_model) 173 | self.mlp = nn.Sequential(OrderedDict([ 174 | ("c_fc", nn.Linear(d_model, d_model * 4)), 175 | ("gelu", QuickGELU()), 176 | ("c_proj", nn.Linear(d_model * 4, d_model)) 177 | ])) 178 | self.ln_2 = LayerNorm(d_model) 179 | self.attn_mask = attn_mask 180 | 181 | def attention(self, x: torch.Tensor): 182 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 183 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 184 | 185 | def forward(self, x: torch.Tensor): 186 | x = x + self.attention(self.ln_1(x)) 187 | x = x + self.mlp(self.ln_2(x)) 188 | return x 189 | 190 | 191 | class Transformer(nn.Module): 192 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 193 | super().__init__() 194 | self.width = width 195 | self.layers = layers 196 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 197 | 198 | def forward(self, x: torch.Tensor): 199 | return self.resblocks(x) 200 | 201 | 202 | class VisionTransformer(nn.Module): 203 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 204 | super().__init__() 205 | self.input_resolution = input_resolution 206 | self.output_dim = output_dim 207 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 208 | 209 | scale = width ** -0.5 210 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 211 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 212 | self.ln_pre = LayerNorm(width) 213 | 214 | self.transformer = Transformer(width, layers, heads) 215 | 216 | self.ln_post = LayerNorm(width) 217 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 218 | 219 | def forward(self, x: torch.Tensor): 220 | x = self.conv1(x) # shape = [*, width, grid, grid] 221 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 222 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 223 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 224 | x = x + self.positional_embedding.to(x.dtype) 225 | x = self.ln_pre(x) 226 | 227 | x = x.permute(1, 0, 2) # NLD -> LND 228 | x = self.transformer(x) 229 | x = x.permute(1, 0, 2) # LND -> NLD 230 | 231 | x = self.ln_post(x[:, 0, :]) 232 | 233 | if self.proj is not None: 234 | x = x @ self.proj 235 | 236 | return x 237 | 238 | 239 | class CLIP(nn.Module): 240 | def __init__(self, 241 | embed_dim: int, 242 | # vision 243 | image_resolution: int, 244 | vision_layers: Union[Tuple[int, int, int, int], int], 245 | vision_width: int, 246 | vision_patch_size: int, 247 | # text 248 | context_length: int, 249 | vocab_size: int, 250 | transformer_width: int, 251 | transformer_heads: int, 252 | transformer_layers: int 253 | ): 254 | super().__init__() 255 | 256 | self.context_length = context_length 257 | 258 | if isinstance(vision_layers, (tuple, list)): 259 | vision_heads = vision_width * 32 // 64 260 | self.visual = ModifiedResNet( 261 | layers=vision_layers, 262 | output_dim=embed_dim, 263 | heads=vision_heads, 264 | input_resolution=image_resolution, 265 | width=vision_width 266 | ) 267 | else: 268 | vision_heads = vision_width // 64 269 | self.visual = VisionTransformer( 270 | input_resolution=image_resolution, 271 | patch_size=vision_patch_size, 272 | width=vision_width, 273 | layers=vision_layers, 274 | heads=vision_heads, 275 | output_dim=embed_dim 276 | ) 277 | 278 | self.transformer = Transformer( 279 | width=transformer_width, 280 | layers=transformer_layers, 281 | heads=transformer_heads, 282 | attn_mask=self.build_attention_mask() 283 | ) 284 | 285 | self.vocab_size = vocab_size 286 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 287 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 288 | self.ln_final = LayerNorm(transformer_width) 289 | 290 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 291 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 292 | 293 | self.initialize_parameters() 294 | 295 | def initialize_parameters(self): 296 | nn.init.normal_(self.token_embedding.weight, std=0.02) 297 | nn.init.normal_(self.positional_embedding, std=0.01) 298 | 299 | if isinstance(self.visual, ModifiedResNet): 300 | if self.visual.attnpool is not None: 301 | std = self.visual.attnpool.c_proj.in_features ** -0.5 302 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 303 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 304 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 305 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 306 | 307 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 308 | for name, param in resnet_block.named_parameters(): 309 | if name.endswith("bn3.weight"): 310 | nn.init.zeros_(param) 311 | 312 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 313 | attn_std = self.transformer.width ** -0.5 314 | fc_std = (2 * self.transformer.width) ** -0.5 315 | for block in self.transformer.resblocks: 316 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 317 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 318 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 319 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 320 | 321 | if self.text_projection is not None: 322 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 323 | 324 | def build_attention_mask(self): 325 | # lazily create causal attention mask, with full attention between the vision tokens 326 | # pytorch uses additive attention mask; fill with -inf 327 | mask = torch.empty(self.context_length, self.context_length) 328 | mask.fill_(float("-inf")) 329 | mask.triu_(1) # zero out the lower diagonal 330 | return mask 331 | 332 | @property 333 | def dtype(self): 334 | return self.visual.conv1.weight.dtype 335 | 336 | def encode_image(self, image): 337 | return self.visual(image.type(self.dtype)) 338 | 339 | def encode_text(self, text): 340 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 341 | 342 | x = x + self.positional_embedding.type(self.dtype) 343 | x = x.permute(1, 0, 2) # NLD -> LND 344 | x = self.transformer(x) 345 | x = x.permute(1, 0, 2) # LND -> NLD 346 | x = self.ln_final(x).type(self.dtype) 347 | 348 | # x.shape = [batch_size, n_ctx, transformer.width] 349 | # take features from the eot embedding (eot_token is the highest number in each sequence) 350 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 351 | 352 | return x 353 | 354 | def forward(self, image, text): 355 | image_features = self.encode_image(image) 356 | text_features = self.encode_text(text) 357 | 358 | # normalized features 359 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 360 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 361 | 362 | # cosine similarity as logits 363 | logit_scale = self.logit_scale.exp() 364 | logits_per_image = logit_scale * image_features @ text_features.t() 365 | logits_per_text = logits_per_image.t() 366 | 367 | # shape = [global_batch_size, global_batch_size] 368 | return logits_per_image, logits_per_text 369 | 370 | 371 | def convert_weights(model: nn.Module): 372 | """Convert applicable model parameters to fp16""" 373 | 374 | def _convert_weights_to_fp16(l): 375 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 376 | l.weight.data = l.weight.data.half() 377 | if l.bias is not None: 378 | l.bias.data = l.bias.data.half() 379 | 380 | if isinstance(l, nn.MultiheadAttention): 381 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 382 | tensor = getattr(l, attr) 383 | if tensor is not None: 384 | tensor.data = tensor.data.half() 385 | 386 | for name in ["text_projection", "proj"]: 387 | if hasattr(l, name): 388 | attr = getattr(l, name) 389 | if attr is not None: 390 | attr.data = attr.data.half() 391 | 392 | model.apply(_convert_weights_to_fp16) 393 | 394 | 395 | def build_model(state_dict: dict): 396 | vit = "visual.proj" in state_dict 397 | 398 | if vit: 399 | vision_width = state_dict["visual.conv1.weight"].shape[0] 400 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 401 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 402 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 403 | image_resolution = vision_patch_size * grid_size 404 | else: 405 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 406 | vision_layers = tuple(counts) 407 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 408 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 409 | vision_patch_size = None 410 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 411 | image_resolution = output_width * 32 412 | 413 | embed_dim = state_dict["text_projection"].shape[1] 414 | context_length = state_dict["positional_embedding"].shape[0] 415 | vocab_size = state_dict["token_embedding.weight"].shape[0] 416 | transformer_width = state_dict["ln_final.weight"].shape[0] 417 | transformer_heads = transformer_width // 64 418 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 419 | 420 | model = CLIP( 421 | embed_dim, 422 | image_resolution, vision_layers, vision_width, vision_patch_size, 423 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 424 | ) 425 | 426 | for key in ["input_resolution", "context_length", "vocab_size"]: 427 | if key in state_dict: 428 | del state_dict[key] 429 | 430 | convert_weights(model) 431 | model.load_state_dict(state_dict) 432 | return model.eval() 433 | -------------------------------------------------------------------------------- /datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import random 4 | from collections import defaultdict 5 | 6 | import torch 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | 10 | import torchvision.datasets as datasets 11 | 12 | 13 | imagenet_classes = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", 14 | "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", 15 | "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", 16 | "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", 17 | "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", 18 | "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", 19 | "box turtle", "banded gecko", "green iguana", "Carolina anole", 20 | "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", 21 | "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", 22 | "American alligator", "triceratops", "worm snake", "ring-necked snake", 23 | "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", 24 | "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", 25 | "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", 26 | "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", 27 | "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", 28 | "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", 29 | "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", 30 | "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", 31 | "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", 32 | "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", 33 | "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", 34 | "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", 35 | "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", 36 | "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", 37 | "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", 38 | "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", 39 | "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", 40 | "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", 41 | "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", 42 | "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", 43 | "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", 44 | "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", 45 | "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", 46 | "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", 47 | "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", 48 | "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", 49 | "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", 50 | "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", 51 | "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", 52 | "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", 53 | "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", 54 | "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", 55 | "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", 56 | "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", 57 | "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", 58 | "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", 59 | "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", 60 | "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", 61 | "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", 62 | "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", 63 | "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", 64 | "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", 65 | "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", 66 | "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", 67 | "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", 68 | "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", 69 | "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", 70 | "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", 71 | "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", 72 | "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", 73 | "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", 74 | "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", 75 | "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", 76 | "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", 77 | "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", 78 | "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", 79 | "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", 80 | "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", 81 | "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", 82 | "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", 83 | "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", 84 | "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", 85 | "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", 86 | "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", 87 | "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", 88 | "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", 89 | "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", 90 | "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", 91 | "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", 92 | "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", 93 | "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", 94 | "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", 95 | "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", 96 | "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", 97 | "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", 98 | "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", 99 | "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", 100 | "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", 101 | "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", 102 | "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", 103 | "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", 104 | "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", 105 | "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", 106 | "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", 107 | "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", 108 | "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", 109 | "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", 110 | "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", 111 | "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", 112 | "freight car", "French horn", "frying pan", "fur coat", "garbage truck", 113 | "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", 114 | "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", 115 | "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", 116 | "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", 117 | "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", 118 | "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", 119 | "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", 120 | "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", 121 | "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", 122 | "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", 123 | "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", 124 | "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", 125 | "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", 126 | "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", 127 | "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", 128 | "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", 129 | "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", 130 | "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", 131 | "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", 132 | "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", 133 | "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", 134 | "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", 135 | "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", 136 | "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", 137 | "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", 138 | "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", 139 | "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", 140 | "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", 141 | "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", 142 | "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", 143 | "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", 144 | "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", 145 | "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", 146 | "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", 147 | "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", 148 | "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", 149 | "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", 150 | "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", 151 | "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", 152 | "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", 153 | "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", 154 | "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", 155 | "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", 156 | "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", 157 | "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", 158 | "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", 159 | "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", 160 | "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", 161 | "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", 162 | "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", 163 | "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", 164 | "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", 165 | "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", 166 | "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", 167 | "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", 168 | "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", 169 | "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", 170 | "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", 171 | "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", 172 | "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", 173 | "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", 174 | "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", 175 | "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", 176 | "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", 177 | "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"] 178 | """ 179 | imagenet_templates = ["itap of a {}.", 180 | "a bad photo of the {}.", 181 | "a origami {}.", 182 | "a photo of the large {}.", 183 | "a {} in a video game.", 184 | "art of the {}.", 185 | "a photo of the small {}."] 186 | 187 | """ 188 | imagenet_templates = ["a photo of a {}."] 189 | 190 | class ImageNet(): 191 | 192 | dataset_dir = 'imagenet' 193 | 194 | def __init__(self, root, num_shots, preprocess, train_preprocess=None, test_preprocess=None): 195 | 196 | self.dataset_dir = os.path.join(root, self.dataset_dir) 197 | self.image_dir = os.path.join(self.dataset_dir, 'images') 198 | 199 | if train_preprocess is None: 200 | train_preprocess = transforms.Compose([ 201 | transforms.RandomResizedCrop(size=224, scale=(0.08, 1), interpolation=transforms.InterpolationMode.BICUBIC), 202 | transforms.RandomHorizontalFlip(p=0.5), 203 | transforms.ToTensor(), 204 | transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) 205 | ]) 206 | 207 | if test_preprocess is None: 208 | test_preprocess = preprocess 209 | 210 | self.train_x = datasets.ImageFolder(os.path.join(os.path.join(self.dataset_dir, 'train')), transform=train_preprocess) 211 | self.val = datasets.ImageFolder(os.path.join(os.path.join(self.dataset_dir, 'train')), transform=preprocess) 212 | self.test = datasets.ImageFolder(os.path.join(os.path.join(self.dataset_dir, 'val')), transform=test_preprocess) 213 | 214 | num_shots_val = min(4, num_shots) 215 | 216 | self.template = imagenet_templates 217 | self.classnames = imagenet_classes 218 | 219 | split_by_label_dict = defaultdict(list) 220 | for i in range(len(self.train_x.imgs)): 221 | split_by_label_dict[self.train_x.targets[i]].append(self.train_x.imgs[i]) 222 | imgs = [] 223 | targets = [] 224 | imgs_val = [] 225 | targets_val = [] 226 | for label, items in split_by_label_dict.items(): 227 | samples = random.sample(items, num_shots + num_shots_val) 228 | imgs = imgs + samples[0:num_shots] 229 | imgs_val = imgs_val + samples[num_shots:num_shots+num_shots_val] 230 | targets = targets + [label for i in range(num_shots)] 231 | targets_val = targets_val + [label for i in range(num_shots_val)] 232 | 233 | self.train_x.imgs = imgs 234 | self.train_x.targets = targets 235 | self.train_x.samples = imgs 236 | 237 | self.val.imgs = imgs_val 238 | self.val.targets = targets_val 239 | self.val.samples = imgs_val --------------------------------------------------------------------------------