├── README.md ├── assets ├── method_compare_github.png ├── methodology_github.png └── puppy.jpg ├── datasets ├── EVAL_DATASETS.md ├── Urban1k │ └── annotations │ │ └── annotations.json ├── coco │ └── annotations │ │ └── captions_val2017.json ├── dci │ └── densely_captioned_images │ │ └── splits.json ├── docci │ └── annotations │ │ └── test_annotations.json ├── flickr30k-images │ └── flickr30k_val.json ├── imageinwords │ └── test_annotations.json └── share4v │ └── share4v_sam_10k.json ├── preprocess ├── convert_to_parquet.py ├── presplit_captions.py └── scraping_cc3m.sh └── src ├── flair ├── __init__.py ├── data.py ├── factory.py ├── loss.py ├── model.py ├── model_configs │ └── ViT-B-16-FLAIR.json ├── params.py ├── train.py └── transformer.py ├── inference.sh ├── main.py ├── minimal_example.py ├── requirements.txt ├── train_cc12m_slurm.sh ├── train_cc3m_slurm.sh ├── train_example.sh ├── train_merged30m_slurm.sh └── train_yfcc15m_slurm.sh /README.md: -------------------------------------------------------------------------------- 1 | # [CVPR 2025] FLAIR: VLM with Fine-grained Language-informed Image Representations 2 | [![Paper](https://img.shields.io/badge/paper-arxiv.2412.03561-B31B1B.svg)](https://arxiv.org/abs/2412.03561) 3 | [![Hugging Face](https://img.shields.io/badge/HuggingFace-FLAIR-FFD700?logo=huggingface&logoColor=yellow)](https://huggingface.co/xiaorui638/flair) 4 | 5 | 6 | **Authors:** [Rui Xiao](https://www.eml-munich.de/people/rui-xiao), [Sanghwan Kim](https://kim-sanghwan.github.io/), [Mariana-Iuliana Georgescu](https://lilygeorgescu.github.io/), [Zeynep Akata](https://www.eml-munich.de/people/zeynep-akata), [Stephan Alaniz](https://www.eml-munich.de/people/stephan-alaniz) 7 | 8 | ## News 9 | - **[2025-03-31]** 🍻 Check out [**COSMOS**](https://github.com/ExplainableML/cosmos), a self-distillation approach to be presented at **CVPR 2025**. 10 | - **[2025-03-02]** ⭐️ Training code & scripts released. 11 | - **[2025-02-26]** 🎉 Our paper was accepted to **CVPR 2025**. 12 | - **[2025-01-20]** ⭐️ Inference code & models released. 13 | 14 | ## Abstract 15 | CLIP has shown impressive results in aligning images and 16 | texts at scale. However, its ability to capture detailed visual features remains limited because CLIP matches images and texts at a global level. To address this issue, we propose **FLAIR**, **F**ine-grained **La**nguage-informed **I**mage 17 | **R**epresentations, an approach that utilizes long and detailed image descriptions to learn localized image embeddings. By sampling diverse sub-captions that describe fine-grained details about an image, we train our vision-language model to produce not only global embeddings but also text-specific image representations. Our model introduces text-conditioned attention pooling on top of local image tokens to produce fine-grained image representations that excel at retrieving detailed image content. We achieve state-of-the-art performance on both, existing multimodal retrieval benchmarks, as well as, our newly introduced fine-grained retrieval task which evaluates vision-language models’ ability to retrieve partial image content. Furthermore, our experiments demonstrate the effectiveness of FLAIR trained on 30M image-text pairs in capturing fine-grained visual information, including zero-shot semantic segmentation, outperforming models trained on billions of pairs. 18 | 19 | ## Methodology 20 | ![](assets/methodology_github.png "A general workflow for FLAIR") 21 | 22 | ## Pre-trained Models 23 | 24 | We released the pre-trained FLAIR models on [Huggingface](https://huggingface.co/xiaorui638/flair). The pre-trained models, their corresponding pre-trained datasets, and R@1 retrieval results on COCO and Flickr are listed below. For the full results please see the [paper](https://arxiv.org/pdf/2412.03561). Generally, FLAIR shares a similar architecture as the `ViT-B-16` model from [OpenCLIP](https://github.com/mlfoundations/open_clip), therefore also having similar number of parameters (150M vs 149M), the extra 1M parameters come from the text-conditioned attention pooling layer in FLAIR. 25 | 26 | | **Checkpoints** | **Pre-trained Datasets** | **COCO T2I** | **COCO I2T** | **Flickr T2I** | **Flickr I2T** | 27 | |------------------------------------------------------------------------------------------------------------|--------------------------|--------------|--------------|----------------|----------------| 28 | | [flair-cc3m-recap](https://huggingface.co/xiaorui638/flair/resolve/main/flair-cc3m-recap.pt?download=true) | CC3M-recap | 37.7 | 51.6 | 65.7 | 78.7 | 29 | | [flair-cc12m-recap](https://huggingface.co/xiaorui638/flair/resolve/main/flair-cc12m-recap.pt?download=true) | CC12M-recap | 47.8 | 64.1 | 75.4 | 90.8 | 30 | | [flair-yfcc15m-recap](https://huggingface.co/xiaorui638/flair/resolve/main/flair-yfcc15m-recap.pt?download=true) | YFCC15M-recap | 51.2 | 67.3 | 79.2 | 93.3 | 31 | | [flair-merged30m](https://huggingface.co/xiaorui638/flair/resolve/main/flair-merged30m.pt?download=true) | Merged30M | 53.3 | 68.0 | 81.1 | 94.7 | 32 | 33 | 34 | ⚠️ You don't need to manually download the pre-trained weights to run the inference, the pre-trained weights will be automatically downloaded by specifying the `huggingface-model-name` in `src/inference.sh` (More details in the 'Inference with FLAIR' section). However, if you would like to store the pretrained weights despite the default path, you could download them manually and set `--pretrained path/to/pretrained_weights` flag in `/src/inference.sh` instead (as OpenCLIP originally does). 35 | 36 | ## Dependencies 37 | The following small tutorial helps you set up a simple python virtual environment to run our code. Since our main dependency is [OpenCLIP](https://github.com/mlfoundations/open_clip), which is still updated frequently, you could always check their repo for a detailed tutorial on creating an environment that is best suited for your system. A conda environment is also possible with the same Python and PyTorch version. 38 | ### 1. Create a Virtual Environment 39 | First, navigate to the project’s root directory `flair/` and create a virtual environment using Python 3.12: 40 | ```bash 41 | cd flair/ 42 | python3.12 -m venv flair_env 43 | ``` 44 | ### 2. Activate and Navigate to src/ 45 | Activate the virtual environment and navigate to `src/` 46 | ```bash 47 | source flair_env/bin/activate 48 | cd src/ 49 | ``` 50 | 51 | ### 3. Install Dependencies 52 | Our code mainly involves installing `open_clip_torch` and `open_clip_torch[training]`. 53 | ```bash 54 | pip install --upgrade pip 55 | pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 56 | pip install -r requirements.txt 57 | ``` 58 | The code is tested in Python 3.12 with PyTorch 2.5.1 with CUDA 12.4. Since [OpenCLIP](https://github.com/mlfoundations/open_clip) is quite dependency-friendly, we would assume other up-to-date versions should also work. 59 | 60 | ## Usage 61 | A minimal usage of FLAIR is displayed in `src/minimal_example.py`, where we show that FLAIR has two ways of generating logits: 62 | 1. `model.get_logits()`: First query the local image tokens with the global text token using attention pooling, then compute the logits. This is the primary way of FLAIR getting the logits 63 | 2. `model.get_logits_as_clip()`: Without using attention pooling, directly compute the similarity between global-level image and text features. 64 | 65 | Run the example by: 66 | ```bash 67 | source flair_env/bin/activate 68 | python3 src/minimal_example.py 69 | ``` 70 | 71 | ## Inference Datasets Preparation 72 | Check [EVAL_DATASETS.md](datasets/EVAL_DATASETS.md) to prepare all the inference datasets. For clarity, we provide an example datasets folder with annotation files in `datasets/`. However, all datasets don't have to be stored in the same directory, you could specify them freely by changing the arguments in `src/inference.sh`. 73 | 74 | ## Inference with FLAIR 75 | To reproduce the retrieval results in the FLAIR paper, we provide an example inference bash script: `src/inference.sh`. Below are detailed explanations of important flags: 76 | 77 | - `--huggingface-repo-name`: Name of the Huggingface repo where the pre-trained models are stored. Should be fixed as `'xiaorui638/flair'`. 78 | - `--huggingface-model-name`: Name of the pretrained models. Options include: 79 | - `flair-cc3m-recap.pt` 80 | - `flair-cc12m-recap.pt` 81 | - `flair-yfcc15m-recap.pt` 82 | - `flair-merged30m.pt` 83 | - `--inference-with-flair`: Enable this flag when using the FLAIR model. 84 | - `--precision`: Fixed as `amp` in our paper. 85 | - `--workers`: Adjustable according to your system. 86 | 87 | ### Retrieval Tasks 88 | Enable the following flags in `src/inference.sh` for different retrieval tasks: 89 | 90 | 1. **Standard Retrieval** 91 | - `--coco-data-root-dir`: Root directory of the COCO dataset. 92 | - `--flickr-data-root-dir`: Root directory of the Flickr30k dataset. 93 | - `--retrieval-coco`: Activate the COCO retrieval task. 94 | - `--retrieval-flickr`: Activate the Flickr retrieval task. 95 | 2. **Fine-grained Retrieval** 96 | - `--iiw-retrieval-dir`: Root directory of the Image-in-Words dataset. 97 | - `--docci-retrieval-dir`: Root directory of the DOCCI dataset. 98 | - `--retrieval-iiw`: Activate the Image-in-Words retrieval task. 99 | - `--retrieval-docci`: Activate the DOCCI retrieval task. 100 | 3. **Long Retrieval** 101 | - `--dci-retrieval-dir`: Root directory of the DCI dataset. 102 | - `--urban-1k-retrieval-dir`: Root directory of the Urban-1K dataset. 103 | - `--sharegpt4v-retrieval-dir`: Root directory of the ShareGPT4V dataset. 104 | - `--retrieval-dci`: Activate the DCI retrieval task. 105 | - `--retrieval-urban-1k`: Activate the Urban1K retrieval task. 106 | - `--retrieval-sharegpt4v-1k`: Activate the ShareGPT4V-1K retrieval task. 107 | - `--retrieval-sharegpt4v-10k`: Activate the ShareGPT4V-10K retrieval task. 108 | 109 | ## Training FLAIR 110 | For results displayed in the main paper, FLAIR used [DreamLIP](https://github.com/ant-research/DreamLIP)'s recaptioned [CC3M-recap](https://huggingface.co/datasets/qidouxiong619/dreamlip_long_captions), [CC12M-recap](https://huggingface.co/datasets/qidouxiong619/dreamlip_long_captions), [YFCC15M-recap](https://huggingface.co/datasets/qidouxiong619/dreamlip_long_captions) and combined (Merged-30M). To verify that FLAIR is fit for various data distributions, FLAIR is also trained on the original [CC3M](https://huggingface.co/datasets/pixparse/cc3m-wds) and [PixelProse](https://huggingface.co/datasets/tomg-group-umd/pixelprose), results presented in the appendix of the paper. Notably, FLAIR requires all pre-training dataset to be processed into the [webdataset](https://github.com/webdataset/webdataset) format, to achieve higher I/O efficiency for large-scale training. In the pre-training dataset preparation step, we will take [CC3M-recap](https://huggingface.co/datasets/qidouxiong619/dreamlip_long_captions) as an example to demonstrate how to prepare the pretraining data. The preparation for other datasets should be similar. 111 | 112 | ### Prepare Pre-training Data 113 | 1. Download DreamLIP's annotations for CC3M-recap: 114 | `wget https://huggingface.co/datasets/qidouxiong619/dreamlip_long_captions/resolve/main/cc3m_3long_3short_1raw_captions_url.csv` 115 | 2. Convert to `.parquet` format: `python3 /preprocess/convert_to_parquet.py --input-path /path/to/csv --output-path /path/to/parquet` 116 | 3. Scrape the images based on the url links using [img2dataset](https://github.com/rom1504/img2dataset), replace the paths accordingly: 117 | `bash preprocess/scraping_cc3m.sh` 118 | 4. Now that the captions should be stored inside each shard with the `.json` format. We then pre-split all the captions and re-write the shards: 119 | ```python 120 | python3 preproecss/presplit_captions.py --shards-dir /path/to/cc3m --num-processes 24 121 | ``` 122 | **Remarks**: FLAIR requires the captions stored in `.json` format inside each shard, so that the captions can be handled by `sample_dict()` function in `src/flair/data.py`. Instead of pre-splitting captions in step 4, an alternative approach would be splitting the captions inside the `sample_dict()` function (see this [issue](https://github.com/ExplainableML/flair/issues/6)). 123 | To minimize loss of images, you could also download existing HuggingFace datasets to avoid the scraping in step 3. 124 | 125 | ### Single-node training script 126 | Users can find the single-node training script example `src/train_example.sh` in this repo, to test if the training runs. Important flags: 127 | - `--train-data`: Root dir of where the training data (shards) is stored. 128 | - `--train-num-samples`: In the example file we set it to `2823019` because that's the total number of image-text pairs we get in CC3M-recap. This should be adjustable based on your data. 129 | 130 | The single-node training script `src/train_example.sh` has been tested to run without problems. We always recommend you to run your job on single node first before starting the multi-node training by: 131 | ```bash 132 | source flair_env/bin/activate 133 | bash src/train_example.sh 134 | ``` 135 | 136 | ### Multi-node training script (Slurm) 137 | In practice, FLAIR is trained with 8 NVIDIA A100s 40GB (on CC3M) or 32 NVIDIA A100s 40GB (on all larger datasets), where we finished all experiments using Slurm. In `src/`, we provide example slurm training scripts for each of the datasets, they are: `train_cc3m_slurm.sh, train_cc12m_slurm.sh, train_yfcc15m_slurm.sh, train_merged_30m_slurm.sh`. 138 | 139 | These training scripts contain all the necessary hyperparams you need to reproduce the training. However, you might need to add modifications to be able to run on your cluster. Please specify `--train-data` to be the directory storing the dataset shards and `--train-num-samples` to be the actual valid samples of that dataset. When training on the Merged-30M dataset, note that the `--train-data` should be the combination of the dataset paths of `cc3m-recap, cc12m-recap, yfcc15m-recap` separated by `::`, such as: 140 | 141 | `--train-data '/datasets/yfcc15m_recap/yfcc15m-train-{0000..3636}.tar::/datasets/cc12m_recap/cc12m-train-{0000..2175}.tar::/datasets/cc3m_recap/cc3m-train-{0001..0575}.tar'` 142 | 143 | After configuring the Slurm scripts correctly, you could run the experiment by (taking CC3M-recap as an example): 144 | ```bash 145 | source flair_env/bin/activate 146 | sbatch src/train_cc3m_slurm.sh 147 | ``` 148 | 149 | 150 | ## Acknowledgements 151 | We thank [OpenCLIP](https://github.com/mlfoundations/open_clip) for providing the amazing code base. Meanwhile, we acknowledge [DreamLIP](https://github.com/zyf0619sjtu/DreamLIP) and [PixelProse](https://huggingface.co/datasets/tomg-group-umd/pixelprose) for providing us with various pre-training datasets with captions from MLLMs. We are also greateful for [LoTLIP](https://github.com/wuw2019/LoTLIP) for providing the the detailed scheme for long image-text retrieval task. 152 | 153 | ## Citations 154 | If you find our work useful, please star this repo and cite: 155 | 156 | ```bibtex 157 | @article{xiao2024flair, 158 | title={FLAIR: VLM with Fine-grained Language-informed Image Representations}, 159 | author={Xiao, Rui and Kim, Sanghwan and Georgescu, Mariana-Iuliana and Akata, Zeynep and Alaniz, Stephan}, 160 | journal={CVPR}, 161 | year={2025} 162 | } 163 | 164 | -------------------------------------------------------------------------------- /assets/method_compare_github.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/flair/bf0cf64d7cee7295f05a3f6e00cf315a04a3c883/assets/method_compare_github.png -------------------------------------------------------------------------------- /assets/methodology_github.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/flair/bf0cf64d7cee7295f05a3f6e00cf315a04a3c883/assets/methodology_github.png -------------------------------------------------------------------------------- /assets/puppy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/flair/bf0cf64d7cee7295f05a3f6e00cf315a04a3c883/assets/puppy.jpg -------------------------------------------------------------------------------- /datasets/EVAL_DATASETS.md: -------------------------------------------------------------------------------- 1 | # Data Preparation for Text-Image Retrieval 2 | 3 | ### Annotation Files 4 | We pre-processed and unified the annotations for various datasets to be in `.json` format to standardize them. These annotation files are stored under `datasets/` directory of this repo. To use our inference code properly, you should also use the same annotation files, the detailed instructions are as follows: 5 | 6 | ### Datasets list: 7 | - [MSCOCO](#coco) 8 | - [FLICKR30K](#flickr) 9 | - [DOCCI](#docci) 10 | - [IIW](#IIW) 11 | - [ShareGPT4v](#share4v) 12 | - [DCI](#DCI) 13 | - [Urban1k](#urban1k) 14 | 15 | 16 | ### MSCOCO dataset 17 | ``` 18 | $coco/ 19 | |–– images/ 20 | |–––– val2017/ 21 | |–––––– 000000134722.jpg 22 | |–––––– 000000177015.jpg 23 | |–––––– ... 24 | |–– annotations/ 25 | |–––– captions_val2017.json 26 | ``` 27 | Step 1. Download validation images from [COCO 2017 Val Images](https://cocodataset.org/#download), unzip them to `coco/images/val2017` 28 | 29 | Step 2. Download the 2017 Val annotations, place it under `coco/annotations/captions_val2017.json` 30 | 31 | ### FLCIKR30K dataset 32 | ``` 33 | $flickr30k-images/ 34 | |–– 2217728745.jpg 35 | |–– 2217728745.jpg 36 | |–– ... 37 | |–– flickr30k_val.json 38 | |–– flickr30k_test.json 39 | ``` 40 | Step 1. Download [flickr30k dataset](https://huggingface.co/datasets/nlphuji/flickr30k), unzip them under `flickr30k-images/`, all the images and annotations files will be structured as above 41 | 42 | ### DOCCI dataset 43 | ``` 44 | $docci/ 45 | |–– images/ 46 | |–––– test_01427.jpg 47 | |–––– test_01428.jpg 48 | |–––– ... 49 | |–– annotations/ 50 | |–––– test_annotations.json 51 | ``` 52 | Step 1. Download [DOCCI Images](https://storage.googleapis.com/docci/data/docci_images.tar.gz), unzip them under `docci/images/`, note that we only need the 5K test images here. 53 | 54 | Step 2. Directly copy the `test_annotations.json` in this repo and put it under `docci/annotations`. This annotation file documents the mapping between all test images with all fine-grained captions. 55 | 56 | ### IIW dataset 57 | 58 | ``` 59 | $imageinwords/ 60 | |–– dci/ 61 | |–– docci/ 62 | |–– docci_aar/ 63 | |–– finegrained_annotations.json 64 | ``` 65 | 66 | **Download human annotated data following [IIW](https://github.com/google/imageinwords/tree/main/datasets), including IIW-400, DCI-Test, DOCCI-Test**: 67 | 68 | Step 1: Download [DCI](https://github.com/facebookresearch/DCI) to path_to_dci_dataset. 69 | 70 | Step 2: Download DOCCI images and AAR images from [DOCCI](https://google.github.io/docci/#downloads) dataset. Unzip the files to path_to_docci_dataset/images and path_to_docci_dataset/images_aar, respectively. 71 | 72 | Step 3: Directly copy `finegrained_annotations.json` in this repo and put it under `imageinwords\`. 73 | 74 | 75 | ### ShareGPT4v dataset 76 | 77 | ``` 78 | $share4v/ 79 | |–– sa_000000/ 80 | |–––– images/ 81 | |–––––– sa_1.jpg 82 | |–––––– sa_2.jpg 83 | |–––––– ... 84 | |–– sa_000001/ 85 | |–– ... 86 | ``` 87 | 88 | Step 1. Download tar files from [SA-1B](https://huggingface.co/datasets/sailvideo/SA-1B) to `share4v/`. 89 | 90 | Step 2. Unzip all tar files. 91 | 92 | For the annotations, we have resaved the top 10k samples from [share-captioner_coco_lcs_sam_1246k_1107.json](https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/tree/main) in dataloaders/share4v/share4v_sam_10k.json. 93 | 94 | 95 | ### DCI dataset 96 | 97 | ``` 98 | $dci/ 99 | |–– densely_captioned_images/ 100 | |–––– annotations/ 101 | |–––– photos/ 102 | |–––– splits.json 103 | 104 | ``` 105 | 106 | **Download data following [DCI](https://github.com/facebookresearch/DCI)**: 107 | 108 | Step 1. Download [dci.tar.gz](https://dl.fbaipublicfiles.com/densely_captioned_images/dci.tar.gz) and unzip the file in `dci/densely_captioned_images`. 109 | 110 | Step 2. Download the archive sa_000138.tar and extract the images to the `dci/densely_captioned_images/photos folder`. 111 | 112 | 113 | ### Urban1k dataset 114 | ``` 115 | $Urban1k/ 116 | |–– images/ 117 | |–––– 221.jpg 118 | |–––– 222.jpg 119 | |–––– ... 120 | |–– annotations/ 121 | |–––– annotations.json 122 | ``` 123 | Step 1. Download [Urban1K](https://huggingface.co/datasets/BeichenZhang/Urban1k), unzip them, only put the images(without the caption folder)under `Urban1k/images/`. 124 | 125 | Step 2. Directly copy the `annotations.json` in this repo and put it under `Urban1k/annotations`. This annotation file documents the mapping between each image with its corresponding long caption. 126 | -------------------------------------------------------------------------------- /preprocess/convert_to_parquet.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import argparse 3 | 4 | 5 | def csv_to_parquet(input_path: str, output_path: str): 6 | print("Start converting, this may take a while...") 7 | df = pd.read_csv(input_path) 8 | df.to_parquet(output_path, index=False) 9 | print("Conversion complete.") 10 | 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser(description="Convert a CSV file to Parquet format.") 14 | parser.add_argument("input_path", type=str, help="Path to the input CSV file") 15 | parser.add_argument("output_path", type=str, help="Path to save the output Parquet file") 16 | args = parser.parse_args() 17 | 18 | csv_to_parquet(args.input_path, args.output_path) 19 | -------------------------------------------------------------------------------- /preprocess/presplit_captions.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import tarfile 4 | import re 5 | import json 6 | from multiprocessing import Pool 7 | from tqdm import tqdm 8 | from io import BytesIO 9 | 10 | CAPTION_KEYS = [ 11 | "raw_caption", 12 | "shortIB_captions", "longIB_captions", 13 | "shortSV_captions", "longSV_captions", 14 | "shortLLA_captions", "longLLA_captions" 15 | ] 16 | 17 | def split_caption(text): 18 | texts = re.split(r'\n||[.]', text) 19 | subcap = [] 20 | for text_prompt in texts: 21 | text_prompt = text_prompt.strip() 22 | if len(text_prompt) != 0: 23 | subcap.append(text_prompt) 24 | return subcap 25 | 26 | def process_tar(tar_path): 27 | tmp_tar_path = tar_path + ".tmp" 28 | try: 29 | with tarfile.open(tar_path, 'r') as in_tar, tarfile.open(tmp_tar_path, 'w') as out_tar: 30 | for member in in_tar.getmembers(): 31 | file_bytes = in_tar.extractfile(member).read() 32 | 33 | if member.name.endswith('.json'): 34 | json_obj = json.loads(file_bytes.decode('utf-8')) 35 | for k in CAPTION_KEYS: 36 | if k in json_obj and isinstance(json_obj[k], str): 37 | json_obj[k] = split_caption(json_obj[k]) 38 | file_bytes = json.dumps(json_obj).encode('utf-8') 39 | 40 | info = tarfile.TarInfo(name=member.name) 41 | info.size = len(file_bytes) 42 | out_tar.addfile(info, BytesIO(file_bytes)) 43 | 44 | os.replace(tmp_tar_path, tar_path) 45 | return (tar_path, "success") 46 | 47 | except Exception as e: 48 | return (tar_path, f"failed: {e}") 49 | 50 | def main(shards_dir, num_processes): 51 | shard_paths = [ 52 | os.path.join(shards_dir, f) 53 | for f in os.listdir(shards_dir) 54 | if f.endswith('.tar') 55 | ] 56 | with Pool(processes=num_processes) as pool: 57 | results = list(tqdm(pool.imap_unordered(process_tar, shard_paths), total=len(shard_paths))) 58 | for tar, status in results: 59 | print(f"{tar}: {status}") 60 | 61 | if __name__ == "__main__": 62 | 63 | parser = argparse.ArgumentParser(description="Split caption fields inside .json files of WebDataset shards.") 64 | parser.add_argument("shards_dir", type=str, help="Path to directory containing .tar shards") 65 | parser.add_argument("num_processes", type=int, help="Number of parallel processes to use") 66 | args = parser.parse_args() 67 | main(args.shards_dir, args.num_processes) 68 | -------------------------------------------------------------------------------- /preprocess/scraping_cc3m.sh: -------------------------------------------------------------------------------- 1 | img2dataset \ 2 | --url_list .datasets/cc3m_3long_3short_1raw_captions_url.parquet \ 3 | --input_format "parquet" \ 4 | --url_col "Image Path" \ 5 | --output_format "webdataset" \ 6 | --output_folder ./datasets/scraped_cc3m \ 7 | --processes_count 32 \ 8 | --thread_count 64 \ 9 | --number_sample_per_shard 5000 \ 10 | --save_additional_columns '["raw_caption", "shortIB_captions", "longIB_captions", "shortSV_captions", "longSV_captions", "shortLLA_captions", "longLLA_captions"]' -------------------------------------------------------------------------------- /src/flair/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import create_model, create_model_and_transforms, get_tokenizer 2 | from .factory import get_model_config, load_checkpoint, download_weights_from_hf 3 | from .model import CLIPTextCfg, CLIPVisionCfg, \ 4 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \ 5 | get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg 6 | -------------------------------------------------------------------------------- /src/flair/factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is adapted from OpenCLIP: 3 | https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/factory.py 4 | 5 | The code integrates additional modifications and extensions to support the FLAIR models. 6 | Original authors: ML Foundations. 7 | """ 8 | import json 9 | import logging 10 | import os 11 | import re 12 | from copy import deepcopy 13 | from dataclasses import asdict 14 | from pathlib import Path 15 | from typing import Any, Dict, Optional, Tuple, Union 16 | from huggingface_hub import hf_hub_download 17 | 18 | import torch 19 | 20 | from open_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 21 | from open_clip.convert import convert_state_dict 22 | from open_clip.model import CLIP, CustomTextCLIP 23 | from .model import convert_weights_to_lp, convert_to_custom_text_state_dict, \ 24 | resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg, FLAIR 25 | 26 | from open_clip.openai import load_openai_model 27 | from open_clip.pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, \ 28 | list_pretrained_tags_by_model, download_pretrained_from_hf 29 | from open_clip.transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, \ 30 | merge_preprocess_kwargs, image_transform 31 | from open_clip.tokenizer import HFTokenizer, SimpleTokenizer, DEFAULT_CONTEXT_LENGTH 32 | from .loss import FlairLoss 33 | 34 | HF_HUB_PREFIX = 'hf-hub:' 35 | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] 36 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs 37 | 38 | def download_weights_from_hf(model_repo, filename): 39 | # Define the custom cache directory relative to the current script 40 | cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "pretrained") 41 | if not os.path.exists(cache_dir): 42 | os.makedirs(cache_dir, exist_ok=True) 43 | local_path = hf_hub_download(repo_id=model_repo, filename=filename, cache_dir=cache_dir) 44 | return local_path 45 | 46 | def _natural_key(string_): 47 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 48 | 49 | 50 | def _rescan_model_configs(): 51 | global _MODEL_CONFIGS 52 | 53 | config_ext = ('.json',) 54 | config_files = [] 55 | for config_path in _MODEL_CONFIG_PATHS: 56 | if config_path.is_file() and config_path.suffix in config_ext: 57 | config_files.append(config_path) 58 | elif config_path.is_dir(): 59 | for ext in config_ext: 60 | config_files.extend(config_path.glob(f'*{ext}')) 61 | 62 | for cf in config_files: 63 | with open(cf, 'r') as f: 64 | model_cfg = json.load(f) 65 | if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): 66 | _MODEL_CONFIGS[cf.stem] = model_cfg 67 | 68 | _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} 69 | 70 | 71 | _rescan_model_configs() # initial populate of model config registry 72 | 73 | 74 | def list_models(): 75 | """ enumerate available model architectures based on config files """ 76 | return list(_MODEL_CONFIGS.keys()) 77 | 78 | 79 | def add_model_config(path): 80 | """ add model config path or file and update registry """ 81 | if not isinstance(path, Path): 82 | path = Path(path) 83 | _MODEL_CONFIG_PATHS.append(path) 84 | _rescan_model_configs() 85 | 86 | 87 | def get_tokenizer( 88 | model_name: str = '', 89 | context_length: Optional[int] = None, 90 | **kwargs, 91 | ): 92 | if model_name.startswith(HF_HUB_PREFIX): 93 | model_name = model_name[len(HF_HUB_PREFIX):] 94 | try: 95 | config = _get_hf_config(model_name)['model_cfg'] 96 | except Exception: 97 | tokenizer = HFTokenizer( 98 | model_name, 99 | context_length=context_length or DEFAULT_CONTEXT_LENGTH, 100 | **kwargs, 101 | ) 102 | return tokenizer 103 | else: 104 | config = get_model_config(model_name) 105 | assert config is not None, f"No valid model config found for {model_name}." 106 | 107 | text_config = config.get('text_cfg', {}) 108 | if 'tokenizer_kwargs' in text_config: 109 | tokenizer_kwargs = dict(text_config['tokenizer_kwargs'], **kwargs) 110 | else: 111 | tokenizer_kwargs = kwargs 112 | 113 | if context_length is None: 114 | context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH) 115 | 116 | if 'hf_tokenizer_name' in text_config: 117 | tokenizer = HFTokenizer( 118 | text_config['hf_tokenizer_name'], 119 | context_length=context_length, 120 | **tokenizer_kwargs, 121 | ) 122 | else: 123 | tokenizer = SimpleTokenizer( 124 | context_length=context_length, 125 | **tokenizer_kwargs, 126 | ) 127 | 128 | return tokenizer 129 | 130 | 131 | def _get_hf_config(model_id, cache_dir=None): 132 | config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir) 133 | with open(config_path, 'r', encoding='utf-8') as f: 134 | config = json.load(f) 135 | return config 136 | 137 | def get_model_config(model_name): 138 | if model_name in _MODEL_CONFIGS: 139 | return deepcopy(_MODEL_CONFIGS[model_name]) 140 | else: 141 | return None 142 | 143 | 144 | def load_state_dict(checkpoint_path: str, map_location='cpu'): 145 | checkpoint = torch.load(checkpoint_path, map_location=map_location) 146 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: 147 | state_dict = checkpoint['state_dict'] 148 | elif isinstance(checkpoint, torch.jit.ScriptModule): 149 | state_dict = checkpoint.state_dict() 150 | for key in ["input_resolution", "context_length", "vocab_size"]: 151 | state_dict.pop(key, None) 152 | else: 153 | state_dict = checkpoint 154 | if next(iter(state_dict.items()))[0].startswith('module'): 155 | state_dict = {k[7:]: v for k, v in state_dict.items()} 156 | return state_dict 157 | 158 | 159 | def load_checkpoint( 160 | model: Union[CLIP, CustomTextCLIP], 161 | checkpoint_path: str, 162 | strict: bool = True, 163 | ): 164 | if Path(checkpoint_path).suffix in ('.npz', '.npy'): 165 | # Separate path loading numpy big_vision (SigLIP) weights 166 | from open_clip.convert import load_big_vision_weights 167 | load_big_vision_weights(model, checkpoint_path) 168 | return {} 169 | 170 | state_dict = load_state_dict(checkpoint_path) 171 | 172 | # Detect & convert 3rd party state_dicts -> open_clip 173 | state_dict = convert_state_dict(model, state_dict) 174 | 175 | # Detect old format and make compatible with new format 176 | if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): 177 | state_dict = convert_to_custom_text_state_dict(state_dict) 178 | 179 | # If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712 180 | if 'logit_bias' not in state_dict and model.logit_bias is not None: 181 | state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"]) 182 | 183 | # Certain text transformers no longer expect position_ids after transformers==4.31 184 | position_id_key = 'text.transformer.embeddings.position_ids' 185 | if position_id_key in state_dict and not hasattr(model, position_id_key): 186 | del state_dict[position_id_key] 187 | 188 | resize_pos_embed(state_dict, model) 189 | resize_text_pos_embed(state_dict, model) 190 | 191 | # Finally, load the massaged state_dict into model 192 | incompatible_keys = model.load_state_dict(state_dict, strict=strict) 193 | return incompatible_keys 194 | 195 | 196 | def create_model( 197 | model_name: str, 198 | pretrained: Optional[str] = None, 199 | precision: str = 'fp32', 200 | device: Union[str, torch.device] = 'cpu', 201 | jit: bool = False, 202 | force_quick_gelu: bool = False, 203 | force_custom_text: bool = False, 204 | force_patch_dropout: Optional[float] = None, 205 | force_image_size: Optional[Union[int, Tuple[int, int]]] = None, 206 | force_preprocess_cfg: Optional[Dict[str, Any]] = None, 207 | pretrained_image: bool = False, 208 | pretrained_hf: bool = True, 209 | cache_dir: Optional[str] = None, 210 | output_dict: Optional[bool] = None, 211 | require_pretrained: bool = False, 212 | **model_kwargs, 213 | ): 214 | force_preprocess_cfg = force_preprocess_cfg or {} 215 | preprocess_cfg = asdict(PreprocessCfg()) 216 | has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) 217 | if has_hf_hub_prefix: 218 | model_id = model_name[len(HF_HUB_PREFIX):] 219 | checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir) 220 | config = _get_hf_config(model_id, cache_dir) 221 | preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg']) 222 | model_cfg = config['model_cfg'] 223 | pretrained_hf = False # override, no need to load original HF text weights 224 | else: 225 | model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names 226 | checkpoint_path = None 227 | model_cfg = None 228 | 229 | if isinstance(device, str): 230 | device = torch.device(device) 231 | 232 | if pretrained and pretrained.lower() == 'openai': 233 | logging.info(f'Loading pretrained {model_name} from OpenAI.') 234 | model = load_openai_model( 235 | model_name, 236 | precision=precision, 237 | device=device, 238 | cache_dir=cache_dir, 239 | ) 240 | else: 241 | model_cfg = model_cfg or get_model_config(model_name) 242 | if model_cfg is not None: 243 | logging.info(f'Loaded {model_name} model config.') 244 | else: 245 | logging.error(f'Model config for {model_name} not found.') 246 | raise RuntimeError(f'Model config for {model_name} not found.') 247 | 248 | if force_quick_gelu: 249 | # override for use of QuickGELU on non-OpenAI transformer models 250 | model_cfg["quick_gelu"] = True 251 | 252 | if force_patch_dropout is not None: 253 | # override the default patch dropout value 254 | model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout 255 | 256 | if force_image_size is not None: 257 | # override model config's image size 258 | model_cfg["vision_cfg"]["image_size"] = force_image_size 259 | 260 | is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {}) 261 | if pretrained_image: 262 | if is_timm_model: 263 | # pretrained weight loading for timm models set via vision_cfg 264 | model_cfg['vision_cfg']['timm_model_pretrained'] = True 265 | else: 266 | assert False, 'pretrained image towers currently only supported for timm models' 267 | 268 | # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes 269 | cast_dtype = get_cast_dtype(precision) 270 | is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) 271 | if is_hf_model: 272 | # load pretrained weights for HF text model IFF no CLIP weights being loaded 273 | model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained 274 | custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model 275 | 276 | model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg) 277 | 278 | if "FLAIR" in model_name: 279 | model = FLAIR(**model_cfg, cast_dtype=cast_dtype) 280 | else: 281 | model = CLIP(**model_cfg, cast_dtype=cast_dtype) 282 | 283 | if precision in ("fp16", "bf16"): 284 | dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 285 | # manual mixed precision that matches original OpenAI behaviour 286 | if is_timm_model: 287 | # FIXME this is a bit janky, create timm based model in low-precision and 288 | # then cast only LayerNormFp32 instances back to float32 so they don't break. 289 | # Why? The convert_weights_to_lp fn only works with native models. 290 | model.to(device=device, dtype=dtype) 291 | from .transformer import LayerNormFp32 292 | 293 | def _convert_ln(m): 294 | if isinstance(m, LayerNormFp32): 295 | m.weight.data = m.weight.data.to(torch.float32) 296 | m.bias.data = m.bias.data.to(torch.float32) 297 | 298 | model.apply(_convert_ln) 299 | else: 300 | model.to(device=device) 301 | convert_weights_to_lp(model, dtype=dtype) 302 | elif precision in ("pure_fp16", "pure_bf16"): 303 | dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 304 | model.to(device=device, dtype=dtype) 305 | else: 306 | model.to(device=device) 307 | 308 | pretrained_loaded = False 309 | if pretrained: 310 | checkpoint_path = '' 311 | pretrained_cfg = get_pretrained_cfg(model_name, pretrained) 312 | if pretrained_cfg: 313 | checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) 314 | preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg) 315 | elif os.path.exists(pretrained): 316 | checkpoint_path = pretrained 317 | 318 | if checkpoint_path: 319 | logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') 320 | load_checkpoint(model, checkpoint_path, strict=True) 321 | else: 322 | error_str = ( 323 | f'Pretrained weights ({pretrained}) not found for model {model_name}.' 324 | f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') 325 | logging.warning(error_str) 326 | raise RuntimeError(error_str) 327 | pretrained_loaded = True 328 | elif has_hf_hub_prefix: 329 | logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).') 330 | load_checkpoint(model, checkpoint_path) 331 | pretrained_loaded = True 332 | 333 | if require_pretrained and not pretrained_loaded: 334 | # callers of create_model_from_pretrained always expect pretrained weights 335 | raise RuntimeError( 336 | f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') 337 | 338 | if output_dict and hasattr(model, "output_dict"): 339 | model.output_dict = True 340 | 341 | if jit: 342 | model = torch.jit.script(model) 343 | 344 | # set image preprocessing configuration in model attributes for convenience 345 | if getattr(model.visual, 'image_size', None) is not None: 346 | # use image_size set on model creation (via config or force_image_size arg) 347 | force_preprocess_cfg['size'] = model.visual.image_size 348 | set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg)) 349 | 350 | return model 351 | 352 | 353 | def create_model_and_transforms( 354 | model_name: str, 355 | pretrained: Optional[str] = None, 356 | precision: str = 'fp32', 357 | device: Union[str, torch.device] = 'cpu', 358 | jit: bool = False, 359 | force_quick_gelu: bool = False, 360 | force_custom_text: bool = False, 361 | force_patch_dropout: Optional[float] = None, 362 | force_image_size: Optional[Union[int, Tuple[int, int]]] = None, 363 | image_mean: Optional[Tuple[float, ...]] = None, 364 | image_std: Optional[Tuple[float, ...]] = None, 365 | image_interpolation: Optional[str] = None, 366 | image_resize_mode: Optional[str] = None, # only effective for inference 367 | aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, 368 | pretrained_image: bool = False, 369 | pretrained_hf: bool = True, 370 | cache_dir: Optional[str] = None, 371 | output_dict: Optional[bool] = None, 372 | **model_kwargs, 373 | ): 374 | force_preprocess_cfg = merge_preprocess_kwargs( 375 | {}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode) 376 | 377 | model = create_model( 378 | model_name, 379 | pretrained, 380 | precision=precision, 381 | device=device, 382 | jit=jit, 383 | force_quick_gelu=force_quick_gelu, 384 | force_custom_text=force_custom_text, 385 | force_patch_dropout=force_patch_dropout, 386 | force_image_size=force_image_size, 387 | force_preprocess_cfg=force_preprocess_cfg, 388 | pretrained_image=pretrained_image, 389 | pretrained_hf=pretrained_hf, 390 | cache_dir=cache_dir, 391 | output_dict=output_dict, 392 | **model_kwargs, 393 | ) 394 | 395 | pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg) 396 | 397 | preprocess_train = image_transform_v2( 398 | pp_cfg, 399 | is_train=True, 400 | aug_cfg=aug_cfg, 401 | ) 402 | preprocess_val = image_transform_v2( 403 | pp_cfg, 404 | is_train=False, 405 | ) 406 | 407 | return model, preprocess_train, preprocess_val 408 | 409 | 410 | def create_loss(args): 411 | 412 | if args.use_flair_loss: 413 | if args.add_mps_loss: 414 | return FlairLoss(rank=args.rank, 415 | world_size=args.world_size, 416 | num_cap_per_img=args.num_sampled_captions, 417 | added_mps_loss=True) 418 | else: 419 | return FlairLoss(rank=args.rank, 420 | world_size=args.world_size, 421 | num_cap_per_img=args.num_sampled_captions, 422 | added_mps_loss=False) 423 | 424 | else: 425 | raise NotImplementedError("Loss function for the given configuration is not implemented.") -------------------------------------------------------------------------------- /src/flair/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | from torch.utils.checkpoint import checkpoint 5 | import math 6 | 7 | try: 8 | import torch.distributed.nn 9 | from torch import distributed as dist 10 | 11 | has_distributed = True 12 | except ImportError: 13 | has_distributed = False 14 | 15 | try: 16 | import horovod.torch as hvd 17 | except ImportError: 18 | hvd = None 19 | 20 | 21 | def gather_features( 22 | image_features, 23 | text_features, 24 | local_loss=False, 25 | gather_with_grad=False, 26 | rank=0, 27 | world_size=1, 28 | use_horovod=False 29 | ): 30 | assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.' 31 | if use_horovod: 32 | assert hvd is not None, 'Please install horovod' 33 | if gather_with_grad: 34 | all_image_features = hvd.allgather(image_features) 35 | all_text_features = hvd.allgather(text_features) 36 | else: 37 | with torch.no_grad(): 38 | all_image_features = hvd.allgather(image_features) 39 | all_text_features = hvd.allgather(text_features) 40 | if not local_loss: 41 | # ensure grads for local rank when all_* features don't have a gradient 42 | gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) 43 | gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) 44 | gathered_image_features[rank] = image_features 45 | gathered_text_features[rank] = text_features 46 | all_image_features = torch.cat(gathered_image_features, dim=0) 47 | all_text_features = torch.cat(gathered_text_features, dim=0) 48 | else: 49 | # We gather tensors from all gpus 50 | if gather_with_grad: 51 | all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) 52 | all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) 53 | else: 54 | gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] 55 | gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] 56 | dist.all_gather(gathered_image_features, image_features) 57 | dist.all_gather(gathered_text_features, text_features) 58 | if not local_loss: 59 | # ensure grads for local rank when all_* features don't have a gradient 60 | gathered_image_features[rank] = image_features 61 | gathered_text_features[rank] = text_features 62 | all_image_features = torch.cat(gathered_image_features, dim=0) 63 | all_text_features = torch.cat(gathered_text_features, dim=0) 64 | 65 | return all_image_features, all_text_features 66 | 67 | 68 | 69 | 70 | def neighbour_exchange(from_rank, to_rank, tensor, group=None): 71 | tensor_recv = torch.zeros_like(tensor) 72 | send_op = torch.distributed.P2POp( 73 | torch.distributed.isend, 74 | tensor, 75 | to_rank, 76 | group=group, 77 | ) 78 | recv_op = torch.distributed.P2POp( 79 | torch.distributed.irecv, 80 | tensor_recv, 81 | from_rank, 82 | group=group, 83 | ) 84 | reqs = torch.distributed.batch_isend_irecv([send_op, recv_op]) 85 | for req in reqs: 86 | req.wait() 87 | return tensor_recv 88 | 89 | 90 | def neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None): 91 | tensor_from_left = torch.zeros_like(tensor_to_right) 92 | tensor_from_right = torch.zeros_like(tensor_to_left) 93 | send_op_left = torch.distributed.P2POp( 94 | torch.distributed.isend, 95 | tensor_to_left, 96 | left_rank, 97 | group=group, 98 | ) 99 | send_op_right = torch.distributed.P2POp( 100 | torch.distributed.isend, 101 | tensor_to_right, 102 | right_rank, 103 | group=group, 104 | ) 105 | recv_op_left = torch.distributed.P2POp( 106 | torch.distributed.irecv, 107 | tensor_from_left, 108 | left_rank, 109 | group=group, 110 | ) 111 | recv_op_right = torch.distributed.P2POp( 112 | torch.distributed.irecv, 113 | tensor_from_right, 114 | right_rank, 115 | group=group, 116 | ) 117 | reqs = torch.distributed.batch_isend_irecv([send_op_right, send_op_left, recv_op_right, recv_op_left]) 118 | for req in reqs: 119 | req.wait() 120 | return tensor_from_right, tensor_from_left 121 | 122 | 123 | class NeighbourExchange(torch.autograd.Function): 124 | @staticmethod 125 | def forward(ctx, from_rank, to_rank, group, tensor): 126 | ctx.group = group 127 | ctx.from_rank = from_rank 128 | ctx.to_rank = to_rank 129 | return neighbour_exchange(from_rank, to_rank, tensor, group=group) 130 | 131 | @staticmethod 132 | def backward(ctx, grad_output): 133 | return (None, None, None) + (NeighbourExchange.apply(ctx.to_rank, ctx.from_rank, ctx.group, grad_output),) 134 | 135 | 136 | def neighbour_exchange_with_grad(from_rank, to_rank, tensor, group=None): 137 | return NeighbourExchange.apply(from_rank, to_rank, group, tensor) 138 | 139 | 140 | class NeighbourExchangeBidir(torch.autograd.Function): 141 | @staticmethod 142 | def forward(ctx, left_rank, right_rank, group, tensor_to_left, tensor_to_right): 143 | ctx.group = group 144 | ctx.left_rank = left_rank 145 | ctx.right_rank = right_rank 146 | return neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=group) 147 | 148 | @staticmethod 149 | def backward(ctx, *grad_outputs): 150 | return (None, None, None) + \ 151 | NeighbourExchangeBidir.apply(ctx.right_rank, ctx.left_rank, ctx.group, *grad_outputs) 152 | 153 | 154 | def neighbour_exchange_bidir_with_grad(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None): 155 | return NeighbourExchangeBidir.apply(left_rank, right_rank, group, tensor_to_left, tensor_to_right) 156 | 157 | 158 | 159 | def get_multi_positive_mps(target, k): 160 | """ 161 | :param target: tensor of shape (b, b*k), all with values -1 at each entry 162 | :param k 163 | :return: tensor of shape (b, b*k), for each row i, the col [i*k, (i+1)*k] should be ones 164 | """ 165 | for i in range(target.shape[0]): 166 | target[i, i * k:(i + 1) * k] = 1 167 | return target 168 | 169 | 170 | 171 | def get_multi_positive_tcs(target, k): 172 | """ 173 | :param target: tensor of shape (b, b+k-1), all with values -1 at each entry 174 | :param k 175 | :return: tensor of shape (b, b+k-1), for each row i, the col [i, i+k) should be ones 176 | """ 177 | for i in range(target.shape[0]): 178 | target[i, i: i + k] = 1 179 | return target 180 | 181 | 182 | 183 | 184 | def get_mps_logits(image_features, text_features, logit_scale, logit_bias=None): 185 | logits = logit_scale * image_features @ text_features.T # if multi-cap: (B, B*K) 186 | if logit_bias is not None: 187 | logits += logit_bias 188 | return logits 189 | 190 | def get_mps_ground_truth(device, dtype, target_shape, negative_only=False, 191 | num_captions=4): 192 | dim0, dim1 = target_shape # (B, B*K) 193 | labels = -torch.ones((dim0, dim1), device=device, dtype=dtype) # (B, B*K) 194 | if not negative_only: 195 | labels = get_multi_positive_mps(target=labels, k=num_captions) 196 | return labels 197 | 198 | def get_intra_logits(image_features, text_features, logit_scale, logit_bias=None): 199 | """ 200 | image_features: (B, K, D), 201 | text_features: (B, K, D). 202 | Target: (B, K, K) 203 | """ 204 | logits = logit_scale * torch.einsum('bkd,bjd->bkj', image_features, text_features) 205 | # logits = logit_scale * image_features @ text_features.T 206 | if logit_bias is not None: 207 | logits += logit_bias 208 | return logits 209 | 210 | def get_tcs_ground_truth(device, dtype, target_shape, negative_only=False, num_captions=4): 211 | dim0, dim1 = target_shape # (B, B+K-1) 212 | labels = -torch.ones((dim0, dim1), device=device, dtype=dtype) # (B, B+K-1) 213 | if not negative_only: 214 | labels = get_multi_positive_tcs(target=labels, k=num_captions) 215 | return labels 216 | 217 | def get_tcs_logits(features_0, features_1, logit_scale, logit_bias=None): 218 | logits = logit_scale * torch.einsum('bij,bij->bi', features_0, features_1) 219 | if logit_bias is not None: 220 | logits += logit_bias 221 | return logits 222 | 223 | 224 | class FlairLoss(nn.Module): 225 | """ 226 | Implementation of FLAIR loss in: https://arxiv.org/pdf/2412.03561 227 | When setting added_mps_loss=False, this class is simply text-conditioned sigmoid loss; 228 | When added_mps_loss=True, this class is 'text-conditioned sigmod loss + multi-positive sigmoid loss' 229 | """ 230 | 231 | def __init__( 232 | self, 233 | cache_labels=False, 234 | rank=0, 235 | world_size=1, 236 | bidir=True, 237 | use_horovod=False, 238 | num_cap_per_img=8, 239 | added_mps_loss=False, 240 | ): 241 | super().__init__() 242 | self.cache_labels = cache_labels 243 | self.rank = rank 244 | self.world_size = world_size 245 | assert not use_horovod # FIXME need to look at hvd ops for ring transfers 246 | self.use_horovod = use_horovod 247 | self.bidir = bidir 248 | 249 | # cache state FIXME cache not currently used, worthwhile? 250 | self.prev_num_logits = 0 251 | self.labels = {} 252 | self.num_cap_per_img = num_cap_per_img 253 | self.added_mps_loss = added_mps_loss 254 | 255 | 256 | def _loss_with_attn_pool(self, image_features, image_tokens, text_features, logit_scale, 257 | logit_bias=None, negative_only=False, visual_proj=None, g_text_features=None): 258 | 259 | local_image_features = visual_proj(text_features, image_tokens, image_tokens) # (B, B+K-1, D) 260 | 261 | local_image_features = F.normalize(local_image_features, dim=-1) 262 | global_text_features = F.normalize(text_features, dim=-1) 263 | 264 | i2t_logits = get_tcs_logits(local_image_features, global_text_features, logit_scale, logit_bias) 265 | 266 | i2t_labels = get_tcs_ground_truth(device=text_features.device, 267 | dtype=text_features.dtype, 268 | target_shape=i2t_logits.size(), 269 | negative_only=negative_only, 270 | num_captions=self.num_cap_per_img) 271 | 272 | tcs_loss = -F.logsigmoid(i2t_labels * i2t_logits).sum() / text_features.shape[1] # text-conditioned sigmoid loss 273 | 274 | 275 | if self.added_mps_loss: 276 | g_image_features = F.normalize(image_features, dim=-1) #(B, D) 277 | g_text_features = F.normalize(g_text_features, dim=-1) #(B*K, D) 278 | mps_logits = get_mps_logits(image_features=g_image_features, text_features=g_text_features, 279 | logit_scale=logit_scale, logit_bias=logit_bias) 280 | g2g_labels = get_mps_ground_truth(device=g_text_features.device, 281 | dtype=g_text_features.dtype, 282 | target_shape=mps_logits.size(), 283 | negative_only=negative_only, 284 | num_captions=self.num_cap_per_img) 285 | mps_loss = -F.logsigmoid(g2g_labels * mps_logits).sum() / g_text_features.shape[0] # multi-positive sigmoid loss 286 | 287 | loss = (tcs_loss + mps_loss) / 2 288 | else: 289 | loss = tcs_loss 290 | 291 | 292 | return loss 293 | 294 | def forward(self, image_features, text_features, logit_scale, logit_bias, image_tokens=None, 295 | visual_proj=None, output_dict=False): 296 | ''' 297 | expected shape: text_features: (B*K, D), image_embeddings: (B, L, D) 298 | ''' 299 | if self.added_mps_loss: 300 | g_text_features = text_features # (B*K, D) 301 | else: 302 | g_text_features = None 303 | 304 | 305 | # We don't change the shape of image tokens anywhere before the loss function. 306 | batch_size = image_tokens.shape[0] 307 | num_captions = self.num_cap_per_img 308 | caption_indices = torch.arange(batch_size * num_captions).view(batch_size, num_captions).to( 309 | text_features.device) 310 | 311 | text_features = downsample_text_features(text_features=text_features, batch_size=batch_size, 312 | caption_indices=caption_indices, 313 | num_captions=num_captions) 314 | 315 | loss = self._loss_with_attn_pool(image_features=image_features, 316 | image_tokens=image_tokens, 317 | text_features=text_features, 318 | visual_proj=visual_proj, 319 | logit_scale=logit_scale, 320 | logit_bias=logit_bias, 321 | g_text_features=g_text_features) 322 | 323 | if self.world_size > 1: 324 | # exchange text features w/ neighbour world_size - 1 times 325 | right_rank = (self.rank + 1) % self.world_size 326 | left_rank = (self.rank - 1 + self.world_size) % self.world_size 327 | if self.bidir: 328 | text_features_to_right = text_features_to_left = text_features 329 | if self.added_mps_loss: 330 | g_text_features_to_right = g_text_features_to_left = g_text_features 331 | 332 | num_bidir, remainder = divmod(self.world_size - 1, 2) 333 | 334 | g_text_features_recv = None # predefine it to be None 335 | 336 | for i in range(num_bidir): 337 | text_features_recv = neighbour_exchange_bidir_with_grad( 338 | left_rank, 339 | right_rank, 340 | text_features_to_left, 341 | text_features_to_right, 342 | ) 343 | if self.added_mps_loss: 344 | g_text_features_recv = neighbour_exchange_bidir_with_grad( 345 | left_rank, 346 | right_rank, 347 | g_text_features_to_left, 348 | g_text_features_to_right, 349 | ) 350 | for j in range(len(text_features_recv)): 351 | loss += self._loss_with_attn_pool( 352 | image_features=image_features, 353 | image_tokens=image_tokens, 354 | text_features=text_features_recv[j], 355 | visual_proj=visual_proj, 356 | logit_scale=logit_scale, 357 | logit_bias=logit_bias, 358 | negative_only=True, 359 | g_text_features=g_text_features_recv[j] 360 | ) 361 | else: 362 | for f in text_features_recv: 363 | loss += self._loss_with_attn_pool( 364 | image_features=image_features, 365 | image_tokens=image_tokens, 366 | text_features=f, 367 | visual_proj=visual_proj, 368 | logit_scale=logit_scale, 369 | logit_bias=logit_bias, 370 | negative_only=True, 371 | g_text_features=None) 372 | text_features_to_left, text_features_to_right = text_features_recv 373 | if self.added_mps_loss: 374 | g_text_features_to_left, g_text_features_to_right = g_text_features_recv 375 | 376 | if remainder: 377 | text_features_recv = neighbour_exchange_with_grad( 378 | left_rank, right_rank, text_features_to_right) 379 | if self.added_mps_loss: 380 | g_text_features_recv = neighbour_exchange_with_grad( 381 | left_rank, right_rank, g_text_features_to_right) 382 | loss += self._loss_with_attn_pool( 383 | image_features=image_features, 384 | image_tokens=image_tokens, 385 | text_features=text_features_recv, 386 | visual_proj=visual_proj, 387 | logit_scale=logit_scale, 388 | logit_bias=logit_bias, 389 | negative_only=True, 390 | g_text_features=g_text_features_recv 391 | ) 392 | else: 393 | loss += self._loss_with_attn_pool( 394 | image_features=image_features, 395 | image_tokens=image_tokens, 396 | text_features=text_features_recv, 397 | visual_proj=visual_proj, 398 | logit_scale=logit_scale, 399 | logit_bias=logit_bias, 400 | negative_only=True, 401 | g_text_features=None) 402 | else: 403 | text_features_to_right = text_features 404 | if self.added_mps_loss: 405 | g_text_features_to_right = g_text_features 406 | 407 | for i in range(self.world_size - 1): 408 | text_features_from_left = neighbour_exchange_with_grad( 409 | left_rank, right_rank, text_features_to_right) 410 | 411 | if self.added_mps_loss: 412 | g_text_features_from_left = neighbour_exchange_with_grad( 413 | left_rank, right_rank, g_text_features_to_right) 414 | else: 415 | g_text_features_from_left = None 416 | 417 | loss += self._loss_with_attn_pool( 418 | image_features=image_features, 419 | image_tokens=image_tokens, 420 | text_features=text_features_from_left, 421 | visual_proj=visual_proj, 422 | logit_scale=logit_scale, 423 | logit_bias=logit_bias, 424 | negative_only=True, 425 | g_text_features=g_text_features_from_left) 426 | 427 | text_features_to_right = text_features_from_left 428 | 429 | return {"contrastive_loss": loss} if output_dict else loss 430 | 431 | 432 | 433 | 434 | def downsample_text_features(text_features, batch_size, caption_indices, num_captions): 435 | device = text_features.device 436 | own_caption_indices = caption_indices # Shape: (B, K) 437 | 438 | mask = torch.ones(batch_size, batch_size, dtype=torch.bool, device=device) 439 | mask.fill_diagonal_(False) 440 | 441 | other_image_indices = torch.arange(batch_size, device=device).unsqueeze(0).expand(batch_size, batch_size) 442 | other_image_indices = other_image_indices[mask].view(batch_size, batch_size - 1) 443 | random_offsets = torch.randint(0, num_captions, (batch_size, batch_size - 1), device=device) # (B, B-1) 444 | other_caption_indices = caption_indices[other_image_indices, random_offsets] # sampled indices (B, B-1) 445 | 446 | combined_indices = torch.cat([own_caption_indices, other_caption_indices], dim=1) 447 | combined_indices, _ = combined_indices.sort(dim=1) 448 | flat_combined_indices = combined_indices.view(-1) # flatten to take the text_features out 449 | 450 | downsampled_text_features = text_features[flat_combined_indices] 451 | 452 | embed_dim = text_features.shape[-1] # Reshape to (B, K + B - 1, D) 453 | downsampled_text_features = downsampled_text_features.view(batch_size, num_captions + batch_size - 1, embed_dim) 454 | return downsampled_text_features 455 | -------------------------------------------------------------------------------- /src/flair/model.py: -------------------------------------------------------------------------------- 1 | """ CLIP Model 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | import copy 6 | import logging 7 | import math 8 | from dataclasses import dataclass 9 | from typing import Any, Dict, Optional, Tuple, Union 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn.functional as F 14 | from torch import nn 15 | from torch.utils.checkpoint import checkpoint 16 | from functools import partial 17 | 18 | from open_clip.hf_model import HFTextEncoder 19 | from open_clip.modified_resnet import ModifiedResNet 20 | from open_clip.timm_model import TimmModel 21 | from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer, \ 22 | text_global_pool, \ 23 | PureAttentionPoolingBlock, VisionPostProcess, TextPostProcess 24 | from open_clip.utils import to_2tuple 25 | 26 | 27 | @dataclass 28 | class CLIPVisionCfg: 29 | layers: Union[Tuple[int, int, int, int], int] = 12 30 | width: int = 768 31 | head_width: int = 64 32 | mlp_ratio: float = 4.0 33 | patch_size: int = 16 34 | image_size: Union[Tuple[int, int], int] = 224 35 | 36 | ls_init_value: Optional[float] = None # layer scale initial value 37 | patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results 38 | attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type) 39 | attn_pooler_queries: int = 256 # n_queries for attentional pooler 40 | attn_pooler_heads: int = 8 # n heads for attentional_pooling 41 | no_ln_pre: bool = False # disable pre transformer LayerNorm 42 | pos_embed_type: str = 'learnable' 43 | final_ln_after_pool: bool = False # apply final LayerNorm after pooling 44 | pool_type: str = 'tok' 45 | output_tokens: bool = False 46 | act_kwargs: Optional[dict] = None 47 | norm_kwargs: Optional[dict] = None 48 | 49 | timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size 50 | timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model 51 | timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') 52 | timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') 53 | timm_proj_bias: bool = False # enable bias final projection 54 | timm_drop: float = 0. # head dropout 55 | timm_drop_path: Optional[float] = None # backbone stochastic depth 56 | 57 | 58 | @dataclass 59 | class CLIPTextCfg: 60 | context_length: int = 77 61 | vocab_size: int = 49408 62 | hf_tokenizer_name: Optional[str] = None 63 | tokenizer_kwargs: Optional[dict] = None 64 | 65 | width: int = 512 66 | heads: int = 8 67 | layers: int = 12 68 | mlp_ratio: float = 4.0 69 | ls_init_value: Optional[float] = None # layer scale initial value 70 | embed_cls: bool = False 71 | pad_id: int = 0 72 | no_causal_mask: bool = False # disable causal masking 73 | final_ln_after_pool: bool = False # apply final LayerNorm after pooling 74 | pool_type: str = 'argmax' 75 | proj_bias: bool = False 76 | output_tokens: bool = False 77 | act_kwargs: dict = None 78 | norm_kwargs: dict = None 79 | 80 | # HuggingFace specific text tower config 81 | hf_model_name: Optional[str] = None 82 | hf_model_pretrained: bool = True 83 | hf_proj_type: str = 'mlp' 84 | hf_pooler_type: str = 'mean_pooler' # attentional pooling for HF models 85 | 86 | 87 | def get_cast_dtype(precision: str): 88 | cast_dtype = None 89 | if precision == 'bf16': 90 | cast_dtype = torch.bfloat16 91 | elif precision == 'fp16': 92 | cast_dtype = torch.float16 93 | return cast_dtype 94 | 95 | 96 | def get_input_dtype(precision: str): 97 | input_dtype = None 98 | if precision in ('bf16', 'pure_bf16'): 99 | input_dtype = torch.bfloat16 100 | elif precision in ('fp16', 'pure_fp16'): 101 | input_dtype = torch.float16 102 | return input_dtype 103 | 104 | 105 | def _build_vision_tower( 106 | embed_dim: int, 107 | vision_cfg: CLIPVisionCfg, 108 | quick_gelu: bool = False, 109 | cast_dtype: Optional[torch.dtype] = None, 110 | project_tokens: bool = False, 111 | text_con: bool = False, 112 | skip_final_pooling: bool = False 113 | ): 114 | if isinstance(vision_cfg, dict): 115 | vision_cfg = CLIPVisionCfg(**vision_cfg) 116 | 117 | # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more 118 | # memory efficient in recent PyTorch releases (>= 1.10). 119 | # NOTE: timm models always use native GELU regardless of quick_gelu flag. 120 | act_layer = QuickGELU if quick_gelu else nn.GELU 121 | 122 | if vision_cfg.timm_model_name: 123 | visual = TimmModel( 124 | vision_cfg.timm_model_name, 125 | pretrained=vision_cfg.timm_model_pretrained, 126 | pool=vision_cfg.timm_pool, 127 | proj=vision_cfg.timm_proj, 128 | proj_bias=vision_cfg.timm_proj_bias, 129 | drop=vision_cfg.timm_drop, 130 | drop_path=vision_cfg.timm_drop_path, 131 | patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None, 132 | embed_dim=embed_dim, 133 | image_size=vision_cfg.image_size, 134 | ) 135 | elif isinstance(vision_cfg.layers, (tuple, list)): 136 | vision_heads = vision_cfg.width * 32 // vision_cfg.head_width 137 | visual = ModifiedResNet( 138 | layers=vision_cfg.layers, 139 | output_dim=embed_dim, 140 | heads=vision_heads, 141 | image_size=vision_cfg.image_size, 142 | width=vision_cfg.width, 143 | ) 144 | else: 145 | vision_heads = vision_cfg.width // vision_cfg.head_width 146 | norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm 147 | if vision_cfg.norm_kwargs: 148 | norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs) 149 | if vision_cfg.act_kwargs is not None: 150 | act_layer = partial(act_layer, **vision_cfg.act_kwargs) 151 | 152 | visual = VisionTransformer( 153 | image_size=vision_cfg.image_size, 154 | patch_size=vision_cfg.patch_size, 155 | width=vision_cfg.width, 156 | layers=vision_cfg.layers, 157 | heads=vision_heads, 158 | mlp_ratio=vision_cfg.mlp_ratio, 159 | ls_init_value=vision_cfg.ls_init_value, 160 | patch_dropout=vision_cfg.patch_dropout, 161 | attentional_pool=vision_cfg.attentional_pool, 162 | attn_pooler_queries=vision_cfg.attn_pooler_queries, 163 | attn_pooler_heads=vision_cfg.attn_pooler_heads, 164 | pos_embed_type=vision_cfg.pos_embed_type, 165 | no_ln_pre=vision_cfg.no_ln_pre, 166 | final_ln_after_pool=vision_cfg.final_ln_after_pool, 167 | pool_type=vision_cfg.pool_type, 168 | output_tokens=vision_cfg.output_tokens, 169 | output_dim=embed_dim, 170 | act_layer=act_layer, 171 | norm_layer=norm_layer, 172 | project_tokens=project_tokens, 173 | text_con=text_con, 174 | skip_final_pooling=skip_final_pooling 175 | ) 176 | 177 | return visual 178 | 179 | 180 | def _build_text_tower( 181 | embed_dim: int, 182 | text_cfg: CLIPTextCfg, 183 | quick_gelu: bool = False, 184 | cast_dtype: Optional[torch.dtype] = None, 185 | text_con_pooling: bool = False 186 | ): 187 | if isinstance(text_cfg, dict): 188 | text_cfg = CLIPTextCfg(**text_cfg) 189 | 190 | if text_cfg.hf_model_name: 191 | text = HFTextEncoder( 192 | text_cfg.hf_model_name, 193 | output_dim=embed_dim, 194 | proj_type=text_cfg.hf_proj_type, 195 | pooler_type=text_cfg.hf_pooler_type, 196 | pretrained=text_cfg.hf_model_pretrained, 197 | output_tokens=text_cfg.output_tokens, 198 | ) 199 | else: 200 | act_layer = QuickGELU if quick_gelu else nn.GELU 201 | norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm 202 | if text_cfg.norm_kwargs: 203 | norm_layer = partial(norm_layer, **text_cfg.norm_kwargs) 204 | if text_cfg.act_kwargs is not None: 205 | act_layer = partial(act_layer, **text_cfg.act_kwargs) 206 | 207 | text = TextTransformer( 208 | context_length=text_cfg.context_length, 209 | vocab_size=text_cfg.vocab_size, 210 | width=text_cfg.width, 211 | heads=text_cfg.heads, 212 | layers=text_cfg.layers, 213 | mlp_ratio=text_cfg.mlp_ratio, 214 | ls_init_value=text_cfg.ls_init_value, 215 | output_dim=embed_dim, 216 | embed_cls=text_cfg.embed_cls, 217 | no_causal_mask=text_cfg.no_causal_mask, 218 | pad_id=text_cfg.pad_id, 219 | pool_type=text_cfg.pool_type, 220 | proj_bias=text_cfg.proj_bias, 221 | output_tokens=text_cfg.output_tokens, 222 | act_layer=act_layer, 223 | norm_layer=norm_layer 224 | ) 225 | return text 226 | 227 | 228 | def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): 229 | """Convert applicable model parameters to low-precision (bf16 or fp16)""" 230 | 231 | def _convert_weights(l): 232 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 233 | l.weight.data = l.weight.data.to(dtype) 234 | if l.bias is not None: 235 | l.bias.data = l.bias.data.to(dtype) 236 | 237 | if isinstance(l, (nn.MultiheadAttention, Attention)): 238 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 239 | tensor = getattr(l, attr) 240 | if tensor is not None: 241 | tensor.data = tensor.data.to(dtype) 242 | 243 | if isinstance(l, (CLIP, TextTransformer)): 244 | # convert text nn.Parameter projections 245 | attr = getattr(l, "text_projection", None) 246 | if attr is not None: 247 | attr.data = attr.data.to(dtype) 248 | 249 | if isinstance(l, VisionTransformer): 250 | # convert vision nn.Parameter projections 251 | attr = getattr(l, "proj", None) 252 | if attr is not None: 253 | attr.data = attr.data.to(dtype) 254 | 255 | model.apply(_convert_weights) 256 | 257 | 258 | convert_weights_to_fp16 = convert_weights_to_lp # backwards compat 259 | 260 | 261 | # used to maintain checkpoint compatibility 262 | def convert_to_custom_text_state_dict(state_dict: dict): 263 | if 'text_projection' in state_dict: 264 | # old format state_dict, move text tower -> .text 265 | new_state_dict = {} 266 | for k, v in state_dict.items(): 267 | if any(k.startswith(p) for p in ( 268 | 'text_projection', 269 | 'positional_embedding', 270 | 'token_embedding', 271 | 'transformer', 272 | 'ln_final', 273 | )): 274 | k = 'text.' + k 275 | new_state_dict[k] = v 276 | return new_state_dict 277 | return state_dict 278 | 279 | 280 | def build_model_from_openai_state_dict( 281 | state_dict: dict, 282 | quick_gelu=True, 283 | cast_dtype=torch.float16, 284 | ): 285 | vit = "visual.proj" in state_dict 286 | 287 | if vit: 288 | vision_width = state_dict["visual.conv1.weight"].shape[0] 289 | vision_layers = len( 290 | [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 291 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 292 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 293 | image_size = vision_patch_size * grid_size 294 | else: 295 | counts: list = [ 296 | len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 297 | vision_layers = tuple(counts) 298 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 299 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 300 | vision_patch_size = None 301 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 302 | image_size = output_width * 32 303 | 304 | embed_dim = state_dict["text_projection"].shape[1] 305 | context_length = state_dict["positional_embedding"].shape[0] 306 | vocab_size = state_dict["token_embedding.weight"].shape[0] 307 | transformer_width = state_dict["ln_final.weight"].shape[0] 308 | transformer_heads = transformer_width // 64 309 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 310 | 311 | vision_cfg = CLIPVisionCfg( 312 | layers=vision_layers, 313 | width=vision_width, 314 | patch_size=vision_patch_size, 315 | image_size=image_size, 316 | ) 317 | text_cfg = CLIPTextCfg( 318 | context_length=context_length, 319 | vocab_size=vocab_size, 320 | width=transformer_width, 321 | heads=transformer_heads, 322 | layers=transformer_layers, 323 | ) 324 | model = CLIP( 325 | embed_dim, 326 | vision_cfg=vision_cfg, 327 | text_cfg=text_cfg, 328 | quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU 329 | cast_dtype=cast_dtype, 330 | ) 331 | 332 | for key in ["input_resolution", "context_length", "vocab_size"]: 333 | state_dict.pop(key, None) 334 | convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 335 | model.load_state_dict(state_dict) 336 | return model.eval() 337 | 338 | 339 | def trace_model(model, batch_size=256, device=torch.device('cpu')): 340 | model.eval() 341 | image_size = model.visual.image_size 342 | example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) 343 | example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) 344 | model = torch.jit.trace_module( 345 | model, 346 | inputs=dict( 347 | forward=(example_images, example_text), 348 | encode_text=(example_text,), 349 | encode_image=(example_images,) 350 | )) 351 | model.visual.image_size = image_size 352 | return model 353 | 354 | 355 | def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True): 356 | # Rescale the grid of position embeddings when loading from state_dict 357 | old_pos_embed = state_dict.get('visual.positional_embedding', None) 358 | if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): 359 | return 360 | grid_size = to_2tuple(model.visual.grid_size) 361 | extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) 362 | new_seq_len = grid_size[0] * grid_size[1] + extra_tokens 363 | if new_seq_len == old_pos_embed.shape[0]: 364 | return 365 | 366 | if extra_tokens: 367 | pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] 368 | else: 369 | pos_emb_tok, pos_emb_img = None, old_pos_embed 370 | old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) 371 | 372 | logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) 373 | pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) 374 | pos_emb_img = F.interpolate( 375 | pos_emb_img, 376 | size=grid_size, 377 | mode=interpolation, 378 | antialias=antialias, 379 | align_corners=False, 380 | ) 381 | pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] 382 | if pos_emb_tok is not None: 383 | new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) 384 | else: 385 | new_pos_embed = pos_emb_img 386 | state_dict['visual.positional_embedding'] = new_pos_embed 387 | 388 | 389 | def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False): 390 | old_pos_embed = state_dict.get('positional_embedding', None) 391 | if old_pos_embed is None: 392 | return 393 | # FIXME add support for text cls_token 394 | model_pos_embed = getattr(model, 'positional_embedding', None) 395 | if model_pos_embed is None: 396 | model_pos_embed = getattr(model.text, 'positional_embedding', None) 397 | 398 | old_num_pos = old_pos_embed.shape[0] 399 | old_width = old_pos_embed.shape[1] 400 | num_pos = model_pos_embed.shape[0] 401 | width = model_pos_embed.shape[1] 402 | assert old_width == width, 'text pos_embed width changed!' 403 | if old_num_pos == num_pos: 404 | return 405 | 406 | logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos) 407 | old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1) 408 | old_pos_embed = F.interpolate( 409 | old_pos_embed, 410 | size=num_pos, 411 | mode=interpolation, 412 | antialias=antialias, 413 | align_corners=False, 414 | ) 415 | old_pos_embed = old_pos_embed.permute(0, 2, 1)[0] 416 | new_pos_embed = old_pos_embed 417 | 418 | state_dict['positional_embedding'] = new_pos_embed 419 | 420 | 421 | def get_model_preprocess_cfg(model): 422 | module = getattr(model, 'visual', model) 423 | preprocess_cfg = getattr(module, 'preprocess_cfg', {}) 424 | if not preprocess_cfg: 425 | # use separate legacy attributes if preprocess_cfg dict not found 426 | size = getattr(module, 'image_size') 427 | if size is not None: 428 | preprocess_cfg['size'] = size 429 | mean = getattr(module, 'image_mean', None) 430 | if mean is not None: 431 | preprocess_cfg['mean'] = mean 432 | std = getattr(module, 'image_std', None) 433 | if std is not None: 434 | preprocess_cfg['std'] = std 435 | return preprocess_cfg 436 | 437 | 438 | def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]): 439 | module = getattr(model, 'visual', model) 440 | module.image_mean = preprocess_cfg['mean'] # legacy attribute, keeping for bwd compat 441 | module.image_std = preprocess_cfg['std'] # legacy attribute, keeping for bwd compat 442 | module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict 443 | 444 | 445 | def get_model_tokenize_cfg(model): 446 | module = getattr(model, 'text', model) 447 | cfg = {} 448 | context_length = getattr(module, 'context_length', None) 449 | if context_length is not None: 450 | cfg['context_length'] = context_length 451 | vocab_size = getattr(module, 'vocab_size', None) 452 | if vocab_size is not None: 453 | cfg['vocab_size'] = vocab_size 454 | return cfg 455 | 456 | 457 | 458 | class FLAIR(nn.Module): 459 | output_dict: torch.jit.Final[bool] 460 | 461 | def __init__( 462 | self, 463 | embed_dim: int, 464 | vision_cfg: CLIPVisionCfg, 465 | text_cfg: CLIPTextCfg, 466 | quick_gelu: bool = False, 467 | init_logit_scale: float = np.log(1 / 0.07), 468 | init_logit_bias: Optional[float] = None, 469 | cast_dtype: Optional[torch.dtype] = None, 470 | output_dict: bool = False 471 | ): 472 | super().__init__() 473 | self.output_dict = output_dict 474 | 475 | self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype, 476 | project_tokens=False, text_con=True, skip_final_pooling=False) 477 | 478 | self.visual_proj = PureAttentionPoolingBlock(context_dim=embed_dim) 479 | 480 | 481 | text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype, text_con_pooling=True) 482 | self.transformer = text.transformer 483 | self.context_length = text.context_length 484 | self.vocab_size = text.vocab_size 485 | self.token_embedding = text.token_embedding 486 | self.positional_embedding = text.positional_embedding 487 | self.ln_final = text.ln_final 488 | self.text_pool_type = text.pool_type 489 | self.register_buffer('attn_mask', text.attn_mask, persistent=False) 490 | 491 | self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) 492 | if init_logit_bias is not None: 493 | self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias) 494 | else: 495 | self.logit_bias = None 496 | 497 | self.image_post = VisionPostProcess(context_dim=self.visual.width, 498 | output_dim=embed_dim, 499 | normalize_final=False, 500 | skip_ln=True) 501 | self.text_post = TextPostProcess(context_dim=text.width, 502 | output_dim=embed_dim, 503 | normalize_final=False, 504 | skip_ln=True) 505 | self.text_projection = None 506 | # we don't need text_projection at this point, text_post already does it 507 | 508 | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): 509 | # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 510 | self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) 511 | 512 | @torch.jit.ignore 513 | def set_grad_checkpointing(self, enable=True): 514 | self.visual.set_grad_checkpointing(enable) 515 | self.transformer.grad_checkpointing = enable 516 | 517 | def encode_image(self, image, normalize: bool = False): 518 | global_image_token, local_image_tokens = self.visual(image) 519 | return global_image_token, local_image_tokens 520 | 521 | def encode_text(self, text, normalize: bool = False, return_tokens: bool = True): 522 | txt_dim = text.shape[-1] 523 | cast_dtype = self.transformer.get_cast_dtype() 524 | 525 | text = text.view(-1, txt_dim) # (B*K, 77) 526 | 527 | x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] 528 | 529 | x = x + self.positional_embedding.to(cast_dtype) 530 | x = self.transformer(x, attn_mask=self.attn_mask) 531 | x = self.ln_final(x) # [B*K, n_ctx, transformer.width] 532 | global_text_token, local_text_tokens = text_global_pool(x, text, self.text_pool_type) # (B*K, L, D) 533 | if self.text_projection is not None: 534 | if isinstance(self.text_projection, nn.Linear): 535 | global_text_token = self.text_projection(global_text_token) # (B*K, N) as queries 536 | local_text_tokens = self.text_projection(local_text_tokens) 537 | else: 538 | global_text_token = global_text_token @ self.text_projection 539 | local_text_tokens = local_text_tokens @ self.text_projection 540 | 541 | return global_text_token, local_text_tokens 542 | 543 | def get_logits(self, image, text): 544 | """ 545 | FLAIR's way to get the logits. Only used as a minimal example to get the logits, not used in training or inference at this stage. 546 | """ 547 | global_image_token, local_image_tokens = self.encode_image(image) 548 | global_text_token, _ = self.encode_text(text) 549 | global_text_token = self.text_post(global_text_token) # (B*K, D) 550 | global_image_token, local_image_tokens = self.image_post(global_image_token), self.image_post( 551 | local_image_tokens) # (B, D), (B, L, D) 552 | batch_size = global_image_token.shape[0] 553 | 554 | # Broadcast the global text token to (B, B*K, D), this is too costly in large-scale training, so we downsample them to (B, B+K-1, D) in training 555 | global_text_token = global_text_token.unsqueeze(0).expand(batch_size, -1, -1) 556 | 557 | local_image_features = self.visual_proj(global_text_token, local_image_tokens, local_image_tokens) # (B, B*K, D) 558 | 559 | text_features, image_features = F.normalize(global_text_token, dim=-1), F.normalize(local_image_features, dim=-1) 560 | 561 | image_logits = self.logit_scale.exp() * torch.einsum('bij,bij->bi', image_features, text_features) # (B, B*K) 562 | image_logits += self.logit_bias 563 | 564 | text_logits = image_logits.T 565 | 566 | return image_logits, text_logits 567 | 568 | def get_logits_as_clip(self, image, text): 569 | """ 570 | FLAIR could also generate the global-to-global logits as the original CLIP does. 571 | """ 572 | global_image_token, _ = self.encode_image(image) 573 | global_text_token, _ = self.encode_text(text) 574 | 575 | 576 | global_image_token = self.image_post(global_image_token) # (B, D) 577 | global_text_token = self.text_post(global_text_token) # (B*K, D) 578 | 579 | image_features, text_features = F.normalize(global_image_token, dim=-1), F.normalize(global_text_token, dim=-1) 580 | 581 | image_logits = self.logit_scale.exp() * image_features @ text_features.t() 582 | text_logits = image_logits.T 583 | 584 | return image_logits, text_logits 585 | 586 | def forward( 587 | self, 588 | image: Optional[torch.Tensor] = None, 589 | text: Optional[torch.Tensor] = None, 590 | ): 591 | global_image_token, local_image_tokens = self.encode_image(image) 592 | global_text_token, local_text_tokens = self.encode_text(text) 593 | global_text_token, local_text_tokens = self.text_post(global_text_token), self.text_post(local_text_tokens) 594 | global_image_token, local_image_tokens = self.image_post(global_image_token), self.image_post(local_image_tokens) 595 | 596 | out_dict = { 597 | "image_features": global_image_token, 598 | "image_tokens": local_image_tokens, 599 | "text_features": global_text_token, 600 | "logit_scale": self.logit_scale.exp(), 601 | "visual_proj": self.visual_proj 602 | } 603 | 604 | if self.logit_bias is not None: 605 | out_dict['logit_bias'] = self.logit_bias 606 | return out_dict 607 | -------------------------------------------------------------------------------- /src/flair/model_configs/ViT-B-16-FLAIR.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "init_logit_bias": -10, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 16, 9 | "output_tokens": true 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 512, 15 | "heads": 8, 16 | "layers": 12 17 | } 18 | } -------------------------------------------------------------------------------- /src/flair/params.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is adapted from OpenCLIP: 3 | https://github.com/mlfoundations/open_clip/blob/main/src/open_clip_train/params.py 4 | 5 | The code integrates additional modifications and extensions to support the FLAIR models. 6 | Original authors: ML Foundations. 7 | """ 8 | import argparse 9 | import ast 10 | 11 | 12 | def get_default_params(model_name): 13 | # Params from paper (https://arxiv.org/pdf/2103.00020.pdf) 14 | model_name = model_name.lower() 15 | if "vit" in model_name: 16 | return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6} 17 | else: 18 | return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8} 19 | 20 | 21 | class ParseKwargs(argparse.Action): 22 | def __call__(self, parser, namespace, values, option_string=None): 23 | kw = {} 24 | for value in values: 25 | key, value = value.split('=') 26 | try: 27 | kw[key] = ast.literal_eval(value) 28 | except ValueError: 29 | kw[key] = str(value) # fallback to string (avoid need to escape on command line) 30 | setattr(namespace, self.dest, kw) 31 | 32 | 33 | def parse_args(args): 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument( 36 | "--train-data", 37 | type=str, 38 | default=None, 39 | help="Path to file(s) with training data. When using webdataset, multiple datasources can be combined using the `::` separator.", 40 | ) 41 | parser.add_argument( 42 | "--train-data-upsampling-factors", 43 | type=str, 44 | default=None, 45 | help=( 46 | "When using multiple data sources with webdataset and sampling with replacement, this can be used to upsample specific data sources. " 47 | "Similar to --train-data, this should be a string with as many numbers as there are data sources, separated by `::` (e.g. 1::2::0.5) " 48 | "By default, datapoints are sampled uniformly regardless of the dataset sizes." 49 | ) 50 | ) 51 | parser.add_argument( 52 | "--data-root-dir", 53 | type=str, 54 | default='', 55 | help=( 56 | "Root directory to your dataset, especially the COCO dataset." 57 | ) 58 | ) 59 | parser.add_argument( 60 | "--cc3m-train-retrieval-dir", 61 | type=str, 62 | default='', 63 | help=( 64 | "Root directory to the train retrieval dataset subsampled from cc3m." 65 | ) 66 | ) 67 | parser.add_argument( 68 | "--sharegpt4v-retrieval-dir", 69 | type=str, 70 | default='', 71 | help=( 72 | "Root directory to the share4v dataset." 73 | ) 74 | ) 75 | parser.add_argument( 76 | "--dci-retrieval-dir", 77 | type=str, 78 | default='', 79 | help=( 80 | "Root directory to the train dci daatset." 81 | ) 82 | ) 83 | parser.add_argument( 84 | "--iiw-retrieval-dir", 85 | type=str, 86 | default='', 87 | help=( 88 | "Root directory to the image in words dataset." 89 | ) 90 | ) 91 | parser.add_argument( 92 | "--docci-retrieval-dir", 93 | type=str, 94 | default='', 95 | help=( 96 | "Root directory to fine-grained docci retrieval." 97 | ) 98 | ) 99 | parser.add_argument( 100 | "--urban-1k-retrieval-dir", 101 | type=str, 102 | default='', 103 | help=( 104 | "Root directory to fine-grained urban-1k retrieval." 105 | ) 106 | ) 107 | parser.add_argument( 108 | "--zeroshot-eval-datasets", 109 | type=str, 110 | default=None, 111 | help=( 112 | "Datasets that you want to do retrieval." 113 | ) 114 | ) 115 | parser.add_argument( 116 | "--coco-data-root-dir", 117 | type=str, 118 | default='', 119 | help=( 120 | "Root directory to the COCO dataset." 121 | ) 122 | ) 123 | parser.add_argument( 124 | "--flickr-data-root-dir", 125 | type=str, 126 | default='', 127 | help=( 128 | "Root directory to the flickr datasets (but we simply use the root of the whole dataset)." 129 | ) 130 | ) 131 | parser.add_argument( 132 | "--val-data", 133 | type=str, 134 | default=None, 135 | help="Path to file(s) with validation data.", 136 | ) 137 | parser.add_argument( 138 | "--dict-root-dir", 139 | type=str, 140 | default=None, 141 | help="Path to the preprocessed dictionaries to filter the dataset.", 142 | ) 143 | parser.add_argument( 144 | "--train-num-samples", 145 | type=int, 146 | default=None, 147 | help="Number of samples in dataset. Required for webdataset if not available in info file.", 148 | ) 149 | parser.add_argument( 150 | "--input-size", 151 | type=int, 152 | default=224, 153 | help="Input size for the zero-shot evaluation task", 154 | ) 155 | parser.add_argument( 156 | "--output-file", 157 | type=str, 158 | default=None, 159 | help="Output txt file for documenting the results", 160 | ) 161 | parser.add_argument( 162 | "--val-num-samples", 163 | type=int, 164 | default=None, 165 | help="Number of samples in dataset. Useful for webdataset if not available in info file.", 166 | ) 167 | parser.add_argument( 168 | "--train-val-num-samples", 169 | type=int, 170 | default=None, 171 | help="Number of samples in train eval dataset. Useful for webdataset if not available in info file.", 172 | ) 173 | parser.add_argument( 174 | "--train-eval-data", 175 | type=str, 176 | default=None, 177 | help="Path to evaluation set within train data with training data. When using webdataset, multiple datasources can be combined using the `::` separator.", 178 | ) 179 | parser.add_argument( 180 | "--dataset-type", 181 | choices=["webdataset", "csv", "synthetic", "auto", "coco", "flickr"], 182 | default="coco", 183 | help="Which type of dataset to process." 184 | ) 185 | parser.add_argument( 186 | "--train-dataset-type", 187 | choices=["webdataset", "csv", "synthetic", "auto", "coco", "flickr"], 188 | default="csv", 189 | help="Which type of dataset to process." 190 | ) 191 | parser.add_argument( 192 | "--val-dataset-type", 193 | choices=["webdataset", "csv", "synthetic", "auto", "coco", "flickr"], 194 | default="coco", 195 | help="Which type of dataset to process." 196 | ) 197 | parser.add_argument( 198 | "--retrieval-dataset-type", 199 | choices=["coco", "flickr"], 200 | default="coco", 201 | help="Which type of dataset for the retrieval task." 202 | ) 203 | parser.add_argument( 204 | "--dataset-resampled", 205 | default=False, 206 | action="store_true", 207 | help="Whether to use sampling with replacement for webdataset shard selection." 208 | ) 209 | parser.add_argument( 210 | "--inference-with-flair", 211 | default=False, 212 | action="store_true", 213 | help="If set to true, then we use FLAIR way of inference." 214 | ) 215 | parser.add_argument( 216 | "--fixed-merged-num", 217 | default=False, 218 | action="store_true", 219 | help="If true, then we fix the merging number in variable merge." 220 | ) 221 | parser.add_argument( 222 | "--all-subsequent-merged", 223 | default=False, 224 | action="store_true", 225 | help="If enabled, then we only merge subsequent captions." 226 | ) 227 | parser.add_argument( 228 | "--use-flair-loss", 229 | default=False, 230 | action="store_true", 231 | help="Whether to use the text-conditioned sigmoid loss or not." 232 | ) 233 | parser.add_argument( 234 | "--add-mps-loss", 235 | default=False, 236 | action="store_true", 237 | help="Whether to add the multi-positive loss or not." 238 | ) 239 | parser.add_argument( 240 | "--directly-use-attn-weights", 241 | default=False, 242 | action="store_true", 243 | help="Directly use attn weights for segmentation or not." 244 | ) 245 | parser.add_argument( 246 | "--sampled-textcon-siglip-loss", 247 | default=False, 248 | action="store_true", 249 | help="Whether to use the sampled textcon siglip loss or not." 250 | ) 251 | parser.add_argument( 252 | "--add-global-loss-textcon", 253 | default=False, 254 | action="store_true", 255 | help="Whether to add the global loss or not." 256 | ) 257 | parser.add_argument( 258 | "--only-global-loss-attn-pool", 259 | default=False, 260 | action="store_true", 261 | help="Whether to add the global loss or not." 262 | ) 263 | parser.add_argument( 264 | "--add-global-loss-textcon-with-attn-pool", 265 | default=False, 266 | action="store_true", 267 | help="Whether to add the global loss with extra attn pool or not." 268 | ) 269 | parser.add_argument( 270 | "--pixelprose", 271 | default=False, 272 | action="store_true", 273 | help="Set to true to remind te webdataset to adapt to the pixelprose format." 274 | ) 275 | parser.add_argument( 276 | "--datacomp", 277 | default=False, 278 | action="store_true", 279 | help="Set to true to remind te webdataset to adapt to the datacomp format." 280 | ) 281 | parser.add_argument( 282 | "--add-global-loss", 283 | default=False, 284 | action="store_true", 285 | help="Whether to add the global loss implementation or not." 286 | ) 287 | parser.add_argument( 288 | "--add-intra-sample-loss", 289 | default=False, 290 | action="store_true", 291 | help="Whether to add the intra sample loss or not." 292 | ) 293 | parser.add_argument( 294 | "--cross-con", 295 | default=False, 296 | action="store_true", 297 | help="Using Cross conditioned model and loss or not." 298 | ) 299 | parser.add_argument( 300 | "--text-con-with-down-proj", 301 | default=False, 302 | action="store_true", 303 | help="Using the new Text-conditioned model with down-proj or not." 304 | ) 305 | parser.add_argument( 306 | "--use-csa", 307 | default=False, 308 | action="store_true", 309 | help='For segmentation evaluation use correlative self-attention by SCLIP.' 310 | ) 311 | parser.add_argument( 312 | "--seg-model", 313 | type=str, 314 | default="", 315 | help=''' For segmentation evaluation, name of the openAI model otherwise evaluate from resume checkpoint 316 | ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']''' 317 | ) 318 | parser.add_argument( 319 | "--show-dir", 320 | type=str, 321 | default="", 322 | help=''' Directory for saving the visualizations for segmentation.''' 323 | ) 324 | parser.add_argument( 325 | "--max-merged-num", type=int, default=3, help="Maximum number of merging." 326 | ) 327 | parser.add_argument( 328 | "--cross-con-with-post-process", 329 | default=False, 330 | action="store_true", 331 | help="Using Cross conditioned model with post processing and loss or not." 332 | ) 333 | parser.add_argument( 334 | "--add-global-loss-in-sampled-cross-con", 335 | default=False, 336 | action="store_true", 337 | help="Using Cross conditioned model with post processing and loss or not." 338 | ) 339 | parser.add_argument( 340 | "--cross-con-with-down-proj", 341 | default=False, 342 | action="store_true", 343 | help="Whether to use the corss conditioned model with down-projected embed dim or not." 344 | ) 345 | parser.add_argument( 346 | "--txt-con-attn-pool", 347 | default=False, 348 | action="store_true", 349 | help="To use text-conditioned attention pooling or not." 350 | ) 351 | parser.add_argument( 352 | "--evaluate-as-original-clip", 353 | default=False, 354 | action="store_true", 355 | help="Though text-conditioned, still evaluate in the original CLIP way." 356 | ) 357 | parser.add_argument( 358 | "--evaluate-as-text-conditioned", 359 | default=False, 360 | action="store_true", 361 | help="Using Text-conditioned and evaluated as text-conditioned way." 362 | ) 363 | parser.add_argument( 364 | "--retrieval-coco", 365 | default=False, 366 | action="store_true", 367 | help="If true, then we enable the coco retrieval task.") 368 | 369 | parser.add_argument( 370 | "--retrieval-dci", 371 | default=False, 372 | action="store_true", 373 | help="If true, then we enable the dci retrieval task.") 374 | 375 | parser.add_argument( 376 | "--retrieval-iiw", 377 | default=False, 378 | action="store_true", 379 | help="If true, then we enable the iiw retrieval task.") 380 | 381 | parser.add_argument( 382 | "--use-finegrained-iiw", 383 | default=True, 384 | action="store_true", 385 | help="If set to true, under the condition that we enable iiw, we further use the fine-grained iiw mode.") 386 | 387 | parser.add_argument( 388 | "--retrieval-sharegpt4v-1k", 389 | default=False, 390 | action="store_true", 391 | help="If true, then we enable the sharegpt4v retrieval task with 1k data size.") 392 | 393 | parser.add_argument( 394 | "--retrieval-sharegpt4v-10k", 395 | default=False, 396 | action="store_true", 397 | help="If true, then we enable the sharegpt4v retrieval task with 10k data size.") 398 | 399 | parser.add_argument( 400 | "--retrieval-flickr", 401 | default=False, 402 | action="store_true", 403 | help="If true, then we enable the flickr retrieval task.") 404 | parser.add_argument( 405 | "--add-global-loss-cross-con", 406 | default=False, 407 | action="store_true", 408 | help="If true, then we add global loss to cross-condition setting.") 409 | parser.add_argument( 410 | "--add-global-loss-cross-con-mean", 411 | default=False, 412 | action="store_true", 413 | help="If true, then we add global loss to cross-condition mean setting.") 414 | parser.add_argument( 415 | "--add-pooled-global-loss-cross-con", 416 | default=False, 417 | action="store_true", 418 | help="If true, then we add pooled global loss to cross-condition setting.") 419 | parser.add_argument( 420 | "--txt-self-attn", 421 | default=False, 422 | action="store_true", 423 | help="If true, then we add pooled attn pooling also to generate the text embeddings.") 424 | 425 | parser.add_argument( 426 | "--retrieval-urban-1k", 427 | default=False, 428 | action="store_true", 429 | help="If true, then we enable the urban-1k retrieval task.") 430 | parser.add_argument( 431 | "--retrieval-data-cc3m-train", 432 | default=False, 433 | action="store_true", 434 | help="If true, then we enable the cc3m retrieval task.") 435 | parser.add_argument( 436 | "--retrieval-docci", 437 | default=False, 438 | action="store_true", 439 | help="If true, then we enable the DOCCI retrieval task.") 440 | 441 | parser.add_argument( 442 | "--use-original-openclip-csv-dataset", 443 | default=False, 444 | action="store_true", 445 | help="Whether to use the original openclip csv dataset or not, if false, then use new csv dataset." 446 | ) 447 | parser.add_argument( 448 | "--csv-separator", 449 | type=str, 450 | default="\t", 451 | help="For csv-like datasets, which separator to use." 452 | ) 453 | parser.add_argument( 454 | "--csv-img-key", 455 | type=str, 456 | default="filepath", 457 | help="For csv-like datasets, the name of the key for the image paths." 458 | ) 459 | parser.add_argument( 460 | "--csv-caption-key", 461 | type=str, 462 | default="title", 463 | help="For csv-like datasets, the name of the key for the captions." 464 | ) 465 | parser.add_argument( 466 | "--imagenet-val", 467 | type=str, 468 | default=None, 469 | help="Path to imagenet val set for conducting zero shot evaluation.", 470 | ) 471 | parser.add_argument( 472 | "--imagenet-v2", 473 | type=str, 474 | default=None, 475 | help="Path to imagenet v2 for conducting zero shot evaluation.", 476 | ) 477 | parser.add_argument( 478 | "--logs-dir", 479 | type=str, 480 | default="./logs/", 481 | help="Where to store tensorboard logs. Use None to avoid storing logs.", 482 | ) 483 | parser.add_argument( 484 | "--flickr-val-or-test", 485 | type=str, 486 | default='val', 487 | choices=['val', 'testing'], 488 | help="Which dataset to be used for inference, default choices are val or test.", 489 | ) 490 | parser.add_argument( 491 | "--huggingface-model-name", 492 | type=str, 493 | default="", 494 | help="Name of the huggingface model." 495 | ) 496 | parser.add_argument( 497 | "--huggingface-repo-name", 498 | type=str, 499 | default="", 500 | help="Name of the huggingface repo." 501 | ) 502 | parser.add_argument( 503 | "--ablation-negative-type", 504 | type=str, 505 | default=None, 506 | choices=['ijj', 'iji', 'ijk', 'iij', 'intra'], 507 | help="Denote the ablation negative type for the abation study.", 508 | ) 509 | parser.add_argument( 510 | "--log-local", 511 | action="store_true", 512 | default=False, 513 | help="Log files on local master, otherwise global master only.", 514 | ) 515 | parser.add_argument( 516 | "--random-select-text-tokens", 517 | action="store_true", 518 | default=False, 519 | help="To randomly select the text tokens for the img-con text tokens pooling or not.", 520 | ) 521 | parser.add_argument( 522 | "--use-siglip", 523 | action="store_true", 524 | default=False, 525 | help="Whether to use the siglip loss for text conditioned model or not.", 526 | ) 527 | parser.add_argument( 528 | "--name", 529 | type=str, 530 | default=None, 531 | help="Optional identifier for the experiment when storing logs. Otherwise use current time.", 532 | ) 533 | parser.add_argument( 534 | "--workers", type=int, default=4, help="Number of dataloader workers per GPU." 535 | ) 536 | parser.add_argument( 537 | "--batch-size", type=int, default=64, help="Batch size per GPU." 538 | ) 539 | parser.add_argument( 540 | "--epochs", type=int, default=32, help="Number of epochs to train for." 541 | ) 542 | parser.add_argument( 543 | "--epochs-cooldown", type=int, default=None, 544 | help="When scheduler w/ cooldown used, perform cooldown from total_epochs - cooldown_epochs onwards." 545 | ) 546 | parser.add_argument("--lr", type=float, default=None, help="Learning rate.") 547 | parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.") 548 | parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.") 549 | parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.") 550 | parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.") 551 | parser.add_argument( 552 | "--warmup", type=int, default=10000, help="Number of steps to warmup for." 553 | ) 554 | parser.add_argument( 555 | "--use-bn-sync", 556 | default=False, 557 | action="store_true", 558 | help="Whether to use batch norm sync.") 559 | parser.add_argument( 560 | "--skip-scheduler", 561 | action="store_true", 562 | default=False, 563 | help="Use this flag to skip the learning rate decay.", 564 | ) 565 | parser.add_argument( 566 | "--lr-scheduler", 567 | type=str, 568 | default='cosine', 569 | help="LR scheduler. One of: 'cosine', 'const' (constant), 'const-cooldown' (constant w/ cooldown). Default: cosine", 570 | ) 571 | parser.add_argument( 572 | "--lr-cooldown-end", type=float, default=0.0, 573 | help="End learning rate for cooldown schedule. Default: 0" 574 | ) 575 | parser.add_argument( 576 | "--lr-cooldown-power", type=float, default=1.0, 577 | help="Power for polynomial cooldown schedule. Default: 1.0 (linear decay)" 578 | ) 579 | parser.add_argument( 580 | "--save-frequency", type=int, default=1, help="How often to save checkpoints." 581 | ) 582 | parser.add_argument( 583 | "--save-most-recent", 584 | action="store_true", 585 | default=False, 586 | help="Always save the most recent model trained to epoch_latest.pt.", 587 | ) 588 | parser.add_argument( 589 | "--negative-sampling-in-forward", 590 | action="store_true", 591 | default=False, 592 | help="Doing negative sampling in SampledCrossConSigLipLoss in the forward function.", 593 | ) 594 | parser.add_argument( 595 | "--negative-sampling-in-gpu", 596 | action="store_true", 597 | default=False, 598 | help="Doing negative sampling in SampledCrossConSigLipLoss inside the GPU.", 599 | ) 600 | parser.add_argument( 601 | "--merged-captions-num", type=int, default=1, help="Number of merged captions." 602 | ) 603 | parser.add_argument( 604 | "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot." 605 | ) 606 | parser.add_argument( 607 | "--val-frequency", type=int, default=1, help="How often to run evaluation with val data." 608 | ) 609 | parser.add_argument( 610 | "--resume", 611 | default=None, 612 | type=str, 613 | help="Path to latest checkpoint (default: none).", 614 | ) 615 | parser.add_argument( 616 | "--coco-random-subset", 617 | default=None, 618 | type=int, 619 | help="Randomly subset how many k number of samples in COCO dataset." 620 | ) 621 | parser.add_argument( 622 | "--coco-sliding-window", 623 | default=None, 624 | type=int, 625 | help="Number to specify the kth window to be used." 626 | ) 627 | 628 | parser.add_argument( 629 | "--precision", 630 | choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "pure_bf16", "pure_fp16", "fp32"], 631 | default="amp", 632 | help="Floating point precision." 633 | ) 634 | parser.add_argument( 635 | "--caption-sampling-mode", 636 | choices=["only_raw_caption", "raw_and_random", "random", "short", "dreamlip", "short-long-mixed-random", "diverse_sampling"], 637 | default="random", 638 | help="Floating point precision." 639 | ) 640 | parser.add_argument( 641 | "--negative-type", 642 | choices=["ijj", "iji"], 643 | default="ijj", 644 | help="Main type of negatives in text conditioned pooling." 645 | ) 646 | parser.add_argument( 647 | "--model", 648 | type=str, 649 | default="RN50", 650 | help="Name of the vision backbone to use.", 651 | ) 652 | 653 | parser.add_argument( 654 | "--target-model", 655 | type=str, 656 | default="RN50", 657 | help="Name of the target vision backbone to be copied.", 658 | ) 659 | 660 | parser.add_argument( 661 | "--pretrained", 662 | default='', 663 | type=str, 664 | help="Use a pretrained CLIP model weights with the specified tag or file path.", 665 | ) 666 | parser.add_argument( 667 | "--pretrained-image", 668 | default=False, 669 | action='store_true', 670 | help="Load imagenet pretrained weights for image tower backbone if available.", 671 | ) 672 | parser.add_argument( 673 | "--mixed-sampling-cross-con", 674 | default=False, 675 | action='store_true', 676 | help="Whether to use the mixed sampling mode for cross-con or not.", 677 | ) 678 | parser.add_argument( 679 | "--use-dreamlip-loss", 680 | default=False, 681 | action='store_true', 682 | help="If true, then we use dreamlip loss.", 683 | ) 684 | parser.add_argument( 685 | "--dreamlip-model", 686 | default=False, 687 | action='store_true', 688 | help="If true, then we use dreamlip model.", 689 | ) 690 | parser.add_argument( 691 | "--normal-clip-with-multi-cap", 692 | default=False, 693 | action='store_true', 694 | help="If the CLIP model is using multi captions or not.", 695 | ) 696 | parser.add_argument( 697 | "--lock-image", 698 | default=False, 699 | action='store_true', 700 | help="Lock full image tower by disabling gradients.", 701 | ) 702 | parser.add_argument( 703 | "--lock-image-unlocked-groups", 704 | type=int, 705 | default=0, 706 | help="Leave last n image tower layer groups unlocked.", 707 | ) 708 | parser.add_argument( 709 | "--lock-image-freeze-bn-stats", 710 | default=False, 711 | action='store_true', 712 | help="Freeze BatchNorm running stats in image tower for any locked layers.", 713 | ) 714 | parser.add_argument( 715 | "--text-conditioned-loss", 716 | default=False, 717 | action='store_true', 718 | help="Whether to use the text-conditioned loss for text conditioned CLIP model.", 719 | ) 720 | parser.add_argument( 721 | '--image-mean', type=float, nargs='+', default=None, metavar='MEAN', 722 | help='Override default image mean value of dataset.') 723 | parser.add_argument( 724 | '--image-std', type=float, nargs='+', default=None, metavar='STD', 725 | help='Override default image std deviation of of dataset.') 726 | parser.add_argument( 727 | '--image-interpolation', 728 | default=None, type=str, choices=['bicubic', 'bilinear', 'random'], 729 | help="Override default image resize interpolation." 730 | ) 731 | parser.add_argument( 732 | '--image-resize-mode', 733 | default=None, type=str, choices=['shortest', 'longest', 'squash'], 734 | help="Override default image resize (& crop) mode during inference." 735 | ) 736 | parser.add_argument('--aug-cfg', nargs='*', default={}, action=ParseKwargs) 737 | parser.add_argument( 738 | "--grad-checkpointing", 739 | default=False, 740 | action='store_true', 741 | help="Enable gradient checkpointing.", 742 | ) 743 | parser.add_argument( 744 | "--local-loss", 745 | default=False, 746 | action="store_true", 747 | help="Calculate loss w/ local features @ global (instead of realizing full global @ global matrix)." 748 | ) 749 | parser.add_argument( 750 | "--gather-with-grad", 751 | default=False, 752 | action="store_true", 753 | help="Enable full distributed gradient for feature gather." 754 | ) 755 | parser.add_argument( 756 | '--force-image-size', type=int, nargs='+', default=None, 757 | help='Override default image size.' 758 | ) 759 | parser.add_argument( 760 | "--force-quick-gelu", 761 | default=False, 762 | action='store_true', 763 | help="Force use of QuickGELU activation for non-OpenAI transformer models.", 764 | ) 765 | parser.add_argument( 766 | "--force-patch-dropout", 767 | default=None, 768 | type=float, 769 | help="Override the patch dropout during training, for fine tuning with no dropout near the end as in the paper.", 770 | ) 771 | parser.add_argument( 772 | "--force-custom-text", 773 | default=False, 774 | action='store_true', 775 | help="Force use of CustomTextCLIP model (separate text-tower).", 776 | ) 777 | parser.add_argument( 778 | "--torchscript", 779 | default=False, 780 | action='store_true', 781 | help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'.", 782 | ) 783 | parser.add_argument( 784 | "--torchcompile", 785 | default=False, 786 | action='store_true', 787 | help="torch.compile() the model, requires pytorch 2.0 or later.", 788 | ) 789 | parser.add_argument( 790 | "--trace", 791 | default=False, 792 | action='store_true', 793 | help="torch.jit.trace the model for inference / eval only.", 794 | ) 795 | parser.add_argument( 796 | "--accum-freq", type=int, default=1, help="Update the model every --acum-freq steps." 797 | ) 798 | # arguments for distributed training 799 | parser.add_argument( 800 | "--dist-url", 801 | default="env://", 802 | type=str, 803 | help="url used to set up distributed training", 804 | ) 805 | parser.add_argument( 806 | "--dist-backend", default="nccl", type=str, help="distributed backend" 807 | ) 808 | parser.add_argument( 809 | "--report-to", 810 | default='', 811 | type=str, 812 | help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']" 813 | ) 814 | parser.add_argument( 815 | "--wandb-notes", 816 | default='', 817 | type=str, 818 | help="Notes if logging with wandb." 819 | ) 820 | parser.add_argument( 821 | "--wandb-project-name", 822 | type=str, 823 | default='open-clip', 824 | help="Name of the project if logging with wandb.", 825 | ) 826 | parser.add_argument( 827 | "--debug", 828 | default=False, 829 | action="store_true", 830 | help="If true, more information is logged." 831 | ) 832 | parser.add_argument( 833 | "--copy-codebase", 834 | default=False, 835 | action="store_true", 836 | help="If true, we copy the entire base on the log directory, and execute from there." 837 | ) 838 | parser.add_argument( 839 | "--horovod", 840 | default=False, 841 | action="store_true", 842 | help="Use horovod for distributed training." 843 | ) 844 | parser.add_argument( 845 | "--ddp-static-graph", 846 | default=False, 847 | action='store_true', 848 | help="Enable static graph optimization for DDP in PyTorch >= 1.11.", 849 | ) 850 | parser.add_argument( 851 | "--no-set-device-rank", 852 | default=False, 853 | action="store_true", 854 | help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc)." 855 | ) 856 | parser.add_argument( 857 | "--num-sampled-captions", type=int, default=8, help="Number of sampled captions per image." 858 | ) 859 | parser.add_argument( 860 | "--num-sampled-long-captions", default=0, type=int, 861 | help="Number of sampled long captions per image. Set to 0 for the default setting." 862 | ) 863 | 864 | parser.add_argument( 865 | "--merged-num", type=int, default=1, help="Merged number" 866 | ) 867 | parser.add_argument( 868 | "--add-attn-pooling", 869 | default=False, 870 | action="store_true", 871 | help="To add attn-pooling in the end of vision encoder or not, note that this will set the original pooling to Identity." 872 | ) 873 | parser.add_argument( 874 | "--text-con-attn-pool", 875 | default=False, 876 | action="store_true", 877 | help="Indicating whether the model is using text conditioning or not. Must be specified if using text-conditioned models." 878 | ) 879 | parser.add_argument( 880 | "--seed", type=int, default=0, help="Default random seed." 881 | ) 882 | parser.add_argument( 883 | "--grad-clip-norm", type=float, default=None, help="Gradient clip." 884 | ) 885 | parser.add_argument( 886 | "--lock-text", 887 | default=False, 888 | action='store_true', 889 | help="Lock full text tower by disabling gradients.", 890 | ) 891 | parser.add_argument( 892 | "--lock-text-unlocked-layers", 893 | type=int, 894 | default=0, 895 | help="Leave last n text tower layer groups unlocked.", 896 | ) 897 | parser.add_argument( 898 | "--lock-text-freeze-layer-norm", 899 | default=False, 900 | action='store_true', 901 | help="Freeze LayerNorm running stats in text tower for any locked layers.", 902 | ) 903 | parser.add_argument( 904 | "--log-every-n-steps", 905 | type=int, 906 | default=100, 907 | help="Log every n steps to tensorboard/console/wandb.", 908 | ) 909 | parser.add_argument( 910 | "--coca-caption-loss-weight", 911 | type=float, 912 | default=2.0, 913 | help="Weight assigned to caption loss in CoCa." 914 | ) 915 | parser.add_argument( 916 | "--coca-contrastive-loss-weight", 917 | type=float, 918 | default=1.0, 919 | help="Weight assigned to contrastive loss when training CoCa." 920 | ) 921 | parser.add_argument( 922 | "--remote-sync", 923 | type=str, 924 | default=None, 925 | help="Optinoally sync with a remote path specified by this arg.", 926 | ) 927 | parser.add_argument( 928 | "--remote-sync-frequency", 929 | type=int, 930 | default=300, 931 | help="How frequently to sync to a remote directly if --remote-sync is not None.", 932 | ) 933 | parser.add_argument( 934 | "--remote-sync-protocol", 935 | choices=["s3", "fsspec"], 936 | default="s3", 937 | help="How to do the remote sync backup if --remote-sync is not None.", 938 | ) 939 | parser.add_argument( 940 | "--delete-previous-checkpoint", 941 | default=False, 942 | action="store_true", 943 | help="If true, delete previous checkpoint after storing a new one." 944 | ) 945 | parser.add_argument( 946 | "--distill-model", 947 | default=None, 948 | help='Which model arch to distill from, if any.' 949 | ) 950 | parser.add_argument( 951 | "--distill-pretrained", 952 | default=None, 953 | help='Which pre-trained weights to distill from, if any.' 954 | ) 955 | parser.add_argument( 956 | "--use-bnb-linear", 957 | default=None, 958 | help='Replace the network linear layers from the bitsandbytes library. ' 959 | 'Allows int8 training/inference, etc.' 960 | ) 961 | parser.add_argument( 962 | "--siglip", 963 | default=False, 964 | action="store_true", 965 | help='Use SigLip (sigmoid) loss.' 966 | ) 967 | parser.add_argument( 968 | "--same-row", 969 | default=False, 970 | action="store_true", 971 | help='If same row, the use a different Text-conditioned SigLIP loss, it also means that you are using a different model.' 972 | ) 973 | 974 | args = parser.parse_args(args) 975 | 976 | # If some params are not passed, we use the default values based on model name. 977 | default_params = get_default_params(args.model) 978 | for name, val in default_params.items(): 979 | if getattr(args, name) is None: 980 | setattr(args, name, val) 981 | 982 | return args 983 | -------------------------------------------------------------------------------- /src/flair/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is adapted from OpenCLIP: 3 | https://github.com/mlfoundations/open_clip/blob/main/src/open_clip_train/train.py 4 | 5 | The code integrates additional modifications and extensions to support the FLAIR models. 6 | Original authors: ML Foundations. 7 | """ 8 | import json 9 | import logging 10 | import math 11 | import os 12 | import time 13 | from unicodedata import normalize 14 | 15 | import torch.nn as nn 16 | import numpy as np 17 | import torch 18 | import torch.nn.functional as F 19 | from torch.nn.parallel.distributed import DistributedDataParallel 20 | from typing import Any, Dict, Optional, Tuple, Union 21 | 22 | try: 23 | import wandb 24 | except ImportError: 25 | wandb = None 26 | 27 | from open_clip import get_input_dtype 28 | from open_clip_train.distributed import is_master 29 | from open_clip_train.zero_shot import zero_shot_eval 30 | from open_clip_train.precision import get_autocast 31 | from tqdm import tqdm 32 | 33 | 34 | class AverageMeter(object): 35 | """Computes and stores the average and current value""" 36 | 37 | def __init__(self): 38 | self.reset() 39 | 40 | def reset(self): 41 | self.val = 0 42 | self.avg = 0 43 | self.sum = 0 44 | self.count = 0 45 | 46 | def update(self, val, n=1): 47 | self.val = val 48 | self.sum += val * n 49 | self.count += n 50 | self.avg = self.sum / self.count 51 | 52 | 53 | def postprocess_clip_output(model_out): 54 | return { 55 | "image_features": model_out[0], 56 | "text_features": model_out[1], 57 | "logit_scale": model_out[2] 58 | } 59 | 60 | 61 | def unwrap_model(model): 62 | if hasattr(model, 'module'): 63 | return model.module 64 | else: 65 | return model 66 | 67 | 68 | def backward(total_loss, scaler): 69 | if scaler is not None: 70 | scaler.scale(total_loss).backward() 71 | else: 72 | total_loss.backward() 73 | 74 | 75 | def get_reordered_indices(batch_size, num_batches): 76 | """ 77 | The original order: [(I_1, T_1), ..., (I_1, T_B), (I_2, T_1), ..., (I_2, T_B), ... 78 | (T_1, T_B+1)...] 79 | reorder to [(I_1, T_1), ..., (I_1, T_N), (I_2, T_1), ..., (I_2, T_N), ... , (I_N, T_1), ..., I(I_N, T_N)] 80 | returning a list of reordered indices 81 | """ 82 | reordered_indices = [] 83 | for k in range(batch_size): 84 | for n in range(num_batches): 85 | base_idx = n * batch_size * batch_size 86 | img_idx_start = base_idx + k * batch_size 87 | img_idx_end = img_idx_start + batch_size 88 | reordered_indices.extend(list(range(img_idx_start, img_idx_end))) 89 | 90 | return reordered_indices 91 | 92 | 93 | def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args, tb_writer=None): 94 | device = torch.device(args.device) 95 | autocast = get_autocast(args.precision) 96 | input_dtype = get_input_dtype(args.precision) 97 | 98 | model.train() 99 | 100 | data['train'].set_epoch(epoch) # set epoch in process safe manner via sampler or shared_epoch 101 | dataloader = data['train'].dataloader 102 | 103 | num_batches_per_epoch = dataloader.num_batches // args.accum_freq 104 | sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) 105 | 106 | if args.accum_freq > 1: 107 | accum_images, accum_texts, accum_features = [], [], {} 108 | 109 | losses_m = {} 110 | batch_time_m = AverageMeter() 111 | data_time_m = AverageMeter() 112 | end = time.time() 113 | for i, batch in enumerate(dataloader): 114 | i_accum = i // args.accum_freq 115 | step = num_batches_per_epoch * epoch + i_accum 116 | 117 | if not args.skip_scheduler: 118 | scheduler(step) 119 | 120 | #images, texts, img_ids = batch 121 | images, texts = batch 122 | # values, counts = torch.unique(img_ids, return_counts=True) 123 | images = images.to(device=device, dtype=input_dtype, non_blocking=True) 124 | texts = texts.to(device=device, non_blocking=True) 125 | 126 | data_time_m.update(time.time() - end) 127 | optimizer.zero_grad() 128 | 129 | if args.accum_freq == 1: 130 | with autocast(): 131 | model_out = model(images, texts) 132 | logit_scale = model_out["logit_scale"] 133 | losses = loss(**model_out, output_dict=True) 134 | 135 | total_loss = sum(losses.values()) 136 | losses["loss"] = total_loss 137 | 138 | backward(total_loss, scaler) 139 | else: 140 | # First, cache the features without any gradient tracking. 141 | with torch.no_grad(): 142 | with autocast(): 143 | model_out = model(images, texts) 144 | 145 | for f in ("logit_scale", "logit_bias"): 146 | model_out.pop(f, None) 147 | 148 | for key, val in model_out.items(): 149 | if key in accum_features: 150 | accum_features[key].append(val) 151 | else: 152 | accum_features[key] = [val] 153 | 154 | accum_images.append(images) 155 | accum_texts.append(texts) 156 | 157 | # If (i + 1) % accum_freq is not zero, move on to the next batch. 158 | if ((i + 1) % args.accum_freq) > 0: 159 | # FIXME this makes data time logging unreliable when accumulating 160 | continue 161 | 162 | # Now, ready to take gradients for the last accum_freq batches. 163 | # Re-do the forward pass for those batches, and use the cached features from the other batches as negatives. 164 | # Call backwards each time, but only step optimizer at the end. 165 | optimizer.zero_grad() 166 | for j in range(args.accum_freq): 167 | images = accum_images[j] 168 | texts = accum_texts[j] 169 | with autocast(): 170 | model_out = model(images, texts) 171 | inputs_no_accum = {} 172 | inputs_no_accum["logit_scale"] = logit_scale = model_out.pop("logit_scale") 173 | if "logit_bias" in model_out: 174 | inputs_no_accum["logit_bias"] = model_out.pop("logit_bias") 175 | 176 | inputs = {} 177 | for key, val in accum_features.items(): 178 | accumulated = accum_features[key] 179 | inputs[key] = torch.cat(accumulated[:j] + [model_out[key]] + accumulated[j + 1:]) 180 | 181 | losses = loss(**inputs, **inputs_no_accum, output_dict=True) 182 | del inputs 183 | del inputs_no_accum 184 | total_loss = sum(losses.values()) 185 | losses["loss"] = total_loss 186 | 187 | backward(total_loss, scaler) 188 | 189 | if scaler is not None: 190 | if args.horovod: 191 | optimizer.synchronize() 192 | scaler.unscale_(optimizer) 193 | if args.grad_clip_norm is not None: 194 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) 195 | with optimizer.skip_synchronize(): 196 | scaler.step(optimizer) 197 | else: 198 | if args.grad_clip_norm is not None: 199 | scaler.unscale_(optimizer) 200 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) 201 | scaler.step(optimizer) 202 | scaler.update() 203 | else: 204 | if args.grad_clip_norm is not None: 205 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) 206 | optimizer.step() 207 | 208 | # reset gradient accum, if enabled 209 | if args.accum_freq > 1: 210 | accum_images, accum_texts, accum_features = [], [], {} 211 | 212 | # Note: we clamp to 4.6052 = ln(100), as in the original paper. 213 | with torch.no_grad(): 214 | unwrap_model(model).logit_scale.clamp_(0, math.log(100)) 215 | 216 | batch_time_m.update(time.time() - end) 217 | end = time.time() 218 | batch_count = i_accum + 1 219 | if is_master(args) and (i_accum % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch): 220 | batch_size = len(images) 221 | num_samples = batch_count * batch_size * args.accum_freq * args.world_size 222 | samples_per_epoch = dataloader.num_samples 223 | percent_complete = 100.0 * batch_count / num_batches_per_epoch 224 | 225 | # NOTE loss is coarsely sampled, just master node and per log update 226 | for key, val in losses.items(): 227 | if key not in losses_m: 228 | losses_m[key] = AverageMeter() 229 | losses_m[key].update(val.item(), batch_size) 230 | 231 | logit_scale_scalar = logit_scale.item() 232 | loss_log = " ".join( 233 | [ 234 | f"{loss_name.capitalize()}: {loss_m.val:#.5g} ({loss_m.avg:#.5g})" 235 | for loss_name, loss_m in losses_m.items() 236 | ] 237 | ) 238 | samples_per_second = args.accum_freq * args.batch_size * args.world_size / batch_time_m.val 239 | samples_per_second_per_gpu = args.accum_freq * args.batch_size / batch_time_m.val 240 | logging.info( 241 | f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " 242 | f"Data (t): {data_time_m.avg:.3f} " 243 | f"Batch (t): {batch_time_m.avg:.3f}, {samples_per_second:#g}/s, {samples_per_second_per_gpu:#g}/s/gpu " 244 | f"LR: {optimizer.param_groups[0]['lr']:5f} " 245 | f"Logit Scale: {math.log(logit_scale_scalar):.3f} " + loss_log 246 | ) 247 | 248 | # Save train loss / etc. Using non avg meter values as loggers have their own smoothing 249 | log_data = { 250 | "data_time": data_time_m.val, 251 | "batch_time": batch_time_m.val, 252 | "samples_per_second": samples_per_second, 253 | "samples_per_second_per_gpu": samples_per_second_per_gpu, 254 | "scale": math.log(logit_scale_scalar), 255 | "lr": optimizer.param_groups[0]["lr"] 256 | } 257 | log_data.update({name: val.val for name, val in losses_m.items()}) 258 | 259 | log_data = {"train/" + name: val for name, val in log_data.items()} 260 | 261 | if tb_writer is not None: 262 | for name, val in log_data.items(): 263 | tb_writer.add_scalar(name, val, step) 264 | 265 | if args.wandb: 266 | assert wandb is not None, 'Please install wandb.' 267 | log_data['step'] = step # for backwards compatibility 268 | wandb.log(log_data, step=step) 269 | 270 | # resetting batch / data time meters per log window 271 | batch_time_m.reset() 272 | data_time_m.reset() 273 | 274 | 275 | def evaluate(model, data, epoch, args, tb_writer=None, tokenizer=None): 276 | metrics = {} 277 | if not is_master(args): 278 | return metrics 279 | device = torch.device(args.device) 280 | model.eval() 281 | zero_shot_metrics = zero_shot_eval( 282 | model, data, epoch, args, tokenizer=tokenizer) 283 | metrics.update(zero_shot_metrics) 284 | 285 | autocast = get_autocast(args.precision) 286 | input_dtype = get_input_dtype(args.precision) 287 | 288 | if args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs): 289 | if 'retrieval_coco' in data: 290 | txt_data, img_data, img2txt_dict, txt2img_dict = data['retrieval_coco'] 291 | txt_loader, img_loader = txt_data.dataloader, img_data.dataloader 292 | metrics = retrieval_on_split('retrieval_coco', model, txt_loader, img_loader, img2txt_dict, txt2img_dict, 293 | args, epoch, metrics, device, input_dtype, autocast) 294 | if 'retrieval_flickr' in data: 295 | txt_data, img_data, img2txt_dict, txt2img_dict = data['retrieval_flickr'] 296 | txt_loader, img_loader = txt_data.dataloader, img_data.dataloader 297 | metrics = retrieval_on_split('retrieval_flickr', model, txt_loader, img_loader, img2txt_dict, txt2img_dict, 298 | args, epoch, metrics, device, input_dtype, autocast) 299 | if 'retrieval_cc3m_train' in data: 300 | txt_data, img_data, img2txt_dict, txt2img_dict = data['retrieval_cc3m_train'] 301 | txt_loader, img_loader = txt_data.dataloader, img_data.dataloader 302 | metrics = retrieval_on_split('retrieval_cc3m_train', model, txt_loader, img_loader, img2txt_dict, 303 | txt2img_dict, 304 | args, epoch, metrics, device, input_dtype, autocast) 305 | 306 | if 'retrieval_docci' in data: 307 | txt_data, img_data, img2txt_dict, txt2img_dict = data['retrieval_docci'] 308 | txt_loader, img_loader = txt_data.dataloader, img_data.dataloader 309 | metrics = retrieval_on_split('retrieval_docci', model, txt_loader, img_loader, img2txt_dict, 310 | txt2img_dict, 311 | args, epoch, metrics, device, input_dtype, autocast) 312 | 313 | if 'retrieval_urban_1k' in data: 314 | txt_data, img_data, img2txt_dict, txt2img_dict = data['retrieval_urban_1k'] 315 | txt_loader, img_loader = txt_data.dataloader, img_data.dataloader 316 | metrics = retrieval_on_split('retrieval_urban_1k', model, txt_loader, img_loader, img2txt_dict, 317 | txt2img_dict, 318 | args, epoch, metrics, device, input_dtype, autocast) 319 | 320 | if 'retrieval_iiw' in data: 321 | txt_data, img_data, img2txt_dict, txt2img_dict = data['retrieval_iiw'] 322 | txt_loader, img_loader = txt_data.dataloader, img_data.dataloader 323 | metrics = retrieval_on_split('retrieval_iiw', model, txt_loader, img_loader, img2txt_dict, 324 | txt2img_dict, 325 | args, epoch, metrics, device, input_dtype, autocast) 326 | 327 | if 'retrieval_dci' in data: 328 | txt_data, img_data, img2txt_dict, txt2img_dict = data['retrieval_dci'] 329 | txt_loader, img_loader = txt_data.dataloader, img_data.dataloader 330 | metrics = retrieval_on_split('retrieval_dci', model, txt_loader, img_loader, img2txt_dict, 331 | txt2img_dict, 332 | args, epoch, metrics, device, input_dtype, autocast) 333 | 334 | if 'retrieval_sharegpt4v-1k' in data: 335 | txt_data, img_data, img2txt_dict, txt2img_dict = data['retrieval_sharegpt4v-1k'] 336 | txt_loader, img_loader = txt_data.dataloader, img_data.dataloader 337 | metrics = retrieval_on_split('retrieval_sharegpt4v-1k', model, txt_loader, img_loader, img2txt_dict, 338 | txt2img_dict, 339 | args, epoch, metrics, device, input_dtype, autocast) 340 | 341 | if 'retrieval_sharegpt4v-10k' in data: 342 | txt_data, img_data, img2txt_dict, txt2img_dict = data['retrieval_sharegpt4v-10k'] 343 | txt_loader, img_loader = txt_data.dataloader, img_data.dataloader 344 | metrics = retrieval_on_split('retrieval_sharegpt4v-10k', model, txt_loader, img_loader, img2txt_dict, 345 | txt2img_dict, 346 | args, epoch, metrics, device, input_dtype, autocast) 347 | 348 | if not metrics: 349 | return metrics 350 | 351 | logging.info( 352 | f"Eval Epoch: {epoch} " 353 | + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) 354 | ) 355 | 356 | log_data = {"val/" + name: val for name, val in metrics.items()} 357 | 358 | if args.save_logs: 359 | if tb_writer is not None: 360 | for name, val in log_data.items(): 361 | tb_writer.add_scalar(name, val, epoch) 362 | 363 | with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: 364 | f.write(json.dumps(metrics)) 365 | f.write("\n") 366 | 367 | if args.wandb: 368 | assert wandb is not None, 'Please install wandb.' 369 | if 'train' in data: 370 | dataloader = data['train'].dataloader 371 | num_batches_per_epoch = dataloader.num_batches // args.accum_freq 372 | step = num_batches_per_epoch * epoch 373 | else: 374 | step = None 375 | log_data['epoch'] = epoch 376 | wandb.log(log_data, step=step) 377 | 378 | return metrics 379 | 380 | 381 | def get_clip_metrics(image_features, text_features, logit_scale): 382 | metrics = {} 383 | logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu() 384 | logits_per_text = logits_per_image.t().detach().cpu() 385 | 386 | logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text} 387 | ground_truth = torch.arange(len(text_features)).view(-1, 1) 388 | 389 | for name, logit in logits.items(): 390 | ranking = torch.argsort(logit, descending=True) 391 | preds = torch.where(ranking == ground_truth)[1] 392 | preds = preds.detach().cpu().numpy() 393 | metrics[f"{name}_mean_rank"] = preds.mean() + 1 394 | metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1 395 | for k in [1, 5, 10]: 396 | metrics[f"{name}_R@{k}"] = np.mean(preds < k) 397 | 398 | return metrics 399 | 400 | 401 | def get_conditioned_clip_metrics(logits_per_image_list): 402 | """ 403 | :param logits_per_image_list: A list of containing all logits_per_images. 404 | Totally N batches, each batch contains B samples. Each logits_per_image should be of shape (B, N*B), already ordered. 405 | :return: 406 | """ 407 | metrics = {} 408 | logits_per_image = torch.cat(logits_per_image_list, dim=0) # shape: (N*B, N*B) 409 | logits_per_text = logits_per_image.t().detach().cpu() 410 | 411 | logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text} 412 | ground_truth = torch.arange(len(logits_per_image)).view(-1, 1) 413 | 414 | for name, logit in logits.items(): 415 | ranking = torch.argsort(logit, descending=True) 416 | preds = torch.where(ranking == ground_truth)[1] 417 | preds = preds.detach().cpu().numpy() 418 | metrics[f"{name}_mean_rank"] = preds.mean() + 1 419 | metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1 420 | for k in [1, 5, 10]: 421 | metrics[f"{name}_R@{k}"] = np.mean(preds < k) 422 | 423 | return metrics 424 | 425 | 426 | def maybe_compute_generative_loss(model_out): 427 | if "logits" in model_out and "labels" in model_out: 428 | token_logits = model_out["logits"] 429 | token_labels = model_out["labels"] 430 | return F.cross_entropy(token_logits.permute(0, 2, 1), token_labels) 431 | 432 | 433 | def remap_indices(merged_img_ids, cap_ids, img2txt_dict, txt2img_dict): 434 | """ 435 | params: 436 | merged_img_ids: tensor of shape (M, D) 437 | cap_ids: tensor of shape (N) (But the ordering might be random) 438 | img2txt_dict: dict mapping each img_id to a list of cap_ids 439 | txt2img_dict: dict mappint each cap_id to an img_id (a list of one element) 440 | text_features: tensor of shape (N, D) 441 | """ 442 | # so now ideally the cap_ids should be (0, ...N), so do the text_features 443 | # step2: re-index the merged_image_ids and re-do the mapping in the dict. 444 | # As the original image ids might just be random numbers, they don't represent the real ordering. 445 | 446 | img_id_mapping = {old_id.item(): new_idx for new_idx, old_id in enumerate(merged_img_ids)} 447 | reindexed_img_ids = torch.tensor([img_id_mapping[img_id.item()] for img_id in merged_img_ids]) 448 | 449 | # Update the img2txt_dict and txt2img_dict with new indices 450 | new_img2txt_dict = {img_id_mapping[img_id]: [cap_id for cap_id in cap_id_list] 451 | for img_id, cap_id_list in img2txt_dict.items()} 452 | 453 | new_txt2img_dict = {cap_id: img_id_mapping[txt2img_dict[cap_id][0]] 454 | for cap_id in txt2img_dict.keys()} 455 | 456 | return new_img2txt_dict, new_txt2img_dict 457 | 458 | 459 | def compute_retrieval(similarity_scores, txt2img, img2txt): 460 | if isinstance(similarity_scores, tuple): 461 | i2t_similarity_score, t2i_similarity_score = similarity_scores 462 | else: 463 | # Otherwise, treat similarity_scores as a single matrix for t2i 464 | t2i_similarity_score = similarity_scores.t() 465 | i2t_similarity_score = similarity_scores 466 | 467 | t2i_ranks = torch.zeros(t2i_similarity_score.shape[0]) 468 | 469 | for index, score in enumerate(t2i_similarity_score): 470 | inds = torch.argsort(score, descending=True) 471 | t2i_ranks[index] = torch.where(inds == txt2img[index])[0][0] 472 | 473 | # Compute metrics 474 | tr1 = len(torch.where(t2i_ranks < 1)[0]) / len(t2i_ranks) 475 | tr5 = len(torch.where(t2i_ranks < 5)[0]) / len(t2i_ranks) 476 | tr10 = len(torch.where(t2i_ranks < 10)[0]) / len(t2i_ranks) 477 | t2i_report_dict = { 478 | "text_to_image_R@1": tr1, 479 | "text_to_image_R@5": tr5, 480 | "text_to_image_R@10": tr10, 481 | "text_to_image_mean_rank": t2i_ranks.mean().item() + 1, 482 | "text_to_image_median_rank": np.floor(np.median(t2i_ranks.numpy())) + 1 483 | } 484 | 485 | # comput image -> text 486 | i2t_ranks = torch.zeros(i2t_similarity_score.shape[0]) 487 | for index, score in enumerate(i2t_similarity_score): 488 | inds = torch.argsort(score, descending=True) 489 | # Score 490 | rank = 1e10 491 | for i in img2txt[index]: 492 | tmp = torch.where(inds == i)[0][0] 493 | if tmp < rank: 494 | rank = tmp 495 | i2t_ranks[index] = rank 496 | 497 | # Compute metrics 498 | ir1 = len(torch.where(i2t_ranks < 1)[0]) / len(i2t_ranks) 499 | ir5 = len(torch.where(i2t_ranks < 5)[0]) / len(i2t_ranks) 500 | ir10 = len(torch.where(i2t_ranks < 10)[0]) / len(i2t_ranks) 501 | 502 | i2t_report_dict = { 503 | "image_to_text_R@1": ir1, 504 | "image_to_text_R@5": ir5, 505 | "image_to_text_R@10": ir10, 506 | "image_to_text_mean_rank": i2t_ranks.mean().item() + 1, 507 | "image_to_text_median_rank": np.floor(np.median(i2t_ranks.numpy())) + 1 508 | } 509 | metrics = {**t2i_report_dict, **i2t_report_dict} 510 | return metrics 511 | 512 | def retrieval_on_split(keyword, model, txt_loader, img_loader, img2txt_dict, txt2img_dict, args, epoch, metrics, device, 513 | input_dtype, autocast): 514 | num_txt_samples = txt_loader.num_samples 515 | num_img_samples = img_loader.num_samples 516 | all_image_features, all_text_tokens, all_text_features = [], [], [] 517 | all_local_text_tokens = [] 518 | all_img_ids, all_cap_ids = [], [] 519 | 520 | with torch.no_grad(): 521 | # first loop over the text dataloader to store all text embeddings 522 | #for i, batch in tqdm(enumerate(txt_loader), total=len(txt_loader), desc="Processing Texts"): 523 | for i, batch in enumerate(txt_loader): 524 | texts, cap_id = batch 525 | texts = texts.to(device=device, non_blocking=True) 526 | with autocast(): 527 | if args.inference_with_flair: 528 | global_text_token, local_text_tokens = unwrap_model(model).encode_text(texts, normalize=False) 529 | global_text_token, local_text_tokens = unwrap_model(model).text_post( 530 | global_text_token), unwrap_model(model).text_post(local_text_tokens) 531 | text_features = F.normalize(global_text_token, dim=-1) 532 | all_text_tokens.append(global_text_token.squeeze(1)) # GPU 533 | all_local_text_tokens.append(local_text_tokens) # GPU 534 | else: 535 | text_features = unwrap_model(model).encode_text(texts, normalize=True) 536 | 537 | all_text_features.append(text_features.detach().cpu()) # cpu list of N, each of shape (B, D) 538 | all_cap_ids.append(cap_id.detach().cpu()) 539 | all_text_features_tensor = torch.cat(all_text_features) # (N, 512) 540 | cap_ids = torch.cat(all_cap_ids) 541 | 542 | 543 | if args.inference_with_flair: 544 | mode = "inference_with_flair" 545 | all_text_tokens_tensor = torch.cat(all_text_tokens) # on GPU 546 | all_local_text_tokens_tensor = torch.cat(all_local_text_tokens) 547 | 548 | similarity_scores, img_ids = compute_similarity_scores_attn_pool( 549 | model, img_loader, all_text_features_tensor, all_text_tokens_tensor, device, input_dtype, autocast, mode 550 | ) 551 | else: 552 | similarity_scores, img_ids = compute_similarity_scores_original_clip(model, img_loader, 553 | all_text_features_tensor, device, 554 | input_dtype, 555 | autocast, 556 | mode='original_clip') 557 | new_img2txt_dict, new_txt2img_dict = remap_indices(merged_img_ids=img_ids, cap_ids=cap_ids, 558 | img2txt_dict=img2txt_dict, txt2img_dict=txt2img_dict) 559 | 560 | retrieval_metrics = compute_retrieval(similarity_scores=similarity_scores, 561 | txt2img=new_txt2img_dict, 562 | img2txt=new_img2txt_dict) 563 | 564 | if keyword != '': 565 | temp_retrieval_metrics = {} 566 | keyword = keyword + '_' 567 | for k, v in retrieval_metrics.items(): 568 | temp_retrieval_metrics[keyword + k] = v 569 | retrieval_metrics = temp_retrieval_metrics 570 | 571 | if "epoch" in metrics: # we only need one epoch information 572 | metrics.update( 573 | {**retrieval_metrics, 574 | f"{keyword}num_text_samples": num_txt_samples, 575 | f"{keyword}num_image_samples": num_img_samples 576 | } 577 | ) 578 | else: 579 | metrics.update( 580 | {**retrieval_metrics, 581 | f"epoch": epoch, 582 | f"{keyword}num_text_samples": num_txt_samples, 583 | f"{keyword}num_image_samples": num_img_samples 584 | } 585 | ) 586 | 587 | return metrics 588 | 589 | 590 | def compute_similarity_scores_original_clip(model, img_loader, all_text_features_tensor, device, input_dtype, 591 | autocast, mode='original_clip'): 592 | all_image_features = [] 593 | all_img_ids = [] 594 | 595 | for i, batch in enumerate(img_loader): 596 | images, img_id = batch 597 | images = images.to(device=device, dtype=input_dtype, non_blocking=True) 598 | all_img_ids.append(img_id.detach().cpu()) 599 | 600 | with autocast(): 601 | if mode == 'original_clip': 602 | image_features = unwrap_model(model).encode_image(images, normalize=True) 603 | elif mode == 'imgcon': 604 | _, local_image_tokens = unwrap_model(model).encode_image(images) 605 | local_image_tokens = unwrap_model(model).image_post(local_image_tokens) 606 | image_features = unwrap_model(model).visual_proj(local_image_tokens.mean(dim=1, keepdim=True), local_image_tokens, local_image_tokens) 607 | image_features = image_features.squeeze(1) 608 | image_features = F.normalize(image_features, dim=-1) 609 | logit_scale = unwrap_model(model).logit_scale.exp() 610 | all_image_features.append(image_features.detach().cpu()) 611 | 612 | all_image_features_tensor = torch.cat(all_image_features) 613 | img_ids = torch.cat(all_img_ids) 614 | 615 | similarity_scores = logit_scale.cpu() * all_image_features_tensor @ all_text_features_tensor.t() 616 | return similarity_scores, img_ids 617 | 618 | 619 | def compute_similarity_scores_attn_pool(model, img_loader, all_text_features_tensor, all_text_tokens_tensor, device, 620 | input_dtype, 621 | autocast, mode): 622 | logits_per_image_list = [] 623 | all_img_ids = [] 624 | 625 | for i, batch in enumerate(img_loader): 626 | images, img_id = batch 627 | images = images.to(device=device, dtype=input_dtype, non_blocking=True) 628 | all_img_ids.append(img_id.detach().cpu()) 629 | with autocast(): 630 | if mode == 'inference_with_flair': 631 | _, image_embeddings = unwrap_model(model).encode_image(images, normalize=False) 632 | image_embeddings = unwrap_model(model).image_post(image_embeddings) # down proj to 256 633 | img_features_after_conditioning = unwrap_model(model).visual_proj( 634 | all_text_tokens_tensor.unsqueeze(0), 635 | image_embeddings, 636 | image_embeddings 637 | ) 638 | img_features_after_conditioning = F.normalize(img_features_after_conditioning, dim=-1).detach().cpu() 639 | embed_dim = img_features_after_conditioning.shape[-1] 640 | img_features_after_conditioning = img_features_after_conditioning.contiguous().view(-1, embed_dim) 641 | else: 642 | embed_dim = all_text_features_tensor.shape[-1] 643 | img_features_after_conditioning = unwrap_model(model).visual_proj( 644 | all_text_tokens_tensor.unsqueeze(0), 645 | image_embeddings, 646 | image_embeddings 647 | ).detach().cpu().contiguous().view(-1, embed_dim) 648 | 649 | logit_scale = unwrap_model(model).logit_scale.exp() 650 | logits_per_image = (logit_scale.cpu() * torch.einsum('ij,ij->i', img_features_after_conditioning, 651 | all_text_features_tensor)).unsqueeze(0).detach().cpu() 652 | logits_per_image_list.append(logits_per_image) 653 | 654 | img_ids = torch.cat(all_img_ids) # shape (M) 655 | similarity_scores = torch.cat(logits_per_image_list) # shape (M, N) 656 | return similarity_scores, img_ids -------------------------------------------------------------------------------- /src/flair/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is adapted from OpenCLIP: 3 | https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py 4 | 5 | The code integrates additional modifications and extensions to support the FLAIR models. 6 | Original authors: ML Foundations. 7 | """ 8 | 9 | from collections import OrderedDict 10 | import math 11 | from typing import Callable, List, Optional, Sequence, Tuple, Union 12 | from functools import partial 13 | 14 | import torch 15 | from torch import nn 16 | from torch.nn import functional as F 17 | from torch.utils.checkpoint import checkpoint 18 | 19 | from open_clip.utils import to_2tuple 20 | from open_clip.pos_embed import get_2d_sincos_pos_embed 21 | import torch.distributed as dist 22 | 23 | 24 | 25 | class LayerNormFp32(nn.LayerNorm): 26 | """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" 27 | 28 | def forward(self, x: torch.Tensor): 29 | orig_type = x.dtype 30 | x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps) 31 | return x.to(orig_type) 32 | 33 | 34 | class LayerNorm(nn.LayerNorm): 35 | """Subclass torch's LayerNorm (with cast back to input dtype).""" 36 | 37 | def forward(self, x: torch.Tensor): 38 | orig_type = x.dtype 39 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 40 | return x.to(orig_type) 41 | 42 | 43 | class QuickGELU(nn.Module): 44 | # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory 45 | def forward(self, x: torch.Tensor): 46 | return x * torch.sigmoid(1.702 * x) 47 | 48 | 49 | class LayerScale(nn.Module): 50 | def __init__(self, dim, init_values=1e-5, inplace=False): 51 | super().__init__() 52 | self.inplace = inplace 53 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 54 | 55 | def forward(self, x): 56 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 57 | 58 | 59 | class PatchDropout(nn.Module): 60 | """ 61 | https://arxiv.org/abs/2212.00794 62 | """ 63 | 64 | def __init__(self, prob, exclude_first_token=True): 65 | super().__init__() 66 | assert 0 <= prob < 1. 67 | self.prob = prob 68 | self.exclude_first_token = exclude_first_token # exclude CLS token 69 | 70 | def forward(self, x): 71 | if not self.training or self.prob == 0.: 72 | return x 73 | 74 | if self.exclude_first_token: 75 | cls_tokens, x = x[:, :1], x[:, 1:] 76 | else: 77 | cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) 78 | 79 | batch = x.size()[0] 80 | num_tokens = x.size()[1] 81 | 82 | batch_indices = torch.arange(batch) 83 | batch_indices = batch_indices[..., None] 84 | 85 | keep_prob = 1 - self.prob 86 | num_patches_keep = max(1, int(num_tokens * keep_prob)) 87 | 88 | rand = torch.randn(batch, num_tokens) 89 | patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices 90 | 91 | x = x[batch_indices, patch_indices_keep] 92 | 93 | if self.exclude_first_token: 94 | x = torch.cat((cls_tokens, x), dim=1) 95 | 96 | return x 97 | 98 | 99 | class Attention(nn.Module): 100 | def __init__( 101 | self, 102 | dim: int, 103 | num_heads: int = 8, 104 | qkv_bias: bool = True, 105 | scaled_cosine: bool = False, 106 | scale_heads: bool = False, 107 | logit_scale_max: float = math.log(1. / 0.01), 108 | batch_first: bool = True, 109 | attn_drop: float = 0., 110 | proj_drop: float = 0. 111 | ): 112 | super().__init__() 113 | self.scaled_cosine = scaled_cosine 114 | self.scale_heads = scale_heads 115 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 116 | self.num_heads = num_heads 117 | self.head_dim = dim // num_heads 118 | self.scale = self.head_dim ** -0.5 119 | self.logit_scale_max = logit_scale_max 120 | self.batch_first = batch_first 121 | self.use_fsdpa = hasattr(nn.functional, 'scaled_dot_product_attention') 122 | 123 | # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original 124 | self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) 125 | if qkv_bias: 126 | self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) 127 | else: 128 | self.in_proj_bias = None 129 | 130 | if self.scaled_cosine: 131 | self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) 132 | else: 133 | self.logit_scale = None 134 | self.attn_drop = nn.Dropout(attn_drop) 135 | if self.scale_heads: 136 | self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) 137 | else: 138 | self.head_scale = None 139 | self.out_proj = nn.Linear(dim, dim) 140 | self.out_drop = nn.Dropout(proj_drop) 141 | 142 | def forward(self, x, attn_mask: Optional[torch.Tensor] = None): 143 | if self.batch_first: 144 | x = x.transpose(0, 1) 145 | 146 | L, N, C = x.shape 147 | q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) 148 | q = q.reshape(L, N * self.num_heads, -1).transpose(0, 1) 149 | k = k.reshape(L, N * self.num_heads, -1).transpose(0, 1) 150 | v = v.reshape(L, N * self.num_heads, -1).transpose(0, 1) 151 | 152 | if attn_mask is not None and attn_mask.dtype == torch.bool: 153 | new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) 154 | new_attn_mask.masked_fill_(attn_mask, float("-inf")) 155 | attn_mask = new_attn_mask 156 | 157 | if self.logit_scale is not None: 158 | attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) 159 | logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() 160 | attn = attn.view(N, self.num_heads, L, L) * logit_scale 161 | attn = attn.view(-1, L, L) 162 | if attn_mask is not None: 163 | attn = attn + attn_mask 164 | attn = attn.softmax(dim=-1) 165 | attn = self.attn_drop(attn) 166 | x = torch.bmm(attn, v) 167 | else: 168 | if self.use_fsdpa: 169 | x = F.scaled_dot_product_attention( 170 | q, k, v, 171 | attn_mask=attn_mask, 172 | dropout_p=self.attn_drop.p if self.training else 0., 173 | ) 174 | else: 175 | q = q * self.scale 176 | attn = torch.bmm(q, k.transpose(-1, -2)) 177 | if attn_mask is not None: 178 | attn += attn_mask 179 | attn = attn.softmax(dim=-1) 180 | attn = self.attn_drop(attn) 181 | x = torch.bmm(attn, v) 182 | 183 | if self.head_scale is not None: 184 | x = x.view(N, self.num_heads, L, C) * self.head_scale 185 | x = x.view(-1, L, C) 186 | 187 | x = x.transpose(0, 1).reshape(L, N, C) 188 | 189 | if self.batch_first: 190 | x = x.transpose(0, 1) 191 | 192 | x = self.out_proj(x) 193 | x = self.out_drop(x) 194 | return x 195 | 196 | 197 | class AttentionalPooler(nn.Module): 198 | def __init__( 199 | self, 200 | d_model: int, 201 | context_dim: int, 202 | n_head: int = 8, 203 | n_queries: int = 256, 204 | norm_layer: Callable = LayerNorm, 205 | ): 206 | super().__init__() 207 | self.query = nn.Parameter(torch.randn(n_queries, d_model)) 208 | self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim, batch_first=True) 209 | self.ln_q = norm_layer(d_model) 210 | self.ln_k = norm_layer(context_dim) 211 | 212 | def forward(self, x: torch.Tensor): 213 | N = x.shape[0] 214 | x = self.ln_k(x) 215 | q = self.ln_q(self.query) 216 | # print("query shape:", q.shape) # (256, 512) -> (128, 256, 512) after unsqueezed 217 | # print("x shape for key and value", x.shape) # (128, 197, 768) 218 | #out = self.attn(q.unsqueeze(0).expand(N, -1, -1), x, x, need_weights=False)[0] 219 | attn_outputs, attn_output_weights = self.attn(q.unsqueeze(0).expand(N, -1, -1), x, x, need_weights=False) 220 | # print("attn_outputs", attn_outputs.shape) 221 | out = attn_outputs 222 | return out 223 | 224 | 225 | 226 | class PureAttentionPoolingBlock(nn.Module): 227 | """ 228 | Just a pure attn_pooling implementation, without ln_post, without projection, no mormalized_final 229 | """ 230 | 231 | def __init__( 232 | self, 233 | context_dim: int, 234 | n_head: int = 8, 235 | norm_layer: Callable = LayerNorm, 236 | need_weights: bool = False 237 | ): 238 | super().__init__() 239 | self.attn = nn.MultiheadAttention(context_dim, n_head, kdim=context_dim, vdim=context_dim, batch_first=True, 240 | add_zero_attn=True) 241 | #self.attn = nn.MultiheadAttention(context_dim, n_head, kdim=context_dim, vdim=context_dim, batch_first=True) 242 | self.ln_q = norm_layer(context_dim) 243 | self.ln_k = norm_layer(context_dim) 244 | self.ln_v = norm_layer(context_dim) 245 | self.need_weights=need_weights 246 | 247 | def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output_attn_weights=False, average_attn_weights=True): 248 | batch_size, seg_length, embed_dim = k.size() 249 | q = self.ln_q(q) 250 | k = self.ln_k(k) 251 | v = self.ln_v(v) 252 | 253 | if self.need_weights or output_attn_weights: 254 | out, attn_weights = self.attn(q, k, v, need_weights=True, average_attn_weights=average_attn_weights) 255 | return out, attn_weights 256 | else: 257 | out = self.attn(q, k, v, need_weights=False)[0] 258 | # we can directly normalize the output, without setting a flag 259 | #return F.normalize(out, dim=-1) 260 | return out 261 | 262 | 263 | class VisionPostProcess(nn.Module): 264 | def __init__( 265 | self, 266 | context_dim: int, 267 | output_dim, 268 | norm_layer: Callable = LayerNorm, 269 | normalize_final: bool = True, 270 | skip_ln = False 271 | ): 272 | super().__init__() 273 | self.skip_ln = skip_ln 274 | self.scale = context_dim ** -0.5 275 | self.proj = nn.Parameter(self.scale * torch.randn(context_dim, output_dim)) 276 | if not self.skip_ln: 277 | self.ln_post = norm_layer(context_dim) 278 | self.normalize_final = normalize_final 279 | 280 | def forward(self, x): 281 | if self.skip_ln: 282 | out = x @ self.proj 283 | else: 284 | out = self.ln_post(x) @ self.proj 285 | 286 | if self.normalize_final: 287 | return F.normalize(out, dim=-1) 288 | else: 289 | return out 290 | 291 | 292 | class TextPostProcess(nn.Module): 293 | def __init__( 294 | self, 295 | context_dim: int, 296 | output_dim, 297 | norm_layer: Callable = LayerNorm, 298 | normalize_final: bool = True, 299 | skip_ln: bool = False 300 | ): 301 | super().__init__() 302 | self.skip_ln = skip_ln 303 | self.scale = context_dim ** -0.5 304 | self.proj = nn.Parameter(self.scale * torch.randn(context_dim, output_dim)) 305 | if not self.skip_ln: 306 | self.ln_post = norm_layer(context_dim) 307 | self.normalize_final = normalize_final 308 | 309 | def forward(self, x): 310 | if self.skip_ln: 311 | out = x @ self.proj 312 | else: 313 | out = self.ln_post(x) @ self.proj 314 | 315 | if self.normalize_final: 316 | return F.normalize(out, dim=-1) 317 | else: 318 | return out 319 | 320 | 321 | class ResidualAttentionBlock(nn.Module): 322 | def __init__( 323 | self, 324 | d_model: int, 325 | n_head: int, 326 | mlp_ratio: float = 4.0, 327 | ls_init_value: float = None, 328 | act_layer: Callable = nn.GELU, 329 | norm_layer: Callable = LayerNorm, 330 | is_cross_attention: bool = False, 331 | batch_first: bool = True, 332 | ): 333 | super().__init__() 334 | 335 | self.ln_1 = norm_layer(d_model) 336 | self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=batch_first) 337 | self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() 338 | if is_cross_attention: 339 | self.ln_1_kv = norm_layer(d_model) 340 | 341 | self.ln_2 = norm_layer(d_model) 342 | mlp_width = int(d_model * mlp_ratio) 343 | self.mlp = nn.Sequential(OrderedDict([ 344 | ("c_fc", nn.Linear(d_model, mlp_width)), 345 | ("gelu", act_layer()), 346 | ("c_proj", nn.Linear(mlp_width, d_model)) 347 | ])) 348 | self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() 349 | 350 | def attention( 351 | self, 352 | q_x: torch.Tensor, 353 | k_x: Optional[torch.Tensor] = None, 354 | v_x: Optional[torch.Tensor] = None, 355 | attn_mask: Optional[torch.Tensor] = None, 356 | ): 357 | k_x = k_x if k_x is not None else q_x 358 | v_x = v_x if v_x is not None else q_x 359 | 360 | attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None 361 | return self.attn( 362 | q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask 363 | )[0] 364 | 365 | def forward( 366 | self, 367 | q_x: torch.Tensor, 368 | k_x: Optional[torch.Tensor] = None, 369 | v_x: Optional[torch.Tensor] = None, 370 | attn_mask: Optional[torch.Tensor] = None, 371 | ): 372 | k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None 373 | v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None 374 | x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) 375 | x = x + self.ls_2(self.mlp(self.ln_2(x))) 376 | return x 377 | 378 | 379 | class CustomResidualAttentionBlock(nn.Module): 380 | def __init__( 381 | self, 382 | d_model: int, 383 | n_head: int, 384 | mlp_ratio: float = 4.0, 385 | ls_init_value: float = None, 386 | act_layer: Callable = nn.GELU, 387 | norm_layer: Callable = LayerNorm, 388 | scale_cosine_attn: bool = False, 389 | scale_heads: bool = False, 390 | scale_attn: bool = False, 391 | scale_fc: bool = False, 392 | batch_first: bool = True, 393 | ): 394 | super().__init__() 395 | 396 | self.ln_1 = norm_layer(d_model) 397 | self.attn = Attention( 398 | d_model, 399 | n_head, 400 | scaled_cosine=scale_cosine_attn, 401 | scale_heads=scale_heads, 402 | batch_first=batch_first, 403 | ) 404 | self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() 405 | self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() 406 | 407 | self.ln_2 = norm_layer(d_model) 408 | mlp_width = int(d_model * mlp_ratio) 409 | self.mlp = nn.Sequential(OrderedDict([ 410 | ("c_fc", nn.Linear(d_model, mlp_width)), 411 | ("gelu", act_layer()), 412 | ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), 413 | ("c_proj", nn.Linear(mlp_width, d_model)) 414 | ])) 415 | self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() 416 | 417 | def get_reference_weight(self): 418 | return self.mlp.c_fc.weight 419 | 420 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 421 | x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) 422 | x = x + self.ls_2(self.mlp(self.ln_2(x))) 423 | return x 424 | 425 | 426 | def _expand_token(token, batch_size: int): 427 | return token.view(1, 1, -1).expand(batch_size, -1, -1) 428 | 429 | 430 | class Transformer(nn.Module): 431 | def __init__( 432 | self, 433 | width: int, 434 | layers: int, 435 | heads: int, 436 | mlp_ratio: float = 4.0, 437 | ls_init_value: float = None, 438 | act_layer: Callable = nn.GELU, 439 | norm_layer: Callable = LayerNorm, 440 | batch_first: bool = True, 441 | ): 442 | super().__init__() 443 | self.width = width 444 | self.layers = layers 445 | self.batch_first = batch_first 446 | self.grad_checkpointing = False 447 | 448 | self.resblocks = nn.ModuleList([ 449 | ResidualAttentionBlock( 450 | width, 451 | heads, 452 | mlp_ratio, 453 | ls_init_value=ls_init_value, 454 | act_layer=act_layer, 455 | norm_layer=norm_layer, 456 | batch_first=batch_first, 457 | ) 458 | for _ in range(layers) 459 | ]) 460 | 461 | def get_cast_dtype(self) -> torch.dtype: 462 | if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'): 463 | return self.resblocks[0].mlp.c_fc.int8_original_dtype 464 | return self.resblocks[0].mlp.c_fc.weight.dtype 465 | 466 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 467 | if not self.batch_first: 468 | x = x.transpose(0, 1).contiguous() # NLD -> LND 469 | for r in self.resblocks: 470 | if self.grad_checkpointing and not torch.jit.is_scripting(): 471 | # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 472 | x = checkpoint(r, x, None, None, attn_mask) 473 | else: 474 | x = r(x, attn_mask=attn_mask) 475 | if not self.batch_first: 476 | x = x.transpose(0, 1) # LND -> NLD 477 | return x 478 | 479 | 480 | class VisionTransformer(nn.Module): 481 | output_tokens: torch.jit.Final[bool] 482 | 483 | def __init__( 484 | self, 485 | image_size: int, 486 | patch_size: int, 487 | width: int, 488 | layers: int, 489 | heads: int, 490 | mlp_ratio: float, 491 | ls_init_value: float = None, 492 | attentional_pool: bool = False, 493 | attn_pooler_queries: int = 256, 494 | attn_pooler_heads: int = 8, 495 | output_dim: int = 512, 496 | patch_dropout: float = 0., 497 | no_ln_pre: bool = False, 498 | pos_embed_type: str = 'learnable', 499 | pool_type: str = 'tok', 500 | final_ln_after_pool: bool = False, 501 | act_layer: Callable = nn.GELU, 502 | norm_layer: Callable = LayerNorm, 503 | output_tokens: bool = False, 504 | project_tokens: bool = False, 505 | text_con: bool = False, 506 | skip_final_pooling: bool = False 507 | ): 508 | super().__init__() 509 | assert pool_type in ('tok', 'avg', 'none') 510 | # Additional Assertions 511 | 512 | self.output_tokens = output_tokens 513 | image_height, image_width = self.image_size = to_2tuple(image_size) 514 | patch_height, patch_width = self.patch_size = to_2tuple(patch_size) 515 | self.grid_size = (image_height // patch_height, image_width // patch_width) 516 | self.final_ln_after_pool = final_ln_after_pool # currently ignored w/ attn pool enabled 517 | self.output_dim = output_dim 518 | self.width = width #added, to indicate the dim before pooling 519 | self.skip_final_pooling = skip_final_pooling 520 | 521 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 522 | 523 | # class embeddings and positional embeddings 524 | scale = width ** -0.5 525 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 526 | if pos_embed_type == 'learnable': 527 | self.positional_embedding = nn.Parameter( 528 | scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) 529 | elif pos_embed_type == 'sin_cos_2d': 530 | # fixed sin-cos embedding 531 | assert self.grid_size[0] == self.grid_size[1], \ 532 | 'currently sin cos 2d pos embedding only supports square input' 533 | self.positional_embedding = nn.Parameter( 534 | torch.zeros(self.grid_size[0] * self.grid_size[1] + 1, width), requires_grad=False) 535 | pos_embed_type = get_2d_sincos_pos_embed(width, self.grid_size[0], cls_token=True) 536 | self.positional_embedding.data.copy_(torch.from_numpy(pos_embed_type).float()) 537 | else: 538 | raise ValueError 539 | 540 | # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn 541 | self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() 542 | 543 | self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width) 544 | self.transformer = Transformer( 545 | width, 546 | layers, 547 | heads, 548 | mlp_ratio, 549 | ls_init_value=ls_init_value, 550 | act_layer=act_layer, 551 | norm_layer=norm_layer, 552 | ) 553 | self.attn_pool_type = None 554 | self.proj_tokens = project_tokens 555 | self.text_con = text_con 556 | 557 | if attentional_pool: 558 | if isinstance(attentional_pool, str): 559 | self.attn_pool_type = attentional_pool 560 | self.pool_type = 'none' 561 | if attentional_pool in ('parallel', 'cascade'): 562 | self.attn_pool = AttentionalPooler( 563 | output_dim, 564 | width, 565 | n_head=attn_pooler_heads, 566 | n_queries=attn_pooler_queries, 567 | ) 568 | self.attn_pool_contrastive = AttentionalPooler( 569 | output_dim, 570 | width, 571 | n_head=attn_pooler_heads, 572 | n_queries=1, 573 | ) 574 | self.attn_pool_contrastive = None 575 | pool_dim = width 576 | else: 577 | assert False 578 | else: 579 | self.attn_pool_type = '' 580 | #self.pool_type = pool_type 581 | self.pool_type = 'none' 582 | self.attn_pool = AttentionalPooler( 583 | output_dim, 584 | width, 585 | n_head=attn_pooler_heads, 586 | n_queries=1, #origin: attn_pooler_queries 587 | ) 588 | self.attn_pool_contrastive = None 589 | pool_dim = output_dim 590 | else: 591 | self.attn_pool = None 592 | pool_dim = width 593 | self.pool_type = pool_type 594 | 595 | if self.text_con: 596 | self.ln_post = norm_layer(pool_dim) 597 | else: 598 | self.ln_post = norm_layer(pool_dim) 599 | self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim)) 600 | self.init_parameters() 601 | 602 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 603 | for param in self.parameters(): 604 | param.requires_grad = False 605 | 606 | if unlocked_groups != 0: 607 | groups = [ 608 | [ 609 | self.conv1, 610 | self.class_embedding, 611 | self.positional_embedding, 612 | self.ln_pre, 613 | ], 614 | *self.transformer.resblocks[:-1], 615 | [ 616 | self.transformer.resblocks[-1], 617 | self.ln_post, 618 | ], 619 | self.proj, 620 | ] 621 | 622 | def _unlock(x): 623 | if isinstance(x, Sequence): 624 | for g in x: 625 | _unlock(g) 626 | else: 627 | if isinstance(x, torch.nn.Parameter): 628 | x.requires_grad = True 629 | else: 630 | for p in x.parameters(): 631 | p.requires_grad = True 632 | 633 | _unlock(groups[-unlocked_groups:]) 634 | 635 | def init_parameters(self): 636 | # FIXME OpenAI CLIP did not define an init for the VisualTransformer 637 | # TODO experiment if default PyTorch init, below, or alternate init is best. 638 | 639 | # nn.init.normal_(self.class_embedding, std=self.scale) 640 | # nn.init.normal_(self.positional_embedding, std=self.scale) 641 | # 642 | # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 643 | # attn_std = self.transformer.width ** -0.5 644 | # fc_std = (2 * self.transformer.width) ** -0.5 645 | # for block in self.transformer.resblocks: 646 | # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 647 | # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 648 | # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 649 | # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 650 | # 651 | # if self.text_projection is not None: 652 | # nn.init.normal_(self.text_projection, std=self.scale) 653 | pass 654 | 655 | @torch.jit.ignore 656 | def set_grad_checkpointing(self, enable=True): 657 | self.transformer.grad_checkpointing = enable 658 | 659 | def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 660 | if self.pool_type == 'avg': 661 | pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:] 662 | elif self.pool_type == 'tok': 663 | pooled, tokens = x[:, 0], x[:, 1:] 664 | else: 665 | pooled = tokens = x 666 | 667 | return pooled, tokens 668 | 669 | def forward(self, x: torch.Tensor, return_all=False): 670 | x = self.conv1(x) # shape = [*, width, grid, grid] 671 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 672 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 673 | 674 | # class embeddings and positional embeddings 675 | x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1) 676 | # shape = [*, grid ** 2 + 1, width] 677 | x = x + self.positional_embedding.to(x.dtype) 678 | 679 | x = self.patch_dropout(x) 680 | x = self.ln_pre(x) 681 | x = self.transformer(x) 682 | 683 | if return_all: 684 | return x #directly returning the output after transformer, mainly used for segmentation 685 | 686 | if self.attn_pool is not None: # for FLAIR, self.attn_pool is set to be 'None', because we are not using the implementation from open_clip 687 | if self.attn_pool_contrastive is not None: 688 | # This is untested, WIP pooling that should match paper 689 | x = self.ln_post(x) # TBD LN first or separate one after each pool? 690 | tokens = self.attn_pool(x) 691 | if self.attn_pool_type == 'parallel': 692 | pooled = self.attn_pool_contrastive(x) 693 | else: 694 | assert self.attn_pool_type == 'cascade' 695 | pooled = self.attn_pool_contrastive(tokens) 696 | elif self.attn_pool_type == 'base_attn_pool' or 'learnable_query': 697 | x_mean = x.mean(dim=1, keepdims=True) #in learnable query, the input q does not matter 698 | pooled = self.attn_pool(x_mean, x, x).squeeze(1) 699 | else: 700 | # this is the original OpenCLIP CoCa setup, does not match paper 701 | x = self.attn_pool(x) 702 | # print("shape after attn_pool:", x.shape) 703 | x = self.ln_post(x) # ln_post() originbally after self.attn_pool() 704 | pooled, tokens = self._global_pool(x) 705 | # print("shape of pooled", pooled.shape) 706 | pooled = pooled.squeeze(1) 707 | # print("shape of pooled after squeeze", pooled.shape) 708 | elif self.text_con: 709 | x = self.ln_post(x) 710 | if self.skip_final_pooling: 711 | pooled = None 712 | tokens = x 713 | else: 714 | pooled, tokens = self._global_pool(x) 715 | return pooled, tokens 716 | else: 717 | if self.final_ln_after_pool: 718 | pooled, tokens = self._global_pool(x) 719 | pooled = self.ln_post(pooled) 720 | else: 721 | x = self.ln_post(x) 722 | pooled, tokens = self._global_pool(x) 723 | 724 | if self.proj is not None: 725 | pooled = pooled @ self.proj 726 | # print("shape after projection:", pooled.shape) 727 | if self.output_tokens: 728 | if self.proj_tokens: 729 | tokens = tokens @ self.proj 730 | return pooled, tokens 731 | return pooled 732 | 733 | 734 | def text_global_pool(x, text: Optional[torch.Tensor] = None, pool_type: str = 'argmax'): 735 | if pool_type == 'first': 736 | pooled, tokens = x[:, 0], x[:, 1:] 737 | elif pool_type == 'last': 738 | pooled, tokens = x[:, -1], x[:, :-1] 739 | elif pool_type == 'argmax': 740 | # take features from the eot embedding (eot_token is the highest number in each sequence) 741 | assert text is not None 742 | pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x 743 | else: 744 | pooled = tokens = x 745 | 746 | return pooled, tokens 747 | 748 | 749 | class TextTransformer(nn.Module): 750 | output_tokens: torch.jit.Final[bool] 751 | 752 | def __init__( 753 | self, 754 | context_length: int = 77, 755 | vocab_size: int = 49408, 756 | width: int = 512, 757 | heads: int = 8, 758 | layers: int = 12, 759 | mlp_ratio: float = 4.0, 760 | ls_init_value: float = None, 761 | output_dim: int = 512, 762 | embed_cls: bool = False, 763 | no_causal_mask: bool = False, 764 | pad_id: int = 0, 765 | pool_type: str = 'argmax', 766 | proj_bias: bool = False, 767 | act_layer: Callable = nn.GELU, 768 | norm_layer: Callable = LayerNorm, 769 | output_tokens: bool = False 770 | ): 771 | super().__init__() 772 | assert pool_type in ('first', 'last', 'argmax', 'none') 773 | self.output_tokens = output_tokens 774 | self.num_pos = self.context_length = context_length 775 | self.vocab_size = vocab_size 776 | self.width = width 777 | self.output_dim = output_dim 778 | self.heads = heads 779 | self.pad_id = pad_id 780 | self.pool_type = pool_type 781 | 782 | self.token_embedding = nn.Embedding(vocab_size, width) 783 | if embed_cls: 784 | self.cls_emb = nn.Parameter(torch.empty(width)) 785 | self.num_pos += 1 786 | else: 787 | self.cls_emb = None 788 | self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) 789 | self.transformer = Transformer( 790 | width=width, 791 | layers=layers, 792 | heads=heads, 793 | mlp_ratio=mlp_ratio, 794 | ls_init_value=ls_init_value, 795 | act_layer=act_layer, 796 | norm_layer=norm_layer, 797 | ) 798 | self.ln_final = norm_layer(width) 799 | 800 | if no_causal_mask: 801 | self.attn_mask = None 802 | else: 803 | self.register_buffer('attn_mask', self.build_causal_mask(), persistent=False) 804 | 805 | if proj_bias: 806 | self.text_projection = nn.Linear(width, output_dim) 807 | else: 808 | self.text_projection = nn.Parameter(torch.empty(width, output_dim)) 809 | 810 | self.init_parameters() 811 | 812 | def init_parameters(self): 813 | nn.init.normal_(self.token_embedding.weight, std=0.02) 814 | nn.init.normal_(self.positional_embedding, std=0.01) 815 | if self.cls_emb is not None: 816 | nn.init.normal_(self.cls_emb, std=0.01) 817 | 818 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 819 | attn_std = self.transformer.width ** -0.5 820 | fc_std = (2 * self.transformer.width) ** -0.5 821 | for block in self.transformer.resblocks: 822 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 823 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 824 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 825 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 826 | 827 | if self.text_projection is not None: 828 | if isinstance(self.text_projection, nn.Linear): 829 | nn.init.normal_(self.text_projection.weight, std=self.transformer.width ** -0.5) 830 | if self.text_projection.bias is not None: 831 | nn.init.zeros_(self.text_projection.bias) 832 | else: 833 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 834 | 835 | @torch.jit.ignore 836 | def set_grad_checkpointing(self, enable=True): 837 | self.transformer.grad_checkpointing = enable 838 | 839 | def build_causal_mask(self): 840 | # lazily create causal attention mask, with full attention between the tokens 841 | # pytorch uses additive attention mask; fill with -inf 842 | mask = torch.empty(self.num_pos, self.num_pos) 843 | mask.fill_(float("-inf")) 844 | mask.triu_(1) # zero out the lower diagonal 845 | return mask 846 | 847 | def build_cls_mask(self, text, cast_dtype: torch.dtype): 848 | cls_mask = (text != self.pad_id).unsqueeze(1) 849 | cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True) 850 | additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) 851 | additive_mask.fill_(0) 852 | additive_mask.masked_fill_(~cls_mask, float("-inf")) 853 | additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) 854 | return additive_mask 855 | 856 | def forward(self, text): 857 | cast_dtype = self.transformer.get_cast_dtype() 858 | seq_len = text.shape[1] 859 | 860 | x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] 861 | attn_mask = self.attn_mask 862 | if self.cls_emb is not None: 863 | seq_len += 1 864 | x = torch.cat([x, _expand_token(self.cls_emb, x.shape[0])], dim=1) 865 | cls_mask = self.build_cls_mask(text, cast_dtype) 866 | if attn_mask is not None: 867 | attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] 868 | 869 | x = x + self.positional_embedding[:seq_len].to(cast_dtype) 870 | x = self.transformer(x, attn_mask=attn_mask) 871 | # x.shape = [batch_size, n_ctx, transformer.width] 872 | if self.cls_emb is not None: 873 | # presence of appended cls embed (CoCa) overrides pool_type, always take last token 874 | pooled, tokens = text_global_pool(x, pool_type='last') 875 | pooled = self.ln_final(pooled) # final LN applied after pooling in this case 876 | else: 877 | x = self.ln_final(x) 878 | pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type) 879 | if self.text_projection is not None: 880 | if isinstance(self.text_projection, nn.Linear): 881 | pooled = self.text_projection(pooled) 882 | else: 883 | pooled = pooled @ self.text_projection 884 | 885 | if self.output_tokens: 886 | return pooled, tokens 887 | else: 888 | return pooled 889 | 890 | -------------------------------------------------------------------------------- /src/inference.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 1 -m main \ 2 | --model ViT-B-16-FLAIR \ 3 | --huggingface-repo-name xiaorui638/flair \ 4 | --huggingface-model-name flair-cc3m-recap.pt \ 5 | --inference-with-flair \ 6 | --coco-data-root-dir ./datasets/coco \ 7 | --flickr-data-root-dir ./datasets/flickr30k-images \ 8 | --iiw-retrieval-dir ./datasets/imageinwords/ \ 9 | --docci-retrieval-dir ./datasets/docci \ 10 | --urban-1k-retrieval-dir ./datasets/Urban1k \ 11 | --sharegpt4v-retrieval-dir ./datasets/share4v \ 12 | --retrieval-coco \ 13 | --retrieval-flickr \ 14 | --retrieval-docci \ 15 | --retrieval-iiw \ 16 | --retrieval-urban-1k \ 17 | --retrieval-sharegpt4v-1k \ 18 | --retrieval-sharegpt4v-10k \ 19 | --batch-size 128 \ 20 | --precision amp \ 21 | --workers 25 \ 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import os 4 | import re 5 | import subprocess 6 | import sys 7 | import random 8 | from datetime import datetime 9 | import numpy as np 10 | import torch 11 | from torch import optim 12 | from torch.cuda.amp import GradScaler 13 | from huggingface_hub import hf_hub_download 14 | 15 | try: 16 | import wandb 17 | except ImportError: 18 | wandb = None 19 | 20 | try: 21 | import torch.utils.tensorboard as tensorboard 22 | except ImportError: 23 | tensorboard = None 24 | 25 | try: 26 | import horovod.torch as hvd 27 | except ImportError: 28 | hvd = None 29 | 30 | from open_clip_train.distributed import is_master, init_distributed_device, broadcast_object 31 | from open_clip_train.logger import setup_logging 32 | from open_clip_train.scheduler import cosine_lr, const_lr, const_lr_cooldown 33 | from open_clip_train.file_utils import pt_load, check_exists, start_sync_process, remote_sync 34 | 35 | from flair.params import parse_args 36 | from flair.factory import create_model_and_transforms, get_tokenizer, create_loss 37 | from flair.train import train_one_epoch, evaluate 38 | from flair.data import get_data 39 | 40 | LATEST_CHECKPOINT_NAME = "epoch_latest.pt" 41 | 42 | 43 | def random_seed(seed=42, rank=0): 44 | torch.manual_seed(seed + rank) 45 | np.random.seed(seed + rank) 46 | random.seed(seed + rank) 47 | 48 | 49 | def natural_key(string_): 50 | """See http://www.codinghorror.com/blog/archives/001018.html""" 51 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 52 | 53 | 54 | def get_latest_checkpoint(path: str, remote: bool): 55 | # as writen, this glob recurses, so can pick up checkpoints across multiple sub-folders 56 | if remote: 57 | result = subprocess.run(["aws", "s3", "ls", path + "/"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) 58 | print(result) 59 | if result.returncode == 1: 60 | return None 61 | checkpoints = [os.path.join(path, x.split(' ')[-1]) for x in result.stdout.decode().split('\n')[:-1]] 62 | else: 63 | checkpoints = glob.glob(path + '**/*.pt', recursive=True) 64 | if checkpoints: 65 | checkpoints = sorted(checkpoints, key=natural_key) 66 | return checkpoints[-1] 67 | return None 68 | 69 | 70 | def download_weights_from_hf(model_repo, filename): 71 | # Define the custom cache directory relative to the current script 72 | cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "pretrained") 73 | if not os.path.exists(cache_dir): 74 | os.makedirs(cache_dir, exist_ok=True) 75 | local_path = hf_hub_download(repo_id=model_repo, filename=filename, cache_dir=cache_dir) 76 | return local_path 77 | 78 | 79 | def main(args): 80 | args = parse_args(args) 81 | if torch.cuda.is_available(): 82 | # This enables tf32 on Ampere GPUs which is only 8% slower than 83 | # float16 and almost as accurate as float32 84 | # This was a default in pytorch until 1.12 85 | torch.backends.cuda.matmul.allow_tf32 = True 86 | torch.backends.cudnn.benchmark = True 87 | torch.backends.cudnn.deterministic = False 88 | 89 | # fully initialize distributed device environment 90 | device = init_distributed_device(args) 91 | 92 | # get the name of the experiments 93 | if args.name is None: 94 | # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule? 95 | model_name_safe = args.model.replace('/', '-') 96 | date_str = datetime.now().strftime("%Y_%m_%d-%H_%M_%S") 97 | if args.distributed: 98 | # sync date_str from master to all ranks 99 | date_str = broadcast_object(args, date_str) 100 | args.name = '-'.join([ 101 | date_str, 102 | f"model_{model_name_safe}", 103 | f"lr_{args.lr}", 104 | f"b_{args.batch_size}", 105 | f"j_{args.workers}", 106 | f"p_{args.precision}", 107 | ]) 108 | 109 | resume_latest = args.resume == 'latest' 110 | log_base_path = os.path.join(args.logs_dir, args.name) 111 | args.log_path = None 112 | if is_master(args, local=args.log_local): 113 | os.makedirs(log_base_path, exist_ok=True) 114 | log_filename = f'out-{args.rank}' if args.log_local else 'out.log' 115 | args.log_path = os.path.join(log_base_path, log_filename) 116 | if os.path.exists(args.log_path) and not resume_latest: 117 | print( 118 | "Error. Experiment already exists. Use --name {} to specify a new experiment." 119 | ) 120 | return -1 121 | 122 | # Setup text logger 123 | args.log_level = logging.DEBUG if args.debug else logging.INFO 124 | setup_logging(args.log_path, args.log_level) 125 | 126 | # Setup wandb, tensorboard, checkpoint logging 127 | args.wandb = 'wandb' in args.report_to or 'all' in args.report_to 128 | args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to 129 | args.checkpoint_path = os.path.join(log_base_path, "checkpoints") 130 | if is_master(args): 131 | args.tensorboard_path = os.path.join(log_base_path, "tensorboard") if args.tensorboard else '' 132 | for dirname in [args.tensorboard_path, args.checkpoint_path]: 133 | if dirname: 134 | os.makedirs(dirname, exist_ok=True) 135 | else: 136 | args.tensorboard_path = '' 137 | 138 | if resume_latest: 139 | resume_from = None 140 | checkpoint_path = args.checkpoint_path 141 | # If using remote_sync, need to check the remote instead of the local checkpoints folder. 142 | if args.remote_sync is not None: 143 | checkpoint_path = os.path.join(args.remote_sync, args.name, "checkpoints") 144 | if args.save_most_recent: 145 | print('Error. Cannot use save-most-recent with remote_sync and resume latest.') 146 | return -1 147 | if args.remote_sync_protocol != 's3': 148 | print('Error. Sync protocol not supported when using resume latest.') 149 | return -1 150 | if is_master(args): 151 | # Checking for existing checkpoint via master rank only. It is possible for 152 | # different rank processes to see different files if a shared file-system is under 153 | # stress, however it's very difficult to fully work around such situations. 154 | if args.save_most_recent: 155 | # if --save-most-recent flag is set, look for latest at a fixed filename 156 | resume_from = os.path.join(checkpoint_path, LATEST_CHECKPOINT_NAME) 157 | if not os.path.exists(resume_from): 158 | # If no latest checkpoint has been saved yet, don't try to resume 159 | resume_from = None 160 | else: 161 | # otherwise, list checkpoint dir contents and pick the newest checkpoint 162 | resume_from = get_latest_checkpoint(checkpoint_path, remote=args.remote_sync is not None) 163 | if resume_from: 164 | logging.info(f'Found latest resume checkpoint at {resume_from}.') 165 | else: 166 | logging.info(f'No latest resume checkpoint found in {checkpoint_path}.') 167 | if args.distributed: 168 | # sync found checkpoint path to all ranks 169 | resume_from = broadcast_object(args, resume_from) 170 | args.resume = resume_from 171 | 172 | if args.copy_codebase: 173 | copy_codebase(args) 174 | 175 | # start the sync proces if remote-sync is not None 176 | remote_sync_process = None 177 | if is_master(args) and args.remote_sync is not None: 178 | # first make sure it works 179 | result = remote_sync( 180 | os.path.join(args.logs_dir, args.name), 181 | os.path.join(args.remote_sync, args.name), 182 | args.remote_sync_protocol 183 | ) 184 | if result: 185 | logging.info('remote sync successful.') 186 | else: 187 | logging.info('Error: remote sync failed. Exiting.') 188 | return -1 189 | # if all looks good, start a process to do this every args.remote_sync_frequency seconds 190 | remote_sync_process = start_sync_process( 191 | args.remote_sync_frequency, 192 | os.path.join(args.logs_dir, args.name), 193 | os.path.join(args.remote_sync, args.name), 194 | args.remote_sync_protocol 195 | ) 196 | remote_sync_process.start() 197 | 198 | if args.precision == 'fp16': 199 | logging.warning( 200 | 'It is recommended to use AMP mixed-precision instead of FP16. ' 201 | 'FP16 support needs further verification and tuning, especially for train.') 202 | 203 | if args.horovod: 204 | logging.info( 205 | f'Running in horovod mode with multiple processes / nodes. Device: {args.device}.' 206 | f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.') 207 | elif args.distributed: 208 | logging.info( 209 | f'Running in distributed mode with multiple processes. Device: {args.device}.' 210 | f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.') 211 | else: 212 | logging.info(f'Running with a single process. Device {args.device}.') 213 | 214 | dist_model = None 215 | 216 | if isinstance(args.force_image_size, (tuple, list)) and len(args.force_image_size) == 1: 217 | # arg is nargs, single (square) image size list -> int 218 | args.force_image_size = args.force_image_size[0] 219 | 220 | random_seed(args.seed, 0) 221 | model_kwargs = {} 222 | 223 | # based on huggingface model name, download the pre-trained weights, the downloaded path is passed as the 'pretrained' arguments 224 | if args.huggingface_model_name != '': 225 | huggingface_model_name, huggingface_repo_name = args.huggingface_model_name, args.huggingface_repo_name 226 | args.pretrained = download_weights_from_hf(model_repo=huggingface_repo_name, filename=huggingface_model_name) 227 | 228 | model, preprocess_train, preprocess_val = create_model_and_transforms( 229 | args.model, 230 | args.pretrained, 231 | precision=args.precision, 232 | device=device, 233 | jit=args.torchscript, 234 | force_quick_gelu=args.force_quick_gelu, 235 | force_custom_text=args.force_custom_text, 236 | force_patch_dropout=args.force_patch_dropout, 237 | force_image_size=args.force_image_size, 238 | image_mean=args.image_mean, 239 | image_std=args.image_std, 240 | image_interpolation=args.image_interpolation, 241 | image_resize_mode=args.image_resize_mode, # only effective for inference 242 | aug_cfg=args.aug_cfg, 243 | pretrained_image=args.pretrained_image, 244 | output_dict=True, 245 | **model_kwargs, 246 | ) 247 | model.to(device) 248 | random_seed(args.seed, args.rank) 249 | 250 | if args.grad_checkpointing: 251 | model.set_grad_checkpointing() 252 | 253 | if is_master(args): 254 | logging.info("Model:") 255 | logging.info(f"{str(model)}") 256 | logging.info("Params:") 257 | params_file = os.path.join(args.logs_dir, args.name, "params.txt") 258 | with open(params_file, "w") as f: 259 | for name in sorted(vars(args)): 260 | val = getattr(args, name) 261 | logging.info(f" {name}: {val}") 262 | f.write(f"{name}: {val}\n") 263 | 264 | if args.distributed and not args.horovod: 265 | if args.use_bn_sync: 266 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 267 | ddp_args = {} 268 | if args.ddp_static_graph: 269 | # this doesn't exist in older PyTorch, arg only added if enabled 270 | ddp_args['static_graph'] = True 271 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args) 272 | 273 | # create optimizer and scaler 274 | optimizer = None 275 | scaler = None 276 | 277 | if args.train_data or args.dataset_type == "synthetic": 278 | assert not args.trace, 'Cannot train with traced model' 279 | 280 | exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n 281 | include = lambda n, p: not exclude(n, p) 282 | 283 | named_parameters = list(model.named_parameters()) 284 | gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad] 285 | rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] 286 | 287 | optimizer = optim.AdamW( 288 | [ 289 | {"params": gain_or_bias_params, "weight_decay": 0.}, 290 | {"params": rest_params, "weight_decay": args.wd}, 291 | ], 292 | lr=args.lr, 293 | betas=(args.beta1, args.beta2), 294 | eps=args.eps, 295 | ) 296 | if args.horovod: 297 | optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters()) 298 | hvd.broadcast_parameters(model.state_dict(), root_rank=0) 299 | hvd.broadcast_optimizer_state(optimizer, root_rank=0) 300 | 301 | scaler = GradScaler() if args.precision == "amp" else None 302 | 303 | # optionally resume from a checkpoint 304 | start_epoch = 0 305 | if args.resume is not None: 306 | checkpoint = pt_load(args.resume, map_location='cpu') 307 | if 'epoch' in checkpoint: 308 | # resuming a train checkpoint w/ epoch and optimizer state 309 | start_epoch = checkpoint["epoch"] 310 | sd = checkpoint["state_dict"] 311 | if not args.distributed and next(iter(sd.items()))[0].startswith('module'): 312 | sd = {k[len('module.'):]: v for k, v in sd.items()} 313 | model.load_state_dict(sd) 314 | if optimizer is not None: 315 | optimizer.load_state_dict(checkpoint["optimizer"]) 316 | if scaler is not None and 'scaler' in checkpoint: 317 | scaler.load_state_dict(checkpoint['scaler']) 318 | logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})") 319 | else: 320 | # loading a bare (model only) checkpoint for fine-tune or evaluation 321 | model.load_state_dict(checkpoint) 322 | logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") 323 | 324 | # initialize datasets 325 | tokenizer = get_tokenizer(args.model) 326 | data = get_data( 327 | args, 328 | (preprocess_train, preprocess_val), 329 | epoch=start_epoch, 330 | tokenizer=tokenizer, 331 | ) 332 | assert len(data), 'At least one train or eval dataset must be specified.' 333 | 334 | # create scheduler if train 335 | scheduler = None 336 | if 'train' in data and optimizer is not None: 337 | total_steps = (data["train"].dataloader.num_batches // args.accum_freq) * args.epochs 338 | if args.lr_scheduler == "cosine": 339 | scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps) 340 | elif args.lr_scheduler == "const": 341 | scheduler = const_lr(optimizer, args.lr, args.warmup, total_steps) 342 | elif args.lr_scheduler == "const-cooldown": 343 | assert args.epochs_cooldown is not None, \ 344 | "Please specify the number of cooldown epochs for this lr schedule." 345 | cooldown_steps = (data["train"].dataloader.num_batches // args.accum_freq) * args.epochs_cooldown 346 | scheduler = const_lr_cooldown( 347 | optimizer, args.lr, args.warmup, total_steps, 348 | cooldown_steps, args.lr_cooldown_power, args.lr_cooldown_end) 349 | else: 350 | logging.error( 351 | f'Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, const-cooldown.') 352 | exit(1) 353 | 354 | # determine if this worker should save logs and checkpoints. only do so if it is rank == 0 355 | args.save_logs = args.logs_dir and args.logs_dir.lower() != 'none' and is_master(args) 356 | writer = None 357 | if args.save_logs and args.tensorboard: 358 | assert tensorboard is not None, "Please install tensorboard." 359 | writer = tensorboard.SummaryWriter(args.tensorboard_path) 360 | 361 | if args.wandb and is_master(args): 362 | assert wandb is not None, 'Please install wandb.' 363 | logging.debug('Starting wandb.') 364 | args.train_sz = data["train"].dataloader.num_samples 365 | if args.val_data is not None: 366 | args.val_sz = data["val"].dataloader.num_samples 367 | # you will have to configure this for your project! 368 | wandb.init( 369 | project=args.wandb_project_name, 370 | name=args.name, 371 | id=args.name, 372 | notes=args.wandb_notes, 373 | tags=[], 374 | resume='auto' if args.resume == "latest" else None, 375 | config=vars(args), 376 | ) 377 | if args.debug: 378 | wandb.watch(model, log='all') 379 | wandb.save(params_file) 380 | logging.debug('Finished loading wandb.') 381 | 382 | # Pytorch 2.0 adds '_orig_mod.' prefix to keys of state_dict() of compiled models. 383 | # For compatibility, we save state_dict() of the original model, which shares the 384 | # weights without the prefix. 385 | original_model = model 386 | if args.torchcompile: 387 | logging.info('Compiling model...') 388 | model = torch.compile(original_model) 389 | 390 | if 'train' not in data: 391 | # If using int8, convert to inference mode. 392 | if args.use_bnb_linear is not None: 393 | from open_clip.utils import convert_int8_model_to_inference_mode 394 | convert_int8_model_to_inference_mode(model) 395 | # Evaluate. 396 | evaluate(model, data, start_epoch, args, tb_writer=writer, tokenizer=tokenizer) 397 | return 398 | 399 | loss = create_loss(args) 400 | 401 | for epoch in range(start_epoch, args.epochs): 402 | if is_master(args): 403 | logging.info(f'Start epoch {epoch}') 404 | 405 | train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args, tb_writer=writer) 406 | completed_epoch = epoch + 1 407 | 408 | if any(v in data for v in ( 409 | 'val', 'imagenet-val', 'imagenet-v2', 'retrieval_coco', 'retrieval_flickr', 410 | 'retrieval_urban_1k', 'retrieval_iiw', 'retrieval_dci', 'sharegpt4v-1k', 'sharegpt4v-10k')): 411 | evaluate(model, data, completed_epoch, args, tb_writer=writer, tokenizer=tokenizer) 412 | 413 | # Saving checkpoints. 414 | if args.save_logs: 415 | checkpoint_dict = { 416 | "epoch": completed_epoch, 417 | "name": args.name, 418 | "state_dict": original_model.state_dict(), 419 | "optimizer": optimizer.state_dict(), 420 | } 421 | if scaler is not None: 422 | checkpoint_dict["scaler"] = scaler.state_dict() 423 | 424 | if completed_epoch == args.epochs or ( 425 | args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 426 | ): 427 | torch.save( 428 | checkpoint_dict, 429 | os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), 430 | ) 431 | if args.delete_previous_checkpoint: 432 | previous_checkpoint = os.path.join(args.checkpoint_path, f"epoch_{completed_epoch - 1}.pt") 433 | if os.path.exists(previous_checkpoint): 434 | os.remove(previous_checkpoint) 435 | 436 | if args.save_most_recent: 437 | # try not to corrupt the latest checkpoint if save fails 438 | tmp_save_path = os.path.join(args.checkpoint_path, "tmp.pt") 439 | latest_save_path = os.path.join(args.checkpoint_path, LATEST_CHECKPOINT_NAME) 440 | torch.save(checkpoint_dict, tmp_save_path) 441 | os.replace(tmp_save_path, latest_save_path) 442 | 443 | if args.wandb and is_master(args): 444 | wandb.finish() 445 | 446 | # run a final sync. 447 | if remote_sync_process is not None: 448 | logging.info('Final remote sync.') 449 | remote_sync_process.terminate() 450 | result = remote_sync( 451 | os.path.join(args.logs_dir, args.name), 452 | os.path.join(args.remote_sync, args.name), 453 | args.remote_sync_protocol 454 | ) 455 | if result: 456 | logging.info('Final remote sync successful.') 457 | else: 458 | logging.info('Final remote sync failed.') 459 | 460 | 461 | def copy_codebase(args): 462 | from shutil import copytree, ignore_patterns 463 | new_code_path = os.path.join(args.logs_dir, args.name, "code") 464 | if os.path.exists(new_code_path): 465 | print( 466 | f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment." 467 | ) 468 | return -1 469 | print(f"Copying codebase to {new_code_path}") 470 | current_code_path = os.path.realpath(__file__) 471 | for _ in range(3): 472 | current_code_path = os.path.dirname(current_code_path) 473 | copytree(current_code_path, new_code_path, ignore=ignore_patterns('log', 'logs', 'wandb')) 474 | print("Done copying code.") 475 | return 1 476 | 477 | 478 | if __name__ == "__main__": 479 | main(sys.argv[1:]) 480 | -------------------------------------------------------------------------------- /src/minimal_example.py: -------------------------------------------------------------------------------- 1 | import flair 2 | from PIL import Image 3 | import torch 4 | 5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 6 | print(f"Using device: {device}") 7 | 8 | pretrained = flair.download_weights_from_hf(model_repo='xiaorui638/flair', filename='flair-cc3m-recap.pt') 9 | model, _, preprocess = flair.create_model_and_transforms('ViT-B-16-FLAIR', pretrained=pretrained) 10 | 11 | model.to(device) 12 | model.eval() 13 | 14 | tokenizer = flair.get_tokenizer('ViT-B-16-FLAIR') 15 | 16 | image = preprocess(Image.open("../assets/puppy.jpg")).unsqueeze(0).to(device) 17 | 18 | text = tokenizer(["In the image, a small white puppy with black ears and eyes is the main subject", # ground-truth caption 19 | "The white door behind the puppy is closed, and there's a window on the right side of the door", # ground-truth caption 20 | "A red ladybug is surrounded by green glass beads", # non-ground-truth caption 21 | "Dominating the scene is a white desk, positioned against a white brick wall"]).to(device) # non-ground-truth caption 22 | 23 | with torch.no_grad(), torch.cuda.amp.autocast(): 24 | flair_logits = model.get_logits(image=image, text=text) 25 | clip_logits = model.get_logits_as_clip(image=image, text=text) 26 | 27 | print("logits get using flair's way:", flair_logits) # [4.4062, 6.9531, -20.5000, -18.1719] 28 | print("logits get using clip's way:", clip_logits) # [12.4609, 15.6797, -3.8535, -0.2281] 29 | 30 | 31 | -------------------------------------------------------------------------------- /src/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | ipykernel 3 | open_clip_torch 4 | open_clip_torch[training] -------------------------------------------------------------------------------- /src/train_cc12m_slurm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -x 2 | #SBATCH --nodes=8 3 | #SBATCH --gres=gpu:4 4 | #SBATCH --ntasks-per-node=4 5 | #SBATCH --cpus-per-task=48 6 | #SBATCH --job-name=train_cc12m_slurm 7 | #SBATCH --account=ACCOUNT_NAME 8 | #SBATCH --partition PARTITION_NAME 9 | 10 | source flair_env/bin/activate 11 | cd flair/src 12 | 13 | export CUDA_VISIBLE_DEVICES=0,1,2,3 14 | export MASTER_PORT=12802 15 | 16 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 17 | export MASTER_ADDR=$master_addr 18 | 19 | 20 | srun env -u CUDA_VISIBLE_DEVICES python -u -m torchrun \ 21 | --nnode=$SLURM_JOB_NUM_NODES --nproc_per_node=gpu --rdzv_id=$SLURM_JOB_ID \ 22 | --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT --rdzv_backend=c10d -m main \ 23 | --logs-dir ./logs \ 24 | --model ViT-B-16-FLAIR \ 25 | --use-flair-loss \ 26 | --add-mps-loss \ 27 | --train-dataset-type webdataset \ 28 | --lr 5e-4 \ 29 | --warmup 2000 \ 30 | --epochs 32 \ 31 | --caption-sampling-mode diverse_sampling \ 32 | --num-sampled-captions 8 \ 33 | --log-every-n-steps 100 \ 34 | --train-data 'datasets/cc12m_recap/cc12m-train-{0000..2175}.tar' \ 35 | --train-num-samples 10010225 \ 36 | --delete-previous-checkpoint \ 37 | --batch-size 192 \ 38 | --precision amp \ 39 | --workers 48 \ 40 | --beta1 0.9 \ 41 | --beta2 0.98 \ 42 | --wd 0.5 \ 43 | --eps 1e-8 \ 44 | 45 | -------------------------------------------------------------------------------- /src/train_cc3m_slurm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -x 2 | #SBATCH --nodes=2 3 | #SBATCH --gres=gpu:4 4 | #SBATCH --ntasks-per-node=4 5 | #SBATCH --cpus-per-task=48 6 | #SBATCH --job-name=train_cc3m_slurm 7 | #SBATCH --account=ACCOUNT_NAME 8 | #SBATCH --partition PARTITION_NAME 9 | 10 | source flair_env/bin/activate 11 | cd flair/src 12 | 13 | export CUDA_VISIBLE_DEVICES=0,1,2,3 14 | export MASTER_PORT=12802 15 | 16 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 17 | export MASTER_ADDR=$master_addr 18 | 19 | 20 | srun env -u CUDA_VISIBLE_DEVICES python -u -m torchrun \ 21 | --nnode=$SLURM_JOB_NUM_NODES --nproc_per_node=gpu --rdzv_id=$SLURM_JOB_ID \ 22 | --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT --rdzv_backend=c10d -m main \ 23 | --logs-dir ./logs \ 24 | --model ViT-B-16-FLAIR \ 25 | --use-flair-loss \ 26 | --add-mps-loss \ 27 | --train-dataset-type webdataset \ 28 | --lr 5e-4 \ 29 | --warmup 2000 \ 30 | --epochs 32 \ 31 | --caption-sampling-mode diverse_sampling \ 32 | --num-sampled-captions 8 \ 33 | --log-every-n-steps 100 \ 34 | --train-data 'datasets/cc3m_recap/cc3m-train-{0000..0575}.tar' \ 35 | --train-num-samples 2823019 \ 36 | --delete-previous-checkpoint \ 37 | --batch-size 128 \ 38 | --precision amp \ 39 | --workers 48 \ 40 | --beta1 0.9 \ 41 | --beta2 0.98 \ 42 | --wd 0.5 \ 43 | --eps 1e-8 \ 44 | 45 | -------------------------------------------------------------------------------- /src/train_example.sh: -------------------------------------------------------------------------------- 1 | torchrun --nnode 1 --nproc_per_node 4 -m main \ 2 | --model ViT-B-16-FLAIR \ 3 | --train-data './datasets/cc3m_recaptioned/' \ 4 | --train-num-samples 2823019 \ 5 | --train-dataset-type webdataset \ 6 | --use-flair-loss \ 7 | --add-mps-loss \ 8 | --num-sampled-captions 8 \ 9 | --log-every-n-steps 200 \ 10 | --caption-sampling-mode diverse_sampling \ 11 | --batch-size 128 \ 12 | --precision amp \ 13 | --workers 48 \ 14 | --delete-previous-checkpoint \ 15 | --beta1 0.9 \ 16 | --beta2 0.98 \ 17 | --wd 0.5 \ 18 | --eps 1e-8 \ 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /src/train_merged30m_slurm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -x 2 | #SBATCH --nodes=2 3 | #SBATCH --gres=gpu:4 4 | #SBATCH --ntasks-per-node=4 5 | #SBATCH --cpus-per-task=48 6 | #SBATCH --wait-all-nodes=1 7 | #SBATCH --job-name=train_cc3m_slurm 8 | #SBATCH --account=ACCOUNT_NAME 9 | #SBATCH --partition PARTITION_NAME 10 | 11 | source flair_env/bin/activate 12 | cd flair/src 13 | 14 | export CUDA_VISIBLE_DEVICES=0,1,2,3 15 | export MASTER_PORT=12802 16 | 17 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 18 | export MASTER_ADDR=$master_addr 19 | 20 | 21 | srun env -u CUDA_VISIBLE_DEVICES python -u -m torchrun \ 22 | --nnode=$SLURM_JOB_NUM_NODES --nproc_per_node=gpu --rdzv_id=$SLURM_JOB_ID \ 23 | --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT --rdzv_backend=c10d -m main \ 24 | --logs-dir ./logs \ 25 | --model ViT-B-16-FLAIR \ 26 | --use-flair-loss \ 27 | --add-mps-loss \ 28 | --train-dataset-type webdataset \ 29 | --lr 5e-4 \ 30 | --warmup 2000 \ 31 | --epochs 32 \ 32 | --caption-sampling-mode diverse_sampling \ 33 | --num-sampled-captions 10 \ 34 | --log-every-n-steps 100 \ 35 | --train-data '/datasets/yfcc15m_recap/yfcc15m-train-{0000..3636}.tar::/datasets/cc12m_recap/cc12m-train-{0000..2175}.tar::/datasets/cc3m_recap/cc3m-train-{0001..0575}.tar' \ 36 | --train-num-samples 27000000 \ 37 | --delete-previous-checkpoint \ 38 | --batch-size 128 \ 39 | --precision amp \ 40 | --workers 48 \ 41 | --beta1 0.9 \ 42 | --beta2 0.98 \ 43 | --wd 0.2 \ 44 | --eps 1e-6 \ 45 | 46 | -------------------------------------------------------------------------------- /src/train_yfcc15m_slurm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -x 2 | #SBATCH --nodes=2 3 | #SBATCH --gres=gpu:4 4 | #SBATCH --ntasks-per-node=4 5 | #SBATCH --cpus-per-task=48 6 | #SBATCH --job-name=train_yfcc15m_slurm 7 | #SBATCH --account=ACCOUNT_NAME 8 | #SBATCH --partition PARTITION_NAME 9 | 10 | source flair_env/bin/activate 11 | cd flair/src 12 | 13 | export CUDA_VISIBLE_DEVICES=0,1,2,3 14 | export MASTER_PORT=12802 15 | 16 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 17 | export MASTER_ADDR=$master_addr 18 | 19 | 20 | srun env -u CUDA_VISIBLE_DEVICES python -u -m torchrun \ 21 | --nnode=$SLURM_JOB_NUM_NODES --nproc_per_node=gpu --rdzv_id=$SLURM_JOB_ID \ 22 | --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT --rdzv_backend=c10d -m main \ 23 | --logs-dir ./logs \ 24 | --model ViT-B-16-FLAIR \ 25 | --use-flair-loss \ 26 | --add-mps-loss \ 27 | --train-dataset-type webdataset \ 28 | --lr 5e-4 \ 29 | --warmup 2000 \ 30 | --epochs 32 \ 31 | --caption-sampling-mode diverse_sampling \ 32 | --num-sampled-captions 10 \ 33 | --log-every-n-steps 100 \ 34 | --train-data 'datasets/yfcc15m_recap/yfcc15m-train-{0000..3636}.tar' \ 35 | --train-num-samples 14065827 \ 36 | --delete-previous-checkpoint \ 37 | --batch-size 192 \ 38 | --precision amp \ 39 | --workers 48 \ 40 | --beta1 0.9 \ 41 | --beta2 0.98 \ 42 | --wd 0.5 \ 43 | --eps 1e-8 \ 44 | 45 | --------------------------------------------------------------------------------