├── .gitignore ├── LICENSE ├── README.md ├── docs ├── Fig1.png └── protoclip_model_structure.png └── src ├── open_clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── factory.py ├── loss.py ├── model.py ├── model_configs │ ├── RN101-quickgelu.json │ ├── RN101.json │ ├── RN50-quickgelu.json │ ├── RN50.json │ ├── RN50x16.json │ ├── RN50x4.json │ ├── ViT-B-16-plus-240.json │ ├── ViT-B-16-plus.json │ ├── ViT-B-16.json │ ├── ViT-B-32-plus-256.json │ ├── ViT-B-32-quickgelu.json │ ├── ViT-B-32.json │ ├── ViT-H-14.json │ ├── ViT-H-16.json │ ├── ViT-L-14-336.json │ ├── ViT-L-14.json │ ├── ViT-L-16-320.json │ ├── ViT-L-16.json │ ├── ViT-g-14.json │ ├── timm-efficientnetv2_rw_s.json │ ├── timm-resnet50d.json │ ├── timm-resnetaa50d.json │ ├── timm-resnetblur50.json │ ├── timm-swin_base_patch4_window7_224.json │ ├── timm-vit_base_patch16_224.json │ ├── timm-vit_base_patch32_224.json │ └── timm-vit_small_patch16_224.json ├── openai.py ├── pretrained.py ├── timm_model.py ├── tokenizer.py ├── transform.py ├── utils.py └── version.py ├── training ├── .gitignore ├── __init__.py ├── clustering.py ├── data.py ├── distributed.py ├── evaluations │ ├── analyze_features.py │ ├── coco_retrieval.py │ ├── evaluation.py │ ├── linear_eval.py │ └── zero_shot.py ├── logger.py ├── loss.py ├── main.py ├── params.py ├── pretrained_transformers.py ├── scheduler.py └── train.py └── utils ├── RoBERTa.py ├── evaluate_checkpoints.py ├── gather_cc.py └── plot_pairs.py /.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | wandb/ 3 | models/ 4 | features/ 5 | results/ 6 | scripts/ 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | sync.sh 138 | gpu1sync.sh 139 | .idea 140 | *.pdf 141 | **/._* 142 | **/*DS_* 143 | **.jsonl 144 | src/sbatch 145 | src/misc 146 | .vscode 147 | src/debug 148 | core.* 149 | 150 | # Allow 151 | !src/evaluation/misc/results_dbs/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | MIT License 3 | 4 | ProtoCLIP 5 | Copyright (c) 2022 Megvii Research 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE 24 | 25 | -------------------------------------------------------------------------- 26 | 27 | OpenCLIP 28 | Copyright (c) 2012-2021 Gabriel Ilharco, Mitchell Wortsman, 29 | Nicholas Carlini, Rohan Taori, Achal Dave, Vaishaal Shankar, 30 | John Miller, Hongseok Namkoong, Hannaneh Hajishirzi, Ali Farhadi, 31 | Ludwig Schmidt 32 | 33 | Permission is hereby granted, free of charge, to any person obtaining 34 | a copy of this software and associated documentation files (the 35 | "Software"), to deal in the Software without restriction, including 36 | without limitation the rights to use, copy, modify, merge, publish, 37 | distribute, sublicense, and/or sell copies of the Software, and to 38 | permit persons to whom the Software is furnished to do so, subject to 39 | the following conditions: 40 | 41 | The above copyright notice and this permission notice shall be 42 | included in all copies or substantial portions of the Software. 43 | 44 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 45 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 46 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 47 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 48 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 49 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 50 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # **Prototypical Contrastive Language Image Pretraining** 2 | Welcome to the official PyTorch implementation of ProtoCLIP in our paper *[ProtoCLIP: Prototypical Contrastive Language Image Pretraining](https://arxiv.org/abs/2206.10996)*, in IEEE Transactions on Neural Networks and Learning Systems (TNNLS). 3 | 4 | by 5 | [Delong Chen](https://chendelong.world/), 6 | Zhao Wu, 7 | [Fan Liu](https://multimodality.group/), 8 | Zaiquan Yang, 9 | Shaoqiu Zheng, 10 | [Ying Tan](https://www.cil.pku.edu.cn/), and 11 | Erjing, Zhou 12 | 13 | > **Abstract**: 14 | > Contrastive Language Image Pretraining (CLIP) received widespread attention since its learned representations can be transferred well to various downstream tasks. During CLIP training, the InfoNCE objective aims to align positive image-text pairs and separate negative ones. In this paper, we show a representation grouping effect during this process: the InfoNCE objective indirectly groups semantically similar representations together via randomly emerged within-modal anchors. 15 | > 16 | > We introduce **Proto**typical **C**ontrastive **L**anguage **I**mage **P**retraining (ProtoCLIP) to enhance such grouping by boosting its efficiency and increasing its robustness against modality gap. Specifically, ProtoCLIP sets up prototype-level discrimination between image and text spaces, which efficiently transfers higher-level structural knowledge. We further propose **P**rototypical **B**ack **T**ranslation (PBT) to decouple representation grouping from representation alignment, resulting in effective learning of meaningful representations under large modality gap. PBT also enables us to introduce additional external teachers with richer prior knowledge. ProtoCLIP is trained with an online episodic training strategy, which makes it can be scaled up to unlimited amounts of data. 17 | > 18 | > Combining the above novel designs, we train our ProtoCLIP on Conceptual Captions and achieved an +5.81% ImageNet linear probing improvement and an +2.01% ImageNet zero-shot classification improvement. 19 | 20 | ![Fig1](docs/Fig1.png) 21 | 22 | ![protoclip_model_structure](docs/protoclip_model_structure.png) 23 | 24 | 🔔 **Updates** 25 | 26 | - **2022-06-23**: Preprint of ProtoCLIP paper is available on ArXiv. See [this url](https://arxiv.org/abs/2206.10996). 27 | - **2022-06-22**: Initial release of ProtoCLIP training code. 28 | 29 | 🚀 **What can you get from this repo** 30 | - Training CLIP and loading pretrained CLIP weights via OpenCLIP. 31 | - ProtoCLIP implementation. 32 | - Episodic training 33 | - Prototypical losses, Prototype Back Translation (PBT) 34 | - RoBERTa external teacher 35 | - Evaluations of CLIP and ProtoCLIP, including: 36 | - Zero-shot classification evaluations on 10 downstream datasets. 37 | - MS-COCO retrieval and [modality gap](https://arxiv.org/abs/2203.02053) evaluations. 38 | - Pytorch-based fast linear probing. 39 | - Clustering evaluations: Adjusted Rand Index (ARI) and Adjusted Mutual Information (AMI). 40 | - ProtoCLIP visualizations: TSNE and clustered samples. 41 | - Experimental support of loading pretrained language model as text tower initialization. 42 | 43 | # Requirements 44 | 45 | ## 1. Install Dependencies 46 | - Create a conda environment and install PyTorch: 47 | 48 | ```bash 49 | conda create -n protoclip python=3.8 50 | conda activate protoclip 51 | ``` 52 | 53 | This repo requirs PyTorch (1.11.0) and torchvision. Please install them via https://pytorch.org/get-started/locally 54 | 55 | - Clone this repo: 56 | 57 | ```bash 58 | git clone https://github.com/megvii-research/protoclip 59 | cd protoclip 60 | export PYTHONPATH="$PYTHONPATH:$PWD/src" 61 | ``` 62 | **Note**: If import error is occured later, run `export PYTHONPATH="$PYTHONPATH:$PWD/src"` again. 63 | 64 | - Install additional dependencies: 65 | ```bash 66 | conda install pandas scikit-learn faiss-gpu ftfy tqdm matplotlib pycocotools 67 | pip install pytorch-transformers 68 | conda install wandb # if you want to use wandb for better logging 69 | ``` 70 | **Note**: This codebase integrates [pytorch-transformers](https://pypi.org/project/pytorch-transformers) to initalize the text tower with large pretrained language model (experimental). 71 | 72 | 73 | ## 2. Prepare Pretraining Data 74 | This codebase reads a `CSV` file (separated by `\t`) with two columns: a path to an image ("filepath" by default), and a text caption ("title" by default). 75 | 76 | | filepath | title | 77 | |-------------------|----------------------------| 78 | | path/to/image.jpg | A very typical bus station | 79 | | ... | ... | 80 | 81 | The script `src/utils/gather_cc.py` will collect the [Conceptual Captions](https://github.com/google-research-datasets/conceptual-captions) (CC3M) dataset. First, download the Conceptual Captions URLs from [here](https://ai.google.com/research/ConceptualCaptions/download), then run the following script: 82 | 83 | ```bash 84 | python3 src/utils/gather_cc.py path/to/Train_GCC-training.tsv 85 | ``` 86 | 87 | **Note**: The requirement of CC3M validation data of OpenCLIP is removed in this codebase. The CC3M dataset was made public by Google in 2018. As noted in our paper, the number of accessible images keeps drooping due to expired image links. This issue is raised by several recent works. In this work, since we can only collect 2,643,718 images (concurrent to our work, [CyCLIP](https://arxiv.org/abs/2205.14459) collected 2,631,703 images), we randomly sample a 2,500,000 subset (75\% of full CC3M) from them to train our ProtoCLIP. Considering the dropping accessibility of image links in Conceptual Captions, we call for the use of this dataset size (2.5M) in future benchmarking for better comparability. 88 | 89 | **Note**: `webdataset` is no longer supported in this codebase. 90 | 91 | 92 | ## 3. Prepare Downstream Data 93 | - **Zero-shot Classification**. The preprocessed zero-shot datasets can be downloaded from [CLOOB](https://github.com/ml-jku/cloob#downstream-tasks). 94 | 95 | - **Linear Probing**. We perform linear evaluation on ImageNet, CIFAR10, CIFAR100, and STL10. You need to download the full [ImageNet-1k](https://image-net.org/download.php) dataset manually. The later three datasets are integrated into `torchvision` and will be downloaded automatically. 96 | 97 | - **Image-text Retrieval**. We implement zero-shot image-text retrieval on MS-COCO. Since we do not perform fine-tuning, only the validation split (`/val2017`) is required here. 98 | 99 | 100 | ``` 101 | # All downstream datasets shall be stored to dictionary: 102 | 103 | ├── imagenet 104 | │ ├── train 105 | │ └── test 106 | ├── birdsnap 107 | │   └── test 108 | ├── country211 109 | │   └── test 110 | ├── flowers102 111 | │   └── test 112 | ├── gtsrb 113 | │   └── test 114 | ├── stanford_cars 115 | │   └── test 116 | ├── ucf101 117 | │ ├── testlist01 118 | │ ├── testlist02 119 | │ └── testlist03 120 | └── coco2017 121 |    ├── annotations 122 |    └── val2017 123 | ``` 124 | 125 | # Evaluation 126 | - **Evaluation of pretrained ProtoCLIP** 127 | 128 | **TODO**: Model checkpoint release. 129 | 130 | ```bash 131 | python src/training/main.py \ 132 | --eval-data-dir '/data/Datasets' --batch-size 32 --workers 8 \ 133 | --zeroshot-frequency 1 --retrieval-frequency 1 --linear-frequency 1 \ 134 | --linear-prob-mode pytorch \ 135 | --add-projection-head \ 136 | --model RN50 --resume 'RELSEASED_CHECKPOINT.pt' 137 | ``` 138 | 139 | - **Evaluation of pretrained CLIPs** 140 | 141 | This implementation is based on an awesome CLIP re-implementation [OpenCLIP](https://github.com/mlfoundations/open_clip). The codebase is developed from version `1.2.1` of OpenCLIP. Therefore, in this codebase, you can also load various pretrained models, as in OpenCLIP. To evaluate these models, you can specify the `--model` and `--pretrained` arguments to `main.py`: 142 | 143 | ```bash 144 | python src/training/main.py \ 145 | --eval-data-dir '/data/Datasets' --batch-size 32 --workers 8 \ 146 | --zeroshot-frequency 1 --retrieval-frequency 1 --linear-frequency 1 \ 147 | --linear-prob-mode pytorch \ 148 | --model RN50 --pretrained openai 149 | ``` 150 | 151 | To list available pretrained models: 152 | 153 | ```bash 154 | >>> import open_clip 155 | >>> open_clip.list_pretrained() 156 | [('RN50', 'openai'), 157 | ('RN50', 'yfcc15m'), 158 | ('RN50', 'cc12m'), 159 | ('RN50-quickgelu', 'openai'), 160 | ('RN50-quickgelu', 'yfcc15m'), 161 | ('RN50-quickgelu', 'cc12m'), 162 | ('RN101', 'openai'), 163 | ('RN101', 'yfcc15m'), 164 | ('RN101-quickgelu', 'openai'), 165 | ('RN101-quickgelu', 'yfcc15m'), 166 | ('RN50x4', 'openai'), 167 | ('RN50x16', 'openai'), 168 | ('RN50x64', 'openai'), 169 | ('ViT-B-32', 'openai'), 170 | ('ViT-B-32', 'laion2b_e16'), 171 | ('ViT-B-32', 'laion400m_e31'), 172 | ('ViT-B-32', 'laion400m_e32'), 173 | ('ViT-B-32-quickgelu', 'openai'), 174 | ('ViT-B-32-quickgelu', 'laion400m_e31'), 175 | ('ViT-B-32-quickgelu', 'laion400m_e32'), 176 | ('ViT-B-16', 'openai'), 177 | ('ViT-B-16', 'laion400m_e31'), 178 | ('ViT-B-16', 'laion400m_e32'), 179 | ('ViT-B-16-plus-240', 'laion400m_e31'), 180 | ('ViT-B-16-plus-240', 'laion400m_e32'), 181 | ('ViT-L-14', 'openai'), 182 | ('ViT-L-14-336', 'openai')] 183 | ``` 184 | 185 | 186 | # Training 187 | ## Build an External Teacher (optional) 188 | Pretrained RoBERTa language model as the external teacher for ProtoCLIP. We load the pretrained RoBERTa-large weights provided by [FAIRSEQ](https://github.com/facebookresearch/fairseq) via [PyTorch Hub](https://pytorch.org/hub/pytorch_fairseq_roberta/). Run the following script to extract text features from a given `.csv` file, and reduce the feature dimension from 1024 to 64 by PCA to save memory cost: 189 | 190 | 191 | ```bash 192 | python src/utils/RoBERTa.py 193 | >>> Input your csv file: 194 | >>> Input your feature file: (e.g, 'features/RoBERTA_features_CC.npy') 195 | ``` 196 | 197 | With a single NVIDIA 2080Ti GPU, extracting RoBERTa features from CC2.5M takes about 3 hours and the resulting feature file takes 600+Mb storage. 198 | 199 | **Note**: We used the pooled and normalized feature from RoBERTa: 200 | ```bash 201 | text_feature = roberta.extract_features(texts) 202 | text_feature = text_feature.mean(dim=1) 203 | text_feature = F.normalize(text_feature, dim=-1) 204 | ``` 205 | 206 | ## Sample Single GPU Training 207 | By running commands provided below, following results can be obtained. On a machine with single NVIDIA 2080ti GPU, 16 CPU, and 100GB RAM, CLIP training takes 1.03 days, while ProtoCLIP training takes 1.84 days. 208 | 209 | | Model | Backbone | Batch Size |ImageNet Linear Prob | ImageNet Zero-shot | 10 Dataset Zero-shot Avg. | COCO Mean Recall | 210 | |:---------:|:--------:|:-------------:|:----------------------:|:---------------------:|:-------------------------:|:-------------------:| 211 | | CLIP | RN50 | 64 | 38.95 | 12.29 | 15.30 | 26.20 | 212 | | ProtoCLIP | RN50 | 64 | 44.55 | 14.50 | 20.48 | 28.26 | 213 | 214 | 215 | - Train CLIP baseline with ResNet-50 backbone on CC2.5M for 8 epoch, using a single NVIDIA 2080ti GPU: 216 | ```bash 217 | python src/training/main.py \ 218 | --dataset-size 2500000 --episode-size 2500000 \ 219 | --train-data '' \ 220 | --eval-data-dir '' \ 221 | --epochs 8 --save-frequency 2 --batch-size 64 --workers 16 \ 222 | --linear-frequency 8 --zeroshot-frequency 8 --retrieval-frequency 8 \ 223 | --model RN50 --lr 5e-5 --warmup 2000 --wd 0.5 --max-grad-norm 10 \ 224 | --report-to 'tensorboard' --logs logs/single_gpu --copy-codebase \ 225 | --name 'CLIP' 226 | ``` 227 | - Train ProtoCLIP with ResNet-50 backbone on CC2.5M for 8 epoch, using a single NVIDIA 2080ti GPU: 228 | ```bash 229 | python src/training/main.py \ 230 | --dataset-size 2500000 --episode-size 200000 \ 231 | --train-data '' \ 232 | --eval-data-dir '' \ 233 | --augmentation protoclip-light-augmentation \ 234 | --epochs 100 --infonce-warmup-epoch 1 --lit-start-epoch 99 --save-frequency 25 --visualize-frequency 25 --batch-size 64 --workers 16 \ 235 | --clustering-frequency 1 --k 20000 --kmeans-max-iter 20 --kmeans-nredo 3 \ 236 | --linear-frequency 100 --zeroshot-frequency 100 --retrieval-frequency 100 \ 237 | --model RN50 --lr 5e-5 --warmup 2000 --wd 0.5 --max-grad-norm 10 \ 238 | --w-clip 1 --w-proto 1 --w-proto-external 1 --external-teacher '' \ 239 | --add-projection-head --projection-dim 128 --target-temperature 0.01 --PBT \ 240 | --report-to 'tensorboard' --logs logs/single_gpu --copy-codebase \ 241 | --name 'ProtoCLIP' 242 | ``` 243 | 244 | ## Multi GPU Training 245 | CLIP and ProtoCLIP achieve the following downstream performance with CC2.5M: 246 | 247 | | Model | Backbone | Batch Size | ImageNet Linear Prob | ImageNet Zero-shot | 10 Dataset Zero-shot Avg.| COCO Mean Recall | 248 | |:---------:|:--------:|:-------------:|:-----------------------:|:---------------------:|:-------------------------:|:-------------------:| 249 | | CLIP | RN50 | 512 | 49.41 | 19.46 | 21.87 | 36.48 | 250 | | ProtoCLIP | RN50 | 512 | 55.22 | 21.47 | 22.52 | 35.69 | 251 | 252 | - Train CLIP baseline with ResNet-50 backbone on CC2.5M for 32 epoch, using 8 NVIDIA 2080ti GPU: 253 | ```bash 254 | torchrun --nproc_per_node 8 -m training.main \ 255 | --dataset-size 2500000 --episode-size 2500000 \ 256 | --train-data '' \ 257 | --eval-data-dir '' \ 258 | --epochs 32 --save-frequency 4 --batch-size 64 --workers 16 \ 259 | --linear-frequency 32 --zeroshot-frequency 32 --retrieval-frequency 32 \ 260 | --model RN50 --lr 5e-4 --warmup 15625 --wd 0.5 --max-grad-norm 10 \ 261 | --report-to 'tensorboard' --logs logs/CC2.5M_32ep_benchmark --copy-codebase \ 262 | --name 'CLIP' 263 | ``` 264 | - Train ProtoCLIP with ResNet-50 backbone on CC2.5M for 32 epoch, using 8 NVIDIA 2080ti GPU: 265 | ```bash 266 | torchrun --nproc_per_node 8 -m training.main \ 267 | --dataset-size 2500000 --episode-size 200000 \ 268 | --train-data '' \ 269 | --eval-data-dir '' \ 270 | --augmentation protoclip-light-augmentation \ 271 | --epochs 400 --infonce-warmup-epoch 1 --lit-start-epoch 399 --save-frequency 50 --visualize-frequency 50 --batch-size 64 --workers 16 \ 272 | --clustering-frequency 1 --k 20000 --kmeans-max-iter 20 --kmeans-nredo 3 \ 273 | --linear-frequency 400 --zeroshot-frequency 400 --retrieval-frequency 400 \ 274 | --model RN50 --lr 5e-4 --warmup 15625 --wd 0.5 --max-grad-norm 10 \ 275 | --w-clip 1 --w-proto 1 --w-proto-external 1 --external-teacher '' \ 276 | --add-projection-head --projection-dim 128 --target-temperature 0.01 --PBT \ 277 | --report-to 'tensorboard' --logs logs/CC2.5M_32ep_benchmark --copy-codebase \ 278 | --name 'ProtoCLIP' 279 | ``` 280 | 281 | **Note**: Multi node distributed training is not supported yet. 282 | 283 | ## Some Notes on Arguments 284 | Run `python src/training/main.py --help` to see descriptions of all arguments. Here we provide some explanations of our newly added arguments: 285 | 286 | 287 | - `--dataset-size` trunk the samples with the specified number. To train model with CC2.5M as in the ProtoCLIP paper, use `--dataset-size=2500000` 288 | 289 | - `--episode-size` enables episode training. We randomly select `args.episode_size` (e.g., 0.2M) samples from the entire dataset as an episode by `index_mapping[:] = torch.from_numpy(np.random.choice(args.dataset_size, args.episode_size, replace=True)).share_memory_()`. Note that we use `.share_memory_()` to make the `index_mapping` to be consistant across different subprocess when multiple GPUs are used. It can also be used for dataset shuffling (see this [issue](https://github.com/mlfoundations/open_clip/issues/101)) by setting episode size = dataset size. 290 | 291 | - `--add-projection-head` creates MLP projection heads for image and text towers. ProtoCLIP checkpoints must be loaded with this argument, otherwise, state dict mismatch error will occurred. 292 | 293 | - `--PBT` enables Prototype Back Translation. It is a prerequisite of applying `--external-teacher`. 294 | 295 | - `--linear-prob-mode` we implemented faster pytorch based linaer regression. Specify `--linear-prob-mode=sklearn` to run sklearn L-BGFS logistic regression on CPU (slow). 296 | 297 | - `--visualize-frequency` enable T-SNE and cluster visualization for ProtoCLIP. You may want to do visualization sparsely since both of them are time-consuming. Setting `--visualize-frequency=-1` as default to skip visualization. 298 | 299 | - `--pretrained-text` is experimental. It is not used in our ProtoCLIP paper. To load a pretrained $\text{RoBERTa}_\text{base}$, specify `--pretrained-text='roberta-base'` 300 | 301 | - `--resume` you can add `--resume='/path/to/your/checkpoints/epoch_xx.pt'` to resume training. 302 | 303 | ## 📈Monitoring Downstream Performances During Training 304 | 305 | Experiment will be logged to `` as following: 306 | ``` 307 | 308 | ├── cache 309 | ├── checkpoints 310 | │   ├── epoch_4.pt 311 | │   ├── epoch_8.pt 312 | │   ├── epoch_12.pt 313 | │   ├── epoch_16.pt 314 | │   ├── epoch_20.pt 315 | │   ├── epoch_24.pt 316 | │   ├── epoch_28.pt 317 | │   ├── epoch_32.pt 318 | │   └── epoch_latest.pt 319 | ├── out.log 320 | ├── params.txt 321 | ├── results.jsonl 322 | ├── evaluation_metrics_all.csv 323 | └── tensorboard 324 | └── events.out.tfevents 325 | ``` 326 | 327 | We present an useful tool for monitoring the downstream performance. By running `src/utils/evaluate_checkpoints.py` and specifying an experiment logging dir, it will read configurations from `params.txt` and automatically monitor and evaluate checkpoints. The result will be automatically saved as a `.csv` file (`evaluation_metrics_all.csv`). You can also specify an individual checkpoint to evaluate. 328 | ``` 329 | >>> python src/utils/evaluate_checkpoints.py 330 | Please input your experiment dir: 331 | Specify a checkpoint epoch? (press "enter" to scan and evaluate all checkpoints) 332 | ``` 333 | 334 | # 🎈 Aknowledgements 335 | 336 | If you find this project useful for your research, please consider citing our paper: 337 | 338 | ```bibtex 339 | @article{chen2023prototypical, 340 | author = {Delong Chen and 341 | Zhao Wu and 342 | Fan Liu and 343 | Zaiquan Yang and 344 | Shaoqiu Zheng and 345 | Ying Tan and 346 | Erjin Zhou}, 347 | title = {ProtoCLIP: Prototypical Contrastive Language Image Pretraining}, 348 | journal = {IEEE Transactions on Neural Networks and Learning Systems (TNNLS)}, 349 | year = {2023}, 350 | } 351 | ``` 352 | 353 | If you have any problems about ProtoCLIP algorithm or this implementation, create an issue or email chendelong@hhu.edu.cn. 354 | -------------------------------------------------------------------------------- /docs/Fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/protoclip/fde193381369380bd6ffe278d644d85717e01a80/docs/Fig1.png -------------------------------------------------------------------------------- /docs/protoclip_model_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/protoclip/fde193381369380bd6ffe278d644d85717e01a80/docs/protoclip_model_structure.png -------------------------------------------------------------------------------- /src/open_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import list_models, create_model, create_model_and_transforms, add_model_config 2 | from .loss import ClipLoss 3 | from .model import CLIP, CLIPTextCfg, CLIPVisionCfg, convert_weights_to_fp16, trace_model 4 | from .openai import load_openai_model, list_openai_models 5 | from .pretrained import list_pretrained, list_pretrained_tag_models, list_pretrained_model_tags,\ 6 | get_pretrained_url, download_pretrained 7 | from .tokenizer import SimpleTokenizer, tokenize 8 | from .transform import image_transform 9 | -------------------------------------------------------------------------------- /src/open_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/protoclip/fde193381369380bd6ffe278d644d85717e01a80/src/open_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /src/open_clip/factory.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import pathlib 5 | import re 6 | from copy import deepcopy 7 | from pathlib import Path 8 | 9 | import torch 10 | 11 | from .model import CLIP, convert_weights_to_fp16 12 | from .openai import load_openai_model 13 | from .pretrained import get_pretrained_url, download_pretrained 14 | from .transform import image_transform 15 | 16 | 17 | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] 18 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs 19 | 20 | 21 | def _natural_key(string_): 22 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 23 | 24 | 25 | def _rescan_model_configs(): 26 | global _MODEL_CONFIGS 27 | 28 | config_ext = ('.json',) 29 | config_files = [] 30 | for config_path in _MODEL_CONFIG_PATHS: 31 | if config_path.is_file() and config_path.suffix in config_ext: 32 | config_files.append(config_path) 33 | elif config_path.is_dir(): 34 | for ext in config_ext: 35 | config_files.extend(config_path.glob(f'*{ext}')) 36 | 37 | for cf in config_files: 38 | with open(cf, 'r') as f: 39 | model_cfg = json.load(f) 40 | if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): 41 | _MODEL_CONFIGS[cf.stem] = model_cfg 42 | 43 | _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} 44 | 45 | 46 | _rescan_model_configs() # initial populate of model config registry 47 | 48 | 49 | def load_state_dict(checkpoint_path: str, map_location='cpu'): 50 | checkpoint = torch.load(checkpoint_path, map_location=map_location) 51 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: 52 | state_dict = checkpoint['state_dict'] 53 | else: 54 | state_dict = checkpoint 55 | if next(iter(state_dict.items()))[0].startswith('module'): 56 | state_dict = {k[7:]: v for k, v in state_dict.items()} 57 | return state_dict 58 | 59 | 60 | def create_model( 61 | model_name: str, 62 | pretrained: str = '', 63 | precision: str = 'fp32', 64 | device: torch.device = torch.device('cpu'), 65 | jit: bool = False, 66 | force_quick_gelu: bool = False, 67 | pretrained_image: bool = False, 68 | pretrained_text= None, 69 | args=None 70 | ): 71 | model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names 72 | 73 | if pretrained.lower() == 'openai': 74 | logging.info(f'Loading pretrained {model_name} from OpenAI.') 75 | model = load_openai_model(model_name, device=device, jit=jit) 76 | # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372 77 | if precision == "amp" or precision == "fp32": 78 | model = model.float() 79 | else: 80 | if model_name in _MODEL_CONFIGS: 81 | logging.info(f'Loading {model_name} model config.') 82 | model_cfg = deepcopy(_MODEL_CONFIGS[model_name]) 83 | else: 84 | logging.error(f'Model config for {model_name} not found; available models {list_models()}.') 85 | raise RuntimeError(f'Model config for {model_name} not found.') 86 | 87 | if force_quick_gelu: 88 | # override for use of QuickGELU on non-OpenAI transformer models 89 | model_cfg["quick_gelu"] = True 90 | 91 | if pretrained_image: 92 | if 'timm_model_name' in model_cfg.get('vision_cfg', {}): 93 | # pretrained weight loading for timm models set via vision_cfg 94 | model_cfg['vision_cfg']['timm_model_pretrained'] = True 95 | else: 96 | assert False, 'pretrained image towers currently only supported for timm models' 97 | model_cfg['pretrained_text'] = pretrained_text 98 | model_cfg['args'] = args 99 | model = CLIP(**model_cfg) 100 | 101 | if pretrained: 102 | checkpoint_path = '' 103 | url = get_pretrained_url(model_name, pretrained) 104 | if url: 105 | checkpoint_path = download_pretrained(url) 106 | elif os.path.exists(pretrained): 107 | checkpoint_path = pretrained 108 | 109 | if checkpoint_path: 110 | logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') 111 | model.load_state_dict(load_state_dict(checkpoint_path)) 112 | else: 113 | logging.warning(f'Pretrained weights ({pretrained}) not found for model {model_name}.') 114 | raise RuntimeError(f'Pretrained weights ({pretrained}) not found for model {model_name}.') 115 | 116 | model.to(device=device) 117 | if precision == "fp16": 118 | assert device.type != 'cpu' 119 | convert_weights_to_fp16(model) 120 | 121 | if jit: 122 | model = torch.jit.script(model) 123 | 124 | return model 125 | 126 | 127 | def create_model_and_transforms( 128 | model_name: str, 129 | pretrained: str = '', 130 | precision: str = 'fp32', 131 | device: torch.device = torch.device('cpu'), 132 | jit: bool = False, 133 | force_quick_gelu: bool = False, 134 | pretrained_image: bool = False, 135 | pretrained_text = None, 136 | args=None 137 | ): 138 | model = create_model( 139 | model_name, pretrained, precision, device, jit, 140 | force_quick_gelu=force_quick_gelu, 141 | pretrained_image=pretrained_image, 142 | pretrained_text=pretrained_text, 143 | args=args) 144 | preprocess_train = image_transform(model.visual.image_size, is_train=True, augmentation=args.augmentation) 145 | preprocess_val = image_transform(model.visual.image_size, is_train=False) 146 | return model, preprocess_train, preprocess_val 147 | 148 | 149 | def list_models(): 150 | """ enumerate available model architectures based on config files """ 151 | return list(_MODEL_CONFIGS.keys()) 152 | 153 | 154 | def add_model_config(path): 155 | """ add model config path or file and update registry """ 156 | if not isinstance(path, Path): 157 | path = Path(path) 158 | _MODEL_CONFIG_PATHS.append(path) 159 | _rescan_model_configs() 160 | -------------------------------------------------------------------------------- /src/open_clip/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed.nn 3 | from torch import distributed as dist, nn as nn 4 | from torch.nn import functional as F 5 | 6 | try: 7 | import horovod.torch as hvd 8 | except ImportError: 9 | hvd = None 10 | 11 | 12 | def gather_features( 13 | image_features, 14 | text_features, 15 | local_loss=False, 16 | gather_with_grad=False, 17 | rank=0, 18 | world_size=1, 19 | use_horovod=False 20 | ): 21 | if use_horovod: 22 | assert hvd is not None, 'Please install horovod' 23 | if gather_with_grad: 24 | all_image_features = hvd.allgather(image_features) 25 | all_text_features = hvd.allgather(text_features) 26 | else: 27 | with torch.no_grad(): 28 | all_image_features = hvd.allgather(image_features) 29 | all_text_features = hvd.allgather(text_features) 30 | if not local_loss: 31 | # ensure grads for local rank when all_* features don't have a gradient 32 | gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) 33 | gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) 34 | gathered_image_features[rank] = image_features 35 | gathered_text_features[rank] = text_features 36 | all_image_features = torch.cat(gathered_image_features, dim=0) 37 | all_text_features = torch.cat(gathered_text_features, dim=0) 38 | else: 39 | # We gather tensors from all gpus 40 | if gather_with_grad: 41 | all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) 42 | all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) 43 | else: 44 | gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] 45 | gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] 46 | dist.all_gather(gathered_image_features, image_features) 47 | dist.all_gather(gathered_text_features, text_features) 48 | if not local_loss: 49 | # ensure grads for local rank when all_* features don't have a gradient 50 | gathered_image_features[rank] = image_features 51 | gathered_text_features[rank] = text_features 52 | all_image_features = torch.cat(gathered_image_features, dim=0) 53 | all_text_features = torch.cat(gathered_text_features, dim=0) 54 | 55 | return all_image_features, all_text_features 56 | 57 | 58 | class ClipLoss(nn.Module): 59 | 60 | def __init__( 61 | self, 62 | local_loss=False, 63 | gather_with_grad=False, 64 | cache_labels=False, 65 | rank=0, 66 | world_size=1, 67 | use_horovod=False, 68 | ): 69 | super().__init__() 70 | self.local_loss = local_loss 71 | self.gather_with_grad = gather_with_grad 72 | self.cache_labels = cache_labels 73 | self.rank = rank 74 | self.world_size = world_size 75 | self.use_horovod = use_horovod 76 | 77 | # cache state 78 | self.prev_num_logits = 0 79 | self.labels = {} 80 | 81 | def forward(self, image_features, text_features, logit_scale): 82 | device = image_features.device 83 | if self.world_size > 1: 84 | all_image_features, all_text_features = gather_features( 85 | image_features, text_features, 86 | self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) 87 | 88 | if self.local_loss: 89 | logits_per_image = logit_scale * image_features @ all_text_features.T 90 | logits_per_text = logit_scale * text_features @ all_image_features.T 91 | else: 92 | logits_per_image = logit_scale * all_image_features @ all_text_features.T 93 | logits_per_text = logits_per_image.T 94 | else: 95 | logits_per_image = logit_scale * image_features @ text_features.T 96 | logits_per_text = logit_scale * text_features @ image_features.T 97 | 98 | # calculated ground-truth and cache if enabled 99 | num_logits = logits_per_image.shape[0] 100 | if self.prev_num_logits != num_logits or device not in self.labels: 101 | labels = torch.arange(num_logits, device=device, dtype=torch.long) 102 | if self.world_size > 1 and self.local_loss: 103 | labels = labels + num_logits * self.rank 104 | if self.cache_labels: 105 | self.labels[device] = labels 106 | self.prev_num_logits = num_logits 107 | else: 108 | labels = self.labels[device] 109 | 110 | total_loss = ( 111 | F.cross_entropy(logits_per_image, labels) + 112 | F.cross_entropy(logits_per_text, labels) 113 | ) / 2 114 | return total_loss 115 | -------------------------------------------------------------------------------- /src/open_clip/model_configs/RN101-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 23, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/RN101.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 23, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/RN50-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 6, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/open_clip/model_configs/RN50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 6, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/RN50x16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 384, 5 | "layers": [ 6 | 6, 7 | 8, 8 | 18, 9 | 8 10 | ], 11 | "width": 96, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 768, 18 | "heads": 12, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/RN50x4.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 288, 5 | "layers": [ 6 | 4, 7 | 6, 8 | 10, 9 | 6 10 | ], 11 | "width": 80, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 640, 18 | "heads": 10, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-16-plus-240.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 240, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-16-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-32-plus-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 256, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-32-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-H-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 16 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-L-16-320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 320, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-L-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/ViT-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1024, 15 | "heads": 16, 16 | "layers": 12 17 | } 18 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/timm-efficientnetv2_rw_s.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "efficientnetv2_rw_s", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "abs_attn", 7 | "timm_proj": "", 8 | "image_size": 288 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/timm-resnet50d.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "resnet50d", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "abs_attn", 7 | "timm_proj": "", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/open_clip/model_configs/timm-resnetaa50d.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "resnetaa50d", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "abs_attn", 7 | "timm_proj": "", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/open_clip/model_configs/timm-resnetblur50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "resnetblur50", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "abs_attn", 7 | "timm_proj": "", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/open_clip/model_configs/timm-swin_base_patch4_window7_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "swin_base_patch4_window7_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/timm-vit_base_patch16_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_base_patch16_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/timm-vit_base_patch32_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_base_patch32_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/model_configs/timm-vit_small_patch16_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_small_patch16_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /src/open_clip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import Union, List 9 | 10 | import torch 11 | 12 | from .model import build_model_from_openai_state_dict 13 | from .pretrained import get_pretrained_url, list_pretrained_tag_models, download_pretrained 14 | 15 | __all__ = ["list_openai_models", "load_openai_model"] 16 | 17 | 18 | def list_openai_models() -> List[str]: 19 | """Returns the names of available CLIP models""" 20 | return list_pretrained_tag_models('openai') 21 | 22 | 23 | def load_openai_model( 24 | name: str, 25 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 26 | jit=True, 27 | ): 28 | """Load a CLIP model 29 | 30 | Parameters 31 | ---------- 32 | name : str 33 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 34 | device : Union[str, torch.device] 35 | The device to put the loaded model 36 | jit : bool 37 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 38 | 39 | Returns 40 | ------- 41 | model : torch.nn.Module 42 | The CLIP model 43 | preprocess : Callable[[PIL.Image], torch.Tensor] 44 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 45 | """ 46 | if get_pretrained_url(name, 'openai'): 47 | model_path = download_pretrained(get_pretrained_url(name, 'openai')) 48 | elif os.path.isfile(name): 49 | model_path = name 50 | else: 51 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 52 | 53 | try: 54 | # loading JIT archive 55 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 56 | state_dict = None 57 | except RuntimeError: 58 | # loading saved state dict 59 | if jit: 60 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 61 | jit = False 62 | state_dict = torch.load(model_path, map_location="cpu") 63 | 64 | if not jit: 65 | try: 66 | model = build_model_from_openai_state_dict(state_dict or model.state_dict()).to(device) 67 | except KeyError: 68 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 69 | model = build_model_from_openai_state_dict(sd).to(device) 70 | 71 | if str(device) == "cpu": 72 | model.float() 73 | return model 74 | 75 | # patch the device names 76 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 77 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 78 | 79 | def patch_device(module): 80 | try: 81 | graphs = [module.graph] if hasattr(module, "graph") else [] 82 | except RuntimeError: 83 | graphs = [] 84 | 85 | if hasattr(module, "forward1"): 86 | graphs.append(module.forward1.graph) 87 | 88 | for graph in graphs: 89 | for node in graph.findAllNodes("prim::Constant"): 90 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 91 | node.copyAttributes(device_node) 92 | 93 | model.apply(patch_device) 94 | patch_device(model.encode_image) 95 | patch_device(model.encode_text) 96 | 97 | # patch dtype to float32 on CPU 98 | if str(device) == "cpu": 99 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 100 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 101 | float_node = float_input.node() 102 | 103 | def patch_float(module): 104 | try: 105 | graphs = [module.graph] if hasattr(module, "graph") else [] 106 | except RuntimeError: 107 | graphs = [] 108 | 109 | if hasattr(module, "forward1"): 110 | graphs.append(module.forward1.graph) 111 | 112 | for graph in graphs: 113 | for node in graph.findAllNodes("aten::to"): 114 | inputs = list(node.inputs()) 115 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 116 | if inputs[i].node()["value"] == 5: 117 | inputs[i].node().copyAttributes(float_node) 118 | 119 | model.apply(patch_float) 120 | patch_float(model.encode_image) 121 | patch_float(model.encode_text) 122 | model.float() 123 | 124 | # ensure image_size attr available at consistent location for both jit and non-jit 125 | model.visual.image_size = model.input_resolution.item() 126 | return model 127 | -------------------------------------------------------------------------------- /src/open_clip/pretrained.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | 6 | from tqdm import tqdm 7 | 8 | _RN50 = dict( 9 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 10 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", 11 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt" 12 | ) 13 | 14 | _RN50_quickgelu = dict( 15 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 16 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", 17 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt" 18 | ) 19 | 20 | _RN101 = dict( 21 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 22 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt" 23 | ) 24 | 25 | _RN101_quickgelu = dict( 26 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 27 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt" 28 | ) 29 | 30 | _RN50x4 = dict( 31 | openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 32 | ) 33 | 34 | _RN50x16 = dict( 35 | openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 36 | ) 37 | 38 | _RN50x64 = dict( 39 | openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 40 | ) 41 | 42 | _VITB32 = dict( 43 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 44 | laion2b_e16="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth", 45 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", 46 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", 47 | ) 48 | 49 | _VITB32_quickgelu = dict( 50 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 51 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", 52 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", 53 | ) 54 | 55 | _VITB16 = dict( 56 | openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 57 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt", 58 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt", 59 | ) 60 | 61 | _VITB16_PLUS_240 = dict( 62 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt", 63 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt", 64 | ) 65 | 66 | _VITL14 = dict( 67 | openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 68 | ) 69 | 70 | _VITL14_336 = dict( 71 | openai="https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt" 72 | ) 73 | 74 | _PRETRAINED = { 75 | "RN50": _RN50, 76 | "RN50-quickgelu": _RN50_quickgelu, 77 | "RN101": _RN101, 78 | "RN101-quickgelu": _RN101_quickgelu, 79 | "RN50x4": _RN50x4, 80 | "RN50x16": _RN50x16, 81 | "RN50x64": _RN50x64, 82 | "ViT-B-32": _VITB32, 83 | "ViT-B-32-quickgelu": _VITB32_quickgelu, 84 | "ViT-B-16": _VITB16, 85 | "ViT-B-16-plus-240": _VITB16_PLUS_240, 86 | "ViT-L-14": _VITL14, 87 | "ViT-L-14-336": _VITL14_336, 88 | } 89 | 90 | 91 | def list_pretrained(as_str: bool = False): 92 | """ returns list of pretrained models 93 | Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True 94 | """ 95 | return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] 96 | 97 | 98 | def list_pretrained_tag_models(tag: str): 99 | """ return all models having the specified pretrain tag """ 100 | models = [] 101 | for k in _PRETRAINED.keys(): 102 | if tag in _PRETRAINED[k]: 103 | models.append(k) 104 | return models 105 | 106 | 107 | def list_pretrained_model_tags(model: str): 108 | """ return all pretrain tags for the specified model architecture """ 109 | tags = [] 110 | if model in _PRETRAINED: 111 | tags.extend(_PRETRAINED[model].keys()) 112 | return tags 113 | 114 | 115 | def get_pretrained_url(model: str, tag: str): 116 | if model not in _PRETRAINED: 117 | return '' 118 | model_pretrained = _PRETRAINED[model] 119 | tag = tag.lower() 120 | if tag not in model_pretrained: 121 | return '' 122 | return model_pretrained[tag] 123 | 124 | 125 | def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")): 126 | os.makedirs(root, exist_ok=True) 127 | filename = os.path.basename(url) 128 | 129 | if 'openaipublic' in url: 130 | expected_sha256 = url.split("/")[-2] 131 | else: 132 | expected_sha256 = '' 133 | 134 | download_target = os.path.join(root, filename) 135 | 136 | if os.path.exists(download_target) and not os.path.isfile(download_target): 137 | raise RuntimeError(f"{download_target} exists and is not a regular file") 138 | 139 | if os.path.isfile(download_target): 140 | if expected_sha256: 141 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 142 | return download_target 143 | else: 144 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 145 | else: 146 | return download_target 147 | 148 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 149 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 150 | while True: 151 | buffer = source.read(8192) 152 | if not buffer: 153 | break 154 | 155 | output.write(buffer) 156 | loop.update(len(buffer)) 157 | 158 | if expected_sha256 and hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 159 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 160 | 161 | return download_target 162 | -------------------------------------------------------------------------------- /src/open_clip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | from collections import OrderedDict 6 | 7 | import torch.nn as nn 8 | 9 | try: 10 | import timm 11 | from timm.models.layers import Mlp, to_2tuple 12 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 13 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 14 | except ImportError as e: 15 | timm = None 16 | 17 | from .utils import freeze_batch_norm_2d 18 | 19 | 20 | class TimmModel(nn.Module): 21 | """ timm model adapter 22 | # FIXME this adapter is a work in progress, may change in ways that break weight compat 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model_name, 28 | embed_dim, 29 | image_size=224, 30 | pool='avg', 31 | proj='linear', 32 | drop=0., 33 | pretrained=False): 34 | super().__init__() 35 | if timm is None: 36 | raise RuntimeError("Please `pip install timm` to use timm models.") 37 | 38 | self.image_size = to_2tuple(image_size) 39 | self.trunk = timm.create_model(model_name, pretrained=pretrained) 40 | feat_size = self.trunk.default_cfg.get('pool_size', None) 41 | feature_ndim = 1 if not feat_size else 2 42 | if pool in ('abs_attn', 'rot_attn'): 43 | assert feature_ndim == 2 44 | # if attn pooling used, remove both classifier and default pool 45 | self.trunk.reset_classifier(0, global_pool='') 46 | else: 47 | # reset global pool if pool config set, otherwise leave as network default 48 | reset_kwargs = dict(global_pool=pool) if pool else {} 49 | self.trunk.reset_classifier(0, **reset_kwargs) 50 | prev_chs = self.trunk.num_features 51 | 52 | head_layers = OrderedDict() 53 | if pool == 'abs_attn': 54 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 55 | prev_chs = embed_dim 56 | elif pool == 'rot_attn': 57 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 58 | prev_chs = embed_dim 59 | else: 60 | assert proj, 'projection layer needed if non-attention pooling is used.' 61 | 62 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 63 | if proj == 'linear': 64 | head_layers['drop'] = nn.Dropout(drop) 65 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim) 66 | elif proj == 'mlp': 67 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop) 68 | 69 | self.head = nn.Sequential(head_layers) 70 | 71 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 72 | """ lock modules 73 | Args: 74 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 75 | """ 76 | if not unlocked_groups: 77 | # lock full model 78 | for param in self.trunk.parameters(): 79 | param.requires_grad = False 80 | if freeze_bn_stats: 81 | freeze_batch_norm_2d(self.trunk) 82 | else: 83 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 84 | try: 85 | # FIXME import here until API stable and in an official release 86 | from timm.models.helpers import group_parameters, group_modules 87 | except ImportError: 88 | raise RuntimeError( 89 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') 90 | matcher = self.trunk.group_matcher() 91 | gparams = group_parameters(self.trunk, matcher) 92 | max_layer_id = max(gparams.keys()) 93 | max_layer_id = max_layer_id - unlocked_groups 94 | for group_idx in range(max_layer_id + 1): 95 | group = gparams[group_idx] 96 | for param in group: 97 | self.trunk.get_parameter(param).requires_grad = False 98 | if freeze_bn_stats: 99 | gmodules = group_modules(self.trunk, matcher, reverse=True) 100 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 101 | freeze_batch_norm_2d(self.trunk, gmodules) 102 | 103 | def forward(self, x): 104 | x = self.trunk(x) 105 | x = self.head(x) 106 | return x 107 | -------------------------------------------------------------------------------- /src/open_clip/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ CLIP tokenizer 2 | 3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | import gzip 6 | import html 7 | import os 8 | from functools import lru_cache 9 | from typing import Union, List 10 | 11 | import ftfy 12 | import regex as re 13 | import torch 14 | 15 | 16 | @lru_cache() 17 | def default_bpe(): 18 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 19 | 20 | 21 | @lru_cache() 22 | def bytes_to_unicode(): 23 | """ 24 | Returns list of utf-8 byte and a corresponding list of unicode strings. 25 | The reversible bpe codes work on unicode strings. 26 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 27 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 28 | This is a signficant percentage of your normal, say, 32K bpe vocab. 29 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 30 | And avoids mapping to whitespace/control characters the bpe code barfs on. 31 | """ 32 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 33 | cs = bs[:] 34 | n = 0 35 | for b in range(2**8): 36 | if b not in bs: 37 | bs.append(b) 38 | cs.append(2**8+n) 39 | n += 1 40 | cs = [chr(n) for n in cs] 41 | return dict(zip(bs, cs)) 42 | 43 | 44 | def get_pairs(word): 45 | """Return set of symbol pairs in a word. 46 | Word is represented as tuple of symbols (symbols being variable-length strings). 47 | """ 48 | pairs = set() 49 | prev_char = word[0] 50 | for char in word[1:]: 51 | pairs.add((prev_char, char)) 52 | prev_char = char 53 | return pairs 54 | 55 | 56 | def basic_clean(text): 57 | text = ftfy.fix_text(text) 58 | text = html.unescape(html.unescape(text)) 59 | return text.strip() 60 | 61 | 62 | def whitespace_clean(text): 63 | text = re.sub(r'\s+', ' ', text) 64 | text = text.strip() 65 | return text 66 | 67 | 68 | class SimpleTokenizer(object): 69 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 70 | self.byte_encoder = bytes_to_unicode() 71 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 72 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 73 | merges = merges[1:49152-256-2+1] 74 | merges = [tuple(merge.split()) for merge in merges] 75 | vocab = list(bytes_to_unicode().values()) 76 | vocab = vocab + [v+'' for v in vocab] 77 | for merge in merges: 78 | vocab.append(''.join(merge)) 79 | if not special_tokens: 80 | special_tokens = ['', ''] 81 | else: 82 | special_tokens = ['', ''] + special_tokens 83 | vocab.extend(special_tokens) 84 | self.encoder = dict(zip(vocab, range(len(vocab)))) 85 | self.decoder = {v: k for k, v in self.encoder.items()} 86 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 87 | self.cache = {t:t for t in special_tokens} 88 | special = "|".join(special_tokens) 89 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 90 | 91 | self.vocab_size = len(self.encoder) 92 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 93 | 94 | def bpe(self, token): 95 | if token in self.cache: 96 | return self.cache[token] 97 | word = tuple(token[:-1]) + ( token[-1] + '',) 98 | pairs = get_pairs(word) 99 | 100 | if not pairs: 101 | return token+'' 102 | 103 | while True: 104 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 105 | if bigram not in self.bpe_ranks: 106 | break 107 | first, second = bigram 108 | new_word = [] 109 | i = 0 110 | while i < len(word): 111 | try: 112 | j = word.index(first, i) 113 | new_word.extend(word[i:j]) 114 | i = j 115 | except: 116 | new_word.extend(word[i:]) 117 | break 118 | 119 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 120 | new_word.append(first+second) 121 | i += 2 122 | else: 123 | new_word.append(word[i]) 124 | i += 1 125 | new_word = tuple(new_word) 126 | word = new_word 127 | if len(word) == 1: 128 | break 129 | else: 130 | pairs = get_pairs(word) 131 | word = ' '.join(word) 132 | self.cache[token] = word 133 | return word 134 | 135 | def encode(self, text): 136 | bpe_tokens = [] 137 | text = whitespace_clean(basic_clean(text)).lower() 138 | for token in re.findall(self.pat, text): 139 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 140 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 141 | return bpe_tokens 142 | 143 | def decode(self, tokens): 144 | text = ''.join([self.decoder[token] for token in tokens]) 145 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 146 | return text 147 | 148 | 149 | _tokenizer = SimpleTokenizer() 150 | 151 | 152 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 153 | """ 154 | Returns the tokenized representation of given input string(s) 155 | 156 | Parameters 157 | ---------- 158 | texts : Union[str, List[str]] 159 | An input string or a list of input strings to tokenize 160 | context_length : int 161 | The context length to use; all CLIP models use 77 as the context length 162 | 163 | Returns 164 | ------- 165 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 166 | """ 167 | if isinstance(texts, str): 168 | texts = [texts] 169 | 170 | sot_token = _tokenizer.encoder[""] 171 | eot_token = _tokenizer.encoder[""] 172 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 173 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 174 | 175 | for i, tokens in enumerate(all_tokens): 176 | if len(tokens) > context_length: 177 | tokens = tokens[:context_length] # Truncate 178 | result[i, :len(tokens)] = torch.tensor(tokens) 179 | 180 | return result 181 | -------------------------------------------------------------------------------- /src/open_clip/transform.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ 2 | CenterCrop 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | from torchvision.transforms import transforms 8 | 9 | def _convert_to_rgb(image): 10 | return image.convert('RGB') 11 | 12 | 13 | def image_transform( 14 | image_size: int, 15 | is_train: bool, 16 | mean=(0.48145466, 0.4578275, 0.40821073), 17 | std=(0.26862954, 0.26130258, 0.27577711), 18 | augmentation=None 19 | ): 20 | normalize = Normalize(mean=mean, std=std) 21 | if is_train: 22 | if not augmentation: 23 | return Compose([ 24 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), 25 | _convert_to_rgb, 26 | ToTensor(), 27 | normalize, 28 | ]) 29 | elif augmentation == 'protoclip-light-augmentation': 30 | s = 1 31 | size = image_size 32 | color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) 33 | gaussian_blur = transforms.GaussianBlur(kernel_size=21) 34 | return Compose([ 35 | transforms.RandomResizedCrop(size=size, scale=(0.5, 1.0), interpolation=InterpolationMode.BICUBIC), 36 | _convert_to_rgb, 37 | transforms.RandomHorizontalFlip(), 38 | transforms.RandomApply([color_jitter], p=0.2), 39 | transforms.RandomGrayscale(p=0.2), 40 | transforms.RandomApply([gaussian_blur], p=0.2), 41 | transforms.ToTensor(), 42 | normalize 43 | ]) 44 | 45 | else: 46 | return Compose([ 47 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 48 | CenterCrop(image_size), 49 | _convert_to_rgb, 50 | ToTensor(), 51 | normalize, 52 | ]) 53 | -------------------------------------------------------------------------------- /src/open_clip/utils.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from torchvision.ops.misc import FrozenBatchNorm2d 3 | 4 | 5 | def freeze_batch_norm_2d(module, module_match={}, name=''): 6 | """ 7 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 8 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 9 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 10 | 11 | Args: 12 | module (torch.nn.Module): Any PyTorch module. 13 | module_match (dict): Dictionary of full module names to freeze (all if empty) 14 | name (str): Full module name (prefix) 15 | 16 | Returns: 17 | torch.nn.Module: Resulting module 18 | 19 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 20 | """ 21 | res = module 22 | is_match = True 23 | if module_match: 24 | is_match = name in module_match 25 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): 26 | res = FrozenBatchNorm2d(module.num_features) 27 | res.num_features = module.num_features 28 | res.affine = module.affine 29 | if module.affine: 30 | res.weight.data = module.weight.data.clone().detach() 31 | res.bias.data = module.bias.data.clone().detach() 32 | res.running_mean.data = module.running_mean.data 33 | res.running_var.data = module.running_var.data 34 | res.eps = module.eps 35 | else: 36 | for child_name, child in module.named_children(): 37 | full_child_name = '.'.join([name, child_name]) if name else child_name 38 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 39 | if new_child is not child: 40 | res.add_module(child_name, new_child) 41 | return res -------------------------------------------------------------------------------- /src/open_clip/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.2.1' 2 | -------------------------------------------------------------------------------- /src/training/.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | -------------------------------------------------------------------------------- /src/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/protoclip/fde193381369380bd6ffe278d644d85717e01a80/src/training/__init__.py -------------------------------------------------------------------------------- /src/training/clustering.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import numpy as np 4 | import faiss 5 | 6 | import os 7 | from PIL import Image 8 | try: 9 | import wandb 10 | except ImportError: 11 | wandb = None 12 | from utils.plot_pairs import plot_pairs 13 | import matplotlib 14 | matplotlib.use('Agg') 15 | import matplotlib.pyplot as plt 16 | from openTSNE import TSNE 17 | from training.distributed import is_master 18 | import torch.distributed as dist 19 | from torchvision.transforms import ToPILImage 20 | 21 | from sklearn.metrics import silhouette_score, davies_bouldin_score, calinski_harabasz_score 22 | import _pickle as pickle 23 | 24 | class Clustering(): 25 | 26 | def __init__(self, args): 27 | self.episode_size=args.episode_size 28 | self.feature_dim = args.projection_dim 29 | self.reset(args.k) 30 | 31 | def reset(self, k): 32 | self.img_feature = torch.zeros(size=(self.episode_size, self.feature_dim)).share_memory_() 33 | self.text_feature = torch.zeros(size=(self.episode_size, self.feature_dim)).share_memory_() 34 | 35 | self.img_labels = torch.zeros(self.episode_size, dtype=torch.long) 36 | self.text_labels = torch.zeros(self.episode_size, dtype=torch.long) 37 | self.external_labels = torch.zeros(self.episode_size, dtype=torch.long) 38 | 39 | self.img_centroids = torch.zeros([k, self.feature_dim]) 40 | self.text_centroids = torch.zeros([k, self.feature_dim]) 41 | self.external_centroids = torch.zeros([k, self.feature_dim]) 42 | 43 | self.img_centroids_translated_from_text_prototypes = torch.zeros([k, self.feature_dim]) 44 | self.text_centroids_translated_from_image_prototypes = torch.zeros([k, self.feature_dim]) 45 | self.img_centroids_translated_from_external_prototypes = torch.zeros([k, self.feature_dim]) 46 | self.text_centroids_translated_from_external_prototypes = torch.zeros([k, self.feature_dim]) 47 | 48 | 49 | def load_batch(self, index, img_features, text_features): 50 | self.img_feature[index] = img_features.detach().cpu().type(torch.float32) 51 | self.text_feature[index] = text_features.detach().cpu().type(torch.float32) 52 | 53 | 54 | def dump(self, file, item): 55 | f = open(file, 'wb') 56 | pickle.dump(item, f, protocol=4) 57 | f.close() 58 | 59 | def load(self, file): 60 | f = open(file, 'rb') 61 | item = pickle.load(f) 62 | f.close() 63 | return item 64 | 65 | def sync_prototypes(self, args): 66 | if is_master(args): 67 | self.dump(os.path.join(args.cache_path, f'img_labels.pkl'), self.img_labels) 68 | self.dump(os.path.join(args.cache_path, f'img_centroids.pkl'), self.img_centroids) 69 | self.dump(os.path.join(args.cache_path, f'text_labels.pkl'), self.text_labels) 70 | self.dump(os.path.join(args.cache_path, f'text_centroids.pkl'), self.text_centroids) 71 | self.dump(os.path.join(args.cache_path, f'external_labels.pkl'), self.external_labels) 72 | self.dump(os.path.join(args.cache_path, f'external_centroids.pkl'), self.external_centroids) 73 | if args.PBT: 74 | self.dump(os.path.join(args.cache_path, f'img_centroids_translated_from_text_prototypes.pkl'), self.img_centroids_translated_from_text_prototypes) 75 | self.dump(os.path.join(args.cache_path, f'text_centroids_translated_from_image_prototypes.pkl'), self.text_centroids_translated_from_image_prototypes) 76 | self.dump(os.path.join(args.cache_path, f'img_centroids_translated_from_external_prototypes.pkl'), self.img_centroids_translated_from_external_prototypes) 77 | self.dump(os.path.join(args.cache_path, f'text_centroids_translated_from_external_prototypes.pkl'), self.text_centroids_translated_from_external_prototypes) 78 | 79 | if args.distributed: 80 | dist.barrier() 81 | 82 | if not is_master(args): 83 | self.img_labels = self.load(os.path.join(args.cache_path, f'img_labels.pkl')) 84 | self.img_centroids = self.load(os.path.join(args.cache_path, f'img_centroids.pkl')) 85 | self.text_labels = self.load(os.path.join(args.cache_path, f'text_labels.pkl')) 86 | self.text_centroids = self.load(os.path.join(args.cache_path, f'text_centroids.pkl')) 87 | self.external_labels = self.load(os.path.join(args.cache_path, f'external_labels.pkl')) 88 | self.external_centroids = self.load(os.path.join(args.cache_path, f'external_centroids.pkl')) 89 | if args.PBT: 90 | self.img_centroids_translated_from_text_prototypes = self.load(os.path.join(args.cache_path, f'img_centroids_translated_from_text_prototypes.pkl')) 91 | self.text_centroids_translated_from_image_prototypes = self.load(os.path.join(args.cache_path, f'text_centroids_translated_from_image_prototypes.pkl')) 92 | self.img_centroids_translated_from_external_prototypes = self.load(os.path.join(args.cache_path, f'img_centroids_translated_from_external_prototypes.pkl')) 93 | self.text_centroids_translated_from_external_prototypes = self.load(os.path.join(args.cache_path, f'text_centroids_translated_from_external_prototypes.pkl')) 94 | 95 | if args.distributed: 96 | dist.barrier() 97 | 98 | if is_master(args): 99 | logging.info(f'Constructed prototypes are synchronized') 100 | for file in os.listdir(args.cache_path): 101 | os.remove(os.path.join(args.cache_path, file)) 102 | logging.info(f'Cache path {args.cache_path} has been cleared') 103 | 104 | 105 | def generate_labels(self, k, args): 106 | # remove possible NaN 107 | self.img_feature = torch.where(torch.isnan(self.img_feature), torch.full_like(self.img_feature,0), self.img_feature) 108 | self.text_feature = torch.where(torch.isnan(self.text_feature), torch.full_like(self.text_feature,0), self.text_feature) 109 | 110 | self.k=k 111 | logging.info(f'Constructing image prototypes with K-Means') 112 | self.img_labels, self.img_centroids, img_error_log, self.img_distance = self.kmeans(self.img_feature, k, args) 113 | logging.info(f'Constructing text prototypes with K-Means') 114 | self.text_labels, self.text_centroids, text_error_log, self.text_distance = self.kmeans(self.text_feature, k, args) 115 | return img_error_log, text_error_log 116 | 117 | def generate_labels_from_external_teacher(self, external_teacher, k, args): 118 | logging.info(f'Constructing external teacher prototypes with K-Means') 119 | external_teacher = torch.from_numpy(external_teacher.astype(np.float32)) 120 | self.external_labels, self.external_centroids, external_error_log, external_distance = self.kmeans(external_teacher, k, args) 121 | 122 | 123 | def kmeans(self, feature, k, args): 124 | feature=feature.cpu().numpy() 125 | 126 | centroids = torch.zeros([k, feature.shape[1]]) 127 | 128 | kmeans = faiss.Kmeans( 129 | d=feature.shape[1], 130 | k=k, 131 | niter=args.kmeans_max_iter, 132 | nredo=args.kmeans_nredo, 133 | verbose=True, 134 | gpu=True) 135 | kmeans.train(feature) 136 | 137 | # in case of derived centroid is less than args.k 138 | centroids[:,:kmeans.centroids.shape[1]] = torch.from_numpy(kmeans.centroids) 139 | distance, labels = kmeans.index.search(feature, 1) 140 | labels = np.array(labels) 141 | labels = np.reshape(labels, labels.shape[0]) 142 | 143 | return torch.from_numpy(labels), centroids, kmeans.iteration_stats, distance.flatten() 144 | 145 | def log_kmeans_error(self, iteration_stats, epoch, writer, args, name): 146 | for i in range(len(iteration_stats)): 147 | if writer is not None: 148 | writer.add_scalar(f'clustering/kmeans_error_log_{name}', iteration_stats[i]['obj'], epoch*args.kmeans_max_iter+i) 149 | if args.wandb: 150 | wandb.log({f'clustering/kmeans_error_log_{name}': iteration_stats[i]['obj'], 'step': epoch*args.kmeans_max_iter+i}) 151 | 152 | 153 | def PBT(self, k, teacher_labels, student_features): 154 | # teacher_centroids: (n_class, feature_dim) 155 | n_sample, feature_dim = student_features.size() 156 | teacher_labels = teacher_labels[:n_sample] 157 | cluster_indexs = np.unique(teacher_labels) 158 | centorids = torch.zeros(k, feature_dim) 159 | 160 | for k in cluster_indexs: 161 | cluster_samples = np.where(teacher_labels==k)[0] 162 | centroid = torch.mean(student_features[cluster_samples], dim=0) 163 | centorids[k] = centroid 164 | 165 | return centorids 166 | 167 | def analyze_labels(self): 168 | metrics = {} 169 | logging.info( "Analyzing pseudo labels.") 170 | for modality in ['image', 'text']: 171 | if modality=='image': 172 | label = self.img_labels 173 | feature = self.img_feature.numpy() 174 | if modality=='text': 175 | label = self.text_labels 176 | feature = self.text_feature.numpy() 177 | 178 | unique_labels, n_samples = np.unique(label, return_counts=True) 179 | metrics[f'{modality}-n_cluster']=len(unique_labels) 180 | metrics[f'{modality}-Silhouette Coefficient']=silhouette_score(feature, label, sample_size=5000) 181 | metrics[f'{modality}-Davies-Bouldin Index']=davies_bouldin_score(feature, label) 182 | metrics[f'{modality}-Calinski and Harabasz score']=calinski_harabasz_score(feature, label) 183 | 184 | logging.info( "Pseudo labels metrics:\n" + f"\n".join([f"\t{k}\t{v}" for k, v in metrics.items()])) 185 | return metrics 186 | 187 | def show_tsne(self, file_name, truncate, title): 188 | 189 | logging.info('Fitting T-SNE') 190 | 191 | tsne_img = TSNE(verbose=True, n_jobs=64, n_iter=1000).fit(self.img_feature[:truncate]) 192 | tsne_text = TSNE(verbose=True, n_jobs=64, n_iter=1000).fit(self.text_feature[:truncate]) 193 | 194 | plt.figure(figsize=(30,15)) 195 | plt.rc('font', size=20) 196 | plt.subplots_adjust(top=0.9,wspace=0.05,hspace=0.05) 197 | 198 | plt.subplot(121) 199 | plt.xticks([]) 200 | plt.yticks([]) 201 | plt.title('image features') 202 | plt.scatter(tsne_img[:,0], tsne_img[:,1], s=1.5, c=self.img_labels[:truncate], cmap='tab10', alpha=0.8) 203 | 204 | plt.subplot(122) 205 | plt.xticks([]) 206 | plt.yticks([]) 207 | plt.title('text features') 208 | plt.scatter(tsne_text[:,0], tsne_text[:,1], s=1.5, c=self.text_labels[:truncate], cmap='tab10', alpha=0.8) 209 | 210 | plt.suptitle(title) 211 | plt.savefig(file_name, bbox_inches='tight') 212 | 213 | logging.info(f'T-SNE visuallization saved to: {file_name}') 214 | 215 | def show_samples(self, dataset, modality, file_name, sample_per_class=16, max_rows=16): 216 | images = [] 217 | texts = [] 218 | if modality=='image': 219 | label = self.img_labels 220 | elif modality=='text': 221 | label = self.text_labels 222 | 223 | logging.info(f'Visuallizing {modality} clustering results') 224 | unique_labels, n_samples = np.unique(label, return_counts=True) 225 | 226 | for k in unique_labels[:max_rows]: 227 | cluster_dataset_index = np.squeeze(np.argwhere(label==k)) 228 | if cluster_dataset_index.shape==(): 229 | continue # empty cluster 230 | # show [sample_per_class] samples for each class 231 | for i in range(sample_per_class): 232 | # sometimes there are not much sample in this cluster 233 | if i >= len(cluster_dataset_index): 234 | images.append(Image.new('RGB', (256,256), (255,255,255))) 235 | texts.append(' ') 236 | else: 237 | image, text = dataset.get_data(int(cluster_dataset_index[i])) 238 | image = ToPILImage()(image) 239 | images.append(image) 240 | texts.append(text) 241 | 242 | plot_pairs( 243 | images[:100*sample_per_class], texts[:100*sample_per_class], 244 | suptitle=file_name, file_name=file_name+'.png', 245 | sample_per_row=sample_per_class 246 | ) 247 | logging.info(f'Sample visuallization saved to: {file_name}') 248 | 249 | 250 | 251 | -------------------------------------------------------------------------------- /src/training/data.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | from dataclasses import dataclass 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from PIL import Image 8 | from torchvision.transforms import Compose, Normalize 9 | import torch 10 | from torch.utils.data import Dataset, DataLoader 11 | from torch.utils.data.distributed import DistributedSampler 12 | 13 | 14 | try: 15 | import horovod.torch as hvd 16 | except ImportError: 17 | hvd = None 18 | 19 | 20 | from open_clip import tokenize 21 | def clip_tokenizer(str): 22 | return tokenize([str])[0] 23 | 24 | class CsvDataset(Dataset): 25 | def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t", dataset_size=None, index_mapping=None, tokenizer=None): 26 | logging.debug(f'Loading csv data from {input_filename}.') 27 | 28 | df = pd.read_csv(input_filename, sep=sep) 29 | 30 | self.images = df[img_key].tolist() 31 | self.captions = df[caption_key].tolist() 32 | self.transforms = transforms 33 | self.inversed_normalize = Compose([ 34 | Normalize((0.0, 0.0, 0.0), (1/0.26862954, 1/0.26130258, 1/0.27577711)), 35 | Normalize((-0.48145466, -0.4578275, -0.40821073), (1.0, 1.0, 1.0)), 36 | ]) 37 | 38 | # Faster data loading. see https://github.com/pytorch/pytorch/issues/13246#issuecomment-905703662 39 | self.images = np.array(df[img_key].tolist()).astype(np.string_) 40 | self.captions = np.array(df[caption_key].tolist()) 41 | for i in range(len(self.captions)): 42 | self.captions[i] = self.captions[i].encode('ascii',errors='ignore') 43 | self.captions = self.captions.astype(np.string_) 44 | 45 | # use a subset of given dataset 46 | if dataset_size is not None: 47 | self.images = self.images[:dataset_size] 48 | self.captions = self.captions[:dataset_size] 49 | 50 | if index_mapping is None: 51 | self.index_mapping=torch.arange(len(self.captions)) 52 | else: 53 | self.index_mapping = index_mapping 54 | 55 | if tokenizer is None: 56 | self.tokenizer = clip_tokenizer 57 | else: 58 | # using the tokenizer of pretrained NLP model 59 | self.tokenizer = tokenizer 60 | 61 | logging.debug('Done loading data.') 62 | 63 | def __len__(self): 64 | return len(self.index_mapping) 65 | 66 | def __getitem__(self, episodic_index): 67 | index = self.index_mapping[episodic_index] 68 | image = Image.open(str(self.images[index].decode('utf-8'))) 69 | image = self.transforms(image) 70 | texts = self.tokenizer(str(self.captions[index].decode('utf-8'))) 71 | return episodic_index, image, texts 72 | 73 | def get_data(self, episode_index): 74 | idx = self.index_mapping[episode_index] 75 | pic = Image.open(str(self.images[idx].decode('utf-8'))) 76 | image = self.inversed_normalize(self.transforms(pic)) 77 | texts = self.captions[idx].decode('utf-8') 78 | 79 | return image, texts 80 | 81 | 82 | @dataclass 83 | class DataInfo: 84 | dataset: Dataset 85 | dataloader: DataLoader 86 | sampler: DistributedSampler 87 | 88 | 89 | def get_csv_dataset(args, preprocess_fn, is_train, index_mapping, tokenizer): 90 | input_filename = args.train_data if is_train else args.val_data 91 | assert input_filename 92 | dataset = CsvDataset( 93 | input_filename, 94 | preprocess_fn, 95 | img_key=args.csv_img_key, 96 | caption_key=args.csv_caption_key, 97 | sep=args.csv_separator, 98 | dataset_size=args.dataset_size, 99 | index_mapping=index_mapping, 100 | tokenizer=tokenizer) 101 | num_samples = len(dataset) 102 | sampler = DistributedSampler(dataset) if args.distributed and is_train else None 103 | shuffle = is_train and sampler is None 104 | 105 | dataloader = DataLoader( 106 | dataset, 107 | batch_size=args.batch_size, 108 | shuffle=shuffle, 109 | num_workers=args.workers, 110 | pin_memory=True, 111 | sampler=sampler, 112 | drop_last=False, 113 | persistent_workers=True 114 | ) 115 | dataloader.num_samples = num_samples 116 | dataloader.num_batches = len(dataloader) 117 | 118 | return DataInfo(dataset, dataloader, sampler) 119 | 120 | 121 | def get_data(args, preprocess_fns, index_mapping, tokenizer=None): 122 | preprocess_train, preprocess_val = preprocess_fns 123 | data = {} 124 | 125 | if args.train_data: 126 | data["train"] = get_csv_dataset(args, preprocess_train, is_train=True, index_mapping=index_mapping, tokenizer=tokenizer) 127 | 128 | return data 129 | -------------------------------------------------------------------------------- /src/training/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | try: 6 | import horovod.torch as hvd 7 | except ImportError: 8 | hvd = None 9 | import torch.distributed as dist 10 | 11 | 12 | def get_gathered_item(item, args): 13 | if args.distributed: 14 | world_size = dist.get_world_size() 15 | gathered_item = [torch.zeros_like(item) for _ in range(world_size)] 16 | dist.all_gather(gathered_item, item) 17 | # all_item = torch.cat([item] + gathered_item[:rank] + gathered_item[rank + 1 :]) 18 | all_item = torch.cat(gathered_item, dim=0) 19 | else: 20 | all_item = item 21 | 22 | return all_item 23 | 24 | 25 | def is_global_master(args): 26 | return args.rank == 0 27 | 28 | 29 | def is_local_master(args): 30 | return args.local_rank == 0 31 | 32 | 33 | def is_master(args, local=False): 34 | return is_local_master(args) if local else is_global_master(args) 35 | 36 | 37 | def is_using_horovod(): 38 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set 39 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... 40 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] 41 | pmi_vars = ["PMI_RANK", "PMI_SIZE"] 42 | if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]): 43 | return True 44 | else: 45 | return False 46 | 47 | 48 | def is_using_distributed(): 49 | if 'WORLD_SIZE' in os.environ: 50 | return int(os.environ['WORLD_SIZE']) > 1 51 | if 'SLURM_NTASKS' in os.environ: 52 | return int(os.environ['SLURM_NTASKS']) > 1 53 | return False 54 | 55 | 56 | def world_info_from_env(): 57 | local_rank = 0 58 | for v in ('SLURM_LOCALID', 'MPI_LOCALRANKID', 'OMPI_COMM_WORLD_LOCAL_RANK', 'LOCAL_RANK'): 59 | if v in os.environ: 60 | local_rank = int(os.environ[v]) 61 | break 62 | global_rank = 0 63 | for v in ('SLURM_PROCID', 'PMI_RANK', 'OMPI_COMM_WORLD_RANK', 'RANK'): 64 | if v in os.environ: 65 | global_rank = int(os.environ[v]) 66 | break 67 | world_size = 1 68 | for v in ('SLURM_NTASKS', 'PMI_SIZE', 'OMPI_COMM_WORLD_SIZE', 'WORLD_SIZE'): 69 | if v in os.environ: 70 | world_size = int(os.environ[v]) 71 | break 72 | 73 | return local_rank, global_rank, world_size 74 | 75 | 76 | def init_distributed_device(args): 77 | # Distributed training = training on more than one GPU. 78 | # Works in both single and multi-node scenarios. 79 | args.distributed = False 80 | args.world_size = 1 81 | args.rank = 0 # global rank 82 | args.local_rank = 0 83 | if args.horovod: 84 | assert hvd is not None, "Horovod is not installed" 85 | hvd.init() 86 | args.local_rank = int(hvd.local_rank()) 87 | args.rank = hvd.rank() 88 | args.world_size = hvd.size() 89 | args.distributed = True 90 | os.environ['LOCAL_RANK'] = str(args.local_rank) 91 | os.environ['RANK'] = str(args.rank) 92 | os.environ['WORLD_SIZE'] = str(args.world_size) 93 | elif is_using_distributed(): 94 | if 'SLURM_PROCID' in os.environ: 95 | # DDP via SLURM 96 | args.local_rank, args.rank, args.world_size = world_info_from_env() 97 | # SLURM var -> torch.distributed vars in case needed 98 | os.environ['LOCAL_RANK'] = str(args.local_rank) 99 | os.environ['RANK'] = str(args.rank) 100 | os.environ['WORLD_SIZE'] = str(args.world_size) 101 | torch.distributed.init_process_group( 102 | backend=args.dist_backend, 103 | init_method=args.dist_url, 104 | world_size=args.world_size, 105 | rank=args.rank, 106 | ) 107 | else: 108 | # DDP via torchrun, torch.distributed.launch 109 | args.local_rank, _, _ = world_info_from_env() 110 | torch.distributed.init_process_group( 111 | backend=args.dist_backend, 112 | init_method=args.dist_url) 113 | args.world_size = torch.distributed.get_world_size() 114 | args.rank = torch.distributed.get_rank() 115 | args.distributed = True 116 | 117 | if torch.cuda.is_available(): 118 | if args.distributed and not args.no_set_device_rank: 119 | device = 'cuda:%d' % args.local_rank 120 | else: 121 | device = 'cuda:0' 122 | torch.cuda.set_device(device) 123 | else: 124 | device = 'cpu' 125 | args.device = device 126 | device = torch.device(device) 127 | return device 128 | -------------------------------------------------------------------------------- /src/training/evaluations/analyze_features.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn.metrics.pairwise import cosine_similarity 4 | 5 | def analyze_features(all_image_features, all_text_features, args): 6 | if all_image_features is None or all_text_features is None: 7 | return {} 8 | 9 | image_avg_sim = get_self_cosine_similarity(all_image_features) 10 | text_avg_sim = get_self_cosine_similarity(all_text_features) 11 | modality_gap = get_modality_gap(all_image_features, all_text_features) 12 | 13 | results = { 14 | 'image_avg_self_similarity':image_avg_sim, 15 | 'text_avg_self_similarity':text_avg_sim, 16 | 'modality_gap':modality_gap, 17 | 'image_feature_std': float(torch.std(all_image_features, dim=0).mean().item()), 18 | 'text_feature_std': float(torch.std(all_text_features, dim=0).mean().item()), 19 | } 20 | 21 | return results 22 | 23 | 24 | def get_self_cosine_similarity(features): 25 | # reimplement Figure 2(a) in https://arxiv.org/abs/2203.02053v1 26 | features = features.numpy() 27 | similarities = cosine_similarity(features, features).flatten() 28 | similarities = similarities[similarities>0] 29 | 30 | return float(np.average(similarities)) 31 | 32 | 33 | def get_modality_gap(all_image_features, all_text_features): 34 | # reimplement the "modaility gap" in Section 4.2 of https://arxiv.org/abs/2203.02053v1 35 | mean_image_feature = torch.mean(all_image_features, dim=0) 36 | mean_text_feature = torch.mean(all_text_features, dim=0) 37 | delta_gap = mean_image_feature - mean_text_feature 38 | 39 | return float(delta_gap.norm().item()) 40 | -------------------------------------------------------------------------------- /src/training/evaluations/coco_retrieval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import numpy as np 4 | import tqdm 5 | import os 6 | from torchvision.datasets.coco import CocoCaptions 7 | from torch.utils.data import Dataset, DataLoader 8 | from open_clip import tokenize as clip_tokenizer 9 | from PIL import Image 10 | 11 | 12 | class CocoDataset(Dataset): 13 | # modeified from https://github.com/uta-smile/TCL/blob/main/dataset/caption_dataset.py#L50 14 | # get the ground truth (1 image v.s. multiple captions, hiting each of them is ok) for retrieval 15 | def __init__(self, coco_dataset, coco_val_root, transform, tokenizer): 16 | self.text = [] 17 | self.image = [] 18 | self.txt2img = {} 19 | self.img2txt = {} 20 | self.transform = transform 21 | 22 | txt_id = 0 23 | for index in range(len(coco_dataset)): 24 | ann_ids = coco_dataset.coco.getAnnIds(imgIds=coco_dataset.ids[index]) 25 | anns = coco_dataset.coco.loadAnns(ann_ids) 26 | target = [ann['caption'] for ann in anns] 27 | 28 | path = coco_dataset.coco.loadImgs(coco_dataset.ids[index])[0]['file_name'] 29 | path = os.path.join(coco_val_root, path) 30 | 31 | self.image.append(path) 32 | self.img2txt[index] = [] 33 | 34 | for i, caption in enumerate(target): 35 | if tokenizer is None: 36 | self.text.append(clip_tokenizer(caption)) 37 | else: 38 | self.text.append(torch.stack([tokenizer(caption)])) 39 | self.img2txt[index].append(txt_id) 40 | self.txt2img[txt_id] = index 41 | txt_id += 1 42 | self.text = torch.cat(self.text, dim=0) 43 | 44 | def __len__(self): 45 | return len(self.image) 46 | 47 | def __getitem__(self, index): 48 | image_path = self.image[index] 49 | image = Image.open(image_path).convert('RGB') 50 | image = self.transform(image) 51 | return image, index 52 | 53 | 54 | class CocoTexts(): 55 | def __init__(self, coco_dataset): 56 | self.coco_dataset = coco_dataset 57 | 58 | def __len__(self): 59 | return len(self.coco_dataset.text) 60 | 61 | def __getitem__(self, index): 62 | return self.coco_dataset.text[index] 63 | 64 | 65 | def coco_retrieval_evaluation(model, epoch, preprocess, tokenizer, args): 66 | if args.retrieval_frequency == 0: 67 | return {}, None, None 68 | if (epoch % args.retrieval_frequency) != 0 and epoch != args.epochs: 69 | return {}, None, None 70 | 71 | coco_val_root = os.path.join(args.eval_data_dir, 'coco2017/val2017') 72 | coco_val_json = os.path.join(args.eval_data_dir, 'coco2017/annotations/captions_val2017.json') 73 | 74 | coco_dataset = CocoCaptions(root=coco_val_root, annFile=coco_val_json, transform=preprocess) 75 | coco_dataset = CocoDataset(coco_dataset, coco_val_root=coco_val_root, transform=preprocess, tokenizer=tokenizer) 76 | coco_retrieval_dataloader = DataLoader( 77 | coco_dataset, 78 | batch_size=args.batch_size, shuffle=False, 79 | num_workers=args.workers, pin_memory=True, drop_last=False, 80 | ) 81 | coco_dataset_text = CocoTexts(coco_dataset) 82 | coco_retrieval_text_dataloader = DataLoader( 83 | coco_dataset_text, 84 | batch_size=args.batch_size, shuffle=False, 85 | num_workers=args.workers, pin_memory=True, drop_last=False, 86 | ) 87 | 88 | with torch.no_grad(): 89 | logging.info('extracting COCO text features...') 90 | all_text_features = [] 91 | for texts in tqdm.tqdm(coco_retrieval_text_dataloader): 92 | texts = texts.to(args.device) 93 | if args.distributed and not args.horovod: 94 | text_features = model.module.encode_text(texts).detach().cpu() 95 | else: 96 | text_features = model.encode_text(texts).detach().cpu() 97 | all_text_features.append(text_features) 98 | all_text_features = torch.cat(all_text_features,dim=0) 99 | 100 | logging.info('extracting COCO image features...') 101 | all_image_features = [] 102 | for images, img_id in tqdm.tqdm(coco_retrieval_dataloader): 103 | images = images.to(args.device) 104 | 105 | if args.distributed and not args.horovod: 106 | image_features = model.module.encode_image(images).detach().cpu() 107 | else: 108 | image_features = model.encode_image(images).detach().cpu() 109 | 110 | all_image_features.append(image_features) 111 | all_image_features = torch.cat(all_image_features,dim=0) 112 | 113 | # normalization, this step is important 114 | all_image_features = all_image_features / all_image_features.norm(dim=-1, keepdim=True) 115 | all_text_features = all_text_features / all_text_features.norm(dim=-1, keepdim=True) 116 | 117 | scores_img2text = (all_image_features @ all_text_features.t()).detach() 118 | scores_text2img = scores_img2text.t().detach() 119 | 120 | retrieval_metrics = get_retrieval_metrics( 121 | scores_img2text.cpu().numpy(), 122 | scores_text2img.cpu().numpy(), 123 | coco_retrieval_dataloader.dataset.img2txt, 124 | coco_retrieval_dataloader.dataset.txt2img 125 | ) 126 | logging.info('COCO retrieval evaluation: '+ str(retrieval_metrics)) 127 | 128 | deduplicated_text_features = torch.zeros_like(all_image_features) 129 | for i in range(len(coco_retrieval_dataloader.dataset.img2txt)): 130 | deduplicated_text_features[i] = all_text_features[coco_retrieval_dataloader.dataset.img2txt[i][0]] 131 | 132 | return retrieval_metrics, all_image_features, deduplicated_text_features 133 | 134 | 135 | 136 | def get_retrieval_metrics(scores_img2text, scores_text2img, gt_img2text, gt_text2img): 137 | 138 | #Images->Text 139 | ranks = np.zeros(scores_img2text.shape[0]) 140 | for index,score in enumerate(scores_img2text): 141 | inds = np.argsort(score)[::-1] 142 | rank = 1e20 143 | for i in gt_img2text[index]: 144 | tmp = np.where(inds == i)[0][0] 145 | if tmp < rank: 146 | rank = tmp 147 | ranks[index] = rank 148 | 149 | img2text_recall_at_1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 150 | img2text_recall_at_5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 151 | img2text_recall_at_10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 152 | 153 | #Text->Images 154 | ranks = np.zeros(scores_text2img.shape[0]) 155 | for index,score in enumerate(scores_text2img): 156 | inds = np.argsort(score)[::-1] 157 | ranks[index] = np.where(inds == gt_text2img[index])[0][0] 158 | 159 | text2img_recall_at_1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 160 | text2img_recall_at_5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 161 | text2img_recall_at_10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 162 | 163 | tr_mean = (img2text_recall_at_1 + img2text_recall_at_5 + img2text_recall_at_10) / 3 164 | ir_mean = (text2img_recall_at_1 + text2img_recall_at_5 + text2img_recall_at_10) / 3 165 | r_mean = (tr_mean + ir_mean) / 2 166 | 167 | eval_result = { 168 | 'image2text-R@1': img2text_recall_at_1, 169 | 'image2text-R@5': img2text_recall_at_5, 170 | 'image2text-R@10': img2text_recall_at_10, 171 | #'image2text-R-mean': tr_mean, 172 | 'text2image-R@1': text2img_recall_at_1, 173 | 'text2image-R@5': text2img_recall_at_5, 174 | 'text2image-R@10': text2img_recall_at_10, 175 | #'text2image-R-mean': ir_mean, 176 | 'mean-recall': r_mean 177 | } 178 | 179 | for key, item in eval_result.items(): 180 | eval_result[key] = float(item) 181 | return eval_result 182 | -------------------------------------------------------------------------------- /src/training/evaluations/evaluation.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import json 4 | from training.distributed import is_master 5 | from .linear_eval import linear_eval 6 | from .zero_shot import zero_shot_eval 7 | from .coco_retrieval import coco_retrieval_evaluation 8 | from .analyze_features import analyze_features 9 | 10 | try: 11 | import wandb 12 | except ImportError: 13 | wandb = None 14 | 15 | def evaluate(model, epoch, preprocess, tokenizer, args, tb_writer=None): 16 | if not is_master(args): 17 | return 18 | logging.info( f"Starting evaluation of [{args.name}] at epoch {epoch}") 19 | 20 | linear_eval_datasets = ['imagenet', 'cifar10', 'cifar100', 'stl10'] 21 | zeroshot_datasets = ['imagenet', 'cifar10', 'cifar100', 'stl10', 'birdsnap','country211', 'flowers102', 'gtsrb', 'ucf101','stanford_cars'] 22 | 23 | model.eval() 24 | all_metrics = {} 25 | 26 | # zeroshot classification 27 | metrics = {} 28 | for zeroshot_dataset in zeroshot_datasets: 29 | zeroshot_metrics = zero_shot_eval(model, zeroshot_dataset, epoch, preprocess, tokenizer, args) 30 | metrics.update(zeroshot_metrics) 31 | all_metrics.update(zeroshot_metrics) 32 | for name, val in metrics.items(): 33 | if tb_writer is not None: 34 | tb_writer.add_scalar(f"eval_zero_shot/{name}", val, epoch) 35 | if args.wandb: 36 | wandb.log({f"eval_zero_shot/{name}": val, 'epoch': epoch}) 37 | 38 | # MS-COCO retrieval 39 | metrics = {} 40 | retrieval_metrics, all_image_features, all_text_features= coco_retrieval_evaluation(model, epoch, preprocess, tokenizer, args) 41 | metrics.update(retrieval_metrics) 42 | all_metrics.update(retrieval_metrics) 43 | for name, val in metrics.items(): 44 | if tb_writer is not None: 45 | tb_writer.add_scalar(f"eval_retrieval/{name}", val, epoch) 46 | if args.wandb: 47 | wandb.log({f"eval_retrieval/{name}": val, 'epoch': epoch}) 48 | 49 | # Analyse COCO features 50 | feature_metrics = analyze_features(all_image_features, all_text_features, args) 51 | all_metrics.update(feature_metrics) 52 | for name, val in feature_metrics.items(): 53 | if tb_writer is not None: 54 | tb_writer.add_scalar(f"eval_analyze_features/{name}", val, epoch) 55 | if args.wandb: 56 | wandb.log({f"eval_analyze_features/{name}": val, 'epoch': epoch}) 57 | 58 | # linear evaluation 59 | metrics = {} 60 | if linear_eval_datasets: 61 | linear_metrics = linear_eval(model, linear_eval_datasets, epoch, preprocess, args) 62 | metrics.update(linear_metrics) 63 | all_metrics.update(linear_metrics) 64 | 65 | logging.info( f"Finished evaluation of [{args.name}] at epoch {epoch}\n" + "\n".join([f"\t{k}\t{v}" for k, v in all_metrics.items()])) 66 | 67 | for name, val in metrics.items(): 68 | if tb_writer is not None: 69 | tb_writer.add_scalar(f"eval_linear_prob/{name}", val, epoch) 70 | if args.wandb: 71 | wandb.log({f"eval_linear_prob/{name}": val, 'epoch': epoch}) 72 | 73 | if args.save_logs: 74 | with open(os.path.join(args.logs, args.name, "results.jsonl"), "a+") as f: 75 | f.write(json.dumps(all_metrics)) 76 | f.write("\n") 77 | 78 | return all_metrics 79 | -------------------------------------------------------------------------------- /src/training/evaluations/linear_eval.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | import numpy as np 6 | from sklearn.linear_model import LogisticRegression as sklearnLogisticRegression 7 | from torch.utils.data import DataLoader 8 | from torchvision.datasets import CIFAR10, CIFAR100, STL10, ImageFolder 9 | from tqdm import tqdm 10 | import logging 11 | 12 | def logistic_regression_pytorch(train_features, train_labels, test_features, test_labels): 13 | 14 | class AverageMeter(object): 15 | """computes and stores the average and current value""" 16 | 17 | def __init__(self): 18 | self.reset() 19 | 20 | def reset(self): 21 | self.val = 0 22 | self.avg = 0 23 | self.sum = 0 24 | self.count = 0 25 | 26 | def update(self, val, n=1): 27 | self.val = val 28 | self.sum += val * n 29 | self.count += n 30 | self.avg = self.sum / self.count 31 | 32 | class TensorDataset(): 33 | def __init__(self, *tensors): 34 | self.tensors = tensors 35 | 36 | def __getitem__(self, index): 37 | return tuple(tensor[index] for tensor in self.tensors) 38 | 39 | def __len__(self): 40 | return self.tensors[0].size(0) 41 | 42 | class Classifier(nn.Module): 43 | def __init__(self, feature_dim, num_labels): 44 | super(Classifier, self).__init__() 45 | 46 | self.linear = nn.Linear(feature_dim, num_labels) 47 | self.linear.weight.data.normal_(mean=0.0, std=0.01) 48 | self.linear.bias.data.zero_() 49 | 50 | def forward(self, x): 51 | x = x.view(x.size(0), -1) 52 | return self.linear(x) 53 | 54 | def accuracy(output, target, topk=(1,)): 55 | """Computes the accuracy over the k top predictions for the specified values of k""" 56 | with torch.no_grad(): 57 | maxk = max(topk) 58 | batch_size = target.size(0) 59 | 60 | _, pred = output.topk(maxk, 1, True, True) 61 | pred = pred.t() 62 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 63 | 64 | res = [] 65 | for k in topk: 66 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 67 | res.append(correct_k.mul_(100.0 / batch_size)) 68 | return res 69 | 70 | train_dataset = TensorDataset(torch.Tensor(train_features), torch.Tensor(train_labels).long()) 71 | val_dataset = TensorDataset(torch.Tensor(test_features), torch.Tensor(test_labels).long()) 72 | train_loader = DataLoader(train_dataset, batch_size=1024, num_workers=8, pin_memory=True, persistent_workers=True) 73 | val_loader = DataLoader(val_dataset, batch_size=5000, num_workers=8, pin_memory=True, persistent_workers=True) 74 | 75 | num_labels = int(max(train_labels)+1) 76 | classifier = Classifier(train_features.shape[1], num_labels).cuda() 77 | optimizer = torch.optim.SGD(classifier.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-6) 78 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 100, eta_min=0) 79 | 80 | criterion = nn.CrossEntropyLoss().cuda() 81 | best_acc = 0 82 | for epoch in (pbar := tqdm(range(100))): 83 | top1_train = AverageMeter() 84 | top5_train = AverageMeter() 85 | top1 = AverageMeter() 86 | top5 = AverageMeter() 87 | losses = AverageMeter() 88 | 89 | for step, (feature, label) in enumerate(train_loader): 90 | feature = feature.cuda() 91 | label = label.cuda() 92 | output = classifier(feature) 93 | loss = criterion(output, label) 94 | optimizer.zero_grad() 95 | loss.backward() 96 | optimizer.step() 97 | acc1, acc5 = accuracy(output, label, topk=(1, 5)) 98 | losses.update(loss.item(), feature.size(0)) 99 | top1_train.update(acc1[0], feature.size(0)) 100 | top5_train.update(acc5[0], feature.size(0)) 101 | 102 | for step, (feature, label) in enumerate(val_loader): 103 | feature = feature.cuda() 104 | label = label.cuda() 105 | with torch.no_grad(): 106 | output = classifier(feature) 107 | acc1, acc5 = accuracy(output, label, topk=(1, 5)) 108 | top1.update(acc1[0], feature.size(0)) 109 | top5.update(acc5[0], feature.size(0)) 110 | 111 | scheduler.step() 112 | 113 | if top1.avg.item() > best_acc: 114 | best_acc = top1.avg.item() 115 | pbar.set_description(f'Epoch {epoch+1}, test accuracy {top1.avg.item():.2f}, best accuracy {best_acc:.2f}') 116 | 117 | return best_acc 118 | 119 | 120 | def get_features(model, dataset, args): 121 | 122 | all_features = [] 123 | all_labels = [] 124 | with torch.no_grad(): 125 | for images, labels in tqdm(DataLoader(dataset, batch_size=args.batch_size, num_workers=args.workers)): 126 | images = images.to(args.device) 127 | 128 | if args.distributed and not args.horovod: 129 | image_features = model.module.encode_image(images) 130 | else: 131 | image_features = model.encode_image(images) 132 | 133 | all_features.append(image_features.cpu()) 134 | all_labels.append(labels.cpu()) 135 | 136 | return torch.cat(all_features).numpy(), torch.cat(all_labels).numpy() 137 | 138 | 139 | def get_linear_eval_acc(model, dataset_name, root, preprocess, args): 140 | 141 | if dataset_name=='cifar10': 142 | train = CIFAR10(root, download=True, train=True, transform=preprocess) 143 | test = CIFAR10(root, download=True, train=False, transform=preprocess) 144 | 145 | elif dataset_name=='cifar100': 146 | train = CIFAR100(root, download=True, train=True, transform=preprocess) 147 | test = CIFAR100(root, download=True, train=False, transform=preprocess) 148 | 149 | elif dataset_name=='stl10': 150 | train = STL10(root, download=True, split='train', transform=preprocess) 151 | test = STL10(root, download=True, split='test', transform=preprocess) 152 | else: 153 | train = ImageFolder(f'{args.eval_data_dir}/{dataset_name}/train', transform=preprocess) 154 | test = ImageFolder(f'{args.eval_data_dir}/{dataset_name}/test', transform=preprocess) 155 | 156 | 157 | # Calculate the image features 158 | logging.info(f'extracting featres from {dataset_name} training set...') 159 | train_features, train_labels = get_features(model, train, args=args) 160 | logging.info(f'extracting featres from {dataset_name} testing set...') 161 | test_features, test_labels = get_features(model, test, args=args) 162 | 163 | if args.linear_prob_mode=='sklearn': 164 | logging.info('Runing sklearn-based logistic regression') 165 | classifier = sklearnLogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1, n_jobs=32) 166 | classifier.fit(train_features, train_labels) 167 | predictions = classifier.predict(test_features) 168 | accuracy = 100 * np.mean((test_labels == predictions).astype(np.float)) 169 | 170 | elif args.linear_prob_mode=='pytorch': 171 | logging.info('Runing pytorch-based logistic regression') 172 | accuracy = logistic_regression_pytorch(train_features, train_labels, test_features, test_labels) 173 | 174 | return float(accuracy) 175 | 176 | 177 | def linear_eval(model, dataset_names, epoch, preprocess, args): 178 | 179 | if args.linear_frequency == 0: 180 | return {} 181 | if (epoch % args.linear_frequency) != 0 and epoch != args.epochs: 182 | return {} 183 | 184 | results = {} 185 | for dataset_name in dataset_names: 186 | logging.info(f'starting linear evaluation on {dataset_name}...') 187 | accuracy = get_linear_eval_acc(model, dataset_name, args.eval_data_dir, preprocess, args) 188 | results[f'{dataset_name}-linear-eval-acc'] = accuracy 189 | logging.info(f'Finished linear evaluation on {dataset_name}. accuracy: {accuracy}') 190 | 191 | return results 192 | 193 | 194 | if __name__=='__main__': 195 | import open_clip 196 | import os 197 | import pickle as pkl 198 | model, _, preprocess = open_clip.create_model_and_transforms('RN50', pretrained='OpenAI') 199 | train_data, val_data = pkl.load(open(os.path.join('train.pkl'),'rb')), pkl.load(open(os.path.join('test.pkl'),'rb')) 200 | train_features, train_labels = train_data['features'], train_data['labels'] 201 | test_features, test_labels = val_data['features'], val_data['labels'] 202 | logistic_regression_pytorch(train_features, train_labels, test_features, test_labels) 203 | -------------------------------------------------------------------------------- /src/training/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def setup_logging(log_file, level, include_host=False): 5 | if include_host: 6 | import socket 7 | hostname = socket.gethostname() 8 | formatter = logging.Formatter( 9 | f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 10 | else: 11 | formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 12 | 13 | logging.root.setLevel(level) 14 | loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] 15 | for logger in loggers: 16 | logger.setLevel(level) 17 | 18 | stream_handler = logging.StreamHandler() 19 | stream_handler.setFormatter(formatter) 20 | logging.root.addHandler(stream_handler) 21 | 22 | if log_file: 23 | file_handler = logging.FileHandler(filename=log_file) 24 | file_handler.setFormatter(formatter) 25 | logging.root.addHandler(file_handler) 26 | 27 | -------------------------------------------------------------------------------- /src/training/loss.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch 5 | import torch.distributed.nn 6 | from torch import distributed as dist, nn as nn 7 | from torch.nn import functional as F 8 | 9 | try: 10 | import horovod.torch as hvd 11 | except ImportError: 12 | hvd = None 13 | 14 | 15 | def gather_features( 16 | image_features, 17 | text_features, 18 | local_loss=False, 19 | gather_with_grad=False, 20 | rank=0, 21 | world_size=1, 22 | use_horovod=False 23 | ): 24 | if use_horovod: 25 | assert hvd is not None, 'Please install horovod' 26 | if gather_with_grad: 27 | all_image_features = hvd.allgather(image_features) 28 | all_text_features = hvd.allgather(text_features) 29 | else: 30 | with torch.no_grad(): 31 | all_image_features = hvd.allgather(image_features) 32 | all_text_features = hvd.allgather(text_features) 33 | if not local_loss: 34 | # ensure grads for local rank when all_* features don't have a gradient 35 | gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) 36 | gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) 37 | gathered_image_features[rank] = image_features 38 | gathered_text_features[rank] = text_features 39 | all_image_features = torch.cat(gathered_image_features, dim=0) 40 | all_text_features = torch.cat(gathered_text_features, dim=0) 41 | else: 42 | # We gather tensors from all gpus 43 | if gather_with_grad: 44 | all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) 45 | all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) 46 | else: 47 | gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] 48 | gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] 49 | dist.all_gather(gathered_image_features, image_features) 50 | dist.all_gather(gathered_text_features, text_features) 51 | if not local_loss: 52 | # ensure grads for local rank when all_* features don't have a gradient 53 | gathered_image_features[rank] = image_features 54 | gathered_text_features[rank] = text_features 55 | all_image_features = torch.cat(gathered_image_features, dim=0) 56 | all_text_features = torch.cat(gathered_text_features, dim=0) 57 | 58 | return all_image_features, all_text_features 59 | 60 | 61 | class ClipLoss(nn.Module): 62 | 63 | def __init__( 64 | self, 65 | local_loss=False, 66 | gather_with_grad=False, 67 | cache_labels=False, 68 | rank=0, 69 | world_size=1, 70 | use_horovod=False, 71 | ): 72 | super().__init__() 73 | self.local_loss = local_loss 74 | self.gather_with_grad = gather_with_grad 75 | self.cache_labels = cache_labels 76 | self.rank = rank 77 | self.world_size = world_size 78 | self.use_horovod = use_horovod 79 | 80 | # cache state 81 | self.prev_num_logits = 0 82 | self.labels = {} 83 | 84 | def forward(self, image_features, text_features, logit_scale): 85 | device = image_features.device 86 | logits_per_image = logit_scale * image_features @ text_features.T 87 | logits_per_text = logit_scale * text_features @ image_features.T 88 | 89 | # calculated ground-truth and cache if enabled 90 | num_logits = logits_per_image.shape[0] 91 | if self.prev_num_logits != num_logits or device not in self.labels: 92 | labels = torch.arange(num_logits, device=device, dtype=torch.long) 93 | if self.world_size > 1 and self.local_loss: 94 | labels = labels + num_logits * self.rank 95 | if self.cache_labels: 96 | self.labels[device] = labels 97 | self.prev_num_logits = num_logits 98 | else: 99 | labels = self.labels[device] 100 | 101 | total_loss = ( 102 | F.cross_entropy(logits_per_image, labels) + 103 | F.cross_entropy(logits_per_text, labels) 104 | ) / 2 105 | return total_loss 106 | 107 | CrossEntropy = nn.CrossEntropyLoss(ignore_index=-1) 108 | 109 | class ProtoLoss(nn.Module): 110 | 111 | def __init__(self): 112 | super().__init__() 113 | 114 | def forward(self, student_features, student_centroids, teacher_centroids, logit_scale_student, teacher_temperature, labels): 115 | 116 | # student_scores: shape(batchsize x n_class) 117 | student_scores = logit_scale_student * student_features @ (student_centroids.T) 118 | 119 | # teacher_scores: shape(batchsize x n_class) 120 | if teacher_temperature > 0: 121 | # [softmax], smooth label 122 | teacher_scores = teacher_centroids[labels] @ (teacher_centroids.T) / teacher_temperature 123 | teacher_scores = teacher_scores.softmax(dim=1) 124 | loss = CrossEntropy(student_scores, teacher_scores) 125 | 126 | else: 127 | # [hardmax], one-hot label 128 | loss = CrossEntropy(student_scores, labels) 129 | 130 | _,pred = torch.max(student_scores.detach(), dim=1) 131 | correct = torch.sum(pred==labels).item() 132 | accuracy = correct/labels.size(0)*100 133 | 134 | return loss, accuracy -------------------------------------------------------------------------------- /src/training/main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | from datetime import datetime 5 | 6 | import numpy as np 7 | import torch 8 | from torch import optim 9 | from torch.cuda.amp import GradScaler 10 | import time 11 | try: 12 | import wandb 13 | except ImportError: 14 | wandb = None 15 | 16 | try: 17 | import torch.utils.tensorboard as tensorboard 18 | except ImportError: 19 | tensorboard = None 20 | 21 | try: 22 | import horovod.torch as hvd 23 | except ImportError: 24 | hvd = None 25 | 26 | from open_clip import create_model_and_transforms, trace_model 27 | from training.pretrained_transformers import get_pretrained_text_encoder_and_tokenizer 28 | 29 | from training.data import get_data 30 | from training.distributed import is_master, init_distributed_device, world_info_from_env 31 | from training.logger import setup_logging 32 | from training.params import parse_args 33 | from training.scheduler import cosine_lr, protoclip_cosine_lr 34 | from training.train import train_one_epoch, feature_extraction_one_epoch 35 | from training.clustering import Clustering 36 | from training.evaluations.evaluation import evaluate 37 | 38 | import torch.distributed as dist 39 | 40 | def random_seed(seed): 41 | torch.manual_seed(seed) 42 | np.random.seed(seed) 43 | torch.cuda.manual_seed_all(seed) 44 | random.seed(seed) 45 | 46 | 47 | def main(): 48 | args = parse_args() 49 | random_seed(args.seed) 50 | 51 | # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule? 52 | args.model = args.model.replace('/', '-') 53 | 54 | # get the name of the experiments 55 | if args.name is None: 56 | args.name = '-'.join([ 57 | datetime.now().strftime("%Y_%m_%d-%H_%M"), # disabled since it might make different process to have different names 58 | f"model_{args.model}", 59 | f"lr_{args.lr}", 60 | f"b_{args.batch_size}", 61 | f"j_{args.workers}", 62 | f"p_{args.precision}", 63 | ]) 64 | 65 | # discover initial world args early so we can log properly 66 | args.distributed = False 67 | args.local_rank, args.rank, args.world_size = world_info_from_env() 68 | 69 | args.log_path = None 70 | if is_master(args, local=args.log_local): 71 | log_base_path = os.path.join(args.logs, args.name) 72 | os.makedirs(log_base_path, exist_ok=True) 73 | log_filename = f'out-{args.rank}' if args.log_local else 'out.log' 74 | args.log_path = os.path.join(log_base_path, log_filename) 75 | if os.path.exists(args.log_path): 76 | print( 77 | "Error. Experiment already exists. Use --name {} to specify a new experiment." 78 | ) 79 | return -1 80 | 81 | # Set logger 82 | args.log_level = logging.DEBUG if args.debug else logging.INFO 83 | setup_logging(args.log_path, args.log_level) 84 | 85 | # fully initialize distributed device environment 86 | torch.backends.cudnn.benchmark = True 87 | torch.backends.cudnn.deterministic = False 88 | device = init_distributed_device(args) 89 | 90 | args.wandb = 'wandb' in args.report_to or 'all' in args.report_to 91 | args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to 92 | 93 | # NCCL does not support CPU tensor communication. Set up manual multiprocessing communication. 94 | args.cache_path = os.path.join(args.logs, args.name, "cache") 95 | args.visualization_path = os.path.join(args.logs, args.name, "visualization") 96 | if is_master(args): 97 | args.tensorboard_path = os.path.join(args.logs, args.name, "tensorboard") if args.tensorboard else '' 98 | args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints") 99 | for dirname in [args.tensorboard_path, args.checkpoint_path, args.cache_path, args.visualization_path]: 100 | if dirname: 101 | os.makedirs(dirname, exist_ok=True) 102 | else: 103 | args.tensorboard_path = '' 104 | args.checkpoint_path = '' 105 | 106 | if args.copy_codebase and is_master(args): 107 | copy_codebase(args) 108 | 109 | assert args.precision in ['amp', 'fp16', 'fp32'] 110 | if args.precision == 'fp16': 111 | logging.warning( 112 | 'It is recommended to use AMP mixed-precision instead of FP16. ' 113 | 'FP16 support needs further verification and tuning, especially for train.') 114 | 115 | if args.horovod: 116 | logging.info( 117 | f'Running in horovod mode with multiple processes / nodes. Device: {args.device}.' 118 | f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.') 119 | elif args.distributed: 120 | logging.info( 121 | f'Running in distributed mode with multiple processes. Device: {args.device}.' 122 | f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.') 123 | else: 124 | logging.info(f'Running with a single process. Device {args.device}.') 125 | 126 | if args.pretrained_text is not None: 127 | logging.info(f'Loading pretrained text transformer structure: {args.pretrained_text}.') 128 | pretrained_text_encoder, tokenizer, args.pretrained_text_feature_dim = get_pretrained_text_encoder_and_tokenizer(args.pretrained_text) 129 | else: 130 | logging.info(f'Using CLIP default text transformer structure.') 131 | pretrained_text_encoder, tokenizer, args.pretrained_text_feature_dim = None, None, None 132 | 133 | model, preprocess_train, preprocess_val = create_model_and_transforms( 134 | args.model, 135 | args.pretrained, 136 | precision=args.precision, 137 | device=device, 138 | jit=args.torchscript, 139 | force_quick_gelu=args.force_quick_gelu, 140 | pretrained_image=args.pretrained_image, 141 | pretrained_text=pretrained_text_encoder, 142 | args=args 143 | ) 144 | if is_master(args): 145 | logging.info(str(model)) 146 | if args.trace: 147 | model = trace_model(model, batch_size=args.batch_size, device=device) 148 | 149 | if args.lock_image: 150 | # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 151 | model.lock_image_tower( 152 | unlocked_groups=args.lock_image_unlocked_groups, 153 | freeze_bn_stats=args.lock_image_freeze_bn_stats) 154 | 155 | if args.grad_checkpointing: 156 | model.set_grad_checkpointing() 157 | 158 | if is_master(args): 159 | logging.info("Params:") 160 | params_file = os.path.join(args.logs, args.name, "params.txt") 161 | with open(params_file, "w") as f: 162 | for name in sorted(vars(args)): 163 | val = getattr(args, name) 164 | logging.info(f" {name}: {val}") 165 | f.write(f"{name}: {val}\n") 166 | 167 | if args.distributed and not args.horovod: 168 | if args.use_bn_sync: 169 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 170 | ddp_args = {} 171 | if args.ddp_static_graph: 172 | # this doesn't exist in older PyTorch, arg only added if enabled 173 | ddp_args['static_graph'] = True 174 | model = torch.nn.parallel.DistributedDataParallel( 175 | model, 176 | device_ids=[device], **ddp_args, 177 | find_unused_parameters=bool(pretrained_text_encoder is not None) # TODO: find which parameter is unused 178 | ) 179 | 180 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 181 | # create optimizer and scaler 182 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 183 | optimizer = None 184 | scaler = None 185 | if args.train_data: 186 | assert not args.trace, 'Cannot train with traced model' 187 | 188 | visual = lambda n, p: 'visual' in n 189 | non_visual = lambda n, p: not visual(n, p) 190 | 191 | exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n or 'projection_head' in n 192 | include = lambda n, p: not exclude(n, p) 193 | 194 | # named_parameters = list(model.named_parameters()) 195 | # gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad] 196 | # rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] 197 | 198 | visual_named_parameters = [(n, p) for n, p in list(model.named_parameters()) if visual(n, p)] 199 | visual_gain_or_bias_params = [p for n, p in visual_named_parameters if exclude(n, p) and p.requires_grad] 200 | visual_rest_params = [p for n, p in visual_named_parameters if include(n, p) and p.requires_grad] 201 | 202 | non_visual_named_parameters = [(n, p) for n, p in list(model.named_parameters()) if non_visual(n, p)] 203 | non_visual_gain_or_bias_params = [p for n, p in non_visual_named_parameters if exclude(n, p) and p.requires_grad] 204 | non_visual_rest_params = [p for n, p in non_visual_named_parameters if include(n, p) and p.requires_grad] 205 | 206 | if is_master(args): 207 | logging.info(f"visual_named_parameters:") 208 | for n, p in visual_named_parameters: 209 | logging.info(f'\t{n}') 210 | logging.info(f"non_visual_named_parameters:") 211 | for n, p in non_visual_named_parameters: 212 | logging.info(f'\t{n}') 213 | 214 | 215 | optimizer = optim.AdamW( 216 | [ 217 | {"params": visual_gain_or_bias_params, "weight_decay": 0.}, 218 | {"params": visual_rest_params, "weight_decay": args.wd}, 219 | {"params": non_visual_gain_or_bias_params, "weight_decay": 0.}, 220 | {"params": non_visual_rest_params, "weight_decay": args.wd}, 221 | ], 222 | lr=args.lr, 223 | betas=(args.beta1, args.beta2), 224 | eps=args.eps, 225 | ) 226 | if args.horovod: 227 | optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters()) 228 | hvd.broadcast_parameters(model.state_dict(), root_rank=0) 229 | hvd.broadcast_optimizer_state(optimizer, root_rank=0) 230 | 231 | scaler = GradScaler() if args.precision == "amp" else None 232 | 233 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 234 | # optionally resume from a checkpoint 235 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 236 | start_epoch = 0 237 | if args.resume is not None: 238 | if os.path.isfile(args.resume): 239 | checkpoint = torch.load(args.resume, map_location=device) 240 | if 'epoch' in checkpoint: 241 | # resuming a train checkpoint w/ epoch and optimizer state 242 | start_epoch = checkpoint["epoch"] 243 | sd = checkpoint["state_dict"] 244 | if not args.distributed and next(iter(sd.items()))[0].startswith('module'): 245 | sd = {k[len('module.'):]: v for k, v in sd.items()} 246 | model.load_state_dict(sd) 247 | if optimizer is not None: 248 | optimizer.load_state_dict(checkpoint["optimizer"]) 249 | if scaler is not None and 'scaler' in checkpoint: 250 | scaler.load_state_dict(checkpoint['scaler']) 251 | logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})") 252 | else: 253 | # loading a bare (model only) checkpoint for fine-tune or evaluation 254 | model.load_state_dict(checkpoint) 255 | logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") 256 | else: 257 | logging.info("=> no checkpoint found at '{}'".format(args.resume)) 258 | 259 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 260 | # initialize datasets 261 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 262 | if args.episode_size!=0: 263 | args.episodic_training=True 264 | index_mapping = torch.arange(args.episode_size).share_memory_() 265 | if is_master(args): 266 | logging.info(f"Model will be trained with episodic training strategy (episodic size={args.episode_size}).") 267 | else: 268 | args.episodic_training=False 269 | index_mapping = None 270 | if is_master(args): 271 | logging.info(f"Model will be trained with epoch-wise training strategy.") 272 | 273 | data = get_data(args, (preprocess_train, preprocess_val), index_mapping=index_mapping, tokenizer=tokenizer) 274 | 275 | if args.train_data is not None and args.dataset_size is None: 276 | args.dataset_size = len(data['train'].dataset.captions) 277 | if not args.episodic_training: 278 | args.episode_size = args.dataset_size 279 | 280 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 281 | # create scheduler if train 282 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 283 | scheduler = None 284 | if 'train' in data and optimizer is not None: 285 | total_steps = data["train"].dataloader.num_batches * args.epochs 286 | if args.lit_start_epoch < 0: # No LiT 287 | visual_steps = total_steps 288 | else: 289 | visual_steps = data["train"].dataloader.num_batches * (args.lit_start_epoch - 1) 290 | 291 | 292 | text_start_step = data["train"].dataloader.num_batches * args.text_start_epoch 293 | if args.text_end_epoch < 0: 294 | args.text_end_epoch = args.epochs 295 | text_end_step = data["train"].dataloader.num_batches * args.text_end_epoch 296 | 297 | if args.lr_text < 0: 298 | args.lr_text = args.lr 299 | scheduler = protoclip_cosine_lr(optimizer, args.lr, args.lr_text, args.warmup, total_steps, visual_steps, text_start_step, text_end_step) 300 | 301 | if is_master(args): 302 | logging.info(f"Using cosine lr scheduler. Total steps: {total_steps} ({args.epochs} epochs, {data['train'].dataloader.num_batches} steps per epoch)") 303 | if visual_steps!=total_steps: 304 | logging.info(f"\tTotal steps for visual backbone: {visual_steps}") 305 | if text_start_step!=0: 306 | logging.info(f"\tRest parameters are frozen until: {text_start_step}") 307 | 308 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 309 | # determine if this worker should save logs and checkpoints. only do so if it is rank == 0 310 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 311 | args.save_logs = args.logs and args.logs.lower() != 'none' and is_master(args) 312 | writer = None 313 | if args.save_logs and args.tensorboard: 314 | assert tensorboard is not None, "Please install tensorboard." 315 | writer = tensorboard.SummaryWriter(args.tensorboard_path) 316 | 317 | if args.wandb and is_master(args): 318 | assert wandb is not None, 'Please install wandb.' 319 | logging.debug('Starting wandb.') 320 | args.train_sz = data["train"].dataloader.num_samples 321 | # you will have to configure this for your project! 322 | wandb.init( 323 | project="TrainProtoCLIP", 324 | notes=args.name, 325 | tags=[], 326 | config=vars(args), 327 | ) 328 | if args.debug: 329 | wandb.watch(model, log='all') 330 | #wandb.save(params_file) 331 | logging.debug('Finished loading wandb.') 332 | 333 | if 'train' not in data: 334 | evaluate(model, start_epoch, preprocess_val, tokenizer, args, writer) 335 | return 336 | 337 | 338 | clustering = Clustering(args) 339 | 340 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 341 | # Start training loop 342 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 343 | 344 | profiling = { 345 | "epsidoe feature extraction time (m)": 0, 346 | "epsidoe kmeans time (m)": 0, 347 | "epsidoe model training time (m)": 0, 348 | "epsidoe total time (m)": 0, 349 | } 350 | for epoch in range(start_epoch, args.epochs): 351 | if is_master(args): 352 | logging.info(f'Start epoch {epoch}') 353 | epoch_start = time.time() 354 | 355 | if args.episodic_training: 356 | # Random episode sampling 357 | index_mapping[:] = torch.from_numpy(np.random.choice(args.dataset_size, args.episode_size, replace=True)) 358 | if is_master(args): 359 | logging.info(f"Randomly select {args.episode_size} samples from full dataset {args.dataset_size} as current episode.") 360 | 361 | if args.clustering_frequency!=-1 and epoch % args.clustering_frequency == 0: 362 | clustering.reset(args.k) 363 | # --- Episodic Training Step 1: Feature Extraction --- # 364 | start = time.time() 365 | feature_extraction_one_epoch(model, data, epoch, optimizer, scaler, scheduler, clustering, args, writer) 366 | if is_master(args): 367 | duration = (time.time()-start)/60 368 | profiling['epsidoe feature extraction time (m)'] = duration 369 | logging.info(f'[Profiling] Feature extraction finished in {duration:.2f} minute.') 370 | 371 | # --- Episodic Training Step 2: Prototype Construction --- # 372 | if is_master(args): 373 | start = time.time() 374 | img_iteration_stats, text_iteration_stats = clustering.generate_labels(args.k, args) 375 | clustering.log_kmeans_error(img_iteration_stats, epoch, writer, args, 'image') 376 | clustering.log_kmeans_error(text_iteration_stats, epoch, writer, args, 'text') 377 | 378 | duration = (time.time()-start)/60 379 | profiling['epsidoe kmeans time (m)'] = duration 380 | logging.info(f'[Profiling] K-Means clustering finished in {duration:.2f} minute.') 381 | 382 | # metrics = clustering.analyze_labels() 383 | # for name, val in metrics.items(): 384 | # if writer is not None: 385 | # writer.add_scalar('clustering/' + name, val, epoch) 386 | # if args.wandb: 387 | # wandb.log({'clustering/' + name: val, 'step': epoch}) 388 | 389 | if args.visualize_frequency != -1 and epoch % args.visualize_frequency == 0: 390 | visualize_start = time.time() 391 | clustering.show_samples(dataset=data['train'].dataset, modality='image',file_name=os.path.join(args.visualization_path, f'samples_image_label@epoch{epoch+1}')) 392 | clustering.show_samples(dataset=data['train'].dataset, modality='text',file_name=os.path.join(args.visualization_path, f'samples_text_label@epoch{epoch+1}')) 393 | clustering.show_tsne(file_name=os.path.join(args.visualization_path, f'TSNE@epoch{epoch+1}'), truncate=20000, title=f"Epoch {epoch+1}") 394 | logging.info(f'[Profiling] Cluster visualization finished in {(time.time()-visualize_start)/60:.2f} minute.') 395 | 396 | if not args.PBT and args.external_teacher is not None: 397 | logging.warning('External teacher supervision can not be applied without PBT. Skip external prototype construction.') 398 | 399 | start = time.time() 400 | if args.PBT: 401 | clustering.img_centroids_translated_from_text_prototypes = clustering.PBT( 402 | k=args.k, 403 | teacher_labels=clustering.text_labels, 404 | student_features=clustering.img_feature 405 | ) 406 | clustering.text_centroids_translated_from_image_prototypes = clustering.PBT( 407 | k=args.k, 408 | teacher_labels=clustering.img_labels, 409 | student_features=clustering.text_feature 410 | ) 411 | if args.external_teacher is not None: 412 | external_teacher = np.load(args.external_teacher) 413 | logging.info(f'Loaded external teacher ({external_teacher.shape}) from file "{args.external_teacher}".') 414 | clustering.generate_labels_from_external_teacher(external_teacher[index_mapping], args.k, args) 415 | clustering.img_centroids_translated_from_external_prototypes = clustering.PBT( 416 | k=args.k, 417 | teacher_labels=clustering.external_labels, 418 | student_features=clustering.img_feature 419 | ) 420 | clustering.text_centroids_translated_from_external_prototypes = clustering.PBT( 421 | k=args.k, 422 | teacher_labels=clustering.external_labels, 423 | student_features=clustering.text_feature 424 | ) 425 | duration = (time.time()-start)/60 426 | profiling['epsidoe PBT time (m)'] = duration 427 | logging.info(f'[Profiling] PBT finished in {duration:.2f} minute.') 428 | 429 | if args.distributed: 430 | dist.barrier() 431 | clustering.sync_prototypes(args) 432 | 433 | # --- Episodic Training Step 3: Model Training --- # 434 | start = time.time() 435 | train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, clustering, args, writer) 436 | if is_master(args): 437 | duration = (time.time()-start)/60 438 | profiling['epsidoe model training time (m)'] = duration 439 | logging.info(f'[Profiling] Model training finished in {duration:.2f} minute.') 440 | duration = (time.time()-epoch_start)/60 441 | profiling['epsidoe total time (m)'] = duration 442 | logging.info(f'[Profiling] Entire epoch/episode takes {duration:.1f} minute.') 443 | 444 | for name, val in profiling.items(): 445 | name = "profiling/" + name 446 | if writer is not None: 447 | writer.add_scalar(name, val, epoch) 448 | if args.wandb: 449 | assert wandb is not None, 'Please install wandb.' 450 | wandb.log({name: val, 'step': epoch}) 451 | 452 | completed_epoch = epoch + 1 453 | 454 | # Saving checkpoints. 455 | if args.save_logs: 456 | checkpoint_dict = { 457 | "epoch": completed_epoch, 458 | "name": args.name, 459 | "state_dict": model.state_dict(), 460 | "optimizer": optimizer.state_dict(), 461 | } 462 | if scaler is not None: 463 | checkpoint_dict["scaler"] = scaler.state_dict() 464 | 465 | if completed_epoch == args.epochs or ( 466 | args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 467 | ): 468 | torch.save( 469 | checkpoint_dict, 470 | os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), 471 | ) 472 | if args.save_most_recent: 473 | torch.save( 474 | checkpoint_dict, 475 | os.path.join(args.checkpoint_path, f"epoch_latest.pt"), 476 | ) 477 | 478 | evaluate(model, completed_epoch, preprocess_val, tokenizer, args, writer) 479 | 480 | if args.wandb and is_master(args): 481 | wandb.finish() 482 | 483 | 484 | def copy_codebase(args): 485 | from shutil import copytree, ignore_patterns 486 | new_code_path = os.path.join(args.logs, args.name, "code") 487 | if os.path.exists(new_code_path): 488 | print( 489 | f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment." 490 | ) 491 | return -1 492 | print(f"Copying codebase to {new_code_path}") 493 | current_code_path = os.path.realpath(__file__) 494 | for _ in range(3): 495 | current_code_path = os.path.dirname(current_code_path) 496 | copytree(current_code_path, new_code_path, ignore=ignore_patterns('log', 'logs', 'wandb', 'cache', 'features')) 497 | print("Done copying code.") 498 | return 1 499 | 500 | 501 | if __name__ == "__main__": 502 | main() 503 | -------------------------------------------------------------------------------- /src/training/params.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_default_params(model_name): 5 | # Params from paper (https://arxiv.org/pdf/2103.00020.pdf) 6 | model_name = model_name.lower() 7 | if "vit" in model_name: 8 | return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6} 9 | else: 10 | return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8} 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser() 15 | 16 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 17 | # Data and Episodic training 18 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 19 | parser.add_argument( 20 | "--train-data", 21 | type=str, 22 | default=None, 23 | help="Path to csv filewith training data", 24 | ) 25 | parser.add_argument( 26 | "--augmentation", 27 | choices=[None, "protoclip-light-augmentation"], 28 | default=None, 29 | help="Use lighter augmentation for implicit contrast. Choices: [None, protoclip-light-augmentation]", 30 | ) 31 | parser.add_argument( 32 | "--eval-data-dir", 33 | type=str, 34 | default=None, 35 | help="Path to datasets for evaluation", 36 | ) 37 | parser.add_argument( 38 | "--dataset-size", 39 | type=int, 40 | default=None, 41 | help="Trunck the number of samples in dataset.", 42 | ) 43 | parser.add_argument( 44 | "--csv-separator", 45 | type=str, 46 | default="\t", 47 | help="For csv-like datasets, which separator to use." 48 | ) 49 | parser.add_argument( 50 | "--csv-img-key", 51 | type=str, 52 | default="filepath", 53 | help="For csv-like datasets, the name of the key for the image paths." 54 | ) 55 | parser.add_argument( 56 | "--csv-caption-key", 57 | type=str, 58 | default="title", 59 | help="For csv-like datasets, the name of the key for the captions." 60 | ) 61 | parser.add_argument( 62 | "--workers", type=int, default=1, help="Number of workers per GPU." 63 | ) 64 | parser.add_argument( 65 | "--episode-size", 66 | type=int, 67 | default=0, 68 | help="Set episode_size to 0 to disable episodic training", 69 | ) 70 | 71 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 72 | # Prototypical contrast 73 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 74 | parser.add_argument( 75 | "--external-teacher", 76 | type=str, 77 | default=None, 78 | help="Saved numpy array with shape (dataset_size, feature_dim) as external teacher. leave it as None to disable the external teacher." 79 | ) 80 | parser.add_argument( 81 | "--add-projection-head", 82 | action="store_true", 83 | default=False, 84 | help="add two projection heads and leanable temperatures to CLIP", 85 | ) 86 | parser.add_argument( 87 | "--PBT", 88 | action="store_true", 89 | default=False, 90 | help="enable Prototype Back Translation", 91 | ) 92 | parser.add_argument( 93 | "--projection-dim", 94 | type=int, 95 | default=128, 96 | help="dimension of projected representations", 97 | ) 98 | parser.add_argument( 99 | "--projection-hidden-dim", 100 | type=int, 101 | default=2048, 102 | help="dimension of projected representations", 103 | ) 104 | parser.add_argument( 105 | "--projection-n-layers", 106 | type=int, 107 | default=1, 108 | help="dimension of projected representations", 109 | ) 110 | parser.add_argument( 111 | "--target-temperature", 112 | type=float, 113 | default=-1.0, 114 | help="target temperature to calculate teacher scroes in proto loss", 115 | ) 116 | parser.add_argument( 117 | "--clustering-frequency", 118 | type=int, 119 | default=-1, 120 | help="update prototypes, set to -1 for non ProtoCLIP models", 121 | ) 122 | parser.add_argument( 123 | "--k", 124 | type=int, 125 | default=20000, 126 | help="dimension of projected representations", 127 | ) 128 | parser.add_argument( 129 | "--kmeans-max-iter", 130 | type=int, 131 | default=20, 132 | help="maximum iterations of K-Means optimization", 133 | ) 134 | parser.add_argument( 135 | "--kmeans-nredo", 136 | type=int, 137 | default=1, 138 | help="random re-initialize and do K-Means for how many times", 139 | ) 140 | 141 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 142 | # Logging and checkpointing 143 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 144 | parser.add_argument( 145 | "--logs", 146 | type=str, 147 | default="./logs/", 148 | help="Where to store tensorboard logs. Use None to avoid storing logs.", 149 | ) 150 | parser.add_argument( 151 | "--log-local", 152 | action="store_true", 153 | default=False, 154 | help="log files on local master, otherwise global master only.", 155 | ) 156 | parser.add_argument( 157 | "--name", 158 | type=str, 159 | default=None, 160 | help="Optional identifier for the experiment when storing logs. Otherwise use current time.", 161 | ) 162 | parser.add_argument( 163 | "--save-frequency", type=int, default=1, help="How often to save checkpoints." 164 | ) 165 | parser.add_argument( 166 | "--save-most-recent", 167 | action="store_true", 168 | default=True, 169 | help="Always save the most recent model trained to epoch_latest.pt.", 170 | ) 171 | parser.add_argument( 172 | "--resume", 173 | default=None, 174 | type=str, 175 | help="path to latest checkpoint (default: none)", 176 | ) 177 | parser.add_argument( 178 | "--pretrained", 179 | default='', 180 | type=str, 181 | help="Use a pretrained CLIP model weights with the specified tag or file path.", 182 | ) 183 | parser.add_argument( 184 | "--pretrained-image", 185 | default=False, 186 | action='store_true', 187 | help="Load imagenet pretrained weights for image tower backbone if available.", 188 | ) 189 | parser.add_argument( 190 | "--pretrained-text", 191 | default=None, 192 | type=str, 193 | help="Load pretrained language model as text tower via pytorch-transformers.", 194 | ) 195 | 196 | # MODELS = [(BertModel, BertTokenizer, 'bert-base-uncased'), 197 | # (OpenAIGPTModel, OpenAIGPTTokenizer, 'openai-gpt'), 198 | # (GPT2Model, GPT2Tokenizer, 'gpt2'), 199 | # (TransfoXLModel, TransfoXLTokenizer, 'transfo-xl-wt103'), 200 | # (XLNetModel, XLNetTokenizer, 'xlnet-base-cased'), 201 | # (XLMModel, XLMTokenizer, 'xlm-mlm-enfr-1024'), 202 | # (RobertaModel, RobertaTokenizer, 'roberta-base')] 203 | 204 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 205 | # Loss functions 206 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 207 | 208 | parser.add_argument("--w-clip", type=float, default=1., help="Loss weight.") 209 | parser.add_argument("--w-proto", type=float, default=0., help="Loss weight.") 210 | parser.add_argument("--w-proto-external", type=float, default=0., help="Loss weight.") 211 | parser.add_argument( 212 | "--infonce-warmup-epoch", 213 | default=0, 214 | type=int, 215 | help="InfoNCE-only warmup.", 216 | ) 217 | parser.add_argument( 218 | "--lit-start-epoch", 219 | default=-1, 220 | type=int, 221 | help="Enable ProtoCLIP asymetric learning rate scheduler. Leave it as negative to skip LiT.", 222 | ) 223 | parser.add_argument( 224 | "--text-start-epoch", 225 | default=0, 226 | type=int, 227 | help="Freeze text encoder at the begining of training.", 228 | ) 229 | parser.add_argument( 230 | "--text-end-epoch", 231 | default=-1, 232 | type=int, 233 | help="TODO", 234 | ) 235 | parser.add_argument( 236 | "--lock-image", 237 | default=False, 238 | action='store_true', 239 | help="Lock full image tower by disabling gradients.", 240 | ) 241 | parser.add_argument( 242 | "--lock-image-unlocked-groups", 243 | type=int, 244 | default=0, 245 | help="Leave last n image tower layer groups unlocked.", 246 | ) 247 | parser.add_argument( 248 | "--lock-image-freeze-bn-stats", 249 | default=False, 250 | action='store_true', 251 | help="Freeze BatchNorm running stats in image tower for any locked layers.", 252 | ) 253 | parser.add_argument( 254 | "--report-to", 255 | default='', 256 | type=str, 257 | help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']" 258 | ) 259 | parser.add_argument( 260 | "--wandb-notes", 261 | default='', 262 | type=str, 263 | help="Notes if logging with wandb" 264 | ) 265 | parser.add_argument( 266 | "--debug", 267 | default=False, 268 | action="store_true", 269 | help="If true, more information is logged." 270 | ) 271 | parser.add_argument( 272 | "--copy-codebase", 273 | default=False, 274 | action="store_true", 275 | help="If true, we copy the entire base on the log diretory, and execute from there." 276 | ) 277 | 278 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 279 | # Optimization 280 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 281 | parser.add_argument("--batch-size", type=int, default=64, help="Batch size per GPU.") 282 | parser.add_argument("--epochs", type=int, default=32, help="Number of epochs to train for.") 283 | parser.add_argument("--lr", type=float, default=None, help="Learning rate.") 284 | parser.add_argument("--lr-text", type=float, default=-1., help="Seperate learning rate despite visual backbone. Leave it as -1 to use default unified learning rate") 285 | parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.") 286 | parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.") 287 | parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.") 288 | parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.") 289 | parser.add_argument("--warmup", type=int, default=10000, help="Number of steps to warmup for.") 290 | parser.add_argument( 291 | "--use-bn-sync", 292 | default=False, 293 | action="store_true", 294 | help="Whether to use batch norm sync.") 295 | parser.add_argument( 296 | "--skip-scheduler", 297 | action="store_true", 298 | default=False, 299 | help="Use this flag to skip the learning rate decay.", 300 | ) 301 | parser.add_argument( 302 | "--precision", 303 | choices=["amp", "fp16", "fp32"], 304 | default="amp", 305 | help="Floating point precision." 306 | ) 307 | parser.add_argument( 308 | "--grad-checkpointing", 309 | default=False, 310 | action='store_true', 311 | help="Enable gradient checkpointing.", 312 | ) 313 | parser.add_argument( 314 | "--max-grad-norm", 315 | default=1e16, 316 | type=float, 317 | help="Enable gradient clipping.", 318 | ) 319 | parser.add_argument( 320 | "--local-loss", 321 | default=False, 322 | action="store_true", 323 | help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)" 324 | ) 325 | parser.add_argument( 326 | "--gather-with-grad", 327 | default=False, 328 | action="store_true", 329 | help="enable full distributed gradient for feature gather" 330 | ) 331 | parser.add_argument( 332 | "--force-quick-gelu", 333 | default=False, 334 | action='store_true', 335 | help="Force use of QuickGELU activation for non-OpenAI transformer models.", 336 | ) 337 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 338 | # Evaluation 339 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 340 | parser.add_argument("--zeroshot-frequency", type=int, default=0, help="How often to run zero shot.") 341 | parser.add_argument("--retrieval-frequency", type=int, default=0, help="How often to run coco retrieval.") 342 | parser.add_argument("--linear-frequency", type=int, default=0, help="How often to run linear eval.") 343 | parser.add_argument("--visualize-frequency", type=int, default=-1, help="How often to run linear eval.") 344 | parser.add_argument("--C", type=float, default=3.16, help="inverse regularizer for logistic reg (sklearn implementation).") 345 | parser.add_argument( 346 | "--linear-prob-mode", 347 | choices=["pytorch", "sklearn"], 348 | default="pytorch", 349 | help="Use witch implementation for linear evaluaion" 350 | ) 351 | parser.add_argument( 352 | "--model", 353 | type=str, 354 | default="RN50", 355 | help="Name of the vision backbone to use.", 356 | ) 357 | parser.add_argument( 358 | "--torchscript", 359 | default=False, 360 | action='store_true', 361 | help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'", 362 | ) 363 | parser.add_argument( 364 | "--trace", 365 | default=False, 366 | action='store_true', 367 | help="torch.jit.trace the model for inference / eval only", 368 | ) 369 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 370 | # Distributed training 371 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 372 | parser.add_argument( 373 | "--dist-url", 374 | default="env://", 375 | type=str, 376 | help="url used to set up distributed training", 377 | ) 378 | parser.add_argument( 379 | "--dist-backend", default="nccl", type=str, help="distributed backend" 380 | ) 381 | parser.add_argument( 382 | "--horovod", 383 | default=False, 384 | action="store_true", 385 | help="Use horovod for distributed training." 386 | ) 387 | parser.add_argument( 388 | "--ddp-static-graph", 389 | default=False, 390 | action='store_true', 391 | help="Enable static graph optimization for DDP in PyTorch >= 1.11.", 392 | ) 393 | parser.add_argument( 394 | "--no-set-device-rank", 395 | default=False, 396 | action="store_true", 397 | help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc)." 398 | ) 399 | parser.add_argument( 400 | "--seed", type=int, default=0, help="Default random seed." 401 | ) 402 | args = parser.parse_args() 403 | 404 | # If some params are not passed, we use the default values based on model name. 405 | default_params = get_default_params(args.model) 406 | for name, val in default_params.items(): 407 | if getattr(args, name) is None: 408 | setattr(args, name, val) 409 | 410 | return args 411 | -------------------------------------------------------------------------------- /src/training/pretrained_transformers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_transformers import RobertaModel, RobertaTokenizer 3 | 4 | def get_pretrained_text_encoder_and_tokenizer(name): 5 | # TODO: get the width of other models 6 | MODELS = [ 7 | # (BertModel, BertTokenizer, 'bert-base-uncased', ?), 8 | # (OpenAIGPTModel, OpenAIGPTTokenizer, 'openai-gpt', ?), 9 | # (GPT2Model, GPT2Tokenizer, 'gpt2', ?), 10 | # (TransfoXLModel, TransfoXLTokenizer, 'transfo-xl-wt103', ?), 11 | # (XLNetModel, XLNetTokenizer, 'xlnet-base-cased', ?), 12 | # (XLMModel, XLMTokenizer, 'xlm-mlm-enfr-1024', ?), 13 | (RobertaModel, RobertaTokenizer, 'roberta-base', 768) 14 | ] 15 | 16 | for model_class, tokenizer_class, pretrained_weights, feature_dim in MODELS: 17 | if pretrained_weights==name: 18 | tokenizer = tokenizer_class.from_pretrained(pretrained_weights, max_len=75) 19 | model = model_class.from_pretrained(pretrained_weights) 20 | 21 | def tokenize(text): 22 | result = torch.zeros(77) 23 | token = torch.tensor(tokenizer.encode(text, add_special_tokens=True))[:77] 24 | result[:len(token)] = token 25 | return result.long() 26 | 27 | return model, tokenize, feature_dim 28 | 29 | 30 | if __name__=='__main__': 31 | MODELS = [(RobertaModel, RobertaTokenizer, 'roberta-base')] 32 | # Let's encode some text in a sequence of hidden-states using each model: 33 | for model_class, tokenizer_class, pretrained_weights in MODELS: 34 | # Load pretrained model/tokenizer 35 | tokenizer = tokenizer_class.from_pretrained(pretrained_weights) 36 | model = model_class.from_pretrained(pretrained_weights) 37 | print(model) 38 | # Encode text 39 | input_ids = torch.tensor([tokenizer.encode("Here is some text to encode", add_special_tokens=True)]) # Add special tokens takes care of adding [CLS], [SEP], ... tokens in the right way for each model. 40 | print(input_ids) 41 | with torch.no_grad(): 42 | last_hidden_states = model(input_ids)[0] # Models outputs are now tuples 43 | print(last_hidden_states.size()) 44 | -------------------------------------------------------------------------------- /src/training/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def assign_learning_rate(optimizer, new_lr): 5 | for param_group in optimizer.param_groups: 6 | param_group["lr"] = new_lr 7 | 8 | 9 | def _warmup_lr(base_lr, warmup_length, step): 10 | return base_lr * (step + 1) / warmup_length 11 | 12 | 13 | def cosine_lr(optimizer, base_lr, warmup_length, steps): 14 | def _lr_adjuster(step): 15 | if step < warmup_length: 16 | lr = _warmup_lr(base_lr, warmup_length, step) 17 | else: 18 | e = step - warmup_length 19 | es = steps - warmup_length 20 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 21 | assign_learning_rate(optimizer, lr) 22 | return lr 23 | return _lr_adjuster 24 | 25 | 26 | # - - - - - - - - - - - - - - - - - - - - - - - - - - 27 | 28 | 29 | def assign_learning_rate_seperately(optimizer, lr_visual, lr_non_visual): 30 | optimizer.param_groups[0]["lr"] = lr_visual 31 | optimizer.param_groups[1]["lr"] = lr_visual 32 | optimizer.param_groups[2]["lr"] = lr_non_visual 33 | optimizer.param_groups[3]["lr"] = lr_non_visual 34 | 35 | def protoclip_cosine_lr(optimizer, visual_base_lr, text_base_lr, warmup_length, total_steps, visual_steps, text_start_step, text_end_step): 36 | def _lr_adjuster(step): 37 | 38 | # get lr_visual for visual backbone 39 | if step < warmup_length: 40 | lr_visual = _warmup_lr(visual_base_lr, warmup_length, step) 41 | else: 42 | e_visual = step - warmup_length 43 | es_visual = visual_steps - warmup_length 44 | lr_visual = 0.5 * (1 + np.cos(np.pi * e_visual / es_visual)) * visual_base_lr 45 | if step > visual_steps: 46 | lr_visual = 0 47 | 48 | # get lr_non_visual for rest parameters 49 | if step < text_start_step: 50 | lr_non_visual = 0 51 | else: 52 | step -= text_start_step 53 | if step < warmup_length: 54 | lr_non_visual = _warmup_lr(text_base_lr, warmup_length, step) 55 | else: 56 | e_non_visual = step - warmup_length 57 | es_non_visual = text_end_step - text_start_step - warmup_length 58 | lr_non_visual = 0.5 * (1 + np.cos(np.pi * e_non_visual / es_non_visual)) * text_base_lr 59 | if step > text_end_step: 60 | lr_non_visual = 0 61 | 62 | assign_learning_rate_seperately(optimizer, lr_visual, lr_non_visual) 63 | return [lr_visual, lr_non_visual] 64 | return _lr_adjuster 65 | -------------------------------------------------------------------------------- /src/training/train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import math 4 | import os 5 | import time 6 | from contextlib import suppress 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | try: 14 | import wandb 15 | except ImportError: 16 | wandb = None 17 | 18 | #from open_clip import ClipLoss 19 | from training.loss import gather_features, ClipLoss, ProtoLoss 20 | from .distributed import is_master, get_gathered_item 21 | import torch.distributed as dist 22 | from training.evaluations.analyze_features import get_modality_gap 23 | 24 | class AverageMeter(object): 25 | """Computes and stores the average and current value""" 26 | def __init__(self): 27 | self.reset() 28 | 29 | def reset(self): 30 | self.val = 0 31 | self.avg = 0 32 | self.sum = 0 33 | self.count = 0 34 | 35 | def update(self, val, n=1): 36 | self.val = val 37 | self.sum += val * n 38 | self.count += n 39 | self.avg = self.sum / self.count 40 | 41 | 42 | def unwrap_model(model): 43 | if hasattr(model, 'module'): 44 | return model.module 45 | else: 46 | return model 47 | 48 | 49 | def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, clustering, args, tb_writer=None): 50 | device = torch.device(args.device) 51 | ZERO = torch.zeros(1).to(args.device) 52 | 53 | #autocast = torch.cuda.amp.autocast if args.precision == 'amp' else suppress 54 | autocast = torch.cuda.amp.autocast 55 | 56 | model.train() 57 | clip_loss = ClipLoss( 58 | local_loss=args.local_loss, 59 | gather_with_grad=args.gather_with_grad, 60 | cache_labels=True, 61 | rank=args.rank, 62 | world_size=args.world_size, 63 | use_horovod=args.horovod) 64 | 65 | proto_loss = ProtoLoss() 66 | 67 | w_clip = args.w_clip 68 | w_proto = args.w_proto 69 | w_proto_external = args.w_proto_external 70 | 71 | # optionally entering and LiT epoch 72 | if args.lit_start_epoch > 0 and epoch >= args.lit_start_epoch - 1: 73 | w_clip = 1 74 | w_proto = 0 75 | w_proto_external = 0 76 | if is_master(args): 77 | logging.info('Setting up Locked-image Finetuning (LiT)') 78 | 79 | # optionally warm-up the model with InfoNCE only following PCL 80 | if epoch < args.infonce_warmup_epoch: 81 | w_clip = 1 82 | w_proto = 0 83 | w_proto_external = 0 84 | if is_master(args): 85 | logging.info('Setting up InfoNCE-only warmup') 86 | 87 | 88 | clustering.img_centroids = clustering.img_centroids.cuda() 89 | clustering.text_centroids = clustering.text_centroids.cuda() 90 | clustering.external_centroids = clustering.external_centroids.cuda() 91 | if args.PBT: 92 | clustering.img_centroids_translated_from_text_prototypes = clustering.img_centroids_translated_from_text_prototypes.cuda() 93 | clustering.text_centroids_translated_from_image_prototypes = clustering.text_centroids_translated_from_image_prototypes.cuda() 94 | clustering.img_centroids_translated_from_external_prototypes = clustering.img_centroids_translated_from_external_prototypes.cuda() 95 | clustering.text_centroids_translated_from_external_prototypes = clustering.text_centroids_translated_from_external_prototypes.cuda() 96 | 97 | 98 | dataloader, sampler = data['train'].dataloader, data['train'].sampler 99 | if args.distributed and sampler is not None: 100 | sampler.set_epoch(epoch) 101 | num_batches_per_epoch = dataloader.num_batches 102 | sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) 103 | 104 | batch_time_m = AverageMeter() 105 | data_time_m = AverageMeter() 106 | end = time.time() 107 | for i, batch in enumerate(dataloader): 108 | step = num_batches_per_epoch * epoch + i 109 | scheduler(step) 110 | index, images, texts = batch 111 | if len(index)!=args.batch_size: # drop last incomplete small batch 112 | continue 113 | all_index = get_gathered_item(index.cuda(), args) 114 | images = images.to(device=device, non_blocking=True) 115 | texts = texts.to(device=device, non_blocking=True) 116 | img_labels = clustering.img_labels[all_index].to(device=device, non_blocking=True) 117 | text_labels = clustering.text_labels[all_index].to(device=device, non_blocking=True) 118 | external_labels = clustering.external_labels[all_index].to(device=device, non_blocking=True) 119 | 120 | data_time_m.update(time.time() - end) 121 | optimizer.zero_grad() 122 | 123 | with autocast(): 124 | # original CLIP 125 | if not args.add_projection_head: 126 | image_features, text_features, logit_scale = model(images, texts) 127 | 128 | if args.distributed: 129 | all_image_features, all_text_features = gather_features(image_features, text_features, 130 | args.local_loss, args.gather_with_grad, args.rank, args.world_size, args.horovod) 131 | else: 132 | all_image_features, all_text_features = image_features, text_features 133 | 134 | L_clip = clip_loss(all_image_features, all_text_features, logit_scale) 135 | total_loss = L_clip 136 | # ProtoCLIP 137 | else: 138 | image_features, text_features, image_features_projected, text_features_projected, logit_scale, logit_scale_proto = model(images, texts) 139 | 140 | if args.distributed: 141 | all_image_features, all_text_features = gather_features(image_features, text_features, 142 | args.local_loss, args.gather_with_grad, args.rank, args.world_size, args.horovod) 143 | all_image_features_projected, all_text_features_projected = gather_features(image_features_projected, text_features_projected, 144 | args.local_loss, args.gather_with_grad, args.rank, args.world_size, args.horovod) 145 | else: 146 | all_image_features, all_text_features = image_features, text_features 147 | all_image_features_projected, all_text_features_projected = image_features_projected, text_features_projected 148 | 149 | L_clip = clip_loss(all_image_features, all_text_features, logit_scale) 150 | 151 | if args.PBT: 152 | img_target = clustering.img_centroids_translated_from_text_prototypes 153 | text_target = clustering.text_centroids_translated_from_image_prototypes 154 | img_target_external = clustering.img_centroids_translated_from_external_prototypes 155 | text_target_external = clustering.text_centroids_translated_from_external_prototypes 156 | else: 157 | img_target = clustering.text_centroids 158 | text_target = clustering.img_centroids 159 | img_target_external = clustering.img_centroids_translated_from_external_prototypes 160 | text_target_external = clustering.text_centroids_translated_from_external_prototypes 161 | 162 | L_proto_img2text, acc_img2text = proto_loss( 163 | student_features=all_image_features_projected, student_centroids=img_target, teacher_centroids=clustering.text_centroids, 164 | logit_scale_student=logit_scale_proto, teacher_temperature=args.target_temperature, labels=text_labels 165 | ) 166 | L_proto_text2img, acc_text2img = proto_loss( 167 | student_features=all_text_features_projected, student_centroids=text_target, teacher_centroids=clustering.img_centroids, 168 | logit_scale_student=logit_scale_proto, teacher_temperature=args.target_temperature, labels=img_labels 169 | ) 170 | 171 | if args.external_teacher is not None: 172 | L_proto_img2external, acc_img2external = proto_loss( 173 | student_features=all_image_features_projected, student_centroids=img_target_external, teacher_centroids=clustering.external_centroids, 174 | logit_scale_student=logit_scale_proto, teacher_temperature=-1, labels=external_labels 175 | ) 176 | L_proto_text2external, acc_text2external = proto_loss( 177 | student_features=all_text_features_projected, student_centroids=text_target_external, teacher_centroids=clustering.external_centroids, 178 | logit_scale_student=logit_scale_proto, teacher_temperature=-1, labels=external_labels 179 | ) 180 | else: 181 | L_proto_img2external, acc_img2external = ZERO, ZERO 182 | L_proto_text2external, acc_text2external = ZERO, ZERO 183 | 184 | L_proto = 0.5 * (L_proto_img2text + L_proto_text2img) 185 | L_proto_external = 0.5 * (L_proto_img2external + L_proto_text2external) 186 | total_loss = w_clip * L_clip + w_proto * L_proto + w_proto_external * L_proto_external 187 | 188 | if scaler is not None: 189 | scaler.scale(total_loss).backward() 190 | scaler.unscale_(optimizer) 191 | norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.max_grad_norm) 192 | if args.horovod: 193 | optimizer.synchronize() 194 | scaler.unscale_(optimizer) 195 | with optimizer.skip_synchronize(): 196 | scaler.step(optimizer) 197 | else: 198 | scaler.step(optimizer) 199 | scaler.update() 200 | else: 201 | total_loss.backward() 202 | norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.max_grad_norm) 203 | optimizer.step() 204 | 205 | # Note: we clamp to 4.6052 = ln(100), as in the original paper. 206 | with torch.no_grad(): 207 | unwrap_model(model).logit_scale.clamp_(0, math.log(100)) 208 | if args.add_projection_head: 209 | unwrap_model(model).logit_scale_proto.clamp_(0, math.log(100)) 210 | 211 | batch_time_m.update(time.time() - end) 212 | end = time.time() 213 | batch_count = i + 1 214 | if is_master(args) and (i % 10 == 0 or batch_count == num_batches_per_epoch): 215 | batch_size = len(images) 216 | num_samples = batch_count * batch_size * args.world_size 217 | samples_per_epoch = dataloader.num_samples 218 | percent_complete = 100.0 * batch_count / num_batches_per_epoch 219 | 220 | # NOTE loss is coarsely sampled, just master node and per log update 221 | logging.info( 222 | f"Train Epoch: {epoch+1}/{args.epochs} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " 223 | f"Loss: {total_loss.item():.5f} " 224 | f"Data (t): {data_time_m.avg:.3f} " 225 | f"Batch (t): {batch_time_m.avg:.3f} " 226 | f"Temperature: {1 / logit_scale.item():.4f} " 227 | f"LR (visual/rest): {optimizer.param_groups[0]['lr']:3f}/{optimizer.param_groups[2]['lr']:3f} " 228 | f"grad: {norm:1f} " 229 | ) 230 | 231 | # Save train loss / etc. Using non avg meter values as loggers have their own smoothing 232 | log_data = { 233 | "loss_clip": L_clip.item(), 234 | "temperature": 1 / logit_scale.item(), 235 | "lr_visual": optimizer.param_groups[0]["lr"], 236 | "lr_rest": optimizer.param_groups[2]["lr"], 237 | "gradient-norm": norm, 238 | 239 | "feature_std_image": torch.std(image_features, dim=0).mean().item(), 240 | "feature_std_text": torch.std(text_features, dim=0).mean().item(), 241 | "feature_modality_gap": get_modality_gap(image_features, text_features), 242 | } 243 | profiling = { 244 | "batch data time (s)": data_time_m.val, 245 | "bathc total time (s)": batch_time_m.val, 246 | } 247 | 248 | log_data_protoclip = {} 249 | if args.add_projection_head: 250 | log_data_protoclip['loss_proto'] = L_proto.item() 251 | log_data_protoclip['loss_proto_external'] = L_proto_external.item() 252 | log_data_protoclip['acc_img2text'] = acc_img2text 253 | log_data_protoclip['acc_text2img'] = acc_text2img 254 | log_data_protoclip['acc_img2external'] = acc_img2external 255 | log_data_protoclip['acc_text2external'] = acc_text2external 256 | log_data_protoclip['temperature_proto'] = 1 / logit_scale_proto.item() 257 | 258 | 259 | for name, val in log_data.items(): 260 | name = "training/" + name 261 | if tb_writer is not None: 262 | tb_writer.add_scalar(name, val, step) 263 | if args.wandb: 264 | assert wandb is not None, 'Please install wandb.' 265 | wandb.log({name: val, 'step': step}) 266 | 267 | for name, val in log_data_protoclip.items(): 268 | name = "training_protoclip/" + name 269 | if tb_writer is not None: 270 | tb_writer.add_scalar(name, val, step) 271 | if args.wandb: 272 | assert wandb is not None, 'Please install wandb.' 273 | wandb.log({name: val, 'step': step}) 274 | 275 | for name, val in profiling.items(): 276 | name = "profiling/" + name 277 | if tb_writer is not None: 278 | tb_writer.add_scalar(name, val, step) 279 | if args.wandb: 280 | assert wandb is not None, 'Please install wandb.' 281 | wandb.log({name: val, 'step': step}) 282 | 283 | # resetting batch / data time meters per log window 284 | batch_time_m.reset() 285 | data_time_m.reset() 286 | 287 | 288 | def feature_extraction_one_epoch(model, data, epoch, optimizer, scaler, scheduler, clustering, args, tb_writer=None): 289 | device = torch.device(args.device) 290 | autocast = torch.cuda.amp.autocast if args.precision == 'amp' else suppress 291 | 292 | model.eval() 293 | 294 | dataloader, sampler = data['train'].dataloader, data['train'].sampler 295 | if args.distributed and sampler is not None: 296 | sampler.set_epoch(epoch) 297 | num_batches_per_epoch = dataloader.num_batches 298 | sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) 299 | 300 | batch_time_m = AverageMeter() 301 | data_time_m = AverageMeter() 302 | end = time.time() 303 | for i, batch in enumerate(dataloader): 304 | 305 | indexs, images, texts = batch 306 | indexs = indexs.to(device=device, non_blocking=True) 307 | images = images.to(device=device, non_blocking=True) 308 | texts = texts.to(device=device, non_blocking=True) 309 | 310 | data_time_m.update(time.time() - end) 311 | 312 | # forward propagation 313 | with autocast(): 314 | with torch.no_grad(): 315 | image_features, text_features, image_features_projected, text_features_projected, logit_scale, logit_scale_proto = model(images, texts) 316 | 317 | # cache features 318 | indexs = get_gathered_item(indexs, args) 319 | image_features_projected = get_gathered_item(image_features_projected, args) 320 | text_features_projected = get_gathered_item(text_features_projected, args) 321 | if is_master(args): 322 | clustering.load_batch(indexs, image_features_projected, text_features_projected) 323 | 324 | batch_time_m.update(time.time() - end) 325 | end = time.time() 326 | batch_count = i + 1 327 | if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): 328 | batch_size = len(images) 329 | num_samples = batch_count * batch_size * args.world_size 330 | samples_per_epoch = dataloader.num_samples 331 | percent_complete = 100.0 * batch_count / num_batches_per_epoch 332 | 333 | logging.info( 334 | f"Feature extraction: {epoch+1}/{args.epochs} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " 335 | f"Data (t): {data_time_m.avg:.3f} " 336 | f"Batch (t): {batch_time_m.avg:.3f} " 337 | ) 338 | 339 | # resetting batch / data time meters per log window 340 | batch_time_m.reset() 341 | data_time_m.reset() 342 | -------------------------------------------------------------------------------- /src/utils/RoBERTa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.utils.data import Dataset, DataLoader 4 | import numpy as np 5 | import pandas as pd 6 | import tqdm 7 | import logging 8 | import faiss 9 | 10 | 11 | class TextDataset(): 12 | def __init__(self, input_filename, tokenizer, caption_key='title', sep="\t") -> None: 13 | 14 | df = pd.read_csv(input_filename, sep=sep) 15 | print(df) 16 | self.captions = np.array(df[caption_key].tolist()) 17 | self.tokenizer = tokenizer 18 | 19 | def __len__(self): 20 | return len(self.captions) 21 | 22 | def __getitem__(self, idx): 23 | tokens_full = torch.zeros(77) 24 | token = self.tokenizer(str(self.captions[idx])) 25 | tokens_full[:len(token)]=token[:77] 26 | 27 | return idx, tokens_full.long() 28 | 29 | 30 | def PCA(dim, feature): 31 | feature = feature.astype(np.float32) 32 | pca = faiss.PCAMatrix(feature.shape[1], dim) 33 | pca.train(feature) 34 | PCAed_feature = pca.apply_py(feature) 35 | 36 | return PCAed_feature 37 | 38 | if __name__ == '__main__': 39 | csv = input('Input your csv file: ') 40 | feature_file = input('Input your feature file: ') 41 | 42 | roberta = torch.hub.load('pytorch/fairseq', 'roberta.large').cuda().eval() 43 | dataset = TextDataset( 44 | input_filename=csv, 45 | caption_key='title', 46 | tokenizer = roberta.encode 47 | ) 48 | 49 | dataloader = DataLoader(dataset, batch_size=256, num_workers=16, drop_last=False) 50 | all_text_features = np.zeros([len(dataset), 1024]) 51 | 52 | 53 | logging.info(f'Start RoBERTa feature extraction for file "{csv}" (total {len(dataset)} samples).') 54 | for step, batch in tqdm.tqdm(enumerate(dataloader), total=len(dataloader)): 55 | idx, texts = batch 56 | idx = idx.numpy() 57 | texts = texts.cuda() 58 | 59 | with torch.no_grad(): 60 | text_feature = roberta.extract_features(texts) 61 | 62 | text_feature = text_feature.mean(dim=1) 63 | text_feature = F.normalize(text_feature, dim=-1) 64 | all_text_features[idx] = text_feature.detach().cpu().numpy() 65 | 66 | logging.info(f'Performing PCA to reduce feature deminsion from {all_text_features.shape[1]} to 64') 67 | PCAed_text_features = PCA(64, all_text_features) 68 | logging.info(f'Saving PCA-ed RoBERTa features {PCAed_text_features.shape} to: "{feature_file}".') 69 | np.save(feature_file, PCAed_text_features) 70 | 71 | -------------------------------------------------------------------------------- /src/utils/evaluate_checkpoints.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import pandas as pd 4 | import torch 5 | from training.params import parse_args 6 | import argparse 7 | from training.evaluations.evaluation import evaluate 8 | from open_clip import create_model_and_transforms 9 | import logging 10 | import matplotlib.pyplot as plt 11 | #from openTSNE import TSNE 12 | from training.pretrained_transformers import get_pretrained_text_encoder_and_tokenizer 13 | 14 | 15 | logging.basicConfig(format='%(asctime)s: %(message)s', level=logging.INFO) 16 | 17 | def evaluate_checkpoint(checkpoint_path, epoch): 18 | # load model 19 | 20 | if args.pretrained_text is not None: 21 | logging.info(f'Loading pretrained text trasformer: {args.pretrained_text}.') 22 | pretrained_text_encoder, tokenizer, args.pretrained_text_feature_dim = get_pretrained_text_encoder_and_tokenizer(args.pretrained_text) 23 | else: 24 | logging.info(f'Text encoder will be trained from scratch.') 25 | pretrained_text_encoder, tokenizer, args.pretrained_text_feature_dim = None, None, None 26 | model, preprocess_train, preprocess_val = create_model_and_transforms( 27 | args.model, 28 | pretrained_text=pretrained_text_encoder, 29 | args=args 30 | ) 31 | 32 | checkpoint = torch.load(checkpoint_path, map_location=device) 33 | sd = checkpoint["state_dict"] 34 | if not args.distributed and next(iter(sd.items()))[0].startswith('module'): 35 | sd = {k[len('module.'):]: v for k, v in sd.items()} 36 | model.load_state_dict(sd) 37 | logging.info(f"=> Loaded checkpoint '{checkpoint_path}' (epoch {checkpoint['epoch']})") 38 | model = model.to(device) 39 | 40 | metrics = evaluate(model, epoch, preprocess_val, tokenizer, args, tb_writer=None) 41 | return metrics 42 | 43 | def load_params(params_file, args): 44 | args = vars(args) 45 | with open(params_file, 'r') as f: 46 | for line in f.readlines(): 47 | line = line.strip().split(': ') 48 | key, value = line[0], ''.join(line[1:]) 49 | if key in args.keys() and args[key] is not None: 50 | #print(key, value, args[key], type(args[key])) 51 | args[key] = type(args[key])(value) 52 | else: 53 | args[key] = value 54 | if value == 'False': 55 | args[key] = False 56 | if value == 'None': 57 | args[key] = None 58 | return argparse.Namespace(**args) 59 | 60 | 61 | 62 | if __name__ == '__main__': 63 | exp_dir = input('Please input your experiment dir: ') 64 | single_eval = input('Specify a checkpoint epoch? (press "enter" to scan and evaluate all checkpoints) ') 65 | 66 | checkpoint_dir = os.path.join(exp_dir, 'checkpoints') 67 | params_file = os.path.join(exp_dir, 'params.txt') 68 | 69 | args = parse_args() 70 | args = load_params(params_file, args) 71 | 72 | args.zeroshot_frequency = 1 73 | args.linear_frequency = 1 74 | args.retrieval_frequency = 1 75 | args.save_logs = False 76 | args.distributed = False 77 | args.wandb = False 78 | args.rank = 0 79 | args.batch_size = 32 80 | args.workers = 12 81 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 82 | 83 | logging.info(f"Loaded params from file '{params_file}':") 84 | for name in sorted(vars(args)): 85 | val = getattr(args, name) 86 | logging.info(f" {name}: {val}") 87 | 88 | all_metrics = pd.DataFrame() 89 | 90 | if not single_eval: 91 | finished = ['epoch_latest.pt'] 92 | while True: 93 | checkpoints = os.listdir(checkpoint_dir) 94 | for checkpoint in checkpoints: 95 | if checkpoint not in finished: 96 | logging.info(f'found new checkpoint: {checkpoint}') 97 | time.sleep(10) # in case of the checkpoint is not fully written to disk 98 | epoch = int(checkpoint.split('_')[1][:-3]) 99 | checkpoint_path = os.path.join(checkpoint_dir, checkpoint) 100 | 101 | metrics = evaluate_checkpoint(checkpoint_path=checkpoint_path, epoch=epoch) 102 | metrics['epoch'] = epoch 103 | 104 | for key in metrics.keys(): 105 | metrics[key] = [metrics[key]] 106 | metrics = pd.DataFrame.from_dict(metrics) 107 | 108 | all_metrics = pd.concat([all_metrics, metrics]) 109 | all_metrics.to_csv(os.path.join(exp_dir, 'evaluation_metrics_all.csv')) 110 | # all_metrics.to_csv(os.path.join(exp_dir, f'evaluation_metrics@epoch_{epoch}.csv')) 111 | print(all_metrics) 112 | 113 | finished.append(checkpoint) 114 | time.sleep(10) 115 | else: 116 | checkpoint = f'epoch_{single_eval}.pt' 117 | logging.info(f'evaluate single checkpoint: {checkpoint}') 118 | checkpoint_path = os.path.join(checkpoint_dir, checkpoint) 119 | epoch = int(checkpoint.split('_')[1][:-3]) 120 | 121 | metrics = evaluate_checkpoint(checkpoint_path=checkpoint_path, epoch=epoch) 122 | metrics['epoch'] = epoch 123 | for key in metrics.keys(): 124 | metrics[key] = [metrics[key]] 125 | metrics = pd.DataFrame.from_dict(metrics) 126 | 127 | all_metrics = pd.concat([all_metrics, metrics]) 128 | all_metrics.to_csv(os.path.join(exp_dir, f'single_evaluation_metrics@epoch_{epoch}.csv')) 129 | print(all_metrics) 130 | 131 | -------------------------------------------------------------------------------- /src/utils/gather_cc.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import os 3 | import multiprocessing as mp 4 | from io import BytesIO 5 | import numpy as np 6 | import PIL 7 | from PIL import Image 8 | import pickle 9 | import sys 10 | 11 | 12 | def grab(line): 13 | """ 14 | Download a single image from the TSV. 15 | """ 16 | uid, split, line = line 17 | try: 18 | caption, url = line.split("\t")[:2] 19 | except: 20 | print("Parse error") 21 | return 22 | 23 | if os.path.exists(ROOT+"/%s/%d/%d.jpg"%(split,uid%1000,uid)): 24 | print("Finished", uid) 25 | return uid, caption, url 26 | 27 | # Let's not crash if anythign weird happens 28 | try: 29 | dat = requests.get(url, timeout=20) 30 | if dat.status_code != 200: 31 | print("404 file", url) 32 | return 33 | 34 | # Try to parse this as an Image file, we'll fail out if not 35 | im = Image.open(BytesIO(dat.content)) 36 | im.thumbnail((512, 512), PIL.Image.BICUBIC) 37 | if min(*im.size) < max(*im.size)/3: 38 | print("Too small", url) 39 | return 40 | 41 | im.save(ROOT+"/%s/%d/%d.jpg"%(split,uid%1000,uid)) 42 | 43 | # Another try/catch just because sometimes saving and re-loading 44 | # the image is different than loading it once. 45 | try: 46 | o = Image.open(ROOT+"/%s/%d/%d.jpg"%(split,uid%1000,uid)) 47 | o = np.array(o) 48 | 49 | print("Success", o.shape, uid, url) 50 | return uid, caption, url 51 | except: 52 | print("Failed", uid, url) 53 | 54 | except Exception as e: 55 | print("Unknown error", e) 56 | pass 57 | 58 | if __name__ == "__main__": 59 | ROOT = "cc_data" 60 | 61 | if not os.path.exists(ROOT): 62 | os.mkdir(ROOT) 63 | os.mkdir(os.path.join(ROOT,"train")) 64 | os.mkdir(os.path.join(ROOT,"val")) 65 | for i in range(1000): 66 | os.mkdir(os.path.join(ROOT,"train", str(i))) 67 | os.mkdir(os.path.join(ROOT,"val", str(i))) 68 | 69 | 70 | p = mp.Pool(300) 71 | 72 | for tsv in sys.argv[1:]: 73 | print("Processing file", tsv) 74 | assert 'val' in tsv.lower() or 'train' in tsv.lower() 75 | split = 'val' if 'val' in tsv.lower() else 'train' 76 | results = p.map(grab, 77 | [(i,split,x) for i,x in enumerate(open(tsv).read().split("\n"))]) 78 | 79 | out = open(tsv.replace(".tsv","_output.csv"),"w") 80 | out.write("title\tfilepath\n") 81 | 82 | for row in results: 83 | if row is None: continue 84 | id, caption, url = row 85 | fp = os.path.join(ROOT, split, str(id % 1000), str(id) + ".jpg") 86 | if os.path.exists(fp): 87 | out.write("%s\t%s\n"%(caption,fp)) 88 | else: 89 | print("Drop", id) 90 | out.close() 91 | 92 | p.close() 93 | 94 | -------------------------------------------------------------------------------- /src/utils/plot_pairs.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import matplotlib.pyplot as plt 4 | import textwrap 5 | 6 | 7 | def plot_pairs(imgs, texts, suptitle, file_name='test.png', sample_per_row = 32): 8 | # imgs: a list of PIL-opened images 9 | # texts: a list of str, will be showed as title (per subplot) 10 | # suptitle: figure super title 11 | 12 | if len(imgs)40: 26 | text = text[:40]+'...' 27 | text = textwrap.fill(text, width=20) 28 | 29 | plt.subplot(row,column, i+1) 30 | plt.imshow(image) 31 | plt.text( 32 | x=int(image.size[0]/2),y=image.size[1]+30,s=text, 33 | fontsize=11, va='top', ha='center', 34 | bbox={'facecolor': 'white', 'edgecolor':'white', 'pad': 4,} 35 | ) 36 | 37 | plt.xticks([]) 38 | plt.yticks([]) 39 | 40 | if suptitle is not None: 41 | plt.suptitle(suptitle, size='x-large') 42 | 43 | 44 | plt.savefig(file_name, bbox_inches='tight') 45 | plt.close() --------------------------------------------------------------------------------