├── .gitignore
├── README.md
├── data
├── __init__.py
├── class_files
│ ├── DTD
│ │ └── class_names.txt
│ ├── EuroSAT
│ │ └── class_names.txt
│ ├── FGVCAircraft
│ │ └── labels.txt
│ ├── Flowers102
│ │ └── class_names.txt
│ └── MNIST
│ │ └── labels.txt
├── data_splits
│ ├── DTD.json
│ ├── EuroSAT.json
│ ├── FGVCAircraft.json
│ ├── Flowers102.json
│ ├── MNIST.json
│ └── RESICS45.json
├── dataset.py
└── dataset_prompts.py
├── imgs
└── overview.png
├── methods
├── __init__.py
├── clip_baseline.py
├── main_CLIP.py
├── main_SSL.py
├── main_TRZSL.py
├── main_UL.py
├── semi_supervised_learning
│ ├── __init__.py
│ ├── multimodal_fpl.py
│ ├── multimodal_prompt.py
│ ├── pseudo_iterative.py
│ ├── textual_fpl.py
│ ├── textual_prompt.py
│ ├── visual_fpl.py
│ └── visual_prompt.py
├── transductive_zsl
│ ├── __init__.py
│ ├── multimodal_fpl.py
│ ├── multimodal_prompt.py
│ ├── pseudo_iterative.py
│ ├── textual_fpl.py
│ ├── textual_prompt.py
│ ├── visual_fpl.py
│ └── visual_prompt.py
└── unsupervised_learning
│ ├── __init__.py
│ ├── multimodal_fpl.py
│ ├── multimodal_prompt.py
│ ├── pseudo_iterative.py
│ ├── textual_fpl.py
│ ├── textual_prompt.py
│ ├── visual_fpl.py
│ └── visual_prompt.py
├── methods_config
├── accelerate_config.yml
├── accelerate_localtest_config.yml
├── clip_config.yml
├── evaluation_config.yml
├── grip_multimodal_config.yml
├── grip_textual_config.yml
├── grip_visual_config.yml
├── iterative_multimodal_fpl_config.yml
├── iterative_textual_fpl_config.yml
├── iterative_visual_fpl_config.yml
├── multimodal_fpl_config.yml
├── multimodal_prompt_config.yml
├── textual_fpl_config.yml
├── textual_prompt_config.yml
├── visual_fpl_config.yml
└── visual_prompt_config.yml
├── models
├── __init__.py
├── clip_encoders.py
└── prompts_models.py
├── requirements.txt
├── run_main_clip.py
├── run_main_ssl.py
├── run_main_trzsl.py
├── run_main_ul.py
├── scripts
├── run_clip.sh
├── run_prompts_ssl.sh
├── run_prompts_trzsl.sh
├── run_pseudolabels_ssl.sh
├── run_pseudolabels_trzsl.sh
└── run_pseudolabels_ul.sh
├── setup.sh
└── utils
├── __init__.py
├── clip_pseudolabels.py
├── compute_metrics.py
├── prepare_data.py
├── schedulers.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | ########################
2 | # Taglets specific stuff
3 | ########################
4 |
5 | # Data
6 | taglets/sql_data
7 | test/MNIST
8 |
9 | # Saved models
10 | trained_models
11 | taglets/trained_models
12 | test/trained_models
13 |
14 | ########################
15 | # Other stuff
16 | ########################
17 |
18 | # OS X
19 | .DS_Store
20 |
21 | # Byte-compiled / optimized / DLL files
22 | __pycache__/
23 | *.py[cod]
24 | *$py.class
25 |
26 | # C extensions
27 | *.so
28 |
29 | # Distribution / packaging
30 | .Python
31 | build/
32 | develop-eggs/
33 | dist/
34 | downloads/
35 | eggs/
36 | .eggs/
37 | lib/
38 | lib64/
39 | parts/
40 | sdist/
41 | var/
42 | wheels/
43 | pip-wheel-metadata/
44 | share/python-wheels/
45 | *.egg-info/
46 | .installed.cfg
47 | *.egg
48 | MANIFEST
49 |
50 | # PyInstaller
51 | # Usually these files are written by a python script from a template
52 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
53 | *.manifest
54 | *.spec
55 |
56 | # Installer logs
57 | pip-log.txt
58 | pip-delete-this-directory.txt
59 |
60 | # Unit test / coverage reports
61 | htmlcov/
62 | .tox/
63 | .nox/
64 | .coverage
65 | .coverage.*
66 | .cache
67 | nosetests.xml
68 | coverage.xml
69 | *.cover
70 | *.py,cover
71 | .hypothesis/
72 | .pytest_cache/
73 |
74 | # Translations
75 | *.mo
76 | *.pot
77 |
78 | # Django stuff:
79 | *.log
80 | local_settings.py
81 | db.sqlite3
82 | db.sqlite3-journal
83 |
84 | # Flask stuff:
85 | instance/
86 | .webassets-cache
87 |
88 | # Scrapy stuff:
89 | .scrapy
90 |
91 | # Sphinx documentation
92 | docs/_build/
93 |
94 | # PyBuilder
95 | target/
96 |
97 | # Jupyter Notebook
98 | .ipynb_checkpoints
99 |
100 | # IPython
101 | profile_default/
102 | ipython_config.py
103 |
104 | # pyenv
105 | .python-version
106 |
107 | # pipenv
108 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
109 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
110 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
111 | # install all needed dependencies.
112 | #Pipfile.lock
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 | .idea
151 | .vscode
152 |
153 | # Data
154 | /taglets/sql_data
155 | .DS_Store
156 | /taglets/test_data
157 |
158 | # Saved models
159 | /trained_models
160 | /taglets/trained_models
161 | /test/trained_models/*
162 |
163 | # predefined
164 | predefined
165 |
166 | # Notebook
167 | API.ipynb
168 | taglets/JPL_test.ipynb
169 | taglets/task/JPL_test.ipynb
170 | api.ipynb
171 | Look_dataset.ipynb
172 | pseudolabels/
173 | results/
174 | trained_prompts/
175 | #Process_results_file.ipynb
176 | notebooks/
177 | evaluation/
178 | representations/
179 | logs/
180 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Enhancing CLIP with CLIP: Exploring Pseudolabeling for Limited-Labeled Prompt Tuning
2 |
3 | This repository contains the code to reproduce the experiments from the paper ["Enhancing CLIP with CLIP: Pseudolabeling for Limiter-Labeled Prompt Tuning"](http://arxiv.org/abs/2306.01669). The paper explores the effect of leveraging pseudolabels to adapt vision-language models such as CLIP to downstream tasks in a unified way across prompt modalities, learning paradigms, and training strategies.
4 |
5 |
6 |

7 |
8 |
9 |
10 | ## Table of Contents
11 |
12 | - [Setup](#setup)
13 | - [Data](#data)
14 | - [Experiments](#experiments)
15 | - [Results](#results)
16 | - [Citation](#citation)
17 |
18 | ## Setup
19 |
20 | To set up the project environment, follow these steps:
21 |
22 | 1. Ensure that you have Python version 3.7.4 installed. You can check the Python version by running the following command:
23 |
24 | ```bash
25 | python --version
26 | ```
27 |
28 | 2. Clone the repository by running the following command:
29 |
30 | ```bash
31 | git clone https://github.com/BatsResearch/menghini-enhanceCLIPwithCLIP-code.git
32 | ```
33 |
34 | 3. Navigate to the root folder and execute the `setup.sh` script to install the required dependencies, including `pytorch`. Note that we assume the installation of a CUDA-compatible version of `pytorch` since GPUs are recommended for running the experiments. If you don't have access to GPUs, you can modify the script to remove the CUDA requirement.
35 |
36 | ```bash
37 | cd menghini-enhanceCLIPwithCLIP-code/
38 | bash setup.sh
39 | ```
40 |
41 | [Back to Table of Contents](#table-of-contents)
42 |
43 | ## Data
44 |
45 | The experiments are conducted on the following six datasets: **F**lowers102, **R**ECSIS45, **F**GVC-Aircraft, **M**NIST, **E**uroSAT, and **D**TD (FRAMED). We use the train and test splits provided in the paper [ELEVATER: A Benchmark and Toolkit for Evaluating Language-Augmented Visual Models](https://openreview.net/pdf?id=hGl8rsmNXzs).
46 |
47 | To access the FRAMED dataset, you can download it [here](https://drive.google.com/file/d/1_ns7regg8dfAAGmYcmCXa5ryuJeOoug-/view?usp=share_link). After downloading, unzip the folder to obtain the required data.
48 |
49 | If you encounter any issues with the download or prefer an alternative method, you can follow these steps:
50 |
51 | 1. Download the data by following the instructions provided [here](https://github.com/Computer-Vision-in-the-Wild/DataDownload).
52 | 2. Rename the folders as follows:
53 | - `dtd/` to `DTD/`
54 | - `eurosat_clip/` to `EuroSAT/`
55 | - `fgvc-aircraft-2013b-variants102/` to `FGVCAircraft/`
56 | - `oxford-flower-102/` to `Flowers102/`
57 | - `mnist/` to `MNIST/`
58 | - `resisc45_clip/` to `RESICS45/`
59 | 3. Ensure that each folder contains the following files:
60 | - `DTD/` should contain the [`class_names.txt`](https://github.com/BatsResearch/menghini-neurips23-code/blob/main/data/class_files/DTD/class_names.txt) file
61 | - `EuroSAT/` should contain the [`class_names.txt`](https://github.com/BatsResearch/menghini-neurips23-code/blob/main/data/class_files/EuroSAT/class_names.txt) file
62 | - `FGVCAircraft/` should contain the [`labels.txt`](https://github.com/BatsResearch/menghini-neurips23-code/blob/main/data/class_files/FGVCAircraft/labels.txt) file
63 | - `Flowers102/` should contain the [`class_names.txt`](https://github.com/BatsResearch/menghini-neurips23-code/blob/main/data/class_files/Flowers102/class_names.txt) file
64 | - `MNIST/` should contain the [`labels.txt`](https://github.com/BatsResearch/menghini-neurips23-code/blob/main/data/class_files/MNIST/labels.txt) file
65 |
66 | [Back to Table of Contents](#table-of-contents)
67 |
68 | ## Experiments
69 |
70 | Before running the experiments, create the following folders to save prompts, pseudolabels, and results.
71 |
72 | ```bash
73 | mkdir pseudolabels
74 | mkdir logs
75 | mkdir trained_prompts
76 | mkdir evaluation
77 | ```
78 |
79 | We organized the code such that for each learning paradigm we can run any combination of prompt modality and training strategy.
80 |
81 | ### Baselines
82 |
83 | 1. CLIP [1]
84 | ```
85 | bash scripts/run_clip.sh
86 | ```
87 | 2. Standard prompt tuning without pseudolabels: CoOp [2], VPT [3], UPT [4].
88 | - For SSL:
89 | ```
90 | bash scripts/run_prompts_ssl.sh
91 | ```
92 | - For TRZSL:
93 | ```
94 | bash scripts/run_prompts_trzsl.sh
95 | ```
96 |
97 | ### Training strategies employing pseudolabels
98 |
99 | To execute the training strategies employing pseudolabels across prompt modalities run the following
100 |
101 | - For SSL:
102 | ```
103 | bash scripts/run_pseudolabels_ssl.sh
104 | ```
105 |
106 | - For UL:
107 | ```
108 | bash scripts/run_pseudolabels_ul.sh
109 | ```
110 |
111 | - For TRZSL:
112 | ```
113 | bash scripts/run_pseudolabels_trzsl.sh
114 | ```
115 |
116 | Logs of the runs are save in `logs/`.
117 | The folder `pseudolabels/` gathers the pseudolabeled used for each prompt modality, leanring paradigms, and training strategies. For iterative methods, we store them at each iteration.
118 | In `trained_prompts/`, we save the prompts used to make predictions. For iterative methods, we save the prompts at each iteration.
119 | While in `evaluation/` there will be the predictions of each method.
120 |
121 | [Back to Table of Contents](#table-of-contents)
122 |
123 |
124 | [1] [Learning Transferable Visual Models From Natural Language Supervision, Radford et al. 2021](https://arxiv.org/pdf/2103.00020.pdf)
125 |
126 | [2] [Learning to prompt for vision-language models, Zhou et al. 2021](https://arxiv.org/pdf/2109.01134.pdf)
127 |
128 | [3] [Visual prompt tuning, Jia et al. 2022](https://arxiv.org/pdf/2203.12119.pdf)
129 |
130 | [4] [Unified vision and language prompt learning, Zang et al., 2022](https://arxiv.org/pdf/2210.07225.pdf)
131 |
132 | ## Results
133 |
134 | To be filled with the results obtained from the experiments.
135 |
136 |
137 | **Textual prompts**
138 | | | | | | | | | | | |
139 | |---------------------|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
140 | | | | **Flowers102** | | | **RESICS45** | | | **FGVCAircraft** | |
141 | | **Method** | **SSL** | **UL** | **TRZSL** | **SSL** | **UL** | **TRZSL** | **SSL** | **UL** | **TRZSL** |
142 | | CLIP [1] | 63.7 | 63.7 | 63.4 | 54.5 | 54.5 | 54.5 | **17.6** | **17.6** | 17.9 |
143 | | CoOp [2] | 76.8 | - | 63.2 | 58.5 | - | 63.4 | 14.9 | - | 21.7 |
144 | | GRIP | **83.6** | **69.8** | **86.3** | **74.1** | **70.6** | **81.1** | 17.0 | 15.2 | **26.1** |
145 | | | | **MNIST** | | | **EuroSAT** | | | **DTD** | |
146 | | CLIP | 25.1 | 25.1 | 20.8 | 32.9 | 32.9 | 30.5 | 43.2 | 43.2 | 43.4 |
147 | | CoOp [2] | 56.4 | - | 21.2 | 59.5 | - | 49.7 | 37.1 | - | 46.3 |
148 | | GRIP | **71.8** | **67.9** | **74.1** | **58.7** | **57.2** | **92.3** | **56.1** | **46.1** | **65.3** |
149 |
150 | **Visual prompts**
151 | | | | | | | | | | | |
152 | |---------------------|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
153 | | | | **Flowers102** | | | **RESICS45** | | | **FGVCAircraft** | |
154 | | **Method** | **SSL** | **UL** | **TRZSL** | **SSL** | **UL** | **TRZSL** | **SSL** | **UL** | **TRZSL** |
155 | | CLIP [1] | 63.7 | **63.7** | 63.4 | 54.5 | 54.5 | 54.5 | 17.6 | **17.6** | 17.9 |
156 | | VPT [3] | 63.7 | - | 64.7 | 60.8 | - | 67.1 | 17.8 | - | 26.7 |
157 | | GRIP | **67.9** | **63.1** | **77.2** | **71.2** | **68.4** | **82.2** | **19.4** | **17.5** | **26.4** |
158 | | | | **MNIST** | | | **EuroSAT** | | | **DTD** | |
159 | | CLIP [1] | 25.1 | 25.1 | 20.8 | 32.9 | 32.9 | 30.5 | 43.2 | 43.2 | 43.4 |
160 | | VPT [3] | 42.5 | - | 25.5 | 47.1 | - | 62.2 | 36.4 | - | 44.2 |
161 | | GRIP | **69.7** | **68.0** | **69.5** | **63.5** | **63.7** | **97.0** | **54.6** | **50.5** | **62.8** |
162 |
163 | **Multimodal prompts**
164 | | | | | | | | | | | |
165 | |---------------------|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
166 | | | | **Flowers102** | | | **RESICS45** | | | **FGVCAircraft** | |
167 | | **Method** | **SSL** | **UL** | **TRZSL** | **SSL** | **UL** | **TRZSL** | **SSL** | **UL** | **TRZSL** |
168 | | CLIP [1] | 63.7 | **63.7** | 63.4 | 54.5 | 54.5 | 54.5 | **17.6** | **17.6** | **17.9** |
169 | | UPT [4] | 68.0 | - | 61.1 | 62.8 | - | 58.8 | 11.1 | - | 15.9 |
170 | | GRIP | **74.6** | **64.8** | **82.0** | **73.7** | **69.4** | **82.2** | **17.4** | 14.7 | **17.9** |
171 | | | | **MNIST** | | | **EuroSAT** | | | **DTD** | |
172 | | CLIP | 25.1 | 25.1 | 20.8 | 32.9 | 32.9 | 30.5 | 43.2 | 43.2 | 43.4 |
173 | | UPT [4] | **64.4** | - | 63.6 | 68.9 | - | 60.4 | 43.7 | - | 36.9 |
174 | | GRIP | **65.9** | **68.2** | **73.8** | **60.4** | **61.5** | **95.5** | **54.1** | **47.4** | **64.4** |
175 |
176 |
177 |
178 | [Back to Table of Contents](#table-of-contents)
179 |
180 | ## Citation
181 |
182 | If you find this work helpful, please consider citing the following paper:
183 |
184 | ```
185 | @inproceedings{
186 | menghini2023enhancing,
187 | title={Enhancing {CLIP} with {CLIP}: Exploring Pseudolabeling for Limited-Label Prompt Tuning},
188 | author={Cristina Menghini and Andrew Delworth and Stephen Bach},
189 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
190 | year={2023},
191 | url={https://openreview.net/forum?id=2b9aY2NgXE}
192 | }
193 | ```
194 |
195 |
196 | [Back to Table of Contents](#table-of-contents)
197 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset import (
2 | CUB,
3 | DTD,
4 | MNIST,
5 | RESICS45,
6 | CustomDataset,
7 | EuroSAT,
8 | FGVCAircraft,
9 | Flowers102,
10 | )
11 | from .dataset_prompts import dataset_custom_prompts
12 |
--------------------------------------------------------------------------------
/data/class_files/DTD/class_names.txt:
--------------------------------------------------------------------------------
1 | banded
2 | blotchy
3 | braided
4 | bubbly
5 | bumpy
6 | chequered
7 | cobwebbed
8 | cracked
9 | crosshatched
10 | crystalline
11 | dotted
12 | fibrous
13 | flecked
14 | freckled
15 | frilly
16 | gauzy
17 | grid
18 | grooved
19 | honeycombed
20 | interlaced
21 | knitted
22 | lacelike
23 | lined
24 | marbled
25 | matted
26 | meshed
27 | paisley
28 | perforated
29 | pitted
30 | pleated
31 | polka-dotted
32 | porous
33 | potholed
34 | scaly
35 | smeared
36 | spiralled
37 | sprinkled
38 | stained
39 | stratified
40 | striped
41 | studded
42 | swirly
43 | veined
44 | waffled
45 | woven
46 | wrinkled
47 | zigzagged
--------------------------------------------------------------------------------
/data/class_files/EuroSAT/class_names.txt:
--------------------------------------------------------------------------------
1 | annual crop land
2 | forest
3 | brushland or shrubland
4 | highway or road
5 | industrial buildings or commercial buildings
6 | pasture land
7 | permanent crop land
8 | residential buildings or homes or apartments
9 | river
10 | lake or sea
--------------------------------------------------------------------------------
/data/class_files/FGVCAircraft/labels.txt:
--------------------------------------------------------------------------------
1 | 707-320
2 | 727-200
3 | 737-200
4 | 737-300
5 | 737-400
6 | 737-500
7 | 737-600
8 | 737-700
9 | 737-800
10 | 737-900
11 | 747-100
12 | 747-200
13 | 747-300
14 | 747-400
15 | 757-200
16 | 757-300
17 | 767-200
18 | 767-300
19 | 767-400
20 | 777-200
21 | 777-300
22 | A300B4
23 | A310
24 | A318
25 | A319
26 | A320
27 | A321
28 | A330-200
29 | A330-300
30 | A340-200
31 | A340-300
32 | A340-500
33 | A340-600
34 | A380
35 | An-12
36 | ATR-42
37 | ATR-72
38 | BAE 146-200
39 | BAE 146-300
40 | BAE-125
41 | Beechcraft 1900
42 | Boeing 717
43 | C-130
44 | C-47
45 | Cessna 172
46 | Cessna 208
47 | Cessna 525
48 | Cessna 560
49 | Challenger 600
50 | CRJ-200
51 | CRJ-700
52 | CRJ-900
53 | DC-10
54 | DC-3
55 | DC-6
56 | DC-8
57 | DC-9-30
58 | DH-82
59 | DHC-1
60 | DHC-6
61 | DHC-8-100
62 | DHC-8-300
63 | Dornier 328
64 | DR-400
65 | E-170
66 | E-190
67 | E-195
68 | EMB-120
69 | Embraer Legacy 600
70 | ERJ 135
71 | ERJ 145
72 | Eurofighter Typhoon
73 | F-16A-B
74 | F-A-18
75 | Falcon 2000
76 | Falcon 900
77 | Fokker 100
78 | Fokker 50
79 | Fokker 70
80 | Global Express
81 | Gulfstream IV
82 | Gulfstream V
83 | Hawk T1
84 | Il-76
85 | L-1011
86 | MD-11
87 | MD-80
88 | MD-87
89 | MD-90
90 | Metroliner
91 | Model B200
92 | PA-28
93 | Saab 2000
94 | Saab 340
95 | Spitfire
96 | SR-20
97 | Tornado
98 | Tu-134
99 | Tu-154
100 | Yak-42
--------------------------------------------------------------------------------
/data/class_files/Flowers102/class_names.txt:
--------------------------------------------------------------------------------
1 | pink primrose
2 | hard-leaved pocket orchid
3 | canterbury bells
4 | sweet pea
5 | english marigold
6 | tiger lily
7 | moon orchid
8 | bird of paradise
9 | monkshood
10 | globe thistle
11 | snapdragon
12 | colt's foot
13 | king protea
14 | spear thistle
15 | yellow iris
16 | globe flower
17 | purple coneflower
18 | peruvian lily
19 | balloon flower
20 | giant white arum lily
21 | fire lily
22 | pincushion flower
23 | fritillary
24 | red ginger
25 | grape hyacinth
26 | corn poppy
27 | prince of wales feathers
28 | stemless gentian
29 | artichoke
30 | sweet william
31 | carnation
32 | garden phlox
33 | love in the mist
34 | mexican aster
35 | alpine sea holly
36 | ruby-lipped cattleya
37 | cape flower
38 | great masterwort
39 | siam tulip
40 | lenten rose
41 | barbeton daisy
42 | daffodil
43 | sword lily
44 | poinsettia
45 | bolero deep blue
46 | wallflower
47 | marigold
48 | buttercup
49 | oxeye daisy
50 | common dandelion
51 | petunia
52 | wild pansy
53 | primula
54 | sunflower
55 | pelargonium
56 | bishop of llandaff
57 | gaura
58 | geranium
59 | orange dahlia
60 | pink and yellow dahlia
61 | cautleya spicata
62 | japanese anemone
63 | black-eyed susan
64 | silverbush
65 | californian poppy
66 | osteospermum
67 | spring crocus
68 | bearded iris
69 | windflower
70 | tree poppy
71 | gazania
72 | azalea
73 | water lily
74 | rose
75 | thorn apple
76 | morning glory
77 | passion flower
78 | lotus
79 | toad lily
80 | anthurium
81 | frangipani
82 | clematis
83 | hibiscus
84 | columbine
85 | desert-rose
86 | tree mallow
87 | magnolia
88 | cyclamen
89 | watercress
90 | canna lily
91 | hippeastrum
92 | bee balm
93 | air plant
94 | foxglove
95 | bougainvillea
96 | camellia
97 | mallow
98 | mexican petunia
99 | bromelia
100 | blanket flower
101 | trumpet creeper
102 | blackberry lily
--------------------------------------------------------------------------------
/data/class_files/MNIST/labels.txt:
--------------------------------------------------------------------------------
1 | 0
2 | 1
3 | 2
4 | 3
5 | 4
6 | 5
7 | 6
8 | 7
9 | 8
10 | 9
--------------------------------------------------------------------------------
/data/data_splits/DTD.json:
--------------------------------------------------------------------------------
1 | {
2 | "split_500": {
3 | "seen": ["knitted", "pitted", "studded", "bumpy", "spiralled", "scaly", "polka-dotted", "veined", "wrinkled", "banded", "flecked", "stained", "chequered", "sprinkled", "bubbly", "grid", "lined", "crystalline", "fibrous", "meshed", "zigzagged", "pleated", "braided", "perforated", "potholed", "waffled", "dotted", "matted", "gauzy"],
4 | "unseen": ["blotchy", "smeared", "cobwebbed", "cracked", "crosshatched", "stratified", "striped", "swirly", "woven", "freckled", "frilly", "grooved", "honeycombed", "interlaced", "lacelike", "marbled", "paisley", "porous"]
5 | },
6 | "split_0": {
7 | "seen": ["pitted", "scaly", "polka-dotted", "bumpy", "honeycombed", "fibrous", "veined", "porous", "lined", "dotted", "perforated", "potholed", "pleated", "waffled", "braided", "wrinkled", "paisley", "gauzy", "meshed", "grid", "studded", "knitted", "swirly", "crosshatched", "freckled", "chequered", "grooved", "smeared", "frilly"],
8 | "unseen": ["banded", "blotchy", "bubbly", "spiralled", "sprinkled", "cobwebbed", "cracked", "stained", "crystalline", "stratified", "striped", "flecked", "woven", "zigzagged", "interlaced", "lacelike", "marbled", "matted"]
9 | },
10 | "split_200": {
11 | "seen": ["pitted", "pleated", "polka-dotted", "sprinkled", "grooved", "knitted", "matted", "wrinkled", "honeycombed", "chequered", "braided", "zigzagged", "spiralled", "banded", "waffled", "crosshatched", "bubbly", "smeared", "dotted", "porous", "woven", "freckled", "lined", "potholed", "lacelike", "marbled", "stratified", "scaly", "studded"],
12 | "unseen": ["blotchy", "bumpy", "stained", "cobwebbed", "cracked", "striped", "crystalline", "swirly", "fibrous", "flecked", "veined", "frilly", "gauzy", "grid", "interlaced", "meshed", "paisley", "perforated"]
13 | }
14 | }
--------------------------------------------------------------------------------
/data/data_splits/EuroSAT.json:
--------------------------------------------------------------------------------
1 | {
2 | "split_500": {
3 | "seen": ["industrial buildings or commercial buildings", "brushland or shrubland", "lake or sea", "highway or road", "annual crop land", "pasture land"],
4 | "unseen": ["river", "forest", "permanent crop land", "residential buildings or homes or apartments"]
5 | },
6 | "split_0": {
7 | "seen": ["brushland or shrubland", "river", "industrial buildings or commercial buildings", "lake or sea", "forest", "permanent crop land"],
8 | "unseen": ["annual crop land", "highway or road", "pasture land", "residential buildings or homes or apartments"]
9 | },
10 | "split_200": {
11 | "seen": ["river", "highway or road", "pasture land", "permanent crop land", "forest", "residential buildings or homes or apartments"],
12 | "unseen": ["annual crop land", "lake or sea", "brushland or shrubland", "industrial buildings or commercial buildings"]
13 | }
14 | }
--------------------------------------------------------------------------------
/data/data_splits/FGVCAircraft.json:
--------------------------------------------------------------------------------
1 | {
2 | "split_500": {
3 | "seen": ["Tu-134", "Spitfire", "Challenger 600", "737-700", "F-A-18", "E-170", "727-200", "A300B4", "Falcon 2000", "DR-400", "MD-87", "CRJ-700", "ERJ 145", "Falcon 900", "MD-80", "DC-10", "Il-76", "Global Express", "Gulfstream IV", "Saab 340", "Yak-42", "CRJ-900", "L-1011", "A330-200", "A321", "747-300", "DC-3", "A310", "ATR-42", "CRJ-200", "Hawk T1", "Fokker 100", "ATR-72", "PA-28", "A319", "707-320", "A318", "A320", "BAE-125", "747-200", "ERJ 135", "737-800", "SR-20", "BAE 146-300", "Beechcraft 1900", "Cessna 172", "A340-300", "EMB-120", "737-900", "737-400", "Cessna 208", "MD-90", "777-300", "A340-600", "737-600", "737-300", "DHC-1", "DC-6", "A380", "C-47", "767-200", "BAE 146-200"],
4 | "unseen": ["737-200", "737-500", "747-100", "747-400", "757-200", "757-300", "767-300", "767-400", "777-200", "A330-300", "A340-200", "A340-500", "An-12", "Boeing 717", "C-130", "Cessna 525", "Cessna 560", "DC-8", "DC-9-30", "DH-82", "DHC-6", "DHC-8-100", "DHC-8-300", "Dornier 328", "E-190", "E-195", "Embraer Legacy 600", "Eurofighter Typhoon", "F-16A-B", "Fokker 50", "Fokker 70", "Gulfstream V", "MD-11", "Metroliner", "Model B200", "Saab 2000", "Tornado", "Tu-154"]
5 | },
6 | "split_0": {
7 | "seen": ["A321", "MD-80", "737-200", "DC-8", "Falcon 900", "Saab 340", "767-200", "F-A-18", "DC-6", "SR-20", "DC-3", "Saab 2000", "Fokker 70", "747-400", "737-700", "A340-300", "A310", "A319", "A380", "737-800", "C-47", "Dornier 328", "737-300", "Eurofighter Typhoon", "Cessna 208", "Challenger 600", "737-600", "Yak-42", "Hawk T1", "Fokker 100", "DHC-8-100", "Gulfstream IV", "Model B200", "Embraer Legacy 600", "CRJ-900", "A330-200", "767-400", "DC-9-30", "DR-400", "Falcon 2000", "727-200", "DHC-8-300", "C-130", "Boeing 717", "737-400", "757-300", "767-300", "Beechcraft 1900", "BAE 146-300", "737-500", "PA-28", "DHC-6", "707-320", "An-12", "A330-300", "CRJ-700", "747-200", "ATR-42", "A318", "DC-10", "747-100", "A340-500"],
8 | "unseen": ["737-900", "747-300", "757-200", "777-200", "777-300", "A300B4", "A320", "A340-200", "A340-600", "ATR-72", "BAE 146-200", "BAE-125", "Cessna 172", "Cessna 525", "Cessna 560", "CRJ-200", "DH-82", "DHC-1", "E-170", "E-190", "E-195", "EMB-120", "ERJ 135", "ERJ 145", "F-16A-B", "Fokker 50", "Global Express", "Gulfstream V", "Il-76", "L-1011", "MD-11", "MD-87", "MD-90", "Metroliner", "Spitfire", "Tornado", "Tu-134", "Tu-154"]
9 | },
10 | "split_200": {
11 | "seen": ["An-12", "737-200", "F-16A-B", "BAE 146-200", "MD-80", "E-170", "Gulfstream IV", "DR-400", "737-900", "777-200", "Boeing 717", "747-100", "Saab 340", "Cessna 525", "Challenger 600", "MD-90", "DHC-8-100", "Cessna 172", "C-47", "747-400", "BAE-125", "MD-11", "767-300", "Cessna 560", "A330-300", "E-195", "737-500", "Fokker 50", "ATR-72", "BAE 146-300", "Fokker 70", "Falcon 900", "Falcon 2000", "Spitfire", "A340-200", "DC-3", "A340-300", "Beechcraft 1900", "A320", "Hawk T1", "E-190", "Gulfstream V", "Tu-134", "767-400", "CRJ-200", "737-400", "747-300", "Eurofighter Typhoon", "PA-28", "MD-87", "Yak-42", "DHC-1", "737-800", "A380", "Model B200", "ERJ 135", "SR-20", "737-300", "707-320", "DC-10", "Dornier 328", "A300B4"],
12 | "unseen": ["727-200", "737-600", "737-700", "747-200", "757-200", "757-300", "767-200", "777-300", "A310", "A318", "A319", "A321", "A330-200", "A340-500", "A340-600", "ATR-42", "C-130", "Cessna 208", "CRJ-700", "CRJ-900", "DC-6", "DC-8", "DC-9-30", "DH-82", "DHC-6", "DHC-8-300", "EMB-120", "Embraer Legacy 600", "ERJ 145", "F-A-18", "Fokker 100", "Global Express", "Il-76", "L-1011", "Metroliner", "Saab 2000", "Tornado", "Tu-154"]
13 | }
14 | }
--------------------------------------------------------------------------------
/data/data_splits/Flowers102.json:
--------------------------------------------------------------------------------
1 | {
2 | "split_500": {
3 | "seen": ["canna lily", "petunia", "silverbush", "prince of wales feathers", "pincushion flower", "bird of paradise", "frangipani", "hard-leaved pocket orchid", "bearded iris", "passion flower", "tiger lily", "lenten rose", "cape flower", "air plant", "mexican petunia", "common dandelion", "magnolia", "foxglove", "hibiscus", "camellia", "orange dahlia", "clematis", "anthurium", "bougainvillea", "ruby-lipped cattleya", "stemless gentian", "oxeye daisy", "spring crocus", "king protea", "cyclamen", "fritillary", "californian poppy", "wild pansy", "desert-rose", "sunflower", "rose", "grape hyacinth", "pink primrose", "red ginger", "corn poppy", "watercress", "colt's foot", "blanket flower", "monkshood", "morning glory", "siam tulip", "barbeton daisy", "bolero deep blue", "carnation", "tree poppy", "globe thistle", "english marigold", "primula", "wallflower", "blackberry lily", "fire lily", "love in the mist", "moon orchid", "sweet pea", "mallow", "pelargonium", "mexican aster", "poinsettia"],
4 | "unseen": ["canterbury bells", "snapdragon", "spear thistle", "yellow iris", "globe flower", "purple coneflower", "peruvian lily", "balloon flower", "giant white arum lily", "artichoke", "sweet william", "garden phlox", "alpine sea holly", "great masterwort", "daffodil", "sword lily", "marigold", "buttercup", "bishop of llandaff", "gaura", "geranium", "pink and yellow dahlia", "cautleya spicata", "japanese anemone", "black-eyed susan", "osteospermum", "windflower", "gazania", "azalea", "water lily", "thorn apple", "lotus", "toad lily", "columbine", "tree mallow", "hippeastrum", "bee balm", "bromelia", "trumpet creeper"]
5 | },
6 | "split_0": {
7 | "seen": ["prince of wales feathers", "air plant", "canterbury bells", "bishop of llandaff", "bee balm", "desert-rose", "purple coneflower", "spring crocus", "pelargonium", "windflower", "sunflower", "bougainvillea", "rose", "spear thistle", "bird of paradise", "carnation", "fritillary", "grape hyacinth", "mexican aster", "monkshood", "poinsettia", "black-eyed susan", "sweet pea", "anthurium", "wallflower", "oxeye daisy", "moon orchid", "blackberry lily", "hibiscus", "frangipani", "cautleya spicata", "camellia", "canna lily", "passion flower", "wild pansy", "stemless gentian", "balloon flower", "gaura", "thorn apple", "morning glory", "hard-leaved pocket orchid", "japanese anemone", "sword lily", "daffodil", "english marigold", "globe flower", "peruvian lily", "barbeton daisy", "siam tulip", "tiger lily", "foxglove", "pink and yellow dahlia", "pink primrose", "alpine sea holly", "artichoke", "petunia", "colt's foot", "ruby-lipped cattleya", "red ginger", "primula", "snapdragon", "garden phlox", "mexican petunia"],
8 | "unseen": ["globe thistle", "king protea", "yellow iris", "giant white arum lily", "fire lily", "pincushion flower", "corn poppy", "sweet william", "love in the mist", "cape flower", "great masterwort", "lenten rose", "bolero deep blue", "marigold", "buttercup", "common dandelion", "geranium", "orange dahlia", "silverbush", "californian poppy", "osteospermum", "bearded iris", "tree poppy", "gazania", "azalea", "water lily", "lotus", "toad lily", "clematis", "columbine", "tree mallow", "magnolia", "cyclamen", "watercress", "hippeastrum", "mallow", "bromelia", "blanket flower", "trumpet creeper"]
9 | },
10 | "split_200": {
11 | "seen": ["oxeye daisy", "canterbury bells", "clematis", "siam tulip", "cape flower", "black-eyed susan", "air plant", "californian poppy", "globe thistle", "giant white arum lily", "cyclamen", "snapdragon", "frangipani", "buttercup", "common dandelion", "hippeastrum", "columbine", "spring crocus", "bolero deep blue", "spear thistle", "barbeton daisy", "poinsettia", "peruvian lily", "alpine sea holly", "artichoke", "sunflower", "tiger lily", "toad lily", "magnolia", "lenten rose", "great masterwort", "camellia", "mallow", "morning glory", "lotus", "sweet william", "thorn apple", "carnation", "daffodil", "corn poppy", "cautleya spicata", "marigold", "hibiscus", "tree poppy", "balloon flower", "osteospermum", "english marigold", "king protea", "azalea", "foxglove", "watercress", "blackberry lily", "bearded iris", "monkshood", "mexican aster", "orange dahlia", "water lily", "mexican petunia", "sweet pea", "pink primrose", "primula", "silverbush", "pincushion flower"],
12 | "unseen": ["hard-leaved pocket orchid", "moon orchid", "bird of paradise", "colt's foot", "yellow iris", "globe flower", "purple coneflower", "fire lily", "fritillary", "red ginger", "grape hyacinth", "prince of wales feathers", "stemless gentian", "garden phlox", "love in the mist", "ruby-lipped cattleya", "sword lily", "wallflower", "petunia", "wild pansy", "pelargonium", "bishop of llandaff", "gaura", "geranium", "pink and yellow dahlia", "japanese anemone", "windflower", "gazania", "rose", "passion flower", "anthurium", "desert-rose", "tree mallow", "canna lily", "bee balm", "bougainvillea", "bromelia", "blanket flower", "trumpet creeper"]
13 | }
14 | }
--------------------------------------------------------------------------------
/data/data_splits/MNIST.json:
--------------------------------------------------------------------------------
1 | {
2 | "split_500": {
3 | "seen": ["4", "2", "9", "3", "0", "5"],
4 | "unseen": ["8", "1", "6", "7"]
5 | },
6 | "split_0": {
7 | "seen": ["2", "8", "4", "9", "1", "6"],
8 | "unseen": ["0", "3", "5", "7"]
9 | },
10 | "split_200": {
11 | "seen": ["8", "3", "5", "6", "1", "7"],
12 | "unseen": ["0", "9", "2", "4"]
13 | }
14 | }
--------------------------------------------------------------------------------
/data/data_splits/RESICS45.json:
--------------------------------------------------------------------------------
1 | {
2 | "split_500": {
3 | "seen": ["beach", "palace", "roundabout", "railway station", "railway", "thermal power station", "river", "airplane", "island", "bridge", "basketball court", "desert", "runway", "ground track field", "sea ice", "sparse residential", "cloud", "dense residential", "wetland", "mountain", "meadow", "baseball diamond", "parking lot", "storage tank", "tennis court", "commercial area", "mobile home park"],
4 | "unseen": ["airport", "ship", "snowberg", "chaparral", "church", "circular farmland", "stadium", "terrace", "forest", "freeway", "golf course", "harbor", "industrial area", "intersection", "lake", "medium residential", "overpass", "rectangular farmland"]
5 | },
6 | "split_0": {
7 | "seen": ["railway station", "snowberg", "palace", "beach", "commercial area", "mountain", "parking lot", "dense residential", "sparse residential", "rectangular farmland", "railway", "island", "tennis court", "baseball diamond", "thermal power station", "industrial area", "golf course", "meadow", "ground track field", "storage tank", "circular farmland", "forest", "bridge", "harbor", "river", "freeway", "sea ice"],
8 | "unseen": ["airplane", "airport", "roundabout", "basketball court", "runway", "ship", "chaparral", "church", "stadium", "cloud", "terrace", "desert", "wetland", "intersection", "lake", "medium residential", "mobile home park", "overpass"]
9 | },
10 | "split_200": {
11 | "seen": ["railway", "parking lot", "wetland", "meadow", "harbor", "island", "mobile home park", "storage tank", "industrial area", "bridge", "baseball diamond", "sea ice", "runway", "airplane", "thermal power station", "circular farmland", "basketball court", "roundabout", "commercial area", "railway station", "terrace", "forest", "rectangular farmland", "lake", "medium residential", "snowberg", "river"],
12 | "unseen": ["airport", "beach", "ship", "chaparral", "church", "sparse residential", "cloud", "stadium", "dense residential", "desert", "tennis court", "freeway", "golf course", "ground track field", "intersection", "mountain", "overpass", "palace"]
13 | }
14 | }
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | import torch
5 | import torch.nn as nn
6 | from PIL import Image
7 | from torch.utils.data import Dataset
8 |
9 | log = logging.getLogger(__name__)
10 |
11 |
12 | class CustomDataset(Dataset):
13 | def __init__(
14 | self,
15 | filepaths,
16 | root,
17 | transform,
18 | augmentations=None,
19 | train=True,
20 | labels=None,
21 | label_id=False,
22 | label_map=None,
23 | ):
24 | """
25 | :param filepaths: list of images
26 | :param root: path to images
27 | :param transform: standard transform
28 | :param augmentations: None or tuple
29 | :param train: indicates in the data is in train or test folder
30 | :param labels: list of label
31 | :param label_id: true if labeles are passed as int
32 | :param label_map: dict mpping string labels to int
33 | """
34 | # Adjust filepaths
35 | self.train = train
36 | if self.train:
37 | self.filepaths = [f"{root}/train/{f}" for f in filepaths]
38 | else:
39 | self.filepaths = [f"{root}/test/{f}" for f in filepaths]
40 |
41 | self.transform = transform
42 | if augmentations:
43 | self.aug1_transform = augmentations[0]
44 | self.aug2_transform = augmentations[1]
45 | else:
46 | self.aug1_transform = None
47 | self.aug2_transform = None
48 | self.labels = labels
49 | self.label_id = label_id
50 | self.label_map = label_map
51 |
52 | def __len__(self):
53 | # dataset size
54 | return len(self.filepaths)
55 |
56 | def __getitem__(self, index: int):
57 | """
58 | Args:
59 | index (int): Index
60 | Returns:
61 | tuple: (image, aug1, aug2, target) where target is index of the target class.
62 | """
63 |
64 | img = Image.open(self.filepaths[index]).convert("RGB")
65 |
66 | # Apply two transformations (strong and weak)
67 | if self.aug1_transform is not None:
68 | aug_1 = self.aug1_transform(img)
69 | else:
70 | img1 = self.transform(img)
71 | aug_1 = img1
72 | if self.aug2_transform is not None:
73 | aug_2 = self.aug2_transform(img)
74 | else:
75 | img2 = self.transform(img)
76 | aug_2 = img2
77 |
78 | if self.transform is not None:
79 | img = self.transform(img)
80 |
81 | # Get image label
82 | if self.labels is not None:
83 | if self.label_id:
84 | label = int(self.labels[index])
85 | else:
86 | label = int(self.label_map[self.labels[index]])
87 | return img, aug_1, aug_2, label, self.filepaths[index].split("/")[-1]
88 | else:
89 | return img, aug_1, aug_2, self.filepaths[index].split("/")[-1]
90 |
91 |
92 |
93 | class EuroSAT(CustomDataset):
94 | def __init__(
95 | self,
96 | filepaths,
97 | root,
98 | transform,
99 | augmentations=None,
100 | train=True,
101 | labels=None,
102 | label_id=False,
103 | label_map=None,
104 | class_folder=False,
105 | original_filepaths=None,
106 | ):
107 | """
108 | :param filepaths: list of images
109 | :param root: path to images
110 | :param transform: standard transform
111 | :param augmentations: None or tuple
112 | :param train: indicates in the data is in train or test folder
113 | :param labels: list of label
114 | :param label_id: true if labeles are passed as int
115 | :param label_map: dict mpping string labels to int
116 | """
117 | super().__init__(
118 | filepaths,
119 | root,
120 | transform,
121 | augmentations,
122 | train,
123 | labels,
124 | label_id,
125 | label_map,
126 | )
127 | # Adjust filepaths
128 | self.filepaths = [f"{root}/{f.split('_')[0]}/{f}" for f in filepaths]
129 |
130 |
131 | class DTD(CustomDataset):
132 | def __init__(
133 | self,
134 | filepaths,
135 | root,
136 | transform,
137 | augmentations=None,
138 | train=True,
139 | labels=None,
140 | label_id=False,
141 | label_map=None,
142 | class_folder=False,
143 | original_filepaths=None,
144 | ):
145 | """
146 | :param filepaths: list of images
147 | :param root: path to images
148 | :param transform: standard transform
149 | :param augmentations: None or tuple
150 | :param train: indicates in the data is in train or test folder
151 | :param labels: list of label
152 | :param label_id: true if labeles are passed as int
153 | :param label_map: dict mpping string labels to int
154 | """
155 | super().__init__(
156 | filepaths,
157 | root,
158 | transform,
159 | augmentations,
160 | train,
161 | labels,
162 | label_id,
163 | label_map,
164 | )
165 | # Adjust filepaths
166 | if class_folder:
167 | paths = []
168 | for f in filepaths:
169 | cl = f.split("_")[0]
170 | tr_files = os.listdir(f"{root}/train/{cl}")
171 | val_files = os.listdir(f"{root}/val/{cl}")
172 | if f in tr_files:
173 | paths.append(f"{root}/train/{cl}/{f}")
174 | elif f in val_files:
175 | paths.append(f"{root}/val/{cl}/{f}")
176 |
177 | self.filepaths = paths
178 |
179 | else:
180 | self.filepaths = [f"{root}/{f}" for f in filepaths]
181 |
182 |
183 | class CUB(CustomDataset):
184 | def __init__(
185 | self,
186 | filepaths,
187 | root,
188 | transform,
189 | augmentations=None,
190 | train=True,
191 | labels=None,
192 | label_id=False,
193 | label_map=None,
194 | class_folder=False,
195 | original_filepaths=None,
196 | ):
197 | """
198 | :param filepaths: list of images
199 | :param root: path to images
200 | :param transform: standard transform
201 | :param augmentations: None or tuple
202 | :param train: indicates in the data is in train or test folder
203 | :param labels: list of label
204 | :param label_id: true if labeles are passed as int
205 | :param label_map: dict mpping string labels to int
206 | """
207 | super().__init__(
208 | filepaths,
209 | root,
210 | transform,
211 | augmentations,
212 | train,
213 | labels,
214 | label_id,
215 | label_map,
216 | )
217 | # Adjust filepaths
218 | self.filepaths = [f"{root}/{f}" for f in filepaths]
219 |
220 |
221 | class RESICS45(CustomDataset):
222 | def __init__(
223 | self,
224 | filepaths,
225 | root,
226 | transform,
227 | augmentations=None,
228 | train=True,
229 | labels=None,
230 | label_id=False,
231 | label_map=None,
232 | class_folder=False,
233 | original_filepaths=None,
234 | ):
235 | """
236 | :param filepaths: list of images
237 | :param root: path to images
238 | :param transform: standard transform
239 | :param augmentations: None or tuple
240 | :param train: indicates in the data is in train or test folder
241 | :param labels: list of label
242 | :param label_id: true if labeles are passed as int
243 | :param label_map: dict mpping string labels to int
244 | """
245 | super().__init__(
246 | filepaths,
247 | root,
248 | transform,
249 | augmentations,
250 | train,
251 | labels,
252 | label_id,
253 | label_map,
254 | )
255 | # Adjust filepaths
256 | self.filepaths = []
257 | for f in filepaths:
258 | folder = "_".join(f.split("_")[:-1])
259 | self.filepaths.append(f"{root}/{folder}/{f}")
260 |
261 |
262 | class FGVCAircraft(CustomDataset):
263 | def __init__(
264 | self,
265 | filepaths,
266 | root,
267 | transform,
268 | augmentations=None,
269 | train=True,
270 | labels=None,
271 | label_id=False,
272 | label_map=None,
273 | class_folder=False,
274 | original_filepaths=None,
275 | ):
276 | """
277 | :param filepaths: list of images
278 | :param root: path to images
279 | :param transform: standard transform
280 | :param augmentations: None or tuple
281 | :param train: indicates in the data is in train or test folder
282 | :param labels: list of label
283 | :param label_id: true if labeles are passed as int
284 | :param label_map: dict mpping string labels to int
285 | """
286 | super().__init__(
287 | filepaths,
288 | root,
289 | transform,
290 | augmentations,
291 | train,
292 | labels,
293 | label_id,
294 | label_map,
295 | )
296 | if class_folder:
297 | filepaths = list(filepaths)
298 | new_paths = []
299 | for f in original_filepaths:
300 | img = f.split("/")[-1]
301 | if img in filepaths:
302 | new_paths.append(f"{f}")
303 |
304 | self.filepaths = new_paths
305 | else:
306 | # Adjust filepaths
307 | self.filepaths = [f"{root}/{f}" for f in filepaths]
308 |
309 |
310 | class MNIST(CustomDataset):
311 | def __init__(
312 | self,
313 | filepaths,
314 | root,
315 | transform,
316 | augmentations=None,
317 | train=True,
318 | labels=None,
319 | label_id=False,
320 | label_map=None,
321 | class_folder=False,
322 | original_filepaths=None,
323 | ):
324 | """
325 | :param filepaths: list of images
326 | :param root: path to images
327 | :param transform: standard transform
328 | :param augmentations: None or tuple
329 | :param train: indicates in the data is in train or test folder
330 | :param labels: list of label
331 | :param label_id: true if labeles are passed as int
332 | :param label_map: dict mpping string labels to int
333 | """
334 | super().__init__(
335 | filepaths,
336 | root,
337 | transform,
338 | augmentations,
339 | train,
340 | labels,
341 | label_id,
342 | label_map,
343 | )
344 | if class_folder:
345 | filepaths = list(filepaths)
346 | new_paths = []
347 | for f in original_filepaths:
348 | img = f.split("/")[-1]
349 | if img in filepaths:
350 | new_paths.append(f"{f}")
351 |
352 | self.filepaths = new_paths
353 | else:
354 | # Adjust filepaths
355 | self.filepaths = [f"{root}/{f}" for f in filepaths]
356 |
357 |
358 | class Flowers102(CustomDataset):
359 | def __init__(
360 | self,
361 | filepaths,
362 | root,
363 | transform,
364 | augmentations=None,
365 | train=True,
366 | labels=None,
367 | label_id=False,
368 | label_map=None,
369 | class_folder=False,
370 | original_filepaths=None,
371 | ):
372 | """
373 | :param filepaths: list of images
374 | :param root: path to images
375 | :param transform: standard transform
376 | :param augmentations: None or tuple
377 | :param train: indicates in the data is in train or test folder
378 | :param labels: list of label
379 | :param label_id: true if labeles are passed as int
380 | :param label_map: dict mpping string labels to int
381 | """
382 | super().__init__(
383 | filepaths,
384 | root,
385 | transform,
386 | augmentations,
387 | train,
388 | labels,
389 | label_id,
390 | label_map,
391 | )
392 | # Adjust filepaths
393 | if class_folder:
394 | filepaths = list(filepaths)
395 | new_paths = []
396 | for f in original_filepaths:
397 | img = f.split("/")[-1]
398 | if img in filepaths:
399 | new_paths.append(f"{f}")
400 |
401 | self.filepaths = new_paths
402 |
403 | else:
404 | self.filepaths = [f"{root}/{f}" for f in filepaths]
405 |
--------------------------------------------------------------------------------
/data/dataset_prompts.py:
--------------------------------------------------------------------------------
1 | dataset_custom_prompts = {
2 | 'EuroSAT' : 'a photo of a {}', # 'a centered satellite photo of a {}',
3 | 'DTD' : 'a photo of a {}', # 'a photo of a {} texture',
4 | 'RESICS45' : 'a photo of a {}', # 'satellite imagery of a {}',
5 | 'FGVCAircraft' : 'a photo of a {}', # 'a photo of a {}, a type of aircraft',
6 | 'MNIST' : 'a photo of a {}', # 'a photo of the number: "{}"',
7 | 'Flowers102' : 'a photo of a {}', # 'a photo of a {}, a type of flower',
8 | }
--------------------------------------------------------------------------------
/imgs/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BatsResearch/menghini-neurips23-code/b387624c15638d9db8024135cafe386089fbe09f/imgs/overview.png
--------------------------------------------------------------------------------
/methods/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip_baseline import ClipBaseline
--------------------------------------------------------------------------------
/methods/clip_baseline.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import clip
4 | import numpy as np
5 | import pandas as pd
6 | import torch
7 | from PIL import Image
8 | from tqdm import tqdm
9 |
10 |
11 | g = torch.Generator()
12 | g.manual_seed(0)
13 |
14 | log = logging.getLogger(__name__)
15 |
16 |
17 | class ClipBaseline(object):
18 | def __init__(
19 | self, config, label_to_idx, classes, seen_classes, unseen_classes, device
20 | ):
21 | """This class is CLIP model.
22 |
23 | :param config: class object with model configurations in the
24 | file models_config/clip_baseline_config.yml
25 | :param label_to_idx: dictionary (key, value):(class name, id)
26 | :param classes: list of class names
27 | :param seen_classes: list on seen classes' names
28 | :param unseen_classes: list on unseen classes' names
29 | :param device: device in use
30 |
31 | """
32 | self.config = config
33 | self.classes = classes
34 | self.seen_classes = seen_classes
35 | self.unseen_classes = unseen_classes
36 | self.label_to_idx = label_to_idx
37 |
38 | self.device = device
39 | self.model, self.transform = clip.load(
40 | self.config.VIS_ENCODER, device=self.device
41 | )
42 | self.template = self.config.PROMPT_TEMPLATE
43 |
44 | def test_predictions(self, data):
45 | """
46 | :param data: test dataset
47 | """
48 |
49 | # Declare data pre-processing
50 | data.transform = self.transform
51 | # Define the data loader
52 | test_loader = torch.utils.data.DataLoader(
53 | data, batch_size=self.config.BATCH_SIZE
54 | )
55 |
56 | # Build textual prompts
57 | prompts = [
58 | self.template.format(" ".join(i.split("_"))) for i in self.classes
59 | ]
60 | log.info(f"Number of prompts: {len(prompts)}")
61 |
62 | # Encode text
63 | text = clip.tokenize(prompts).to(self.device)
64 | text_features = self.model.encode_text(text)
65 |
66 | log.info(f"Start inference for test data")
67 | predictions = []
68 | images = []
69 | prob_preds = []
70 | # Iterate through data loader
71 | for img, _, _, img_path in tqdm(test_loader):
72 | with torch.no_grad():
73 | logits_per_image, logits_per_text = self.model(
74 | img.to(self.device), text.to(self.device)
75 | )
76 | probs = logits_per_image.softmax(dim=-1)
77 | idx_preds = torch.argmax(probs, dim=1)
78 |
79 | predictions += [self.classes[i] for i in idx_preds]
80 | images += [i for i in img_path]
81 | prob_preds += [logits_per_image]
82 |
83 | prob_preds = torch.cat(prob_preds, axis=0).detach().to('cpu')
84 | df_predictions = pd.DataFrame({"id": images, "class": predictions})
85 |
86 | return df_predictions, images, predictions, prob_preds
87 |
--------------------------------------------------------------------------------
/methods/main_CLIP.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import logging
4 | import os
5 | import random
6 | import sys
7 | from collections import defaultdict
8 | from logging import StreamHandler
9 | from pathlib import Path
10 |
11 | import numpy as np
12 | import pandas as pd
13 | import requests
14 | import scipy.stats as st
15 | import torch
16 | import yaml
17 | from accelerate import Accelerator
18 | from requests.adapters import HTTPAdapter
19 | from torch import nn
20 | from urllib3.util import Retry
21 |
22 | from data import CustomDataset, dataset_custom_prompts
23 | from methods import ClipBaseline
24 | from utils import (
25 | Config,
26 | dataset_object,
27 | evaluate_predictions,
28 | get_class_names,
29 | get_labeled_and_unlabeled_data,
30 | save_parameters,
31 | save_predictions,
32 | store_results,
33 | )
34 |
35 | accelerator = Accelerator()
36 |
37 | logger_ = logging.getLogger()
38 | logger_.level = logging.INFO
39 | formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s")
40 |
41 |
42 | class AccelerateHandler(StreamHandler):
43 | def __init__(self, stream):
44 | super().__init__(stream)
45 |
46 | def emit(self, record):
47 | if accelerator.is_local_main_process:
48 | super().emit(record)
49 |
50 |
51 | stream_handler = AccelerateHandler(sys.stdout)
52 | stream_handler.setLevel(logging.INFO)
53 | stream_handler.setFormatter(formatter)
54 | logger_.addHandler(stream_handler)
55 |
56 | log = logging.getLogger(__name__)
57 |
58 | def workflow(dataset_dir, obj_conf):
59 |
60 | dataset = obj_conf.DATASET_NAME
61 | # Get class names of target task
62 | classes, seen_classes, unseen_classes = get_class_names(
63 | dataset,
64 | dataset_dir,
65 | obj_conf.SPLIT_SEED
66 | )
67 | # Create dict classes to pass as variable
68 | dict_classes = {
69 | "classes": classes,
70 | "seen_classes": seen_classes,
71 | "unseen_classes": unseen_classes,
72 | }
73 | # Log number of classes
74 | log.info(f"\n----------------------DATA INFO-----------------------\n")
75 | log.info(f"Number of classes split {obj_conf.SPLIT_SEED}: {len(classes)}")
76 | log.info(f"Number of seen classes split {obj_conf.SPLIT_SEED}: {len(seen_classes)}")
77 | log.info(f"Number of unseen classes split {obj_conf.SPLIT_SEED}: {len(unseen_classes)}")
78 | # Path for images
79 | data_folder = f"{dataset_dir}/{dataset}"
80 | log.info(f"Data folder: {data_folder}")
81 | log.info(f"\n-------------------------------------------------------------\n")
82 |
83 | # Get labeled data (seen classes)
84 | # Get unlabeled data (unseen classes)
85 | # Get test data (both seen and unseen classes)
86 | _, _, test_data = get_labeled_and_unlabeled_data(
87 | dataset, data_folder, seen_classes, unseen_classes, classes
88 | )
89 |
90 | # Create datasets
91 | test_labeled_files, test_labeles = zip(*test_data)
92 | label_to_idx = {c: idx for idx, c in enumerate(classes)}
93 |
94 | DatasetObject = dataset_object(obj_conf.DATASET_NAME)
95 |
96 | # Test set (test seen and unseen)
97 | test_dataset = DatasetObject(
98 | test_labeled_files,
99 | data_folder,
100 | transform=None,
101 | augmentations=None,
102 | train=False,
103 | labels=None,
104 | label_map=label_to_idx,
105 | )
106 | # Log info data
107 | log.info(f"\n----------------------TRAINING DATA INFO-----------------------\n")
108 | log.info(f"Len test data: {len(test_dataset.filepaths)}")
109 | log.info(f"\n-------------------------------------------------------------\n")
110 | # Define model
111 | device = "cuda" if torch.cuda.is_available() else "cpu"
112 |
113 |
114 | log.info(f"\n----------------------MODEL INFO-----------------------\n")
115 | log.info(f"The model in use is: {obj_conf.MODEL}")
116 | model = ClipBaseline(obj_conf, label_to_idx, device=device, **dict_classes)
117 |
118 | # Validate on test set (standard)
119 | std_predictions, images, predictions, prob_preds = model.test_predictions(test_dataset)
120 | # Submit predictions (standard)
121 | std_response = evaluate_predictions(
122 | obj_conf,
123 | std_predictions,
124 | test_labeled_files,
125 | test_labeles,
126 | unseen_classes,
127 | seen_classes
128 | )
129 | log.info(f"ZSL accuracy: {std_response}")
130 |
131 | # Store model results
132 | store_results(obj_conf, std_response)
133 |
134 | dictionary_predictions = {
135 | 'images' : images,
136 | 'predictions' : predictions,
137 | 'labels' : test_labeles,
138 | 'logits' : prob_preds,
139 | }
140 |
141 | save_predictions(dictionary_predictions, obj_conf, iteration=None)
142 |
143 |
144 | def main():
145 | parser = argparse.ArgumentParser(description="Run JPL task")
146 | parser.add_argument(
147 | "--model_config",
148 | type=str,
149 | default="model_config.yml",
150 | help="Name of model config file",
151 | )
152 | parser.add_argument(
153 | "--learning_paradigm",
154 | type=str,
155 | default="trzsl",
156 | help="Choose among trzsl, ssl, and ul",
157 | )
158 |
159 | args = parser.parse_args()
160 |
161 | with open(f"methods_config/{args.model_config}", "r") as file:
162 | config = yaml.safe_load(file)
163 |
164 | # Cast configs to object
165 | obj_conf = Config(config)
166 |
167 | # Set seed seed
168 | optim_seed = int(os.environ["OPTIM_SEED"])
169 | obj_conf.OPTIM_SEED = optim_seed
170 | # Set backbone
171 | obj_conf.VIS_ENCODER = os.environ["VIS_ENCODER"]
172 | # Set dataset name
173 | obj_conf.DATASET_NAME = os.environ["DATASET_NAME"]
174 | # Set dataset dir
175 | obj_conf.DATASET_DIR = os.environ["DATASET_DIR"]
176 | # Set model name
177 | obj_conf.MODEL = os.environ["MODEL"]
178 | # Set split seed
179 | obj_conf.SPLIT_SEED = int(os.environ["SPLIT_SEED"])
180 | # Set dataset's template for textual prompts
181 | obj_conf.PROMPT_TEMPLATE = dataset_custom_prompts[obj_conf.DATASET_NAME]
182 | # Set data dir
183 | dataset_dir = obj_conf.DATASET_DIR
184 | # Set learning paradigm
185 | obj_conf.LEARNING_PARADIGM = args.learning_paradigm
186 |
187 | # Set the file path for the log file
188 | log_file = f"logs/{obj_conf.DATASET_NAME}_{obj_conf.MODEL}_{obj_conf.VIS_ENCODER.replace('/', '-')}.log"
189 | # Create a FileHandler and set the log file
190 | file_handler = logging.FileHandler(log_file)
191 | file_handler.setFormatter(formatter)
192 | # Add the FileHandler to the logger
193 | logger_.addHandler(file_handler)
194 |
195 | log.info(f"Current working directory: {os.getcwd()}")
196 | log.info(f"Dataset dir: {dataset_dir}")
197 | # Check dataset directory exists
198 | if not Path(dataset_dir).exists():
199 | print(dataset_dir)
200 | raise Exception("`dataset_dir` does not exist..")
201 |
202 | # Set random seeds
203 | device = "cuda" if torch.cuda.is_available() else "cpu"
204 | np.random.seed(obj_conf.OPTIM_SEED)
205 | random.seed(obj_conf.OPTIM_SEED)
206 | torch.manual_seed(obj_conf.OPTIM_SEED)
207 | accelerator.wait_for_everyone()
208 | # Seed for cuda
209 | if torch.cuda.is_available():
210 | torch.cuda.manual_seed(obj_conf.OPTIM_SEED)
211 | torch.cuda.manual_seed_all(obj_conf.OPTIM_SEED)
212 | accelerator.wait_for_everyone()
213 |
214 | torch.backends.cudnn.benchmark = True
215 |
216 | workflow(dataset_dir, obj_conf)
217 |
218 |
219 | if __name__ == "__main__":
220 | main()
221 |
--------------------------------------------------------------------------------
/methods/semi_supervised_learning/__init__.py:
--------------------------------------------------------------------------------
1 | from .training_strategies import TrainingStrategy
2 | from .visual_prompt import VisualPrompt
3 | from .textual_prompt import TextualPrompt
4 | from .multimodal_prompt import MultimodalPrompt
5 | from .visual_fpl import VisualFPL
6 | from .textual_fpl import TextualFPL
7 | from .multimodal_fpl import MultimodalFPL
8 |
--------------------------------------------------------------------------------
/methods/semi_supervised_learning/multimodal_fpl.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import clip
4 | import numpy as np
5 | import pandas as pd
6 | import torch
7 | from accelerate import Accelerator
8 | from PIL import Image
9 | from torch import nn
10 | import math
11 |
12 | accelerator = Accelerator()
13 |
14 | from methods.semi_supervised_learning import MultimodalPrompt
15 | from utils import (
16 | dataset_object,
17 | make_scheduler,
18 | pseudolabel_top_k,
19 | )
20 |
21 |
22 | log = logging.getLogger(__name__)
23 |
24 |
25 | class MultimodalFPL(MultimodalPrompt):
26 | def __init__(
27 | self,
28 | config,
29 | label_to_idx,
30 | data_folder,
31 | unlabeled_files,
32 | classes,
33 | seen_classes,
34 | unseen_classes,
35 | device
36 | ):
37 | """This class defines self-trainig UPT's training and evaluation.
38 | :param config: dictionaries of prameters in models_config/upt_baseline_pseudo_config.yml
39 | :param label_to_idx: dictionary (key, value):(class name, id)
40 | :param classes: list of class names
41 | :param seen_classes: list of seen classes' names
42 | :param unseen_classes: list of unseen classes' names
43 | :param device: device in use
44 | """
45 |
46 | super().__init__(
47 | config, label_to_idx, classes, seen_classes, unseen_classes, device
48 | )
49 |
50 | self.data_folder = data_folder
51 |
52 | self.check_unlabeled = unlabeled_files
53 |
54 | def create_training_dataset(self, train_data, unlabeled_data=None):
55 | """This function creates the dataset for training. Specifically, it
56 | merges pseudo-labels for unseen data and labeled data for seen classes.
57 | :param train_data: Dataset object - training seen classes (defined in zsl_jpl line 323)
58 | :param unlabeled_data: Dataset object - dataset of unlabeled data for
59 | unseen classes (defined in zsl_jpl line 328)
60 | """
61 |
62 | # Get pseudo-labels for unlabeled data from unseen classes
63 | train_unseen_dataset = pseudolabel_top_k(
64 | self.config,
65 | self.config.DATASET_NAME,
66 | self.config.N_PSEUDOSHOTS,
67 | self.config.PROMPT_TEMPLATE,
68 | unlabeled_data,
69 | self.unseen_classes,
70 | self.transform,
71 | self.clip_model,
72 | self.label_to_idx,
73 | self.device,
74 | self.config.VIS_ENCODER,
75 | self.config.SPLIT_SEED
76 | )
77 |
78 | # Define the lists of traiing data from seen and unseen classes
79 | unseen_imgs = train_unseen_dataset.filepaths
80 | unseen_labs = train_unseen_dataset.labels
81 |
82 | # Use a portion of the pseudo-labeled data to build a validation set
83 | if self.config.N_PSEUDOSHOTS >= 10:
84 | np.random.seed(self.config.validation_seed)
85 | train_indices = np.random.choice(
86 | range(len(unseen_imgs)),
87 | size=int(len(unseen_imgs) * self.config.ratio_train_val),
88 | replace=False,
89 | )
90 | val_indices = list(
91 | set(range(len(unseen_imgs))).difference(set(train_indices))
92 | )
93 |
94 | self.val_unseen_files = np.array(unseen_imgs)[val_indices]
95 | self.val_unseen_labs = np.array(unseen_labs)[val_indices]
96 |
97 | unseen_imgs = list(np.array(unseen_imgs)[train_indices])
98 | unseen_labs = list(np.array(unseen_labs)[train_indices])
99 |
100 | else:
101 | self.val_unseen_files = None
102 | self.val_unseen_labs = None
103 |
104 | seen_imgs = train_data.filepaths
105 | seen_labs = [self.label_to_idx[l] for l in train_data.labels]
106 |
107 | self.balance_param = math.sqrt(len(unseen_imgs) / len(seen_imgs))
108 |
109 | train_data.filepaths = list(unseen_imgs) + list(seen_imgs)
110 | train_data.labels = list(unseen_labs) + list(seen_labs)
111 | train_data.label_id = True
112 |
113 | def define_loss_function(self, logits, labs, paths):
114 |
115 | loss_ce_seen = self.cross_entropy(logits, labs, paths, False)
116 | loss_ce_unseen = self.cross_entropy(logits, labs, paths, True)
117 |
118 | return self.balance_param * loss_ce_seen + loss_ce_unseen
119 |
120 |
121 | def cross_entropy(self, logits, labels, paths, unlabeled=True):
122 | """This loss computes the probability mass on the
123 | opposite set of classes for each sample.
124 | :param logits: continuous vector
125 | :param labels: class ids
126 | """
127 |
128 | # self.check_unlabeled
129 | if unlabeled:
130 | samples = []
131 | for idx in range(len(paths)):
132 | if paths[idx] in self.check_unlabeled:
133 | samples.append(idx)
134 |
135 | #log.info(f"Unlabeled: {len(samples)} {self.unlabeled_weight}")
136 | if samples:
137 | error = self.loss_func(logits[samples], labels[samples])
138 | else:
139 | error = 0
140 | else:
141 | samples = []
142 | for idx in range(len(paths)):
143 | if paths[idx] not in self.check_unlabeled:
144 | samples.append(idx)
145 |
146 | #log.info(f"Labeled: {len(samples)} {self.labeled_weight}")
147 | if samples:
148 | error = self.loss_func(logits[samples], labels[samples])
149 | else:
150 | error = 0
151 |
152 | return error
153 |
154 | def reindex_predicted_labels(self, idx_preds, only_unlabelled=False):
155 | """This function returns the correct index of predictions to compute
156 | model's accuracy.
157 | :param idx_pred: list of predictions ids
158 | :param only_unlabelled: boolean. It is True if the training only involves
159 | pseudo-labeled unseen data
160 | """
161 |
162 | if only_unlabelled:
163 | return [self.unseen_classes[i.item()] for i in idx_preds]
164 | else:
165 | return [self.classes[i.item()] for i in idx_preds]
166 |
167 | def reindex_true_labels(self, label, only_unlabelled=False):
168 | """This function returns the correct index of true labels.
169 | :param label: list of labels from data loader
170 | :param only_unlabelled: boolean. It is True if the training only involves
171 | pseudo-labeled unseen data
172 | """
173 |
174 | if only_unlabelled:
175 | return torch.tensor(
176 | [self.unseen_classes.index(self.classes[l.item()]) for l in label]
177 | )
178 | else:
179 | return torch.tensor([l for l in label])
180 |
181 | def get_pseudo_labels(self, unlabeled_examples):
182 | log.info(f"Num unlabeled data: {len(unlabeled_examples)}")
183 | # Get prediction on unlabeled data
184 | std_preds = self.test_predictions(
185 | unlabeled_examples, standard_zsl=True
186 | )
187 |
188 | DatasetObject = dataset_object(self.config.DATASET_NAME)
189 | # 4. Take top-16 pseudo-labels to finetune the student
190 | pseudo_unseen_examples = DatasetObject(
191 | std_preds["id"],
192 | self.data_folder,
193 | transform=self.transform,
194 | augmentations=None,
195 | train=True,
196 | labels=None,
197 | label_map=self.label_to_idx,
198 | class_folder=True,
199 | original_filepaths=unlabeled_examples.filepaths,
200 | )
201 |
202 | pseudo_labels = self.assign_pseudo_labels(
203 | self.config.N_PSEUDOSHOTS, pseudo_unseen_examples
204 | )
205 |
206 | return pseudo_labels
207 |
208 | def assign_pseudo_labels(self, k, unlabeled_data):
209 |
210 | # to find the top k for each class, each class has it's own "leaderboard"
211 | top_k_leaderboard = {
212 | self.label_to_idx[self.unseen_classes[i]]: []
213 | for i in range(len(self.unseen_classes))
214 | } # maps class idx -> (confidence, image_path) tuple
215 |
216 | classes = self.unseen_classes
217 | for img_path in unlabeled_data.filepaths:
218 | # log.info(f"IMAGEPATH: {img_path}")
219 | img = Image.open(img_path).convert("RGB")
220 | img = torch.unsqueeze(self.transform(img), 0).to(self.device)
221 | with torch.no_grad():
222 | # Get text and image prompts using UPT
223 | text_features, image_features = self.model(img, classes)
224 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
225 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
226 |
227 | logit_scale = self.clip_model.logit_scale.exp()
228 | logits = logit_scale * image_features @ text_features.t()
229 | probs = logits.softmax(dim=-1)
230 | idx_preds = torch.argmax(logits, dim=1)
231 | pred_id = idx_preds.item()
232 | pred = self.label_to_idx[self.unseen_classes[idx_preds.item()]]
233 |
234 | """if predicted class has empty leaderboard, or if the confidence is high
235 | enough for predicted class leaderboard, add the new example
236 | """
237 | prob_score = probs[0][pred_id]
238 | if len(top_k_leaderboard[pred]) < k:
239 | top_k_leaderboard[pred].append((prob_score, img_path))
240 | elif (
241 | top_k_leaderboard[pred][-1][0] < prob_score
242 | ): # if the confidence in predicted class "qualifies" for top-k
243 | # default sorting of tuples is by first element
244 | top_k_leaderboard[pred] = sorted(
245 | top_k_leaderboard[pred] + [(probs[0][pred_id], img_path)],
246 | reverse=True,
247 | )[:k]
248 | else:
249 | # sort the other classes by confidence score
250 | order_of_classes = sorted(
251 | [
252 | (probs[0][j], j)
253 | for j in range(len(self.unseen_classes))
254 | if j != pred_id
255 | ],
256 | reverse=True,
257 | )
258 | for score, index in order_of_classes:
259 | index_dict = self.label_to_idx[self.unseen_classes[index]]
260 | # log.info(f"{classnames[index]}")
261 | # log.info(f"{index_dict}")
262 | if len(top_k_leaderboard[index_dict]) < k:
263 | top_k_leaderboard[index_dict].append(
264 | (probs[0][index], img_path)
265 | )
266 | elif top_k_leaderboard[index_dict][-1][0] < probs[0][index]:
267 | # default sorting of tuples is by first element
268 | top_k_leaderboard[index_dict] = sorted(
269 | top_k_leaderboard[index_dict]
270 | + [((probs[0][index], img_path))],
271 | reverse=True,
272 | )[:k]
273 |
274 | new_imgs = []
275 | new_labels = []
276 | # loop through, and rebuild the dataset
277 | for index, leaderboard in top_k_leaderboard.items():
278 | new_imgs += [tup[1] for tup in leaderboard]
279 | new_labels += [index for _ in leaderboard]
280 |
281 | unlabeled_data.filepaths = new_imgs
282 | unlabeled_data.labels = new_labels
283 | unlabeled_data.label_id = True
284 |
285 | return unlabeled_data
286 |
--------------------------------------------------------------------------------
/methods/semi_supervised_learning/pseudo_iterative.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | import math
4 | import pickle
5 |
6 | import clip
7 | import numpy as np
8 | import pandas as pd
9 | import scipy.stats as st
10 | import torch
11 | from accelerate import Accelerator
12 | from PIL import Image
13 | from torch import nn
14 |
15 |
16 | accelerator = Accelerator()
17 |
18 | from data import CustomDataset
19 | from models import CustomImageEncoder, ImagePrefixModel
20 | from methods import TeacherStudent
21 | from utils import (
22 | dataset_object,
23 | evaluate_predictions,
24 | make_scheduler,
25 | pseudolabel_top_k,
26 | seed_worker,
27 | save_parameters,
28 | save_pseudo_labels,
29 | )
30 |
31 | g = torch.Generator()
32 | g.manual_seed(0)
33 |
34 | log = logging.getLogger(__name__)
35 |
36 |
37 | class PseudoIterative(TeacherStudent):
38 | def __init__(
39 | self,
40 | config,
41 | label_to_idx,
42 | data_folder,
43 | classes,
44 | seen_classes,
45 | unseen_classes,
46 | device,
47 | ):
48 | super().__init__(
49 | config, label_to_idx, data_folder, classes, seen_classes, unseen_classes, device
50 | )
51 |
52 |
53 | def train(
54 | self,
55 | train_data,
56 | val_data,
57 | unlabeled_data,
58 | test_data,
59 | test_labeled_files,
60 | test_labeles,
61 | ):
62 | # Number of total iterations to cover all unlabeled data
63 | num_iter = int(100/self.config.STEP_QUANTILE)
64 | num_samples = int(len(unlabeled_data) / num_iter)
65 | # Initialize the number of pseudo-labels per class
66 | n_per_class = int(num_samples / len(self.unseen_classes))
67 | n_unseen = len(self.unseen_classes)
68 | if n_per_class * n_unseen <= len(unlabeled_data.filepaths):
69 | # self.num_pseudo_labels_per_class = n_per_class
70 | self.config.N_PSEUDOSHOTS = n_per_class
71 | else:
72 | # self.num_pseudo_labels_per_class = math.floor(len(unlabeled_data.filepaths)/n_unseen)
73 | self.config.N_PSEUDOSHOTS = math.floor(
74 | len(unlabeled_data.filepaths) / n_unseen
75 | )
76 |
77 | log.info(f"We select {self.config.N_PSEUDOSHOTS} pseudolabel per each unseen classes.")
78 | log.info(f"The number of unseen classes is: {len(self.unseen_classes)}.")
79 | log.info(f"Thus we expect an initial number of pseudo labeles equal to {len(self.unseen_classes) * self.config.N_PSEUDOSHOTS}.")
80 | # Create a safe copy of labeled/unlabeled data
81 | original_train_data = copy.deepcopy(train_data)
82 | # log.info(f"Training data labels: {original_train_data.labels}")
83 | original_unlabeled_data = copy.deepcopy(unlabeled_data)
84 | # Original val
85 | original_val_data = copy.deepcopy(val_data)
86 |
87 | # Initialize here first batch of pseudo labels
88 | #self.create_training_dataset(train_data, unlabeled_data)
89 | #log.info(f"The original train data has size: {len(original_train_data.filepaths)}.")
90 | #log.info(f"Plus: {len(unlabeled_data.filepaths)}.")
91 |
92 | for niter in range(1, num_iter + 1):
93 | log.info(f"NUM PSEUDO SHOTS: {self.config.N_PSEUDOSHOTS}")
94 | pseudolabel_top_k(
95 | self.config.DATASET_NAME,
96 | self.config.N_PSEUDOSHOTS,
97 | self.config.PROMPT_TEMPLATE,
98 | unlabeled_data,
99 | self.unseen_classes,
100 | self.transform,
101 | self.clip_model,
102 | self.label_to_idx,
103 | self.device,
104 | self.config.VIS_ENCODER,
105 | self.config.SPLIT_SEED,
106 | )
107 |
108 | log.info(f"Plus: {len(unlabeled_data.filepaths)}.")
109 | filename = f"pseudolabels/{self.config.DATASET_NAME}_CLIP_{self.config.VIS_ENCODER.replace('/', '')}_iter_{niter}_pseudolabels_spl_{self.config.SPLIT_SEED}.pickle"
110 | with open(filename, "wb") as f:
111 | pickle.dump({"filepaths": unlabeled_data.filepaths, "labels": unlabeled_data.labels}, f)
112 |
113 | # Exploit all the available unlabeled data
114 | if self.config.ALL_UNLABELED:
115 | n_per_class = int((niter + 1) * num_samples / n_unseen)
116 | log.info(f"n_per_class: {n_per_class}")
117 | if n_per_class * n_unseen <= len(original_unlabeled_data.filepaths):
118 | log.info(f"if n_per_class: {n_per_class}")
119 | self.config.N_PSEUDOSHOTS = n_per_class
120 | else:
121 | log.info(f"else new val: {len(original_unlabeled_data.filepaths) / n_unseen}")
122 | # We are making a stong assumption about the distribution of unlabeled data
123 | self.config.N_PSEUDOSHOTS = math.floor(
124 | len(original_unlabeled_data.filepaths) / n_unseen
125 | )
126 |
127 | unlabeled_data = original_unlabeled_data
128 | original_unlabeled_data = copy.deepcopy(unlabeled_data)
--------------------------------------------------------------------------------
/methods/semi_supervised_learning/textual_fpl.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import clip
4 | import numpy as np
5 | import pandas as pd
6 | import torch
7 | from accelerate import Accelerator
8 | from PIL import Image
9 | from torch import nn
10 | from tqdm import tqdm
11 |
12 | accelerator = Accelerator()
13 |
14 | from methods.semi_supervised_learning import TextualPrompt
15 | from utils import (
16 | dataset_object,
17 | make_scheduler,
18 | pseudolabel_top_k,
19 | seed_worker,
20 | )
21 |
22 |
23 | g = torch.Generator()
24 | g.manual_seed(0)
25 |
26 | log = logging.getLogger(__name__)
27 |
28 |
29 | class TextualFPL(TextualPrompt):
30 | def __init__(
31 | self,
32 | config,
33 | label_to_idx,
34 | data_folder,
35 | unlabeled_files,
36 | classes,
37 | seen_classes,
38 | unseen_classes,
39 | device,
40 | ):
41 | """This class define Coop baseline.
42 |
43 | :param config: dictionaries of prameters in models_config/coop_baseline_config.yml
44 | :param label_to_idx: dictionary (key, value):(class name, id)
45 | :param classes: list of class names
46 | :param seen_classes: list of seen classes' names
47 | :param unseen_classes: list of unseen classes' names
48 | :param device: device in use
49 | """
50 | super().__init__(
51 | config, label_to_idx, classes, seen_classes, unseen_classes, device
52 | )
53 |
54 | self.data_folder = data_folder
55 |
56 | self.check_unlabeled = unlabeled_files
57 |
58 | def create_training_dataset(self, train_data, unlabeled_data=None):
59 | """This function create the dataset for training. Specifically, it
60 | merges pseudo-labels for unseen data and labeled data for seen classes.
61 |
62 | :param train_data: Dataset object - training seen classes (defined in zsl_jpl line 323)
63 | :param unlabeled_data: Dataset object - dataset of unlabeled data for
64 | unseen classes (defined in zsl_jpl line 328)
65 | """
66 |
67 | # Get pseudo-labels for unlabeled data from unseen classes
68 | train_unseen_dataset = pseudolabel_top_k(
69 | self.config,
70 | self.config.DATASET_NAME,
71 | self.config.N_PSEUDOSHOTS,
72 | self.config.PROMPT_TEMPLATE,
73 | unlabeled_data,
74 | self.unseen_classes,
75 | self.transform,
76 | self.clip_model,
77 | self.label_to_idx,
78 | self.device,
79 | self.config.VIS_ENCODER,
80 | self.config.SPLIT_SEED,
81 | )
82 |
83 | # Define the lists of traiing data from seen and unseen classes
84 | unseen_imgs = train_unseen_dataset.filepaths
85 | unseen_labs = train_unseen_dataset.labels
86 |
87 | # Use a portion of the pseudo-labeled data to build a validation set
88 | if self.config.N_PSEUDOSHOTS >= 10:
89 | np.random.seed(self.config.validation_seed)
90 | train_indices = np.random.choice(
91 | range(len(unseen_imgs)),
92 | size=int(len(unseen_imgs) * self.config.ratio_train_val),
93 | replace=False,
94 | )
95 | val_indices = list(
96 | set(range(len(unseen_imgs))).difference(set(train_indices))
97 | )
98 |
99 | self.val_unseen_files = np.array(unseen_imgs)[val_indices]
100 | self.val_unseen_labs = np.array(unseen_labs)[val_indices]
101 |
102 | unseen_imgs = list(np.array(unseen_imgs)[train_indices])
103 | unseen_labs = list(np.array(unseen_labs)[train_indices])
104 |
105 | else:
106 | self.val_unseen_files = None
107 | self.val_unseen_labs = None
108 |
109 | seen_imgs = train_data.filepaths
110 | seen_labs = [self.label_to_idx[l] for l in train_data.labels]
111 |
112 | # self.labeled_weight = 1 / len(seen_imgs)
113 | # self.unlabeled_weight = 1 / len(unseen_imgs)
114 | # log.info(f"UNLABELD: {len(unseen_imgs)} and LABELED {len(seen_imgs)}")
115 | self.balance_param = len(unseen_imgs) / len(seen_imgs)
116 |
117 | train_data.filepaths = list(unseen_imgs) + list(seen_imgs)
118 | train_data.labels = list(unseen_labs) + list(seen_labs)
119 | train_data.label_id = True
120 |
121 | return train_data
122 |
123 | def define_loss_function(self, logits, labs, paths):
124 |
125 | loss_ce_seen = self.cross_entropy(logits, labs, paths, False)
126 | loss_ce_unseen = self.cross_entropy(logits, labs, paths, True)
127 |
128 | return self.balance_param * loss_ce_seen + loss_ce_unseen
129 |
130 | def cross_entropy(self, logits, labels, paths, unlabeled=True):
131 | """This loss computes the probability mass on the
132 | opposite set of classes for each sample.
133 |
134 | :param logits: continuous vector
135 | :param labels: class ids
136 | """
137 |
138 | # log.info(f"IMAGE PATHS: {paths}")
139 | # log.info(f"IMAGE CHECK: {self.check_unlabeled[:10]}")
140 |
141 | # self.check_unlabeled
142 | if unlabeled:
143 | samples = []
144 | for idx in range(len(paths)):
145 | if paths[idx] in self.check_unlabeled:
146 | samples.append(idx)
147 |
148 | # log.info(f"Unlabeled: {len(samples)} {self.balance_param}")
149 | if samples:
150 | error = self.loss_func(logits[samples], labels[samples])
151 | else:
152 | error = 0
153 | else:
154 | samples = []
155 | for idx in range(len(paths)):
156 | if paths[idx] not in self.check_unlabeled:
157 | samples.append(idx)
158 |
159 | # log.info(f"Labeled: {len(samples)} {self.balance_param}")
160 | if samples:
161 | error = self.loss_func(logits[samples], labels[samples])
162 | else:
163 | error = 0
164 |
165 | return error
166 |
167 |
168 | def get_pseudo_labels(self, unlabeled_examples):
169 | log.info(f"Num unlabeled data: {len(unlabeled_examples)}")
170 | # Get prediction on unlabeled data
171 | std_preds = self.test_predictions(
172 | unlabeled_examples, standard_zsl=True
173 | )
174 |
175 | DatasetObject = dataset_object(self.config.DATASET_NAME)
176 | # 4. Take top-16 pseudo-labels to finetune the student
177 | pseudo_unseen_examples = DatasetObject(
178 | std_preds["id"],
179 | self.data_folder,
180 | transform=self.transform,
181 | augmentations=None,
182 | train=True,
183 | labels=None,
184 | label_map=self.label_to_idx,
185 | class_folder=True,
186 | original_filepaths=unlabeled_examples.filepaths,
187 | )
188 |
189 | pseudo_labels = self.assign_pseudo_labels(
190 | self.config.N_PSEUDOSHOTS, pseudo_unseen_examples
191 | )
192 |
193 | return pseudo_labels
194 |
195 | def assign_pseudo_labels(self, k, unlabeled_data):
196 | # Define text queries
197 | # prompts = [f"{self.template}{' '.join(i.split('_'))}" \
198 | # for i in self.unseen_classes]
199 |
200 | log.info(f"[self.assign_pseudo_labels] Number of prompts: {len(self.unseen_classes)}")
201 |
202 | # Get prompts
203 | self.model.classes = self.unseen_classes
204 | text_features = self.model(self.model.classes)
205 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
206 | log.info(f"TEXT FEATURES SHAPE: {text_features.size()}")
207 |
208 | # to find the top k for each class, each class has it's own "leaderboard"
209 | top_k_leaderboard = {
210 | self.label_to_idx[self.unseen_classes[i]]: []
211 | for i in range(len(self.unseen_classes))
212 | } # maps class idx -> (confidence, image_path) tuple
213 |
214 | for img_path in unlabeled_data.filepaths:
215 | # log.info(f"IMAGEPATH: {img_path}")
216 | img = Image.open(img_path).convert("RGB")
217 | img = torch.unsqueeze(self.transform(img), 0).to(self.device)
218 | with torch.no_grad():
219 | image_features = self.clip_model.encode_image(img)
220 | image_features = image_features / image_features.norm(
221 | dim=-1, keepdim=True
222 | )
223 | # cosine similarity as logits
224 |
225 | logit_scale = self.clip_model.logit_scale.exp()
226 | logits = logit_scale * image_features @ text_features.t()
227 | probs = logits.softmax(dim=-1)
228 | idx_preds = torch.argmax(logits, dim=1)
229 | pred_id = idx_preds.item()
230 | pred = self.label_to_idx[self.unseen_classes[idx_preds.item()]]
231 |
232 | """if predicted class has empty leaderboard, or if the confidence is high
233 | enough for predicted class leaderboard, add the new example
234 | """
235 | prob_score = probs[0][pred_id]
236 | if len(top_k_leaderboard[pred]) < k:
237 | top_k_leaderboard[pred].append((prob_score, img_path))
238 | elif (
239 | top_k_leaderboard[pred][-1][0] < prob_score
240 | ): # if the confidence in predicted class "qualifies" for top-k
241 | # default sorting of tuples is by first element
242 | top_k_leaderboard[pred] = sorted(
243 | top_k_leaderboard[pred] + [(probs[0][pred_id], img_path)],
244 | reverse=True,
245 | )[:k]
246 | else:
247 | # sort the other classes by confidence score
248 | order_of_classes = sorted(
249 | [
250 | (probs[0][j], j)
251 | for j in range(len(self.unseen_classes))
252 | if j != pred_id
253 | ],
254 | reverse=True,
255 | )
256 | for score, index in order_of_classes:
257 | index_dict = self.label_to_idx[self.unseen_classes[index]]
258 | # log.info(f"{classnames[index]}")
259 | # log.info(f"{index_dict}")
260 | if len(top_k_leaderboard[index_dict]) < k:
261 | top_k_leaderboard[index_dict].append(
262 | (probs[0][index], img_path)
263 | )
264 | elif top_k_leaderboard[index_dict][-1][0] < probs[0][index]:
265 | # default sorting of tuples is by first element
266 | top_k_leaderboard[index_dict] = sorted(
267 | top_k_leaderboard[index_dict]
268 | + [((probs[0][index], img_path))],
269 | reverse=True,
270 | )[:k]
271 |
272 | new_imgs = []
273 | new_labels = []
274 | # loop through, and rebuild the dataset
275 | for index, leaderboard in top_k_leaderboard.items():
276 | new_imgs += [tup[1] for tup in leaderboard]
277 | new_labels += [index for _ in leaderboard]
278 |
279 | unlabeled_data.filepaths = new_imgs
280 | unlabeled_data.labels = new_labels
281 | unlabeled_data.label_id = True
282 |
283 | return unlabeled_data
284 |
--------------------------------------------------------------------------------
/methods/transductive_zsl/__init__.py:
--------------------------------------------------------------------------------
1 | from .training_strategies import TrainingStrategy
2 | from .visual_prompt import VisualPrompt
3 | from .textual_prompt import TextualPrompt
4 | from .multimodal_prompt import MultimodalPrompt
5 | from .visual_fpl import VisualFPL
6 | from .textual_fpl import TextualFPL
7 | from .multimodal_fpl import MultimodalFPL
8 |
--------------------------------------------------------------------------------
/methods/transductive_zsl/multimodal_fpl.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import clip
4 | import numpy as np
5 | import pandas as pd
6 | import torch
7 | from accelerate import Accelerator
8 | from PIL import Image
9 | from torch import nn
10 | import math
11 |
12 | accelerator = Accelerator()
13 |
14 | from methods.transductive_zsl import MultimodalPrompt
15 | from utils import (
16 | dataset_object,
17 | make_scheduler,
18 | pseudolabel_top_k,
19 | )
20 |
21 |
22 | log = logging.getLogger(__name__)
23 |
24 |
25 | class MultimodalFPL(MultimodalPrompt):
26 | def __init__(
27 | self,
28 | config,
29 | label_to_idx,
30 | data_folder,
31 | classes,
32 | seen_classes,
33 | unseen_classes,
34 | device
35 | ):
36 | """This class defines self-trainig UPT's training and evaluation.
37 | :param config: dictionaries of prameters in models_config/upt_baseline_pseudo_config.yml
38 | :param label_to_idx: dictionary (key, value):(class name, id)
39 | :param classes: list of class names
40 | :param seen_classes: list of seen classes' names
41 | :param unseen_classes: list of unseen classes' names
42 | :param device: device in use
43 | """
44 |
45 | super().__init__(
46 | config, label_to_idx, classes, seen_classes, unseen_classes, device
47 | )
48 |
49 | self.data_folder = data_folder
50 |
51 | def create_training_dataset(self, train_data, unlabeled_data=None):
52 | """This function creates the dataset for training. Specifically, it
53 | merges pseudo-labels for unseen data and labeled data for seen classes.
54 | :param train_data: Dataset object - training seen classes (defined in zsl_jpl line 323)
55 | :param unlabeled_data: Dataset object - dataset of unlabeled data for
56 | unseen classes (defined in zsl_jpl line 328)
57 | """
58 |
59 | # Get pseudo-labels for unlabeled data from unseen classes
60 | train_unseen_dataset = pseudolabel_top_k(
61 | self.config,
62 | self.config.DATASET_NAME,
63 | self.config.N_PSEUDOSHOTS,
64 | self.config.PROMPT_TEMPLATE,
65 | unlabeled_data,
66 | self.unseen_classes,
67 | self.transform,
68 | self.clip_model,
69 | self.label_to_idx,
70 | self.device,
71 | self.config.VIS_ENCODER,
72 | self.config.SPLIT_SEED
73 | )
74 |
75 | # Define the lists of traiing data from seen and unseen classes
76 | unseen_imgs = train_unseen_dataset.filepaths
77 | unseen_labs = train_unseen_dataset.labels
78 |
79 | # Use a portion of the pseudo-labeled data to build a validation set
80 | if self.config.N_PSEUDOSHOTS >= 10:
81 | np.random.seed(self.config.validation_seed)
82 | train_indices = np.random.choice(
83 | range(len(unseen_imgs)),
84 | size=int(len(unseen_imgs) * self.config.ratio_train_val),
85 | replace=False,
86 | )
87 | val_indices = list(
88 | set(range(len(unseen_imgs))).difference(set(train_indices))
89 | )
90 |
91 | self.val_unseen_files = np.array(unseen_imgs)[val_indices]
92 | self.val_unseen_labs = np.array(unseen_labs)[val_indices]
93 |
94 | unseen_imgs = list(np.array(unseen_imgs)[train_indices])
95 | unseen_labs = list(np.array(unseen_labs)[train_indices])
96 |
97 | else:
98 | self.val_unseen_files = None
99 | self.val_unseen_labs = None
100 |
101 | seen_imgs = train_data.filepaths
102 | seen_labs = [self.label_to_idx[l] for l in train_data.labels]
103 |
104 | self.balance_param = math.sqrt(len(seen_imgs) / len(unseen_imgs))
105 |
106 | train_data.filepaths = list(unseen_imgs) + list(seen_imgs)
107 | train_data.labels = list(unseen_labs) + list(seen_labs)
108 | train_data.label_id = True
109 |
110 | def define_loss_function(self, logits, labs, teacher=False):
111 |
112 | loss_ce_seen = self.cross_entropy(logits, labs, self.seen_classes)
113 | loss_ce_unseen = self.cross_entropy(logits, labs, self.unseen_classes)
114 |
115 | return loss_ce_seen + self.balance_param * loss_ce_unseen
116 |
117 | def cross_entropy(self, logits, labels, classes):
118 | """This loss computes the probability mass on the
119 | opposite set of classes for each sample.
120 | :param logits: continuous vector
121 | :param labels: class ids
122 | """
123 |
124 | ids = [self.label_to_idx[c] for c in classes]
125 |
126 | # Get indices of unseen and seen samples in the batch
127 | samples = []
128 |
129 | for idx, l in enumerate(labels):
130 | if l in ids:
131 | samples.append(idx)
132 |
133 | if samples:
134 | error = self.loss_func(logits[samples], labels[samples])
135 | else:
136 | error = 0
137 |
138 | return error
139 |
140 | def reindex_predicted_labels(self, idx_preds, only_unlabelled=False):
141 | """This function returns the correct index of predictions to compute
142 | model's accuracy.
143 | :param idx_pred: list of predictions ids
144 | :param only_unlabelled: boolean. It is True if the training only involves
145 | pseudo-labeled unseen data
146 | """
147 |
148 | if only_unlabelled:
149 | return [self.unseen_classes[i.item()] for i in idx_preds]
150 | else:
151 | return [self.classes[i.item()] for i in idx_preds]
152 |
153 | def reindex_true_labels(self, label, only_unlabelled=False):
154 | """This function returns the correct index of true labels.
155 | :param label: list of labels from data loader
156 | :param only_unlabelled: boolean. It is True if the training only involves
157 | pseudo-labeled unseen data
158 | """
159 |
160 | if only_unlabelled:
161 | return torch.tensor(
162 | [self.unseen_classes.index(self.classes[l.item()]) for l in label]
163 | )
164 | else:
165 | return torch.tensor([l for l in label])
166 |
167 | def get_pseudo_labels(self, unlabeled_examples):
168 | log.info(f"Num unlabeled data: {len(unlabeled_examples)}")
169 | # Get prediction on unlabeled data
170 | std_preds = self.test_predictions(
171 | unlabeled_examples, standard_zsl=True
172 | )
173 |
174 | DatasetObject = dataset_object(self.config.DATASET_NAME)
175 | # 4. Take top-16 pseudo-labels to finetune the student
176 | pseudo_unseen_examples = DatasetObject(
177 | std_preds["id"],
178 | self.data_folder,
179 | transform=self.transform,
180 | augmentations=None,
181 | train=True,
182 | labels=None,
183 | label_map=self.label_to_idx,
184 | class_folder=True,
185 | original_filepaths=unlabeled_examples.filepaths,
186 | )
187 |
188 | pseudo_labels = self.assign_pseudo_labels(
189 | self.config.N_PSEUDOSHOTS, pseudo_unseen_examples
190 | )
191 |
192 | return pseudo_labels
193 |
194 | def assign_pseudo_labels(self, k, unlabeled_data):
195 |
196 | # to find the top k for each class, each class has it's own "leaderboard"
197 | top_k_leaderboard = {
198 | self.label_to_idx[self.unseen_classes[i]]: []
199 | for i in range(len(self.unseen_classes))
200 | } # maps class idx -> (confidence, image_path) tuple
201 |
202 | classes = self.unseen_classes
203 | for img_path in unlabeled_data.filepaths:
204 | # log.info(f"IMAGEPATH: {img_path}")
205 | img = Image.open(img_path).convert("RGB")
206 | img = torch.unsqueeze(self.transform(img), 0).to(self.device)
207 | with torch.no_grad():
208 | # Get text and image prompts using UPT
209 | # coop_embeddings, vpt_embeddings, vpt_deep_embeddings = self.model(0)
210 | # Calculate text prompts
211 | # text_features = self.text_encoder(coop_embeddings, classes)
212 | text_features, image_features = self.model(img, classes)
213 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
214 | # Calculate image prompts
215 | # image_features = self.image_encoder(img, vpt_embeddings, deep_embds=vpt_deep_embeddings)
216 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
217 |
218 | logit_scale = self.clip_model.logit_scale.exp()
219 | logits = logit_scale * image_features @ text_features.t()
220 | probs = logits.softmax(dim=-1)
221 | idx_preds = torch.argmax(logits, dim=1)
222 | pred_id = idx_preds.item()
223 | pred = self.label_to_idx[self.unseen_classes[idx_preds.item()]]
224 |
225 | """if predicted class has empty leaderboard, or if the confidence is high
226 | enough for predicted class leaderboard, add the new example
227 | """
228 | prob_score = probs[0][pred_id]
229 | if len(top_k_leaderboard[pred]) < k:
230 | top_k_leaderboard[pred].append((prob_score, img_path))
231 | elif (
232 | top_k_leaderboard[pred][-1][0] < prob_score
233 | ): # if the confidence in predicted class "qualifies" for top-k
234 | # default sorting of tuples is by first element
235 | top_k_leaderboard[pred] = sorted(
236 | top_k_leaderboard[pred] + [(probs[0][pred_id], img_path)],
237 | reverse=True,
238 | )[:k]
239 | else:
240 | # sort the other classes by confidence score
241 | order_of_classes = sorted(
242 | [
243 | (probs[0][j], j)
244 | for j in range(len(self.unseen_classes))
245 | if j != pred_id
246 | ],
247 | reverse=True,
248 | )
249 | for score, index in order_of_classes:
250 | index_dict = self.label_to_idx[self.unseen_classes[index]]
251 | # log.info(f"{classnames[index]}")
252 | # log.info(f"{index_dict}")
253 | if len(top_k_leaderboard[index_dict]) < k:
254 | top_k_leaderboard[index_dict].append(
255 | (probs[0][index], img_path)
256 | )
257 | elif top_k_leaderboard[index_dict][-1][0] < probs[0][index]:
258 | # default sorting of tuples is by first element
259 | top_k_leaderboard[index_dict] = sorted(
260 | top_k_leaderboard[index_dict]
261 | + [((probs[0][index], img_path))],
262 | reverse=True,
263 | )[:k]
264 |
265 | new_imgs = []
266 | new_labels = []
267 | # loop through, and rebuild the dataset
268 | for index, leaderboard in top_k_leaderboard.items():
269 | new_imgs += [tup[1] for tup in leaderboard]
270 | new_labels += [index for _ in leaderboard]
271 |
272 | unlabeled_data.filepaths = new_imgs
273 | unlabeled_data.labels = new_labels
274 | unlabeled_data.label_id = True
275 |
276 | return unlabeled_data
277 |
--------------------------------------------------------------------------------
/methods/transductive_zsl/pseudo_iterative.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | import math
4 | import pickle
5 |
6 | import clip
7 | import numpy as np
8 | import pandas as pd
9 | import scipy.stats as st
10 | import torch
11 | from accelerate import Accelerator
12 | from PIL import Image
13 | from torch import nn
14 |
15 |
16 | accelerator = Accelerator()
17 |
18 | from data import CustomDataset
19 | from models import CustomImageEncoder, ImagePrefixModel
20 | from methods import TeacherStudent
21 | from utils import (
22 | dataset_object,
23 | evaluate_predictions,
24 | make_scheduler,
25 | pseudolabel_top_k,
26 | seed_worker,
27 | save_parameters,
28 | save_pseudo_labels,
29 | )
30 |
31 | g = torch.Generator()
32 | g.manual_seed(0)
33 |
34 | log = logging.getLogger(__name__)
35 |
36 |
37 | class PseudoIterative(TeacherStudent):
38 | def __init__(
39 | self,
40 | config,
41 | label_to_idx,
42 | data_folder,
43 | classes,
44 | seen_classes,
45 | unseen_classes,
46 | device,
47 | ):
48 | super().__init__(
49 | config, label_to_idx, data_folder, classes, seen_classes, unseen_classes, device
50 | )
51 |
52 |
53 | def train(
54 | self,
55 | train_data,
56 | val_data,
57 | unlabeled_data,
58 | test_data,
59 | test_labeled_files,
60 | test_labeles,
61 | ):
62 | # Number of total iterations to cover all unlabeled data
63 | num_iter = int(100/self.config.STEP_QUANTILE)
64 | num_samples = int(len(unlabeled_data) / num_iter)
65 | # Initialize the number of pseudo-labels per class
66 | n_per_class = int(num_samples / len(self.unseen_classes))
67 | n_unseen = len(self.unseen_classes)
68 | if n_per_class * n_unseen <= len(unlabeled_data.filepaths):
69 | # self.num_pseudo_labels_per_class = n_per_class
70 | self.config.N_PSEUDOSHOTS = n_per_class
71 | else:
72 | # self.num_pseudo_labels_per_class = math.floor(len(unlabeled_data.filepaths)/n_unseen)
73 | self.config.N_PSEUDOSHOTS = math.floor(
74 | len(unlabeled_data.filepaths) / n_unseen
75 | )
76 |
77 | log.info(f"We select {self.config.N_PSEUDOSHOTS} pseudolabel per each unseen classes.")
78 | log.info(f"The number of unseen classes is: {len(self.unseen_classes)}.")
79 | log.info(f"Thus we expect an initial number of pseudo labeles equal to {len(self.unseen_classes) * self.config.N_PSEUDOSHOTS}.")
80 | # Create a safe copy of labeled/unlabeled data
81 | original_train_data = copy.deepcopy(train_data)
82 | # log.info(f"Training data labels: {original_train_data.labels}")
83 | original_unlabeled_data = copy.deepcopy(unlabeled_data)
84 | # Original val
85 | original_val_data = copy.deepcopy(val_data)
86 |
87 | # Initialize here first batch of pseudo labels
88 | #self.create_training_dataset(train_data, unlabeled_data)
89 | #log.info(f"The original train data has size: {len(original_train_data.filepaths)}.")
90 | #log.info(f"Plus: {len(unlabeled_data.filepaths)}.")
91 |
92 | for niter in range(1, num_iter + 1):
93 | log.info(f"NUM PSEUDO SHOTS: {self.config.N_PSEUDOSHOTS}")
94 | pseudolabel_top_k(
95 | self.config.DATASET_NAME,
96 | self.config.N_PSEUDOSHOTS,
97 | self.config.PROMPT_TEMPLATE,
98 | unlabeled_data,
99 | self.unseen_classes,
100 | self.transform,
101 | self.clip_model,
102 | self.label_to_idx,
103 | self.device,
104 | self.config.VIS_ENCODER,
105 | self.config.SPLIT_SEED,
106 | )
107 |
108 | log.info(f"Plus: {len(unlabeled_data.filepaths)}.")
109 | filename = f"pseudolabels/{self.config.DATASET_NAME}_CLIP_{self.config.VIS_ENCODER.replace('/', '')}_iter_{niter}_pseudolabels_spl_{self.config.SPLIT_SEED}.pickle"
110 | with open(filename, "wb") as f:
111 | pickle.dump({"filepaths": unlabeled_data.filepaths, "labels": unlabeled_data.labels}, f)
112 |
113 | # Exploit all the available unlabeled data
114 | if self.config.ALL_UNLABELED:
115 | n_per_class = int((niter + 1) * num_samples / n_unseen)
116 | log.info(f"n_per_class: {n_per_class}")
117 | if n_per_class * n_unseen <= len(original_unlabeled_data.filepaths):
118 | log.info(f"if n_per_class: {n_per_class}")
119 | self.config.N_PSEUDOSHOTS = n_per_class
120 | else:
121 | log.info(f"else new val: {len(original_unlabeled_data.filepaths) / n_unseen}")
122 | # We are making a stong assumption about the distribution of unlabeled data
123 | self.config.N_PSEUDOSHOTS = math.floor(
124 | len(original_unlabeled_data.filepaths) / n_unseen
125 | )
126 |
127 | unlabeled_data = original_unlabeled_data
128 | original_unlabeled_data = copy.deepcopy(unlabeled_data)
--------------------------------------------------------------------------------
/methods/transductive_zsl/textual_fpl.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import clip
4 | import numpy as np
5 | import pandas as pd
6 | import torch
7 | from accelerate import Accelerator
8 | from PIL import Image
9 | from torch import nn
10 | from tqdm import tqdm
11 |
12 | accelerator = Accelerator()
13 |
14 | from methods.transductive_zsl import TextualPrompt
15 | from utils import (
16 | dataset_object,
17 | make_scheduler,
18 | pseudolabel_top_k,
19 | seed_worker,
20 | )
21 |
22 |
23 | g = torch.Generator()
24 | g.manual_seed(0)
25 |
26 | log = logging.getLogger(__name__)
27 |
28 |
29 | class TextualFPL(TextualPrompt):
30 | def __init__(
31 | self,
32 | config,
33 | label_to_idx,
34 | data_folder,
35 | classes,
36 | seen_classes,
37 | unseen_classes,
38 | device,
39 | ):
40 | """This class define Coop baseline.
41 |
42 | :param config: dictionaries of prameters in models_config/coop_baseline_config.yml
43 | :param label_to_idx: dictionary (key, value):(class name, id)
44 | :param classes: list of class names
45 | :param seen_classes: list of seen classes' names
46 | :param unseen_classes: list of unseen classes' names
47 | :param device: device in use
48 | """
49 | super().__init__(
50 | config, label_to_idx, classes, seen_classes, unseen_classes, device
51 | )
52 |
53 | self.data_folder = data_folder
54 |
55 | def create_training_dataset(self, train_data, unlabeled_data=None):
56 | """This function create the dataset for training. Specifically, it
57 | merges pseudo-labels for unseen data and labeled data for seen classes.
58 |
59 | :param train_data: Dataset object - training seen classes (defined in zsl_jpl line 323)
60 | :param unlabeled_data: Dataset object - dataset of unlabeled data for
61 | unseen classes (defined in zsl_jpl line 328)
62 | """
63 |
64 | # Get pseudo-labels for unlabeled data from unseen classes
65 | train_unseen_dataset = pseudolabel_top_k(
66 | self.config,
67 | self.config.DATASET_NAME,
68 | self.config.N_PSEUDOSHOTS,
69 | self.config.PROMPT_TEMPLATE,
70 | unlabeled_data,
71 | self.unseen_classes,
72 | self.transform,
73 | self.clip_model,
74 | self.label_to_idx,
75 | self.device,
76 | self.config.VIS_ENCODER,
77 | self.config.SPLIT_SEED,
78 | )
79 |
80 | # Define the lists of traiing data from seen and unseen classes
81 | unseen_imgs = train_unseen_dataset.filepaths
82 | unseen_labs = train_unseen_dataset.labels
83 |
84 | # Use a portion of the pseudo-labeled data to build a validation set
85 | if self.config.N_PSEUDOSHOTS >= 10:
86 | np.random.seed(self.config.validation_seed)
87 | train_indices = np.random.choice(
88 | range(len(unseen_imgs)),
89 | size=int(len(unseen_imgs) * self.config.ratio_train_val),
90 | replace=False,
91 | )
92 | val_indices = list(
93 | set(range(len(unseen_imgs))).difference(set(train_indices))
94 | )
95 |
96 | self.val_unseen_files = np.array(unseen_imgs)[val_indices]
97 | self.val_unseen_labs = np.array(unseen_labs)[val_indices]
98 |
99 | unseen_imgs = list(np.array(unseen_imgs)[train_indices])
100 | unseen_labs = list(np.array(unseen_labs)[train_indices])
101 |
102 | else:
103 | self.val_unseen_files = None
104 | self.val_unseen_labs = None
105 |
106 | seen_imgs = train_data.filepaths
107 | seen_labs = [self.label_to_idx[l] for l in train_data.labels]
108 |
109 | self.balance_param = len(seen_imgs) / len(unseen_imgs)
110 |
111 | train_data.filepaths = list(unseen_imgs) + list(seen_imgs)
112 | train_data.labels = list(unseen_labs) + list(seen_labs)
113 | train_data.label_id = True
114 |
115 | return train_data
116 |
117 | def define_loss_function(self, logits, labs):
118 |
119 | loss_ce_seen = self.cross_entropy(logits, labs, self.seen_classes)
120 | loss_ce_unseen = self.cross_entropy(logits, labs, self.unseen_classes)
121 |
122 | return loss_ce_seen + self.balance_param * loss_ce_unseen
123 |
124 | def cross_entropy(self, logits, labels, classes):
125 | """This loss computes the probability mass on the
126 | opposite set of classes for each sample.
127 |
128 | :param logits: continuous vector
129 | :param labels: class ids
130 | """
131 |
132 | ids = [self.label_to_idx[c] for c in classes]
133 |
134 | # Get indices of unseen and seen samples in the batch
135 | samples = []
136 |
137 | for idx, l in enumerate(labels):
138 | if l in ids:
139 | samples.append(idx)
140 |
141 | # Get logit sums on unseen samples
142 | if samples:
143 | error = self.loss_func(logits[samples], labels[samples])
144 | else:
145 | error = 0
146 |
147 | return error
148 |
149 |
150 | def get_pseudo_labels(self, unlabeled_examples):
151 | log.info(f"Num unlabeled data: {len(unlabeled_examples)}")
152 | # Get prediction on unlabeled data
153 | std_preds = self.test_predictions(
154 | unlabeled_examples, standard_zsl=True
155 | )
156 |
157 | DatasetObject = dataset_object(self.config.DATASET_NAME)
158 | # 4. Take top-16 pseudo-labels to finetune the student
159 | pseudo_unseen_examples = DatasetObject(
160 | std_preds["id"],
161 | self.data_folder,
162 | transform=self.transform,
163 | augmentations=None,
164 | train=True,
165 | labels=None,
166 | label_map=self.label_to_idx,
167 | class_folder=True,
168 | original_filepaths=unlabeled_examples.filepaths,
169 | )
170 |
171 | pseudo_labels = self.assign_pseudo_labels(
172 | self.config.N_PSEUDOSHOTS, pseudo_unseen_examples
173 | )
174 |
175 | return pseudo_labels
176 |
177 | def assign_pseudo_labels(self, k, unlabeled_data):
178 | # Define text queries
179 | # prompts = [f"{self.template}{' '.join(i.split('_'))}" \
180 | # for i in self.unseen_classes]
181 |
182 | log.info(f"[self.assign_pseudo_labels] Number of prompts: {len(self.unseen_classes)}")
183 |
184 | # Get prompts
185 | self.model.classes = self.unseen_classes
186 | text_features = self.model(self.model.classes)
187 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
188 | log.info(f"TEXT FEATURES SHAPE: {text_features.size()}")
189 |
190 | # to find the top k for each class, each class has it's own "leaderboard"
191 | top_k_leaderboard = {
192 | self.label_to_idx[self.unseen_classes[i]]: []
193 | for i in range(len(self.unseen_classes))
194 | } # maps class idx -> (confidence, image_path) tuple
195 |
196 | for img_path in unlabeled_data.filepaths:
197 | # log.info(f"IMAGEPATH: {img_path}")
198 | img = Image.open(img_path).convert("RGB")
199 | img = torch.unsqueeze(self.transform(img), 0).to(self.device)
200 | with torch.no_grad():
201 | image_features = self.clip_model.encode_image(img)
202 | image_features = image_features / image_features.norm(
203 | dim=-1, keepdim=True
204 | )
205 | # cosine similarity as logits
206 |
207 | logit_scale = self.clip_model.logit_scale.exp()
208 | logits = logit_scale * image_features @ text_features.t()
209 | probs = logits.softmax(dim=-1)
210 | idx_preds = torch.argmax(logits, dim=1)
211 | pred_id = idx_preds.item()
212 | pred = self.label_to_idx[self.unseen_classes[idx_preds.item()]]
213 |
214 | """if predicted class has empty leaderboard, or if the confidence is high
215 | enough for predicted class leaderboard, add the new example
216 | """
217 | prob_score = probs[0][pred_id]
218 | if len(top_k_leaderboard[pred]) < k:
219 | top_k_leaderboard[pred].append((prob_score, img_path))
220 | elif (
221 | top_k_leaderboard[pred][-1][0] < prob_score
222 | ): # if the confidence in predicted class "qualifies" for top-k
223 | # default sorting of tuples is by first element
224 | top_k_leaderboard[pred] = sorted(
225 | top_k_leaderboard[pred] + [(probs[0][pred_id], img_path)],
226 | reverse=True,
227 | )[:k]
228 | else:
229 | # sort the other classes by confidence score
230 | order_of_classes = sorted(
231 | [
232 | (probs[0][j], j)
233 | for j in range(len(self.unseen_classes))
234 | if j != pred_id
235 | ],
236 | reverse=True,
237 | )
238 | for score, index in order_of_classes:
239 | index_dict = self.label_to_idx[self.unseen_classes[index]]
240 | # log.info(f"{classnames[index]}")
241 | # log.info(f"{index_dict}")
242 | if len(top_k_leaderboard[index_dict]) < k:
243 | top_k_leaderboard[index_dict].append(
244 | (probs[0][index], img_path)
245 | )
246 | elif top_k_leaderboard[index_dict][-1][0] < probs[0][index]:
247 | # default sorting of tuples is by first element
248 | top_k_leaderboard[index_dict] = sorted(
249 | top_k_leaderboard[index_dict]
250 | + [((probs[0][index], img_path))],
251 | reverse=True,
252 | )[:k]
253 |
254 | new_imgs = []
255 | new_labels = []
256 | # loop through, and rebuild the dataset
257 | for index, leaderboard in top_k_leaderboard.items():
258 | new_imgs += [tup[1] for tup in leaderboard]
259 | new_labels += [index for _ in leaderboard]
260 |
261 | unlabeled_data.filepaths = new_imgs
262 | unlabeled_data.labels = new_labels
263 | unlabeled_data.label_id = True
264 |
265 | return unlabeled_data
266 |
--------------------------------------------------------------------------------
/methods/transductive_zsl/visual_fpl.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import clip
4 | import numpy as np
5 | import pandas as pd
6 | import torch
7 | from accelerate import Accelerator
8 | from PIL import Image
9 | from torch import nn
10 |
11 | accelerator = Accelerator()
12 |
13 | from methods.transductive_zsl import VisualPrompt
14 | from utils import (
15 | dataset_object,
16 | make_scheduler,
17 | pseudolabel_top_k,
18 | )
19 |
20 |
21 | log = logging.getLogger(__name__)
22 |
23 |
24 | class VisualFPL(VisualPrompt):
25 | def __init__(
26 | self,
27 | config,
28 | label_to_idx,
29 | data_folder,
30 | classes,
31 | seen_classes,
32 | unseen_classes,
33 | device
34 | ):
35 | """This class defines few-pseudo-learning (FPL) VPT's training and evaluation.
36 |
37 | :param config: dictionaries of prameters in models_config/vpt_baseline_config.yml
38 | :param label_to_idx: dictionary (key, value):(class name, id)
39 | :param classes: list of class names
40 | :param seen_classes: list of seen classes' names
41 | :param unseen_classes: list of unseen classes' names
42 | :param device: device in use
43 | """
44 |
45 | super().__init__(
46 | config, label_to_idx, classes, seen_classes, unseen_classes, device
47 | )
48 |
49 | self.data_folder = data_folder
50 |
51 | def create_training_dataset(self, train_data, unlabeled_data=None):
52 | """This function creates the dataset for training. Specifically, it
53 | merges pseudo-labels for unseen data and labeled data for seen classes.
54 |
55 | :param train_data: Dataset object - training seen classes (defined in zsl_jpl line 323)
56 | :param unlabeled_data: Dataset object - dataset of unlabeled data for
57 | unseen classes (defined in zsl_jpl line 328)
58 | """
59 |
60 | # Get pseudo-labels for unlabeled data from unseen classes
61 | train_unseen_dataset = pseudolabel_top_k(
62 | self.config,
63 | self.config.DATASET_NAME,
64 | self.config.N_PSEUDOSHOTS,
65 | self.config.PROMPT_TEMPLATE,
66 | unlabeled_data,
67 | self.unseen_classes,
68 | self.transform,
69 | self.clip_model,
70 | self.label_to_idx,
71 | self.device,
72 | self.config.VIS_ENCODER,
73 | self.config.SPLIT_SEED,
74 | )
75 |
76 | # Define the lists of traiing data from seen and unseen classes
77 | unseen_imgs = train_unseen_dataset.filepaths
78 | unseen_labs = train_unseen_dataset.labels
79 |
80 | # Use a portion of the pseudo-labeled data to build a validation set
81 | if self.config.N_PSEUDOSHOTS >= 10:
82 | np.random.seed(self.config.validation_seed)
83 | train_indices = np.random.choice(
84 | range(len(unseen_imgs)),
85 | size=int(len(unseen_imgs) * self.config.ratio_train_val),
86 | replace=False,
87 | )
88 | val_indices = list(
89 | set(range(len(unseen_imgs))).difference(set(train_indices))
90 | )
91 |
92 | self.val_unseen_files = np.array(unseen_imgs)[val_indices]
93 | self.val_unseen_labs = np.array(unseen_labs)[val_indices]
94 |
95 | unseen_imgs = list(np.array(unseen_imgs)[train_indices])
96 | unseen_labs = list(np.array(unseen_labs)[train_indices])
97 |
98 | else:
99 | self.val_unseen_files = None
100 | self.val_unseen_labs = None
101 |
102 | seen_imgs = train_data.filepaths
103 | seen_labs = [self.label_to_idx[l] for l in train_data.labels]
104 |
105 | self.balance_param = len(seen_imgs) / len(unseen_imgs)
106 |
107 | train_data.filepaths = list(unseen_imgs) + list(seen_imgs)
108 | train_data.labels = list(unseen_labs) + list(seen_labs)
109 | train_data.label_id = True
110 |
111 |
112 | def define_loss_function(self, logits, labs):
113 |
114 | loss_ce_seen = self.cross_entropy(logits, labs, self.seen_classes)
115 | loss_ce_unseen = self.cross_entropy(logits, labs, self.unseen_classes)
116 |
117 | return loss_ce_seen + self.balance_param * loss_ce_unseen
118 |
119 | def cross_entropy(self, logits, labels, classes):
120 | """This loss computes the probability mass on the
121 | opposite set of classes for each sample.
122 |
123 | :param logits: continuous vector
124 | :param labels: class ids
125 | """
126 |
127 | ids = [self.label_to_idx[c] for c in classes]
128 |
129 | # Get indices of unseen and seen samples in the batch
130 | samples = []
131 |
132 | for idx, l in enumerate(labels):
133 | if l in ids:
134 | samples.append(idx)
135 |
136 | if samples:
137 | error = self.loss_func(logits[samples], labels[samples])
138 | else:
139 | error = 0
140 |
141 | return error
142 |
143 | def define_textual_prompts(self, only_unlabelled=None, validation=False):
144 | """This function returns the textual prompts. You can modify the list
145 | of classes of interest.
146 |
147 | :param only_unlabelled: boolean. It is True if the training only involves
148 | pseudo-labeled unseen data
149 | """
150 |
151 | if only_unlabelled:
152 | return [
153 | self.template.format(" ".join(i.split("_")))
154 | for i in self.unseen_classes
155 | ]
156 | else:
157 | if validation:
158 | return [
159 | self.template.format(" ".join(i.split("_")))
160 | for i in self.seen_classes
161 | ]
162 | else:
163 | return [
164 | self.template.format(" ".join(i.split("_"))) for i in self.classes
165 | ]
166 |
167 | def reindex_predicted_labels(self, idx_preds, only_unlabelled=False):
168 | """This function returns the correct index of predictions to compute
169 | model's accuracy.
170 |
171 | :param idx_pred: list of predictions ids
172 | :param only_unlabelled: boolean. It is True if the training only involves
173 | pseudo-labeled unseen data
174 | """
175 |
176 | if only_unlabelled:
177 | return [self.unseen_classes[i.item()] for i in idx_preds]
178 | else:
179 | return [self.classes[i.item()] for i in idx_preds]
180 |
181 | def reindex_true_labels(self, label, only_unlabelled=False):
182 | """This function returns the correct index of true labels.
183 |
184 | :param label: list of labels from data loader
185 | :param only_unlabelled: boolean. It is True if the training only involves
186 | pseudo-labeled unseen data
187 | """
188 |
189 | if only_unlabelled:
190 | return torch.tensor(
191 | [self.unseen_classes.index(self.classes[l.item()]) for l in label]
192 | )
193 | else:
194 | return torch.tensor([l for l in label])
195 |
196 |
197 | def get_pseudo_labels(self, unlabeled_examples):
198 | log.info(f"Num unlabeled data: {len(unlabeled_examples)}")
199 | # Get prediction on unlabeled data
200 | std_preds = self.test_predictions(
201 | unlabeled_examples, standard_zsl=True
202 | )
203 |
204 | DatasetObject = dataset_object(self.config.DATASET_NAME)
205 | # 4. Take top-16 pseudo-labels to finetune the student
206 | pseudo_unseen_examples = DatasetObject(
207 | std_preds["id"],
208 | self.data_folder,
209 | transform=self.transform,
210 | augmentations=None,
211 | train=True,
212 | labels=None,
213 | label_map=self.label_to_idx,
214 | class_folder=True,
215 | original_filepaths=unlabeled_examples.filepaths,
216 | )
217 |
218 | pseudo_labels = self.assign_pseudo_labels(
219 | self.config.N_PSEUDOSHOTS, pseudo_unseen_examples
220 | )
221 |
222 | return pseudo_labels
223 |
224 | def assign_pseudo_labels(self, k, unlabeled_data):
225 | # Define text queries
226 | # prompts = [f"{self.template}{' '.join(i.split('_'))}" \
227 | # for i in self.unseen_classes]
228 | prompts = [
229 | self.template.format(" ".join(i.split("_"))) for i in self.unseen_classes
230 | ]
231 | log.info(f"Number of prompts: {len(prompts)}")
232 |
233 | # Encode text
234 | text = clip.tokenize(prompts).to(self.device)
235 | text_features = self.clip_model.encode_text(text)
236 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
237 |
238 | # to find the top k for each class, each class has it's own "leaderboard"
239 | top_k_leaderboard = {
240 | self.label_to_idx[self.unseen_classes[i]]: []
241 | for i in range(len(self.unseen_classes))
242 | } # maps class idx -> (confidence, image_path) tuple
243 |
244 | for img_path in unlabeled_data.filepaths:
245 | # log.info(f"IMAGEPATH: {img_path}")
246 | img = Image.open(img_path).convert("RGB")
247 | img = torch.unsqueeze(self.transform(img), 0).to(self.device)
248 | with torch.no_grad():
249 | image_features = self.model(img)
250 | image_features = image_features / image_features.norm(
251 | dim=-1, keepdim=True
252 | )
253 | # cosine similarity as logits
254 |
255 | logit_scale = self.clip_model.logit_scale.exp()
256 | logits = logit_scale * image_features @ text_features.t()
257 | probs = logits.softmax(dim=-1)
258 | idx_preds = torch.argmax(logits, dim=1)
259 | pred_id = idx_preds.item()
260 | pred = self.label_to_idx[self.unseen_classes[idx_preds.item()]]
261 |
262 | """if predicted class has empty leaderboard, or if the confidence is high
263 | enough for predicted class leaderboard, add the new example
264 | """
265 | prob_score = probs[0][pred_id]
266 | if len(top_k_leaderboard[pred]) < k:
267 | top_k_leaderboard[pred].append((prob_score, img_path))
268 | elif (
269 | top_k_leaderboard[pred][-1][0] < prob_score
270 | ): # if the confidence in predicted class "qualifies" for top-k
271 | # default sorting of tuples is by first element
272 | top_k_leaderboard[pred] = sorted(
273 | top_k_leaderboard[pred] + [(probs[0][pred_id], img_path)],
274 | reverse=True,
275 | )[:k]
276 | else:
277 | # sort the other classes by confidence score
278 | order_of_classes = sorted(
279 | [
280 | (probs[0][j], j)
281 | for j in range(len(self.unseen_classes))
282 | if j != pred_id
283 | ],
284 | reverse=True,
285 | )
286 | for score, index in order_of_classes:
287 | index_dict = self.label_to_idx[self.unseen_classes[index]]
288 | # log.info(f"{classnames[index]}")
289 | # log.info(f"{index_dict}")
290 | if len(top_k_leaderboard[index_dict]) < k:
291 | top_k_leaderboard[index_dict].append(
292 | (probs[0][index], img_path)
293 | )
294 | elif top_k_leaderboard[index_dict][-1][0] < probs[0][index]:
295 | # default sorting of tuples is by first element
296 | top_k_leaderboard[index_dict] = sorted(
297 | top_k_leaderboard[index_dict]
298 | + [((probs[0][index], img_path))],
299 | reverse=True,
300 | )[:k]
301 |
302 | new_imgs = []
303 | new_labels = []
304 | # loop through, and rebuild the dataset
305 | for index, leaderboard in top_k_leaderboard.items():
306 | new_imgs += [tup[1] for tup in leaderboard]
307 | new_labels += [index for _ in leaderboard]
308 |
309 | unlabeled_data.filepaths = new_imgs
310 | unlabeled_data.labels = new_labels
311 | unlabeled_data.label_id = True
312 |
313 | return unlabeled_data
314 |
--------------------------------------------------------------------------------
/methods/unsupervised_learning/__init__.py:
--------------------------------------------------------------------------------
1 | from .training_strategies import TrainingStrategy
2 | from .visual_prompt import VisualPrompt
3 | from .textual_prompt import TextualPrompt
4 | from .multimodal_prompt import MultimodalPrompt
5 | from .visual_fpl import VisualFPL
6 | from .textual_fpl import TextualFPL
7 | from .multimodal_fpl import MultimodalFPL
8 |
--------------------------------------------------------------------------------
/methods/unsupervised_learning/multimodal_fpl.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import clip
4 | import numpy as np
5 | import pandas as pd
6 | import torch
7 | from accelerate import Accelerator
8 | from PIL import Image
9 | from torch import nn
10 | import math
11 |
12 | accelerator = Accelerator()
13 |
14 | from methods.unsupervised_learning import MultimodalPrompt
15 | from utils import (
16 | dataset_object,
17 | make_scheduler,
18 | pseudolabel_top_k,
19 | )
20 |
21 |
22 | log = logging.getLogger(__name__)
23 |
24 |
25 | class MultimodalFPL(MultimodalPrompt):
26 | def __init__(
27 | self,
28 | config,
29 | label_to_idx,
30 | data_folder,
31 | classes,
32 | seen_classes,
33 | unseen_classes,
34 | device
35 | ):
36 | """This class defines self-trainig UPT's training and evaluation.
37 | :param config: dictionaries of prameters in models_config/upt_baseline_pseudo_config.yml
38 | :param label_to_idx: dictionary (key, value):(class name, id)
39 | :param classes: list of class names
40 | :param seen_classes: list of seen classes' names
41 | :param unseen_classes: list of unseen classes' names
42 | :param device: device in use
43 | """
44 |
45 | super().__init__(
46 | config, label_to_idx, classes, seen_classes, unseen_classes, device
47 | )
48 |
49 | self.data_folder = data_folder
50 |
51 | def create_training_dataset(self, train_data, unlabeled_data=None):
52 | """This function creates the dataset for training. Specifically, it
53 | merges pseudo-labels for unseen data and labeled data for seen classes.
54 | :param train_data: Dataset object - training seen classes (defined in zsl_jpl line 323)
55 | :param unlabeled_data: Dataset object - dataset of unlabeled data for
56 | unseen classes (defined in zsl_jpl line 328)
57 | """
58 |
59 | # Get pseudo-labels for unlabeled data from unseen classes
60 | train_unseen_dataset = pseudolabel_top_k(
61 | self.config,
62 | self.config.DATASET_NAME,
63 | self.config.N_PSEUDOSHOTS,
64 | self.config.PROMPT_TEMPLATE,
65 | unlabeled_data,
66 | self.classes,
67 | self.transform,
68 | self.clip_model,
69 | self.label_to_idx,
70 | self.device,
71 | self.config.VIS_ENCODER,
72 | self.config.SPLIT_SEED
73 | )
74 |
75 | # Define the lists of traiing data from seen and unseen classes
76 | unseen_imgs = train_unseen_dataset.filepaths
77 | unseen_labs = train_unseen_dataset.labels
78 | log.info(f"Number of classes in pseudolabels: {len(set(unseen_labs))}")
79 |
80 | # Use a portion of the pseudo-labeled data to build a validation set
81 | if self.config.N_PSEUDOSHOTS >= 10:
82 | np.random.seed(self.config.validation_seed)
83 | train_indices = np.random.choice(
84 | range(len(unseen_imgs)),
85 | size=int(len(unseen_imgs) * self.config.ratio_train_val),
86 | replace=False,
87 | )
88 | val_indices = list(
89 | set(range(len(unseen_imgs))).difference(set(train_indices))
90 | )
91 |
92 | self.val_unseen_files = np.array(unseen_imgs)[val_indices]
93 | self.val_unseen_labs = np.array(unseen_labs)[val_indices]
94 |
95 | unseen_imgs = list(np.array(unseen_imgs)[train_indices])
96 | unseen_labs = list(np.array(unseen_labs)[train_indices])
97 |
98 | else:
99 | self.val_unseen_files = None
100 | self.val_unseen_labs = None
101 |
102 | train_data.filepaths = list(unseen_imgs)
103 | train_data.labels = list(unseen_labs)
104 | train_data.label_id = True
105 |
106 | def define_loss_function(self, logits, labs):
107 |
108 | loss_ce = self.cross_entropy(logits, labs, self.classes)
109 |
110 | return loss_ce
111 |
112 | def cross_entropy(self, logits, labels, classes):
113 | """This loss computes the probability mass on the
114 | opposite set of classes for each sample.
115 | :param logits: continuous vector
116 | :param labels: class ids
117 | """
118 |
119 | error = self.loss_func(logits, labels)
120 |
121 | return error
122 |
123 | def reindex_predicted_labels(self, idx_preds, only_unlabelled=False):
124 | """This function returns the correct index of predictions to compute
125 | model's accuracy.
126 | :param idx_pred: list of predictions ids
127 | :param only_unlabelled: boolean. It is True if the training only involves
128 | pseudo-labeled unseen data
129 | """
130 |
131 | return [self.classes[i.item()] for i in idx_preds]
132 |
133 | def reindex_true_labels(self, label, only_unlabelled=False):
134 | """This function returns the correct index of true labels.
135 | :param label: list of labels from data loader
136 | :param only_unlabelled: boolean. It is True if the training only involves
137 | pseudo-labeled unseen data
138 | """
139 |
140 | return torch.tensor([l for l in label])
141 |
142 | def get_pseudo_labels(self, unlabeled_examples):
143 | log.info(f"[self.get_pseudo_labels] Num unlabeled data: {len(unlabeled_examples)}")
144 | # Get prediction on unlabeled data
145 | std_preds = self.test_predictions(
146 | unlabeled_examples, standard_zsl=True
147 | )
148 |
149 | DatasetObject = dataset_object(self.config.DATASET_NAME)
150 | # 4. Take top-16 pseudo-labels to finetune the student
151 | pseudo_unseen_examples = DatasetObject(
152 | std_preds["id"],
153 | self.data_folder,
154 | transform=self.transform,
155 | augmentations=None,
156 | train=True,
157 | labels=None,
158 | label_map=self.label_to_idx,
159 | class_folder=True,
160 | original_filepaths=unlabeled_examples.filepaths,
161 | )
162 |
163 | pseudo_labels = self.assign_pseudo_labels(
164 | self.config.N_PSEUDOSHOTS, pseudo_unseen_examples
165 | )
166 |
167 | return pseudo_labels
168 |
169 | def assign_pseudo_labels(self, k, unlabeled_data):
170 |
171 | # to find the top k for each class, each class has it's own "leaderboard"
172 | top_k_leaderboard = {
173 | self.label_to_idx[self.classes[i]]: []
174 | for i in range(len(self.classes))
175 | } # maps class idx -> (confidence, image_path) tuple
176 |
177 | classes = self.classes
178 | for img_path in unlabeled_data.filepaths:
179 | # log.info(f"IMAGEPATH: {img_path}")
180 | img = Image.open(img_path).convert("RGB")
181 | img = torch.unsqueeze(self.transform(img), 0).to(self.device)
182 | with torch.no_grad():
183 | # Get text and image prompts using UPT
184 | # coop_embeddings, vpt_embeddings, vpt_deep_embeddings = self.model(0)
185 | # Calculate text prompts
186 | # text_features = self.text_encoder(coop_embeddings, classes)
187 | text_features, image_features = self.model(img, classes)
188 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
189 | # Calculate image prompts
190 | # image_features = self.image_encoder(img, vpt_embeddings, deep_embds=vpt_deep_embeddings)
191 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
192 |
193 | logit_scale = self.clip_model.logit_scale.exp()
194 | logits = logit_scale * image_features @ text_features.t()
195 | probs = logits.softmax(dim=-1)
196 | idx_preds = torch.argmax(logits, dim=1)
197 | pred_id = idx_preds.item()
198 | pred = self.label_to_idx[self.classes[idx_preds.item()]]
199 |
200 | """if predicted class has empty leaderboard, or if the confidence is high
201 | enough for predicted class leaderboard, add the new example
202 | """
203 | prob_score = probs[0][pred_id]
204 | if len(top_k_leaderboard[pred]) < k:
205 | top_k_leaderboard[pred].append((prob_score, img_path))
206 | elif (
207 | top_k_leaderboard[pred][-1][0] < prob_score
208 | ): # if the confidence in predicted class "qualifies" for top-k
209 | # default sorting of tuples is by first element
210 | top_k_leaderboard[pred] = sorted(
211 | top_k_leaderboard[pred] + [(probs[0][pred_id], img_path)],
212 | reverse=True,
213 | )[:k]
214 | else:
215 | # sort the other classes by confidence score
216 | order_of_classes = sorted(
217 | [
218 | (probs[0][j], j)
219 | for j in range(len(self.classes))
220 | if j != pred_id
221 | ],
222 | reverse=True,
223 | )
224 | for score, index in order_of_classes:
225 | index_dict = self.label_to_idx[self.classes[index]]
226 | # log.info(f"{classnames[index]}")
227 | # log.info(f"{index_dict}")
228 | if len(top_k_leaderboard[index_dict]) < k:
229 | top_k_leaderboard[index_dict].append(
230 | (probs[0][index], img_path)
231 | )
232 | elif top_k_leaderboard[index_dict][-1][0] < probs[0][index]:
233 | # default sorting of tuples is by first element
234 | top_k_leaderboard[index_dict] = sorted(
235 | top_k_leaderboard[index_dict]
236 | + [((probs[0][index], img_path))],
237 | reverse=True,
238 | )[:k]
239 |
240 | new_imgs = []
241 | new_labels = []
242 | # loop through, and rebuild the dataset
243 | for index, leaderboard in top_k_leaderboard.items():
244 | new_imgs += [tup[1] for tup in leaderboard]
245 | new_labels += [index for _ in leaderboard]
246 |
247 | unlabeled_data.filepaths = new_imgs
248 | unlabeled_data.labels = new_labels
249 | unlabeled_data.label_id = True
250 |
251 | return unlabeled_data
252 |
--------------------------------------------------------------------------------
/methods/unsupervised_learning/pseudo_iterative.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | import math
4 | import pickle
5 |
6 | import clip
7 | import numpy as np
8 | import pandas as pd
9 | import scipy.stats as st
10 | import torch
11 | from accelerate import Accelerator
12 | from PIL import Image
13 | from torch import nn
14 |
15 |
16 | accelerator = Accelerator()
17 |
18 | from data import CustomDataset
19 | from models import CustomImageEncoder, ImagePrefixModel
20 | from methods import TeacherStudent
21 | from utils import (
22 | dataset_object,
23 | evaluate_predictions,
24 | make_scheduler,
25 | pseudolabel_top_k,
26 | seed_worker,
27 | save_parameters,
28 | save_pseudo_labels,
29 | )
30 |
31 | g = torch.Generator()
32 | g.manual_seed(0)
33 |
34 | log = logging.getLogger(__name__)
35 |
36 |
37 | class PseudoIterative(TeacherStudent):
38 | def __init__(
39 | self,
40 | config,
41 | label_to_idx,
42 | data_folder,
43 | classes,
44 | seen_classes,
45 | unseen_classes,
46 | device,
47 | ):
48 | super().__init__(
49 | config, label_to_idx, data_folder, classes, seen_classes, unseen_classes, device
50 | )
51 |
52 |
53 | def train(
54 | self,
55 | train_data,
56 | val_data,
57 | unlabeled_data,
58 | test_data,
59 | test_labeled_files,
60 | test_labeles,
61 | ):
62 | # Number of total iterations to cover all unlabeled data
63 | num_iter = int(100/self.config.STEP_QUANTILE)
64 | num_samples = int(len(unlabeled_data) / num_iter)
65 | # Initialize the number of pseudo-labels per class
66 | n_per_class = int(num_samples / len(self.unseen_classes))
67 | n_unseen = len(self.unseen_classes)
68 | if n_per_class * n_unseen <= len(unlabeled_data.filepaths):
69 | # self.num_pseudo_labels_per_class = n_per_class
70 | self.config.N_PSEUDOSHOTS = n_per_class
71 | else:
72 | # self.num_pseudo_labels_per_class = math.floor(len(unlabeled_data.filepaths)/n_unseen)
73 | self.config.N_PSEUDOSHOTS = math.floor(
74 | len(unlabeled_data.filepaths) / n_unseen
75 | )
76 |
77 | log.info(f"We select {self.config.N_PSEUDOSHOTS} pseudolabel per each unseen classes.")
78 | log.info(f"The number of unseen classes is: {len(self.unseen_classes)}.")
79 | log.info(f"Thus we expect an initial number of pseudo labeles equal to {len(self.unseen_classes) * self.config.N_PSEUDOSHOTS}.")
80 | # Create a safe copy of labeled/unlabeled data
81 | original_train_data = copy.deepcopy(train_data)
82 | # log.info(f"Training data labels: {original_train_data.labels}")
83 | original_unlabeled_data = copy.deepcopy(unlabeled_data)
84 | # Original val
85 | original_val_data = copy.deepcopy(val_data)
86 |
87 | # Initialize here first batch of pseudo labels
88 | #self.create_training_dataset(train_data, unlabeled_data)
89 | #log.info(f"The original train data has size: {len(original_train_data.filepaths)}.")
90 | #log.info(f"Plus: {len(unlabeled_data.filepaths)}.")
91 |
92 | for niter in range(1, num_iter + 1):
93 | log.info(f"NUM PSEUDO SHOTS: {self.config.N_PSEUDOSHOTS}")
94 | pseudolabel_top_k(
95 | self.config.DATASET_NAME,
96 | self.config.N_PSEUDOSHOTS,
97 | self.config.PROMPT_TEMPLATE,
98 | unlabeled_data,
99 | self.unseen_classes,
100 | self.transform,
101 | self.clip_model,
102 | self.label_to_idx,
103 | self.device,
104 | self.config.VIS_ENCODER,
105 | self.config.SPLIT_SEED,
106 | )
107 |
108 | log.info(f"Plus: {len(unlabeled_data.filepaths)}.")
109 | filename = f"pseudolabels/{self.config.DATASET_NAME}_CLIP_{self.config.VIS_ENCODER.replace('/', '')}_iter_{niter}_pseudolabels_spl_{self.config.SPLIT_SEED}.pickle"
110 | with open(filename, "wb") as f:
111 | pickle.dump({"filepaths": unlabeled_data.filepaths, "labels": unlabeled_data.labels}, f)
112 |
113 | # Exploit all the available unlabeled data
114 | if self.config.ALL_UNLABELED:
115 | n_per_class = int((niter + 1) * num_samples / n_unseen)
116 | log.info(f"n_per_class: {n_per_class}")
117 | if n_per_class * n_unseen <= len(original_unlabeled_data.filepaths):
118 | log.info(f"if n_per_class: {n_per_class}")
119 | self.config.N_PSEUDOSHOTS = n_per_class
120 | else:
121 | log.info(f"else new val: {len(original_unlabeled_data.filepaths) / n_unseen}")
122 | # We are making a stong assumption about the distribution of unlabeled data
123 | self.config.N_PSEUDOSHOTS = math.floor(
124 | len(original_unlabeled_data.filepaths) / n_unseen
125 | )
126 |
127 | unlabeled_data = original_unlabeled_data
128 | original_unlabeled_data = copy.deepcopy(unlabeled_data)
--------------------------------------------------------------------------------
/methods/unsupervised_learning/textual_fpl.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import clip
4 | import numpy as np
5 | import pandas as pd
6 | import torch
7 | from accelerate import Accelerator
8 | from PIL import Image
9 | from torch import nn
10 | from tqdm import tqdm
11 |
12 | accelerator = Accelerator()
13 |
14 | from methods.unsupervised_learning import TextualPrompt
15 | from utils import (
16 | dataset_object,
17 | make_scheduler,
18 | pseudolabel_top_k,
19 | seed_worker,
20 | )
21 |
22 |
23 | g = torch.Generator()
24 | g.manual_seed(0)
25 |
26 | log = logging.getLogger(__name__)
27 |
28 |
29 | class TextualFPL(TextualPrompt):
30 | def __init__(
31 | self,
32 | config,
33 | label_to_idx,
34 | data_folder,
35 | classes,
36 | seen_classes,
37 | unseen_classes,
38 | device,
39 | ):
40 | """This class define Coop baseline.
41 |
42 | :param config: dictionaries of prameters in models_config/coop_baseline_config.yml
43 | :param label_to_idx: dictionary (key, value):(class name, id)
44 | :param classes: list of class names
45 | :param seen_classes: list of seen classes' names
46 | :param unseen_classes: list of unseen classes' names
47 | :param device: device in use
48 | """
49 | super().__init__(
50 | config, label_to_idx, classes, seen_classes, unseen_classes, device
51 | )
52 |
53 | self.data_folder = data_folder
54 |
55 | def create_training_dataset(self, train_data, unlabeled_data=None):
56 | """This function create the dataset for training. Specifically, it
57 | merges pseudo-labels for unseen data and labeled data for seen classes.
58 |
59 | :param train_data: Dataset object - training seen classes (defined in zsl_jpl line 323)
60 | :param unlabeled_data: Dataset object - dataset of unlabeled data for
61 | unseen classes (defined in zsl_jpl line 328)
62 | """
63 |
64 | # Get pseudo-labels for unlabeled data from unseen classes
65 | train_unseen_dataset = pseudolabel_top_k(
66 | self.config,
67 | self.config.DATASET_NAME,
68 | self.config.N_PSEUDOSHOTS,
69 | self.config.PROMPT_TEMPLATE,
70 | unlabeled_data,
71 | self.classes,
72 | self.transform,
73 | self.clip_model,
74 | self.label_to_idx,
75 | self.device,
76 | self.config.VIS_ENCODER,
77 | self.config.SPLIT_SEED,
78 | )
79 |
80 | # Define the lists of traiing data from seen and unseen classes
81 | unseen_imgs = train_unseen_dataset.filepaths
82 | unseen_labs = train_unseen_dataset.labels
83 |
84 | # Use a portion of the pseudo-labeled data to build a validation set
85 | if self.config.N_PSEUDOSHOTS >= 10:
86 | np.random.seed(self.config.validation_seed)
87 | train_indices = np.random.choice(
88 | range(len(unseen_imgs)),
89 | size=int(len(unseen_imgs) * self.config.ratio_train_val),
90 | replace=False,
91 | )
92 | val_indices = list(
93 | set(range(len(unseen_imgs))).difference(set(train_indices))
94 | )
95 |
96 | self.val_unseen_files = np.array(unseen_imgs)[val_indices]
97 | self.val_unseen_labs = np.array(unseen_labs)[val_indices]
98 |
99 | unseen_imgs = list(np.array(unseen_imgs)[train_indices])
100 | unseen_labs = list(np.array(unseen_labs)[train_indices])
101 |
102 | else:
103 | self.val_unseen_files = None
104 | self.val_unseen_labs = None
105 |
106 | train_data.filepaths = list(unseen_imgs)
107 | train_data.labels = list(unseen_labs)
108 | train_data.label_id = True
109 |
110 | return train_data
111 |
112 | def define_loss_function(self, logits, labs):
113 |
114 | loss_ce = self.cross_entropy(logits, labs, self.classes)
115 |
116 | return loss_ce
117 |
118 | def cross_entropy(self, logits, labels, classes):
119 | """This loss computes the probability mass on the
120 | opposite set of classes for each sample.
121 |
122 | :param logits: continuous vector
123 | :param labels: class ids
124 | """
125 |
126 | error = self.loss_func(logits, labels)
127 |
128 | return error
129 |
130 |
131 | def get_pseudo_labels(self, unlabeled_examples):
132 | log.info(f"Num unlabeled data: {len(unlabeled_examples)}")
133 | # Get prediction on unlabeled data
134 | std_preds = self.test_predictions(
135 | unlabeled_examples, standard_zsl=True
136 | )
137 |
138 | DatasetObject = dataset_object(self.config.DATASET_NAME)
139 | # 4. Take top-16 pseudo-labels to finetune the student
140 | pseudo_unseen_examples = DatasetObject(
141 | std_preds["id"],
142 | self.data_folder,
143 | transform=self.transform,
144 | augmentations=None,
145 | train=True,
146 | labels=None,
147 | label_map=self.label_to_idx,
148 | class_folder=True,
149 | original_filepaths=unlabeled_examples.filepaths,
150 | )
151 |
152 | pseudo_labels = self.assign_pseudo_labels(
153 | self.config.N_PSEUDOSHOTS, pseudo_unseen_examples
154 | )
155 |
156 | return pseudo_labels
157 |
158 | def assign_pseudo_labels(self, k, unlabeled_data):
159 | # Define text queries
160 | # prompts = [f"{self.template}{' '.join(i.split('_'))}" \
161 | # for i in self.unseen_classes]
162 |
163 | log.info(f"[self.assign_pseudo_labels] Number of prompts: {len(self.classes)}")
164 |
165 | # Get prompts
166 | self.model.classes = self.classes
167 | text_features = self.model(self.model.classes)
168 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
169 | log.info(f"TEXT FEATURES SHAPE: {text_features.size()}")
170 |
171 | # to find the top k for each class, each class has it's own "leaderboard"
172 | top_k_leaderboard = {
173 | self.label_to_idx[self.classes[i]]: []
174 | for i in range(len(self.classes))
175 | } # maps class idx -> (confidence, image_path) tuple
176 |
177 | for img_path in unlabeled_data.filepaths:
178 | # log.info(f"IMAGEPATH: {img_path}")
179 | img = Image.open(img_path).convert("RGB")
180 | img = torch.unsqueeze(self.transform(img), 0).to(self.device)
181 | with torch.no_grad():
182 | image_features = self.clip_model.encode_image(img)
183 | image_features = image_features / image_features.norm(
184 | dim=-1, keepdim=True
185 | )
186 | # cosine similarity as logits
187 |
188 | logit_scale = self.clip_model.logit_scale.exp()
189 | logits = logit_scale * image_features @ text_features.t()
190 | probs = logits.softmax(dim=-1)
191 | idx_preds = torch.argmax(logits, dim=1)
192 | pred_id = idx_preds.item()
193 | pred = self.label_to_idx[self.classes[idx_preds.item()]]
194 |
195 | """if predicted class has empty leaderboard, or if the confidence is high
196 | enough for predicted class leaderboard, add the new example
197 | """
198 | prob_score = probs[0][pred_id]
199 | if len(top_k_leaderboard[pred]) < k:
200 | top_k_leaderboard[pred].append((prob_score, img_path))
201 | elif (
202 | top_k_leaderboard[pred][-1][0] < prob_score
203 | ): # if the confidence in predicted class "qualifies" for top-k
204 | # default sorting of tuples is by first element
205 | top_k_leaderboard[pred] = sorted(
206 | top_k_leaderboard[pred] + [(probs[0][pred_id], img_path)],
207 | reverse=True,
208 | )[:k]
209 | else:
210 | # sort the other classes by confidence score
211 | order_of_classes = sorted(
212 | [
213 | (probs[0][j], j)
214 | for j in range(len(self.classes))
215 | if j != pred_id
216 | ],
217 | reverse=True,
218 | )
219 | for score, index in order_of_classes:
220 | index_dict = self.label_to_idx[self.classes[index]]
221 | # log.info(f"{classnames[index]}")
222 | # log.info(f"{index_dict}")
223 | if len(top_k_leaderboard[index_dict]) < k:
224 | top_k_leaderboard[index_dict].append(
225 | (probs[0][index], img_path)
226 | )
227 | elif top_k_leaderboard[index_dict][-1][0] < probs[0][index]:
228 | # default sorting of tuples is by first element
229 | top_k_leaderboard[index_dict] = sorted(
230 | top_k_leaderboard[index_dict]
231 | + [((probs[0][index], img_path))],
232 | reverse=True,
233 | )[:k]
234 |
235 | new_imgs = []
236 | new_labels = []
237 | # loop through, and rebuild the dataset
238 | for index, leaderboard in top_k_leaderboard.items():
239 | new_imgs += [tup[1] for tup in leaderboard]
240 | new_labels += [index for _ in leaderboard]
241 |
242 | unlabeled_data.filepaths = new_imgs
243 | unlabeled_data.labels = new_labels
244 | unlabeled_data.label_id = True
245 |
246 | return unlabeled_data
247 |
--------------------------------------------------------------------------------
/methods/unsupervised_learning/visual_fpl.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import clip
4 | import numpy as np
5 | import pandas as pd
6 | import torch
7 | from accelerate import Accelerator
8 | from PIL import Image
9 | from torch import nn
10 |
11 | accelerator = Accelerator()
12 |
13 | from methods.unsupervised_learning import VisualPrompt
14 | from utils import (
15 | dataset_object,
16 | make_scheduler,
17 | pseudolabel_top_k,
18 | )
19 |
20 |
21 | log = logging.getLogger(__name__)
22 |
23 |
24 | class VisualFPL(VisualPrompt):
25 | def __init__(
26 | self,
27 | config,
28 | label_to_idx,
29 | data_folder,
30 | classes,
31 | seen_classes,
32 | unseen_classes,
33 | device
34 | ):
35 | """This class defines few-pseudo-learning (FPL) VPT's training and evaluation.
36 |
37 | :param config: dictionaries of prameters in models_config/vpt_baseline_config.yml
38 | :param label_to_idx: dictionary (key, value):(class name, id)
39 | :param classes: list of class names
40 | :param seen_classes: list of seen classes' names
41 | :param unseen_classes: list of unseen classes' names
42 | :param device: device in use
43 | """
44 |
45 | super().__init__(
46 | config, label_to_idx, classes, seen_classes, unseen_classes, device
47 | )
48 |
49 | self.data_folder = data_folder
50 |
51 | def create_training_dataset(self, train_data, unlabeled_data=None):
52 | """This function creates the dataset for training. Specifically, it
53 | merges pseudo-labels for unseen data and labeled data for seen classes.
54 |
55 | :param train_data: Dataset object - training seen classes (defined in zsl_jpl line 323)
56 | :param unlabeled_data: Dataset object - dataset of unlabeled data for
57 | unseen classes (defined in zsl_jpl line 328)
58 | """
59 |
60 | # Get pseudo-labels for unlabeled data from unseen classes
61 | train_unseen_dataset = pseudolabel_top_k(
62 | self.config,
63 | self.config.DATASET_NAME,
64 | self.config.N_PSEUDOSHOTS,
65 | self.config.PROMPT_TEMPLATE,
66 | unlabeled_data,
67 | self.classes,
68 | self.transform,
69 | self.clip_model,
70 | self.label_to_idx,
71 | self.device,
72 | self.config.VIS_ENCODER,
73 | self.config.SPLIT_SEED,
74 | )
75 | log.info(f"[self.create_training_dataset] Training data: {len(train_unseen_dataset.filepaths)}")
76 | # Define the lists of traiing data from seen and unseen classes
77 | unseen_imgs = train_unseen_dataset.filepaths
78 | unseen_labs = train_unseen_dataset.labels
79 |
80 | # Use a portion of the pseudo-labeled data to build a validation set
81 | if self.config.N_PSEUDOSHOTS >= 10:
82 | np.random.seed(self.config.validation_seed)
83 | train_indices = np.random.choice(
84 | range(len(unseen_imgs)),
85 | size=int(len(unseen_imgs) * self.config.ratio_train_val),
86 | replace=False,
87 | )
88 | val_indices = list(
89 | set(range(len(unseen_imgs))).difference(set(train_indices))
90 | )
91 |
92 | self.val_unseen_files = np.array(unseen_imgs)[val_indices]
93 | self.val_unseen_labs = np.array(unseen_labs)[val_indices]
94 |
95 | unseen_imgs = list(np.array(unseen_imgs)[train_indices])
96 | unseen_labs = list(np.array(unseen_labs)[train_indices])
97 |
98 | else:
99 | self.val_unseen_files = None
100 | self.val_unseen_labs = None
101 |
102 | train_data.filepaths = list(unseen_imgs)
103 | train_data.labels = list(unseen_labs)
104 | train_data.label_id = True
105 |
106 |
107 | def define_loss_function(self, logits, labs):
108 |
109 | loss_ce = self.cross_entropy(logits, labs, self.classes)
110 | return loss_ce
111 |
112 | def cross_entropy(self, logits, labels, classes):
113 | """This loss computes the probability mass on the
114 | opposite set of classes for each sample.
115 |
116 | :param logits: continuous vector
117 | :param labels: class ids
118 | """
119 |
120 | error = self.loss_func(logits, labels)
121 |
122 | return error
123 |
124 | def define_textual_prompts(self, only_unlabelled=None, validation=False):
125 | """This function returns the textual prompts. You can modify the list
126 | of classes of interest.
127 |
128 | :param only_unlabelled: boolean. It is True if the training only involves
129 | pseudo-labeled unseen data
130 | """
131 |
132 | return [
133 | self.template.format(" ".join(i.split("_"))) for i in self.classes
134 | ]
135 |
136 | def reindex_predicted_labels(self, idx_preds, only_unlabelled=False):
137 | """This function returns the correct index of predictions to compute
138 | model's accuracy.
139 |
140 | :param idx_pred: list of predictions ids
141 | :param only_unlabelled: boolean. It is True if the training only involves
142 | pseudo-labeled unseen data
143 | """
144 |
145 | return [self.classes[i.item()] for i in idx_preds]
146 |
147 | def reindex_true_labels(self, label, only_unlabelled=False):
148 | """This function returns the correct index of true labels.
149 |
150 | :param label: list of labels from data loader
151 | :param only_unlabelled: boolean. It is True if the training only involves
152 | pseudo-labeled unseen data
153 | """
154 |
155 |
156 | return torch.tensor([l for l in label])
157 |
158 |
159 | def get_pseudo_labels(self, unlabeled_examples):
160 | log.info(f"Num unlabeled data: {len(unlabeled_examples)}")
161 | # Get prediction on unlabeled data
162 | std_preds = self.test_predictions(
163 | unlabeled_examples, standard_zsl=True
164 | )
165 |
166 | DatasetObject = dataset_object(self.config.DATASET_NAME)
167 | # 4. Take top-16 pseudo-labels to finetune the student
168 | pseudo_unseen_examples = DatasetObject(
169 | std_preds["id"],
170 | self.data_folder,
171 | transform=self.transform,
172 | augmentations=None,
173 | train=True,
174 | labels=None,
175 | label_map=self.label_to_idx,
176 | class_folder=True,
177 | original_filepaths=unlabeled_examples.filepaths,
178 | )
179 |
180 | pseudo_labels = self.assign_pseudo_labels(
181 | self.config.N_PSEUDOSHOTS, pseudo_unseen_examples
182 | )
183 |
184 | return pseudo_labels
185 |
186 | def assign_pseudo_labels(self, k, unlabeled_data):
187 | # Define text queries
188 | # prompts = [f"{self.template}{' '.join(i.split('_'))}" \
189 | # for i in self.unseen_classes]
190 | prompts = [
191 | self.template.format(" ".join(i.split("_"))) for i in self.classes
192 | ]
193 | log.info(f"Number of prompts: {len(prompts)}")
194 |
195 | # Encode text
196 | text = clip.tokenize(prompts).to(self.device)
197 | text_features = self.clip_model.encode_text(text)
198 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
199 |
200 | # to find the top k for each class, each class has it's own "leaderboard"
201 | top_k_leaderboard = {
202 | self.label_to_idx[self.classes[i]]: []
203 | for i in range(len(self.classes))
204 | } # maps class idx -> (confidence, image_path) tuple
205 |
206 | for img_path in unlabeled_data.filepaths:
207 | # log.info(f"IMAGEPATH: {img_path}")
208 | img = Image.open(img_path).convert("RGB")
209 | img = torch.unsqueeze(self.transform(img), 0).to(self.device)
210 | with torch.no_grad():
211 | image_features = self.model(img)
212 | image_features = image_features / image_features.norm(
213 | dim=-1, keepdim=True
214 | )
215 | # cosine similarity as logits
216 |
217 | logit_scale = self.clip_model.logit_scale.exp()
218 | logits = logit_scale * image_features @ text_features.t()
219 | probs = logits.softmax(dim=-1)
220 | idx_preds = torch.argmax(logits, dim=1)
221 | pred_id = idx_preds.item()
222 | pred = self.label_to_idx[self.classes[idx_preds.item()]]
223 |
224 | """if predicted class has empty leaderboard, or if the confidence is high
225 | enough for predicted class leaderboard, add the new example
226 | """
227 | prob_score = probs[0][pred_id]
228 | if len(top_k_leaderboard[pred]) < k:
229 | top_k_leaderboard[pred].append((prob_score, img_path))
230 | elif (
231 | top_k_leaderboard[pred][-1][0] < prob_score
232 | ): # if the confidence in predicted class "qualifies" for top-k
233 | # default sorting of tuples is by first element
234 | top_k_leaderboard[pred] = sorted(
235 | top_k_leaderboard[pred] + [(probs[0][pred_id], img_path)],
236 | reverse=True,
237 | )[:k]
238 | else:
239 | # sort the other classes by confidence score
240 | order_of_classes = sorted(
241 | [
242 | (probs[0][j], j)
243 | for j in range(len(self.classes))
244 | if j != pred_id
245 | ],
246 | reverse=True,
247 | )
248 | for score, index in order_of_classes:
249 | index_dict = self.label_to_idx[self.classes[index]]
250 | # log.info(f"{classnames[index]}")
251 | # log.info(f"{index_dict}")
252 | if len(top_k_leaderboard[index_dict]) < k:
253 | top_k_leaderboard[index_dict].append(
254 | (probs[0][index], img_path)
255 | )
256 | elif top_k_leaderboard[index_dict][-1][0] < probs[0][index]:
257 | # default sorting of tuples is by first element
258 | top_k_leaderboard[index_dict] = sorted(
259 | top_k_leaderboard[index_dict]
260 | + [((probs[0][index], img_path))],
261 | reverse=True,
262 | )[:k]
263 |
264 | new_imgs = []
265 | new_labels = []
266 | # loop through, and rebuild the dataset
267 | for index, leaderboard in top_k_leaderboard.items():
268 | new_imgs += [tup[1] for tup in leaderboard]
269 | new_labels += [index for _ in leaderboard]
270 |
271 | unlabeled_data.filepaths = new_imgs
272 | unlabeled_data.labels = new_labels
273 | unlabeled_data.label_id = True
274 |
275 | return unlabeled_data
276 |
--------------------------------------------------------------------------------
/methods_config/accelerate_config.yml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | distributed_type: MULTI_GPU
3 | fp16: false
4 | machine_rank: 0
5 | main_process_ip: null
6 | main_process_port: null
7 | main_training_function: main
8 | num_machines: 1
9 | num_processes: 4
--------------------------------------------------------------------------------
/methods_config/accelerate_localtest_config.yml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | distributed_type: 'NO'
3 | fp16: false
4 | machine_rank: 0
5 | main_process_ip: null
6 | main_process_port: null
7 | main_training_function: main
8 | num_machines: 1
9 | num_processes: 1
--------------------------------------------------------------------------------
/methods_config/clip_config.yml:
--------------------------------------------------------------------------------
1 | # Dataset settings
2 | DATASET_NAME: "$DATASET_NAME"
3 | DATASET_DIR: "$DATASET_DIR"
4 |
5 | # Model setting
6 | MODEL: "$MODEL"
7 | VIS_ENCODER: "$VIS_ENCODER"
8 |
9 | # Prompt template
10 | PROMPT_TEMPLATE: 'imported in main.py'
11 |
12 | # Train-validation separation
13 | validation_seed: 0
14 | ratio_train_val: 0.8
15 |
16 | # Batch size
17 | BATCH_SIZE: 8
18 |
19 | # Seed settings
20 | OPTIM_SEED: "$OPTIM_SEED"
21 |
22 | # Classes split
23 | CLASSES_SPLIT: SPLIT_SEED
24 | SPLIT_SEED: "$SPLIT_SEED"
25 |
--------------------------------------------------------------------------------
/methods_config/evaluation_config.yml:
--------------------------------------------------------------------------------
1 | # Dataset name
2 | DATASET_NAME: "$DATASET_NAME"
3 | DATASET_DIR: '/Users/menga/Desktop/github/zsl_taglets/development' # '/users/cmenghin/data/bats/datasets/classification'
4 | # Model
5 | MODEL: "$MODEL"
6 | MODALITY: 'image'
7 | # Visual ecoder
8 | VIS_ENCODER: "$VIS_ENCODER"
9 | # Prompt template
10 | PROMPT_TEMPLATE: 'imported in main.py'
11 | # Number of shats per classes in SSL
12 | N_LABEL: 16
13 | ALPHA: 0.3
14 | # Prefix size
15 | # Text Prefix size
16 | TEXT_PREFIX_SIZE: 4
17 | # Vision Prefix size
18 | VISION_PREFIX_SIZE: 4
19 | # Lightweight transformer dim
20 | TRANSFORMER_DIM: 128
21 | # Use VPT-Deep?
22 | VPT_DEEP: False
23 | PREFIX_SIZE: 16
24 | # Prefix initialization: normal/uniform
25 | VIS_PREFIX_INIT: "normal"
26 | # Initialization mean and variance
27 | MEAN_INIT: 0
28 | VAR_INIT: 0.02
29 | # Seeed to separate train and validation
30 | validation_seed: 0
31 | # Ratio validation
32 | ratio_train_val: 0.8
33 | # Batch size
34 | BATCH_SIZE: 2
35 | # Number of epochs
36 | EPOCHS: 150
37 | # Scheduler
38 | SCHEDULER: "cosine"
39 | # Scheduler warmup epochs
40 | WARMUP_EPOCHS: 5
41 | WARMUP_LR: 0.0001
42 | # Number of accumulation iter
43 | ACCUMULATION_ITER: 1
44 | # Optimizer
45 | OPTIM: "SGD"
46 | LR: 0.1
47 | DECAY: 0.1
48 | STEP_SIZE: 1
49 | # Set seeds
50 | OPTIM_SEED: "$OPTIM_SEED"
51 | # Classes split
52 | CLASSES_SPLIT: SPLIT_SEED
53 | # Seed split
54 | SPLIT_SEED: "$SPLIT_SEED"
--------------------------------------------------------------------------------
/methods_config/grip_multimodal_config.yml:
--------------------------------------------------------------------------------
1 | DATASET_DIR: '/users/cmenghin/data/bats/datasets/classification' #'/Users/menga/Desktop/github/zsl_taglets/development' #
2 | DATASET_NAME: "$DATASET_NAME"
3 | # Model
4 | MODALITY: 'multi'
5 | MODEL: "$MODEL"
6 | # Visual ecoder
7 | VIS_ENCODER: "$VIS_ENCODER"
8 | # Prompt template
9 | PROMPT_TEMPLATE: 'imported in main.py'
10 | # Number of shats per classes in SSL
11 | N_LABEL: 2
12 | ALPHA: 0.3
13 | # Text Prefix size
14 | TEXT_PREFIX_SIZE: 4
15 | # Vision Prefix size
16 | VISION_PREFIX_SIZE: 4
17 | # Pseudolabels
18 | N_PSEUDOSHOTS: 16
19 | STEP_QUANTILE: 10
20 | # Lightweight transformer dim
21 | TRANSFORMER_DIM: 128
22 | # Use VPT-Deep?
23 | VPT_DEEP: False
24 | # Prefix initialization: normal/uniform
25 | VIS_PREFIX_INIT: "normal"
26 | # Initialization mean and variance
27 | MEAN_INIT: 0
28 | VAR_INIT: 0.02
29 | # Seeed to separate train and validation
30 | validation_seed: 0
31 | # Ratio validation
32 | ratio_train_val: 0.8
33 | # Batch size
34 | BATCH_SIZE: 16
35 | # Number of epochs
36 | EPOCHS: 150
37 | # Scheduler
38 | SCHEDULER: "cosine"
39 | # Scheduler warmup epochs
40 | WARMUP_EPOCHS: 5
41 | WARMUP_LR: 0.0001
42 | # Number of accumulation iter
43 | ACCUMULATION_ITER: 1
44 | # Optimizer teacher
45 | OPTIM: "SGD"
46 | LR: 0.01
47 | DECAY: 0.1
48 | STEP_SIZE: 1
49 | STEP_SIZE: 1
50 | # Set seeds
51 | OPTIM_SEED: "$OPTIM_SEED"
52 | # Classes split
53 | CLASSES_SPLIT: SPLIT_SEED
54 | # Seed split
55 | SPLIT_SEED: "$SPLIT_SEED"
56 |
--------------------------------------------------------------------------------
/methods_config/grip_textual_config.yml:
--------------------------------------------------------------------------------
1 | DATASET_DIR: '/users/cmenghin/data/bats/datasets/classification' #'/Users/menga/Desktop/github/zsl_taglets/development' # '/users/cmenghin/data/bats/datasets/classification'
2 | DATASET_NAME: "$DATASET_NAME"
3 | # Model
4 | MODALITY: 'text'
5 | MODEL: "$MODEL"
6 | # Visual ecoder
7 | VIS_ENCODER: "$VIS_ENCODER"
8 | # Prompt template
9 | PROMPT_TEMPLATE: 'imported in main.py'
10 | # Number of shats per classes in SSL
11 | N_LABEL: 2
12 | ALPHA: 0.3
13 | # Prefix size
14 | PREFIX_SIZE: 16
15 | N_PSEUDOSHOTS: 16
16 | STEP_QUANTILE: 10
17 | # Prefix initialization: normal/uniform
18 | VIS_PREFIX_INIT: "normal"
19 | # Initialization mean and variance
20 | MEAN_INIT: 0
21 | VAR_INIT: 0.02
22 | # Seeed to separate train and validation
23 | validation_seed: 0
24 | # Ratio validation
25 | ratio_train_val: 0.8
26 | # Batch size
27 | BATCH_SIZE: 16
28 | # Number of epochs
29 | EPOCHS: 150
30 | # Scheduler
31 | SCHEDULER: "cosine"
32 | # Scheduler warmup epochs
33 | WARMUP_EPOCHS: 5
34 | WARMUP_LR: 0.0001
35 | # Number of accumulation iter
36 | ACCUMULATION_ITER: 1
37 | # Optimizer teacher
38 | OPTIM: "SGD"
39 | LR: 0.1
40 | DECAY: 0.1
41 | STEP_SIZE: 1
42 | STEP_SIZE: 1
43 | # Set seeds
44 | OPTIM_SEED: "$OPTIM_SEED"
45 | # Classes split
46 | CLASSES_SPLIT: SPLIT_SEED
47 | # Seed split
48 | SPLIT_SEED: "$SPLIT_SEED"
49 |
--------------------------------------------------------------------------------
/methods_config/grip_visual_config.yml:
--------------------------------------------------------------------------------
1 | DATASET_DIR: '/users/cmenghin/data/bats/datasets/classification' #'/Users/menga/Desktop/github/zsl_taglets/development' #
2 | DATASET_NAME: "$DATASET_NAME"
3 | # Model
4 | MODALITY: 'image'
5 | MODEL: "$MODEL"
6 | # Visual ecoder
7 | VIS_ENCODER: "$VIS_ENCODER"
8 | # Prompt template
9 | PROMPT_TEMPLATE: 'imported in main.py'
10 | # Number of shats per classes in SSL
11 | N_LABEL: 2
12 | ALPHA: 0.3
13 | # Prefix size
14 | PREFIX_SIZE: 16
15 | N_PSEUDOSHOTS: 16
16 | STEP_QUANTILE: 10
17 | # Prefix initialization: normal/uniform
18 | VIS_PREFIX_INIT: "normal"
19 | # Initialization mean and variance
20 | MEAN_INIT: 0
21 | VAR_INIT: 0.02
22 | # Seeed to separate train and validation
23 | validation_seed: 0
24 | # Ratio validation
25 | ratio_train_val: 0.8
26 | # Batch size
27 | BATCH_SIZE: 16
28 | # Number of epochs
29 | EPOCHS: 150
30 | # Scheduler
31 | SCHEDULER: "cosine"
32 | # Scheduler warmup epochs
33 | WARMUP_EPOCHS: 5
34 | WARMUP_LR: 0.0001
35 | # Number of accumulation iter
36 | ACCUMULATION_ITER: 1
37 | # Optimizer teacher
38 | OPTIM: "SGD"
39 | LR: 0.1
40 | DECAY: 0.1
41 | STEP_SIZE: 1
42 | STEP_SIZE: 1
43 | # Set seeds
44 | OPTIM_SEED: "$OPTIM_SEED"
45 | # Classes split
46 | CLASSES_SPLIT: SPLIT_SEED
47 | # Seed split
48 | SPLIT_SEED: "$SPLIT_SEED"
49 |
--------------------------------------------------------------------------------
/methods_config/iterative_multimodal_fpl_config.yml:
--------------------------------------------------------------------------------
1 | DATASET_DIR: "$DATASET_DIR" #'/users/cmenghin/data/bats/datasets/classification' # '/Users/menga/Desktop/github/zsl_taglets/development' #
2 | DATASET_NAME: "$DATASET_NAME"
3 | # Model
4 | MODALITY: 'multi'
5 | MODEL: "$MODEL"
6 | # Visual ecoder
7 | VIS_ENCODER: "$VIS_ENCODER"
8 | # Prompt template
9 | PROMPT_TEMPLATE: 'imported in main.py'
10 | # Number of shats per classes in SSL
11 | N_LABEL: 2
12 | ALPHA: 0.3
13 | # Text Prefix size
14 | TEXT_PREFIX_SIZE: 4
15 | # Vision Prefix size
16 | VISION_PREFIX_SIZE: 4
17 | # Prefix size
18 | N_PSEUDOSHOTS: 16
19 | STEP_QUANTILE: 10
20 | # Lightweight transformer dim
21 | TRANSFORMER_DIM: 128
22 | # Use VPT-Deep?
23 | VPT_DEEP: False
24 | # Prefix initialization: normal/uniform
25 | VIS_PREFIX_INIT: "normal"
26 | # Initialization mean and variance
27 | MEAN_INIT: 0
28 | VAR_INIT: 0.02
29 | # Seeed to separate train and validation
30 | validation_seed: 0
31 | # Ratio validation
32 | ratio_train_val: 0.8
33 | # Batch size
34 | BATCH_SIZE: 16
35 | # Number of epochs
36 | EPOCHS: 150
37 | # Scheduler
38 | SCHEDULER: "cosine"
39 | # Scheduler warmup epochs
40 | WARMUP_EPOCHS: 5
41 | WARMUP_LR: 0.0001
42 | # Number of accumulation iter
43 | ACCUMULATION_ITER: 1
44 | # Optimizer teacher
45 | OPTIM: "SGD"
46 | LR: 0.01
47 | DECAY: 0.1
48 | STEP_SIZE: 1
49 | STEP_SIZE: 1
50 | # Set seeds
51 | OPTIM_SEED: "$OPTIM_SEED"
52 | # Classes split
53 | CLASSES_SPLIT: SPLIT_SEED
54 | # Seed split
55 | SPLIT_SEED: "$SPLIT_SEED"
56 |
--------------------------------------------------------------------------------
/methods_config/iterative_textual_fpl_config.yml:
--------------------------------------------------------------------------------
1 | DATASET_DIR: "$DATASET_DIR" # '/users/cmenghin/data/bats/datasets/classification' # '/Users/menga/Desktop/github/zsl_taglets/development' #
2 | DATASET_NAME: "$DATASET_NAME"
3 | # Model
4 | MODALITY: 'text'
5 | MODEL: "$MODEL"
6 | # Visual ecoder
7 | VIS_ENCODER: "$VIS_ENCODER"
8 | # Prompt template
9 | PROMPT_TEMPLATE: 'imported in main.py'
10 | # Number of shats per classes in SSL
11 | N_LABEL: 2
12 | ALPHA: 0.3
13 | # Prefix size
14 | PREFIX_SIZE: 16
15 | N_PSEUDOSHOTS: 16
16 | STEP_QUANTILE: 10
17 | # Prefix initialization: normal/uniform
18 | VIS_PREFIX_INIT: "normal"
19 | # Initialization mean and variance
20 | MEAN_INIT: 0
21 | VAR_INIT: 0.02
22 | # Seeed to separate train and validation
23 | validation_seed: 0
24 | # Ratio validation
25 | ratio_train_val: 0.8
26 | # Batch size
27 | BATCH_SIZE: 16
28 | # Number of epochs
29 | EPOCHS: 150
30 | # Scheduler
31 | SCHEDULER: "cosine"
32 | # Scheduler warmup epochs
33 | WARMUP_EPOCHS: 5
34 | WARMUP_LR: 0.0001
35 | # Number of accumulation iter
36 | ACCUMULATION_ITER: 1
37 | # Optimizer teacher
38 | OPTIM: "SGD"
39 | LR: 0.1
40 | DECAY: 0.1
41 | STEP_SIZE: 1
42 | STEP_SIZE: 1
43 | # Set seeds
44 | OPTIM_SEED: "$OPTIM_SEED"
45 | # Classes split
46 | CLASSES_SPLIT: SPLIT_SEED
47 | # Seed split
48 | SPLIT_SEED: "$SPLIT_SEED"
49 |
--------------------------------------------------------------------------------
/methods_config/iterative_visual_fpl_config.yml:
--------------------------------------------------------------------------------
1 | DATASET_DIR: "$DATASET_DIR" # '/users/cmenghin/data/bats/datasets/classification' # '/Users/menga/Desktop/github/zsl_taglets/development' #
2 | DATASET_NAME: "$DATASET_NAME"
3 | # Model
4 | MODALITY: 'image'
5 | MODEL: "$MODEL"
6 | # Visual ecoder
7 | VIS_ENCODER: "$VIS_ENCODER"
8 | # Prompt template
9 | PROMPT_TEMPLATE: 'imported in main.py'
10 | # Number of shats per classes in SSL
11 | N_LABEL: 2
12 | ALPHA: 0.3
13 | # Prefix size
14 | PREFIX_SIZE: 16
15 | N_PSEUDOSHOTS: 16
16 | STEP_QUANTILE: 10
17 | # Prefix initialization: normal/uniform
18 | VIS_PREFIX_INIT: "normal"
19 | # Initialization mean and variance
20 | MEAN_INIT: 0
21 | VAR_INIT: 0.02
22 | # Seeed to separate train and validation
23 | validation_seed: 0
24 | # Ratio validation
25 | ratio_train_val: 0.8
26 | # Batch size
27 | BATCH_SIZE: 16
28 | # Number of epochs
29 | EPOCHS: 150
30 | # Scheduler
31 | SCHEDULER: "cosine"
32 | # Scheduler warmup epochs
33 | WARMUP_EPOCHS: 5
34 | WARMUP_LR: 0.0001
35 | # Number of accumulation iter
36 | ACCUMULATION_ITER: 1
37 | # Optimizer teacher
38 | OPTIM: "SGD"
39 | LR: 0.1
40 | DECAY: 0.1
41 | STEP_SIZE: 1
42 | STEP_SIZE: 1
43 | # Set seeds
44 | OPTIM_SEED: "$OPTIM_SEED"
45 | # Classes split
46 | CLASSES_SPLIT: SPLIT_SEED
47 | # Seed split
48 | SPLIT_SEED: "$SPLIT_SEED"
49 |
--------------------------------------------------------------------------------
/methods_config/multimodal_fpl_config.yml:
--------------------------------------------------------------------------------
1 | # Dataset name
2 | DATASET_DIR: "$DATASET_DIR" #'/users/cmenghin/data/bats/datasets/classification' # '/Users/menga/Desktop/github/zsl_taglets/development' #
3 | DATASET_NAME: "$DATASET_NAME"
4 | # Model
5 | MODALITY: 'multi'
6 | MODEL: "$MODEL"
7 | # Visual ecoder
8 | VIS_ENCODER: "$VIS_ENCODER"
9 | # Prompt template
10 | PROMPT_TEMPLATE: 'imported in main.py'
11 | # Number of shats per classes in SSL
12 | N_LABEL: 2
13 | ALPHA: 0.3
14 | # Text Prefix size
15 | TEXT_PREFIX_SIZE: 4
16 | # Vision Prefix size
17 | VISION_PREFIX_SIZE: 4
18 | # Pseudolabels
19 | N_PSEUDOSHOTS: 16
20 | STEP_QUANTILE: 10
21 | # Lightweight transformer dim
22 | TRANSFORMER_DIM: 128
23 | # Use VPT-Deep?
24 | VPT_DEEP: False
25 | # Prefix initialization: normal/uniform
26 | VIS_PREFIX_INIT: "normal"
27 | # Initialization mean and variance
28 | MEAN_INIT: 0
29 | VAR_INIT: 0.02
30 | # Seeed to separate train and validation
31 | validation_seed: 0
32 | # Ratio validation
33 | ratio_train_val: 0.8
34 | # Batch size
35 | BATCH_SIZE: 16
36 | # Number of epochs
37 | EPOCHS: 150
38 | # Scheduler
39 | SCHEDULER: "cosine"
40 | # Scheduler warmup epochs
41 | WARMUP_EPOCHS: 5
42 | WARMUP_LR: 0.0001
43 | # Number of accumulation iter
44 | ACCUMULATION_ITER: 1
45 | # Optimizer
46 | OPTIM: "SGD"
47 | LR: 0.01
48 | DECAY: 0.1
49 | STEP_SIZE: 1
50 | # Set seeds
51 | OPTIM_SEED: "$OPTIM_SEED"
52 | # Classes split
53 | CLASSES_SPLIT: SPLIT_SEED
54 | # Seed split
55 | SPLIT_SEED: "$SPLIT_SEED"
--------------------------------------------------------------------------------
/methods_config/multimodal_prompt_config.yml:
--------------------------------------------------------------------------------
1 | # Dataset name
2 | DATASET_DIR: "$DATASET_DIR"
3 | DATASET_NAME: "$DATASET_NAME"
4 | # Model
5 | MODALITY: 'multi'
6 | MODEL: "$MODEL"
7 | # Visual ecoder
8 | VIS_ENCODER: "$VIS_ENCODER"
9 | # Prompt template
10 | PROMPT_TEMPLATE: 'imported in main.py'
11 | # Number of shats per classes in SSL
12 | N_LABEL: 2
13 | ALPHA: 0.3
14 | # Text Prefix size
15 | TEXT_PREFIX_SIZE: 4
16 | # Vision Prefix size
17 | VISION_PREFIX_SIZE: 4
18 | # Lightweight transformer dim
19 | TRANSFORMER_DIM: 128
20 | # Use VPT-Deep?
21 | VPT_DEEP: False
22 | # Prefix initialization: normal/uniform
23 | VIS_PREFIX_INIT: "normal"
24 | # Initialization mean and variance
25 | MEAN_INIT: 0
26 | VAR_INIT: 0.02
27 | # Seeed to separate train and validation
28 | validation_seed: 0
29 | # Ratio validation
30 | ratio_train_val: 0.8
31 | # Batch size
32 | BATCH_SIZE: 16
33 | # Number of epochs
34 | EPOCHS: 150
35 | # Scheduler
36 | SCHEDULER: "cosine"
37 | # Scheduler warmup epochs
38 | WARMUP_EPOCHS: 5
39 | WARMUP_LR: 0.0001
40 | # Number of accumulation iter
41 | ACCUMULATION_ITER: 1
42 | # Optimizer
43 | OPTIM: "SGD"
44 | LR: 0.01
45 | DECAY: 0.1
46 | STEP_SIZE: 1
47 | # Set seeds
48 | OPTIM_SEED: "$OPTIM_SEED"
49 | # Classes split
50 | CLASSES_SPLIT: SPLIT_SEED
51 | # Seed split
52 | SPLIT_SEED: "$SPLIT_SEED"
--------------------------------------------------------------------------------
/methods_config/textual_fpl_config.yml:
--------------------------------------------------------------------------------
1 | DATASET_DIR: "$DATASET_DIR" #'/users/cmenghin/data/bats/datasets/classification' # '/Users/menga/Desktop/github/zsl_taglets/development' #
2 | DATASET_NAME: "$DATASET_NAME"
3 | # Model
4 | MODALITY: 'text'
5 | MODEL: "$MODEL"
6 | # Visual ecoder
7 | VIS_ENCODER: "$VIS_ENCODER"
8 | # Prompt template
9 | PROMPT_TEMPLATE: 'imported in main.py'
10 | # Number of shats per classes in SSL
11 | N_LABEL: 2
12 | ALPHA: 0.3
13 | # Prefix size
14 | PREFIX_SIZE: 16
15 | N_PSEUDOSHOTS: 16
16 | STEP_QUANTILE: 10
17 | # Prefix initialization: normal/uniform
18 | VIS_PREFIX_INIT: "normal"
19 | # Initialization mean and variance
20 | MEAN_INIT: 0
21 | VAR_INIT: 0.02
22 | # Seeed to separate train and validation
23 | validation_seed: 0
24 | # Ratio validation
25 | ratio_train_val: 0.8
26 | # Batch size
27 | BATCH_SIZE: 16
28 | # Number of epochs
29 | EPOCHS: 150
30 | # Scheduler
31 | SCHEDULER: "cosine"
32 | # Scheduler warmup epochs
33 | WARMUP_EPOCHS: 5
34 | WARMUP_LR: 0.0001
35 | # Number of accumulation iter
36 | ACCUMULATION_ITER: 1
37 | # Optimizer teacher
38 | OPTIM: "SGD"
39 | LR: 0.1
40 | DECAY: 0.1
41 | STEP_SIZE: 1
42 | STEP_SIZE: 1
43 | # Set seeds
44 | OPTIM_SEED: "$OPTIM_SEED"
45 | # Classes split
46 | CLASSES_SPLIT: SPLIT_SEED
47 | # Seed split
48 | SPLIT_SEED: "$SPLIT_SEED"
49 |
--------------------------------------------------------------------------------
/methods_config/textual_prompt_config.yml:
--------------------------------------------------------------------------------
1 | # Dataset name
2 | DATASET_DIR: "$DATASET_DIR" # '/Users/menga/Desktop/github/zsl_taglets/development' # '/users/cmenghin/data/bats/datasets/classification'
3 | DATASET_NAME: "$DATASET_NAME"
4 | # Model
5 | MODALITY: 'text'
6 | MODEL: "$MODEL"
7 | # Visual ecoder
8 | VIS_ENCODER: "$VIS_ENCODER"
9 | # Prompt template
10 | PROMPT_TEMPLATE: 'imported in main.py'
11 | # Number of shats per classes in SSL
12 | N_LABEL: 2
13 | ALPHA: 0.3
14 | # Prefix size
15 | PREFIX_SIZE: 16
16 | # Prefix initialization: normal/uniform
17 | VIS_PREFIX_INIT: "normal"
18 | # Initialization mean and variance
19 | MEAN_INIT: 0
20 | VAR_INIT: 0.02
21 | # Seeed to separate train and validation
22 | validation_seed: 0
23 | # Ratio validation
24 | ratio_train_val: 0.8
25 | # Batch size
26 | BATCH_SIZE: 16
27 | # Number of epochs
28 | EPOCHS: 150
29 | # Scheduler
30 | SCHEDULER: "cosine"
31 | # Scheduler warmup epochs
32 | WARMUP_EPOCHS: 5
33 | WARMUP_LR: 0.0001
34 | # Number of accumulation iter
35 | ACCUMULATION_ITER: 1
36 | # Optimizer
37 | OPTIM: "SGD"
38 | LR: 0.1
39 | DECAY: 0.1
40 | STEP_SIZE: 1
41 | # Set seeds
42 | OPTIM_SEED: "$OPTIM_SEED"
43 | # Classes split
44 | CLASSES_SPLIT: SPLIT_SEED
45 | # Seed split
46 | SPLIT_SEED: "$SPLIT_SEED"
--------------------------------------------------------------------------------
/methods_config/visual_fpl_config.yml:
--------------------------------------------------------------------------------
1 | DATASET_DIR: "$DATASET_DIR" #'/users/cmenghin/data/bats/datasets/classification' # '/Users/menga/Desktop/github/zsl_taglets/development' #
2 | DATASET_NAME: "$DATASET_NAME"
3 | # Model
4 | MODALITY: 'image'
5 | MODEL: "$MODEL"
6 | # Visual ecoder
7 | VIS_ENCODER: "$VIS_ENCODER"
8 | # Prompt template
9 | PROMPT_TEMPLATE: 'imported in main.py'
10 | # Number of shats per classes in SSL
11 | N_LABEL: 2
12 | ALPHA: 0.3
13 | # Prefix size
14 | PREFIX_SIZE: 16
15 | N_PSEUDOSHOTS: 16
16 | STEP_QUANTILE: 10
17 | # Prefix initialization: normal/uniform
18 | VIS_PREFIX_INIT: "normal"
19 | # Initialization mean and variance
20 | MEAN_INIT: 0
21 | VAR_INIT: 0.02
22 | # Seeed to separate train and validation
23 | validation_seed: 0
24 | # Ratio validation
25 | ratio_train_val: 0.8
26 | # Batch size
27 | BATCH_SIZE: 16
28 | # Number of epochs
29 | EPOCHS: 150
30 | # Scheduler
31 | SCHEDULER: "cosine"
32 | # Scheduler warmup epochs
33 | WARMUP_EPOCHS: 5
34 | WARMUP_LR: 0.0001
35 | # Number of accumulation iter
36 | ACCUMULATION_ITER: 1
37 | # Optimizer teacher
38 | OPTIM: "SGD"
39 | LR: 0.1
40 | DECAY: 0.1
41 | STEP_SIZE: 1
42 | STEP_SIZE: 1
43 | # Set seeds
44 | OPTIM_SEED: "$OPTIM_SEED"
45 | # Classes split
46 | CLASSES_SPLIT: SPLIT_SEED
47 | # Seed split
48 | SPLIT_SEED: "$SPLIT_SEED"
49 |
--------------------------------------------------------------------------------
/methods_config/visual_prompt_config.yml:
--------------------------------------------------------------------------------
1 | # Dataset name
2 | DATASET_DIR: "$DATASET_DIR" # '/Users/menga/Desktop/github/zsl_taglets/development' #
3 | DATASET_NAME: "$DATASET_NAME"
4 | # Model
5 | MODALITY: 'image'
6 | MODEL: "$MODEL"
7 | # Visual ecoder
8 | VIS_ENCODER: "$VIS_ENCODER"
9 | # Prompt template
10 | PROMPT_TEMPLATE: 'imported in main.py'
11 | # Number of shats per classes in SSL
12 | N_LABEL: 2
13 | ALPHA: 0.3
14 | # Prefix size
15 | PREFIX_SIZE: 16
16 | # Prefix initialization: normal/uniform
17 | VIS_PREFIX_INIT: "normal"
18 | # Initialization mean and variance
19 | MEAN_INIT: 0
20 | VAR_INIT: 0.02
21 | # Seeed to separate train and validation
22 | validation_seed: 0
23 | # Ratio validation
24 | ratio_train_val: 0.8
25 | # Batch size
26 | BATCH_SIZE: 16
27 | # Number of epochs
28 | EPOCHS: 150
29 | # Scheduler
30 | SCHEDULER: "cosine"
31 | # Scheduler warmup epochs
32 | WARMUP_EPOCHS: 5
33 | WARMUP_LR: 0.0001
34 | # Number of accumulation iter
35 | ACCUMULATION_ITER: 1
36 | # Optimizer
37 | OPTIM: "SGD"
38 | LR: 0.1
39 | DECAY: 0.1
40 | STEP_SIZE: 1
41 | # Set seeds
42 | OPTIM_SEED: "$OPTIM_SEED"
43 | # Classes split
44 | CLASSES_SPLIT: SPLIT_SEED
45 | # Seed split
46 | SPLIT_SEED: "$SPLIT_SEED"
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip_encoders import (
2 | CustomImageEncoder,
3 | CustomTextEncoder,
4 | ImageEncoder,
5 | TextEncoder,
6 | )
7 | from .prompts_models import (
8 | ImagePrefixModel,
9 | TextPrefixModel,
10 | UPTModel,
11 | )
12 |
--------------------------------------------------------------------------------
/models/clip_encoders.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | import os.path as osp
4 |
5 | import torch
6 | import torch.nn as nn
7 | from clip import clip
8 |
9 |
10 | log = logging.getLogger(__name__)
11 |
12 |
13 | class TextEncoder(nn.Module):
14 | """CLIP text encoder"""
15 |
16 | def __init__(self, clip_model):
17 | super(TextEncoder, self).__init__()
18 | self.clip_model = clip_model
19 |
20 | def forward(self, text):
21 | encoded_text = self.clip_model.encode_text(text)
22 | return encoded_text
23 |
24 |
25 | class CustomTextEncoder(torch.nn.Module):
26 | """This class is adapted from the codebase of "Learning to Compose Soft Prompts for Compositional Zero-Shot Learning"
27 | https://github.com/BatsResearch/csp/blob/main/clip_modules/text_encoder.py"""
28 |
29 | def __init__(self, clip_model, device, dtype):
30 | super(CustomTextEncoder, self).__init__()
31 | self.dtype = dtype
32 | self.clip_model = clip_model
33 | self.transformer = clip_model.transformer
34 | self.positional_embedding = clip_model.positional_embedding
35 | self.ln_final = clip_model.ln_final
36 | self.text_projection = clip_model.text_projection
37 | self.token_embedding = clip_model.token_embedding
38 | self.device = device
39 |
40 | def tokenize(self, text):
41 | return torch.cat([clip.tokenize(tok) for tok in text])
42 |
43 | def forward(self, class_embeddings, classes, enable_pos_emb=True):
44 | """The forward function to compute representations for the prompts.
45 |
46 | :param class_embedding: These are the vectors for class names
47 | :param labels: tensor of labels (already preprocessed without _)
48 | :param enable_pos_emb: We set this to True since we want to account for the order
49 |
50 | Returns:
51 | torch.Tensor: the vector representation of the prompt.
52 | """
53 |
54 | prompts = [
55 | " ".join([" ".join(["X"] * (class_embeddings.size()[1])).strip(), c])
56 | for c in classes
57 | ]
58 | # log.info(f"Extended text: {prompts}")
59 |
60 | token_ids = clip.tokenize(prompts)
61 |
62 | # Get embeddings for the prompt
63 | text_embedding = self.token_embedding(token_ids.to(self.device))
64 | # for idx in range(class_embeddings.size()[0]):
65 | # text_embedding[idx, 1:(class_embeddings[idx].size()[0]+1), :] = class_embeddings[idx]
66 |
67 | text_embedding[:, 1 : (class_embeddings[0].size()[0] + 1), :] = class_embeddings
68 |
69 | text_features = text_embedding.type(self.dtype)
70 | x = (
71 | text_features + self.positional_embedding.type(self.dtype)
72 | if enable_pos_emb
73 | else text_features
74 | )
75 | x = x.permute(1, 0, 2)
76 | # log.info(f'DEVICE: {type(self.device)}')
77 |
78 | if torch.cuda.is_available():
79 | # log.info('WRONG CPU')
80 | x = self.transformer(x)
81 | else:
82 | # log.info('CPU')
83 | x = self.transformer(x.float())
84 | x = x.permute(1, 0, 2)
85 | x = self.ln_final(x)
86 | tf = (
87 | x[torch.arange(x.shape[0]), token_ids.argmax(dim=-1)] # POS of
88 | @ self.text_projection
89 | )
90 | return tf
91 |
92 |
93 | class ImageEncoder(nn.Module):
94 | """CLIP image encoder"""
95 |
96 | def __init__(self, clip_model):
97 | super(ImageEncoder, self).__init__()
98 | self.clip_model = clip_model
99 |
100 | def forward(self, text):
101 | encoded_image = self.clip_model.encode_image(text)
102 | return encoded_image
103 |
104 |
105 | class CustomVisionTransformer(nn.Module):
106 | def __init__(self, vision_transformer):
107 | super().__init__()
108 | self.input_resolution = vision_transformer.input_resolution
109 | self.output_dim = vision_transformer.output_dim
110 | self.conv1 = vision_transformer.conv1
111 |
112 | self.class_embedding = vision_transformer.class_embedding
113 | self.positional_embedding = vision_transformer.positional_embedding
114 | self.ln_pre = vision_transformer.ln_pre
115 |
116 | self.transformer = vision_transformer.transformer
117 |
118 | self.ln_post = vision_transformer.ln_post
119 | self.proj = vision_transformer.proj
120 |
121 | # self.type = config.TYPE
122 |
123 | def forward(
124 | self,
125 | x: torch.Tensor,
126 | image_prefix: torch.Tensor,
127 | pos_emb=True,
128 | deep_embs=None,
129 | ):
130 |
131 | x = self.conv1(x) # shape = [*, width, grid, grid]
132 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
133 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
134 |
135 | x = torch.cat(
136 | [
137 | self.class_embedding.to(x.dtype).to(x.device)
138 | + torch.zeros(
139 | x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
140 | ),
141 | x,
142 | ],
143 | dim=1,
144 | ) # shape = [*, grid ** 2 + 1, width]
145 |
146 | x = x + self.positional_embedding.to(x.dtype) if pos_emb else x
147 |
148 | image_prefix = image_prefix.expand(x.shape[0], -1, -1)
149 | # Here we concat the prefix to the flattened patches
150 | x = torch.cat([
151 | x[:,:1,:],
152 | image_prefix,
153 | x[:,1:,:],
154 | ],
155 | dim=1,)
156 |
157 | # x = torch.cat([
158 | # image_prefix,
159 | # x,
160 | # ],
161 | # dim=1,)
162 |
163 | x = self.ln_pre(x)
164 |
165 | x = x.permute(1, 0, 2) # NLD -> LND
166 | if deep_embs is not None:
167 | B = x.shape[0]
168 | # Code adapted from https://github.com/sIncerass/MVLPT/blob/main/trainers/mvlpt.py#L65
169 | for layer_idx in range(self.visual.transformer.layers):
170 | layer = self.transformer.resblocks[layer_idx]
171 |
172 | if layer_idx == 0:
173 | x = layer(x)
174 | elif layer_idx <= deep_embs.shape[0]:
175 | vpt_emb_deep = self.mvlpt_model.vpt_dropout(self.mvlpt_model.vpt_proj(
176 | deep_embs[layer_idx-1]).expand(B, -1, -1)).to(x.dtype)
177 |
178 | vpt_emb_deep = vpt_emb_deep.permute(1, 0, 2) # NLD -> LND
179 | x = torch.cat((
180 | x[:1, :, :],
181 | vpt_emb_deep,
182 | x[(1+self.mvlpt_model.vpt_n_ctx):, :, :]
183 | ), dim=0)
184 | x = layer(x)
185 | else:
186 | x = self.transformer(x)
187 | x = x.permute(1, 0, 2) # LND -> NLD
188 |
189 | x = self.ln_post(x[:, 0, :])
190 |
191 | if self.proj is not None:
192 | x = x @ self.proj
193 |
194 | return x
195 |
196 |
197 |
198 | class CustomImageEncoder(nn.Module):
199 | """CLIP image encoder"""
200 |
201 | def __init__(self, visual):
202 | super(CustomImageEncoder, self).__init__()
203 | self.visual = CustomVisionTransformer(visual)
204 | self.dtype = self.visual.conv1.weight.dtype
205 |
206 | def forward(self, image, prefix, deep_embds=None):
207 | encoded_image = self.visual(image.type(self.dtype), prefix.type(self.dtype), deep_embs=deep_embds)
208 | return encoded_image
209 |
--------------------------------------------------------------------------------
/models/prompts_models.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import clip
4 | import torch
5 | from torch import nn
6 |
7 | log = logging.getLogger(__name__)
8 |
9 |
10 | class TextPrefixModel(nn.Module):
11 | def __init__(
12 | self, initial_prefix, text_encoder, classes, temperature=0.07, device="cpu"
13 | ):
14 | """Define the model for textual prompt tuning.
15 |
16 | :param initial_prefix: initializes tensor of floats
17 | :param text_encoder: text encoder to use
18 | :param classes: list of classes' names
19 | :param temperature: fix parameter, same as clip
20 | :param device: device in use
21 | """
22 |
23 | super(TextPrefixModel, self).__init__()
24 | self.device = device
25 | self.initialized_prefix = initial_prefix
26 | self.classes = classes
27 |
28 | self.prefix = nn.Parameter(initial_prefix)
29 | self.text_encoder = text_encoder
30 |
31 | def forward(self, classes):
32 | # log.info(f"classes: {classes}")
33 | out = self.text_encoder(self.prefix, classes)
34 | norm_out = out / out.norm(dim=-1, keepdim=True)
35 |
36 | return out
37 |
38 |
39 | class ImagePrefixModel(nn.Module):
40 | def __init__(
41 | self,
42 | initial_prefix,
43 | image_encoder,
44 | temperature=0.07,
45 | device="cpu",
46 | ):
47 | super(ImagePrefixModel, self).__init__()
48 | self.device = device
49 | self.initialized_prefix = initial_prefix
50 |
51 | # Initialize the model's parametwets
52 | self.prefix = nn.Parameter(initial_prefix)
53 | self.image_encoder = image_encoder
54 |
55 | def forward(self, x):
56 | # Combine prefix and class embeddings to get the entire prompt representation for the
57 | # two augmented images
58 | out = self.image_encoder(x, self.prefix)
59 | norm_out = out / out.norm(dim=-1, keepdim=True)
60 |
61 | return out
62 |
63 |
64 | class UPTModel(nn.Module):
65 | def __init__(
66 | self,
67 | coop_embeddings,
68 | vpt_embeddings,
69 | vpt_embeddings_deep,
70 | image_encoder,
71 | text_encoder,
72 | classes,
73 | dim_transformer,
74 | temperature=0.07,
75 | device="cpu",
76 | dtype=torch.float32,
77 | ):
78 | super(UPTModel, self).__init__()
79 | self.device = device
80 | self.classes = classes
81 | self.temperature = temperature
82 | self.dtype = dtype
83 |
84 | # Initialize the model's parameters
85 | self.coop_embeddings = nn.Parameter(coop_embeddings)
86 | self.vpt_embeddings = nn.Parameter(vpt_embeddings)
87 |
88 | self.coop_length = self.coop_embeddings.size()[1]
89 | self.coop_dim = self.coop_embeddings.size()[2]
90 |
91 | self.vpt_length = self.vpt_embeddings.size()[1]
92 | self.vpt_dim = self.vpt_embeddings.size()[2]
93 |
94 | if vpt_embeddings_deep is not None:
95 | self.vpt_embeddings_deep = nn.Parameter(vpt_embeddings_deep)
96 | else:
97 | self.vpt_embeddings_deep = None
98 |
99 | self.proj_coop_pre = nn.Linear(
100 | self.coop_dim,
101 | dim_transformer,
102 | dtype=self.dtype).to(self.device)
103 | self.proj_coop_post = nn.Linear(
104 | dim_transformer,
105 | self.coop_dim,
106 | dtype=self.dtype).to(self.device)
107 | self.proj_vpt_pre = nn.Linear(
108 | self.vpt_dim,
109 | dim_transformer,
110 | dtype=self.dtype).to(self.device)
111 | self.proj_vpt_post = nn.Linear(
112 | dim_transformer,
113 | self.vpt_dim,
114 | dtype=self.dtype).to(self.device)
115 |
116 | self.transformer = clip.model.Transformer(
117 | width=dim_transformer,
118 | layers=1,
119 | heads=1).to(self.device)
120 |
121 | self.image_encoder = image_encoder
122 | self.text_encoder = text_encoder
123 |
124 | # Given coop_embeddings, vpt_embeddings, and vpt_embeddings_deep
125 | # - Project into 128 dim space
126 | # - Run sequence through transformer
127 | # - Project back to CLIP (512) dim space
128 | # (Error when there is no input arg. https://github.com/pytorch/pytorch/pull/37902)
129 | def forward(self, x, classes):
130 | # First, we project the prompts into lower dim space, and concat them, and make them correct dtype
131 | coop_embeddings = self.coop_embeddings
132 | coop_embds = self.proj_coop_pre(coop_embeddings).to(self.device)
133 | if self.vpt_embeddings_deep is not None:
134 | vpt_embds = torch.cat((self.vpt_embeddings, self.vpt_embeddings_deep), dim=0).to(self.device)
135 | vpt_embds = self.proj_vpt_pre(self.vpt_embeddings).to(self.device)
136 | # vpt_embds = vpt_embds.reshape((-1, self.vpt_length, self.vpt_dim)) #flatten if they are deep embds
137 | # concat coop and vpt prompts
138 | prompt_seq = torch.cat((coop_embds, vpt_embds), dim=0).to(torch.float32) # TODO: Fix hacky type change
139 |
140 | # Then, we run the sequence through the transformer
141 | output_seq = self.transformer(prompt_seq).to(torch.float16) # TODO: Fix hacky type change
142 |
143 | # Finally, we project the seq back into prompt space
144 | coop_embs = self.proj_coop_post(output_seq[:len(self.coop_embeddings)].to(self.dtype)).reshape(-1, self.coop_length, self.coop_dim)
145 | vpt_embs = self.proj_vpt_post(output_seq[len(self.coop_embeddings):].to(self.dtype)).reshape(-1, self.vpt_length, self.vpt_dim)
146 | vpt_emb_deep = None if vpt_embs.shape[0] == 1 else vpt_embs[1:, :, :]
147 |
148 | text_out = self.text_encoder(coop_embs, classes)
149 | norm_text_out = text_out / text_out.norm(dim=-1, keepdim=True)
150 | visual_out = self.image_encoder(x, vpt_embs)
151 | norm_visual_out = visual_out / visual_out.norm(dim=-1, keepdim=True)
152 |
153 | return text_out, visual_out
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.15.0
2 | git+https://github.com/openai/CLIP.git
3 | numpy
4 | pandas
5 | Pillow
6 | PyYAML
7 | requests
8 | tqdm
9 | urllib3
10 |
--------------------------------------------------------------------------------
/run_main_clip.py:
--------------------------------------------------------------------------------
1 | import methods.main_CLIP as pipeline
2 |
3 | if __name__ == '__main__':
4 | pipeline.main()
--------------------------------------------------------------------------------
/run_main_ssl.py:
--------------------------------------------------------------------------------
1 | import methods.main_SSL as pipeline
2 |
3 | if __name__ == '__main__':
4 | pipeline.main()
--------------------------------------------------------------------------------
/run_main_trzsl.py:
--------------------------------------------------------------------------------
1 | import methods.main_TRZSL as pipeline
2 |
3 | if __name__ == '__main__':
4 |
5 | pipeline.main()
--------------------------------------------------------------------------------
/run_main_ul.py:
--------------------------------------------------------------------------------
1 | import methods.main_UL as pipeline
2 |
3 | if __name__ == '__main__':
4 | pipeline.main()
--------------------------------------------------------------------------------
/scripts/run_clip.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | for dataset_dir in 'pata/data' ; do # add here the path to the folder containing dataset folders
4 | for vis_encoder in 'ViT-B/32'; do # You can choose among 'ViT-B/32' and 'ViT-L/14'
5 | for split_seed in 500; do # This indicate the split for TRZSL, i.e., 500, 0, or 200. For other learning setting this is 500 as default.
6 | for dataset_name in RESICS45; do # DTD Flowers102 EuroSAT FGVCAircraft MNIST
7 | for model in clip; do
8 | for optim_seed in 1; do # To simply make inference with CLIP we do not need multiple seeds
9 |
10 | export OPTIM_SEED="$optim_seed"
11 | export VIS_ENCODER="$vis_encoder"
12 | export DATASET_NAME="$dataset_name"
13 | export SPLIT_SEED="$split_seed"
14 | export MODEL="$model"
15 | export DATASET_DIR="$dataset_dir"
16 |
17 | python3 ./run_main_clip.py \
18 | --model_config ${model}_config.yml \
19 | --learning_paradigm trzsl # Choose among ul, ssl, and trzsl
20 |
21 | done
22 | done
23 | done
24 | done
25 | done
26 | done
--------------------------------------------------------------------------------
/scripts/run_prompts_ssl.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | for dataset_dir in '/path/data' ; do # add here the path to the folder containing dataset folders
4 | for vis_encoder in 'ViT-B/32'; do # You can choose among 'ViT-B/32' and 'ViT-L/14'
5 | for split_seed in 500; do # This indicate the split for TRZSL, i.e., 500, 0, or 200. For other learning setting this is 500 as default.
6 | for dataset_name in RESICS45; do # DTD Flowers102 EuroSAT FGVCAircraft MNIST
7 | for model in textual_prompt visual_prompt multimodal_prompt; do
8 | for optim_seed in 1; do # 1 2 3 4 5 are the seeds we used
9 |
10 | export OPTIM_SEED="$optim_seed"
11 | export VIS_ENCODER="$vis_encoder"
12 | export DATASET_NAME="$dataset_name"
13 | export SPLIT_SEED="$split_seed"
14 | export MODEL="$model"
15 | export DATASET_DIR="$dataset_dir"
16 |
17 | # Set accelerate configuration file to to accelerate_config.yml when running on GPUs
18 | accelerate launch --config_file methods_config/accelerate_localtest_config.yml run_main_ssl.py --model_config ${model}_config.yml \
19 | --learning_paradigm ssl # Choose among ul, ssl, and trzsl
20 |
21 | done
22 | done
23 | done
24 | done
25 | done
26 | done
--------------------------------------------------------------------------------
/scripts/run_prompts_trzsl.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | for dataset_dir in '/Users/menga/Desktop/github/zsl_taglets/development' ; do # add here the path to the folder containing dataset folders
4 | for vis_encoder in 'ViT-B/32'; do # You can choose among 'ViT-B/32' and 'ViT-L/14'
5 | for split_seed in 500; do # This indicate the split for TRZSL, i.e., 500, 0, or 200. For other learning setting this is 500 as default.
6 | for dataset_name in RESICS45; do # DTD Flowers102 EuroSAT FGVCAircraft MNIST
7 | for model in textual_prompt visual_prompt multimodal_prompt; do
8 | for optim_seed in 1; do # 1 2 3 4 5 are the seeds we used
9 |
10 | export OPTIM_SEED="$optim_seed"
11 | export VIS_ENCODER="$vis_encoder"
12 | export DATASET_NAME="$dataset_name"
13 | export SPLIT_SEED="$split_seed"
14 | export MODEL="$model"
15 | export DATASET_DIR="$dataset_dir"
16 |
17 | # Set accelerate configguration file to to accelerate_config.yml when running on GPUs
18 | accelerate launch --config_file methods_config/accelerate_localtest_config.yml run_main_trzsl.py --model_config ${model}_config.yml \
19 | --learning_paradigm trzsl # Choose among ul, ssl, and trzsl
20 |
21 | done
22 | done
23 | done
24 | done
25 | done
26 | done
--------------------------------------------------------------------------------
/scripts/run_pseudolabels_ssl.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | for dataset_dir in '/path/data' ; do # add here the path to the folder containing dataset folders
4 | for vis_encoder in 'ViT-B/32'; do # You can choose among 'ViT-B/32' and 'ViT-L/14'
5 | for split_seed in 500; do # This indicate the split for TRZSL, i.e., 500, 0, or 200. For other learning setting this is 500 as default.
6 | for dataset_name in RESICS45; do # DTD Flowers102 EuroSAT FGVCAircraft MNIST
7 | for model in textual_fpl; do # Choose among: textual_fpl, visual_fpl, multimodal_fpl, iterative_textual_fpl, iterative_visual_fpl, iterative_multimodal_fpl, grip_textual, grip_visual, grip_multimodal
8 | for optim_seed in 1; do # 1 2 3 4 5 are the seeds we used
9 |
10 | export OPTIM_SEED="$optim_seed"
11 | export VIS_ENCODER="$vis_encoder"
12 | export DATASET_NAME="$dataset_name"
13 | export SPLIT_SEED="$split_seed"
14 | export MODEL="$model"
15 | export DATASET_DIR="$dataset_dir"
16 |
17 | # Set accelerate configuration file to to accelerate_config.yml when running on GPUs
18 | accelerate launch --config_file methods_config/accelerate_localtest_config.yml run_main_ssl.py --model_config ${model}_config.yml \
19 | --learning_paradigm ssl # Choose among ul, ssl, and trzsl
20 |
21 | done
22 | done
23 | done
24 | done
25 | done
26 | done
--------------------------------------------------------------------------------
/scripts/run_pseudolabels_trzsl.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | for dataset_dir in '/path/data' ; do # add here the path to the folder containing dataset folders
4 | for vis_encoder in 'ViT-B/32'; do # You can choose among 'ViT-B/32' and 'ViT-L/14'
5 | for split_seed in 500; do # This indicate the split for TRZSL, i.e., 500, 0, or 200. For other learning setting this is 500 as default.
6 | for dataset_name in RESICS45; do # DTD Flowers102 EuroSAT FGVCAircraft MNIST
7 | for model in textual_fpl; do # Choose among: textual_fpl, visual_fpl, multimodal_fpl, iterative_textual_fpl, iterative_visual_fpl, iterative_multimodal_fpl, grip_textual, grip_visual, grip_multimodal
8 | for optim_seed in 1; do # 1 2 3 4 5 are the seeds we used
9 |
10 | export OPTIM_SEED="$optim_seed"
11 | export VIS_ENCODER="$vis_encoder"
12 | export DATASET_NAME="$dataset_name"
13 | export SPLIT_SEED="$split_seed"
14 | export MODEL="$model"
15 | export DATASET_DIR="$dataset_dir"
16 |
17 | # Set accelerate configguration file to to accelerate_config.yml when running on GPUs
18 | accelerate launch --config_file methods_config/accelerate_localtest_config.yml run_main_trzsl.py --model_config ${model}_config.yml \
19 | --learning_paradigm trzsl # Choose among ul, ssl, and trzsl
20 |
21 | done
22 | done
23 | done
24 | done
25 | done
26 | done
--------------------------------------------------------------------------------
/scripts/run_pseudolabels_ul.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | for dataset_dir in '/path/data' ; do # add here the path to the folder containing dataset folders
4 | for vis_encoder in 'ViT-B/32'; do # You can choose among 'ViT-B/32' and 'ViT-L/14'
5 | for split_seed in 500; do # This indicate the split for TRZSL, i.e., 500, 0, or 200. For other learning setting this is 500 as default.
6 | for dataset_name in RESICS45; do # DTD Flowers102 EuroSAT FGVCAircraft MNIST
7 | for model in textual_fpl; do # Choose among: textual_fpl, visual_fpl, multimodal_fpl, iterative_textual_fpl, iterative_visual_fpl, iterative_multimodal_fpl, grip_textual, grip_visual, grip_multimodal
8 | for optim_seed in 1; do # 1 2 3 4 5 are the seeds we used
9 |
10 | export OPTIM_SEED="$optim_seed"
11 | export VIS_ENCODER="$vis_encoder"
12 | export DATASET_NAME="$dataset_name"
13 | export SPLIT_SEED="$split_seed"
14 | export MODEL="$model"
15 | export DATASET_DIR="$dataset_dir"
16 |
17 | # Set accelerate configguration file to to accelerate_config.yml when running on GPUs
18 | accelerate launch --config_file methods_config/accelerate_localtest_config.yml run_main_ul.py --model_config ${model}_config.yml \
19 | --learning_paradigm ul # Choose among ul, ssl, and trzsl
20 |
21 | done
22 | done
23 | done
24 | done
25 | done
26 | done
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | pip install --upgrade pip
4 | pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
5 | pip install -r requirements.txt
6 | pip install importlib_metadata
7 | pip install pytorch-metric-learning
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip_pseudolabels import pseudolabel_top_k
2 | from .compute_metrics import (
3 | evaluate_predictions,
4 | store_results,
5 | save_parameters,
6 | save_predictions,
7 | save_pseudo_labels,
8 | )
9 | from .prepare_data import (
10 | get_class_names,
11 | get_labeled_and_unlabeled_data,
12 | )
13 | from .schedulers import make_scheduler
14 | from .utils import (
15 | Config,
16 | dataset_object,
17 | seed_worker,
18 | )
19 |
--------------------------------------------------------------------------------
/utils/clip_pseudolabels.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import pickle
4 |
5 | import clip
6 | import torch
7 | from PIL import Image
8 | from tqdm import tqdm
9 |
10 | log = logging.getLogger(__name__)
11 |
12 |
13 | def compute_pseudo_labels(
14 | k,
15 | template,
16 | dataset,
17 | classnames,
18 | transform,
19 | clip_model,
20 | label_to_idx,
21 | device,
22 | filename,
23 | ):
24 | prompts = [f"{template}{' '.join(i.split('_'))}" for i in classnames]
25 | text = clip.tokenize(prompts).to(device)
26 |
27 | if k == 10000000:
28 | log.info(f"Compute pseudo-labeles on all unlabeled data")
29 | new_labels = []
30 | new_imgs = []
31 | for i, image_path in enumerate(tqdm(dataset.filepaths)):
32 | img = Image.open(image_path).convert("RGB")
33 | img = transform(img).to(device)
34 | with torch.no_grad():
35 | logits_per_image, logits_per_text = clip_model(
36 | torch.unsqueeze(img, 0).to(device), text
37 | )
38 | probs = logits_per_image.softmax(dim=-1)
39 | idx_preds = torch.argmax(probs, dim=1)
40 | pred_id = idx_preds.item()
41 | pred = label_to_idx[classnames[idx_preds.item()]]
42 |
43 | new_labels.append(pred)
44 | new_imgs.append(image_path)
45 |
46 | else:
47 |
48 | # to find the top k for each class, each class has it's own "leaderboard"
49 | top_k_leaderboard = {
50 | label_to_idx[classnames[i]]: [] for i in range(len(classnames))
51 | } # maps class idx -> (confidence, image_path) tuple
52 |
53 | log.info(f"Compute {k} pseudo-labeles")
54 | # log.info(f"{label_to_idx}")
55 | for i, image_path in enumerate(tqdm(dataset.filepaths)):
56 | img = Image.open(image_path).convert("RGB")
57 | img = transform(img).to(device)
58 | with torch.no_grad():
59 | logits_per_image, logits_per_text = clip_model(
60 | torch.unsqueeze(img, 0).to(device), text
61 | )
62 | probs = logits_per_image.softmax(dim=-1)
63 | idx_preds = torch.argmax(probs, dim=1)
64 | pred_id = idx_preds.item()
65 | pred = label_to_idx[classnames[idx_preds.item()]]
66 | # log.info(f"{classnames[idx_preds.item()]}")
67 | # log.info(f"{pred}")
68 |
69 | """if predicted class has empty leaderboard, or if the confidence is high
70 | enough for predicted class leaderboard, add the new example
71 | """
72 | prob_score = probs[0][pred_id]
73 | if len(top_k_leaderboard[pred]) < k:
74 | top_k_leaderboard[pred].append((prob_score, image_path))
75 | elif (
76 | top_k_leaderboard[pred][-1][0] < prob_score
77 | ): # if the confidence in predicted class "qualifies" for top-k
78 | # default sorting of tuples is by first element
79 | top_k_leaderboard[pred] = sorted(
80 | top_k_leaderboard[pred] + [(probs[0][pred_id], image_path)],
81 | reverse=True,
82 | )[:k]
83 | else:
84 | # sort the other classes by confidence score
85 | order_of_classes = sorted(
86 | [(probs[0][j], j) for j in range(len(classnames)) if j != pred_id],
87 | reverse=True,
88 | )
89 | for score, index in order_of_classes:
90 | index_dict = label_to_idx[classnames[index]]
91 | # log.info(f"{classnames[index]}")
92 | # log.info(f"{index_dict}")
93 | if len(top_k_leaderboard[index_dict]) < k:
94 | top_k_leaderboard[index_dict].append((probs[0][index], image_path))
95 | elif top_k_leaderboard[index_dict][-1][0] < probs[0][index]:
96 | # default sorting of tuples is by first element
97 | top_k_leaderboard[index_dict] = sorted(
98 | top_k_leaderboard[index_dict]
99 | + [((probs[0][index], image_path))],
100 | reverse=True,
101 | )[:k]
102 |
103 | new_imgs = []
104 | new_labels = []
105 | # loop through, and rebuild the dataset
106 | for index, leaderboard in top_k_leaderboard.items():
107 | # print(len(dataset.imgs))
108 | new_imgs += [tup[1] for tup in leaderboard]
109 | new_labels += [index for _ in leaderboard]
110 |
111 | dataset.filepaths = new_imgs
112 | dataset.labels = new_labels
113 |
114 | with open(filename, "wb") as f:
115 | pickle.dump({"filepaths": new_imgs, "labels": new_labels}, f)
116 |
117 | return dataset
118 |
119 |
120 | def pseudolabel_top_k(
121 | config,
122 | data_name,
123 | k,
124 | template,
125 | dataset,
126 | classnames,
127 | transform,
128 | clip_model,
129 | label_to_idx,
130 | device,
131 | vis_encoder,
132 | split_seed,
133 | ):
134 | filename = f"pseudolabels/{data_name}_{vis_encoder.replace('/', '')}_{config.LEARNING_PARADIGM}_{config.MODEL}_{k}_pseudolabels_split_{split_seed}.pickle"
135 | if os.path.exists(filename):
136 | # print('Load pseudolabels')
137 | with open(filename, "rb") as f:
138 | pseudolabels = pickle.load(f)
139 | new_imgs = pseudolabels["filepaths"]
140 | new_labels = pseudolabels["labels"]
141 |
142 | dataset.filepaths = new_imgs
143 | dataset.labels = new_labels
144 | else:
145 | dataset = compute_pseudo_labels(
146 | k,
147 | template,
148 | dataset,
149 | classnames,
150 | transform,
151 | clip_model,
152 | label_to_idx,
153 | device,
154 | filename,
155 | )
156 |
157 | return dataset
--------------------------------------------------------------------------------
/utils/compute_metrics.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | import pickle
5 |
6 | import numpy as np
7 | import pandas as pd
8 | import scipy.stats as st
9 | import torch
10 | from accelerate import Accelerator
11 |
12 | accelerator = Accelerator()
13 |
14 |
15 | log = logging.getLogger(__name__)
16 |
17 |
18 | def evaluate_predictions(
19 | config,
20 | df_predictions,
21 | test_labeled_files,
22 | labels,
23 | unseen_classes,
24 | seen_classes=None
25 | ):
26 | df_test = pd.DataFrame({"id": test_labeled_files, "true": labels})
27 | df_test["id"] = df_test["id"].apply(lambda x: x.split("/")[-1])
28 | # log.info(f"DF TEST: {df_test.head(5)}")
29 | # log.info(f"DF PREDS: {df_predictions.head(5)}")
30 | df_predictions = pd.merge(df_predictions, df_test, on="id")
31 |
32 | if config.LEARNING_PARADIGM == 'ul' or config.LEARNING_PARADIGM == 'ssl':
33 | accuracy = (
34 | np.sum(df_predictions["class"] == df_predictions["true"])
35 | / df_predictions.shape[0]
36 | )
37 |
38 | return accuracy, None, None
39 |
40 | else:
41 | # Compute unseen accuracy
42 | unseen_predictions = df_predictions[df_predictions["true"].isin(unseen_classes)]
43 | unseen_accuracy = (
44 | np.sum(unseen_predictions["class"] == unseen_predictions["true"])
45 | / unseen_predictions.shape[0]
46 | )
47 | # Compute seen accuracy
48 | seen_predictions = df_predictions[df_predictions["true"].isin(seen_classes)]
49 | seen_accuracy = (
50 | np.sum(seen_predictions["class"] == seen_predictions["true"])
51 | / seen_predictions.shape[0]
52 | )
53 |
54 | harmonic_mean = st.hmean([unseen_accuracy, seen_accuracy])
55 |
56 | return unseen_accuracy, seen_accuracy, harmonic_mean
57 |
58 | def store_results(
59 | obj_conf,
60 | std_response
61 | ):
62 | """The function stores results of the model in a json.
63 |
64 | :param obj_config: class object that stores configurations
65 | :param std_response: for UL and SSL it is a variable corresponding
66 | to the accuracy of the model. For TRZSL is is a tuple of seen,
67 | unseen, and harmonic accuracy.
68 | """
69 | if obj_conf.LEARNING_PARADIGM == 'trzsl':
70 | # Store results
71 | if accelerator.is_local_main_process:
72 | results_to_store = {
73 | "model": obj_conf.MODEL,
74 | "config": obj_conf.__dict__,
75 | # "std_accuracy": std_response,
76 | "harmonic_mean": std_response[2], #harmonic_mean,
77 | "seen_accuracy": std_response[1], # seen_accuracy,
78 | "unseen_accuracy": std_response[0] # unseen_accuracy,
79 | }
80 | else:
81 | # Store results
82 | if accelerator.is_local_main_process:
83 | results_to_store = {
84 | "model": obj_conf.MODEL,
85 | "config": obj_conf.__dict__,
86 | "accuracy": std_response[0],
87 | }
88 |
89 |
90 | if accelerator.is_local_main_process:
91 | file_name = f"results_model_{obj_conf.MODEL}.json"
92 |
93 | # Check if the file already exists
94 | if os.path.exists(file_name):
95 | # If the file exists, open it in append mode
96 | with open(file_name, "a") as f:
97 | # Append the res dictionary to the file
98 | f.write(json.dumps(results_to_store) + "\n")
99 | else:
100 | # If the file doesn't exist, create a new file
101 | with open(file_name, "w") as f:
102 | # Write the res dictionary to the file
103 | f.write(json.dumps(results_to_store) + "\n")
104 |
105 | def save_parameters(obj, config, iteration=None):
106 | """ Save in a pickle the parameters used for
107 | evaluation.
108 |
109 | :param obj: object to save
110 | :param config: object with method configurations
111 | :param iteration: indicate the number of iteration for iterative strategies
112 | """
113 |
114 | if iteration is None:
115 | file_name = f"trained_prompts/{config.DATASET_NAME}_{config.LEARNING_PARADIGM}_{config.MODEL}_{config.VIS_ENCODER.replace('/','')}_opt_{config.OPTIM_SEED}_spl_{config.SPLIT_SEED}.pickle"
116 | else:
117 | file_name = f"trained_prompts/{config.DATASET_NAME}_{config.LEARNING_PARADIGM}_{config.MODEL}_{config.VIS_ENCODER.replace('/','')}_iter_{iteration}_opt_{config.OPTIM_SEED}_spl_{config.SPLIT_SEED}.pickle"
118 |
119 | if config.MODALITY == 'multi':
120 | names = [
121 | 'transformer',
122 | 'proj_coop_pre',
123 | 'proj_coop_post',
124 | 'proj_vpt_pre',
125 | 'proj_vpt_post',
126 | 'coop_embeddings',
127 | 'deep_vpt',
128 | 'vpt_embeddings'
129 | ]
130 | for idx, param in enumerate(obj):
131 | if names[idx] in [
132 | 'transformer',
133 | 'proj_coop_pre',
134 | 'proj_coop_post',
135 | 'proj_vpt_pre',
136 | 'proj_vpt_post',
137 | ]:
138 | ff = file_name.split('.')[:-1][0]
139 | torch.save(obj[idx], f'{ff}_{names[idx]}.pt')
140 | else:
141 | ff = file_name.split('.')[:-1][0]
142 | with open(f'{ff}_{names[idx]}.pickle', 'wb') as f:
143 | pickle.dump(obj[idx], f)
144 |
145 | else:
146 | with open(file_name, 'wb') as f:
147 | pickle.dump(obj, f)
148 |
149 |
150 | def save_pseudo_labels(imgs, labs, config, iteration):
151 |
152 | filename = f"pseudolabels/{config.DATASET_NAME}_{config.LEARNING_PARADIGM}_{config.MODEL}_{config.VIS_ENCODER.replace('/', '')}_iter_{iteration}_opt_{config.OPTIM_SEED}_spl_{config.SPLIT_SEED}.pickle"
153 | with open(filename, "wb") as f:
154 | pickle.dump({"filepaths": imgs, "labels": labs}, f)
155 |
156 |
157 | def save_predictions(obj, config, iteration=None):
158 | """ Save in a pickle the parameters used for
159 | evaluation.
160 |
161 | :param obj: object to save
162 | :param config: object with method configurations
163 | """
164 |
165 | if iteration is None:
166 | file_name = f"evaluation/{config.DATASET_NAME}_{config.LEARNING_PARADIGM}_{config.MODEL}_{config.VIS_ENCODER.replace('/','')}_opt_{config.OPTIM_SEED}_spl_{config.SPLIT_SEED}.pickle"
167 | else:
168 | file_name = f"evaluation/{config.DATASET_NAME}_{config.LEARNING_PARADIGM}_{config.MODEL}_{config.VIS_ENCODER.replace('/','')}_iter_{iteration}_opt_{config.OPTIM_SEED}_spl_{config.SPLIT_SEED}.pickle"
169 |
170 | with open(file_name, 'wb') as f:
171 | pickle.dump(obj, f)
--------------------------------------------------------------------------------
/utils/schedulers.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 |
4 | import torch.optim as optim
5 | from torch.optim import lr_scheduler as scheduler
6 | from torch.optim.lr_scheduler import LambdaLR
7 |
8 | log = logging.getLogger(__name__)
9 |
10 |
11 | def make_scheduler(optimizer, config, double=False, teacher=False):
12 | warmup = config.WARMUP_EPOCHS
13 | if double:
14 | if teacher:
15 | total_iters = config.t_EPOCHS
16 | else:
17 | total_iters = config.s_EPOCHS
18 | else:
19 | total_iters = config.EPOCHS
20 | if config.SCHEDULER == "cosine":
21 | lr_scheduler = WarmupCosineSchedule(
22 | optimizer, warmup_steps=warmup, t_total=total_iters
23 | )
24 | elif config.SCHEDULER == "one_warmup_epoch":
25 | lr_scheduler = LambdaLR(
26 | optimizer,
27 | lr_lambda=lambda epoch: config.WARMUP_LR / config.LR if epoch == 0 else 1,
28 | )
29 | else:
30 | lr_scheduler = scheduler.StepLR(
31 | optimizer, step_size=config.STEP_SIZE, gamma=0.1
32 | )
33 | return lr_scheduler
34 |
35 |
36 | class WarmupCosineSchedule(LambdaLR):
37 | """Linear warmup and then cosine decay.
38 | Linearly increases learning rate from 0 to 1 over `warmup_steps`.
39 | Decreases learning rate from 1. to 0. over remaining
40 | `t_total - warmup_steps` steps following a cosine curve.
41 | If `cycles` (default=0.5) is different from default, learning rate
42 | follows cosine function after warmup.
43 | """
44 |
45 | def __init__(self, optimizer, warmup_steps, t_total, cycles=0.5, last_epoch=-1):
46 | self.warmup_steps = warmup_steps
47 | self.t_total = t_total
48 | self.cycles = cycles
49 | # log.info(f"vars: {self.warmup_steps}, {self.t_total}")
50 | super(WarmupCosineSchedule, self).__init__(
51 | optimizer, self.lr_lambda, last_epoch=last_epoch, verbose=True
52 | )
53 |
54 | def lr_lambda(self, step):
55 | if step < self.warmup_steps:
56 | # log.info(f"STEP: {step}, LR1: {float(step) / float(max(1.0, self.warmup_steps))}")
57 | return float(step) / float(max(1.0, self.warmup_steps))
58 | # progress after warmup
59 | progress = float(step - self.warmup_steps) / float(
60 | max(1, self.t_total - self.warmup_steps)
61 | )
62 | # log.info(f"STEP: {step}, LR2: {max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))}")
63 | return max(
64 | 0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.cycles) * 2.0 * progress))
65 | )
66 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import random
3 |
4 | import numpy as np
5 | import torch
6 |
7 |
8 | log = logging.getLogger(__name__)
9 |
10 |
11 | def dataset_object(dataset_name):
12 | if dataset_name == "aPY":
13 | from data import aPY as DataObject
14 | elif dataset_name == "Animals_with_Attributes2":
15 | from data import AwA2 as DataObject
16 | elif dataset_name == "EuroSAT":
17 | from data import EuroSAT as DataObject
18 | elif dataset_name == "DTD":
19 | from data import DTD as DataObject
20 | elif dataset_name == "sun397":
21 | from data import SUN397 as DataObject
22 | elif dataset_name == "CUB":
23 | from data import CUB as DataObject
24 | elif dataset_name == "RESICS45":
25 | from data import RESICS45 as DataObject
26 | elif dataset_name == "FGVCAircraft":
27 | from data import FGVCAircraft as DataObject
28 | elif dataset_name == "MNIST":
29 | from data import MNIST as DataObject
30 | elif dataset_name == "Flowers102":
31 | from data import Flowers102 as DataObject
32 |
33 | return DataObject
34 |
35 |
36 | def seed_worker(worker_id):
37 | worker_seed = torch.initial_seed() % 2**32
38 | np.random.seed(worker_seed)
39 | random.seed(worker_seed)
40 |
41 |
42 | class Config(object):
43 | def __init__(self, config):
44 | for k, v in config.items():
45 | setattr(self, k, v)
46 |
--------------------------------------------------------------------------------