├── .gitignore ├── README.md ├── data ├── __init__.py ├── class_files │ ├── DTD │ │ └── class_names.txt │ ├── EuroSAT │ │ └── class_names.txt │ ├── FGVCAircraft │ │ └── labels.txt │ ├── Flowers102 │ │ └── class_names.txt │ └── MNIST │ │ └── labels.txt ├── data_splits │ ├── DTD.json │ ├── EuroSAT.json │ ├── FGVCAircraft.json │ ├── Flowers102.json │ ├── MNIST.json │ └── RESICS45.json ├── dataset.py └── dataset_prompts.py ├── imgs └── overview.png ├── methods ├── __init__.py ├── clip_baseline.py ├── main_CLIP.py ├── main_SSL.py ├── main_TRZSL.py ├── main_UL.py ├── semi_supervised_learning │ ├── __init__.py │ ├── multimodal_fpl.py │ ├── multimodal_prompt.py │ ├── pseudo_iterative.py │ ├── textual_fpl.py │ ├── textual_prompt.py │ ├── visual_fpl.py │ └── visual_prompt.py ├── transductive_zsl │ ├── __init__.py │ ├── multimodal_fpl.py │ ├── multimodal_prompt.py │ ├── pseudo_iterative.py │ ├── textual_fpl.py │ ├── textual_prompt.py │ ├── visual_fpl.py │ └── visual_prompt.py └── unsupervised_learning │ ├── __init__.py │ ├── multimodal_fpl.py │ ├── multimodal_prompt.py │ ├── pseudo_iterative.py │ ├── textual_fpl.py │ ├── textual_prompt.py │ ├── visual_fpl.py │ └── visual_prompt.py ├── methods_config ├── accelerate_config.yml ├── accelerate_localtest_config.yml ├── clip_config.yml ├── evaluation_config.yml ├── grip_multimodal_config.yml ├── grip_textual_config.yml ├── grip_visual_config.yml ├── iterative_multimodal_fpl_config.yml ├── iterative_textual_fpl_config.yml ├── iterative_visual_fpl_config.yml ├── multimodal_fpl_config.yml ├── multimodal_prompt_config.yml ├── textual_fpl_config.yml ├── textual_prompt_config.yml ├── visual_fpl_config.yml └── visual_prompt_config.yml ├── models ├── __init__.py ├── clip_encoders.py └── prompts_models.py ├── requirements.txt ├── run_main_clip.py ├── run_main_ssl.py ├── run_main_trzsl.py ├── run_main_ul.py ├── scripts ├── run_clip.sh ├── run_prompts_ssl.sh ├── run_prompts_trzsl.sh ├── run_pseudolabels_ssl.sh ├── run_pseudolabels_trzsl.sh └── run_pseudolabels_ul.sh ├── setup.sh └── utils ├── __init__.py ├── clip_pseudolabels.py ├── compute_metrics.py ├── prepare_data.py ├── schedulers.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | ######################## 2 | # Taglets specific stuff 3 | ######################## 4 | 5 | # Data 6 | taglets/sql_data 7 | test/MNIST 8 | 9 | # Saved models 10 | trained_models 11 | taglets/trained_models 12 | test/trained_models 13 | 14 | ######################## 15 | # Other stuff 16 | ######################## 17 | 18 | # OS X 19 | .DS_Store 20 | 21 | # Byte-compiled / optimized / DLL files 22 | __pycache__/ 23 | *.py[cod] 24 | *$py.class 25 | 26 | # C extensions 27 | *.so 28 | 29 | # Distribution / packaging 30 | .Python 31 | build/ 32 | develop-eggs/ 33 | dist/ 34 | downloads/ 35 | eggs/ 36 | .eggs/ 37 | lib/ 38 | lib64/ 39 | parts/ 40 | sdist/ 41 | var/ 42 | wheels/ 43 | pip-wheel-metadata/ 44 | share/python-wheels/ 45 | *.egg-info/ 46 | .installed.cfg 47 | *.egg 48 | MANIFEST 49 | 50 | # PyInstaller 51 | # Usually these files are written by a python script from a template 52 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 53 | *.manifest 54 | *.spec 55 | 56 | # Installer logs 57 | pip-log.txt 58 | pip-delete-this-directory.txt 59 | 60 | # Unit test / coverage reports 61 | htmlcov/ 62 | .tox/ 63 | .nox/ 64 | .coverage 65 | .coverage.* 66 | .cache 67 | nosetests.xml 68 | coverage.xml 69 | *.cover 70 | *.py,cover 71 | .hypothesis/ 72 | .pytest_cache/ 73 | 74 | # Translations 75 | *.mo 76 | *.pot 77 | 78 | # Django stuff: 79 | *.log 80 | local_settings.py 81 | db.sqlite3 82 | db.sqlite3-journal 83 | 84 | # Flask stuff: 85 | instance/ 86 | .webassets-cache 87 | 88 | # Scrapy stuff: 89 | .scrapy 90 | 91 | # Sphinx documentation 92 | docs/_build/ 93 | 94 | # PyBuilder 95 | target/ 96 | 97 | # Jupyter Notebook 98 | .ipynb_checkpoints 99 | 100 | # IPython 101 | profile_default/ 102 | ipython_config.py 103 | 104 | # pyenv 105 | .python-version 106 | 107 | # pipenv 108 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 109 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 110 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 111 | # install all needed dependencies. 112 | #Pipfile.lock 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | .idea 151 | .vscode 152 | 153 | # Data 154 | /taglets/sql_data 155 | .DS_Store 156 | /taglets/test_data 157 | 158 | # Saved models 159 | /trained_models 160 | /taglets/trained_models 161 | /test/trained_models/* 162 | 163 | # predefined 164 | predefined 165 | 166 | # Notebook 167 | API.ipynb 168 | taglets/JPL_test.ipynb 169 | taglets/task/JPL_test.ipynb 170 | api.ipynb 171 | Look_dataset.ipynb 172 | pseudolabels/ 173 | results/ 174 | trained_prompts/ 175 | #Process_results_file.ipynb 176 | notebooks/ 177 | evaluation/ 178 | representations/ 179 | logs/ 180 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Enhancing CLIP with CLIP: Exploring Pseudolabeling for Limited-Labeled Prompt Tuning 2 | 3 | This repository contains the code to reproduce the experiments from the paper ["Enhancing CLIP with CLIP: Pseudolabeling for Limiter-Labeled Prompt Tuning"](http://arxiv.org/abs/2306.01669). The paper explores the effect of leveraging pseudolabels to adapt vision-language models such as CLIP to downstream tasks in a unified way across prompt modalities, learning paradigms, and training strategies. 4 | 5 |
6 | Design space 7 |
8 | 9 | 10 | ## Table of Contents 11 | 12 | - [Setup](#setup) 13 | - [Data](#data) 14 | - [Experiments](#experiments) 15 | - [Results](#results) 16 | - [Citation](#citation) 17 | 18 | ## Setup 19 | 20 | To set up the project environment, follow these steps: 21 | 22 | 1. Ensure that you have Python version 3.7.4 installed. You can check the Python version by running the following command: 23 | 24 | ```bash 25 | python --version 26 | ``` 27 | 28 | 2. Clone the repository by running the following command: 29 | 30 | ```bash 31 | git clone https://github.com/BatsResearch/menghini-enhanceCLIPwithCLIP-code.git 32 | ``` 33 | 34 | 3. Navigate to the root folder and execute the `setup.sh` script to install the required dependencies, including `pytorch`. Note that we assume the installation of a CUDA-compatible version of `pytorch` since GPUs are recommended for running the experiments. If you don't have access to GPUs, you can modify the script to remove the CUDA requirement. 35 | 36 | ```bash 37 | cd menghini-enhanceCLIPwithCLIP-code/ 38 | bash setup.sh 39 | ``` 40 | 41 | [Back to Table of Contents](#table-of-contents) 42 | 43 | ## Data 44 | 45 | The experiments are conducted on the following six datasets: **F**lowers102, **R**ECSIS45, **F**GVC-Aircraft, **M**NIST, **E**uroSAT, and **D**TD (FRAMED). We use the train and test splits provided in the paper [ELEVATER: A Benchmark and Toolkit for Evaluating Language-Augmented Visual Models](https://openreview.net/pdf?id=hGl8rsmNXzs). 46 | 47 | To access the FRAMED dataset, you can download it [here](https://drive.google.com/file/d/1_ns7regg8dfAAGmYcmCXa5ryuJeOoug-/view?usp=share_link). After downloading, unzip the folder to obtain the required data. 48 | 49 | If you encounter any issues with the download or prefer an alternative method, you can follow these steps: 50 | 51 | 1. Download the data by following the instructions provided [here](https://github.com/Computer-Vision-in-the-Wild/DataDownload). 52 | 2. Rename the folders as follows: 53 | - `dtd/` to `DTD/` 54 | - `eurosat_clip/` to `EuroSAT/` 55 | - `fgvc-aircraft-2013b-variants102/` to `FGVCAircraft/` 56 | - `oxford-flower-102/` to `Flowers102/` 57 | - `mnist/` to `MNIST/` 58 | - `resisc45_clip/` to `RESICS45/` 59 | 3. Ensure that each folder contains the following files: 60 | - `DTD/` should contain the [`class_names.txt`](https://github.com/BatsResearch/menghini-neurips23-code/blob/main/data/class_files/DTD/class_names.txt) file 61 | - `EuroSAT/` should contain the [`class_names.txt`](https://github.com/BatsResearch/menghini-neurips23-code/blob/main/data/class_files/EuroSAT/class_names.txt) file 62 | - `FGVCAircraft/` should contain the [`labels.txt`](https://github.com/BatsResearch/menghini-neurips23-code/blob/main/data/class_files/FGVCAircraft/labels.txt) file 63 | - `Flowers102/` should contain the [`class_names.txt`](https://github.com/BatsResearch/menghini-neurips23-code/blob/main/data/class_files/Flowers102/class_names.txt) file 64 | - `MNIST/` should contain the [`labels.txt`](https://github.com/BatsResearch/menghini-neurips23-code/blob/main/data/class_files/MNIST/labels.txt) file 65 | 66 | [Back to Table of Contents](#table-of-contents) 67 | 68 | ## Experiments 69 | 70 | Before running the experiments, create the following folders to save prompts, pseudolabels, and results. 71 | 72 | ```bash 73 | mkdir pseudolabels 74 | mkdir logs 75 | mkdir trained_prompts 76 | mkdir evaluation 77 | ``` 78 | 79 | We organized the code such that for each learning paradigm we can run any combination of prompt modality and training strategy. 80 | 81 | ### Baselines 82 | 83 | 1. CLIP [1] 84 | ``` 85 | bash scripts/run_clip.sh 86 | ``` 87 | 2. Standard prompt tuning without pseudolabels: CoOp [2], VPT [3], UPT [4]. 88 | - For SSL: 89 | ``` 90 | bash scripts/run_prompts_ssl.sh 91 | ``` 92 | - For TRZSL: 93 | ``` 94 | bash scripts/run_prompts_trzsl.sh 95 | ``` 96 | 97 | ### Training strategies employing pseudolabels 98 | 99 | To execute the training strategies employing pseudolabels across prompt modalities run the following 100 | 101 | - For SSL: 102 | ``` 103 | bash scripts/run_pseudolabels_ssl.sh 104 | ``` 105 | 106 | - For UL: 107 | ``` 108 | bash scripts/run_pseudolabels_ul.sh 109 | ``` 110 | 111 | - For TRZSL: 112 | ``` 113 | bash scripts/run_pseudolabels_trzsl.sh 114 | ``` 115 | 116 | Logs of the runs are save in `logs/`. 117 | The folder `pseudolabels/` gathers the pseudolabeled used for each prompt modality, leanring paradigms, and training strategies. For iterative methods, we store them at each iteration. 118 | In `trained_prompts/`, we save the prompts used to make predictions. For iterative methods, we save the prompts at each iteration. 119 | While in `evaluation/` there will be the predictions of each method. 120 | 121 | [Back to Table of Contents](#table-of-contents) 122 | 123 | 124 | [1] [Learning Transferable Visual Models From Natural Language Supervision, Radford et al. 2021](https://arxiv.org/pdf/2103.00020.pdf) 125 | 126 | [2] [Learning to prompt for vision-language models, Zhou et al. 2021](https://arxiv.org/pdf/2109.01134.pdf) 127 | 128 | [3] [Visual prompt tuning, Jia et al. 2022](https://arxiv.org/pdf/2203.12119.pdf) 129 | 130 | [4] [Unified vision and language prompt learning, Zang et al., 2022](https://arxiv.org/pdf/2210.07225.pdf) 131 | 132 | ## Results 133 | 134 | To be filled with the results obtained from the experiments. 135 | 136 | 137 | **Textual prompts** 138 | | | | | | | | | | | | 139 | |---------------------|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| 140 | | | | **Flowers102** | | | **RESICS45** | | | **FGVCAircraft** | | 141 | | **Method** | **SSL** | **UL** | **TRZSL** | **SSL** | **UL** | **TRZSL** | **SSL** | **UL** | **TRZSL** | 142 | | CLIP [1] | 63.7 | 63.7 | 63.4 | 54.5 | 54.5 | 54.5 | **17.6** | **17.6** | 17.9 | 143 | | CoOp [2] | 76.8 | - | 63.2 | 58.5 | - | 63.4 | 14.9 | - | 21.7 | 144 | | GRIP | **83.6** | **69.8** | **86.3** | **74.1** | **70.6** | **81.1** | 17.0 | 15.2 | **26.1** | 145 | | | | **MNIST** | | | **EuroSAT** | | | **DTD** | | 146 | | CLIP | 25.1 | 25.1 | 20.8 | 32.9 | 32.9 | 30.5 | 43.2 | 43.2 | 43.4 | 147 | | CoOp [2] | 56.4 | - | 21.2 | 59.5 | - | 49.7 | 37.1 | - | 46.3 | 148 | | GRIP | **71.8** | **67.9** | **74.1** | **58.7** | **57.2** | **92.3** | **56.1** | **46.1** | **65.3** | 149 | 150 | **Visual prompts** 151 | | | | | | | | | | | | 152 | |---------------------|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| 153 | | | | **Flowers102** | | | **RESICS45** | | | **FGVCAircraft** | | 154 | | **Method** | **SSL** | **UL** | **TRZSL** | **SSL** | **UL** | **TRZSL** | **SSL** | **UL** | **TRZSL** | 155 | | CLIP [1] | 63.7 | **63.7** | 63.4 | 54.5 | 54.5 | 54.5 | 17.6 | **17.6** | 17.9 | 156 | | VPT [3] | 63.7 | - | 64.7 | 60.8 | - | 67.1 | 17.8 | - | 26.7 | 157 | | GRIP | **67.9** | **63.1** | **77.2** | **71.2** | **68.4** | **82.2** | **19.4** | **17.5** | **26.4** | 158 | | | | **MNIST** | | | **EuroSAT** | | | **DTD** | | 159 | | CLIP [1] | 25.1 | 25.1 | 20.8 | 32.9 | 32.9 | 30.5 | 43.2 | 43.2 | 43.4 | 160 | | VPT [3] | 42.5 | - | 25.5 | 47.1 | - | 62.2 | 36.4 | - | 44.2 | 161 | | GRIP | **69.7** | **68.0** | **69.5** | **63.5** | **63.7** | **97.0** | **54.6** | **50.5** | **62.8** | 162 | 163 | **Multimodal prompts** 164 | | | | | | | | | | | | 165 | |---------------------|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| 166 | | | | **Flowers102** | | | **RESICS45** | | | **FGVCAircraft** | | 167 | | **Method** | **SSL** | **UL** | **TRZSL** | **SSL** | **UL** | **TRZSL** | **SSL** | **UL** | **TRZSL** | 168 | | CLIP [1] | 63.7 | **63.7** | 63.4 | 54.5 | 54.5 | 54.5 | **17.6** | **17.6** | **17.9** | 169 | | UPT [4] | 68.0 | - | 61.1 | 62.8 | - | 58.8 | 11.1 | - | 15.9 | 170 | | GRIP | **74.6** | **64.8** | **82.0** | **73.7** | **69.4** | **82.2** | **17.4** | 14.7 | **17.9** | 171 | | | | **MNIST** | | | **EuroSAT** | | | **DTD** | | 172 | | CLIP | 25.1 | 25.1 | 20.8 | 32.9 | 32.9 | 30.5 | 43.2 | 43.2 | 43.4 | 173 | | UPT [4] | **64.4** | - | 63.6 | 68.9 | - | 60.4 | 43.7 | - | 36.9 | 174 | | GRIP | **65.9** | **68.2** | **73.8** | **60.4** | **61.5** | **95.5** | **54.1** | **47.4** | **64.4** | 175 | 176 | 177 | 178 | [Back to Table of Contents](#table-of-contents) 179 | 180 | ## Citation 181 | 182 | If you find this work helpful, please consider citing the following paper: 183 | 184 | ``` 185 | @inproceedings{ 186 | menghini2023enhancing, 187 | title={Enhancing {CLIP} with {CLIP}: Exploring Pseudolabeling for Limited-Label Prompt Tuning}, 188 | author={Cristina Menghini and Andrew Delworth and Stephen Bach}, 189 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, 190 | year={2023}, 191 | url={https://openreview.net/forum?id=2b9aY2NgXE} 192 | } 193 | ``` 194 | 195 | 196 | [Back to Table of Contents](#table-of-contents) 197 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import ( 2 | CUB, 3 | DTD, 4 | MNIST, 5 | RESICS45, 6 | CustomDataset, 7 | EuroSAT, 8 | FGVCAircraft, 9 | Flowers102, 10 | ) 11 | from .dataset_prompts import dataset_custom_prompts 12 | -------------------------------------------------------------------------------- /data/class_files/DTD/class_names.txt: -------------------------------------------------------------------------------- 1 | banded 2 | blotchy 3 | braided 4 | bubbly 5 | bumpy 6 | chequered 7 | cobwebbed 8 | cracked 9 | crosshatched 10 | crystalline 11 | dotted 12 | fibrous 13 | flecked 14 | freckled 15 | frilly 16 | gauzy 17 | grid 18 | grooved 19 | honeycombed 20 | interlaced 21 | knitted 22 | lacelike 23 | lined 24 | marbled 25 | matted 26 | meshed 27 | paisley 28 | perforated 29 | pitted 30 | pleated 31 | polka-dotted 32 | porous 33 | potholed 34 | scaly 35 | smeared 36 | spiralled 37 | sprinkled 38 | stained 39 | stratified 40 | striped 41 | studded 42 | swirly 43 | veined 44 | waffled 45 | woven 46 | wrinkled 47 | zigzagged -------------------------------------------------------------------------------- /data/class_files/EuroSAT/class_names.txt: -------------------------------------------------------------------------------- 1 | annual crop land 2 | forest 3 | brushland or shrubland 4 | highway or road 5 | industrial buildings or commercial buildings 6 | pasture land 7 | permanent crop land 8 | residential buildings or homes or apartments 9 | river 10 | lake or sea -------------------------------------------------------------------------------- /data/class_files/FGVCAircraft/labels.txt: -------------------------------------------------------------------------------- 1 | 707-320 2 | 727-200 3 | 737-200 4 | 737-300 5 | 737-400 6 | 737-500 7 | 737-600 8 | 737-700 9 | 737-800 10 | 737-900 11 | 747-100 12 | 747-200 13 | 747-300 14 | 747-400 15 | 757-200 16 | 757-300 17 | 767-200 18 | 767-300 19 | 767-400 20 | 777-200 21 | 777-300 22 | A300B4 23 | A310 24 | A318 25 | A319 26 | A320 27 | A321 28 | A330-200 29 | A330-300 30 | A340-200 31 | A340-300 32 | A340-500 33 | A340-600 34 | A380 35 | An-12 36 | ATR-42 37 | ATR-72 38 | BAE 146-200 39 | BAE 146-300 40 | BAE-125 41 | Beechcraft 1900 42 | Boeing 717 43 | C-130 44 | C-47 45 | Cessna 172 46 | Cessna 208 47 | Cessna 525 48 | Cessna 560 49 | Challenger 600 50 | CRJ-200 51 | CRJ-700 52 | CRJ-900 53 | DC-10 54 | DC-3 55 | DC-6 56 | DC-8 57 | DC-9-30 58 | DH-82 59 | DHC-1 60 | DHC-6 61 | DHC-8-100 62 | DHC-8-300 63 | Dornier 328 64 | DR-400 65 | E-170 66 | E-190 67 | E-195 68 | EMB-120 69 | Embraer Legacy 600 70 | ERJ 135 71 | ERJ 145 72 | Eurofighter Typhoon 73 | F-16A-B 74 | F-A-18 75 | Falcon 2000 76 | Falcon 900 77 | Fokker 100 78 | Fokker 50 79 | Fokker 70 80 | Global Express 81 | Gulfstream IV 82 | Gulfstream V 83 | Hawk T1 84 | Il-76 85 | L-1011 86 | MD-11 87 | MD-80 88 | MD-87 89 | MD-90 90 | Metroliner 91 | Model B200 92 | PA-28 93 | Saab 2000 94 | Saab 340 95 | Spitfire 96 | SR-20 97 | Tornado 98 | Tu-134 99 | Tu-154 100 | Yak-42 -------------------------------------------------------------------------------- /data/class_files/Flowers102/class_names.txt: -------------------------------------------------------------------------------- 1 | pink primrose 2 | hard-leaved pocket orchid 3 | canterbury bells 4 | sweet pea 5 | english marigold 6 | tiger lily 7 | moon orchid 8 | bird of paradise 9 | monkshood 10 | globe thistle 11 | snapdragon 12 | colt's foot 13 | king protea 14 | spear thistle 15 | yellow iris 16 | globe flower 17 | purple coneflower 18 | peruvian lily 19 | balloon flower 20 | giant white arum lily 21 | fire lily 22 | pincushion flower 23 | fritillary 24 | red ginger 25 | grape hyacinth 26 | corn poppy 27 | prince of wales feathers 28 | stemless gentian 29 | artichoke 30 | sweet william 31 | carnation 32 | garden phlox 33 | love in the mist 34 | mexican aster 35 | alpine sea holly 36 | ruby-lipped cattleya 37 | cape flower 38 | great masterwort 39 | siam tulip 40 | lenten rose 41 | barbeton daisy 42 | daffodil 43 | sword lily 44 | poinsettia 45 | bolero deep blue 46 | wallflower 47 | marigold 48 | buttercup 49 | oxeye daisy 50 | common dandelion 51 | petunia 52 | wild pansy 53 | primula 54 | sunflower 55 | pelargonium 56 | bishop of llandaff 57 | gaura 58 | geranium 59 | orange dahlia 60 | pink and yellow dahlia 61 | cautleya spicata 62 | japanese anemone 63 | black-eyed susan 64 | silverbush 65 | californian poppy 66 | osteospermum 67 | spring crocus 68 | bearded iris 69 | windflower 70 | tree poppy 71 | gazania 72 | azalea 73 | water lily 74 | rose 75 | thorn apple 76 | morning glory 77 | passion flower 78 | lotus 79 | toad lily 80 | anthurium 81 | frangipani 82 | clematis 83 | hibiscus 84 | columbine 85 | desert-rose 86 | tree mallow 87 | magnolia 88 | cyclamen 89 | watercress 90 | canna lily 91 | hippeastrum 92 | bee balm 93 | air plant 94 | foxglove 95 | bougainvillea 96 | camellia 97 | mallow 98 | mexican petunia 99 | bromelia 100 | blanket flower 101 | trumpet creeper 102 | blackberry lily -------------------------------------------------------------------------------- /data/class_files/MNIST/labels.txt: -------------------------------------------------------------------------------- 1 | 0 2 | 1 3 | 2 4 | 3 5 | 4 6 | 5 7 | 6 8 | 7 9 | 8 10 | 9 -------------------------------------------------------------------------------- /data/data_splits/DTD.json: -------------------------------------------------------------------------------- 1 | { 2 | "split_500": { 3 | "seen": ["knitted", "pitted", "studded", "bumpy", "spiralled", "scaly", "polka-dotted", "veined", "wrinkled", "banded", "flecked", "stained", "chequered", "sprinkled", "bubbly", "grid", "lined", "crystalline", "fibrous", "meshed", "zigzagged", "pleated", "braided", "perforated", "potholed", "waffled", "dotted", "matted", "gauzy"], 4 | "unseen": ["blotchy", "smeared", "cobwebbed", "cracked", "crosshatched", "stratified", "striped", "swirly", "woven", "freckled", "frilly", "grooved", "honeycombed", "interlaced", "lacelike", "marbled", "paisley", "porous"] 5 | }, 6 | "split_0": { 7 | "seen": ["pitted", "scaly", "polka-dotted", "bumpy", "honeycombed", "fibrous", "veined", "porous", "lined", "dotted", "perforated", "potholed", "pleated", "waffled", "braided", "wrinkled", "paisley", "gauzy", "meshed", "grid", "studded", "knitted", "swirly", "crosshatched", "freckled", "chequered", "grooved", "smeared", "frilly"], 8 | "unseen": ["banded", "blotchy", "bubbly", "spiralled", "sprinkled", "cobwebbed", "cracked", "stained", "crystalline", "stratified", "striped", "flecked", "woven", "zigzagged", "interlaced", "lacelike", "marbled", "matted"] 9 | }, 10 | "split_200": { 11 | "seen": ["pitted", "pleated", "polka-dotted", "sprinkled", "grooved", "knitted", "matted", "wrinkled", "honeycombed", "chequered", "braided", "zigzagged", "spiralled", "banded", "waffled", "crosshatched", "bubbly", "smeared", "dotted", "porous", "woven", "freckled", "lined", "potholed", "lacelike", "marbled", "stratified", "scaly", "studded"], 12 | "unseen": ["blotchy", "bumpy", "stained", "cobwebbed", "cracked", "striped", "crystalline", "swirly", "fibrous", "flecked", "veined", "frilly", "gauzy", "grid", "interlaced", "meshed", "paisley", "perforated"] 13 | } 14 | } -------------------------------------------------------------------------------- /data/data_splits/EuroSAT.json: -------------------------------------------------------------------------------- 1 | { 2 | "split_500": { 3 | "seen": ["industrial buildings or commercial buildings", "brushland or shrubland", "lake or sea", "highway or road", "annual crop land", "pasture land"], 4 | "unseen": ["river", "forest", "permanent crop land", "residential buildings or homes or apartments"] 5 | }, 6 | "split_0": { 7 | "seen": ["brushland or shrubland", "river", "industrial buildings or commercial buildings", "lake or sea", "forest", "permanent crop land"], 8 | "unseen": ["annual crop land", "highway or road", "pasture land", "residential buildings or homes or apartments"] 9 | }, 10 | "split_200": { 11 | "seen": ["river", "highway or road", "pasture land", "permanent crop land", "forest", "residential buildings or homes or apartments"], 12 | "unseen": ["annual crop land", "lake or sea", "brushland or shrubland", "industrial buildings or commercial buildings"] 13 | } 14 | } -------------------------------------------------------------------------------- /data/data_splits/FGVCAircraft.json: -------------------------------------------------------------------------------- 1 | { 2 | "split_500": { 3 | "seen": ["Tu-134", "Spitfire", "Challenger 600", "737-700", "F-A-18", "E-170", "727-200", "A300B4", "Falcon 2000", "DR-400", "MD-87", "CRJ-700", "ERJ 145", "Falcon 900", "MD-80", "DC-10", "Il-76", "Global Express", "Gulfstream IV", "Saab 340", "Yak-42", "CRJ-900", "L-1011", "A330-200", "A321", "747-300", "DC-3", "A310", "ATR-42", "CRJ-200", "Hawk T1", "Fokker 100", "ATR-72", "PA-28", "A319", "707-320", "A318", "A320", "BAE-125", "747-200", "ERJ 135", "737-800", "SR-20", "BAE 146-300", "Beechcraft 1900", "Cessna 172", "A340-300", "EMB-120", "737-900", "737-400", "Cessna 208", "MD-90", "777-300", "A340-600", "737-600", "737-300", "DHC-1", "DC-6", "A380", "C-47", "767-200", "BAE 146-200"], 4 | "unseen": ["737-200", "737-500", "747-100", "747-400", "757-200", "757-300", "767-300", "767-400", "777-200", "A330-300", "A340-200", "A340-500", "An-12", "Boeing 717", "C-130", "Cessna 525", "Cessna 560", "DC-8", "DC-9-30", "DH-82", "DHC-6", "DHC-8-100", "DHC-8-300", "Dornier 328", "E-190", "E-195", "Embraer Legacy 600", "Eurofighter Typhoon", "F-16A-B", "Fokker 50", "Fokker 70", "Gulfstream V", "MD-11", "Metroliner", "Model B200", "Saab 2000", "Tornado", "Tu-154"] 5 | }, 6 | "split_0": { 7 | "seen": ["A321", "MD-80", "737-200", "DC-8", "Falcon 900", "Saab 340", "767-200", "F-A-18", "DC-6", "SR-20", "DC-3", "Saab 2000", "Fokker 70", "747-400", "737-700", "A340-300", "A310", "A319", "A380", "737-800", "C-47", "Dornier 328", "737-300", "Eurofighter Typhoon", "Cessna 208", "Challenger 600", "737-600", "Yak-42", "Hawk T1", "Fokker 100", "DHC-8-100", "Gulfstream IV", "Model B200", "Embraer Legacy 600", "CRJ-900", "A330-200", "767-400", "DC-9-30", "DR-400", "Falcon 2000", "727-200", "DHC-8-300", "C-130", "Boeing 717", "737-400", "757-300", "767-300", "Beechcraft 1900", "BAE 146-300", "737-500", "PA-28", "DHC-6", "707-320", "An-12", "A330-300", "CRJ-700", "747-200", "ATR-42", "A318", "DC-10", "747-100", "A340-500"], 8 | "unseen": ["737-900", "747-300", "757-200", "777-200", "777-300", "A300B4", "A320", "A340-200", "A340-600", "ATR-72", "BAE 146-200", "BAE-125", "Cessna 172", "Cessna 525", "Cessna 560", "CRJ-200", "DH-82", "DHC-1", "E-170", "E-190", "E-195", "EMB-120", "ERJ 135", "ERJ 145", "F-16A-B", "Fokker 50", "Global Express", "Gulfstream V", "Il-76", "L-1011", "MD-11", "MD-87", "MD-90", "Metroliner", "Spitfire", "Tornado", "Tu-134", "Tu-154"] 9 | }, 10 | "split_200": { 11 | "seen": ["An-12", "737-200", "F-16A-B", "BAE 146-200", "MD-80", "E-170", "Gulfstream IV", "DR-400", "737-900", "777-200", "Boeing 717", "747-100", "Saab 340", "Cessna 525", "Challenger 600", "MD-90", "DHC-8-100", "Cessna 172", "C-47", "747-400", "BAE-125", "MD-11", "767-300", "Cessna 560", "A330-300", "E-195", "737-500", "Fokker 50", "ATR-72", "BAE 146-300", "Fokker 70", "Falcon 900", "Falcon 2000", "Spitfire", "A340-200", "DC-3", "A340-300", "Beechcraft 1900", "A320", "Hawk T1", "E-190", "Gulfstream V", "Tu-134", "767-400", "CRJ-200", "737-400", "747-300", "Eurofighter Typhoon", "PA-28", "MD-87", "Yak-42", "DHC-1", "737-800", "A380", "Model B200", "ERJ 135", "SR-20", "737-300", "707-320", "DC-10", "Dornier 328", "A300B4"], 12 | "unseen": ["727-200", "737-600", "737-700", "747-200", "757-200", "757-300", "767-200", "777-300", "A310", "A318", "A319", "A321", "A330-200", "A340-500", "A340-600", "ATR-42", "C-130", "Cessna 208", "CRJ-700", "CRJ-900", "DC-6", "DC-8", "DC-9-30", "DH-82", "DHC-6", "DHC-8-300", "EMB-120", "Embraer Legacy 600", "ERJ 145", "F-A-18", "Fokker 100", "Global Express", "Il-76", "L-1011", "Metroliner", "Saab 2000", "Tornado", "Tu-154"] 13 | } 14 | } -------------------------------------------------------------------------------- /data/data_splits/Flowers102.json: -------------------------------------------------------------------------------- 1 | { 2 | "split_500": { 3 | "seen": ["canna lily", "petunia", "silverbush", "prince of wales feathers", "pincushion flower", "bird of paradise", "frangipani", "hard-leaved pocket orchid", "bearded iris", "passion flower", "tiger lily", "lenten rose", "cape flower", "air plant", "mexican petunia", "common dandelion", "magnolia", "foxglove", "hibiscus", "camellia", "orange dahlia", "clematis", "anthurium", "bougainvillea", "ruby-lipped cattleya", "stemless gentian", "oxeye daisy", "spring crocus", "king protea", "cyclamen", "fritillary", "californian poppy", "wild pansy", "desert-rose", "sunflower", "rose", "grape hyacinth", "pink primrose", "red ginger", "corn poppy", "watercress", "colt's foot", "blanket flower", "monkshood", "morning glory", "siam tulip", "barbeton daisy", "bolero deep blue", "carnation", "tree poppy", "globe thistle", "english marigold", "primula", "wallflower", "blackberry lily", "fire lily", "love in the mist", "moon orchid", "sweet pea", "mallow", "pelargonium", "mexican aster", "poinsettia"], 4 | "unseen": ["canterbury bells", "snapdragon", "spear thistle", "yellow iris", "globe flower", "purple coneflower", "peruvian lily", "balloon flower", "giant white arum lily", "artichoke", "sweet william", "garden phlox", "alpine sea holly", "great masterwort", "daffodil", "sword lily", "marigold", "buttercup", "bishop of llandaff", "gaura", "geranium", "pink and yellow dahlia", "cautleya spicata", "japanese anemone", "black-eyed susan", "osteospermum", "windflower", "gazania", "azalea", "water lily", "thorn apple", "lotus", "toad lily", "columbine", "tree mallow", "hippeastrum", "bee balm", "bromelia", "trumpet creeper"] 5 | }, 6 | "split_0": { 7 | "seen": ["prince of wales feathers", "air plant", "canterbury bells", "bishop of llandaff", "bee balm", "desert-rose", "purple coneflower", "spring crocus", "pelargonium", "windflower", "sunflower", "bougainvillea", "rose", "spear thistle", "bird of paradise", "carnation", "fritillary", "grape hyacinth", "mexican aster", "monkshood", "poinsettia", "black-eyed susan", "sweet pea", "anthurium", "wallflower", "oxeye daisy", "moon orchid", "blackberry lily", "hibiscus", "frangipani", "cautleya spicata", "camellia", "canna lily", "passion flower", "wild pansy", "stemless gentian", "balloon flower", "gaura", "thorn apple", "morning glory", "hard-leaved pocket orchid", "japanese anemone", "sword lily", "daffodil", "english marigold", "globe flower", "peruvian lily", "barbeton daisy", "siam tulip", "tiger lily", "foxglove", "pink and yellow dahlia", "pink primrose", "alpine sea holly", "artichoke", "petunia", "colt's foot", "ruby-lipped cattleya", "red ginger", "primula", "snapdragon", "garden phlox", "mexican petunia"], 8 | "unseen": ["globe thistle", "king protea", "yellow iris", "giant white arum lily", "fire lily", "pincushion flower", "corn poppy", "sweet william", "love in the mist", "cape flower", "great masterwort", "lenten rose", "bolero deep blue", "marigold", "buttercup", "common dandelion", "geranium", "orange dahlia", "silverbush", "californian poppy", "osteospermum", "bearded iris", "tree poppy", "gazania", "azalea", "water lily", "lotus", "toad lily", "clematis", "columbine", "tree mallow", "magnolia", "cyclamen", "watercress", "hippeastrum", "mallow", "bromelia", "blanket flower", "trumpet creeper"] 9 | }, 10 | "split_200": { 11 | "seen": ["oxeye daisy", "canterbury bells", "clematis", "siam tulip", "cape flower", "black-eyed susan", "air plant", "californian poppy", "globe thistle", "giant white arum lily", "cyclamen", "snapdragon", "frangipani", "buttercup", "common dandelion", "hippeastrum", "columbine", "spring crocus", "bolero deep blue", "spear thistle", "barbeton daisy", "poinsettia", "peruvian lily", "alpine sea holly", "artichoke", "sunflower", "tiger lily", "toad lily", "magnolia", "lenten rose", "great masterwort", "camellia", "mallow", "morning glory", "lotus", "sweet william", "thorn apple", "carnation", "daffodil", "corn poppy", "cautleya spicata", "marigold", "hibiscus", "tree poppy", "balloon flower", "osteospermum", "english marigold", "king protea", "azalea", "foxglove", "watercress", "blackberry lily", "bearded iris", "monkshood", "mexican aster", "orange dahlia", "water lily", "mexican petunia", "sweet pea", "pink primrose", "primula", "silverbush", "pincushion flower"], 12 | "unseen": ["hard-leaved pocket orchid", "moon orchid", "bird of paradise", "colt's foot", "yellow iris", "globe flower", "purple coneflower", "fire lily", "fritillary", "red ginger", "grape hyacinth", "prince of wales feathers", "stemless gentian", "garden phlox", "love in the mist", "ruby-lipped cattleya", "sword lily", "wallflower", "petunia", "wild pansy", "pelargonium", "bishop of llandaff", "gaura", "geranium", "pink and yellow dahlia", "japanese anemone", "windflower", "gazania", "rose", "passion flower", "anthurium", "desert-rose", "tree mallow", "canna lily", "bee balm", "bougainvillea", "bromelia", "blanket flower", "trumpet creeper"] 13 | } 14 | } -------------------------------------------------------------------------------- /data/data_splits/MNIST.json: -------------------------------------------------------------------------------- 1 | { 2 | "split_500": { 3 | "seen": ["4", "2", "9", "3", "0", "5"], 4 | "unseen": ["8", "1", "6", "7"] 5 | }, 6 | "split_0": { 7 | "seen": ["2", "8", "4", "9", "1", "6"], 8 | "unseen": ["0", "3", "5", "7"] 9 | }, 10 | "split_200": { 11 | "seen": ["8", "3", "5", "6", "1", "7"], 12 | "unseen": ["0", "9", "2", "4"] 13 | } 14 | } -------------------------------------------------------------------------------- /data/data_splits/RESICS45.json: -------------------------------------------------------------------------------- 1 | { 2 | "split_500": { 3 | "seen": ["beach", "palace", "roundabout", "railway station", "railway", "thermal power station", "river", "airplane", "island", "bridge", "basketball court", "desert", "runway", "ground track field", "sea ice", "sparse residential", "cloud", "dense residential", "wetland", "mountain", "meadow", "baseball diamond", "parking lot", "storage tank", "tennis court", "commercial area", "mobile home park"], 4 | "unseen": ["airport", "ship", "snowberg", "chaparral", "church", "circular farmland", "stadium", "terrace", "forest", "freeway", "golf course", "harbor", "industrial area", "intersection", "lake", "medium residential", "overpass", "rectangular farmland"] 5 | }, 6 | "split_0": { 7 | "seen": ["railway station", "snowberg", "palace", "beach", "commercial area", "mountain", "parking lot", "dense residential", "sparse residential", "rectangular farmland", "railway", "island", "tennis court", "baseball diamond", "thermal power station", "industrial area", "golf course", "meadow", "ground track field", "storage tank", "circular farmland", "forest", "bridge", "harbor", "river", "freeway", "sea ice"], 8 | "unseen": ["airplane", "airport", "roundabout", "basketball court", "runway", "ship", "chaparral", "church", "stadium", "cloud", "terrace", "desert", "wetland", "intersection", "lake", "medium residential", "mobile home park", "overpass"] 9 | }, 10 | "split_200": { 11 | "seen": ["railway", "parking lot", "wetland", "meadow", "harbor", "island", "mobile home park", "storage tank", "industrial area", "bridge", "baseball diamond", "sea ice", "runway", "airplane", "thermal power station", "circular farmland", "basketball court", "roundabout", "commercial area", "railway station", "terrace", "forest", "rectangular farmland", "lake", "medium residential", "snowberg", "river"], 12 | "unseen": ["airport", "beach", "ship", "chaparral", "church", "sparse residential", "cloud", "stadium", "dense residential", "desert", "tennis court", "freeway", "golf course", "ground track field", "intersection", "mountain", "overpass", "palace"] 13 | } 14 | } -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import torch 5 | import torch.nn as nn 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | 9 | log = logging.getLogger(__name__) 10 | 11 | 12 | class CustomDataset(Dataset): 13 | def __init__( 14 | self, 15 | filepaths, 16 | root, 17 | transform, 18 | augmentations=None, 19 | train=True, 20 | labels=None, 21 | label_id=False, 22 | label_map=None, 23 | ): 24 | """ 25 | :param filepaths: list of images 26 | :param root: path to images 27 | :param transform: standard transform 28 | :param augmentations: None or tuple 29 | :param train: indicates in the data is in train or test folder 30 | :param labels: list of label 31 | :param label_id: true if labeles are passed as int 32 | :param label_map: dict mpping string labels to int 33 | """ 34 | # Adjust filepaths 35 | self.train = train 36 | if self.train: 37 | self.filepaths = [f"{root}/train/{f}" for f in filepaths] 38 | else: 39 | self.filepaths = [f"{root}/test/{f}" for f in filepaths] 40 | 41 | self.transform = transform 42 | if augmentations: 43 | self.aug1_transform = augmentations[0] 44 | self.aug2_transform = augmentations[1] 45 | else: 46 | self.aug1_transform = None 47 | self.aug2_transform = None 48 | self.labels = labels 49 | self.label_id = label_id 50 | self.label_map = label_map 51 | 52 | def __len__(self): 53 | # dataset size 54 | return len(self.filepaths) 55 | 56 | def __getitem__(self, index: int): 57 | """ 58 | Args: 59 | index (int): Index 60 | Returns: 61 | tuple: (image, aug1, aug2, target) where target is index of the target class. 62 | """ 63 | 64 | img = Image.open(self.filepaths[index]).convert("RGB") 65 | 66 | # Apply two transformations (strong and weak) 67 | if self.aug1_transform is not None: 68 | aug_1 = self.aug1_transform(img) 69 | else: 70 | img1 = self.transform(img) 71 | aug_1 = img1 72 | if self.aug2_transform is not None: 73 | aug_2 = self.aug2_transform(img) 74 | else: 75 | img2 = self.transform(img) 76 | aug_2 = img2 77 | 78 | if self.transform is not None: 79 | img = self.transform(img) 80 | 81 | # Get image label 82 | if self.labels is not None: 83 | if self.label_id: 84 | label = int(self.labels[index]) 85 | else: 86 | label = int(self.label_map[self.labels[index]]) 87 | return img, aug_1, aug_2, label, self.filepaths[index].split("/")[-1] 88 | else: 89 | return img, aug_1, aug_2, self.filepaths[index].split("/")[-1] 90 | 91 | 92 | 93 | class EuroSAT(CustomDataset): 94 | def __init__( 95 | self, 96 | filepaths, 97 | root, 98 | transform, 99 | augmentations=None, 100 | train=True, 101 | labels=None, 102 | label_id=False, 103 | label_map=None, 104 | class_folder=False, 105 | original_filepaths=None, 106 | ): 107 | """ 108 | :param filepaths: list of images 109 | :param root: path to images 110 | :param transform: standard transform 111 | :param augmentations: None or tuple 112 | :param train: indicates in the data is in train or test folder 113 | :param labels: list of label 114 | :param label_id: true if labeles are passed as int 115 | :param label_map: dict mpping string labels to int 116 | """ 117 | super().__init__( 118 | filepaths, 119 | root, 120 | transform, 121 | augmentations, 122 | train, 123 | labels, 124 | label_id, 125 | label_map, 126 | ) 127 | # Adjust filepaths 128 | self.filepaths = [f"{root}/{f.split('_')[0]}/{f}" for f in filepaths] 129 | 130 | 131 | class DTD(CustomDataset): 132 | def __init__( 133 | self, 134 | filepaths, 135 | root, 136 | transform, 137 | augmentations=None, 138 | train=True, 139 | labels=None, 140 | label_id=False, 141 | label_map=None, 142 | class_folder=False, 143 | original_filepaths=None, 144 | ): 145 | """ 146 | :param filepaths: list of images 147 | :param root: path to images 148 | :param transform: standard transform 149 | :param augmentations: None or tuple 150 | :param train: indicates in the data is in train or test folder 151 | :param labels: list of label 152 | :param label_id: true if labeles are passed as int 153 | :param label_map: dict mpping string labels to int 154 | """ 155 | super().__init__( 156 | filepaths, 157 | root, 158 | transform, 159 | augmentations, 160 | train, 161 | labels, 162 | label_id, 163 | label_map, 164 | ) 165 | # Adjust filepaths 166 | if class_folder: 167 | paths = [] 168 | for f in filepaths: 169 | cl = f.split("_")[0] 170 | tr_files = os.listdir(f"{root}/train/{cl}") 171 | val_files = os.listdir(f"{root}/val/{cl}") 172 | if f in tr_files: 173 | paths.append(f"{root}/train/{cl}/{f}") 174 | elif f in val_files: 175 | paths.append(f"{root}/val/{cl}/{f}") 176 | 177 | self.filepaths = paths 178 | 179 | else: 180 | self.filepaths = [f"{root}/{f}" for f in filepaths] 181 | 182 | 183 | class CUB(CustomDataset): 184 | def __init__( 185 | self, 186 | filepaths, 187 | root, 188 | transform, 189 | augmentations=None, 190 | train=True, 191 | labels=None, 192 | label_id=False, 193 | label_map=None, 194 | class_folder=False, 195 | original_filepaths=None, 196 | ): 197 | """ 198 | :param filepaths: list of images 199 | :param root: path to images 200 | :param transform: standard transform 201 | :param augmentations: None or tuple 202 | :param train: indicates in the data is in train or test folder 203 | :param labels: list of label 204 | :param label_id: true if labeles are passed as int 205 | :param label_map: dict mpping string labels to int 206 | """ 207 | super().__init__( 208 | filepaths, 209 | root, 210 | transform, 211 | augmentations, 212 | train, 213 | labels, 214 | label_id, 215 | label_map, 216 | ) 217 | # Adjust filepaths 218 | self.filepaths = [f"{root}/{f}" for f in filepaths] 219 | 220 | 221 | class RESICS45(CustomDataset): 222 | def __init__( 223 | self, 224 | filepaths, 225 | root, 226 | transform, 227 | augmentations=None, 228 | train=True, 229 | labels=None, 230 | label_id=False, 231 | label_map=None, 232 | class_folder=False, 233 | original_filepaths=None, 234 | ): 235 | """ 236 | :param filepaths: list of images 237 | :param root: path to images 238 | :param transform: standard transform 239 | :param augmentations: None or tuple 240 | :param train: indicates in the data is in train or test folder 241 | :param labels: list of label 242 | :param label_id: true if labeles are passed as int 243 | :param label_map: dict mpping string labels to int 244 | """ 245 | super().__init__( 246 | filepaths, 247 | root, 248 | transform, 249 | augmentations, 250 | train, 251 | labels, 252 | label_id, 253 | label_map, 254 | ) 255 | # Adjust filepaths 256 | self.filepaths = [] 257 | for f in filepaths: 258 | folder = "_".join(f.split("_")[:-1]) 259 | self.filepaths.append(f"{root}/{folder}/{f}") 260 | 261 | 262 | class FGVCAircraft(CustomDataset): 263 | def __init__( 264 | self, 265 | filepaths, 266 | root, 267 | transform, 268 | augmentations=None, 269 | train=True, 270 | labels=None, 271 | label_id=False, 272 | label_map=None, 273 | class_folder=False, 274 | original_filepaths=None, 275 | ): 276 | """ 277 | :param filepaths: list of images 278 | :param root: path to images 279 | :param transform: standard transform 280 | :param augmentations: None or tuple 281 | :param train: indicates in the data is in train or test folder 282 | :param labels: list of label 283 | :param label_id: true if labeles are passed as int 284 | :param label_map: dict mpping string labels to int 285 | """ 286 | super().__init__( 287 | filepaths, 288 | root, 289 | transform, 290 | augmentations, 291 | train, 292 | labels, 293 | label_id, 294 | label_map, 295 | ) 296 | if class_folder: 297 | filepaths = list(filepaths) 298 | new_paths = [] 299 | for f in original_filepaths: 300 | img = f.split("/")[-1] 301 | if img in filepaths: 302 | new_paths.append(f"{f}") 303 | 304 | self.filepaths = new_paths 305 | else: 306 | # Adjust filepaths 307 | self.filepaths = [f"{root}/{f}" for f in filepaths] 308 | 309 | 310 | class MNIST(CustomDataset): 311 | def __init__( 312 | self, 313 | filepaths, 314 | root, 315 | transform, 316 | augmentations=None, 317 | train=True, 318 | labels=None, 319 | label_id=False, 320 | label_map=None, 321 | class_folder=False, 322 | original_filepaths=None, 323 | ): 324 | """ 325 | :param filepaths: list of images 326 | :param root: path to images 327 | :param transform: standard transform 328 | :param augmentations: None or tuple 329 | :param train: indicates in the data is in train or test folder 330 | :param labels: list of label 331 | :param label_id: true if labeles are passed as int 332 | :param label_map: dict mpping string labels to int 333 | """ 334 | super().__init__( 335 | filepaths, 336 | root, 337 | transform, 338 | augmentations, 339 | train, 340 | labels, 341 | label_id, 342 | label_map, 343 | ) 344 | if class_folder: 345 | filepaths = list(filepaths) 346 | new_paths = [] 347 | for f in original_filepaths: 348 | img = f.split("/")[-1] 349 | if img in filepaths: 350 | new_paths.append(f"{f}") 351 | 352 | self.filepaths = new_paths 353 | else: 354 | # Adjust filepaths 355 | self.filepaths = [f"{root}/{f}" for f in filepaths] 356 | 357 | 358 | class Flowers102(CustomDataset): 359 | def __init__( 360 | self, 361 | filepaths, 362 | root, 363 | transform, 364 | augmentations=None, 365 | train=True, 366 | labels=None, 367 | label_id=False, 368 | label_map=None, 369 | class_folder=False, 370 | original_filepaths=None, 371 | ): 372 | """ 373 | :param filepaths: list of images 374 | :param root: path to images 375 | :param transform: standard transform 376 | :param augmentations: None or tuple 377 | :param train: indicates in the data is in train or test folder 378 | :param labels: list of label 379 | :param label_id: true if labeles are passed as int 380 | :param label_map: dict mpping string labels to int 381 | """ 382 | super().__init__( 383 | filepaths, 384 | root, 385 | transform, 386 | augmentations, 387 | train, 388 | labels, 389 | label_id, 390 | label_map, 391 | ) 392 | # Adjust filepaths 393 | if class_folder: 394 | filepaths = list(filepaths) 395 | new_paths = [] 396 | for f in original_filepaths: 397 | img = f.split("/")[-1] 398 | if img in filepaths: 399 | new_paths.append(f"{f}") 400 | 401 | self.filepaths = new_paths 402 | 403 | else: 404 | self.filepaths = [f"{root}/{f}" for f in filepaths] 405 | -------------------------------------------------------------------------------- /data/dataset_prompts.py: -------------------------------------------------------------------------------- 1 | dataset_custom_prompts = { 2 | 'EuroSAT' : 'a photo of a {}', # 'a centered satellite photo of a {}', 3 | 'DTD' : 'a photo of a {}', # 'a photo of a {} texture', 4 | 'RESICS45' : 'a photo of a {}', # 'satellite imagery of a {}', 5 | 'FGVCAircraft' : 'a photo of a {}', # 'a photo of a {}, a type of aircraft', 6 | 'MNIST' : 'a photo of a {}', # 'a photo of the number: "{}"', 7 | 'Flowers102' : 'a photo of a {}', # 'a photo of a {}, a type of flower', 8 | } -------------------------------------------------------------------------------- /imgs/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BatsResearch/menghini-neurips23-code/b387624c15638d9db8024135cafe386089fbe09f/imgs/overview.png -------------------------------------------------------------------------------- /methods/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip_baseline import ClipBaseline -------------------------------------------------------------------------------- /methods/clip_baseline.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import clip 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from PIL import Image 8 | from tqdm import tqdm 9 | 10 | 11 | g = torch.Generator() 12 | g.manual_seed(0) 13 | 14 | log = logging.getLogger(__name__) 15 | 16 | 17 | class ClipBaseline(object): 18 | def __init__( 19 | self, config, label_to_idx, classes, seen_classes, unseen_classes, device 20 | ): 21 | """This class is CLIP model. 22 | 23 | :param config: class object with model configurations in the 24 | file models_config/clip_baseline_config.yml 25 | :param label_to_idx: dictionary (key, value):(class name, id) 26 | :param classes: list of class names 27 | :param seen_classes: list on seen classes' names 28 | :param unseen_classes: list on unseen classes' names 29 | :param device: device in use 30 | 31 | """ 32 | self.config = config 33 | self.classes = classes 34 | self.seen_classes = seen_classes 35 | self.unseen_classes = unseen_classes 36 | self.label_to_idx = label_to_idx 37 | 38 | self.device = device 39 | self.model, self.transform = clip.load( 40 | self.config.VIS_ENCODER, device=self.device 41 | ) 42 | self.template = self.config.PROMPT_TEMPLATE 43 | 44 | def test_predictions(self, data): 45 | """ 46 | :param data: test dataset 47 | """ 48 | 49 | # Declare data pre-processing 50 | data.transform = self.transform 51 | # Define the data loader 52 | test_loader = torch.utils.data.DataLoader( 53 | data, batch_size=self.config.BATCH_SIZE 54 | ) 55 | 56 | # Build textual prompts 57 | prompts = [ 58 | self.template.format(" ".join(i.split("_"))) for i in self.classes 59 | ] 60 | log.info(f"Number of prompts: {len(prompts)}") 61 | 62 | # Encode text 63 | text = clip.tokenize(prompts).to(self.device) 64 | text_features = self.model.encode_text(text) 65 | 66 | log.info(f"Start inference for test data") 67 | predictions = [] 68 | images = [] 69 | prob_preds = [] 70 | # Iterate through data loader 71 | for img, _, _, img_path in tqdm(test_loader): 72 | with torch.no_grad(): 73 | logits_per_image, logits_per_text = self.model( 74 | img.to(self.device), text.to(self.device) 75 | ) 76 | probs = logits_per_image.softmax(dim=-1) 77 | idx_preds = torch.argmax(probs, dim=1) 78 | 79 | predictions += [self.classes[i] for i in idx_preds] 80 | images += [i for i in img_path] 81 | prob_preds += [logits_per_image] 82 | 83 | prob_preds = torch.cat(prob_preds, axis=0).detach().to('cpu') 84 | df_predictions = pd.DataFrame({"id": images, "class": predictions}) 85 | 86 | return df_predictions, images, predictions, prob_preds 87 | -------------------------------------------------------------------------------- /methods/main_CLIP.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import random 6 | import sys 7 | from collections import defaultdict 8 | from logging import StreamHandler 9 | from pathlib import Path 10 | 11 | import numpy as np 12 | import pandas as pd 13 | import requests 14 | import scipy.stats as st 15 | import torch 16 | import yaml 17 | from accelerate import Accelerator 18 | from requests.adapters import HTTPAdapter 19 | from torch import nn 20 | from urllib3.util import Retry 21 | 22 | from data import CustomDataset, dataset_custom_prompts 23 | from methods import ClipBaseline 24 | from utils import ( 25 | Config, 26 | dataset_object, 27 | evaluate_predictions, 28 | get_class_names, 29 | get_labeled_and_unlabeled_data, 30 | save_parameters, 31 | save_predictions, 32 | store_results, 33 | ) 34 | 35 | accelerator = Accelerator() 36 | 37 | logger_ = logging.getLogger() 38 | logger_.level = logging.INFO 39 | formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s") 40 | 41 | 42 | class AccelerateHandler(StreamHandler): 43 | def __init__(self, stream): 44 | super().__init__(stream) 45 | 46 | def emit(self, record): 47 | if accelerator.is_local_main_process: 48 | super().emit(record) 49 | 50 | 51 | stream_handler = AccelerateHandler(sys.stdout) 52 | stream_handler.setLevel(logging.INFO) 53 | stream_handler.setFormatter(formatter) 54 | logger_.addHandler(stream_handler) 55 | 56 | log = logging.getLogger(__name__) 57 | 58 | def workflow(dataset_dir, obj_conf): 59 | 60 | dataset = obj_conf.DATASET_NAME 61 | # Get class names of target task 62 | classes, seen_classes, unseen_classes = get_class_names( 63 | dataset, 64 | dataset_dir, 65 | obj_conf.SPLIT_SEED 66 | ) 67 | # Create dict classes to pass as variable 68 | dict_classes = { 69 | "classes": classes, 70 | "seen_classes": seen_classes, 71 | "unseen_classes": unseen_classes, 72 | } 73 | # Log number of classes 74 | log.info(f"\n----------------------DATA INFO-----------------------\n") 75 | log.info(f"Number of classes split {obj_conf.SPLIT_SEED}: {len(classes)}") 76 | log.info(f"Number of seen classes split {obj_conf.SPLIT_SEED}: {len(seen_classes)}") 77 | log.info(f"Number of unseen classes split {obj_conf.SPLIT_SEED}: {len(unseen_classes)}") 78 | # Path for images 79 | data_folder = f"{dataset_dir}/{dataset}" 80 | log.info(f"Data folder: {data_folder}") 81 | log.info(f"\n-------------------------------------------------------------\n") 82 | 83 | # Get labeled data (seen classes) 84 | # Get unlabeled data (unseen classes) 85 | # Get test data (both seen and unseen classes) 86 | _, _, test_data = get_labeled_and_unlabeled_data( 87 | dataset, data_folder, seen_classes, unseen_classes, classes 88 | ) 89 | 90 | # Create datasets 91 | test_labeled_files, test_labeles = zip(*test_data) 92 | label_to_idx = {c: idx for idx, c in enumerate(classes)} 93 | 94 | DatasetObject = dataset_object(obj_conf.DATASET_NAME) 95 | 96 | # Test set (test seen and unseen) 97 | test_dataset = DatasetObject( 98 | test_labeled_files, 99 | data_folder, 100 | transform=None, 101 | augmentations=None, 102 | train=False, 103 | labels=None, 104 | label_map=label_to_idx, 105 | ) 106 | # Log info data 107 | log.info(f"\n----------------------TRAINING DATA INFO-----------------------\n") 108 | log.info(f"Len test data: {len(test_dataset.filepaths)}") 109 | log.info(f"\n-------------------------------------------------------------\n") 110 | # Define model 111 | device = "cuda" if torch.cuda.is_available() else "cpu" 112 | 113 | 114 | log.info(f"\n----------------------MODEL INFO-----------------------\n") 115 | log.info(f"The model in use is: {obj_conf.MODEL}") 116 | model = ClipBaseline(obj_conf, label_to_idx, device=device, **dict_classes) 117 | 118 | # Validate on test set (standard) 119 | std_predictions, images, predictions, prob_preds = model.test_predictions(test_dataset) 120 | # Submit predictions (standard) 121 | std_response = evaluate_predictions( 122 | obj_conf, 123 | std_predictions, 124 | test_labeled_files, 125 | test_labeles, 126 | unseen_classes, 127 | seen_classes 128 | ) 129 | log.info(f"ZSL accuracy: {std_response}") 130 | 131 | # Store model results 132 | store_results(obj_conf, std_response) 133 | 134 | dictionary_predictions = { 135 | 'images' : images, 136 | 'predictions' : predictions, 137 | 'labels' : test_labeles, 138 | 'logits' : prob_preds, 139 | } 140 | 141 | save_predictions(dictionary_predictions, obj_conf, iteration=None) 142 | 143 | 144 | def main(): 145 | parser = argparse.ArgumentParser(description="Run JPL task") 146 | parser.add_argument( 147 | "--model_config", 148 | type=str, 149 | default="model_config.yml", 150 | help="Name of model config file", 151 | ) 152 | parser.add_argument( 153 | "--learning_paradigm", 154 | type=str, 155 | default="trzsl", 156 | help="Choose among trzsl, ssl, and ul", 157 | ) 158 | 159 | args = parser.parse_args() 160 | 161 | with open(f"methods_config/{args.model_config}", "r") as file: 162 | config = yaml.safe_load(file) 163 | 164 | # Cast configs to object 165 | obj_conf = Config(config) 166 | 167 | # Set seed seed 168 | optim_seed = int(os.environ["OPTIM_SEED"]) 169 | obj_conf.OPTIM_SEED = optim_seed 170 | # Set backbone 171 | obj_conf.VIS_ENCODER = os.environ["VIS_ENCODER"] 172 | # Set dataset name 173 | obj_conf.DATASET_NAME = os.environ["DATASET_NAME"] 174 | # Set dataset dir 175 | obj_conf.DATASET_DIR = os.environ["DATASET_DIR"] 176 | # Set model name 177 | obj_conf.MODEL = os.environ["MODEL"] 178 | # Set split seed 179 | obj_conf.SPLIT_SEED = int(os.environ["SPLIT_SEED"]) 180 | # Set dataset's template for textual prompts 181 | obj_conf.PROMPT_TEMPLATE = dataset_custom_prompts[obj_conf.DATASET_NAME] 182 | # Set data dir 183 | dataset_dir = obj_conf.DATASET_DIR 184 | # Set learning paradigm 185 | obj_conf.LEARNING_PARADIGM = args.learning_paradigm 186 | 187 | # Set the file path for the log file 188 | log_file = f"logs/{obj_conf.DATASET_NAME}_{obj_conf.MODEL}_{obj_conf.VIS_ENCODER.replace('/', '-')}.log" 189 | # Create a FileHandler and set the log file 190 | file_handler = logging.FileHandler(log_file) 191 | file_handler.setFormatter(formatter) 192 | # Add the FileHandler to the logger 193 | logger_.addHandler(file_handler) 194 | 195 | log.info(f"Current working directory: {os.getcwd()}") 196 | log.info(f"Dataset dir: {dataset_dir}") 197 | # Check dataset directory exists 198 | if not Path(dataset_dir).exists(): 199 | print(dataset_dir) 200 | raise Exception("`dataset_dir` does not exist..") 201 | 202 | # Set random seeds 203 | device = "cuda" if torch.cuda.is_available() else "cpu" 204 | np.random.seed(obj_conf.OPTIM_SEED) 205 | random.seed(obj_conf.OPTIM_SEED) 206 | torch.manual_seed(obj_conf.OPTIM_SEED) 207 | accelerator.wait_for_everyone() 208 | # Seed for cuda 209 | if torch.cuda.is_available(): 210 | torch.cuda.manual_seed(obj_conf.OPTIM_SEED) 211 | torch.cuda.manual_seed_all(obj_conf.OPTIM_SEED) 212 | accelerator.wait_for_everyone() 213 | 214 | torch.backends.cudnn.benchmark = True 215 | 216 | workflow(dataset_dir, obj_conf) 217 | 218 | 219 | if __name__ == "__main__": 220 | main() 221 | -------------------------------------------------------------------------------- /methods/semi_supervised_learning/__init__.py: -------------------------------------------------------------------------------- 1 | from .training_strategies import TrainingStrategy 2 | from .visual_prompt import VisualPrompt 3 | from .textual_prompt import TextualPrompt 4 | from .multimodal_prompt import MultimodalPrompt 5 | from .visual_fpl import VisualFPL 6 | from .textual_fpl import TextualFPL 7 | from .multimodal_fpl import MultimodalFPL 8 | -------------------------------------------------------------------------------- /methods/semi_supervised_learning/multimodal_fpl.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import clip 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from accelerate import Accelerator 8 | from PIL import Image 9 | from torch import nn 10 | import math 11 | 12 | accelerator = Accelerator() 13 | 14 | from methods.semi_supervised_learning import MultimodalPrompt 15 | from utils import ( 16 | dataset_object, 17 | make_scheduler, 18 | pseudolabel_top_k, 19 | ) 20 | 21 | 22 | log = logging.getLogger(__name__) 23 | 24 | 25 | class MultimodalFPL(MultimodalPrompt): 26 | def __init__( 27 | self, 28 | config, 29 | label_to_idx, 30 | data_folder, 31 | unlabeled_files, 32 | classes, 33 | seen_classes, 34 | unseen_classes, 35 | device 36 | ): 37 | """This class defines self-trainig UPT's training and evaluation. 38 | :param config: dictionaries of prameters in models_config/upt_baseline_pseudo_config.yml 39 | :param label_to_idx: dictionary (key, value):(class name, id) 40 | :param classes: list of class names 41 | :param seen_classes: list of seen classes' names 42 | :param unseen_classes: list of unseen classes' names 43 | :param device: device in use 44 | """ 45 | 46 | super().__init__( 47 | config, label_to_idx, classes, seen_classes, unseen_classes, device 48 | ) 49 | 50 | self.data_folder = data_folder 51 | 52 | self.check_unlabeled = unlabeled_files 53 | 54 | def create_training_dataset(self, train_data, unlabeled_data=None): 55 | """This function creates the dataset for training. Specifically, it 56 | merges pseudo-labels for unseen data and labeled data for seen classes. 57 | :param train_data: Dataset object - training seen classes (defined in zsl_jpl line 323) 58 | :param unlabeled_data: Dataset object - dataset of unlabeled data for 59 | unseen classes (defined in zsl_jpl line 328) 60 | """ 61 | 62 | # Get pseudo-labels for unlabeled data from unseen classes 63 | train_unseen_dataset = pseudolabel_top_k( 64 | self.config, 65 | self.config.DATASET_NAME, 66 | self.config.N_PSEUDOSHOTS, 67 | self.config.PROMPT_TEMPLATE, 68 | unlabeled_data, 69 | self.unseen_classes, 70 | self.transform, 71 | self.clip_model, 72 | self.label_to_idx, 73 | self.device, 74 | self.config.VIS_ENCODER, 75 | self.config.SPLIT_SEED 76 | ) 77 | 78 | # Define the lists of traiing data from seen and unseen classes 79 | unseen_imgs = train_unseen_dataset.filepaths 80 | unseen_labs = train_unseen_dataset.labels 81 | 82 | # Use a portion of the pseudo-labeled data to build a validation set 83 | if self.config.N_PSEUDOSHOTS >= 10: 84 | np.random.seed(self.config.validation_seed) 85 | train_indices = np.random.choice( 86 | range(len(unseen_imgs)), 87 | size=int(len(unseen_imgs) * self.config.ratio_train_val), 88 | replace=False, 89 | ) 90 | val_indices = list( 91 | set(range(len(unseen_imgs))).difference(set(train_indices)) 92 | ) 93 | 94 | self.val_unseen_files = np.array(unseen_imgs)[val_indices] 95 | self.val_unseen_labs = np.array(unseen_labs)[val_indices] 96 | 97 | unseen_imgs = list(np.array(unseen_imgs)[train_indices]) 98 | unseen_labs = list(np.array(unseen_labs)[train_indices]) 99 | 100 | else: 101 | self.val_unseen_files = None 102 | self.val_unseen_labs = None 103 | 104 | seen_imgs = train_data.filepaths 105 | seen_labs = [self.label_to_idx[l] for l in train_data.labels] 106 | 107 | self.balance_param = math.sqrt(len(unseen_imgs) / len(seen_imgs)) 108 | 109 | train_data.filepaths = list(unseen_imgs) + list(seen_imgs) 110 | train_data.labels = list(unseen_labs) + list(seen_labs) 111 | train_data.label_id = True 112 | 113 | def define_loss_function(self, logits, labs, paths): 114 | 115 | loss_ce_seen = self.cross_entropy(logits, labs, paths, False) 116 | loss_ce_unseen = self.cross_entropy(logits, labs, paths, True) 117 | 118 | return self.balance_param * loss_ce_seen + loss_ce_unseen 119 | 120 | 121 | def cross_entropy(self, logits, labels, paths, unlabeled=True): 122 | """This loss computes the probability mass on the 123 | opposite set of classes for each sample. 124 | :param logits: continuous vector 125 | :param labels: class ids 126 | """ 127 | 128 | # self.check_unlabeled 129 | if unlabeled: 130 | samples = [] 131 | for idx in range(len(paths)): 132 | if paths[idx] in self.check_unlabeled: 133 | samples.append(idx) 134 | 135 | #log.info(f"Unlabeled: {len(samples)} {self.unlabeled_weight}") 136 | if samples: 137 | error = self.loss_func(logits[samples], labels[samples]) 138 | else: 139 | error = 0 140 | else: 141 | samples = [] 142 | for idx in range(len(paths)): 143 | if paths[idx] not in self.check_unlabeled: 144 | samples.append(idx) 145 | 146 | #log.info(f"Labeled: {len(samples)} {self.labeled_weight}") 147 | if samples: 148 | error = self.loss_func(logits[samples], labels[samples]) 149 | else: 150 | error = 0 151 | 152 | return error 153 | 154 | def reindex_predicted_labels(self, idx_preds, only_unlabelled=False): 155 | """This function returns the correct index of predictions to compute 156 | model's accuracy. 157 | :param idx_pred: list of predictions ids 158 | :param only_unlabelled: boolean. It is True if the training only involves 159 | pseudo-labeled unseen data 160 | """ 161 | 162 | if only_unlabelled: 163 | return [self.unseen_classes[i.item()] for i in idx_preds] 164 | else: 165 | return [self.classes[i.item()] for i in idx_preds] 166 | 167 | def reindex_true_labels(self, label, only_unlabelled=False): 168 | """This function returns the correct index of true labels. 169 | :param label: list of labels from data loader 170 | :param only_unlabelled: boolean. It is True if the training only involves 171 | pseudo-labeled unseen data 172 | """ 173 | 174 | if only_unlabelled: 175 | return torch.tensor( 176 | [self.unseen_classes.index(self.classes[l.item()]) for l in label] 177 | ) 178 | else: 179 | return torch.tensor([l for l in label]) 180 | 181 | def get_pseudo_labels(self, unlabeled_examples): 182 | log.info(f"Num unlabeled data: {len(unlabeled_examples)}") 183 | # Get prediction on unlabeled data 184 | std_preds = self.test_predictions( 185 | unlabeled_examples, standard_zsl=True 186 | ) 187 | 188 | DatasetObject = dataset_object(self.config.DATASET_NAME) 189 | # 4. Take top-16 pseudo-labels to finetune the student 190 | pseudo_unseen_examples = DatasetObject( 191 | std_preds["id"], 192 | self.data_folder, 193 | transform=self.transform, 194 | augmentations=None, 195 | train=True, 196 | labels=None, 197 | label_map=self.label_to_idx, 198 | class_folder=True, 199 | original_filepaths=unlabeled_examples.filepaths, 200 | ) 201 | 202 | pseudo_labels = self.assign_pseudo_labels( 203 | self.config.N_PSEUDOSHOTS, pseudo_unseen_examples 204 | ) 205 | 206 | return pseudo_labels 207 | 208 | def assign_pseudo_labels(self, k, unlabeled_data): 209 | 210 | # to find the top k for each class, each class has it's own "leaderboard" 211 | top_k_leaderboard = { 212 | self.label_to_idx[self.unseen_classes[i]]: [] 213 | for i in range(len(self.unseen_classes)) 214 | } # maps class idx -> (confidence, image_path) tuple 215 | 216 | classes = self.unseen_classes 217 | for img_path in unlabeled_data.filepaths: 218 | # log.info(f"IMAGEPATH: {img_path}") 219 | img = Image.open(img_path).convert("RGB") 220 | img = torch.unsqueeze(self.transform(img), 0).to(self.device) 221 | with torch.no_grad(): 222 | # Get text and image prompts using UPT 223 | text_features, image_features = self.model(img, classes) 224 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 225 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 226 | 227 | logit_scale = self.clip_model.logit_scale.exp() 228 | logits = logit_scale * image_features @ text_features.t() 229 | probs = logits.softmax(dim=-1) 230 | idx_preds = torch.argmax(logits, dim=1) 231 | pred_id = idx_preds.item() 232 | pred = self.label_to_idx[self.unseen_classes[idx_preds.item()]] 233 | 234 | """if predicted class has empty leaderboard, or if the confidence is high 235 | enough for predicted class leaderboard, add the new example 236 | """ 237 | prob_score = probs[0][pred_id] 238 | if len(top_k_leaderboard[pred]) < k: 239 | top_k_leaderboard[pred].append((prob_score, img_path)) 240 | elif ( 241 | top_k_leaderboard[pred][-1][0] < prob_score 242 | ): # if the confidence in predicted class "qualifies" for top-k 243 | # default sorting of tuples is by first element 244 | top_k_leaderboard[pred] = sorted( 245 | top_k_leaderboard[pred] + [(probs[0][pred_id], img_path)], 246 | reverse=True, 247 | )[:k] 248 | else: 249 | # sort the other classes by confidence score 250 | order_of_classes = sorted( 251 | [ 252 | (probs[0][j], j) 253 | for j in range(len(self.unseen_classes)) 254 | if j != pred_id 255 | ], 256 | reverse=True, 257 | ) 258 | for score, index in order_of_classes: 259 | index_dict = self.label_to_idx[self.unseen_classes[index]] 260 | # log.info(f"{classnames[index]}") 261 | # log.info(f"{index_dict}") 262 | if len(top_k_leaderboard[index_dict]) < k: 263 | top_k_leaderboard[index_dict].append( 264 | (probs[0][index], img_path) 265 | ) 266 | elif top_k_leaderboard[index_dict][-1][0] < probs[0][index]: 267 | # default sorting of tuples is by first element 268 | top_k_leaderboard[index_dict] = sorted( 269 | top_k_leaderboard[index_dict] 270 | + [((probs[0][index], img_path))], 271 | reverse=True, 272 | )[:k] 273 | 274 | new_imgs = [] 275 | new_labels = [] 276 | # loop through, and rebuild the dataset 277 | for index, leaderboard in top_k_leaderboard.items(): 278 | new_imgs += [tup[1] for tup in leaderboard] 279 | new_labels += [index for _ in leaderboard] 280 | 281 | unlabeled_data.filepaths = new_imgs 282 | unlabeled_data.labels = new_labels 283 | unlabeled_data.label_id = True 284 | 285 | return unlabeled_data 286 | -------------------------------------------------------------------------------- /methods/semi_supervised_learning/pseudo_iterative.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import math 4 | import pickle 5 | 6 | import clip 7 | import numpy as np 8 | import pandas as pd 9 | import scipy.stats as st 10 | import torch 11 | from accelerate import Accelerator 12 | from PIL import Image 13 | from torch import nn 14 | 15 | 16 | accelerator = Accelerator() 17 | 18 | from data import CustomDataset 19 | from models import CustomImageEncoder, ImagePrefixModel 20 | from methods import TeacherStudent 21 | from utils import ( 22 | dataset_object, 23 | evaluate_predictions, 24 | make_scheduler, 25 | pseudolabel_top_k, 26 | seed_worker, 27 | save_parameters, 28 | save_pseudo_labels, 29 | ) 30 | 31 | g = torch.Generator() 32 | g.manual_seed(0) 33 | 34 | log = logging.getLogger(__name__) 35 | 36 | 37 | class PseudoIterative(TeacherStudent): 38 | def __init__( 39 | self, 40 | config, 41 | label_to_idx, 42 | data_folder, 43 | classes, 44 | seen_classes, 45 | unseen_classes, 46 | device, 47 | ): 48 | super().__init__( 49 | config, label_to_idx, data_folder, classes, seen_classes, unseen_classes, device 50 | ) 51 | 52 | 53 | def train( 54 | self, 55 | train_data, 56 | val_data, 57 | unlabeled_data, 58 | test_data, 59 | test_labeled_files, 60 | test_labeles, 61 | ): 62 | # Number of total iterations to cover all unlabeled data 63 | num_iter = int(100/self.config.STEP_QUANTILE) 64 | num_samples = int(len(unlabeled_data) / num_iter) 65 | # Initialize the number of pseudo-labels per class 66 | n_per_class = int(num_samples / len(self.unseen_classes)) 67 | n_unseen = len(self.unseen_classes) 68 | if n_per_class * n_unseen <= len(unlabeled_data.filepaths): 69 | # self.num_pseudo_labels_per_class = n_per_class 70 | self.config.N_PSEUDOSHOTS = n_per_class 71 | else: 72 | # self.num_pseudo_labels_per_class = math.floor(len(unlabeled_data.filepaths)/n_unseen) 73 | self.config.N_PSEUDOSHOTS = math.floor( 74 | len(unlabeled_data.filepaths) / n_unseen 75 | ) 76 | 77 | log.info(f"We select {self.config.N_PSEUDOSHOTS} pseudolabel per each unseen classes.") 78 | log.info(f"The number of unseen classes is: {len(self.unseen_classes)}.") 79 | log.info(f"Thus we expect an initial number of pseudo labeles equal to {len(self.unseen_classes) * self.config.N_PSEUDOSHOTS}.") 80 | # Create a safe copy of labeled/unlabeled data 81 | original_train_data = copy.deepcopy(train_data) 82 | # log.info(f"Training data labels: {original_train_data.labels}") 83 | original_unlabeled_data = copy.deepcopy(unlabeled_data) 84 | # Original val 85 | original_val_data = copy.deepcopy(val_data) 86 | 87 | # Initialize here first batch of pseudo labels 88 | #self.create_training_dataset(train_data, unlabeled_data) 89 | #log.info(f"The original train data has size: {len(original_train_data.filepaths)}.") 90 | #log.info(f"Plus: {len(unlabeled_data.filepaths)}.") 91 | 92 | for niter in range(1, num_iter + 1): 93 | log.info(f"NUM PSEUDO SHOTS: {self.config.N_PSEUDOSHOTS}") 94 | pseudolabel_top_k( 95 | self.config.DATASET_NAME, 96 | self.config.N_PSEUDOSHOTS, 97 | self.config.PROMPT_TEMPLATE, 98 | unlabeled_data, 99 | self.unseen_classes, 100 | self.transform, 101 | self.clip_model, 102 | self.label_to_idx, 103 | self.device, 104 | self.config.VIS_ENCODER, 105 | self.config.SPLIT_SEED, 106 | ) 107 | 108 | log.info(f"Plus: {len(unlabeled_data.filepaths)}.") 109 | filename = f"pseudolabels/{self.config.DATASET_NAME}_CLIP_{self.config.VIS_ENCODER.replace('/', '')}_iter_{niter}_pseudolabels_spl_{self.config.SPLIT_SEED}.pickle" 110 | with open(filename, "wb") as f: 111 | pickle.dump({"filepaths": unlabeled_data.filepaths, "labels": unlabeled_data.labels}, f) 112 | 113 | # Exploit all the available unlabeled data 114 | if self.config.ALL_UNLABELED: 115 | n_per_class = int((niter + 1) * num_samples / n_unseen) 116 | log.info(f"n_per_class: {n_per_class}") 117 | if n_per_class * n_unseen <= len(original_unlabeled_data.filepaths): 118 | log.info(f"if n_per_class: {n_per_class}") 119 | self.config.N_PSEUDOSHOTS = n_per_class 120 | else: 121 | log.info(f"else new val: {len(original_unlabeled_data.filepaths) / n_unseen}") 122 | # We are making a stong assumption about the distribution of unlabeled data 123 | self.config.N_PSEUDOSHOTS = math.floor( 124 | len(original_unlabeled_data.filepaths) / n_unseen 125 | ) 126 | 127 | unlabeled_data = original_unlabeled_data 128 | original_unlabeled_data = copy.deepcopy(unlabeled_data) -------------------------------------------------------------------------------- /methods/semi_supervised_learning/textual_fpl.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import clip 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from accelerate import Accelerator 8 | from PIL import Image 9 | from torch import nn 10 | from tqdm import tqdm 11 | 12 | accelerator = Accelerator() 13 | 14 | from methods.semi_supervised_learning import TextualPrompt 15 | from utils import ( 16 | dataset_object, 17 | make_scheduler, 18 | pseudolabel_top_k, 19 | seed_worker, 20 | ) 21 | 22 | 23 | g = torch.Generator() 24 | g.manual_seed(0) 25 | 26 | log = logging.getLogger(__name__) 27 | 28 | 29 | class TextualFPL(TextualPrompt): 30 | def __init__( 31 | self, 32 | config, 33 | label_to_idx, 34 | data_folder, 35 | unlabeled_files, 36 | classes, 37 | seen_classes, 38 | unseen_classes, 39 | device, 40 | ): 41 | """This class define Coop baseline. 42 | 43 | :param config: dictionaries of prameters in models_config/coop_baseline_config.yml 44 | :param label_to_idx: dictionary (key, value):(class name, id) 45 | :param classes: list of class names 46 | :param seen_classes: list of seen classes' names 47 | :param unseen_classes: list of unseen classes' names 48 | :param device: device in use 49 | """ 50 | super().__init__( 51 | config, label_to_idx, classes, seen_classes, unseen_classes, device 52 | ) 53 | 54 | self.data_folder = data_folder 55 | 56 | self.check_unlabeled = unlabeled_files 57 | 58 | def create_training_dataset(self, train_data, unlabeled_data=None): 59 | """This function create the dataset for training. Specifically, it 60 | merges pseudo-labels for unseen data and labeled data for seen classes. 61 | 62 | :param train_data: Dataset object - training seen classes (defined in zsl_jpl line 323) 63 | :param unlabeled_data: Dataset object - dataset of unlabeled data for 64 | unseen classes (defined in zsl_jpl line 328) 65 | """ 66 | 67 | # Get pseudo-labels for unlabeled data from unseen classes 68 | train_unseen_dataset = pseudolabel_top_k( 69 | self.config, 70 | self.config.DATASET_NAME, 71 | self.config.N_PSEUDOSHOTS, 72 | self.config.PROMPT_TEMPLATE, 73 | unlabeled_data, 74 | self.unseen_classes, 75 | self.transform, 76 | self.clip_model, 77 | self.label_to_idx, 78 | self.device, 79 | self.config.VIS_ENCODER, 80 | self.config.SPLIT_SEED, 81 | ) 82 | 83 | # Define the lists of traiing data from seen and unseen classes 84 | unseen_imgs = train_unseen_dataset.filepaths 85 | unseen_labs = train_unseen_dataset.labels 86 | 87 | # Use a portion of the pseudo-labeled data to build a validation set 88 | if self.config.N_PSEUDOSHOTS >= 10: 89 | np.random.seed(self.config.validation_seed) 90 | train_indices = np.random.choice( 91 | range(len(unseen_imgs)), 92 | size=int(len(unseen_imgs) * self.config.ratio_train_val), 93 | replace=False, 94 | ) 95 | val_indices = list( 96 | set(range(len(unseen_imgs))).difference(set(train_indices)) 97 | ) 98 | 99 | self.val_unseen_files = np.array(unseen_imgs)[val_indices] 100 | self.val_unseen_labs = np.array(unseen_labs)[val_indices] 101 | 102 | unseen_imgs = list(np.array(unseen_imgs)[train_indices]) 103 | unseen_labs = list(np.array(unseen_labs)[train_indices]) 104 | 105 | else: 106 | self.val_unseen_files = None 107 | self.val_unseen_labs = None 108 | 109 | seen_imgs = train_data.filepaths 110 | seen_labs = [self.label_to_idx[l] for l in train_data.labels] 111 | 112 | # self.labeled_weight = 1 / len(seen_imgs) 113 | # self.unlabeled_weight = 1 / len(unseen_imgs) 114 | # log.info(f"UNLABELD: {len(unseen_imgs)} and LABELED {len(seen_imgs)}") 115 | self.balance_param = len(unseen_imgs) / len(seen_imgs) 116 | 117 | train_data.filepaths = list(unseen_imgs) + list(seen_imgs) 118 | train_data.labels = list(unseen_labs) + list(seen_labs) 119 | train_data.label_id = True 120 | 121 | return train_data 122 | 123 | def define_loss_function(self, logits, labs, paths): 124 | 125 | loss_ce_seen = self.cross_entropy(logits, labs, paths, False) 126 | loss_ce_unseen = self.cross_entropy(logits, labs, paths, True) 127 | 128 | return self.balance_param * loss_ce_seen + loss_ce_unseen 129 | 130 | def cross_entropy(self, logits, labels, paths, unlabeled=True): 131 | """This loss computes the probability mass on the 132 | opposite set of classes for each sample. 133 | 134 | :param logits: continuous vector 135 | :param labels: class ids 136 | """ 137 | 138 | # log.info(f"IMAGE PATHS: {paths}") 139 | # log.info(f"IMAGE CHECK: {self.check_unlabeled[:10]}") 140 | 141 | # self.check_unlabeled 142 | if unlabeled: 143 | samples = [] 144 | for idx in range(len(paths)): 145 | if paths[idx] in self.check_unlabeled: 146 | samples.append(idx) 147 | 148 | # log.info(f"Unlabeled: {len(samples)} {self.balance_param}") 149 | if samples: 150 | error = self.loss_func(logits[samples], labels[samples]) 151 | else: 152 | error = 0 153 | else: 154 | samples = [] 155 | for idx in range(len(paths)): 156 | if paths[idx] not in self.check_unlabeled: 157 | samples.append(idx) 158 | 159 | # log.info(f"Labeled: {len(samples)} {self.balance_param}") 160 | if samples: 161 | error = self.loss_func(logits[samples], labels[samples]) 162 | else: 163 | error = 0 164 | 165 | return error 166 | 167 | 168 | def get_pseudo_labels(self, unlabeled_examples): 169 | log.info(f"Num unlabeled data: {len(unlabeled_examples)}") 170 | # Get prediction on unlabeled data 171 | std_preds = self.test_predictions( 172 | unlabeled_examples, standard_zsl=True 173 | ) 174 | 175 | DatasetObject = dataset_object(self.config.DATASET_NAME) 176 | # 4. Take top-16 pseudo-labels to finetune the student 177 | pseudo_unseen_examples = DatasetObject( 178 | std_preds["id"], 179 | self.data_folder, 180 | transform=self.transform, 181 | augmentations=None, 182 | train=True, 183 | labels=None, 184 | label_map=self.label_to_idx, 185 | class_folder=True, 186 | original_filepaths=unlabeled_examples.filepaths, 187 | ) 188 | 189 | pseudo_labels = self.assign_pseudo_labels( 190 | self.config.N_PSEUDOSHOTS, pseudo_unseen_examples 191 | ) 192 | 193 | return pseudo_labels 194 | 195 | def assign_pseudo_labels(self, k, unlabeled_data): 196 | # Define text queries 197 | # prompts = [f"{self.template}{' '.join(i.split('_'))}" \ 198 | # for i in self.unseen_classes] 199 | 200 | log.info(f"[self.assign_pseudo_labels] Number of prompts: {len(self.unseen_classes)}") 201 | 202 | # Get prompts 203 | self.model.classes = self.unseen_classes 204 | text_features = self.model(self.model.classes) 205 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 206 | log.info(f"TEXT FEATURES SHAPE: {text_features.size()}") 207 | 208 | # to find the top k for each class, each class has it's own "leaderboard" 209 | top_k_leaderboard = { 210 | self.label_to_idx[self.unseen_classes[i]]: [] 211 | for i in range(len(self.unseen_classes)) 212 | } # maps class idx -> (confidence, image_path) tuple 213 | 214 | for img_path in unlabeled_data.filepaths: 215 | # log.info(f"IMAGEPATH: {img_path}") 216 | img = Image.open(img_path).convert("RGB") 217 | img = torch.unsqueeze(self.transform(img), 0).to(self.device) 218 | with torch.no_grad(): 219 | image_features = self.clip_model.encode_image(img) 220 | image_features = image_features / image_features.norm( 221 | dim=-1, keepdim=True 222 | ) 223 | # cosine similarity as logits 224 | 225 | logit_scale = self.clip_model.logit_scale.exp() 226 | logits = logit_scale * image_features @ text_features.t() 227 | probs = logits.softmax(dim=-1) 228 | idx_preds = torch.argmax(logits, dim=1) 229 | pred_id = idx_preds.item() 230 | pred = self.label_to_idx[self.unseen_classes[idx_preds.item()]] 231 | 232 | """if predicted class has empty leaderboard, or if the confidence is high 233 | enough for predicted class leaderboard, add the new example 234 | """ 235 | prob_score = probs[0][pred_id] 236 | if len(top_k_leaderboard[pred]) < k: 237 | top_k_leaderboard[pred].append((prob_score, img_path)) 238 | elif ( 239 | top_k_leaderboard[pred][-1][0] < prob_score 240 | ): # if the confidence in predicted class "qualifies" for top-k 241 | # default sorting of tuples is by first element 242 | top_k_leaderboard[pred] = sorted( 243 | top_k_leaderboard[pred] + [(probs[0][pred_id], img_path)], 244 | reverse=True, 245 | )[:k] 246 | else: 247 | # sort the other classes by confidence score 248 | order_of_classes = sorted( 249 | [ 250 | (probs[0][j], j) 251 | for j in range(len(self.unseen_classes)) 252 | if j != pred_id 253 | ], 254 | reverse=True, 255 | ) 256 | for score, index in order_of_classes: 257 | index_dict = self.label_to_idx[self.unseen_classes[index]] 258 | # log.info(f"{classnames[index]}") 259 | # log.info(f"{index_dict}") 260 | if len(top_k_leaderboard[index_dict]) < k: 261 | top_k_leaderboard[index_dict].append( 262 | (probs[0][index], img_path) 263 | ) 264 | elif top_k_leaderboard[index_dict][-1][0] < probs[0][index]: 265 | # default sorting of tuples is by first element 266 | top_k_leaderboard[index_dict] = sorted( 267 | top_k_leaderboard[index_dict] 268 | + [((probs[0][index], img_path))], 269 | reverse=True, 270 | )[:k] 271 | 272 | new_imgs = [] 273 | new_labels = [] 274 | # loop through, and rebuild the dataset 275 | for index, leaderboard in top_k_leaderboard.items(): 276 | new_imgs += [tup[1] for tup in leaderboard] 277 | new_labels += [index for _ in leaderboard] 278 | 279 | unlabeled_data.filepaths = new_imgs 280 | unlabeled_data.labels = new_labels 281 | unlabeled_data.label_id = True 282 | 283 | return unlabeled_data 284 | -------------------------------------------------------------------------------- /methods/transductive_zsl/__init__.py: -------------------------------------------------------------------------------- 1 | from .training_strategies import TrainingStrategy 2 | from .visual_prompt import VisualPrompt 3 | from .textual_prompt import TextualPrompt 4 | from .multimodal_prompt import MultimodalPrompt 5 | from .visual_fpl import VisualFPL 6 | from .textual_fpl import TextualFPL 7 | from .multimodal_fpl import MultimodalFPL 8 | -------------------------------------------------------------------------------- /methods/transductive_zsl/multimodal_fpl.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import clip 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from accelerate import Accelerator 8 | from PIL import Image 9 | from torch import nn 10 | import math 11 | 12 | accelerator = Accelerator() 13 | 14 | from methods.transductive_zsl import MultimodalPrompt 15 | from utils import ( 16 | dataset_object, 17 | make_scheduler, 18 | pseudolabel_top_k, 19 | ) 20 | 21 | 22 | log = logging.getLogger(__name__) 23 | 24 | 25 | class MultimodalFPL(MultimodalPrompt): 26 | def __init__( 27 | self, 28 | config, 29 | label_to_idx, 30 | data_folder, 31 | classes, 32 | seen_classes, 33 | unseen_classes, 34 | device 35 | ): 36 | """This class defines self-trainig UPT's training and evaluation. 37 | :param config: dictionaries of prameters in models_config/upt_baseline_pseudo_config.yml 38 | :param label_to_idx: dictionary (key, value):(class name, id) 39 | :param classes: list of class names 40 | :param seen_classes: list of seen classes' names 41 | :param unseen_classes: list of unseen classes' names 42 | :param device: device in use 43 | """ 44 | 45 | super().__init__( 46 | config, label_to_idx, classes, seen_classes, unseen_classes, device 47 | ) 48 | 49 | self.data_folder = data_folder 50 | 51 | def create_training_dataset(self, train_data, unlabeled_data=None): 52 | """This function creates the dataset for training. Specifically, it 53 | merges pseudo-labels for unseen data and labeled data for seen classes. 54 | :param train_data: Dataset object - training seen classes (defined in zsl_jpl line 323) 55 | :param unlabeled_data: Dataset object - dataset of unlabeled data for 56 | unseen classes (defined in zsl_jpl line 328) 57 | """ 58 | 59 | # Get pseudo-labels for unlabeled data from unseen classes 60 | train_unseen_dataset = pseudolabel_top_k( 61 | self.config, 62 | self.config.DATASET_NAME, 63 | self.config.N_PSEUDOSHOTS, 64 | self.config.PROMPT_TEMPLATE, 65 | unlabeled_data, 66 | self.unseen_classes, 67 | self.transform, 68 | self.clip_model, 69 | self.label_to_idx, 70 | self.device, 71 | self.config.VIS_ENCODER, 72 | self.config.SPLIT_SEED 73 | ) 74 | 75 | # Define the lists of traiing data from seen and unseen classes 76 | unseen_imgs = train_unseen_dataset.filepaths 77 | unseen_labs = train_unseen_dataset.labels 78 | 79 | # Use a portion of the pseudo-labeled data to build a validation set 80 | if self.config.N_PSEUDOSHOTS >= 10: 81 | np.random.seed(self.config.validation_seed) 82 | train_indices = np.random.choice( 83 | range(len(unseen_imgs)), 84 | size=int(len(unseen_imgs) * self.config.ratio_train_val), 85 | replace=False, 86 | ) 87 | val_indices = list( 88 | set(range(len(unseen_imgs))).difference(set(train_indices)) 89 | ) 90 | 91 | self.val_unseen_files = np.array(unseen_imgs)[val_indices] 92 | self.val_unseen_labs = np.array(unseen_labs)[val_indices] 93 | 94 | unseen_imgs = list(np.array(unseen_imgs)[train_indices]) 95 | unseen_labs = list(np.array(unseen_labs)[train_indices]) 96 | 97 | else: 98 | self.val_unseen_files = None 99 | self.val_unseen_labs = None 100 | 101 | seen_imgs = train_data.filepaths 102 | seen_labs = [self.label_to_idx[l] for l in train_data.labels] 103 | 104 | self.balance_param = math.sqrt(len(seen_imgs) / len(unseen_imgs)) 105 | 106 | train_data.filepaths = list(unseen_imgs) + list(seen_imgs) 107 | train_data.labels = list(unseen_labs) + list(seen_labs) 108 | train_data.label_id = True 109 | 110 | def define_loss_function(self, logits, labs, teacher=False): 111 | 112 | loss_ce_seen = self.cross_entropy(logits, labs, self.seen_classes) 113 | loss_ce_unseen = self.cross_entropy(logits, labs, self.unseen_classes) 114 | 115 | return loss_ce_seen + self.balance_param * loss_ce_unseen 116 | 117 | def cross_entropy(self, logits, labels, classes): 118 | """This loss computes the probability mass on the 119 | opposite set of classes for each sample. 120 | :param logits: continuous vector 121 | :param labels: class ids 122 | """ 123 | 124 | ids = [self.label_to_idx[c] for c in classes] 125 | 126 | # Get indices of unseen and seen samples in the batch 127 | samples = [] 128 | 129 | for idx, l in enumerate(labels): 130 | if l in ids: 131 | samples.append(idx) 132 | 133 | if samples: 134 | error = self.loss_func(logits[samples], labels[samples]) 135 | else: 136 | error = 0 137 | 138 | return error 139 | 140 | def reindex_predicted_labels(self, idx_preds, only_unlabelled=False): 141 | """This function returns the correct index of predictions to compute 142 | model's accuracy. 143 | :param idx_pred: list of predictions ids 144 | :param only_unlabelled: boolean. It is True if the training only involves 145 | pseudo-labeled unseen data 146 | """ 147 | 148 | if only_unlabelled: 149 | return [self.unseen_classes[i.item()] for i in idx_preds] 150 | else: 151 | return [self.classes[i.item()] for i in idx_preds] 152 | 153 | def reindex_true_labels(self, label, only_unlabelled=False): 154 | """This function returns the correct index of true labels. 155 | :param label: list of labels from data loader 156 | :param only_unlabelled: boolean. It is True if the training only involves 157 | pseudo-labeled unseen data 158 | """ 159 | 160 | if only_unlabelled: 161 | return torch.tensor( 162 | [self.unseen_classes.index(self.classes[l.item()]) for l in label] 163 | ) 164 | else: 165 | return torch.tensor([l for l in label]) 166 | 167 | def get_pseudo_labels(self, unlabeled_examples): 168 | log.info(f"Num unlabeled data: {len(unlabeled_examples)}") 169 | # Get prediction on unlabeled data 170 | std_preds = self.test_predictions( 171 | unlabeled_examples, standard_zsl=True 172 | ) 173 | 174 | DatasetObject = dataset_object(self.config.DATASET_NAME) 175 | # 4. Take top-16 pseudo-labels to finetune the student 176 | pseudo_unseen_examples = DatasetObject( 177 | std_preds["id"], 178 | self.data_folder, 179 | transform=self.transform, 180 | augmentations=None, 181 | train=True, 182 | labels=None, 183 | label_map=self.label_to_idx, 184 | class_folder=True, 185 | original_filepaths=unlabeled_examples.filepaths, 186 | ) 187 | 188 | pseudo_labels = self.assign_pseudo_labels( 189 | self.config.N_PSEUDOSHOTS, pseudo_unseen_examples 190 | ) 191 | 192 | return pseudo_labels 193 | 194 | def assign_pseudo_labels(self, k, unlabeled_data): 195 | 196 | # to find the top k for each class, each class has it's own "leaderboard" 197 | top_k_leaderboard = { 198 | self.label_to_idx[self.unseen_classes[i]]: [] 199 | for i in range(len(self.unseen_classes)) 200 | } # maps class idx -> (confidence, image_path) tuple 201 | 202 | classes = self.unseen_classes 203 | for img_path in unlabeled_data.filepaths: 204 | # log.info(f"IMAGEPATH: {img_path}") 205 | img = Image.open(img_path).convert("RGB") 206 | img = torch.unsqueeze(self.transform(img), 0).to(self.device) 207 | with torch.no_grad(): 208 | # Get text and image prompts using UPT 209 | # coop_embeddings, vpt_embeddings, vpt_deep_embeddings = self.model(0) 210 | # Calculate text prompts 211 | # text_features = self.text_encoder(coop_embeddings, classes) 212 | text_features, image_features = self.model(img, classes) 213 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 214 | # Calculate image prompts 215 | # image_features = self.image_encoder(img, vpt_embeddings, deep_embds=vpt_deep_embeddings) 216 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 217 | 218 | logit_scale = self.clip_model.logit_scale.exp() 219 | logits = logit_scale * image_features @ text_features.t() 220 | probs = logits.softmax(dim=-1) 221 | idx_preds = torch.argmax(logits, dim=1) 222 | pred_id = idx_preds.item() 223 | pred = self.label_to_idx[self.unseen_classes[idx_preds.item()]] 224 | 225 | """if predicted class has empty leaderboard, or if the confidence is high 226 | enough for predicted class leaderboard, add the new example 227 | """ 228 | prob_score = probs[0][pred_id] 229 | if len(top_k_leaderboard[pred]) < k: 230 | top_k_leaderboard[pred].append((prob_score, img_path)) 231 | elif ( 232 | top_k_leaderboard[pred][-1][0] < prob_score 233 | ): # if the confidence in predicted class "qualifies" for top-k 234 | # default sorting of tuples is by first element 235 | top_k_leaderboard[pred] = sorted( 236 | top_k_leaderboard[pred] + [(probs[0][pred_id], img_path)], 237 | reverse=True, 238 | )[:k] 239 | else: 240 | # sort the other classes by confidence score 241 | order_of_classes = sorted( 242 | [ 243 | (probs[0][j], j) 244 | for j in range(len(self.unseen_classes)) 245 | if j != pred_id 246 | ], 247 | reverse=True, 248 | ) 249 | for score, index in order_of_classes: 250 | index_dict = self.label_to_idx[self.unseen_classes[index]] 251 | # log.info(f"{classnames[index]}") 252 | # log.info(f"{index_dict}") 253 | if len(top_k_leaderboard[index_dict]) < k: 254 | top_k_leaderboard[index_dict].append( 255 | (probs[0][index], img_path) 256 | ) 257 | elif top_k_leaderboard[index_dict][-1][0] < probs[0][index]: 258 | # default sorting of tuples is by first element 259 | top_k_leaderboard[index_dict] = sorted( 260 | top_k_leaderboard[index_dict] 261 | + [((probs[0][index], img_path))], 262 | reverse=True, 263 | )[:k] 264 | 265 | new_imgs = [] 266 | new_labels = [] 267 | # loop through, and rebuild the dataset 268 | for index, leaderboard in top_k_leaderboard.items(): 269 | new_imgs += [tup[1] for tup in leaderboard] 270 | new_labels += [index for _ in leaderboard] 271 | 272 | unlabeled_data.filepaths = new_imgs 273 | unlabeled_data.labels = new_labels 274 | unlabeled_data.label_id = True 275 | 276 | return unlabeled_data 277 | -------------------------------------------------------------------------------- /methods/transductive_zsl/pseudo_iterative.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import math 4 | import pickle 5 | 6 | import clip 7 | import numpy as np 8 | import pandas as pd 9 | import scipy.stats as st 10 | import torch 11 | from accelerate import Accelerator 12 | from PIL import Image 13 | from torch import nn 14 | 15 | 16 | accelerator = Accelerator() 17 | 18 | from data import CustomDataset 19 | from models import CustomImageEncoder, ImagePrefixModel 20 | from methods import TeacherStudent 21 | from utils import ( 22 | dataset_object, 23 | evaluate_predictions, 24 | make_scheduler, 25 | pseudolabel_top_k, 26 | seed_worker, 27 | save_parameters, 28 | save_pseudo_labels, 29 | ) 30 | 31 | g = torch.Generator() 32 | g.manual_seed(0) 33 | 34 | log = logging.getLogger(__name__) 35 | 36 | 37 | class PseudoIterative(TeacherStudent): 38 | def __init__( 39 | self, 40 | config, 41 | label_to_idx, 42 | data_folder, 43 | classes, 44 | seen_classes, 45 | unseen_classes, 46 | device, 47 | ): 48 | super().__init__( 49 | config, label_to_idx, data_folder, classes, seen_classes, unseen_classes, device 50 | ) 51 | 52 | 53 | def train( 54 | self, 55 | train_data, 56 | val_data, 57 | unlabeled_data, 58 | test_data, 59 | test_labeled_files, 60 | test_labeles, 61 | ): 62 | # Number of total iterations to cover all unlabeled data 63 | num_iter = int(100/self.config.STEP_QUANTILE) 64 | num_samples = int(len(unlabeled_data) / num_iter) 65 | # Initialize the number of pseudo-labels per class 66 | n_per_class = int(num_samples / len(self.unseen_classes)) 67 | n_unseen = len(self.unseen_classes) 68 | if n_per_class * n_unseen <= len(unlabeled_data.filepaths): 69 | # self.num_pseudo_labels_per_class = n_per_class 70 | self.config.N_PSEUDOSHOTS = n_per_class 71 | else: 72 | # self.num_pseudo_labels_per_class = math.floor(len(unlabeled_data.filepaths)/n_unseen) 73 | self.config.N_PSEUDOSHOTS = math.floor( 74 | len(unlabeled_data.filepaths) / n_unseen 75 | ) 76 | 77 | log.info(f"We select {self.config.N_PSEUDOSHOTS} pseudolabel per each unseen classes.") 78 | log.info(f"The number of unseen classes is: {len(self.unseen_classes)}.") 79 | log.info(f"Thus we expect an initial number of pseudo labeles equal to {len(self.unseen_classes) * self.config.N_PSEUDOSHOTS}.") 80 | # Create a safe copy of labeled/unlabeled data 81 | original_train_data = copy.deepcopy(train_data) 82 | # log.info(f"Training data labels: {original_train_data.labels}") 83 | original_unlabeled_data = copy.deepcopy(unlabeled_data) 84 | # Original val 85 | original_val_data = copy.deepcopy(val_data) 86 | 87 | # Initialize here first batch of pseudo labels 88 | #self.create_training_dataset(train_data, unlabeled_data) 89 | #log.info(f"The original train data has size: {len(original_train_data.filepaths)}.") 90 | #log.info(f"Plus: {len(unlabeled_data.filepaths)}.") 91 | 92 | for niter in range(1, num_iter + 1): 93 | log.info(f"NUM PSEUDO SHOTS: {self.config.N_PSEUDOSHOTS}") 94 | pseudolabel_top_k( 95 | self.config.DATASET_NAME, 96 | self.config.N_PSEUDOSHOTS, 97 | self.config.PROMPT_TEMPLATE, 98 | unlabeled_data, 99 | self.unseen_classes, 100 | self.transform, 101 | self.clip_model, 102 | self.label_to_idx, 103 | self.device, 104 | self.config.VIS_ENCODER, 105 | self.config.SPLIT_SEED, 106 | ) 107 | 108 | log.info(f"Plus: {len(unlabeled_data.filepaths)}.") 109 | filename = f"pseudolabels/{self.config.DATASET_NAME}_CLIP_{self.config.VIS_ENCODER.replace('/', '')}_iter_{niter}_pseudolabels_spl_{self.config.SPLIT_SEED}.pickle" 110 | with open(filename, "wb") as f: 111 | pickle.dump({"filepaths": unlabeled_data.filepaths, "labels": unlabeled_data.labels}, f) 112 | 113 | # Exploit all the available unlabeled data 114 | if self.config.ALL_UNLABELED: 115 | n_per_class = int((niter + 1) * num_samples / n_unseen) 116 | log.info(f"n_per_class: {n_per_class}") 117 | if n_per_class * n_unseen <= len(original_unlabeled_data.filepaths): 118 | log.info(f"if n_per_class: {n_per_class}") 119 | self.config.N_PSEUDOSHOTS = n_per_class 120 | else: 121 | log.info(f"else new val: {len(original_unlabeled_data.filepaths) / n_unseen}") 122 | # We are making a stong assumption about the distribution of unlabeled data 123 | self.config.N_PSEUDOSHOTS = math.floor( 124 | len(original_unlabeled_data.filepaths) / n_unseen 125 | ) 126 | 127 | unlabeled_data = original_unlabeled_data 128 | original_unlabeled_data = copy.deepcopy(unlabeled_data) -------------------------------------------------------------------------------- /methods/transductive_zsl/textual_fpl.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import clip 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from accelerate import Accelerator 8 | from PIL import Image 9 | from torch import nn 10 | from tqdm import tqdm 11 | 12 | accelerator = Accelerator() 13 | 14 | from methods.transductive_zsl import TextualPrompt 15 | from utils import ( 16 | dataset_object, 17 | make_scheduler, 18 | pseudolabel_top_k, 19 | seed_worker, 20 | ) 21 | 22 | 23 | g = torch.Generator() 24 | g.manual_seed(0) 25 | 26 | log = logging.getLogger(__name__) 27 | 28 | 29 | class TextualFPL(TextualPrompt): 30 | def __init__( 31 | self, 32 | config, 33 | label_to_idx, 34 | data_folder, 35 | classes, 36 | seen_classes, 37 | unseen_classes, 38 | device, 39 | ): 40 | """This class define Coop baseline. 41 | 42 | :param config: dictionaries of prameters in models_config/coop_baseline_config.yml 43 | :param label_to_idx: dictionary (key, value):(class name, id) 44 | :param classes: list of class names 45 | :param seen_classes: list of seen classes' names 46 | :param unseen_classes: list of unseen classes' names 47 | :param device: device in use 48 | """ 49 | super().__init__( 50 | config, label_to_idx, classes, seen_classes, unseen_classes, device 51 | ) 52 | 53 | self.data_folder = data_folder 54 | 55 | def create_training_dataset(self, train_data, unlabeled_data=None): 56 | """This function create the dataset for training. Specifically, it 57 | merges pseudo-labels for unseen data and labeled data for seen classes. 58 | 59 | :param train_data: Dataset object - training seen classes (defined in zsl_jpl line 323) 60 | :param unlabeled_data: Dataset object - dataset of unlabeled data for 61 | unseen classes (defined in zsl_jpl line 328) 62 | """ 63 | 64 | # Get pseudo-labels for unlabeled data from unseen classes 65 | train_unseen_dataset = pseudolabel_top_k( 66 | self.config, 67 | self.config.DATASET_NAME, 68 | self.config.N_PSEUDOSHOTS, 69 | self.config.PROMPT_TEMPLATE, 70 | unlabeled_data, 71 | self.unseen_classes, 72 | self.transform, 73 | self.clip_model, 74 | self.label_to_idx, 75 | self.device, 76 | self.config.VIS_ENCODER, 77 | self.config.SPLIT_SEED, 78 | ) 79 | 80 | # Define the lists of traiing data from seen and unseen classes 81 | unseen_imgs = train_unseen_dataset.filepaths 82 | unseen_labs = train_unseen_dataset.labels 83 | 84 | # Use a portion of the pseudo-labeled data to build a validation set 85 | if self.config.N_PSEUDOSHOTS >= 10: 86 | np.random.seed(self.config.validation_seed) 87 | train_indices = np.random.choice( 88 | range(len(unseen_imgs)), 89 | size=int(len(unseen_imgs) * self.config.ratio_train_val), 90 | replace=False, 91 | ) 92 | val_indices = list( 93 | set(range(len(unseen_imgs))).difference(set(train_indices)) 94 | ) 95 | 96 | self.val_unseen_files = np.array(unseen_imgs)[val_indices] 97 | self.val_unseen_labs = np.array(unseen_labs)[val_indices] 98 | 99 | unseen_imgs = list(np.array(unseen_imgs)[train_indices]) 100 | unseen_labs = list(np.array(unseen_labs)[train_indices]) 101 | 102 | else: 103 | self.val_unseen_files = None 104 | self.val_unseen_labs = None 105 | 106 | seen_imgs = train_data.filepaths 107 | seen_labs = [self.label_to_idx[l] for l in train_data.labels] 108 | 109 | self.balance_param = len(seen_imgs) / len(unseen_imgs) 110 | 111 | train_data.filepaths = list(unseen_imgs) + list(seen_imgs) 112 | train_data.labels = list(unseen_labs) + list(seen_labs) 113 | train_data.label_id = True 114 | 115 | return train_data 116 | 117 | def define_loss_function(self, logits, labs): 118 | 119 | loss_ce_seen = self.cross_entropy(logits, labs, self.seen_classes) 120 | loss_ce_unseen = self.cross_entropy(logits, labs, self.unseen_classes) 121 | 122 | return loss_ce_seen + self.balance_param * loss_ce_unseen 123 | 124 | def cross_entropy(self, logits, labels, classes): 125 | """This loss computes the probability mass on the 126 | opposite set of classes for each sample. 127 | 128 | :param logits: continuous vector 129 | :param labels: class ids 130 | """ 131 | 132 | ids = [self.label_to_idx[c] for c in classes] 133 | 134 | # Get indices of unseen and seen samples in the batch 135 | samples = [] 136 | 137 | for idx, l in enumerate(labels): 138 | if l in ids: 139 | samples.append(idx) 140 | 141 | # Get logit sums on unseen samples 142 | if samples: 143 | error = self.loss_func(logits[samples], labels[samples]) 144 | else: 145 | error = 0 146 | 147 | return error 148 | 149 | 150 | def get_pseudo_labels(self, unlabeled_examples): 151 | log.info(f"Num unlabeled data: {len(unlabeled_examples)}") 152 | # Get prediction on unlabeled data 153 | std_preds = self.test_predictions( 154 | unlabeled_examples, standard_zsl=True 155 | ) 156 | 157 | DatasetObject = dataset_object(self.config.DATASET_NAME) 158 | # 4. Take top-16 pseudo-labels to finetune the student 159 | pseudo_unseen_examples = DatasetObject( 160 | std_preds["id"], 161 | self.data_folder, 162 | transform=self.transform, 163 | augmentations=None, 164 | train=True, 165 | labels=None, 166 | label_map=self.label_to_idx, 167 | class_folder=True, 168 | original_filepaths=unlabeled_examples.filepaths, 169 | ) 170 | 171 | pseudo_labels = self.assign_pseudo_labels( 172 | self.config.N_PSEUDOSHOTS, pseudo_unseen_examples 173 | ) 174 | 175 | return pseudo_labels 176 | 177 | def assign_pseudo_labels(self, k, unlabeled_data): 178 | # Define text queries 179 | # prompts = [f"{self.template}{' '.join(i.split('_'))}" \ 180 | # for i in self.unseen_classes] 181 | 182 | log.info(f"[self.assign_pseudo_labels] Number of prompts: {len(self.unseen_classes)}") 183 | 184 | # Get prompts 185 | self.model.classes = self.unseen_classes 186 | text_features = self.model(self.model.classes) 187 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 188 | log.info(f"TEXT FEATURES SHAPE: {text_features.size()}") 189 | 190 | # to find the top k for each class, each class has it's own "leaderboard" 191 | top_k_leaderboard = { 192 | self.label_to_idx[self.unseen_classes[i]]: [] 193 | for i in range(len(self.unseen_classes)) 194 | } # maps class idx -> (confidence, image_path) tuple 195 | 196 | for img_path in unlabeled_data.filepaths: 197 | # log.info(f"IMAGEPATH: {img_path}") 198 | img = Image.open(img_path).convert("RGB") 199 | img = torch.unsqueeze(self.transform(img), 0).to(self.device) 200 | with torch.no_grad(): 201 | image_features = self.clip_model.encode_image(img) 202 | image_features = image_features / image_features.norm( 203 | dim=-1, keepdim=True 204 | ) 205 | # cosine similarity as logits 206 | 207 | logit_scale = self.clip_model.logit_scale.exp() 208 | logits = logit_scale * image_features @ text_features.t() 209 | probs = logits.softmax(dim=-1) 210 | idx_preds = torch.argmax(logits, dim=1) 211 | pred_id = idx_preds.item() 212 | pred = self.label_to_idx[self.unseen_classes[idx_preds.item()]] 213 | 214 | """if predicted class has empty leaderboard, or if the confidence is high 215 | enough for predicted class leaderboard, add the new example 216 | """ 217 | prob_score = probs[0][pred_id] 218 | if len(top_k_leaderboard[pred]) < k: 219 | top_k_leaderboard[pred].append((prob_score, img_path)) 220 | elif ( 221 | top_k_leaderboard[pred][-1][0] < prob_score 222 | ): # if the confidence in predicted class "qualifies" for top-k 223 | # default sorting of tuples is by first element 224 | top_k_leaderboard[pred] = sorted( 225 | top_k_leaderboard[pred] + [(probs[0][pred_id], img_path)], 226 | reverse=True, 227 | )[:k] 228 | else: 229 | # sort the other classes by confidence score 230 | order_of_classes = sorted( 231 | [ 232 | (probs[0][j], j) 233 | for j in range(len(self.unseen_classes)) 234 | if j != pred_id 235 | ], 236 | reverse=True, 237 | ) 238 | for score, index in order_of_classes: 239 | index_dict = self.label_to_idx[self.unseen_classes[index]] 240 | # log.info(f"{classnames[index]}") 241 | # log.info(f"{index_dict}") 242 | if len(top_k_leaderboard[index_dict]) < k: 243 | top_k_leaderboard[index_dict].append( 244 | (probs[0][index], img_path) 245 | ) 246 | elif top_k_leaderboard[index_dict][-1][0] < probs[0][index]: 247 | # default sorting of tuples is by first element 248 | top_k_leaderboard[index_dict] = sorted( 249 | top_k_leaderboard[index_dict] 250 | + [((probs[0][index], img_path))], 251 | reverse=True, 252 | )[:k] 253 | 254 | new_imgs = [] 255 | new_labels = [] 256 | # loop through, and rebuild the dataset 257 | for index, leaderboard in top_k_leaderboard.items(): 258 | new_imgs += [tup[1] for tup in leaderboard] 259 | new_labels += [index for _ in leaderboard] 260 | 261 | unlabeled_data.filepaths = new_imgs 262 | unlabeled_data.labels = new_labels 263 | unlabeled_data.label_id = True 264 | 265 | return unlabeled_data 266 | -------------------------------------------------------------------------------- /methods/transductive_zsl/visual_fpl.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import clip 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from accelerate import Accelerator 8 | from PIL import Image 9 | from torch import nn 10 | 11 | accelerator = Accelerator() 12 | 13 | from methods.transductive_zsl import VisualPrompt 14 | from utils import ( 15 | dataset_object, 16 | make_scheduler, 17 | pseudolabel_top_k, 18 | ) 19 | 20 | 21 | log = logging.getLogger(__name__) 22 | 23 | 24 | class VisualFPL(VisualPrompt): 25 | def __init__( 26 | self, 27 | config, 28 | label_to_idx, 29 | data_folder, 30 | classes, 31 | seen_classes, 32 | unseen_classes, 33 | device 34 | ): 35 | """This class defines few-pseudo-learning (FPL) VPT's training and evaluation. 36 | 37 | :param config: dictionaries of prameters in models_config/vpt_baseline_config.yml 38 | :param label_to_idx: dictionary (key, value):(class name, id) 39 | :param classes: list of class names 40 | :param seen_classes: list of seen classes' names 41 | :param unseen_classes: list of unseen classes' names 42 | :param device: device in use 43 | """ 44 | 45 | super().__init__( 46 | config, label_to_idx, classes, seen_classes, unseen_classes, device 47 | ) 48 | 49 | self.data_folder = data_folder 50 | 51 | def create_training_dataset(self, train_data, unlabeled_data=None): 52 | """This function creates the dataset for training. Specifically, it 53 | merges pseudo-labels for unseen data and labeled data for seen classes. 54 | 55 | :param train_data: Dataset object - training seen classes (defined in zsl_jpl line 323) 56 | :param unlabeled_data: Dataset object - dataset of unlabeled data for 57 | unseen classes (defined in zsl_jpl line 328) 58 | """ 59 | 60 | # Get pseudo-labels for unlabeled data from unseen classes 61 | train_unseen_dataset = pseudolabel_top_k( 62 | self.config, 63 | self.config.DATASET_NAME, 64 | self.config.N_PSEUDOSHOTS, 65 | self.config.PROMPT_TEMPLATE, 66 | unlabeled_data, 67 | self.unseen_classes, 68 | self.transform, 69 | self.clip_model, 70 | self.label_to_idx, 71 | self.device, 72 | self.config.VIS_ENCODER, 73 | self.config.SPLIT_SEED, 74 | ) 75 | 76 | # Define the lists of traiing data from seen and unseen classes 77 | unseen_imgs = train_unseen_dataset.filepaths 78 | unseen_labs = train_unseen_dataset.labels 79 | 80 | # Use a portion of the pseudo-labeled data to build a validation set 81 | if self.config.N_PSEUDOSHOTS >= 10: 82 | np.random.seed(self.config.validation_seed) 83 | train_indices = np.random.choice( 84 | range(len(unseen_imgs)), 85 | size=int(len(unseen_imgs) * self.config.ratio_train_val), 86 | replace=False, 87 | ) 88 | val_indices = list( 89 | set(range(len(unseen_imgs))).difference(set(train_indices)) 90 | ) 91 | 92 | self.val_unseen_files = np.array(unseen_imgs)[val_indices] 93 | self.val_unseen_labs = np.array(unseen_labs)[val_indices] 94 | 95 | unseen_imgs = list(np.array(unseen_imgs)[train_indices]) 96 | unseen_labs = list(np.array(unseen_labs)[train_indices]) 97 | 98 | else: 99 | self.val_unseen_files = None 100 | self.val_unseen_labs = None 101 | 102 | seen_imgs = train_data.filepaths 103 | seen_labs = [self.label_to_idx[l] for l in train_data.labels] 104 | 105 | self.balance_param = len(seen_imgs) / len(unseen_imgs) 106 | 107 | train_data.filepaths = list(unseen_imgs) + list(seen_imgs) 108 | train_data.labels = list(unseen_labs) + list(seen_labs) 109 | train_data.label_id = True 110 | 111 | 112 | def define_loss_function(self, logits, labs): 113 | 114 | loss_ce_seen = self.cross_entropy(logits, labs, self.seen_classes) 115 | loss_ce_unseen = self.cross_entropy(logits, labs, self.unseen_classes) 116 | 117 | return loss_ce_seen + self.balance_param * loss_ce_unseen 118 | 119 | def cross_entropy(self, logits, labels, classes): 120 | """This loss computes the probability mass on the 121 | opposite set of classes for each sample. 122 | 123 | :param logits: continuous vector 124 | :param labels: class ids 125 | """ 126 | 127 | ids = [self.label_to_idx[c] for c in classes] 128 | 129 | # Get indices of unseen and seen samples in the batch 130 | samples = [] 131 | 132 | for idx, l in enumerate(labels): 133 | if l in ids: 134 | samples.append(idx) 135 | 136 | if samples: 137 | error = self.loss_func(logits[samples], labels[samples]) 138 | else: 139 | error = 0 140 | 141 | return error 142 | 143 | def define_textual_prompts(self, only_unlabelled=None, validation=False): 144 | """This function returns the textual prompts. You can modify the list 145 | of classes of interest. 146 | 147 | :param only_unlabelled: boolean. It is True if the training only involves 148 | pseudo-labeled unseen data 149 | """ 150 | 151 | if only_unlabelled: 152 | return [ 153 | self.template.format(" ".join(i.split("_"))) 154 | for i in self.unseen_classes 155 | ] 156 | else: 157 | if validation: 158 | return [ 159 | self.template.format(" ".join(i.split("_"))) 160 | for i in self.seen_classes 161 | ] 162 | else: 163 | return [ 164 | self.template.format(" ".join(i.split("_"))) for i in self.classes 165 | ] 166 | 167 | def reindex_predicted_labels(self, idx_preds, only_unlabelled=False): 168 | """This function returns the correct index of predictions to compute 169 | model's accuracy. 170 | 171 | :param idx_pred: list of predictions ids 172 | :param only_unlabelled: boolean. It is True if the training only involves 173 | pseudo-labeled unseen data 174 | """ 175 | 176 | if only_unlabelled: 177 | return [self.unseen_classes[i.item()] for i in idx_preds] 178 | else: 179 | return [self.classes[i.item()] for i in idx_preds] 180 | 181 | def reindex_true_labels(self, label, only_unlabelled=False): 182 | """This function returns the correct index of true labels. 183 | 184 | :param label: list of labels from data loader 185 | :param only_unlabelled: boolean. It is True if the training only involves 186 | pseudo-labeled unseen data 187 | """ 188 | 189 | if only_unlabelled: 190 | return torch.tensor( 191 | [self.unseen_classes.index(self.classes[l.item()]) for l in label] 192 | ) 193 | else: 194 | return torch.tensor([l for l in label]) 195 | 196 | 197 | def get_pseudo_labels(self, unlabeled_examples): 198 | log.info(f"Num unlabeled data: {len(unlabeled_examples)}") 199 | # Get prediction on unlabeled data 200 | std_preds = self.test_predictions( 201 | unlabeled_examples, standard_zsl=True 202 | ) 203 | 204 | DatasetObject = dataset_object(self.config.DATASET_NAME) 205 | # 4. Take top-16 pseudo-labels to finetune the student 206 | pseudo_unseen_examples = DatasetObject( 207 | std_preds["id"], 208 | self.data_folder, 209 | transform=self.transform, 210 | augmentations=None, 211 | train=True, 212 | labels=None, 213 | label_map=self.label_to_idx, 214 | class_folder=True, 215 | original_filepaths=unlabeled_examples.filepaths, 216 | ) 217 | 218 | pseudo_labels = self.assign_pseudo_labels( 219 | self.config.N_PSEUDOSHOTS, pseudo_unseen_examples 220 | ) 221 | 222 | return pseudo_labels 223 | 224 | def assign_pseudo_labels(self, k, unlabeled_data): 225 | # Define text queries 226 | # prompts = [f"{self.template}{' '.join(i.split('_'))}" \ 227 | # for i in self.unseen_classes] 228 | prompts = [ 229 | self.template.format(" ".join(i.split("_"))) for i in self.unseen_classes 230 | ] 231 | log.info(f"Number of prompts: {len(prompts)}") 232 | 233 | # Encode text 234 | text = clip.tokenize(prompts).to(self.device) 235 | text_features = self.clip_model.encode_text(text) 236 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 237 | 238 | # to find the top k for each class, each class has it's own "leaderboard" 239 | top_k_leaderboard = { 240 | self.label_to_idx[self.unseen_classes[i]]: [] 241 | for i in range(len(self.unseen_classes)) 242 | } # maps class idx -> (confidence, image_path) tuple 243 | 244 | for img_path in unlabeled_data.filepaths: 245 | # log.info(f"IMAGEPATH: {img_path}") 246 | img = Image.open(img_path).convert("RGB") 247 | img = torch.unsqueeze(self.transform(img), 0).to(self.device) 248 | with torch.no_grad(): 249 | image_features = self.model(img) 250 | image_features = image_features / image_features.norm( 251 | dim=-1, keepdim=True 252 | ) 253 | # cosine similarity as logits 254 | 255 | logit_scale = self.clip_model.logit_scale.exp() 256 | logits = logit_scale * image_features @ text_features.t() 257 | probs = logits.softmax(dim=-1) 258 | idx_preds = torch.argmax(logits, dim=1) 259 | pred_id = idx_preds.item() 260 | pred = self.label_to_idx[self.unseen_classes[idx_preds.item()]] 261 | 262 | """if predicted class has empty leaderboard, or if the confidence is high 263 | enough for predicted class leaderboard, add the new example 264 | """ 265 | prob_score = probs[0][pred_id] 266 | if len(top_k_leaderboard[pred]) < k: 267 | top_k_leaderboard[pred].append((prob_score, img_path)) 268 | elif ( 269 | top_k_leaderboard[pred][-1][0] < prob_score 270 | ): # if the confidence in predicted class "qualifies" for top-k 271 | # default sorting of tuples is by first element 272 | top_k_leaderboard[pred] = sorted( 273 | top_k_leaderboard[pred] + [(probs[0][pred_id], img_path)], 274 | reverse=True, 275 | )[:k] 276 | else: 277 | # sort the other classes by confidence score 278 | order_of_classes = sorted( 279 | [ 280 | (probs[0][j], j) 281 | for j in range(len(self.unseen_classes)) 282 | if j != pred_id 283 | ], 284 | reverse=True, 285 | ) 286 | for score, index in order_of_classes: 287 | index_dict = self.label_to_idx[self.unseen_classes[index]] 288 | # log.info(f"{classnames[index]}") 289 | # log.info(f"{index_dict}") 290 | if len(top_k_leaderboard[index_dict]) < k: 291 | top_k_leaderboard[index_dict].append( 292 | (probs[0][index], img_path) 293 | ) 294 | elif top_k_leaderboard[index_dict][-1][0] < probs[0][index]: 295 | # default sorting of tuples is by first element 296 | top_k_leaderboard[index_dict] = sorted( 297 | top_k_leaderboard[index_dict] 298 | + [((probs[0][index], img_path))], 299 | reverse=True, 300 | )[:k] 301 | 302 | new_imgs = [] 303 | new_labels = [] 304 | # loop through, and rebuild the dataset 305 | for index, leaderboard in top_k_leaderboard.items(): 306 | new_imgs += [tup[1] for tup in leaderboard] 307 | new_labels += [index for _ in leaderboard] 308 | 309 | unlabeled_data.filepaths = new_imgs 310 | unlabeled_data.labels = new_labels 311 | unlabeled_data.label_id = True 312 | 313 | return unlabeled_data 314 | -------------------------------------------------------------------------------- /methods/unsupervised_learning/__init__.py: -------------------------------------------------------------------------------- 1 | from .training_strategies import TrainingStrategy 2 | from .visual_prompt import VisualPrompt 3 | from .textual_prompt import TextualPrompt 4 | from .multimodal_prompt import MultimodalPrompt 5 | from .visual_fpl import VisualFPL 6 | from .textual_fpl import TextualFPL 7 | from .multimodal_fpl import MultimodalFPL 8 | -------------------------------------------------------------------------------- /methods/unsupervised_learning/multimodal_fpl.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import clip 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from accelerate import Accelerator 8 | from PIL import Image 9 | from torch import nn 10 | import math 11 | 12 | accelerator = Accelerator() 13 | 14 | from methods.unsupervised_learning import MultimodalPrompt 15 | from utils import ( 16 | dataset_object, 17 | make_scheduler, 18 | pseudolabel_top_k, 19 | ) 20 | 21 | 22 | log = logging.getLogger(__name__) 23 | 24 | 25 | class MultimodalFPL(MultimodalPrompt): 26 | def __init__( 27 | self, 28 | config, 29 | label_to_idx, 30 | data_folder, 31 | classes, 32 | seen_classes, 33 | unseen_classes, 34 | device 35 | ): 36 | """This class defines self-trainig UPT's training and evaluation. 37 | :param config: dictionaries of prameters in models_config/upt_baseline_pseudo_config.yml 38 | :param label_to_idx: dictionary (key, value):(class name, id) 39 | :param classes: list of class names 40 | :param seen_classes: list of seen classes' names 41 | :param unseen_classes: list of unseen classes' names 42 | :param device: device in use 43 | """ 44 | 45 | super().__init__( 46 | config, label_to_idx, classes, seen_classes, unseen_classes, device 47 | ) 48 | 49 | self.data_folder = data_folder 50 | 51 | def create_training_dataset(self, train_data, unlabeled_data=None): 52 | """This function creates the dataset for training. Specifically, it 53 | merges pseudo-labels for unseen data and labeled data for seen classes. 54 | :param train_data: Dataset object - training seen classes (defined in zsl_jpl line 323) 55 | :param unlabeled_data: Dataset object - dataset of unlabeled data for 56 | unseen classes (defined in zsl_jpl line 328) 57 | """ 58 | 59 | # Get pseudo-labels for unlabeled data from unseen classes 60 | train_unseen_dataset = pseudolabel_top_k( 61 | self.config, 62 | self.config.DATASET_NAME, 63 | self.config.N_PSEUDOSHOTS, 64 | self.config.PROMPT_TEMPLATE, 65 | unlabeled_data, 66 | self.classes, 67 | self.transform, 68 | self.clip_model, 69 | self.label_to_idx, 70 | self.device, 71 | self.config.VIS_ENCODER, 72 | self.config.SPLIT_SEED 73 | ) 74 | 75 | # Define the lists of traiing data from seen and unseen classes 76 | unseen_imgs = train_unseen_dataset.filepaths 77 | unseen_labs = train_unseen_dataset.labels 78 | log.info(f"Number of classes in pseudolabels: {len(set(unseen_labs))}") 79 | 80 | # Use a portion of the pseudo-labeled data to build a validation set 81 | if self.config.N_PSEUDOSHOTS >= 10: 82 | np.random.seed(self.config.validation_seed) 83 | train_indices = np.random.choice( 84 | range(len(unseen_imgs)), 85 | size=int(len(unseen_imgs) * self.config.ratio_train_val), 86 | replace=False, 87 | ) 88 | val_indices = list( 89 | set(range(len(unseen_imgs))).difference(set(train_indices)) 90 | ) 91 | 92 | self.val_unseen_files = np.array(unseen_imgs)[val_indices] 93 | self.val_unseen_labs = np.array(unseen_labs)[val_indices] 94 | 95 | unseen_imgs = list(np.array(unseen_imgs)[train_indices]) 96 | unseen_labs = list(np.array(unseen_labs)[train_indices]) 97 | 98 | else: 99 | self.val_unseen_files = None 100 | self.val_unseen_labs = None 101 | 102 | train_data.filepaths = list(unseen_imgs) 103 | train_data.labels = list(unseen_labs) 104 | train_data.label_id = True 105 | 106 | def define_loss_function(self, logits, labs): 107 | 108 | loss_ce = self.cross_entropy(logits, labs, self.classes) 109 | 110 | return loss_ce 111 | 112 | def cross_entropy(self, logits, labels, classes): 113 | """This loss computes the probability mass on the 114 | opposite set of classes for each sample. 115 | :param logits: continuous vector 116 | :param labels: class ids 117 | """ 118 | 119 | error = self.loss_func(logits, labels) 120 | 121 | return error 122 | 123 | def reindex_predicted_labels(self, idx_preds, only_unlabelled=False): 124 | """This function returns the correct index of predictions to compute 125 | model's accuracy. 126 | :param idx_pred: list of predictions ids 127 | :param only_unlabelled: boolean. It is True if the training only involves 128 | pseudo-labeled unseen data 129 | """ 130 | 131 | return [self.classes[i.item()] for i in idx_preds] 132 | 133 | def reindex_true_labels(self, label, only_unlabelled=False): 134 | """This function returns the correct index of true labels. 135 | :param label: list of labels from data loader 136 | :param only_unlabelled: boolean. It is True if the training only involves 137 | pseudo-labeled unseen data 138 | """ 139 | 140 | return torch.tensor([l for l in label]) 141 | 142 | def get_pseudo_labels(self, unlabeled_examples): 143 | log.info(f"[self.get_pseudo_labels] Num unlabeled data: {len(unlabeled_examples)}") 144 | # Get prediction on unlabeled data 145 | std_preds = self.test_predictions( 146 | unlabeled_examples, standard_zsl=True 147 | ) 148 | 149 | DatasetObject = dataset_object(self.config.DATASET_NAME) 150 | # 4. Take top-16 pseudo-labels to finetune the student 151 | pseudo_unseen_examples = DatasetObject( 152 | std_preds["id"], 153 | self.data_folder, 154 | transform=self.transform, 155 | augmentations=None, 156 | train=True, 157 | labels=None, 158 | label_map=self.label_to_idx, 159 | class_folder=True, 160 | original_filepaths=unlabeled_examples.filepaths, 161 | ) 162 | 163 | pseudo_labels = self.assign_pseudo_labels( 164 | self.config.N_PSEUDOSHOTS, pseudo_unseen_examples 165 | ) 166 | 167 | return pseudo_labels 168 | 169 | def assign_pseudo_labels(self, k, unlabeled_data): 170 | 171 | # to find the top k for each class, each class has it's own "leaderboard" 172 | top_k_leaderboard = { 173 | self.label_to_idx[self.classes[i]]: [] 174 | for i in range(len(self.classes)) 175 | } # maps class idx -> (confidence, image_path) tuple 176 | 177 | classes = self.classes 178 | for img_path in unlabeled_data.filepaths: 179 | # log.info(f"IMAGEPATH: {img_path}") 180 | img = Image.open(img_path).convert("RGB") 181 | img = torch.unsqueeze(self.transform(img), 0).to(self.device) 182 | with torch.no_grad(): 183 | # Get text and image prompts using UPT 184 | # coop_embeddings, vpt_embeddings, vpt_deep_embeddings = self.model(0) 185 | # Calculate text prompts 186 | # text_features = self.text_encoder(coop_embeddings, classes) 187 | text_features, image_features = self.model(img, classes) 188 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 189 | # Calculate image prompts 190 | # image_features = self.image_encoder(img, vpt_embeddings, deep_embds=vpt_deep_embeddings) 191 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 192 | 193 | logit_scale = self.clip_model.logit_scale.exp() 194 | logits = logit_scale * image_features @ text_features.t() 195 | probs = logits.softmax(dim=-1) 196 | idx_preds = torch.argmax(logits, dim=1) 197 | pred_id = idx_preds.item() 198 | pred = self.label_to_idx[self.classes[idx_preds.item()]] 199 | 200 | """if predicted class has empty leaderboard, or if the confidence is high 201 | enough for predicted class leaderboard, add the new example 202 | """ 203 | prob_score = probs[0][pred_id] 204 | if len(top_k_leaderboard[pred]) < k: 205 | top_k_leaderboard[pred].append((prob_score, img_path)) 206 | elif ( 207 | top_k_leaderboard[pred][-1][0] < prob_score 208 | ): # if the confidence in predicted class "qualifies" for top-k 209 | # default sorting of tuples is by first element 210 | top_k_leaderboard[pred] = sorted( 211 | top_k_leaderboard[pred] + [(probs[0][pred_id], img_path)], 212 | reverse=True, 213 | )[:k] 214 | else: 215 | # sort the other classes by confidence score 216 | order_of_classes = sorted( 217 | [ 218 | (probs[0][j], j) 219 | for j in range(len(self.classes)) 220 | if j != pred_id 221 | ], 222 | reverse=True, 223 | ) 224 | for score, index in order_of_classes: 225 | index_dict = self.label_to_idx[self.classes[index]] 226 | # log.info(f"{classnames[index]}") 227 | # log.info(f"{index_dict}") 228 | if len(top_k_leaderboard[index_dict]) < k: 229 | top_k_leaderboard[index_dict].append( 230 | (probs[0][index], img_path) 231 | ) 232 | elif top_k_leaderboard[index_dict][-1][0] < probs[0][index]: 233 | # default sorting of tuples is by first element 234 | top_k_leaderboard[index_dict] = sorted( 235 | top_k_leaderboard[index_dict] 236 | + [((probs[0][index], img_path))], 237 | reverse=True, 238 | )[:k] 239 | 240 | new_imgs = [] 241 | new_labels = [] 242 | # loop through, and rebuild the dataset 243 | for index, leaderboard in top_k_leaderboard.items(): 244 | new_imgs += [tup[1] for tup in leaderboard] 245 | new_labels += [index for _ in leaderboard] 246 | 247 | unlabeled_data.filepaths = new_imgs 248 | unlabeled_data.labels = new_labels 249 | unlabeled_data.label_id = True 250 | 251 | return unlabeled_data 252 | -------------------------------------------------------------------------------- /methods/unsupervised_learning/pseudo_iterative.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import math 4 | import pickle 5 | 6 | import clip 7 | import numpy as np 8 | import pandas as pd 9 | import scipy.stats as st 10 | import torch 11 | from accelerate import Accelerator 12 | from PIL import Image 13 | from torch import nn 14 | 15 | 16 | accelerator = Accelerator() 17 | 18 | from data import CustomDataset 19 | from models import CustomImageEncoder, ImagePrefixModel 20 | from methods import TeacherStudent 21 | from utils import ( 22 | dataset_object, 23 | evaluate_predictions, 24 | make_scheduler, 25 | pseudolabel_top_k, 26 | seed_worker, 27 | save_parameters, 28 | save_pseudo_labels, 29 | ) 30 | 31 | g = torch.Generator() 32 | g.manual_seed(0) 33 | 34 | log = logging.getLogger(__name__) 35 | 36 | 37 | class PseudoIterative(TeacherStudent): 38 | def __init__( 39 | self, 40 | config, 41 | label_to_idx, 42 | data_folder, 43 | classes, 44 | seen_classes, 45 | unseen_classes, 46 | device, 47 | ): 48 | super().__init__( 49 | config, label_to_idx, data_folder, classes, seen_classes, unseen_classes, device 50 | ) 51 | 52 | 53 | def train( 54 | self, 55 | train_data, 56 | val_data, 57 | unlabeled_data, 58 | test_data, 59 | test_labeled_files, 60 | test_labeles, 61 | ): 62 | # Number of total iterations to cover all unlabeled data 63 | num_iter = int(100/self.config.STEP_QUANTILE) 64 | num_samples = int(len(unlabeled_data) / num_iter) 65 | # Initialize the number of pseudo-labels per class 66 | n_per_class = int(num_samples / len(self.unseen_classes)) 67 | n_unseen = len(self.unseen_classes) 68 | if n_per_class * n_unseen <= len(unlabeled_data.filepaths): 69 | # self.num_pseudo_labels_per_class = n_per_class 70 | self.config.N_PSEUDOSHOTS = n_per_class 71 | else: 72 | # self.num_pseudo_labels_per_class = math.floor(len(unlabeled_data.filepaths)/n_unseen) 73 | self.config.N_PSEUDOSHOTS = math.floor( 74 | len(unlabeled_data.filepaths) / n_unseen 75 | ) 76 | 77 | log.info(f"We select {self.config.N_PSEUDOSHOTS} pseudolabel per each unseen classes.") 78 | log.info(f"The number of unseen classes is: {len(self.unseen_classes)}.") 79 | log.info(f"Thus we expect an initial number of pseudo labeles equal to {len(self.unseen_classes) * self.config.N_PSEUDOSHOTS}.") 80 | # Create a safe copy of labeled/unlabeled data 81 | original_train_data = copy.deepcopy(train_data) 82 | # log.info(f"Training data labels: {original_train_data.labels}") 83 | original_unlabeled_data = copy.deepcopy(unlabeled_data) 84 | # Original val 85 | original_val_data = copy.deepcopy(val_data) 86 | 87 | # Initialize here first batch of pseudo labels 88 | #self.create_training_dataset(train_data, unlabeled_data) 89 | #log.info(f"The original train data has size: {len(original_train_data.filepaths)}.") 90 | #log.info(f"Plus: {len(unlabeled_data.filepaths)}.") 91 | 92 | for niter in range(1, num_iter + 1): 93 | log.info(f"NUM PSEUDO SHOTS: {self.config.N_PSEUDOSHOTS}") 94 | pseudolabel_top_k( 95 | self.config.DATASET_NAME, 96 | self.config.N_PSEUDOSHOTS, 97 | self.config.PROMPT_TEMPLATE, 98 | unlabeled_data, 99 | self.unseen_classes, 100 | self.transform, 101 | self.clip_model, 102 | self.label_to_idx, 103 | self.device, 104 | self.config.VIS_ENCODER, 105 | self.config.SPLIT_SEED, 106 | ) 107 | 108 | log.info(f"Plus: {len(unlabeled_data.filepaths)}.") 109 | filename = f"pseudolabels/{self.config.DATASET_NAME}_CLIP_{self.config.VIS_ENCODER.replace('/', '')}_iter_{niter}_pseudolabels_spl_{self.config.SPLIT_SEED}.pickle" 110 | with open(filename, "wb") as f: 111 | pickle.dump({"filepaths": unlabeled_data.filepaths, "labels": unlabeled_data.labels}, f) 112 | 113 | # Exploit all the available unlabeled data 114 | if self.config.ALL_UNLABELED: 115 | n_per_class = int((niter + 1) * num_samples / n_unseen) 116 | log.info(f"n_per_class: {n_per_class}") 117 | if n_per_class * n_unseen <= len(original_unlabeled_data.filepaths): 118 | log.info(f"if n_per_class: {n_per_class}") 119 | self.config.N_PSEUDOSHOTS = n_per_class 120 | else: 121 | log.info(f"else new val: {len(original_unlabeled_data.filepaths) / n_unseen}") 122 | # We are making a stong assumption about the distribution of unlabeled data 123 | self.config.N_PSEUDOSHOTS = math.floor( 124 | len(original_unlabeled_data.filepaths) / n_unseen 125 | ) 126 | 127 | unlabeled_data = original_unlabeled_data 128 | original_unlabeled_data = copy.deepcopy(unlabeled_data) -------------------------------------------------------------------------------- /methods/unsupervised_learning/textual_fpl.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import clip 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from accelerate import Accelerator 8 | from PIL import Image 9 | from torch import nn 10 | from tqdm import tqdm 11 | 12 | accelerator = Accelerator() 13 | 14 | from methods.unsupervised_learning import TextualPrompt 15 | from utils import ( 16 | dataset_object, 17 | make_scheduler, 18 | pseudolabel_top_k, 19 | seed_worker, 20 | ) 21 | 22 | 23 | g = torch.Generator() 24 | g.manual_seed(0) 25 | 26 | log = logging.getLogger(__name__) 27 | 28 | 29 | class TextualFPL(TextualPrompt): 30 | def __init__( 31 | self, 32 | config, 33 | label_to_idx, 34 | data_folder, 35 | classes, 36 | seen_classes, 37 | unseen_classes, 38 | device, 39 | ): 40 | """This class define Coop baseline. 41 | 42 | :param config: dictionaries of prameters in models_config/coop_baseline_config.yml 43 | :param label_to_idx: dictionary (key, value):(class name, id) 44 | :param classes: list of class names 45 | :param seen_classes: list of seen classes' names 46 | :param unseen_classes: list of unseen classes' names 47 | :param device: device in use 48 | """ 49 | super().__init__( 50 | config, label_to_idx, classes, seen_classes, unseen_classes, device 51 | ) 52 | 53 | self.data_folder = data_folder 54 | 55 | def create_training_dataset(self, train_data, unlabeled_data=None): 56 | """This function create the dataset for training. Specifically, it 57 | merges pseudo-labels for unseen data and labeled data for seen classes. 58 | 59 | :param train_data: Dataset object - training seen classes (defined in zsl_jpl line 323) 60 | :param unlabeled_data: Dataset object - dataset of unlabeled data for 61 | unseen classes (defined in zsl_jpl line 328) 62 | """ 63 | 64 | # Get pseudo-labels for unlabeled data from unseen classes 65 | train_unseen_dataset = pseudolabel_top_k( 66 | self.config, 67 | self.config.DATASET_NAME, 68 | self.config.N_PSEUDOSHOTS, 69 | self.config.PROMPT_TEMPLATE, 70 | unlabeled_data, 71 | self.classes, 72 | self.transform, 73 | self.clip_model, 74 | self.label_to_idx, 75 | self.device, 76 | self.config.VIS_ENCODER, 77 | self.config.SPLIT_SEED, 78 | ) 79 | 80 | # Define the lists of traiing data from seen and unseen classes 81 | unseen_imgs = train_unseen_dataset.filepaths 82 | unseen_labs = train_unseen_dataset.labels 83 | 84 | # Use a portion of the pseudo-labeled data to build a validation set 85 | if self.config.N_PSEUDOSHOTS >= 10: 86 | np.random.seed(self.config.validation_seed) 87 | train_indices = np.random.choice( 88 | range(len(unseen_imgs)), 89 | size=int(len(unseen_imgs) * self.config.ratio_train_val), 90 | replace=False, 91 | ) 92 | val_indices = list( 93 | set(range(len(unseen_imgs))).difference(set(train_indices)) 94 | ) 95 | 96 | self.val_unseen_files = np.array(unseen_imgs)[val_indices] 97 | self.val_unseen_labs = np.array(unseen_labs)[val_indices] 98 | 99 | unseen_imgs = list(np.array(unseen_imgs)[train_indices]) 100 | unseen_labs = list(np.array(unseen_labs)[train_indices]) 101 | 102 | else: 103 | self.val_unseen_files = None 104 | self.val_unseen_labs = None 105 | 106 | train_data.filepaths = list(unseen_imgs) 107 | train_data.labels = list(unseen_labs) 108 | train_data.label_id = True 109 | 110 | return train_data 111 | 112 | def define_loss_function(self, logits, labs): 113 | 114 | loss_ce = self.cross_entropy(logits, labs, self.classes) 115 | 116 | return loss_ce 117 | 118 | def cross_entropy(self, logits, labels, classes): 119 | """This loss computes the probability mass on the 120 | opposite set of classes for each sample. 121 | 122 | :param logits: continuous vector 123 | :param labels: class ids 124 | """ 125 | 126 | error = self.loss_func(logits, labels) 127 | 128 | return error 129 | 130 | 131 | def get_pseudo_labels(self, unlabeled_examples): 132 | log.info(f"Num unlabeled data: {len(unlabeled_examples)}") 133 | # Get prediction on unlabeled data 134 | std_preds = self.test_predictions( 135 | unlabeled_examples, standard_zsl=True 136 | ) 137 | 138 | DatasetObject = dataset_object(self.config.DATASET_NAME) 139 | # 4. Take top-16 pseudo-labels to finetune the student 140 | pseudo_unseen_examples = DatasetObject( 141 | std_preds["id"], 142 | self.data_folder, 143 | transform=self.transform, 144 | augmentations=None, 145 | train=True, 146 | labels=None, 147 | label_map=self.label_to_idx, 148 | class_folder=True, 149 | original_filepaths=unlabeled_examples.filepaths, 150 | ) 151 | 152 | pseudo_labels = self.assign_pseudo_labels( 153 | self.config.N_PSEUDOSHOTS, pseudo_unseen_examples 154 | ) 155 | 156 | return pseudo_labels 157 | 158 | def assign_pseudo_labels(self, k, unlabeled_data): 159 | # Define text queries 160 | # prompts = [f"{self.template}{' '.join(i.split('_'))}" \ 161 | # for i in self.unseen_classes] 162 | 163 | log.info(f"[self.assign_pseudo_labels] Number of prompts: {len(self.classes)}") 164 | 165 | # Get prompts 166 | self.model.classes = self.classes 167 | text_features = self.model(self.model.classes) 168 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 169 | log.info(f"TEXT FEATURES SHAPE: {text_features.size()}") 170 | 171 | # to find the top k for each class, each class has it's own "leaderboard" 172 | top_k_leaderboard = { 173 | self.label_to_idx[self.classes[i]]: [] 174 | for i in range(len(self.classes)) 175 | } # maps class idx -> (confidence, image_path) tuple 176 | 177 | for img_path in unlabeled_data.filepaths: 178 | # log.info(f"IMAGEPATH: {img_path}") 179 | img = Image.open(img_path).convert("RGB") 180 | img = torch.unsqueeze(self.transform(img), 0).to(self.device) 181 | with torch.no_grad(): 182 | image_features = self.clip_model.encode_image(img) 183 | image_features = image_features / image_features.norm( 184 | dim=-1, keepdim=True 185 | ) 186 | # cosine similarity as logits 187 | 188 | logit_scale = self.clip_model.logit_scale.exp() 189 | logits = logit_scale * image_features @ text_features.t() 190 | probs = logits.softmax(dim=-1) 191 | idx_preds = torch.argmax(logits, dim=1) 192 | pred_id = idx_preds.item() 193 | pred = self.label_to_idx[self.classes[idx_preds.item()]] 194 | 195 | """if predicted class has empty leaderboard, or if the confidence is high 196 | enough for predicted class leaderboard, add the new example 197 | """ 198 | prob_score = probs[0][pred_id] 199 | if len(top_k_leaderboard[pred]) < k: 200 | top_k_leaderboard[pred].append((prob_score, img_path)) 201 | elif ( 202 | top_k_leaderboard[pred][-1][0] < prob_score 203 | ): # if the confidence in predicted class "qualifies" for top-k 204 | # default sorting of tuples is by first element 205 | top_k_leaderboard[pred] = sorted( 206 | top_k_leaderboard[pred] + [(probs[0][pred_id], img_path)], 207 | reverse=True, 208 | )[:k] 209 | else: 210 | # sort the other classes by confidence score 211 | order_of_classes = sorted( 212 | [ 213 | (probs[0][j], j) 214 | for j in range(len(self.classes)) 215 | if j != pred_id 216 | ], 217 | reverse=True, 218 | ) 219 | for score, index in order_of_classes: 220 | index_dict = self.label_to_idx[self.classes[index]] 221 | # log.info(f"{classnames[index]}") 222 | # log.info(f"{index_dict}") 223 | if len(top_k_leaderboard[index_dict]) < k: 224 | top_k_leaderboard[index_dict].append( 225 | (probs[0][index], img_path) 226 | ) 227 | elif top_k_leaderboard[index_dict][-1][0] < probs[0][index]: 228 | # default sorting of tuples is by first element 229 | top_k_leaderboard[index_dict] = sorted( 230 | top_k_leaderboard[index_dict] 231 | + [((probs[0][index], img_path))], 232 | reverse=True, 233 | )[:k] 234 | 235 | new_imgs = [] 236 | new_labels = [] 237 | # loop through, and rebuild the dataset 238 | for index, leaderboard in top_k_leaderboard.items(): 239 | new_imgs += [tup[1] for tup in leaderboard] 240 | new_labels += [index for _ in leaderboard] 241 | 242 | unlabeled_data.filepaths = new_imgs 243 | unlabeled_data.labels = new_labels 244 | unlabeled_data.label_id = True 245 | 246 | return unlabeled_data 247 | -------------------------------------------------------------------------------- /methods/unsupervised_learning/visual_fpl.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import clip 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from accelerate import Accelerator 8 | from PIL import Image 9 | from torch import nn 10 | 11 | accelerator = Accelerator() 12 | 13 | from methods.unsupervised_learning import VisualPrompt 14 | from utils import ( 15 | dataset_object, 16 | make_scheduler, 17 | pseudolabel_top_k, 18 | ) 19 | 20 | 21 | log = logging.getLogger(__name__) 22 | 23 | 24 | class VisualFPL(VisualPrompt): 25 | def __init__( 26 | self, 27 | config, 28 | label_to_idx, 29 | data_folder, 30 | classes, 31 | seen_classes, 32 | unseen_classes, 33 | device 34 | ): 35 | """This class defines few-pseudo-learning (FPL) VPT's training and evaluation. 36 | 37 | :param config: dictionaries of prameters in models_config/vpt_baseline_config.yml 38 | :param label_to_idx: dictionary (key, value):(class name, id) 39 | :param classes: list of class names 40 | :param seen_classes: list of seen classes' names 41 | :param unseen_classes: list of unseen classes' names 42 | :param device: device in use 43 | """ 44 | 45 | super().__init__( 46 | config, label_to_idx, classes, seen_classes, unseen_classes, device 47 | ) 48 | 49 | self.data_folder = data_folder 50 | 51 | def create_training_dataset(self, train_data, unlabeled_data=None): 52 | """This function creates the dataset for training. Specifically, it 53 | merges pseudo-labels for unseen data and labeled data for seen classes. 54 | 55 | :param train_data: Dataset object - training seen classes (defined in zsl_jpl line 323) 56 | :param unlabeled_data: Dataset object - dataset of unlabeled data for 57 | unseen classes (defined in zsl_jpl line 328) 58 | """ 59 | 60 | # Get pseudo-labels for unlabeled data from unseen classes 61 | train_unseen_dataset = pseudolabel_top_k( 62 | self.config, 63 | self.config.DATASET_NAME, 64 | self.config.N_PSEUDOSHOTS, 65 | self.config.PROMPT_TEMPLATE, 66 | unlabeled_data, 67 | self.classes, 68 | self.transform, 69 | self.clip_model, 70 | self.label_to_idx, 71 | self.device, 72 | self.config.VIS_ENCODER, 73 | self.config.SPLIT_SEED, 74 | ) 75 | log.info(f"[self.create_training_dataset] Training data: {len(train_unseen_dataset.filepaths)}") 76 | # Define the lists of traiing data from seen and unseen classes 77 | unseen_imgs = train_unseen_dataset.filepaths 78 | unseen_labs = train_unseen_dataset.labels 79 | 80 | # Use a portion of the pseudo-labeled data to build a validation set 81 | if self.config.N_PSEUDOSHOTS >= 10: 82 | np.random.seed(self.config.validation_seed) 83 | train_indices = np.random.choice( 84 | range(len(unseen_imgs)), 85 | size=int(len(unseen_imgs) * self.config.ratio_train_val), 86 | replace=False, 87 | ) 88 | val_indices = list( 89 | set(range(len(unseen_imgs))).difference(set(train_indices)) 90 | ) 91 | 92 | self.val_unseen_files = np.array(unseen_imgs)[val_indices] 93 | self.val_unseen_labs = np.array(unseen_labs)[val_indices] 94 | 95 | unseen_imgs = list(np.array(unseen_imgs)[train_indices]) 96 | unseen_labs = list(np.array(unseen_labs)[train_indices]) 97 | 98 | else: 99 | self.val_unseen_files = None 100 | self.val_unseen_labs = None 101 | 102 | train_data.filepaths = list(unseen_imgs) 103 | train_data.labels = list(unseen_labs) 104 | train_data.label_id = True 105 | 106 | 107 | def define_loss_function(self, logits, labs): 108 | 109 | loss_ce = self.cross_entropy(logits, labs, self.classes) 110 | return loss_ce 111 | 112 | def cross_entropy(self, logits, labels, classes): 113 | """This loss computes the probability mass on the 114 | opposite set of classes for each sample. 115 | 116 | :param logits: continuous vector 117 | :param labels: class ids 118 | """ 119 | 120 | error = self.loss_func(logits, labels) 121 | 122 | return error 123 | 124 | def define_textual_prompts(self, only_unlabelled=None, validation=False): 125 | """This function returns the textual prompts. You can modify the list 126 | of classes of interest. 127 | 128 | :param only_unlabelled: boolean. It is True if the training only involves 129 | pseudo-labeled unseen data 130 | """ 131 | 132 | return [ 133 | self.template.format(" ".join(i.split("_"))) for i in self.classes 134 | ] 135 | 136 | def reindex_predicted_labels(self, idx_preds, only_unlabelled=False): 137 | """This function returns the correct index of predictions to compute 138 | model's accuracy. 139 | 140 | :param idx_pred: list of predictions ids 141 | :param only_unlabelled: boolean. It is True if the training only involves 142 | pseudo-labeled unseen data 143 | """ 144 | 145 | return [self.classes[i.item()] for i in idx_preds] 146 | 147 | def reindex_true_labels(self, label, only_unlabelled=False): 148 | """This function returns the correct index of true labels. 149 | 150 | :param label: list of labels from data loader 151 | :param only_unlabelled: boolean. It is True if the training only involves 152 | pseudo-labeled unseen data 153 | """ 154 | 155 | 156 | return torch.tensor([l for l in label]) 157 | 158 | 159 | def get_pseudo_labels(self, unlabeled_examples): 160 | log.info(f"Num unlabeled data: {len(unlabeled_examples)}") 161 | # Get prediction on unlabeled data 162 | std_preds = self.test_predictions( 163 | unlabeled_examples, standard_zsl=True 164 | ) 165 | 166 | DatasetObject = dataset_object(self.config.DATASET_NAME) 167 | # 4. Take top-16 pseudo-labels to finetune the student 168 | pseudo_unseen_examples = DatasetObject( 169 | std_preds["id"], 170 | self.data_folder, 171 | transform=self.transform, 172 | augmentations=None, 173 | train=True, 174 | labels=None, 175 | label_map=self.label_to_idx, 176 | class_folder=True, 177 | original_filepaths=unlabeled_examples.filepaths, 178 | ) 179 | 180 | pseudo_labels = self.assign_pseudo_labels( 181 | self.config.N_PSEUDOSHOTS, pseudo_unseen_examples 182 | ) 183 | 184 | return pseudo_labels 185 | 186 | def assign_pseudo_labels(self, k, unlabeled_data): 187 | # Define text queries 188 | # prompts = [f"{self.template}{' '.join(i.split('_'))}" \ 189 | # for i in self.unseen_classes] 190 | prompts = [ 191 | self.template.format(" ".join(i.split("_"))) for i in self.classes 192 | ] 193 | log.info(f"Number of prompts: {len(prompts)}") 194 | 195 | # Encode text 196 | text = clip.tokenize(prompts).to(self.device) 197 | text_features = self.clip_model.encode_text(text) 198 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 199 | 200 | # to find the top k for each class, each class has it's own "leaderboard" 201 | top_k_leaderboard = { 202 | self.label_to_idx[self.classes[i]]: [] 203 | for i in range(len(self.classes)) 204 | } # maps class idx -> (confidence, image_path) tuple 205 | 206 | for img_path in unlabeled_data.filepaths: 207 | # log.info(f"IMAGEPATH: {img_path}") 208 | img = Image.open(img_path).convert("RGB") 209 | img = torch.unsqueeze(self.transform(img), 0).to(self.device) 210 | with torch.no_grad(): 211 | image_features = self.model(img) 212 | image_features = image_features / image_features.norm( 213 | dim=-1, keepdim=True 214 | ) 215 | # cosine similarity as logits 216 | 217 | logit_scale = self.clip_model.logit_scale.exp() 218 | logits = logit_scale * image_features @ text_features.t() 219 | probs = logits.softmax(dim=-1) 220 | idx_preds = torch.argmax(logits, dim=1) 221 | pred_id = idx_preds.item() 222 | pred = self.label_to_idx[self.classes[idx_preds.item()]] 223 | 224 | """if predicted class has empty leaderboard, or if the confidence is high 225 | enough for predicted class leaderboard, add the new example 226 | """ 227 | prob_score = probs[0][pred_id] 228 | if len(top_k_leaderboard[pred]) < k: 229 | top_k_leaderboard[pred].append((prob_score, img_path)) 230 | elif ( 231 | top_k_leaderboard[pred][-1][0] < prob_score 232 | ): # if the confidence in predicted class "qualifies" for top-k 233 | # default sorting of tuples is by first element 234 | top_k_leaderboard[pred] = sorted( 235 | top_k_leaderboard[pred] + [(probs[0][pred_id], img_path)], 236 | reverse=True, 237 | )[:k] 238 | else: 239 | # sort the other classes by confidence score 240 | order_of_classes = sorted( 241 | [ 242 | (probs[0][j], j) 243 | for j in range(len(self.classes)) 244 | if j != pred_id 245 | ], 246 | reverse=True, 247 | ) 248 | for score, index in order_of_classes: 249 | index_dict = self.label_to_idx[self.classes[index]] 250 | # log.info(f"{classnames[index]}") 251 | # log.info(f"{index_dict}") 252 | if len(top_k_leaderboard[index_dict]) < k: 253 | top_k_leaderboard[index_dict].append( 254 | (probs[0][index], img_path) 255 | ) 256 | elif top_k_leaderboard[index_dict][-1][0] < probs[0][index]: 257 | # default sorting of tuples is by first element 258 | top_k_leaderboard[index_dict] = sorted( 259 | top_k_leaderboard[index_dict] 260 | + [((probs[0][index], img_path))], 261 | reverse=True, 262 | )[:k] 263 | 264 | new_imgs = [] 265 | new_labels = [] 266 | # loop through, and rebuild the dataset 267 | for index, leaderboard in top_k_leaderboard.items(): 268 | new_imgs += [tup[1] for tup in leaderboard] 269 | new_labels += [index for _ in leaderboard] 270 | 271 | unlabeled_data.filepaths = new_imgs 272 | unlabeled_data.labels = new_labels 273 | unlabeled_data.label_id = True 274 | 275 | return unlabeled_data 276 | -------------------------------------------------------------------------------- /methods_config/accelerate_config.yml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: MULTI_GPU 3 | fp16: false 4 | machine_rank: 0 5 | main_process_ip: null 6 | main_process_port: null 7 | main_training_function: main 8 | num_machines: 1 9 | num_processes: 4 -------------------------------------------------------------------------------- /methods_config/accelerate_localtest_config.yml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: 'NO' 3 | fp16: false 4 | machine_rank: 0 5 | main_process_ip: null 6 | main_process_port: null 7 | main_training_function: main 8 | num_machines: 1 9 | num_processes: 1 -------------------------------------------------------------------------------- /methods_config/clip_config.yml: -------------------------------------------------------------------------------- 1 | # Dataset settings 2 | DATASET_NAME: "$DATASET_NAME" 3 | DATASET_DIR: "$DATASET_DIR" 4 | 5 | # Model setting 6 | MODEL: "$MODEL" 7 | VIS_ENCODER: "$VIS_ENCODER" 8 | 9 | # Prompt template 10 | PROMPT_TEMPLATE: 'imported in main.py' 11 | 12 | # Train-validation separation 13 | validation_seed: 0 14 | ratio_train_val: 0.8 15 | 16 | # Batch size 17 | BATCH_SIZE: 8 18 | 19 | # Seed settings 20 | OPTIM_SEED: "$OPTIM_SEED" 21 | 22 | # Classes split 23 | CLASSES_SPLIT: SPLIT_SEED 24 | SPLIT_SEED: "$SPLIT_SEED" 25 | -------------------------------------------------------------------------------- /methods_config/evaluation_config.yml: -------------------------------------------------------------------------------- 1 | # Dataset name 2 | DATASET_NAME: "$DATASET_NAME" 3 | DATASET_DIR: '/Users/menga/Desktop/github/zsl_taglets/development' # '/users/cmenghin/data/bats/datasets/classification' 4 | # Model 5 | MODEL: "$MODEL" 6 | MODALITY: 'image' 7 | # Visual ecoder 8 | VIS_ENCODER: "$VIS_ENCODER" 9 | # Prompt template 10 | PROMPT_TEMPLATE: 'imported in main.py' 11 | # Number of shats per classes in SSL 12 | N_LABEL: 16 13 | ALPHA: 0.3 14 | # Prefix size 15 | # Text Prefix size 16 | TEXT_PREFIX_SIZE: 4 17 | # Vision Prefix size 18 | VISION_PREFIX_SIZE: 4 19 | # Lightweight transformer dim 20 | TRANSFORMER_DIM: 128 21 | # Use VPT-Deep? 22 | VPT_DEEP: False 23 | PREFIX_SIZE: 16 24 | # Prefix initialization: normal/uniform 25 | VIS_PREFIX_INIT: "normal" 26 | # Initialization mean and variance 27 | MEAN_INIT: 0 28 | VAR_INIT: 0.02 29 | # Seeed to separate train and validation 30 | validation_seed: 0 31 | # Ratio validation 32 | ratio_train_val: 0.8 33 | # Batch size 34 | BATCH_SIZE: 2 35 | # Number of epochs 36 | EPOCHS: 150 37 | # Scheduler 38 | SCHEDULER: "cosine" 39 | # Scheduler warmup epochs 40 | WARMUP_EPOCHS: 5 41 | WARMUP_LR: 0.0001 42 | # Number of accumulation iter 43 | ACCUMULATION_ITER: 1 44 | # Optimizer 45 | OPTIM: "SGD" 46 | LR: 0.1 47 | DECAY: 0.1 48 | STEP_SIZE: 1 49 | # Set seeds 50 | OPTIM_SEED: "$OPTIM_SEED" 51 | # Classes split 52 | CLASSES_SPLIT: SPLIT_SEED 53 | # Seed split 54 | SPLIT_SEED: "$SPLIT_SEED" -------------------------------------------------------------------------------- /methods_config/grip_multimodal_config.yml: -------------------------------------------------------------------------------- 1 | DATASET_DIR: '/users/cmenghin/data/bats/datasets/classification' #'/Users/menga/Desktop/github/zsl_taglets/development' # 2 | DATASET_NAME: "$DATASET_NAME" 3 | # Model 4 | MODALITY: 'multi' 5 | MODEL: "$MODEL" 6 | # Visual ecoder 7 | VIS_ENCODER: "$VIS_ENCODER" 8 | # Prompt template 9 | PROMPT_TEMPLATE: 'imported in main.py' 10 | # Number of shats per classes in SSL 11 | N_LABEL: 2 12 | ALPHA: 0.3 13 | # Text Prefix size 14 | TEXT_PREFIX_SIZE: 4 15 | # Vision Prefix size 16 | VISION_PREFIX_SIZE: 4 17 | # Pseudolabels 18 | N_PSEUDOSHOTS: 16 19 | STEP_QUANTILE: 10 20 | # Lightweight transformer dim 21 | TRANSFORMER_DIM: 128 22 | # Use VPT-Deep? 23 | VPT_DEEP: False 24 | # Prefix initialization: normal/uniform 25 | VIS_PREFIX_INIT: "normal" 26 | # Initialization mean and variance 27 | MEAN_INIT: 0 28 | VAR_INIT: 0.02 29 | # Seeed to separate train and validation 30 | validation_seed: 0 31 | # Ratio validation 32 | ratio_train_val: 0.8 33 | # Batch size 34 | BATCH_SIZE: 16 35 | # Number of epochs 36 | EPOCHS: 150 37 | # Scheduler 38 | SCHEDULER: "cosine" 39 | # Scheduler warmup epochs 40 | WARMUP_EPOCHS: 5 41 | WARMUP_LR: 0.0001 42 | # Number of accumulation iter 43 | ACCUMULATION_ITER: 1 44 | # Optimizer teacher 45 | OPTIM: "SGD" 46 | LR: 0.01 47 | DECAY: 0.1 48 | STEP_SIZE: 1 49 | STEP_SIZE: 1 50 | # Set seeds 51 | OPTIM_SEED: "$OPTIM_SEED" 52 | # Classes split 53 | CLASSES_SPLIT: SPLIT_SEED 54 | # Seed split 55 | SPLIT_SEED: "$SPLIT_SEED" 56 | -------------------------------------------------------------------------------- /methods_config/grip_textual_config.yml: -------------------------------------------------------------------------------- 1 | DATASET_DIR: '/users/cmenghin/data/bats/datasets/classification' #'/Users/menga/Desktop/github/zsl_taglets/development' # '/users/cmenghin/data/bats/datasets/classification' 2 | DATASET_NAME: "$DATASET_NAME" 3 | # Model 4 | MODALITY: 'text' 5 | MODEL: "$MODEL" 6 | # Visual ecoder 7 | VIS_ENCODER: "$VIS_ENCODER" 8 | # Prompt template 9 | PROMPT_TEMPLATE: 'imported in main.py' 10 | # Number of shats per classes in SSL 11 | N_LABEL: 2 12 | ALPHA: 0.3 13 | # Prefix size 14 | PREFIX_SIZE: 16 15 | N_PSEUDOSHOTS: 16 16 | STEP_QUANTILE: 10 17 | # Prefix initialization: normal/uniform 18 | VIS_PREFIX_INIT: "normal" 19 | # Initialization mean and variance 20 | MEAN_INIT: 0 21 | VAR_INIT: 0.02 22 | # Seeed to separate train and validation 23 | validation_seed: 0 24 | # Ratio validation 25 | ratio_train_val: 0.8 26 | # Batch size 27 | BATCH_SIZE: 16 28 | # Number of epochs 29 | EPOCHS: 150 30 | # Scheduler 31 | SCHEDULER: "cosine" 32 | # Scheduler warmup epochs 33 | WARMUP_EPOCHS: 5 34 | WARMUP_LR: 0.0001 35 | # Number of accumulation iter 36 | ACCUMULATION_ITER: 1 37 | # Optimizer teacher 38 | OPTIM: "SGD" 39 | LR: 0.1 40 | DECAY: 0.1 41 | STEP_SIZE: 1 42 | STEP_SIZE: 1 43 | # Set seeds 44 | OPTIM_SEED: "$OPTIM_SEED" 45 | # Classes split 46 | CLASSES_SPLIT: SPLIT_SEED 47 | # Seed split 48 | SPLIT_SEED: "$SPLIT_SEED" 49 | -------------------------------------------------------------------------------- /methods_config/grip_visual_config.yml: -------------------------------------------------------------------------------- 1 | DATASET_DIR: '/users/cmenghin/data/bats/datasets/classification' #'/Users/menga/Desktop/github/zsl_taglets/development' # 2 | DATASET_NAME: "$DATASET_NAME" 3 | # Model 4 | MODALITY: 'image' 5 | MODEL: "$MODEL" 6 | # Visual ecoder 7 | VIS_ENCODER: "$VIS_ENCODER" 8 | # Prompt template 9 | PROMPT_TEMPLATE: 'imported in main.py' 10 | # Number of shats per classes in SSL 11 | N_LABEL: 2 12 | ALPHA: 0.3 13 | # Prefix size 14 | PREFIX_SIZE: 16 15 | N_PSEUDOSHOTS: 16 16 | STEP_QUANTILE: 10 17 | # Prefix initialization: normal/uniform 18 | VIS_PREFIX_INIT: "normal" 19 | # Initialization mean and variance 20 | MEAN_INIT: 0 21 | VAR_INIT: 0.02 22 | # Seeed to separate train and validation 23 | validation_seed: 0 24 | # Ratio validation 25 | ratio_train_val: 0.8 26 | # Batch size 27 | BATCH_SIZE: 16 28 | # Number of epochs 29 | EPOCHS: 150 30 | # Scheduler 31 | SCHEDULER: "cosine" 32 | # Scheduler warmup epochs 33 | WARMUP_EPOCHS: 5 34 | WARMUP_LR: 0.0001 35 | # Number of accumulation iter 36 | ACCUMULATION_ITER: 1 37 | # Optimizer teacher 38 | OPTIM: "SGD" 39 | LR: 0.1 40 | DECAY: 0.1 41 | STEP_SIZE: 1 42 | STEP_SIZE: 1 43 | # Set seeds 44 | OPTIM_SEED: "$OPTIM_SEED" 45 | # Classes split 46 | CLASSES_SPLIT: SPLIT_SEED 47 | # Seed split 48 | SPLIT_SEED: "$SPLIT_SEED" 49 | -------------------------------------------------------------------------------- /methods_config/iterative_multimodal_fpl_config.yml: -------------------------------------------------------------------------------- 1 | DATASET_DIR: "$DATASET_DIR" #'/users/cmenghin/data/bats/datasets/classification' # '/Users/menga/Desktop/github/zsl_taglets/development' # 2 | DATASET_NAME: "$DATASET_NAME" 3 | # Model 4 | MODALITY: 'multi' 5 | MODEL: "$MODEL" 6 | # Visual ecoder 7 | VIS_ENCODER: "$VIS_ENCODER" 8 | # Prompt template 9 | PROMPT_TEMPLATE: 'imported in main.py' 10 | # Number of shats per classes in SSL 11 | N_LABEL: 2 12 | ALPHA: 0.3 13 | # Text Prefix size 14 | TEXT_PREFIX_SIZE: 4 15 | # Vision Prefix size 16 | VISION_PREFIX_SIZE: 4 17 | # Prefix size 18 | N_PSEUDOSHOTS: 16 19 | STEP_QUANTILE: 10 20 | # Lightweight transformer dim 21 | TRANSFORMER_DIM: 128 22 | # Use VPT-Deep? 23 | VPT_DEEP: False 24 | # Prefix initialization: normal/uniform 25 | VIS_PREFIX_INIT: "normal" 26 | # Initialization mean and variance 27 | MEAN_INIT: 0 28 | VAR_INIT: 0.02 29 | # Seeed to separate train and validation 30 | validation_seed: 0 31 | # Ratio validation 32 | ratio_train_val: 0.8 33 | # Batch size 34 | BATCH_SIZE: 16 35 | # Number of epochs 36 | EPOCHS: 150 37 | # Scheduler 38 | SCHEDULER: "cosine" 39 | # Scheduler warmup epochs 40 | WARMUP_EPOCHS: 5 41 | WARMUP_LR: 0.0001 42 | # Number of accumulation iter 43 | ACCUMULATION_ITER: 1 44 | # Optimizer teacher 45 | OPTIM: "SGD" 46 | LR: 0.01 47 | DECAY: 0.1 48 | STEP_SIZE: 1 49 | STEP_SIZE: 1 50 | # Set seeds 51 | OPTIM_SEED: "$OPTIM_SEED" 52 | # Classes split 53 | CLASSES_SPLIT: SPLIT_SEED 54 | # Seed split 55 | SPLIT_SEED: "$SPLIT_SEED" 56 | -------------------------------------------------------------------------------- /methods_config/iterative_textual_fpl_config.yml: -------------------------------------------------------------------------------- 1 | DATASET_DIR: "$DATASET_DIR" # '/users/cmenghin/data/bats/datasets/classification' # '/Users/menga/Desktop/github/zsl_taglets/development' # 2 | DATASET_NAME: "$DATASET_NAME" 3 | # Model 4 | MODALITY: 'text' 5 | MODEL: "$MODEL" 6 | # Visual ecoder 7 | VIS_ENCODER: "$VIS_ENCODER" 8 | # Prompt template 9 | PROMPT_TEMPLATE: 'imported in main.py' 10 | # Number of shats per classes in SSL 11 | N_LABEL: 2 12 | ALPHA: 0.3 13 | # Prefix size 14 | PREFIX_SIZE: 16 15 | N_PSEUDOSHOTS: 16 16 | STEP_QUANTILE: 10 17 | # Prefix initialization: normal/uniform 18 | VIS_PREFIX_INIT: "normal" 19 | # Initialization mean and variance 20 | MEAN_INIT: 0 21 | VAR_INIT: 0.02 22 | # Seeed to separate train and validation 23 | validation_seed: 0 24 | # Ratio validation 25 | ratio_train_val: 0.8 26 | # Batch size 27 | BATCH_SIZE: 16 28 | # Number of epochs 29 | EPOCHS: 150 30 | # Scheduler 31 | SCHEDULER: "cosine" 32 | # Scheduler warmup epochs 33 | WARMUP_EPOCHS: 5 34 | WARMUP_LR: 0.0001 35 | # Number of accumulation iter 36 | ACCUMULATION_ITER: 1 37 | # Optimizer teacher 38 | OPTIM: "SGD" 39 | LR: 0.1 40 | DECAY: 0.1 41 | STEP_SIZE: 1 42 | STEP_SIZE: 1 43 | # Set seeds 44 | OPTIM_SEED: "$OPTIM_SEED" 45 | # Classes split 46 | CLASSES_SPLIT: SPLIT_SEED 47 | # Seed split 48 | SPLIT_SEED: "$SPLIT_SEED" 49 | -------------------------------------------------------------------------------- /methods_config/iterative_visual_fpl_config.yml: -------------------------------------------------------------------------------- 1 | DATASET_DIR: "$DATASET_DIR" # '/users/cmenghin/data/bats/datasets/classification' # '/Users/menga/Desktop/github/zsl_taglets/development' # 2 | DATASET_NAME: "$DATASET_NAME" 3 | # Model 4 | MODALITY: 'image' 5 | MODEL: "$MODEL" 6 | # Visual ecoder 7 | VIS_ENCODER: "$VIS_ENCODER" 8 | # Prompt template 9 | PROMPT_TEMPLATE: 'imported in main.py' 10 | # Number of shats per classes in SSL 11 | N_LABEL: 2 12 | ALPHA: 0.3 13 | # Prefix size 14 | PREFIX_SIZE: 16 15 | N_PSEUDOSHOTS: 16 16 | STEP_QUANTILE: 10 17 | # Prefix initialization: normal/uniform 18 | VIS_PREFIX_INIT: "normal" 19 | # Initialization mean and variance 20 | MEAN_INIT: 0 21 | VAR_INIT: 0.02 22 | # Seeed to separate train and validation 23 | validation_seed: 0 24 | # Ratio validation 25 | ratio_train_val: 0.8 26 | # Batch size 27 | BATCH_SIZE: 16 28 | # Number of epochs 29 | EPOCHS: 150 30 | # Scheduler 31 | SCHEDULER: "cosine" 32 | # Scheduler warmup epochs 33 | WARMUP_EPOCHS: 5 34 | WARMUP_LR: 0.0001 35 | # Number of accumulation iter 36 | ACCUMULATION_ITER: 1 37 | # Optimizer teacher 38 | OPTIM: "SGD" 39 | LR: 0.1 40 | DECAY: 0.1 41 | STEP_SIZE: 1 42 | STEP_SIZE: 1 43 | # Set seeds 44 | OPTIM_SEED: "$OPTIM_SEED" 45 | # Classes split 46 | CLASSES_SPLIT: SPLIT_SEED 47 | # Seed split 48 | SPLIT_SEED: "$SPLIT_SEED" 49 | -------------------------------------------------------------------------------- /methods_config/multimodal_fpl_config.yml: -------------------------------------------------------------------------------- 1 | # Dataset name 2 | DATASET_DIR: "$DATASET_DIR" #'/users/cmenghin/data/bats/datasets/classification' # '/Users/menga/Desktop/github/zsl_taglets/development' # 3 | DATASET_NAME: "$DATASET_NAME" 4 | # Model 5 | MODALITY: 'multi' 6 | MODEL: "$MODEL" 7 | # Visual ecoder 8 | VIS_ENCODER: "$VIS_ENCODER" 9 | # Prompt template 10 | PROMPT_TEMPLATE: 'imported in main.py' 11 | # Number of shats per classes in SSL 12 | N_LABEL: 2 13 | ALPHA: 0.3 14 | # Text Prefix size 15 | TEXT_PREFIX_SIZE: 4 16 | # Vision Prefix size 17 | VISION_PREFIX_SIZE: 4 18 | # Pseudolabels 19 | N_PSEUDOSHOTS: 16 20 | STEP_QUANTILE: 10 21 | # Lightweight transformer dim 22 | TRANSFORMER_DIM: 128 23 | # Use VPT-Deep? 24 | VPT_DEEP: False 25 | # Prefix initialization: normal/uniform 26 | VIS_PREFIX_INIT: "normal" 27 | # Initialization mean and variance 28 | MEAN_INIT: 0 29 | VAR_INIT: 0.02 30 | # Seeed to separate train and validation 31 | validation_seed: 0 32 | # Ratio validation 33 | ratio_train_val: 0.8 34 | # Batch size 35 | BATCH_SIZE: 16 36 | # Number of epochs 37 | EPOCHS: 150 38 | # Scheduler 39 | SCHEDULER: "cosine" 40 | # Scheduler warmup epochs 41 | WARMUP_EPOCHS: 5 42 | WARMUP_LR: 0.0001 43 | # Number of accumulation iter 44 | ACCUMULATION_ITER: 1 45 | # Optimizer 46 | OPTIM: "SGD" 47 | LR: 0.01 48 | DECAY: 0.1 49 | STEP_SIZE: 1 50 | # Set seeds 51 | OPTIM_SEED: "$OPTIM_SEED" 52 | # Classes split 53 | CLASSES_SPLIT: SPLIT_SEED 54 | # Seed split 55 | SPLIT_SEED: "$SPLIT_SEED" -------------------------------------------------------------------------------- /methods_config/multimodal_prompt_config.yml: -------------------------------------------------------------------------------- 1 | # Dataset name 2 | DATASET_DIR: "$DATASET_DIR" 3 | DATASET_NAME: "$DATASET_NAME" 4 | # Model 5 | MODALITY: 'multi' 6 | MODEL: "$MODEL" 7 | # Visual ecoder 8 | VIS_ENCODER: "$VIS_ENCODER" 9 | # Prompt template 10 | PROMPT_TEMPLATE: 'imported in main.py' 11 | # Number of shats per classes in SSL 12 | N_LABEL: 2 13 | ALPHA: 0.3 14 | # Text Prefix size 15 | TEXT_PREFIX_SIZE: 4 16 | # Vision Prefix size 17 | VISION_PREFIX_SIZE: 4 18 | # Lightweight transformer dim 19 | TRANSFORMER_DIM: 128 20 | # Use VPT-Deep? 21 | VPT_DEEP: False 22 | # Prefix initialization: normal/uniform 23 | VIS_PREFIX_INIT: "normal" 24 | # Initialization mean and variance 25 | MEAN_INIT: 0 26 | VAR_INIT: 0.02 27 | # Seeed to separate train and validation 28 | validation_seed: 0 29 | # Ratio validation 30 | ratio_train_val: 0.8 31 | # Batch size 32 | BATCH_SIZE: 16 33 | # Number of epochs 34 | EPOCHS: 150 35 | # Scheduler 36 | SCHEDULER: "cosine" 37 | # Scheduler warmup epochs 38 | WARMUP_EPOCHS: 5 39 | WARMUP_LR: 0.0001 40 | # Number of accumulation iter 41 | ACCUMULATION_ITER: 1 42 | # Optimizer 43 | OPTIM: "SGD" 44 | LR: 0.01 45 | DECAY: 0.1 46 | STEP_SIZE: 1 47 | # Set seeds 48 | OPTIM_SEED: "$OPTIM_SEED" 49 | # Classes split 50 | CLASSES_SPLIT: SPLIT_SEED 51 | # Seed split 52 | SPLIT_SEED: "$SPLIT_SEED" -------------------------------------------------------------------------------- /methods_config/textual_fpl_config.yml: -------------------------------------------------------------------------------- 1 | DATASET_DIR: "$DATASET_DIR" #'/users/cmenghin/data/bats/datasets/classification' # '/Users/menga/Desktop/github/zsl_taglets/development' # 2 | DATASET_NAME: "$DATASET_NAME" 3 | # Model 4 | MODALITY: 'text' 5 | MODEL: "$MODEL" 6 | # Visual ecoder 7 | VIS_ENCODER: "$VIS_ENCODER" 8 | # Prompt template 9 | PROMPT_TEMPLATE: 'imported in main.py' 10 | # Number of shats per classes in SSL 11 | N_LABEL: 2 12 | ALPHA: 0.3 13 | # Prefix size 14 | PREFIX_SIZE: 16 15 | N_PSEUDOSHOTS: 16 16 | STEP_QUANTILE: 10 17 | # Prefix initialization: normal/uniform 18 | VIS_PREFIX_INIT: "normal" 19 | # Initialization mean and variance 20 | MEAN_INIT: 0 21 | VAR_INIT: 0.02 22 | # Seeed to separate train and validation 23 | validation_seed: 0 24 | # Ratio validation 25 | ratio_train_val: 0.8 26 | # Batch size 27 | BATCH_SIZE: 16 28 | # Number of epochs 29 | EPOCHS: 150 30 | # Scheduler 31 | SCHEDULER: "cosine" 32 | # Scheduler warmup epochs 33 | WARMUP_EPOCHS: 5 34 | WARMUP_LR: 0.0001 35 | # Number of accumulation iter 36 | ACCUMULATION_ITER: 1 37 | # Optimizer teacher 38 | OPTIM: "SGD" 39 | LR: 0.1 40 | DECAY: 0.1 41 | STEP_SIZE: 1 42 | STEP_SIZE: 1 43 | # Set seeds 44 | OPTIM_SEED: "$OPTIM_SEED" 45 | # Classes split 46 | CLASSES_SPLIT: SPLIT_SEED 47 | # Seed split 48 | SPLIT_SEED: "$SPLIT_SEED" 49 | -------------------------------------------------------------------------------- /methods_config/textual_prompt_config.yml: -------------------------------------------------------------------------------- 1 | # Dataset name 2 | DATASET_DIR: "$DATASET_DIR" # '/Users/menga/Desktop/github/zsl_taglets/development' # '/users/cmenghin/data/bats/datasets/classification' 3 | DATASET_NAME: "$DATASET_NAME" 4 | # Model 5 | MODALITY: 'text' 6 | MODEL: "$MODEL" 7 | # Visual ecoder 8 | VIS_ENCODER: "$VIS_ENCODER" 9 | # Prompt template 10 | PROMPT_TEMPLATE: 'imported in main.py' 11 | # Number of shats per classes in SSL 12 | N_LABEL: 2 13 | ALPHA: 0.3 14 | # Prefix size 15 | PREFIX_SIZE: 16 16 | # Prefix initialization: normal/uniform 17 | VIS_PREFIX_INIT: "normal" 18 | # Initialization mean and variance 19 | MEAN_INIT: 0 20 | VAR_INIT: 0.02 21 | # Seeed to separate train and validation 22 | validation_seed: 0 23 | # Ratio validation 24 | ratio_train_val: 0.8 25 | # Batch size 26 | BATCH_SIZE: 16 27 | # Number of epochs 28 | EPOCHS: 150 29 | # Scheduler 30 | SCHEDULER: "cosine" 31 | # Scheduler warmup epochs 32 | WARMUP_EPOCHS: 5 33 | WARMUP_LR: 0.0001 34 | # Number of accumulation iter 35 | ACCUMULATION_ITER: 1 36 | # Optimizer 37 | OPTIM: "SGD" 38 | LR: 0.1 39 | DECAY: 0.1 40 | STEP_SIZE: 1 41 | # Set seeds 42 | OPTIM_SEED: "$OPTIM_SEED" 43 | # Classes split 44 | CLASSES_SPLIT: SPLIT_SEED 45 | # Seed split 46 | SPLIT_SEED: "$SPLIT_SEED" -------------------------------------------------------------------------------- /methods_config/visual_fpl_config.yml: -------------------------------------------------------------------------------- 1 | DATASET_DIR: "$DATASET_DIR" #'/users/cmenghin/data/bats/datasets/classification' # '/Users/menga/Desktop/github/zsl_taglets/development' # 2 | DATASET_NAME: "$DATASET_NAME" 3 | # Model 4 | MODALITY: 'image' 5 | MODEL: "$MODEL" 6 | # Visual ecoder 7 | VIS_ENCODER: "$VIS_ENCODER" 8 | # Prompt template 9 | PROMPT_TEMPLATE: 'imported in main.py' 10 | # Number of shats per classes in SSL 11 | N_LABEL: 2 12 | ALPHA: 0.3 13 | # Prefix size 14 | PREFIX_SIZE: 16 15 | N_PSEUDOSHOTS: 16 16 | STEP_QUANTILE: 10 17 | # Prefix initialization: normal/uniform 18 | VIS_PREFIX_INIT: "normal" 19 | # Initialization mean and variance 20 | MEAN_INIT: 0 21 | VAR_INIT: 0.02 22 | # Seeed to separate train and validation 23 | validation_seed: 0 24 | # Ratio validation 25 | ratio_train_val: 0.8 26 | # Batch size 27 | BATCH_SIZE: 16 28 | # Number of epochs 29 | EPOCHS: 150 30 | # Scheduler 31 | SCHEDULER: "cosine" 32 | # Scheduler warmup epochs 33 | WARMUP_EPOCHS: 5 34 | WARMUP_LR: 0.0001 35 | # Number of accumulation iter 36 | ACCUMULATION_ITER: 1 37 | # Optimizer teacher 38 | OPTIM: "SGD" 39 | LR: 0.1 40 | DECAY: 0.1 41 | STEP_SIZE: 1 42 | STEP_SIZE: 1 43 | # Set seeds 44 | OPTIM_SEED: "$OPTIM_SEED" 45 | # Classes split 46 | CLASSES_SPLIT: SPLIT_SEED 47 | # Seed split 48 | SPLIT_SEED: "$SPLIT_SEED" 49 | -------------------------------------------------------------------------------- /methods_config/visual_prompt_config.yml: -------------------------------------------------------------------------------- 1 | # Dataset name 2 | DATASET_DIR: "$DATASET_DIR" # '/Users/menga/Desktop/github/zsl_taglets/development' # 3 | DATASET_NAME: "$DATASET_NAME" 4 | # Model 5 | MODALITY: 'image' 6 | MODEL: "$MODEL" 7 | # Visual ecoder 8 | VIS_ENCODER: "$VIS_ENCODER" 9 | # Prompt template 10 | PROMPT_TEMPLATE: 'imported in main.py' 11 | # Number of shats per classes in SSL 12 | N_LABEL: 2 13 | ALPHA: 0.3 14 | # Prefix size 15 | PREFIX_SIZE: 16 16 | # Prefix initialization: normal/uniform 17 | VIS_PREFIX_INIT: "normal" 18 | # Initialization mean and variance 19 | MEAN_INIT: 0 20 | VAR_INIT: 0.02 21 | # Seeed to separate train and validation 22 | validation_seed: 0 23 | # Ratio validation 24 | ratio_train_val: 0.8 25 | # Batch size 26 | BATCH_SIZE: 16 27 | # Number of epochs 28 | EPOCHS: 150 29 | # Scheduler 30 | SCHEDULER: "cosine" 31 | # Scheduler warmup epochs 32 | WARMUP_EPOCHS: 5 33 | WARMUP_LR: 0.0001 34 | # Number of accumulation iter 35 | ACCUMULATION_ITER: 1 36 | # Optimizer 37 | OPTIM: "SGD" 38 | LR: 0.1 39 | DECAY: 0.1 40 | STEP_SIZE: 1 41 | # Set seeds 42 | OPTIM_SEED: "$OPTIM_SEED" 43 | # Classes split 44 | CLASSES_SPLIT: SPLIT_SEED 45 | # Seed split 46 | SPLIT_SEED: "$SPLIT_SEED" -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip_encoders import ( 2 | CustomImageEncoder, 3 | CustomTextEncoder, 4 | ImageEncoder, 5 | TextEncoder, 6 | ) 7 | from .prompts_models import ( 8 | ImagePrefixModel, 9 | TextPrefixModel, 10 | UPTModel, 11 | ) 12 | -------------------------------------------------------------------------------- /models/clip_encoders.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import os.path as osp 4 | 5 | import torch 6 | import torch.nn as nn 7 | from clip import clip 8 | 9 | 10 | log = logging.getLogger(__name__) 11 | 12 | 13 | class TextEncoder(nn.Module): 14 | """CLIP text encoder""" 15 | 16 | def __init__(self, clip_model): 17 | super(TextEncoder, self).__init__() 18 | self.clip_model = clip_model 19 | 20 | def forward(self, text): 21 | encoded_text = self.clip_model.encode_text(text) 22 | return encoded_text 23 | 24 | 25 | class CustomTextEncoder(torch.nn.Module): 26 | """This class is adapted from the codebase of "Learning to Compose Soft Prompts for Compositional Zero-Shot Learning" 27 | https://github.com/BatsResearch/csp/blob/main/clip_modules/text_encoder.py""" 28 | 29 | def __init__(self, clip_model, device, dtype): 30 | super(CustomTextEncoder, self).__init__() 31 | self.dtype = dtype 32 | self.clip_model = clip_model 33 | self.transformer = clip_model.transformer 34 | self.positional_embedding = clip_model.positional_embedding 35 | self.ln_final = clip_model.ln_final 36 | self.text_projection = clip_model.text_projection 37 | self.token_embedding = clip_model.token_embedding 38 | self.device = device 39 | 40 | def tokenize(self, text): 41 | return torch.cat([clip.tokenize(tok) for tok in text]) 42 | 43 | def forward(self, class_embeddings, classes, enable_pos_emb=True): 44 | """The forward function to compute representations for the prompts. 45 | 46 | :param class_embedding: These are the vectors for class names 47 | :param labels: tensor of labels (already preprocessed without _) 48 | :param enable_pos_emb: We set this to True since we want to account for the order 49 | 50 | Returns: 51 | torch.Tensor: the vector representation of the prompt. 52 | """ 53 | 54 | prompts = [ 55 | " ".join([" ".join(["X"] * (class_embeddings.size()[1])).strip(), c]) 56 | for c in classes 57 | ] 58 | # log.info(f"Extended text: {prompts}") 59 | 60 | token_ids = clip.tokenize(prompts) 61 | 62 | # Get embeddings for the prompt 63 | text_embedding = self.token_embedding(token_ids.to(self.device)) 64 | # for idx in range(class_embeddings.size()[0]): 65 | # text_embedding[idx, 1:(class_embeddings[idx].size()[0]+1), :] = class_embeddings[idx] 66 | 67 | text_embedding[:, 1 : (class_embeddings[0].size()[0] + 1), :] = class_embeddings 68 | 69 | text_features = text_embedding.type(self.dtype) 70 | x = ( 71 | text_features + self.positional_embedding.type(self.dtype) 72 | if enable_pos_emb 73 | else text_features 74 | ) 75 | x = x.permute(1, 0, 2) 76 | # log.info(f'DEVICE: {type(self.device)}') 77 | 78 | if torch.cuda.is_available(): 79 | # log.info('WRONG CPU') 80 | x = self.transformer(x) 81 | else: 82 | # log.info('CPU') 83 | x = self.transformer(x.float()) 84 | x = x.permute(1, 0, 2) 85 | x = self.ln_final(x) 86 | tf = ( 87 | x[torch.arange(x.shape[0]), token_ids.argmax(dim=-1)] # POS of 88 | @ self.text_projection 89 | ) 90 | return tf 91 | 92 | 93 | class ImageEncoder(nn.Module): 94 | """CLIP image encoder""" 95 | 96 | def __init__(self, clip_model): 97 | super(ImageEncoder, self).__init__() 98 | self.clip_model = clip_model 99 | 100 | def forward(self, text): 101 | encoded_image = self.clip_model.encode_image(text) 102 | return encoded_image 103 | 104 | 105 | class CustomVisionTransformer(nn.Module): 106 | def __init__(self, vision_transformer): 107 | super().__init__() 108 | self.input_resolution = vision_transformer.input_resolution 109 | self.output_dim = vision_transformer.output_dim 110 | self.conv1 = vision_transformer.conv1 111 | 112 | self.class_embedding = vision_transformer.class_embedding 113 | self.positional_embedding = vision_transformer.positional_embedding 114 | self.ln_pre = vision_transformer.ln_pre 115 | 116 | self.transformer = vision_transformer.transformer 117 | 118 | self.ln_post = vision_transformer.ln_post 119 | self.proj = vision_transformer.proj 120 | 121 | # self.type = config.TYPE 122 | 123 | def forward( 124 | self, 125 | x: torch.Tensor, 126 | image_prefix: torch.Tensor, 127 | pos_emb=True, 128 | deep_embs=None, 129 | ): 130 | 131 | x = self.conv1(x) # shape = [*, width, grid, grid] 132 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 133 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 134 | 135 | x = torch.cat( 136 | [ 137 | self.class_embedding.to(x.dtype).to(x.device) 138 | + torch.zeros( 139 | x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device 140 | ), 141 | x, 142 | ], 143 | dim=1, 144 | ) # shape = [*, grid ** 2 + 1, width] 145 | 146 | x = x + self.positional_embedding.to(x.dtype) if pos_emb else x 147 | 148 | image_prefix = image_prefix.expand(x.shape[0], -1, -1) 149 | # Here we concat the prefix to the flattened patches 150 | x = torch.cat([ 151 | x[:,:1,:], 152 | image_prefix, 153 | x[:,1:,:], 154 | ], 155 | dim=1,) 156 | 157 | # x = torch.cat([ 158 | # image_prefix, 159 | # x, 160 | # ], 161 | # dim=1,) 162 | 163 | x = self.ln_pre(x) 164 | 165 | x = x.permute(1, 0, 2) # NLD -> LND 166 | if deep_embs is not None: 167 | B = x.shape[0] 168 | # Code adapted from https://github.com/sIncerass/MVLPT/blob/main/trainers/mvlpt.py#L65 169 | for layer_idx in range(self.visual.transformer.layers): 170 | layer = self.transformer.resblocks[layer_idx] 171 | 172 | if layer_idx == 0: 173 | x = layer(x) 174 | elif layer_idx <= deep_embs.shape[0]: 175 | vpt_emb_deep = self.mvlpt_model.vpt_dropout(self.mvlpt_model.vpt_proj( 176 | deep_embs[layer_idx-1]).expand(B, -1, -1)).to(x.dtype) 177 | 178 | vpt_emb_deep = vpt_emb_deep.permute(1, 0, 2) # NLD -> LND 179 | x = torch.cat(( 180 | x[:1, :, :], 181 | vpt_emb_deep, 182 | x[(1+self.mvlpt_model.vpt_n_ctx):, :, :] 183 | ), dim=0) 184 | x = layer(x) 185 | else: 186 | x = self.transformer(x) 187 | x = x.permute(1, 0, 2) # LND -> NLD 188 | 189 | x = self.ln_post(x[:, 0, :]) 190 | 191 | if self.proj is not None: 192 | x = x @ self.proj 193 | 194 | return x 195 | 196 | 197 | 198 | class CustomImageEncoder(nn.Module): 199 | """CLIP image encoder""" 200 | 201 | def __init__(self, visual): 202 | super(CustomImageEncoder, self).__init__() 203 | self.visual = CustomVisionTransformer(visual) 204 | self.dtype = self.visual.conv1.weight.dtype 205 | 206 | def forward(self, image, prefix, deep_embds=None): 207 | encoded_image = self.visual(image.type(self.dtype), prefix.type(self.dtype), deep_embs=deep_embds) 208 | return encoded_image 209 | -------------------------------------------------------------------------------- /models/prompts_models.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import clip 4 | import torch 5 | from torch import nn 6 | 7 | log = logging.getLogger(__name__) 8 | 9 | 10 | class TextPrefixModel(nn.Module): 11 | def __init__( 12 | self, initial_prefix, text_encoder, classes, temperature=0.07, device="cpu" 13 | ): 14 | """Define the model for textual prompt tuning. 15 | 16 | :param initial_prefix: initializes tensor of floats 17 | :param text_encoder: text encoder to use 18 | :param classes: list of classes' names 19 | :param temperature: fix parameter, same as clip 20 | :param device: device in use 21 | """ 22 | 23 | super(TextPrefixModel, self).__init__() 24 | self.device = device 25 | self.initialized_prefix = initial_prefix 26 | self.classes = classes 27 | 28 | self.prefix = nn.Parameter(initial_prefix) 29 | self.text_encoder = text_encoder 30 | 31 | def forward(self, classes): 32 | # log.info(f"classes: {classes}") 33 | out = self.text_encoder(self.prefix, classes) 34 | norm_out = out / out.norm(dim=-1, keepdim=True) 35 | 36 | return out 37 | 38 | 39 | class ImagePrefixModel(nn.Module): 40 | def __init__( 41 | self, 42 | initial_prefix, 43 | image_encoder, 44 | temperature=0.07, 45 | device="cpu", 46 | ): 47 | super(ImagePrefixModel, self).__init__() 48 | self.device = device 49 | self.initialized_prefix = initial_prefix 50 | 51 | # Initialize the model's parametwets 52 | self.prefix = nn.Parameter(initial_prefix) 53 | self.image_encoder = image_encoder 54 | 55 | def forward(self, x): 56 | # Combine prefix and class embeddings to get the entire prompt representation for the 57 | # two augmented images 58 | out = self.image_encoder(x, self.prefix) 59 | norm_out = out / out.norm(dim=-1, keepdim=True) 60 | 61 | return out 62 | 63 | 64 | class UPTModel(nn.Module): 65 | def __init__( 66 | self, 67 | coop_embeddings, 68 | vpt_embeddings, 69 | vpt_embeddings_deep, 70 | image_encoder, 71 | text_encoder, 72 | classes, 73 | dim_transformer, 74 | temperature=0.07, 75 | device="cpu", 76 | dtype=torch.float32, 77 | ): 78 | super(UPTModel, self).__init__() 79 | self.device = device 80 | self.classes = classes 81 | self.temperature = temperature 82 | self.dtype = dtype 83 | 84 | # Initialize the model's parameters 85 | self.coop_embeddings = nn.Parameter(coop_embeddings) 86 | self.vpt_embeddings = nn.Parameter(vpt_embeddings) 87 | 88 | self.coop_length = self.coop_embeddings.size()[1] 89 | self.coop_dim = self.coop_embeddings.size()[2] 90 | 91 | self.vpt_length = self.vpt_embeddings.size()[1] 92 | self.vpt_dim = self.vpt_embeddings.size()[2] 93 | 94 | if vpt_embeddings_deep is not None: 95 | self.vpt_embeddings_deep = nn.Parameter(vpt_embeddings_deep) 96 | else: 97 | self.vpt_embeddings_deep = None 98 | 99 | self.proj_coop_pre = nn.Linear( 100 | self.coop_dim, 101 | dim_transformer, 102 | dtype=self.dtype).to(self.device) 103 | self.proj_coop_post = nn.Linear( 104 | dim_transformer, 105 | self.coop_dim, 106 | dtype=self.dtype).to(self.device) 107 | self.proj_vpt_pre = nn.Linear( 108 | self.vpt_dim, 109 | dim_transformer, 110 | dtype=self.dtype).to(self.device) 111 | self.proj_vpt_post = nn.Linear( 112 | dim_transformer, 113 | self.vpt_dim, 114 | dtype=self.dtype).to(self.device) 115 | 116 | self.transformer = clip.model.Transformer( 117 | width=dim_transformer, 118 | layers=1, 119 | heads=1).to(self.device) 120 | 121 | self.image_encoder = image_encoder 122 | self.text_encoder = text_encoder 123 | 124 | # Given coop_embeddings, vpt_embeddings, and vpt_embeddings_deep 125 | # - Project into 128 dim space 126 | # - Run sequence through transformer 127 | # - Project back to CLIP (512) dim space 128 | # (Error when there is no input arg. https://github.com/pytorch/pytorch/pull/37902) 129 | def forward(self, x, classes): 130 | # First, we project the prompts into lower dim space, and concat them, and make them correct dtype 131 | coop_embeddings = self.coop_embeddings 132 | coop_embds = self.proj_coop_pre(coop_embeddings).to(self.device) 133 | if self.vpt_embeddings_deep is not None: 134 | vpt_embds = torch.cat((self.vpt_embeddings, self.vpt_embeddings_deep), dim=0).to(self.device) 135 | vpt_embds = self.proj_vpt_pre(self.vpt_embeddings).to(self.device) 136 | # vpt_embds = vpt_embds.reshape((-1, self.vpt_length, self.vpt_dim)) #flatten if they are deep embds 137 | # concat coop and vpt prompts 138 | prompt_seq = torch.cat((coop_embds, vpt_embds), dim=0).to(torch.float32) # TODO: Fix hacky type change 139 | 140 | # Then, we run the sequence through the transformer 141 | output_seq = self.transformer(prompt_seq).to(torch.float16) # TODO: Fix hacky type change 142 | 143 | # Finally, we project the seq back into prompt space 144 | coop_embs = self.proj_coop_post(output_seq[:len(self.coop_embeddings)].to(self.dtype)).reshape(-1, self.coop_length, self.coop_dim) 145 | vpt_embs = self.proj_vpt_post(output_seq[len(self.coop_embeddings):].to(self.dtype)).reshape(-1, self.vpt_length, self.vpt_dim) 146 | vpt_emb_deep = None if vpt_embs.shape[0] == 1 else vpt_embs[1:, :, :] 147 | 148 | text_out = self.text_encoder(coop_embs, classes) 149 | norm_text_out = text_out / text_out.norm(dim=-1, keepdim=True) 150 | visual_out = self.image_encoder(x, vpt_embs) 151 | norm_visual_out = visual_out / visual_out.norm(dim=-1, keepdim=True) 152 | 153 | return text_out, visual_out -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.15.0 2 | git+https://github.com/openai/CLIP.git 3 | numpy 4 | pandas 5 | Pillow 6 | PyYAML 7 | requests 8 | tqdm 9 | urllib3 10 | -------------------------------------------------------------------------------- /run_main_clip.py: -------------------------------------------------------------------------------- 1 | import methods.main_CLIP as pipeline 2 | 3 | if __name__ == '__main__': 4 | pipeline.main() -------------------------------------------------------------------------------- /run_main_ssl.py: -------------------------------------------------------------------------------- 1 | import methods.main_SSL as pipeline 2 | 3 | if __name__ == '__main__': 4 | pipeline.main() -------------------------------------------------------------------------------- /run_main_trzsl.py: -------------------------------------------------------------------------------- 1 | import methods.main_TRZSL as pipeline 2 | 3 | if __name__ == '__main__': 4 | 5 | pipeline.main() -------------------------------------------------------------------------------- /run_main_ul.py: -------------------------------------------------------------------------------- 1 | import methods.main_UL as pipeline 2 | 3 | if __name__ == '__main__': 4 | pipeline.main() -------------------------------------------------------------------------------- /scripts/run_clip.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for dataset_dir in 'pata/data' ; do # add here the path to the folder containing dataset folders 4 | for vis_encoder in 'ViT-B/32'; do # You can choose among 'ViT-B/32' and 'ViT-L/14' 5 | for split_seed in 500; do # This indicate the split for TRZSL, i.e., 500, 0, or 200. For other learning setting this is 500 as default. 6 | for dataset_name in RESICS45; do # DTD Flowers102 EuroSAT FGVCAircraft MNIST 7 | for model in clip; do 8 | for optim_seed in 1; do # To simply make inference with CLIP we do not need multiple seeds 9 | 10 | export OPTIM_SEED="$optim_seed" 11 | export VIS_ENCODER="$vis_encoder" 12 | export DATASET_NAME="$dataset_name" 13 | export SPLIT_SEED="$split_seed" 14 | export MODEL="$model" 15 | export DATASET_DIR="$dataset_dir" 16 | 17 | python3 ./run_main_clip.py \ 18 | --model_config ${model}_config.yml \ 19 | --learning_paradigm trzsl # Choose among ul, ssl, and trzsl 20 | 21 | done 22 | done 23 | done 24 | done 25 | done 26 | done -------------------------------------------------------------------------------- /scripts/run_prompts_ssl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for dataset_dir in '/path/data' ; do # add here the path to the folder containing dataset folders 4 | for vis_encoder in 'ViT-B/32'; do # You can choose among 'ViT-B/32' and 'ViT-L/14' 5 | for split_seed in 500; do # This indicate the split for TRZSL, i.e., 500, 0, or 200. For other learning setting this is 500 as default. 6 | for dataset_name in RESICS45; do # DTD Flowers102 EuroSAT FGVCAircraft MNIST 7 | for model in textual_prompt visual_prompt multimodal_prompt; do 8 | for optim_seed in 1; do # 1 2 3 4 5 are the seeds we used 9 | 10 | export OPTIM_SEED="$optim_seed" 11 | export VIS_ENCODER="$vis_encoder" 12 | export DATASET_NAME="$dataset_name" 13 | export SPLIT_SEED="$split_seed" 14 | export MODEL="$model" 15 | export DATASET_DIR="$dataset_dir" 16 | 17 | # Set accelerate configuration file to to accelerate_config.yml when running on GPUs 18 | accelerate launch --config_file methods_config/accelerate_localtest_config.yml run_main_ssl.py --model_config ${model}_config.yml \ 19 | --learning_paradigm ssl # Choose among ul, ssl, and trzsl 20 | 21 | done 22 | done 23 | done 24 | done 25 | done 26 | done -------------------------------------------------------------------------------- /scripts/run_prompts_trzsl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for dataset_dir in '/Users/menga/Desktop/github/zsl_taglets/development' ; do # add here the path to the folder containing dataset folders 4 | for vis_encoder in 'ViT-B/32'; do # You can choose among 'ViT-B/32' and 'ViT-L/14' 5 | for split_seed in 500; do # This indicate the split for TRZSL, i.e., 500, 0, or 200. For other learning setting this is 500 as default. 6 | for dataset_name in RESICS45; do # DTD Flowers102 EuroSAT FGVCAircraft MNIST 7 | for model in textual_prompt visual_prompt multimodal_prompt; do 8 | for optim_seed in 1; do # 1 2 3 4 5 are the seeds we used 9 | 10 | export OPTIM_SEED="$optim_seed" 11 | export VIS_ENCODER="$vis_encoder" 12 | export DATASET_NAME="$dataset_name" 13 | export SPLIT_SEED="$split_seed" 14 | export MODEL="$model" 15 | export DATASET_DIR="$dataset_dir" 16 | 17 | # Set accelerate configguration file to to accelerate_config.yml when running on GPUs 18 | accelerate launch --config_file methods_config/accelerate_localtest_config.yml run_main_trzsl.py --model_config ${model}_config.yml \ 19 | --learning_paradigm trzsl # Choose among ul, ssl, and trzsl 20 | 21 | done 22 | done 23 | done 24 | done 25 | done 26 | done -------------------------------------------------------------------------------- /scripts/run_pseudolabels_ssl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for dataset_dir in '/path/data' ; do # add here the path to the folder containing dataset folders 4 | for vis_encoder in 'ViT-B/32'; do # You can choose among 'ViT-B/32' and 'ViT-L/14' 5 | for split_seed in 500; do # This indicate the split for TRZSL, i.e., 500, 0, or 200. For other learning setting this is 500 as default. 6 | for dataset_name in RESICS45; do # DTD Flowers102 EuroSAT FGVCAircraft MNIST 7 | for model in textual_fpl; do # Choose among: textual_fpl, visual_fpl, multimodal_fpl, iterative_textual_fpl, iterative_visual_fpl, iterative_multimodal_fpl, grip_textual, grip_visual, grip_multimodal 8 | for optim_seed in 1; do # 1 2 3 4 5 are the seeds we used 9 | 10 | export OPTIM_SEED="$optim_seed" 11 | export VIS_ENCODER="$vis_encoder" 12 | export DATASET_NAME="$dataset_name" 13 | export SPLIT_SEED="$split_seed" 14 | export MODEL="$model" 15 | export DATASET_DIR="$dataset_dir" 16 | 17 | # Set accelerate configuration file to to accelerate_config.yml when running on GPUs 18 | accelerate launch --config_file methods_config/accelerate_localtest_config.yml run_main_ssl.py --model_config ${model}_config.yml \ 19 | --learning_paradigm ssl # Choose among ul, ssl, and trzsl 20 | 21 | done 22 | done 23 | done 24 | done 25 | done 26 | done -------------------------------------------------------------------------------- /scripts/run_pseudolabels_trzsl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for dataset_dir in '/path/data' ; do # add here the path to the folder containing dataset folders 4 | for vis_encoder in 'ViT-B/32'; do # You can choose among 'ViT-B/32' and 'ViT-L/14' 5 | for split_seed in 500; do # This indicate the split for TRZSL, i.e., 500, 0, or 200. For other learning setting this is 500 as default. 6 | for dataset_name in RESICS45; do # DTD Flowers102 EuroSAT FGVCAircraft MNIST 7 | for model in textual_fpl; do # Choose among: textual_fpl, visual_fpl, multimodal_fpl, iterative_textual_fpl, iterative_visual_fpl, iterative_multimodal_fpl, grip_textual, grip_visual, grip_multimodal 8 | for optim_seed in 1; do # 1 2 3 4 5 are the seeds we used 9 | 10 | export OPTIM_SEED="$optim_seed" 11 | export VIS_ENCODER="$vis_encoder" 12 | export DATASET_NAME="$dataset_name" 13 | export SPLIT_SEED="$split_seed" 14 | export MODEL="$model" 15 | export DATASET_DIR="$dataset_dir" 16 | 17 | # Set accelerate configguration file to to accelerate_config.yml when running on GPUs 18 | accelerate launch --config_file methods_config/accelerate_localtest_config.yml run_main_trzsl.py --model_config ${model}_config.yml \ 19 | --learning_paradigm trzsl # Choose among ul, ssl, and trzsl 20 | 21 | done 22 | done 23 | done 24 | done 25 | done 26 | done -------------------------------------------------------------------------------- /scripts/run_pseudolabels_ul.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for dataset_dir in '/path/data' ; do # add here the path to the folder containing dataset folders 4 | for vis_encoder in 'ViT-B/32'; do # You can choose among 'ViT-B/32' and 'ViT-L/14' 5 | for split_seed in 500; do # This indicate the split for TRZSL, i.e., 500, 0, or 200. For other learning setting this is 500 as default. 6 | for dataset_name in RESICS45; do # DTD Flowers102 EuroSAT FGVCAircraft MNIST 7 | for model in textual_fpl; do # Choose among: textual_fpl, visual_fpl, multimodal_fpl, iterative_textual_fpl, iterative_visual_fpl, iterative_multimodal_fpl, grip_textual, grip_visual, grip_multimodal 8 | for optim_seed in 1; do # 1 2 3 4 5 are the seeds we used 9 | 10 | export OPTIM_SEED="$optim_seed" 11 | export VIS_ENCODER="$vis_encoder" 12 | export DATASET_NAME="$dataset_name" 13 | export SPLIT_SEED="$split_seed" 14 | export MODEL="$model" 15 | export DATASET_DIR="$dataset_dir" 16 | 17 | # Set accelerate configguration file to to accelerate_config.yml when running on GPUs 18 | accelerate launch --config_file methods_config/accelerate_localtest_config.yml run_main_ul.py --model_config ${model}_config.yml \ 19 | --learning_paradigm ul # Choose among ul, ssl, and trzsl 20 | 21 | done 22 | done 23 | done 24 | done 25 | done 26 | done -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | pip install --upgrade pip 4 | pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html 5 | pip install -r requirements.txt 6 | pip install importlib_metadata 7 | pip install pytorch-metric-learning -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip_pseudolabels import pseudolabel_top_k 2 | from .compute_metrics import ( 3 | evaluate_predictions, 4 | store_results, 5 | save_parameters, 6 | save_predictions, 7 | save_pseudo_labels, 8 | ) 9 | from .prepare_data import ( 10 | get_class_names, 11 | get_labeled_and_unlabeled_data, 12 | ) 13 | from .schedulers import make_scheduler 14 | from .utils import ( 15 | Config, 16 | dataset_object, 17 | seed_worker, 18 | ) 19 | -------------------------------------------------------------------------------- /utils/clip_pseudolabels.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | 5 | import clip 6 | import torch 7 | from PIL import Image 8 | from tqdm import tqdm 9 | 10 | log = logging.getLogger(__name__) 11 | 12 | 13 | def compute_pseudo_labels( 14 | k, 15 | template, 16 | dataset, 17 | classnames, 18 | transform, 19 | clip_model, 20 | label_to_idx, 21 | device, 22 | filename, 23 | ): 24 | prompts = [f"{template}{' '.join(i.split('_'))}" for i in classnames] 25 | text = clip.tokenize(prompts).to(device) 26 | 27 | if k == 10000000: 28 | log.info(f"Compute pseudo-labeles on all unlabeled data") 29 | new_labels = [] 30 | new_imgs = [] 31 | for i, image_path in enumerate(tqdm(dataset.filepaths)): 32 | img = Image.open(image_path).convert("RGB") 33 | img = transform(img).to(device) 34 | with torch.no_grad(): 35 | logits_per_image, logits_per_text = clip_model( 36 | torch.unsqueeze(img, 0).to(device), text 37 | ) 38 | probs = logits_per_image.softmax(dim=-1) 39 | idx_preds = torch.argmax(probs, dim=1) 40 | pred_id = idx_preds.item() 41 | pred = label_to_idx[classnames[idx_preds.item()]] 42 | 43 | new_labels.append(pred) 44 | new_imgs.append(image_path) 45 | 46 | else: 47 | 48 | # to find the top k for each class, each class has it's own "leaderboard" 49 | top_k_leaderboard = { 50 | label_to_idx[classnames[i]]: [] for i in range(len(classnames)) 51 | } # maps class idx -> (confidence, image_path) tuple 52 | 53 | log.info(f"Compute {k} pseudo-labeles") 54 | # log.info(f"{label_to_idx}") 55 | for i, image_path in enumerate(tqdm(dataset.filepaths)): 56 | img = Image.open(image_path).convert("RGB") 57 | img = transform(img).to(device) 58 | with torch.no_grad(): 59 | logits_per_image, logits_per_text = clip_model( 60 | torch.unsqueeze(img, 0).to(device), text 61 | ) 62 | probs = logits_per_image.softmax(dim=-1) 63 | idx_preds = torch.argmax(probs, dim=1) 64 | pred_id = idx_preds.item() 65 | pred = label_to_idx[classnames[idx_preds.item()]] 66 | # log.info(f"{classnames[idx_preds.item()]}") 67 | # log.info(f"{pred}") 68 | 69 | """if predicted class has empty leaderboard, or if the confidence is high 70 | enough for predicted class leaderboard, add the new example 71 | """ 72 | prob_score = probs[0][pred_id] 73 | if len(top_k_leaderboard[pred]) < k: 74 | top_k_leaderboard[pred].append((prob_score, image_path)) 75 | elif ( 76 | top_k_leaderboard[pred][-1][0] < prob_score 77 | ): # if the confidence in predicted class "qualifies" for top-k 78 | # default sorting of tuples is by first element 79 | top_k_leaderboard[pred] = sorted( 80 | top_k_leaderboard[pred] + [(probs[0][pred_id], image_path)], 81 | reverse=True, 82 | )[:k] 83 | else: 84 | # sort the other classes by confidence score 85 | order_of_classes = sorted( 86 | [(probs[0][j], j) for j in range(len(classnames)) if j != pred_id], 87 | reverse=True, 88 | ) 89 | for score, index in order_of_classes: 90 | index_dict = label_to_idx[classnames[index]] 91 | # log.info(f"{classnames[index]}") 92 | # log.info(f"{index_dict}") 93 | if len(top_k_leaderboard[index_dict]) < k: 94 | top_k_leaderboard[index_dict].append((probs[0][index], image_path)) 95 | elif top_k_leaderboard[index_dict][-1][0] < probs[0][index]: 96 | # default sorting of tuples is by first element 97 | top_k_leaderboard[index_dict] = sorted( 98 | top_k_leaderboard[index_dict] 99 | + [((probs[0][index], image_path))], 100 | reverse=True, 101 | )[:k] 102 | 103 | new_imgs = [] 104 | new_labels = [] 105 | # loop through, and rebuild the dataset 106 | for index, leaderboard in top_k_leaderboard.items(): 107 | # print(len(dataset.imgs)) 108 | new_imgs += [tup[1] for tup in leaderboard] 109 | new_labels += [index for _ in leaderboard] 110 | 111 | dataset.filepaths = new_imgs 112 | dataset.labels = new_labels 113 | 114 | with open(filename, "wb") as f: 115 | pickle.dump({"filepaths": new_imgs, "labels": new_labels}, f) 116 | 117 | return dataset 118 | 119 | 120 | def pseudolabel_top_k( 121 | config, 122 | data_name, 123 | k, 124 | template, 125 | dataset, 126 | classnames, 127 | transform, 128 | clip_model, 129 | label_to_idx, 130 | device, 131 | vis_encoder, 132 | split_seed, 133 | ): 134 | filename = f"pseudolabels/{data_name}_{vis_encoder.replace('/', '')}_{config.LEARNING_PARADIGM}_{config.MODEL}_{k}_pseudolabels_split_{split_seed}.pickle" 135 | if os.path.exists(filename): 136 | # print('Load pseudolabels') 137 | with open(filename, "rb") as f: 138 | pseudolabels = pickle.load(f) 139 | new_imgs = pseudolabels["filepaths"] 140 | new_labels = pseudolabels["labels"] 141 | 142 | dataset.filepaths = new_imgs 143 | dataset.labels = new_labels 144 | else: 145 | dataset = compute_pseudo_labels( 146 | k, 147 | template, 148 | dataset, 149 | classnames, 150 | transform, 151 | clip_model, 152 | label_to_idx, 153 | device, 154 | filename, 155 | ) 156 | 157 | return dataset -------------------------------------------------------------------------------- /utils/compute_metrics.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import pickle 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import scipy.stats as st 9 | import torch 10 | from accelerate import Accelerator 11 | 12 | accelerator = Accelerator() 13 | 14 | 15 | log = logging.getLogger(__name__) 16 | 17 | 18 | def evaluate_predictions( 19 | config, 20 | df_predictions, 21 | test_labeled_files, 22 | labels, 23 | unseen_classes, 24 | seen_classes=None 25 | ): 26 | df_test = pd.DataFrame({"id": test_labeled_files, "true": labels}) 27 | df_test["id"] = df_test["id"].apply(lambda x: x.split("/")[-1]) 28 | # log.info(f"DF TEST: {df_test.head(5)}") 29 | # log.info(f"DF PREDS: {df_predictions.head(5)}") 30 | df_predictions = pd.merge(df_predictions, df_test, on="id") 31 | 32 | if config.LEARNING_PARADIGM == 'ul' or config.LEARNING_PARADIGM == 'ssl': 33 | accuracy = ( 34 | np.sum(df_predictions["class"] == df_predictions["true"]) 35 | / df_predictions.shape[0] 36 | ) 37 | 38 | return accuracy, None, None 39 | 40 | else: 41 | # Compute unseen accuracy 42 | unseen_predictions = df_predictions[df_predictions["true"].isin(unseen_classes)] 43 | unseen_accuracy = ( 44 | np.sum(unseen_predictions["class"] == unseen_predictions["true"]) 45 | / unseen_predictions.shape[0] 46 | ) 47 | # Compute seen accuracy 48 | seen_predictions = df_predictions[df_predictions["true"].isin(seen_classes)] 49 | seen_accuracy = ( 50 | np.sum(seen_predictions["class"] == seen_predictions["true"]) 51 | / seen_predictions.shape[0] 52 | ) 53 | 54 | harmonic_mean = st.hmean([unseen_accuracy, seen_accuracy]) 55 | 56 | return unseen_accuracy, seen_accuracy, harmonic_mean 57 | 58 | def store_results( 59 | obj_conf, 60 | std_response 61 | ): 62 | """The function stores results of the model in a json. 63 | 64 | :param obj_config: class object that stores configurations 65 | :param std_response: for UL and SSL it is a variable corresponding 66 | to the accuracy of the model. For TRZSL is is a tuple of seen, 67 | unseen, and harmonic accuracy. 68 | """ 69 | if obj_conf.LEARNING_PARADIGM == 'trzsl': 70 | # Store results 71 | if accelerator.is_local_main_process: 72 | results_to_store = { 73 | "model": obj_conf.MODEL, 74 | "config": obj_conf.__dict__, 75 | # "std_accuracy": std_response, 76 | "harmonic_mean": std_response[2], #harmonic_mean, 77 | "seen_accuracy": std_response[1], # seen_accuracy, 78 | "unseen_accuracy": std_response[0] # unseen_accuracy, 79 | } 80 | else: 81 | # Store results 82 | if accelerator.is_local_main_process: 83 | results_to_store = { 84 | "model": obj_conf.MODEL, 85 | "config": obj_conf.__dict__, 86 | "accuracy": std_response[0], 87 | } 88 | 89 | 90 | if accelerator.is_local_main_process: 91 | file_name = f"results_model_{obj_conf.MODEL}.json" 92 | 93 | # Check if the file already exists 94 | if os.path.exists(file_name): 95 | # If the file exists, open it in append mode 96 | with open(file_name, "a") as f: 97 | # Append the res dictionary to the file 98 | f.write(json.dumps(results_to_store) + "\n") 99 | else: 100 | # If the file doesn't exist, create a new file 101 | with open(file_name, "w") as f: 102 | # Write the res dictionary to the file 103 | f.write(json.dumps(results_to_store) + "\n") 104 | 105 | def save_parameters(obj, config, iteration=None): 106 | """ Save in a pickle the parameters used for 107 | evaluation. 108 | 109 | :param obj: object to save 110 | :param config: object with method configurations 111 | :param iteration: indicate the number of iteration for iterative strategies 112 | """ 113 | 114 | if iteration is None: 115 | file_name = f"trained_prompts/{config.DATASET_NAME}_{config.LEARNING_PARADIGM}_{config.MODEL}_{config.VIS_ENCODER.replace('/','')}_opt_{config.OPTIM_SEED}_spl_{config.SPLIT_SEED}.pickle" 116 | else: 117 | file_name = f"trained_prompts/{config.DATASET_NAME}_{config.LEARNING_PARADIGM}_{config.MODEL}_{config.VIS_ENCODER.replace('/','')}_iter_{iteration}_opt_{config.OPTIM_SEED}_spl_{config.SPLIT_SEED}.pickle" 118 | 119 | if config.MODALITY == 'multi': 120 | names = [ 121 | 'transformer', 122 | 'proj_coop_pre', 123 | 'proj_coop_post', 124 | 'proj_vpt_pre', 125 | 'proj_vpt_post', 126 | 'coop_embeddings', 127 | 'deep_vpt', 128 | 'vpt_embeddings' 129 | ] 130 | for idx, param in enumerate(obj): 131 | if names[idx] in [ 132 | 'transformer', 133 | 'proj_coop_pre', 134 | 'proj_coop_post', 135 | 'proj_vpt_pre', 136 | 'proj_vpt_post', 137 | ]: 138 | ff = file_name.split('.')[:-1][0] 139 | torch.save(obj[idx], f'{ff}_{names[idx]}.pt') 140 | else: 141 | ff = file_name.split('.')[:-1][0] 142 | with open(f'{ff}_{names[idx]}.pickle', 'wb') as f: 143 | pickle.dump(obj[idx], f) 144 | 145 | else: 146 | with open(file_name, 'wb') as f: 147 | pickle.dump(obj, f) 148 | 149 | 150 | def save_pseudo_labels(imgs, labs, config, iteration): 151 | 152 | filename = f"pseudolabels/{config.DATASET_NAME}_{config.LEARNING_PARADIGM}_{config.MODEL}_{config.VIS_ENCODER.replace('/', '')}_iter_{iteration}_opt_{config.OPTIM_SEED}_spl_{config.SPLIT_SEED}.pickle" 153 | with open(filename, "wb") as f: 154 | pickle.dump({"filepaths": imgs, "labels": labs}, f) 155 | 156 | 157 | def save_predictions(obj, config, iteration=None): 158 | """ Save in a pickle the parameters used for 159 | evaluation. 160 | 161 | :param obj: object to save 162 | :param config: object with method configurations 163 | """ 164 | 165 | if iteration is None: 166 | file_name = f"evaluation/{config.DATASET_NAME}_{config.LEARNING_PARADIGM}_{config.MODEL}_{config.VIS_ENCODER.replace('/','')}_opt_{config.OPTIM_SEED}_spl_{config.SPLIT_SEED}.pickle" 167 | else: 168 | file_name = f"evaluation/{config.DATASET_NAME}_{config.LEARNING_PARADIGM}_{config.MODEL}_{config.VIS_ENCODER.replace('/','')}_iter_{iteration}_opt_{config.OPTIM_SEED}_spl_{config.SPLIT_SEED}.pickle" 169 | 170 | with open(file_name, 'wb') as f: 171 | pickle.dump(obj, f) -------------------------------------------------------------------------------- /utils/schedulers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | import torch.optim as optim 5 | from torch.optim import lr_scheduler as scheduler 6 | from torch.optim.lr_scheduler import LambdaLR 7 | 8 | log = logging.getLogger(__name__) 9 | 10 | 11 | def make_scheduler(optimizer, config, double=False, teacher=False): 12 | warmup = config.WARMUP_EPOCHS 13 | if double: 14 | if teacher: 15 | total_iters = config.t_EPOCHS 16 | else: 17 | total_iters = config.s_EPOCHS 18 | else: 19 | total_iters = config.EPOCHS 20 | if config.SCHEDULER == "cosine": 21 | lr_scheduler = WarmupCosineSchedule( 22 | optimizer, warmup_steps=warmup, t_total=total_iters 23 | ) 24 | elif config.SCHEDULER == "one_warmup_epoch": 25 | lr_scheduler = LambdaLR( 26 | optimizer, 27 | lr_lambda=lambda epoch: config.WARMUP_LR / config.LR if epoch == 0 else 1, 28 | ) 29 | else: 30 | lr_scheduler = scheduler.StepLR( 31 | optimizer, step_size=config.STEP_SIZE, gamma=0.1 32 | ) 33 | return lr_scheduler 34 | 35 | 36 | class WarmupCosineSchedule(LambdaLR): 37 | """Linear warmup and then cosine decay. 38 | Linearly increases learning rate from 0 to 1 over `warmup_steps`. 39 | Decreases learning rate from 1. to 0. over remaining 40 | `t_total - warmup_steps` steps following a cosine curve. 41 | If `cycles` (default=0.5) is different from default, learning rate 42 | follows cosine function after warmup. 43 | """ 44 | 45 | def __init__(self, optimizer, warmup_steps, t_total, cycles=0.5, last_epoch=-1): 46 | self.warmup_steps = warmup_steps 47 | self.t_total = t_total 48 | self.cycles = cycles 49 | # log.info(f"vars: {self.warmup_steps}, {self.t_total}") 50 | super(WarmupCosineSchedule, self).__init__( 51 | optimizer, self.lr_lambda, last_epoch=last_epoch, verbose=True 52 | ) 53 | 54 | def lr_lambda(self, step): 55 | if step < self.warmup_steps: 56 | # log.info(f"STEP: {step}, LR1: {float(step) / float(max(1.0, self.warmup_steps))}") 57 | return float(step) / float(max(1.0, self.warmup_steps)) 58 | # progress after warmup 59 | progress = float(step - self.warmup_steps) / float( 60 | max(1, self.t_total - self.warmup_steps) 61 | ) 62 | # log.info(f"STEP: {step}, LR2: {max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))}") 63 | return max( 64 | 0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.cycles) * 2.0 * progress)) 65 | ) 66 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | log = logging.getLogger(__name__) 9 | 10 | 11 | def dataset_object(dataset_name): 12 | if dataset_name == "aPY": 13 | from data import aPY as DataObject 14 | elif dataset_name == "Animals_with_Attributes2": 15 | from data import AwA2 as DataObject 16 | elif dataset_name == "EuroSAT": 17 | from data import EuroSAT as DataObject 18 | elif dataset_name == "DTD": 19 | from data import DTD as DataObject 20 | elif dataset_name == "sun397": 21 | from data import SUN397 as DataObject 22 | elif dataset_name == "CUB": 23 | from data import CUB as DataObject 24 | elif dataset_name == "RESICS45": 25 | from data import RESICS45 as DataObject 26 | elif dataset_name == "FGVCAircraft": 27 | from data import FGVCAircraft as DataObject 28 | elif dataset_name == "MNIST": 29 | from data import MNIST as DataObject 30 | elif dataset_name == "Flowers102": 31 | from data import Flowers102 as DataObject 32 | 33 | return DataObject 34 | 35 | 36 | def seed_worker(worker_id): 37 | worker_seed = torch.initial_seed() % 2**32 38 | np.random.seed(worker_seed) 39 | random.seed(worker_seed) 40 | 41 | 42 | class Config(object): 43 | def __init__(self, config): 44 | for k, v in config.items(): 45 | setattr(self, k, v) 46 | --------------------------------------------------------------------------------