├── README.md ├── assets ├── .gitkeep ├── cosmos.png ├── framework.png └── qualitative_results_supp.png ├── datasets ├── README.md └── imagenet_organize.py ├── requirements.txt └── src ├── dataloaders ├── __init__.py ├── caltech101.py ├── cifar10.py ├── cifar100.py ├── dtd.py ├── fgvc_aircraft.py ├── flowers102.py ├── food101.py ├── label.json ├── oxford_pets.py ├── stanford_car.py ├── sun397.py ├── templates.json └── utils.py ├── inference_classification.sh ├── inference_retrieval.sh ├── inference_segmentation.sh ├── main.py ├── open_clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── coca_model.py ├── constants.py ├── convert.py ├── factory.py ├── hf_configs.py ├── hf_model.py ├── loss.py ├── model.py ├── model_configs │ ├── ViT-B-16.json │ └── ViT-B-32.json ├── modified_resnet.py ├── openai.py ├── pos_embed.py ├── pretrained.py ├── push_to_hf_hub.py ├── timm_model.py ├── tokenizer.py ├── transform.py ├── transformer.py ├── utils.py ├── version.py ├── zero_shot_classifier.py └── zero_shot_metadata.py ├── seg_eval.py ├── train_cc12m.sh ├── train_cc3m.sh ├── train_merged30m.sh ├── train_pixelprose.sh ├── train_yfcc15m.sh └── training ├── __init__.py ├── clip_segmentor.py ├── custom_datasets.py ├── data.py ├── distributed.py ├── file_utils.py ├── logger.py ├── pamr.py ├── params.py ├── precision.py ├── profiler.py ├── scheduler.py ├── seg_configs ├── base_config.py ├── cfg_ade20k.py ├── cfg_city_scapes.py ├── cfg_coco_object.py ├── cfg_coco_stuff164k.py ├── cfg_context59.py ├── cfg_context60.py ├── cfg_voc20.py ├── cfg_voc21.py ├── cls_ade20k.txt ├── cls_city_scapes.txt ├── cls_coco_object.txt ├── cls_coco_stuff.txt ├── cls_context59.txt ├── cls_context60.txt ├── cls_voc20.txt ├── cls_voc21.txt ├── convert_cityscapes.py └── convert_coco_object.py ├── train.py └── zero_shot.py /README.md: -------------------------------------------------------------------------------- 1 | # [CVPR 2025] COSMOS: Cross-Modality Self-Distillation for Vision Language Pre-training 2 | [![Paper](https://img.shields.io/badge/paper-arxiv.2412.03561-B31B1B.svg)](https://arxiv.org/abs/2412.01814) 3 | [![Hugging Face](https://img.shields.io/badge/HuggingFace-COSMOS-FFD700?logo=huggingface&logoColor=yellow)](https://huggingface.co/sankim2/cosmos) 4 | 5 | 6 | **Authors:** [Sanghwan Kim](https://kim-sanghwan.github.io/), [Rui Xiao](https://www.eml-munich.de/people/rui-xiao), [Mariana-Iuliana Georgescu](https://lilygeorgescu.github.io/), [Stephan Alaniz](https://www.eml-munich.de/people/stephan-alaniz), [Zeynep Akata](https://www.eml-munich.de/people/zeynep-akata) 7 | 8 | ### Abstract 9 | Vision-Language Models (VLMs) trained with contrastive loss have achieved significant advancements in various vision and language tasks. However, the global nature of contrastive loss makes VLMs focus predominantly on foreground objects, neglecting other crucial information in the image, which limits their effectiveness in downstream tasks. To address these challenges, we propose COSMOS: CrOSs-MOdality Self-distillation for vision-language pre-training that integrates a novel text-cropping strategy and cross-attention module into a self-supervised learning framework. We create global and local views of images and texts (i.e., multi-modal augmentations), which are essential for self-distillation in VLMs. We further introduce a cross-attention module, enabling COSMOS to learn comprehensive cross-modal representations optimized via a cross-modality self-distillation loss. COSMOS consistently outperforms previous strong baselines on various zero-shot downstream tasks, including retrieval, classification, and semantic segmentation. Additionally, it surpasses CLIP-based models trained on larger datasets in visual perception and contextual understanding tasks. 10 | 11 | ## Methodology 12 | ![](assets/framework.png "An overview of COSMOS") 13 | 14 | ## Pre-trained Model Weights 15 | 16 | We released the pre-trained COSMOS models on [Huggingface](https://huggingface.co/sankim2/cosmos). Our pre-trained models and their corresponding performances on COCO (I2T R@1 and T2I R@1), Flickr (I2T R@1 and T2I R@1) and ImageNet (Top-1) are reported below. For the full results, please refer to our [paper](https://arxiv.org/abs/2412.01814). 17 | 18 | | **Checkpoints** | **Arch.** | **Datasets** | **COCO I2T** | **COCO T2I** | **Flickr I2T** | **Flickr T2I** | **IN Top-1** | 19 | |------------------------------------------------------------------------------------------------------------|------------------|--------------|--------------|----------------|----------------|----------------|----------------| 20 | | [cosmos_vitb16_cc3m](https://huggingface.co/sankim2/cosmos/resolve/main/cosmos_vitb16_cc3m.pt?download=true) | ViT-B/16 | CC3M-recap | 53.1 | 40.1 | 84.1 | 68.6 |37.1 | 21 | | [cosmos_vitb16_cc12m](https://huggingface.co/sankim2/cosmos/resolve/main/cosmos_vitb16_cc12m.pt?download=true) | ViT-B/16 | CC12M-recap | 64.2 | 48.9 | 91.4 | 76.2 |51.4 | 22 | | [cosmos_vitb16_yfcc15m](https://huggingface.co/sankim2/cosmos/resolve/main/cosmos_vitb16_yfcc15m.pt?download=true) | ViT-B/16 | YFCC15M-recap | 67.5 | 50.9 | 92.6 | 79.6 |52.4 | 23 | | [cosmos_vitb16_merged30m](https://huggingface.co/sankim2/cosmos/resolve/main/cosmos_vitb16_merged30m.pt?download=true) | ViT-B/16 | Merged30M | 68.0 | 52.5 | 92.9 | 80.3 |57.6 | 24 | | [cosmos_vitb16_pixelprose](https://huggingface.co/sankim2/cosmos/resolve/main/cosmos_vitb16_pixelprose.pt?download=true) | ViT-B/16 | PixelProse | 62.4 | 43.4 | 89.9 | 73.6 |59.6 | 25 | | [cosmos_vitb32_cc3m](https://huggingface.co/sankim2/cosmos/resolve/main/cosmos_vitb32_cc3m.pt?download=true) | ViT-B/32 | CC3M-recap | 47.6 | 33.5 | 74.3 | 59.2 |33.0 | 26 | | [cosmos_vitb32_cc12m](https://huggingface.co/sankim2/cosmos/resolve/main/cosmos_vitb32_cc12m.pt?download=true) | ViT-B/32 | CC12M-recap | 59.6 | 43.0 | 86.5 | 69.8 |46.7 | 27 | | [cosmos_vitb32_yfcc15m](https://huggingface.co/sankim2/cosmos/resolve/main/cosmos_vitb32_yfcc15m.pt?download=true) | ViT-B/32 | YFCC15M-recap | 64.5 | 46.0 | 90.2 | 73.3 |48.1 | 28 | | [cosmos_vitb32_merged30m](https://huggingface.co/sankim2/cosmos/resolve/main/cosmos_vitb32_merged30m.pt?download=true) | ViT-B/32 | Merged30M | 64.3 | 48.4 | 89.9 | 76.1 |53.4 | 29 | | [cosmos_vitb32_pixelprose](https://huggingface.co/sankim2/cosmos/resolve/main/cosmos_vitb32_pixelprose.pt?download=true) | ViT-B/32 | PixelProse | 57.2 | 38.9 | 85.6 | 66.3 |54.3 | 30 | 31 | ⚠️ You don't need to manually download the pre-trained weights to run the inference, the pre-trained weights will be automatically downloaded by specifying the `--huggingface-model-name` and `--huggingface-repo-name` during inference. 32 | Optionally, you could download each weight separately and set `--resume path/to/pretrained_weights` flag in inference code. 33 | 34 | ## Dependencies 35 | You can set up your virtual environment following the below instructions. We built our code repository upon [OpenCLIP](https://github.com/mlfoundations/open_clip), which is still updated frequently. We recommend you to check their repo for a detailed tutorial on creating an environment that is best suited for your system. A conda environment is also possible with the same Python and PyTorch version. 36 | 37 | ### 1. Download our github Repository 38 | First, download the COSMOS github repo and navigate to the project’s root directory `cosmos/`. 39 | ```bash 40 | git clone https://github.com/ExplainableML/cosmos.git 41 | cd cosmos/ 42 | ``` 43 | 44 | ### 2. Create a Virtual Environment 45 | Create a virtual environment using Python 3.12 and activate the virtual environment. 46 | ```bash 47 | python3.12 -m venv cosmos_env 48 | source cosmos_env/bin/activate 49 | ``` 50 | 51 | ### 3. Install Dependencies 52 | Install all requirements via pip. 53 | ```bash 54 | pip install --upgrade pip 55 | pip install -r requirements.txt 56 | ``` 57 | If you want to conduct semantic segmentation tasks, please follow [SCLIP](https://github.com/wangf3014/SCLIP) to install their dependencies as well. We wrote down their command below for completeness. 58 | ```bash 59 | pip install openmim 60 | mim install mmcv==2.0.1 mmengine==0.8.4 mmsegmentation==1.1.1 61 | pip install ftfy regex yapf==0.40.1 62 | ``` 63 | 64 | ### [Optional] Anaconda Environment 65 | One can optionally use anaconda to set up the environment. 66 | ```bash 67 | conda create --name cosmos_env python=3.12 68 | conda activate cosmos_env 69 | ``` 70 | Then, install all dependencies as follows. 71 | ```bash 72 | pip install --upgrade pip 73 | pip install -r requirements.txt 74 | ``` 75 | 76 | ## Inference Datasets Preparation 77 | Check [datasets/README.md](datasets/README.md) to prepare all the inference datasets for retrieval, classification, and segmentation tasks. 78 | 79 | ## Inference with COSMOS 80 | To reproduce the results of downstream tasks (image-text retrieval, image classification, semantic segmentation) in the COSMOS paper, we provide an example inference bash script for each task: `src/inference_retrieval.sh`, `src/inference_classification.sh`, and `src/inference_segmentation.sh`. 81 | 82 | Here are detailed explanations of important flags. 83 | 84 | 85 | - `--huggingface-repo-name`: Name of the Huggingface repo where the pre-trained models are stored. Should be fixed as `sankim2/cosmos`. 86 | - `--huggingface-model-name`: Name of the pretrained models. Options include `cosmos_vitb16_cc3m.pt, cosmos_vitb16_cc12m.pt, cosmos_vitb16_yfcc15m.pt, cosmos_vitb16_merged30m.pt, cosmos_vitb16_pixelprose.pt` for ViT-B/16 and `cosmos_vitb32_cc3m.pt, cosmos_vitb32_cc12m.pt, cosmos_vitb32_yfcc15m.pt, cosmos_vitb32_merged30m.pt, cosmos_vitb32_pixelprose.pt` for ViT-B/32. 87 | - `--model`: Model architecture should be matched with `--huggingface-model-name`. Options include `ViT-B-16` and `ViT-B-32`. 88 | - `--precision`: Defualt as `amp` in our paper. 89 | - `--workers`: Adjustable according to your system. 90 | 91 | ### Image-Text Retrieval Task 92 | `--data-root-dir` should denote your directory which contains COCO and Flickr30k validation set. Please refer to [/src/inference_retrieval.sh](/src/inference_retrieval.sh) for running inference on retrieval task. 93 | 94 | ### Image Classification Task 95 | `--imagenet-val` should denote your directory which contains ImageNet validation set. Please refer to [/src/inference_classification.sh](/src/inference_classification.sh) for running inference on classification task. 96 | 97 | ### Semantic Segmentation Task 98 | `--seg-w-background` denotes a flag whether to evaluate on segmentation benchmarks with background. If `--use-csa` is included, the model will use Correlative Self-Attention (CSA) block from SCLIP for segmentation. Please refer to [/src/inference_segmentation.sh](/src/inference_segmentation.sh) for running inference on segmentation task. 99 | 100 | ## Training COSMOS 101 | In order to train COSMOS from scratch, synthetic long caption datasets should be downloaded from [DreamLIP](https://github.com/ant-research/DreamLIP)'s recaptioned [CC3M-recap](https://huggingface.co/datasets/qidouxiong619/dreamlip_long_captions), [CC12M-recap](https://huggingface.co/datasets/qidouxiong619/dreamlip_long_captions), [YFCC15M-recap](https://huggingface.co/datasets/qidouxiong619/dreamlip_long_captions) and combined (Merged-30M), and [PixelProse](https://huggingface.co/datasets/tomg-group-umd/pixelprose). Notably, COSMOS requires all pre-training dataset to be processed into the [webdataset](https://github.com/webdataset/webdataset) format, to achieve higher I/O efficiency for large-scale training. In the pre-training dataset preparation step, we take [CC3M-recap](https://huggingface.co/datasets/qidouxiong619/dreamlip_long_captions) as an example to demonstrate how to prepare the pretraining data. The preparation for other datasets should be similar. We share the same pre-training dataset as [FLAIR](https://arxiv.org/abs/2412.03561). Please check their [repo](https://github.com/ExplainableML/flair) as well if you find it interesting! 102 | 103 | ### Prepare Pre-training Data 104 | 1. Download DreamLIP's annotations for CC3M-recap: 105 | `wget https://huggingface.co/datasets/qidouxiong619/dreamlip_long_captions/resolve/main/cc3m_3long_3short_1raw_captions_url.csv` 106 | 2. Scrape the images based on the url links using [img2dataset](https://github.com/rom1504/img2dataset). 107 | 108 | ### Training Script 109 | COSMOS is trained with Slurm GPU Cluster on 16 NVIDIA A100s 40GB (on CC3M) or 128 NVIDIA A100s 40GB (on other larger datasets). In `src/`, we provide example slurm training scripts for each of the datasets: `train_cc3m.sh, train_cc12m.sh, train_yfcc15m.sh, train_merged30m.sh, train_pixelprose.sh`. 110 | 111 | Important flags are described below: 112 | - `--train-data`: Root dir of where the training data (shards) is stored. 113 | - `--train-num-samples`: the total number of training samples. This should be adjustable based on your available data. 114 | - `--use-imagecrop-aug`: Using multi-crop image augmentation described in the paper. 115 | - `--global-crops-number`: Number of global crop of image. Fixed as 2. 116 | - `--local-crops-number`: Number of local crop of image. 117 | - `--crop-scale`: Determine the scale s of global and local crop images. (0.05, s) for local crops and (s, 1.0) for global crops. Fixed as 0.4 118 | - `--caption-sampling-mode`: Determine how captions are sampled. Fixed as `textcrop` or `textcrop_pixelprose`. 119 | - `--num-sampled-captions`: Total number of captions (global+local) 120 | - `--momentum-teacher`: Initial momentum value. This should be adjusted based on batch size. We used 0.999 for 1k batch and 0.99 for 4k batch. 121 | - `--fix-momentum`: Fix momentum value during training. 122 | - `--output-all`: Output both patch (or word) tokens and [cls] (or [eot]) tokens. 123 | - `--attentional-pool`: Set cross-attention module in model. 124 | - `--cosmos`: Use COSMOS loss during training. 125 | 126 | ## Qualitative Results 127 | We visualize the attention weights of image and text cross-attention modules. Patch-wise (image) and token-wise (caption) attention weights are both normalized between 0 and 1. 128 | 129 | ![](assets/qualitative_results_supp.png "Qualitative Results") 130 | 131 | ## Acknowledgements 132 | We thank [OpenCLIP](https://github.com/mlfoundations/open_clip) for providing the amazing code base. Meanwhile, we acknowledge [DreamLIP](https://github.com/zyf0619sjtu/DreamLIP) and [PixelProse](https://huggingface.co/datasets/tomg-group-umd/pixelprose) for providing us with various pre-training datasets with captions from MLLMs. We are also grateful to [SCLIP](https://github.com/wangf3014/SCLIP) for providing the detailed scheme for semantic segmentation task. 133 | 134 | ## Citations 135 | If you find our work useful, please star this repo and cite: 136 | 137 | ```bibtex 138 | @article{kim2025cosmos, 139 | title={COSMOS: Cross-Modality Self-Distillation for Vision Language Pre-training}, 140 | author={Kim, Sanghwan and Xiao, Rui and Georgescu, Mariana-Iuliana and Alaniz, Stephan and Akata, Zeynep}, 141 | journal={CVPR}, 142 | year={2025} 143 | } 144 | ``` 145 | -------------------------------------------------------------------------------- /assets/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /assets/cosmos.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/cosmos/c80348f08e9b02c5a81adadd7dce486b3c284b78/assets/cosmos.png -------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/cosmos/c80348f08e9b02c5a81adadd7dce486b3c284b78/assets/framework.png -------------------------------------------------------------------------------- /assets/qualitative_results_supp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/cosmos/c80348f08e9b02c5a81adadd7dce486b3c284b78/assets/qualitative_results_supp.png -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | # Data Preparation for Downstream Tasks 2 | ### Datasets list: 3 | - [MSCOCO](#coco) 4 | - [Flickr30k](#flickr) 5 | - [ImageNet](#imagenet) 6 | - [Other Classification datasets](#others) 7 | - [Segmentation datasets](#segmentation) 8 | 9 | 10 | ## Image-Text Retrieval Task 11 | 12 | ### MSCOCO dataset 13 | ``` 14 | $coco/ 15 | |–– images/ 16 | |–––– val2017/ 17 | |–––––– 000000134722.jpg 18 | |–––––– 000000177015.jpg 19 | |–––––– ... 20 | |–– annotations/ 21 | |–––– captions_val2017.json 22 | ``` 23 | Step 1. Download validation images from [COCO 2017 Val Images](https://cocodataset.org/#download), unzip them to `coco/images/val2017`. 24 | 25 | Step 2. Download the 2017 Val annotations, place it under `coco/annotations/captions_val2017.json`. 26 | 27 | ### Flickr30K dataset 28 | ``` 29 | $flickr30k-images/ 30 | |–– 2217728745.jpg 31 | |–– 2217728745.jpg 32 | |–– ... 33 | |–– flickr30k_val.json 34 | |–– flickr30k_test.json 35 | ``` 36 | Step 1. Download [flickr30k dataset](https://huggingface.co/datasets/nlphuji/flickr30k), unzip them under `flickr30k-images/`, all the images and annotations files will be structured as above. 37 | 38 | ## Image Classification Task 39 | 40 | ### ImageNet dataset 41 | ``` 42 | $imagenet/ 43 | |–– data/ 44 | |–––– val_images/ 45 | |–––––– n01440764/ 46 | |–––––––– ILSVRC2012_val_00000293_n01440764.JPEG 47 | |–––––––– ILSVRC2012_val_00017699_n01440764.JPEG 48 | |–––––––– ... 49 | |–––––– n01871265/ 50 | |–––––––– ILSVRC2012_val_00000067_n01871265.JPEG 51 | |–––––––– ILSVRC2012_val_00017361_n01871265.JPEG 52 | |–––––––– ... 53 | ``` 54 | 55 | Step 1. Download validation data `val_images.tar.gz` from [ILSVRC/imagenet-1k](https://huggingface.co/datasets/ILSVRC/imagenet-1k), and unzip them to `imagenet/data/val_images`. 56 | You can manually download the `imagenet-1k/data/val_images.tar.gz` or use this command. `huggingface-cli download ILSVRC/imagenet-1k --repo-type dataset --local-dir /directory/to/your/dataset/`. 57 | 58 | Step 2. Change [source_dir](imagenet_organize.py#L5) in `imagenet_organize.py` according to your val_images folder. Then, run `imagenet_organize.py` to organize the image in the above format. 59 | 60 | ### Other Classification datasets 61 | 62 | Other classification datasets include `["food101", "cifar10", "cifar100", "sun397", "stanford_car", "aircraft", "dtd", "pets", "caltech101", "flowers"]`. 63 | 64 | Please set appropriate [dataset_root](/src/dataloaders/utils.py#L17) in `src/dataloaders/utils.py` to save classification datasets. 65 | 66 | Then, `torchvision.datasets` will automatically download the datatsets in `dataset_root` during inference. 67 | 68 | 69 | ## Semantic Segmentation Task 70 | 71 | ### Segmentation datasets 72 | 73 | We followed the evaluation scheme and config files provided by [SCLIP](https://github.com/wangf3014/SCLIP) as shown [here](/src/training/seg_configs). 74 | 75 | Our segmentation configs include benchmarks with background `['cfg_voc21.py', 'cfg_context60.py', 'cfg_coco_object.py']` and without background `['cfg_voc20.py', 'cfg_city_scapes.py', 'cfg_context59.py', 'cfg_ade20k.py', 'cfg_coco_stuff164k.py']`. 76 | 77 | Please follow the dataset preparation instruction provided by [SCLIP](https://github.com/wangf3014/SCLIP) and [mmsegmentation](https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md) to download the following datasets: `["VOCdevkit/VOC2012", "VOCdevkit/VOC2010", "coco_stuff164k", "cityscapes, "ade"]`. 78 | 79 | Then, change the `data_root` in each segmentation config according to the dataset location. For example, this is [root_dir](/src/training/seg_configs/cfg_ade20k.py#L12) for `cfg_ade20k.py`. 80 | -------------------------------------------------------------------------------- /datasets/imagenet_organize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | # Define the source directory 5 | source_dir = "/your/directory/to/imagenet/data/val_images" 6 | 7 | # Check if the source directory exists 8 | if not os.path.exists(source_dir): 9 | print(f"Source directory {source_dir} does not exist.") 10 | exit() 11 | 12 | # Get a list of all JPEG files in the source directory 13 | jpeg_files = [f for f in os.listdir(source_dir) if f.endswith('.JPEG')] 14 | 15 | # Process each JPEG file 16 | for jpeg_file in jpeg_files: 17 | # Extract the class name from the file name 18 | # Example) ILSVRC2012_val_00012508_n01843065.JPEG => n01843065 19 | class_name = jpeg_file.split('_')[-1].split('.')[0] 20 | 21 | # Define the target directory for this class 22 | target_dir = os.path.join(source_dir, class_name) 23 | 24 | # Create the target directory if it doesn't exist 25 | if not os.path.exists(target_dir): 26 | os.makedirs(target_dir) 27 | 28 | # Define the source and target file paths 29 | source_file = os.path.join(source_dir, jpeg_file) 30 | target_file = os.path.join(target_dir, jpeg_file) 31 | 32 | # Move the file to the target directory 33 | shutil.move(source_file, target_file) 34 | 35 | print("Files have been classified and moved to their respective directories.") -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.0 2 | torchvision 3 | webdataset>=0.2.5 4 | regex 5 | ftfy 6 | tqdm 7 | pandas 8 | braceexpand 9 | huggingface_hub 10 | pytest-split==0.8.0 11 | pytest==7.2.0 12 | transformers[sentencepiece] 13 | timm>=1.0.7 14 | fsspec 15 | wandb -------------------------------------------------------------------------------- /src/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/dataloaders/caltech101.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import torch 3 | from torch.utils.data import DataLoader, Dataset, random_split 4 | from torchvision.datasets import Caltech101 5 | from .utils import dataset_root 6 | 7 | 8 | root = dataset_root 9 | num_example_train = 3000 10 | num_example_test = 5677 11 | num_classes = 101 12 | mean_per_class = True 13 | 14 | def get_loader_train( 15 | transform, batch_size, num_workers, seed 16 | ) -> Tuple[Dataset, DataLoader]: 17 | dataset = Caltech101(root, download=False, transform=transform) 18 | dataset_train, dataset_test = random_split( 19 | dataset, 20 | lengths=[num_example_train, 21 | num_example_test], 22 | generator=torch.Generator().manual_seed(seed)) 23 | return (dataset_train, None) 24 | 25 | 26 | def get_loader_test( 27 | transform, batch_size, num_workers, seed 28 | ) -> Tuple[Dataset, DataLoader]: 29 | dataset = Caltech101(root, download=False, transform=transform) 30 | dataset_train, dataset_test = random_split( 31 | dataset, 32 | lengths=[num_example_train, 33 | num_example_test], 34 | generator=torch.Generator().manual_seed(seed)) 35 | return (dataset_test, None) -------------------------------------------------------------------------------- /src/dataloaders/cifar10.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from torch.utils.data import DataLoader, Dataset 3 | from torchvision.datasets import CIFAR10 4 | from .utils import dataset_root 5 | 6 | 7 | root = dataset_root 8 | num_example_train = 50000 9 | num_example_test = 10000 10 | num_classes = 10 11 | 12 | 13 | def get_loader_train( 14 | transform, batch_size, num_workers, seed 15 | ) -> Tuple[Dataset, DataLoader]: 16 | dataset_train = CIFAR10(root, download=False, train=True, transform=transform) 17 | return (dataset_train, None) 18 | 19 | 20 | def get_loader_test( 21 | transform, batch_size, num_workers, seed 22 | ) -> Tuple[Dataset, DataLoader]: 23 | dataset = CIFAR10(root, download=True, train=False, transform=transform) 24 | return (dataset, None) -------------------------------------------------------------------------------- /src/dataloaders/cifar100.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from torch.utils.data import DataLoader, Dataset 3 | from torchvision.datasets import CIFAR100 4 | from .utils import dataset_root 5 | 6 | 7 | root = dataset_root 8 | num_example_train_val = 50000 9 | num_example_test = 10000 10 | num_classes = 100 11 | 12 | 13 | def get_loader_train( 14 | transform, batch_size, num_workers, seed 15 | ) -> Tuple[Dataset, DataLoader]: 16 | dataset_train_val = CIFAR100(root, download=False, train=True, transform=transform) 17 | return (dataset_train_val, None) 18 | 19 | 20 | def get_loader_test( 21 | transform, batch_size, num_workers, seed 22 | ) -> Tuple[Dataset, DataLoader]: 23 | dataset_test = CIFAR100(root, download=True, train=False, transform=transform) 24 | return (dataset_test, None) -------------------------------------------------------------------------------- /src/dataloaders/dtd.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from torch.utils.data import DataLoader, Dataset, ConcatDataset 3 | from torchvision.datasets import DTD 4 | import torchvision 5 | from .utils import dataset_root 6 | 7 | root = dataset_root 8 | num_example_train = 3760 9 | num_example_test = 1880 10 | num_classes = 47 11 | 12 | 13 | class Warper(Dataset): 14 | def __init__(self, dataset) -> None: 15 | super().__init__() 16 | self.dataset = dataset 17 | 18 | def __len__(self) -> int: 19 | return len(self.dataset) 20 | 21 | def __getitem__(self, index): 22 | img, label = self.dataset[index] 23 | if torchvision.__version__ >= "0.13.0": 24 | return img, label 25 | else: 26 | return img, label - 1 27 | 28 | 29 | def get_loader_train( 30 | transform, batch_size, num_workers, seed 31 | ) -> Tuple[Dataset, DataLoader]: 32 | dataset_train = DTD(root, download=True, split='train', transform=transform) 33 | dataset_val = DTD(root, download=True, split='val', transform=transform) 34 | dataset_train = ConcatDataset([dataset_train, dataset_val]) 35 | dataset_train = Warper(dataset_train) 36 | return (dataset_train, None) 37 | 38 | 39 | def get_loader_test( 40 | transform, batch_size, num_workers, seed 41 | ) -> Tuple[Dataset, DataLoader]: 42 | dataset_test = DTD(root, download=True, split='test', transform=transform) 43 | return (dataset_test, None) -------------------------------------------------------------------------------- /src/dataloaders/fgvc_aircraft.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from torch.utils.data import DataLoader, Dataset 3 | from torchvision.datasets import FGVCAircraft 4 | from .utils import dataset_root 5 | 6 | 7 | root = dataset_root 8 | num_example_train_val = 6667 9 | num_example_test = 3333 10 | num_classes = 100 11 | mean_per_class = True 12 | 13 | 14 | def get_loader_train( 15 | transform, batch_size, num_workers, seed 16 | ) -> Tuple[Dataset, DataLoader]: 17 | dataset_train_val = FGVCAircraft(root, download=True, annotation_level='variant', split='trainval',transform=transform) 18 | return (dataset_train_val, None) 19 | 20 | 21 | def get_loader_test( 22 | transform, batch_size, num_workers, seed 23 | ) -> Tuple[Dataset, DataLoader]: 24 | dataset_test = FGVCAircraft(root, download=True, annotation_level='variant', split='test', transform=transform) 25 | return (dataset_test, None) 26 | -------------------------------------------------------------------------------- /src/dataloaders/flowers102.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from torch.utils.data import DataLoader, Dataset, ConcatDataset 3 | from torchvision.datasets import Flowers102 4 | import torchvision 5 | from .utils import dataset_root 6 | 7 | root = dataset_root 8 | num_example_train_val = 2040 9 | num_example_test = 6149 10 | num_classes = 102 11 | mean_per_class = True 12 | 13 | 14 | class Warper(Dataset): 15 | def __init__(self, dataset) -> None: 16 | super().__init__() 17 | self.dataset = dataset 18 | 19 | def __len__(self) -> int: 20 | return len(self.dataset) 21 | 22 | def __getitem__(self, index): 23 | img, label = self.dataset[index] 24 | if torchvision.__version__ >= "0.13.0": 25 | return img, label 26 | else: 27 | return img, label - 1 28 | 29 | 30 | def get_loader_train( 31 | transform, batch_size, num_workers, seed 32 | ) -> Tuple[Dataset, DataLoader]: 33 | dataset_train = Flowers102(root, download=True, split='train', transform=transform) 34 | dataset_val = Flowers102(root, download=True, split='val', transform=transform) 35 | dataset_train = ConcatDataset([dataset_train, dataset_val]) 36 | dataset_train = Warper(dataset_train) 37 | return (dataset_train, None) 38 | 39 | 40 | def get_loader_test( 41 | transform, batch_size, num_workers, seed 42 | ) -> Tuple[Dataset, DataLoader]: 43 | dataset_test = Flowers102(root, download=True, split='test', transform=transform) 44 | dataset_test = Warper(dataset_test) 45 | return (dataset_test, None) 46 | -------------------------------------------------------------------------------- /src/dataloaders/food101.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from torch.utils.data import DataLoader, Dataset 3 | from torchvision.datasets import Food101 4 | from .utils import dataset_root 5 | 6 | 7 | root = dataset_root 8 | num_example_train_val = 75750 9 | num_example_test = 25250 10 | num_classes = 101 11 | 12 | 13 | def get_loader_train( 14 | transform, batch_size, num_workers, seed 15 | ) -> Tuple[Dataset, DataLoader]: 16 | dataset_train_val = Food101(root, download=True, split='train', transform=transform) 17 | return (dataset_train_val, None) 18 | 19 | 20 | def get_loader_test( 21 | transform, batch_size, num_workers, seed 22 | ) -> Tuple[Dataset, DataLoader]: 23 | dataset_test = Food101(root, download=True, split='test', transform=transform) 24 | return (dataset_test, None) -------------------------------------------------------------------------------- /src/dataloaders/oxford_pets.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from torch.utils.data import DataLoader, Dataset 3 | from torchvision.datasets import OxfordIIITPet 4 | from .utils import dataset_root 5 | 6 | 7 | root = dataset_root 8 | num_example_train_val = 3680 9 | num_example_test = 3699 10 | num_classes = 37 11 | mean_per_class = True 12 | 13 | 14 | def get_loader_train( 15 | transform, batch_size, num_workers, seed 16 | ) -> Tuple[Dataset, DataLoader]: 17 | dataset_train_val = OxfordIIITPet(root, download=True, split='trainval',transform=transform) 18 | return (dataset_train_val, None) 19 | 20 | 21 | def get_loader_test( 22 | transform, batch_size, num_workers, seed 23 | ) -> Tuple[Dataset, DataLoader]: 24 | dataset_test = OxfordIIITPet(root, download=True, split='test', transform=transform) 25 | return (dataset_test, None) 26 | -------------------------------------------------------------------------------- /src/dataloaders/stanford_car.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from torch.utils.data import DataLoader, Dataset 3 | from torchvision.datasets import StanfordCars 4 | from .utils import dataset_root 5 | 6 | 7 | root = dataset_root 8 | num_example_train_val = 8144 9 | num_example_test = 8041 10 | num_classes = 196 11 | 12 | 13 | def get_loader_train( 14 | transform, batch_size, num_workers, seed 15 | ) -> Tuple[Dataset, DataLoader]: 16 | dataset_train = StanfordCars( 17 | root, download=False, split='train', transform=transform) 18 | return (dataset_train, None) 19 | 20 | 21 | def get_loader_test( 22 | transform, batch_size, num_workers, seed 23 | ) -> Tuple[Dataset, DataLoader]: 24 | dataset_test = StanfordCars( 25 | root, download=False, split='test', transform=transform) 26 | return (dataset_test, None) -------------------------------------------------------------------------------- /src/dataloaders/sun397.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import torch 3 | from torch.utils.data import DataLoader, Dataset, random_split 4 | from torchvision.datasets import SUN397 5 | from .utils import dataset_root 6 | 7 | 8 | num_example_train = 19850 9 | num_example_test = 19850 10 | num_example_others = 69054 11 | num_classes = 397 12 | root = dataset_root 13 | 14 | 15 | def get_loader_train( 16 | transform, batch_size, num_workers, seed 17 | ) -> Tuple[Dataset, DataLoader]: 18 | dataset = SUN397(root, transform=transform) 19 | dataset_train, dataset_test, others = random_split( 20 | dataset, 21 | lengths=[num_example_train, 22 | num_example_test, 23 | num_example_others,], 24 | generator=torch.Generator().manual_seed(seed + hash("sun397") % 2048)) 25 | return (dataset_train, None) 26 | 27 | 28 | def get_loader_test( 29 | transform, batch_size, num_workers, seed 30 | ) -> Tuple[Dataset, DataLoader]: 31 | dataset = SUN397(root, download=True, transform=transform) 32 | dataset_train, dataset_test, others = random_split( 33 | dataset, 34 | lengths=[num_example_train, 35 | num_example_test, 36 | num_example_others,], 37 | generator=torch.Generator().manual_seed(seed + hash("sun397") % 2048)) 38 | return (dataset_test, None) -------------------------------------------------------------------------------- /src/dataloaders/templates.json: -------------------------------------------------------------------------------- 1 | { 2 | "food101": [ 3 | "a photo of {}, a type of food." 4 | ], 5 | "cifar10": [ 6 | "a photo of a {}.", 7 | "a blurry photo of a {}.", 8 | "a black and white photo of a {}.", 9 | "a low contrast photo of a {}.", 10 | "a high contrast photo of a {}.", 11 | "a bad photo of a {}.", 12 | "a good photo of a {}.", 13 | "a photo of a small {}.", 14 | "a photo of a big {}.", 15 | "a photo of the {}.", 16 | "a blurry photo of the {}.", 17 | "a black and white photo of the {}.", 18 | "a low contrast photo of the {}.", 19 | "a high contrast photo of the {}.", 20 | "a bad photo of the {}.", 21 | "a good photo of the {}.", 22 | "a photo of the small {}.", 23 | "a photo of the big {}." 24 | ], 25 | "cifar100": [ 26 | "a photo of a {}.", 27 | "a blurry photo of a {}.", 28 | "a black and white photo of a {}.", 29 | "a low contrast photo of a {}.", 30 | "a high contrast photo of a {}.", 31 | "a bad photo of a {}.", 32 | "a good photo of a {}.", 33 | "a photo of a small {}.", 34 | "a photo of a big {}.", 35 | "a photo of the {}.", 36 | "a blurry photo of the {}.", 37 | "a black and white photo of the {}.", 38 | "a low contrast photo of the {}.", 39 | "a high contrast photo of the {}.", 40 | "a bad photo of the {}.", 41 | "a good photo of the {}.", 42 | "a photo of the small {}.", 43 | "a photo of the big {}." 44 | ], 45 | "birdsnap": [ 46 | "a photo of a {}, a type of bird." 47 | ], 48 | "cub": [ 49 | "a photo of a {}, a type of bird." 50 | ], 51 | "imagenet": [ 52 | "a bad photo of a {}.", 53 | "a photo of many {}.", 54 | "a sculpture of a {}.", 55 | "a photo of the hard to see {}.", 56 | "a low resolution photo of the {}.", 57 | "a rendering of a {}.", 58 | "graffiti of a {}.", 59 | "a bad photo of the {}.", 60 | "a cropped photo of the {}.", 61 | "a tattoo of a {}.", 62 | "the embroidered {}.", 63 | "a photo of a hard to see {}.", 64 | "a bright photo of a {}.", 65 | "a photo of a clean {}.", 66 | "a photo of a dirty {}.", 67 | "a dark photo of the {}.", 68 | "a drawing of a {}.", 69 | "a photo of my {}.", 70 | "the plastic {}.", 71 | "a photo of the cool {}.", 72 | "a close-up photo of a {}.", 73 | "a black and white photo of the {}.", 74 | "a painting of the {}.", 75 | "a painting of a {}.", 76 | "a pixelated photo of the {}.", 77 | "a sculpture of the {}.", 78 | "a bright photo of the {}.", 79 | "a cropped photo of a {}.", 80 | "a plastic {}.", 81 | "a photo of the dirty {}.", 82 | "a jpeg corrupted photo of a {}.", 83 | "a blurry photo of the {}.", 84 | "a photo of the {}.", 85 | "a good photo of the {}.", 86 | "a rendering of the {}.", 87 | "a {} in a video game.", 88 | "a photo of one {}.", 89 | "a doodle of a {}.", 90 | "a close-up photo of the {}.", 91 | "a photo of a {}.", 92 | "the origami {}.", 93 | "the {} in a video game.", 94 | "a sketch of a {}.", 95 | "a doodle of the {}.", 96 | "a origami {}.", 97 | "a low resolution photo of a {}.", 98 | "the toy {}.", 99 | "a rendition of the {}.", 100 | "a photo of the clean {}.", 101 | "a photo of a large {}.", 102 | "a rendition of a {}.", 103 | "a photo of a nice {}.", 104 | "a photo of a weird {}.", 105 | "a blurry photo of a {}.", 106 | "a cartoon {}.", 107 | "art of a {}.", 108 | "a sketch of the {}.", 109 | "a embroidered {}.", 110 | "a pixelated photo of a {}.", 111 | "itap of the {}.", 112 | "a jpeg corrupted photo of the {}.", 113 | "a good photo of a {}.", 114 | "a plushie {}.", 115 | "a photo of the nice {}.", 116 | "a photo of the small {}.", 117 | "a photo of the weird {}.", 118 | "the cartoon {}.", 119 | "art of the {}.", 120 | "a drawing of the {}.", 121 | "a photo of the large {}.", 122 | "a black and white photo of a {}.", 123 | "the plushie {}.", 124 | "a dark photo of a {}.", 125 | "itap of a {}.", 126 | "graffiti of the {}.", 127 | "a toy {}.", 128 | "itap of my {}.", 129 | "a photo of a cool {}.", 130 | "a photo of a small {}.", 131 | "a tattoo of the {}." 132 | ], 133 | "imagenet_a": [ 134 | "a bad photo of a {}.", 135 | "a photo of many {}.", 136 | "a sculpture of a {}.", 137 | "a photo of the hard to see {}.", 138 | "a low resolution photo of the {}.", 139 | "a rendering of a {}.", 140 | "graffiti of a {}.", 141 | "a bad photo of the {}.", 142 | "a cropped photo of the {}.", 143 | "a tattoo of a {}.", 144 | "the embroidered {}.", 145 | "a photo of a hard to see {}.", 146 | "a bright photo of a {}.", 147 | "a photo of a clean {}.", 148 | "a photo of a dirty {}.", 149 | "a dark photo of the {}.", 150 | "a drawing of a {}.", 151 | "a photo of my {}.", 152 | "the plastic {}.", 153 | "a photo of the cool {}.", 154 | "a close-up photo of a {}.", 155 | "a black and white photo of the {}.", 156 | "a painting of the {}.", 157 | "a painting of a {}.", 158 | "a pixelated photo of the {}.", 159 | "a sculpture of the {}.", 160 | "a bright photo of the {}.", 161 | "a cropped photo of a {}.", 162 | "a plastic {}.", 163 | "a photo of the dirty {}.", 164 | "a jpeg corrupted photo of a {}.", 165 | "a blurry photo of the {}.", 166 | "a photo of the {}.", 167 | "a good photo of the {}.", 168 | "a rendering of the {}.", 169 | "a {} in a video game.", 170 | "a photo of one {}.", 171 | "a doodle of a {}.", 172 | "a close-up photo of the {}.", 173 | "a photo of a {}.", 174 | "the origami {}.", 175 | "the {} in a video game.", 176 | "a sketch of a {}.", 177 | "a doodle of the {}.", 178 | "a origami {}.", 179 | "a low resolution photo of a {}.", 180 | "the toy {}.", 181 | "a rendition of the {}.", 182 | "a photo of the clean {}.", 183 | "a photo of a large {}.", 184 | "a rendition of a {}.", 185 | "a photo of a nice {}.", 186 | "a photo of a weird {}.", 187 | "a blurry photo of a {}.", 188 | "a cartoon {}.", 189 | "art of a {}.", 190 | "a sketch of the {}.", 191 | "a embroidered {}.", 192 | "a pixelated photo of a {}.", 193 | "itap of the {}.", 194 | "a jpeg corrupted photo of the {}.", 195 | "a good photo of a {}.", 196 | "a plushie {}.", 197 | "a photo of the nice {}.", 198 | "a photo of the small {}.", 199 | "a photo of the weird {}.", 200 | "the cartoon {}.", 201 | "art of the {}.", 202 | "a drawing of the {}.", 203 | "a photo of the large {}.", 204 | "a black and white photo of a {}.", 205 | "the plushie {}.", 206 | "a dark photo of a {}.", 207 | "itap of a {}.", 208 | "graffiti of the {}.", 209 | "a toy {}.", 210 | "itap of my {}.", 211 | "a photo of a cool {}.", 212 | "a photo of a small {}.", 213 | "a tattoo of the {}." 214 | ], 215 | "sst2": [ 216 | "a {} review of a movie." 217 | ], 218 | "hateful_memes": [ 219 | "a {}." 220 | ], 221 | "clevr": [ 222 | "a photo of {} objects." 223 | ], 224 | "kinetics700": [ 225 | "a photo of {}.", 226 | "a photo of a person {}.", 227 | "a photo of a person using {}.", 228 | "a photo of a person doing {}.", 229 | "a photo of a person during {}.", 230 | "a photo of a person performing {}.", 231 | "a photo of a person practicing {}.", 232 | "a video of {}.", 233 | "a video of a person {}.", 234 | "a video of a person using {}.", 235 | "a video of a person doing {}.", 236 | "a video of a person during {}.", 237 | "a video of a person performing {}.", 238 | "a video of a person practicing {}.", 239 | "a example of {}.", 240 | "a example of a person {}.", 241 | "a example of a person using {}.", 242 | "a example of a person doing {}.", 243 | "a example of a person during {}.", 244 | "a example of a person performing {}.", 245 | "a example of a person practicing {}.", 246 | "a demonstration of {}.", 247 | "a demonstration of a person {}.", 248 | "a demonstration of a person using {}.", 249 | "a demonstration of a person doing {}.", 250 | "a demonstration of a person during {}.", 251 | "a demonstration of a person performing {}.", 252 | "a demonstration of a person practicing {}." 253 | ], 254 | "ucf101": [ 255 | "a photo of a person {}.", 256 | "a video of a person {}.", 257 | "a example of a person {}.", 258 | "a demonstration of a person {}.", 259 | "a photo of the person {}.", 260 | "a video of the person {}.", 261 | "a example of the person {}.", 262 | "a demonstration of the person {}.", 263 | "a photo of a person using {}.", 264 | "a video of a person using {}.", 265 | "a example of a person using {}.", 266 | "a demonstration of a person using {}.", 267 | "a photo of the person using {}.", 268 | "a video of the person using {}.", 269 | "a example of the person using {}.", 270 | "a demonstration of the person using {}.", 271 | "a photo of a person doing {}.", 272 | "a video of a person doing {}.", 273 | "a example of a person doing {}.", 274 | "a demonstration of a person doing {}.", 275 | "a photo of the person doing {}.", 276 | "a video of the person doing {}.", 277 | "a example of the person doing {}.", 278 | "a demonstration of the person doing {}.", 279 | "a photo of a person during {}.", 280 | "a video of a person during {}.", 281 | "a example of a person during {}.", 282 | "a demonstration of a person during {}.", 283 | "a photo of the person during {}.", 284 | "a video of the person during {}.", 285 | "a example of the person during {}.", 286 | "a demonstration of the person during {}.", 287 | "a photo of a person performing {}.", 288 | "a video of a person performing {}.", 289 | "a example of a person performing {}.", 290 | "a demonstration of a person performing {}.", 291 | "a photo of the person performing {}.", 292 | "a video of the person performing {}.", 293 | "a example of the person performing {}.", 294 | "a demonstration of the person performing {}.", 295 | "a photo of a person practicing {}.", 296 | "a video of a person practicing {}.", 297 | "a example of a person practicing {}.", 298 | "a demonstration of a person practicing {}.", 299 | "a photo of the person practicing {}.", 300 | "a video of the person practicing {}.", 301 | "a example of the person practicing {}.", 302 | "a demonstration of the person practicing {}." 303 | ], 304 | "pcam": [ 305 | "this is a photo of {}" 306 | ], 307 | "country211": [ 308 | "a photo i took in {}.", 309 | "a photo i took while visiting {}.", 310 | "a photo from my home country of {}.", 311 | "a photo from my visit to {}.", 312 | "a photo showing the country of {}." 313 | ], 314 | "kitti": [ 315 | "{}" 316 | ], 317 | "gtsrb": [ 318 | "a zoomed in photo of a \"{}\" traffic sign.", 319 | "a centered photo of a \"{}\" traffic sign.", 320 | "a close up photo of a \"{}\" traffic sign." 321 | ], 322 | "resisc45": [ 323 | "satellite imagery of {}.", 324 | "aerial imagery of {}.", 325 | "satellite photo of {}.", 326 | "aerial photo of {}.", 327 | "satellite view of {}.", 328 | "aerial view of {}.", 329 | "satellite imagery of a {}.", 330 | "aerial imagery of a {}.", 331 | "satellite photo of a {}.", 332 | "aerial photo of a {}.", 333 | "satellite view of a {}.", 334 | "aerial view of a {}.", 335 | "satellite imagery of the {}.", 336 | "aerial imagery of the {}.", 337 | "satellite photo of the {}.", 338 | "aerial photo of the {}.", 339 | "satellite view of the {}.", 340 | "aerial view of the {}." 341 | ], 342 | "eurosat": [ 343 | "a centered satellite photo of {}.", 344 | "a centered satellite photo of a {}.", 345 | "a centered satellite photo of the {}." 346 | ], 347 | "stl10": [ 348 | "a photo of a {}.", 349 | "a photo of the {}." 350 | ], 351 | "fer2013": [ 352 | "a photo of a {} looking face.", 353 | "a photo of a face showing the emotion: {}.", 354 | "a photo of a face looking {}.", 355 | "a face that looks {}.", 356 | "they look {}.", 357 | "look at how {} they are." 358 | ], 359 | "mnist": [ 360 | "a photo of the number: \"{}\"." 361 | ], 362 | "flowers": [ 363 | "a photo of a {}, a type of flower." 364 | ], 365 | "caltech101": [ 366 | "a photo of a {}.", 367 | "a painting of a {}.", 368 | "a plastic {}.", 369 | "a sculpture of a {}.", 370 | "a sketch of a {}.", 371 | "a tattoo of a {}.", 372 | "a toy {}.", 373 | "a rendition of a {}.", 374 | "a embroidered {}.", 375 | "a cartoon {}.", 376 | "a {} in a video game.", 377 | "a plushie {}.", 378 | "a origami {}.", 379 | "art of a {}.", 380 | "graffiti of a {}.", 381 | "a drawing of a {}.", 382 | "a doodle of a {}.", 383 | "a photo of the {}.", 384 | "a painting of the {}.", 385 | "the plastic {}.", 386 | "a sculpture of the {}.", 387 | "a sketch of the {}.", 388 | "a tattoo of the {}.", 389 | "the toy {}.", 390 | "a rendition of the {}.", 391 | "the embroidered {}.", 392 | "the cartoon {}.", 393 | "the {} in a video game.", 394 | "the plushie {}.", 395 | "the origami {}.", 396 | "art of the {}.", 397 | "graffiti of the {}.", 398 | "a drawing of the {}.", 399 | "a doodle of the {}." 400 | ], 401 | "pets": [ 402 | "a photo of a {}, a type of pet." 403 | ], 404 | "dtd": [ 405 | "a photo of a {} texture.", 406 | "a photo of a {} pattern.", 407 | "a photo of a {} thing.", 408 | "a photo of a {} object.", 409 | "a photo of the {} texture.", 410 | "a photo of the {} pattern.", 411 | "a photo of the {} thing.", 412 | "a photo of the {} object." 413 | ], 414 | "voc2007": [ 415 | "a photo of a {}." 416 | ], 417 | "aircraft": [ 418 | "a photo of a {}, a type of aircraft.", 419 | "a photo of the {}, a type of aircraft." 420 | ], 421 | "stanford_car": [ 422 | "a photo of a {}.", 423 | "a photo of the {}.", 424 | "a photo of my {}.", 425 | "i love my {}!", 426 | "a photo of my dirty {}.", 427 | "a photo of my clean {}.", 428 | "a photo of my new {}.", 429 | "a photo of my old {}." 430 | ], 431 | "sun397": [ 432 | "a photo of a {}.", 433 | "a photo of the {}." 434 | ] 435 | } -------------------------------------------------------------------------------- /src/dataloaders/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import numpy as np 4 | import torch 5 | from torch import distributed as dist 6 | from torch.utils.data import DistributedSampler as _DistributedSampler 7 | import os 8 | try: 9 | rank = int(os.environ["RANK"]) 10 | world_size = int(os.environ["WORLD_SIZE"]) 11 | except KeyError: 12 | rank = 0 13 | world_size = 1 14 | 15 | # torchvision.datasets will automatically download the datasets 16 | # Set your directory to classification dataset. 17 | dataset_root = "/your/directory/to/classification/dataset" 18 | 19 | def worker_init_fn(worker_id, num_workers, rank, seed): 20 | # The seed of each worker equals to 21 | # num_worker * rank + worker_id + user_seed 22 | worker_seed = num_workers * rank + worker_id + seed 23 | np.random.seed(worker_seed) 24 | random.seed(worker_seed) 25 | torch.manual_seed(worker_seed) 26 | 27 | 28 | def get_dist_info(): 29 | if dist.is_available() and dist.is_initialized(): 30 | rank = dist.get_rank() 31 | world_size = dist.get_world_size() 32 | else: 33 | rank = 0 34 | world_size = 1 35 | 36 | return rank, world_size 37 | 38 | 39 | def sync_random_seed(seed=None, device="cuda"): 40 | """Make sure different ranks share the same seed. 41 | All workers must call this function, otherwise it will deadlock. 42 | This method is generally used in `DistributedSampler`, 43 | because the seed should be identical across all processes 44 | in the distributed group. 45 | In distributed sampling, different ranks should sample non-overlapped 46 | data in the dataset. Therefore, this function is used to make sure that 47 | each rank shuffles the data indices in the same order based 48 | on the same seed. Then different ranks could use different indices 49 | to select non-overlapped data from the same data list. 50 | Args: 51 | seed (int, Optional): The seed. Default to None. 52 | device (str): The device where the seed will be put on. 53 | Default to 'cuda'. 54 | Returns: 55 | int: Seed to be used. 56 | """ 57 | if seed is None: 58 | seed = np.random.randint(2**31) 59 | assert isinstance(seed, int) 60 | 61 | rank, world_size = get_dist_info() 62 | 63 | if world_size == 1: 64 | return seed 65 | 66 | if rank == 0: 67 | random_num = torch.tensor(seed, dtype=torch.int32, device=device) 68 | else: 69 | random_num = torch.tensor(0, dtype=torch.int32, device=device) 70 | 71 | dist.broadcast(random_num, src=0) 72 | 73 | return random_num.item() 74 | 75 | 76 | class DistributedSampler(_DistributedSampler): 77 | def __init__( 78 | self, 79 | dataset, 80 | num_replicas=None, # world_size 81 | rank=None, # local_rank 82 | shuffle=True, 83 | seed=0, 84 | ): 85 | 86 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) 87 | 88 | # In distributed sampling, different ranks should sample 89 | # non-overlapped data in the dataset. Therefore, this function 90 | # is used to make sure that each rank shuffles the data indices 91 | # in the same order based on the same seed. Then different ranks 92 | # could use different indices to select non-overlapped data from the 93 | # same data list. 94 | self.seed = sync_random_seed(seed) 95 | 96 | def __iter__(self): 97 | # deterministically shuffle based on epoch 98 | if self.shuffle: 99 | g = torch.Generator() 100 | # When :attr:`shuffle=True`, this ensures all replicas 101 | # use a different random ordering for each epoch. 102 | # Otherwise, the next iteration of this sampler will 103 | # yield the same ordering. 104 | g.manual_seed(self.epoch + self.seed) 105 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 106 | else: 107 | indices = torch.arange(len(self.dataset)).tolist() 108 | 109 | # add extra samples to make it evenly divisible 110 | # in case that indices is shorter than half of total_size 111 | indices = (indices * math.ceil(self.total_size / len(indices)))[ 112 | : self.total_size 113 | ] 114 | assert len(indices) == self.total_size 115 | 116 | # subsample 117 | indices = indices[self.rank : self.total_size : self.num_replicas] 118 | assert len(indices) == self.num_samples 119 | 120 | return iter(indices) -------------------------------------------------------------------------------- /src/inference_classification.sh: -------------------------------------------------------------------------------- 1 | # COSMOS models 2 | # --model ViT-B-16 3 | # --huggingface-model-name [cosmos_vitb16_cc3m, cosmos_vitb16_cc12m, cosmos_vitb16_yfcc15m, cosmos_vitb16_merged30m, cosmos_vitb16_pixelprose] 4 | # --model ViT-B-32 5 | # --huggingface-model-name [cosmos_vitb32_cc3m, cosmos_vitb32_cc12m, cosmos_vitb32_yfcc15m, cosmos_vitb32_merged30m, cosmos_vitb32_pixelprose] 6 | torchrun --nproc_per_node 1 -m main \ 7 | --model ViT-B-16 \ 8 | --huggingface-repo-name sankim2/cosmos \ 9 | --huggingface-model-name cosmos_vitb16_merged30m.pt \ 10 | --val-data classification \ 11 | --imagenet-val /directory/to/your/imagenet/data/val_images \ 12 | --batch-size 256 \ 13 | --workers 16 \ 14 | --output-all \ 15 | --attentional-pool \ 16 | --cosmos \ 17 | 18 | # OpenCLIP models 19 | # --model ViT-B-16 --pretrained [laion400m_e32, datacomp_xl_s13b_b90k, laion2b_s34b_b88k] 20 | # --model ViT-B-32 --pretrained [laion400m_e32, datacomp_xl_s13b_b90k, laion2b_s34b_b79k] 21 | torchrun --nproc_per_node 1 -m main \ 22 | --model ViT-B-16 \ 23 | --pretrained laion400m_e32 \ 24 | --val-data classification \ 25 | --imagenet-val /directory/to/your/imagenet/data/val_images \ 26 | --batch-size 256 \ 27 | --workers 16 \ -------------------------------------------------------------------------------- /src/inference_retrieval.sh: -------------------------------------------------------------------------------- 1 | # COSMOS models 2 | # --model ViT-B-16 3 | # --huggingface-model-name [cosmos_vitb16_cc3m, cosmos_vitb16_cc12m, cosmos_vitb16_yfcc15m, cosmos_vitb16_merged30m, cosmos_vitb16_pixelprose] 4 | # --model ViT-B-32 5 | # --huggingface-model-name [cosmos_vitb32_cc3m, cosmos_vitb32_cc12m, cosmos_vitb32_yfcc15m, cosmos_vitb32_merged30m, cosmos_vitb32_pixelprose] 6 | torchrun --nproc_per_node 1 -m main \ 7 | --model ViT-B-16 \ 8 | --huggingface-repo-name sankim2/cosmos \ 9 | --huggingface-model-name cosmos_vitb16_merged30m.pt \ 10 | --val-data retrieval \ 11 | --data-root-dir /directory/to/your/coco/and/flickr30k/ \ 12 | --batch-size 256 \ 13 | --workers 16 \ 14 | --output-all \ 15 | --attentional-pool \ 16 | --cosmos \ 17 | 18 | # OpenCLIP models 19 | # --model ViT-B-16 --pretrained [laion400m_e32, datacomp_xl_s13b_b90k, laion2b_s34b_b88k] 20 | # --model ViT-B-32 --pretrained [laion400m_e32, datacomp_xl_s13b_b90k, laion2b_s34b_b79k] 21 | torchrun --nproc_per_node 1 -m main \ 22 | --model ViT-B-16 \ 23 | --pretrained laion400m_e32 \ 24 | --val-data retrieval \ 25 | --data-root-dir /directory/to/your/coco/and/flickr30k/ \ 26 | --batch-size 256 \ 27 | --workers 16 \ -------------------------------------------------------------------------------- /src/inference_segmentation.sh: -------------------------------------------------------------------------------- 1 | # COSMOS models 2 | # --model ViT-B-16 3 | # --huggingface-model-name [cosmos_vitb16_cc3m, cosmos_vitb16_cc12m, cosmos_vitb16_yfcc15m, cosmos_vitb16_merged30m, cosmos_vitb16_pixelprose] 4 | # --model ViT-B-32 5 | # --huggingface-model-name [cosmos_vitb32_cc3m, cosmos_vitb32_cc12m, cosmos_vitb32_yfcc15m, cosmos_vitb32_merged30m, cosmos_vitb32_pixelprose] 6 | # --seg-w-background : segmentation with background benchmarks 7 | # --use-csa : using Correlative Self-Attention (CSA) block from SCLIP 8 | torchrun --nproc_per_node 1 -m seg_eval.py \ 9 | --model ViT-B-16 \ 10 | --huggingface-repo-name sankim2/cosmos \ 11 | --huggingface-model-name cosmos_vitb16_merged30m.pt \ 12 | --batch-size 256 \ 13 | --workers 16 \ 14 | --output-all \ 15 | --attentional-pool \ 16 | --cosmos \ 17 | --seg-w-background \ 18 | --use-csa 19 | 20 | # OpenCLIP models 21 | # --model ViT-B-16 --pretrained [laion400m_e32, datacomp_xl_s13b_b90k, laion2b_s34b_b88k] 22 | # --model ViT-B-32 --pretrained [laion400m_e32, datacomp_xl_s13b_b90k, laion2b_s34b_b79k] 23 | torchrun --nproc_per_node 1 -m seg_eval.py \ 24 | --model ViT-B-16 \ 25 | --pretrained laion400m_e32 \ 26 | --batch-size 256 \ 27 | --workers 16 \ -------------------------------------------------------------------------------- /src/open_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .coca_model import CoCa 2 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 3 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss 4 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint 5 | from .loss import ClipLoss, DistillClipLoss, CoCaLoss 6 | from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ 7 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \ 8 | get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg 9 | from .openai import load_openai_model, list_openai_models 10 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ 11 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 12 | from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub 13 | from .tokenizer import SimpleTokenizer, tokenize, decode 14 | from .transform import image_transform, AugmentationCfg 15 | from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy 16 | from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES 17 | -------------------------------------------------------------------------------- /src/open_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/cosmos/c80348f08e9b02c5a81adadd7dce486b3c284b78/src/open_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /src/open_clip/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 3 | IMAGENET_MEAN = (0.485, 0.456, 0.406) 4 | IMAGENET_STD = (0.229, 0.224, 0.225) 5 | INCEPTION_MEAN = (0.5, 0.5, 0.5) 6 | INCEPTION_STD = (0.5, 0.5, 0.5) 7 | -------------------------------------------------------------------------------- /src/open_clip/convert.py: -------------------------------------------------------------------------------- 1 | """ Conversion functions for 3rd part state-dicts and non-torch native checkpoint formats. 2 | """ 3 | from typing import Union 4 | 5 | import torch 6 | import numpy as np 7 | 8 | from .model import CLIP, CustomTextCLIP 9 | from .transformer import TextTransformer, Transformer 10 | 11 | 12 | @torch.no_grad() 13 | def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str): 14 | """ Load weights from .npz checkpoints for official Google big_vision image-text models 15 | 16 | Currently the SigLIP source models are supported and a CustomTextCLIP destination model 17 | w/ timm image encoder. 18 | """ 19 | from timm.layers import resample_patch_embed, resample_abs_pos_embed 20 | 21 | def _n2p(w, t=True): 22 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 23 | w = w.flatten() 24 | if t: 25 | if w.ndim == 4: 26 | w = w.transpose([3, 2, 0, 1]) 27 | elif w.ndim == 3: 28 | w = w.transpose([2, 0, 1]) 29 | elif w.ndim == 2: 30 | w = w.transpose([1, 0]) 31 | return torch.from_numpy(w) 32 | 33 | w = np.load(checkpoint_path) 34 | interpolation = 'bilinear' 35 | antialias = False 36 | 37 | def _convert_timm_img(module, prefix): 38 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 39 | if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]: 40 | embed_conv_w = resample_patch_embed( 41 | embed_conv_w, 42 | module.patch_embed.proj.weight.shape[-2:], 43 | interpolation=interpolation, 44 | antialias=antialias, 45 | verbose=True, 46 | ) 47 | module.patch_embed.proj.weight.copy_(embed_conv_w) 48 | module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 49 | 50 | if module.cls_token is not None: 51 | module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 52 | 53 | pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False) 54 | if pos_embed_w.shape != module.pos_embed.shape: 55 | assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}' 56 | num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1) 57 | pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights 58 | pos_embed_w, 59 | new_size=module.patch_embed.grid_size, 60 | num_prefix_tokens=num_prefix_tokens, 61 | interpolation=interpolation, 62 | antialias=antialias, 63 | verbose=True, 64 | ) 65 | module.pos_embed.copy_(pos_embed_w) 66 | 67 | mha_sub, b_sub, ln1_sub = (0, 0, 1) 68 | for i, block in enumerate(module.blocks.children()): 69 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 70 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/' 71 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 72 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 73 | block.attn.qkv.weight.copy_(torch.cat([ 74 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 75 | block.attn.qkv.bias.copy_(torch.cat([ 76 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 77 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 78 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 79 | for r in range(2): 80 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'])) 81 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'])) 82 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'])) 83 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'])) 84 | 85 | module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 86 | module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 87 | 88 | if module.attn_pool is not None: 89 | block_prefix = f'{prefix}MAPHead_0/' 90 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' 91 | module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False)) 92 | module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T) 93 | module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1)) 94 | module.attn_pool.kv.weight.copy_(torch.cat([ 95 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')])) 96 | module.attn_pool.kv.bias.copy_(torch.cat([ 97 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')])) 98 | module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 99 | module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 100 | module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 101 | module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 102 | for r in range(2): 103 | getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel'])) 104 | getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias'])) 105 | 106 | def _convert_openclip_transformer(module: Transformer, prefix): 107 | for i, block in enumerate(module.resblocks.children()): 108 | block_prefix = f'{prefix}encoderblock_{i}/' 109 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' 110 | block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 111 | block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 112 | block.attn.in_proj_weight.copy_(torch.cat([ 113 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 114 | block.attn.in_proj_bias.copy_(torch.cat([ 115 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 116 | block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 117 | block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 118 | block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale'])) 119 | block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias'])) 120 | block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel'])) 121 | block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias'])) 122 | block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel'])) 123 | block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias'])) 124 | 125 | def _convert_openclip_txt(module: TextTransformer, prefix): 126 | module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False)) 127 | pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0) 128 | module.positional_embedding.copy_(pos_embed_w) 129 | _convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/') 130 | module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale'])) 131 | module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias'])) 132 | module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 133 | module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 134 | 135 | _convert_timm_img(model.visual.trunk, 'params/img/') 136 | _convert_openclip_txt(model.text, 'params/txt/') 137 | model.logit_bias.copy_(_n2p(w['params/b'])[0]) 138 | model.logit_scale.copy_(_n2p(w['params/t'])[0]) 139 | 140 | 141 | @torch.no_grad() 142 | def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fastvit = True): 143 | 144 | def _convert_timm_img(state_dict): 145 | if fastvit: 146 | from timm.models.fastvit import checkpoint_filter_fn 147 | else: 148 | from timm.models.vision_transformer_hybrid import checkpoint_filter_fn 149 | timm_state_dict = checkpoint_filter_fn(state_dict, model.visual.trunk) 150 | timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()} 151 | return timm_state_dict 152 | 153 | def _convert_openclip_txt(state_dict, prefix='text_encoder.'): 154 | text_dict = {} 155 | for k, v in state_dict.items(): 156 | if not k.startswith(prefix): 157 | continue 158 | k = k.replace(prefix, '') 159 | k = k.replace('projection_layer', 'text_projection') 160 | k = k.replace('embedding_layer', 'token_embedding') 161 | if k.startswith('positional_embedding.pos_embed.pos_embed'): 162 | k = k.replace('positional_embedding.pos_embed.pos_embed', 'positional_embedding') 163 | v = v.squeeze() 164 | k = k.replace('final_layer_norm', 'ln_final') 165 | k = k.replace('pre_norm_mha.0', 'ln_1') 166 | k = k.replace('pre_norm_mha.1', 'attn') 167 | k = k.replace('pre_norm_ffn.0', 'ln_2') 168 | k = k.replace('pre_norm_ffn.1', 'mlp.c_fc') 169 | k = k.replace('pre_norm_ffn.4', 'mlp.c_proj') 170 | k = k.replace('qkv_proj.weight', 'in_proj_weight') 171 | k = k.replace('qkv_proj.bias', 'in_proj_bias') 172 | k = k.replace('transformer.', 'transformer.resblocks.') 173 | text_dict['text.' + k] = v 174 | return text_dict 175 | 176 | image_dict = _convert_timm_img(state_dict) 177 | text_dict = _convert_openclip_txt(state_dict) 178 | out_dict = {**image_dict, **text_dict} 179 | out_dict['logit_scale'] = state_dict['logit_scale'] 180 | return out_dict 181 | 182 | 183 | def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict): 184 | if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict: 185 | # Apple MobileCLIP s1 & s2 state_dicts (s0 and b not currently supported) 186 | state_dict = convert_mobile_clip_state_dict(model, state_dict) 187 | if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict: 188 | # convert b model 189 | state_dict = convert_mobile_clip_state_dict(model, state_dict, fastvit=False) 190 | return state_dict 191 | -------------------------------------------------------------------------------- /src/open_clip/hf_configs.py: -------------------------------------------------------------------------------- 1 | # HF architecture dict: 2 | arch_dict = { 3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 4 | "roberta": { 5 | "config_names": { 6 | "context_length": "max_position_embeddings", 7 | "vocab_size": "vocab_size", 8 | "width": "hidden_size", 9 | "heads": "num_attention_heads", 10 | "layers": "num_hidden_layers", 11 | "layer_attr": "layer", 12 | "token_embeddings_attr": "embeddings" 13 | }, 14 | "pooler": "mean_pooler", 15 | }, 16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 17 | "xlm-roberta": { 18 | "config_names": { 19 | "context_length": "max_position_embeddings", 20 | "vocab_size": "vocab_size", 21 | "width": "hidden_size", 22 | "heads": "num_attention_heads", 23 | "layers": "num_hidden_layers", 24 | "layer_attr": "layer", 25 | "token_embeddings_attr": "embeddings" 26 | }, 27 | "pooler": "mean_pooler", 28 | }, 29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 30 | "mt5": { 31 | "config_names": { 32 | # unlimited seqlen 33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 35 | "context_length": "", 36 | "vocab_size": "vocab_size", 37 | "width": "d_model", 38 | "heads": "num_heads", 39 | "layers": "num_layers", 40 | "layer_attr": "block", 41 | "token_embeddings_attr": "embed_tokens" 42 | }, 43 | "pooler": "mean_pooler", 44 | }, 45 | # https://huggingface.co/docs/transformers/model_doc/bert 46 | "bert": { 47 | "config_names": { 48 | "context_length": "max_position_embeddings", 49 | "vocab_size": "vocab_size", 50 | "width": "hidden_size", 51 | "heads": "num_attention_heads", 52 | "layers": "num_hidden_layers", 53 | }, 54 | "pooler": "cls_pooler", 55 | }, 56 | # https://huggingface.co/docs/transformers/model_doc/m2m_100 57 | "m2m_100": { 58 | "config_names": { 59 | "context_length": "max_position_embeddings", 60 | "vocab_size": "vocab_size", 61 | "width": "d_model", 62 | "heads": "encoder_attention_heads", 63 | "layers": "encoder_layers", 64 | }, 65 | "pooler": "cls_pooler", 66 | }, 67 | } 68 | -------------------------------------------------------------------------------- /src/open_clip/hf_model.py: -------------------------------------------------------------------------------- 1 | """ huggingface model adapter 2 | 3 | Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. 4 | """ 5 | import re 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch import TensorType 10 | 11 | try: 12 | import transformers 13 | from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig 14 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ 15 | BaseModelOutputWithPoolingAndCrossAttentions 16 | except ImportError as e: 17 | transformers = None 18 | 19 | 20 | class BaseModelOutput: 21 | pass 22 | 23 | 24 | class PretrainedConfig: 25 | pass 26 | 27 | from .hf_configs import arch_dict 28 | 29 | 30 | # utils 31 | def _camel2snake(s): 32 | return re.sub(r'(? 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.act1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.act2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.act3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.act1(self.bn1(self.conv1(x))) 46 | out = self.act2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.act3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x, key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0., 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | 92 | return x[0] 93 | 94 | 95 | class ModifiedResNet(nn.Module): 96 | """ 97 | A ResNet class that is similar to torchvision's but contains the following changes: 98 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 99 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 100 | - The final pooling layer is a QKV attention instead of an average pool 101 | """ 102 | 103 | def __init__(self, layers, output_dim, heads, image_size=224, width=64): 104 | super().__init__() 105 | self.output_dim = output_dim 106 | self.image_size = image_size 107 | 108 | # the 3-layer stem 109 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 110 | self.bn1 = nn.BatchNorm2d(width // 2) 111 | self.act1 = nn.ReLU(inplace=True) 112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 113 | self.bn2 = nn.BatchNorm2d(width // 2) 114 | self.act2 = nn.ReLU(inplace=True) 115 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 116 | self.bn3 = nn.BatchNorm2d(width) 117 | self.act3 = nn.ReLU(inplace=True) 118 | self.avgpool = nn.AvgPool2d(2) 119 | 120 | # residual layers 121 | self._inplanes = width # this is a *mutable* variable used during construction 122 | self.layer1 = self._make_layer(width, layers[0]) 123 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 124 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 125 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 126 | 127 | embed_dim = width * 32 # the ResNet feature dimension 128 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 129 | 130 | self.init_parameters() 131 | 132 | def _make_layer(self, planes, blocks, stride=1): 133 | layers = [Bottleneck(self._inplanes, planes, stride)] 134 | 135 | self._inplanes = planes * Bottleneck.expansion 136 | for _ in range(1, blocks): 137 | layers.append(Bottleneck(self._inplanes, planes)) 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def init_parameters(self): 142 | if self.attnpool is not None: 143 | std = self.attnpool.c_proj.in_features ** -0.5 144 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 145 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 146 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 147 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 148 | 149 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 150 | for name, param in resnet_block.named_parameters(): 151 | if name.endswith("bn3.weight"): 152 | nn.init.zeros_(param) 153 | 154 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 155 | assert unlocked_groups == 0, 'partial locking not currently supported for this model' 156 | for param in self.parameters(): 157 | param.requires_grad = False 158 | if freeze_bn_stats: 159 | freeze_batch_norm_2d(self) 160 | 161 | @torch.jit.ignore 162 | def set_grad_checkpointing(self, enable=True): 163 | # FIXME support for non-transformer 164 | pass 165 | 166 | def stem(self, x): 167 | x = self.act1(self.bn1(self.conv1(x))) 168 | x = self.act2(self.bn2(self.conv2(x))) 169 | x = self.act3(self.bn3(self.conv3(x))) 170 | x = self.avgpool(x) 171 | return x 172 | 173 | def forward(self, x): 174 | x = self.stem(x) 175 | x = self.layer1(x) 176 | x = self.layer2(x) 177 | x = self.layer3(x) 178 | x = self.layer4(x) 179 | x = self.attnpool(x) 180 | 181 | return x 182 | -------------------------------------------------------------------------------- /src/open_clip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import List, Optional, Union 9 | 10 | import torch 11 | 12 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 13 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype 14 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url 15 | 16 | __all__ = ["list_openai_models", "load_openai_model"] 17 | 18 | 19 | def list_openai_models() -> List[str]: 20 | """Returns the names of available CLIP models""" 21 | return list_pretrained_models_by_tag('openai') 22 | 23 | 24 | def load_openai_model( 25 | name: str, 26 | precision: Optional[str] = None, 27 | device: Optional[Union[str, torch.device]] = None, 28 | cache_dir: Optional[str] = None, 29 | ): 30 | """Load a CLIP model 31 | 32 | Parameters 33 | ---------- 34 | name : str 35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 36 | precision: str 37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | cache_dir : Optional[str] 41 | The directory to cache the downloaded model weights 42 | 43 | Returns 44 | ------- 45 | model : torch.nn.Module 46 | The CLIP model 47 | preprocess : Callable[[PIL.Image], torch.Tensor] 48 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 49 | """ 50 | if device is None: 51 | device = "cuda" if torch.cuda.is_available() else "cpu" 52 | if precision is None: 53 | precision = 'fp32' if device == 'cpu' else 'fp16' 54 | 55 | if get_pretrained_url(name, 'openai'): 56 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) 57 | elif os.path.isfile(name): 58 | model_path = name 59 | else: 60 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 61 | 62 | try: 63 | # loading JIT archive 64 | model = torch.jit.load(model_path, map_location="cpu").eval() 65 | state_dict = None 66 | except RuntimeError: 67 | # loading saved state dict 68 | state_dict = torch.load(model_path, map_location="cpu") 69 | 70 | # Build a non-jit model from the OpenAI jitted model state dict 71 | cast_dtype = get_cast_dtype(precision) 72 | try: 73 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) 74 | except KeyError: 75 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 76 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) 77 | 78 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use 79 | model = model.to(device) 80 | # FIXME support pure fp16/bf16 precision modes 81 | if precision != 'fp16': 82 | model.float() 83 | if precision == 'bf16': 84 | # for bf16, convert back to low-precision 85 | convert_weights_to_lp(model, dtype=torch.bfloat16) 86 | 87 | # add mean / std attributes for consistency with OpenCLIP models 88 | model.visual.image_mean = OPENAI_DATASET_MEAN 89 | model.visual.image_std = OPENAI_DATASET_STD 90 | return model 91 | -------------------------------------------------------------------------------- /src/open_clip/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | -------------------------------------------------------------------------------- /src/open_clip/push_to_hf_hub.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from pathlib import Path 5 | from tempfile import TemporaryDirectory 6 | from typing import Optional, Tuple, Union 7 | 8 | import torch 9 | 10 | try: 11 | from huggingface_hub import ( 12 | create_repo, 13 | get_hf_file_metadata, 14 | hf_hub_download, 15 | hf_hub_url, 16 | repo_type_and_id_from_hf_id, 17 | upload_folder, 18 | list_repo_files, 19 | ) 20 | from huggingface_hub.utils import EntryNotFoundError 21 | _has_hf_hub = True 22 | except ImportError: 23 | _has_hf_hub = False 24 | 25 | try: 26 | import safetensors.torch 27 | _has_safetensors = True 28 | except ImportError: 29 | _has_safetensors = False 30 | 31 | from .factory import create_model_from_pretrained, get_model_config, get_tokenizer 32 | from .tokenizer import HFTokenizer 33 | 34 | # Default name for a weights file hosted on the Huggingface Hub. 35 | HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl 36 | HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version 37 | HF_CONFIG_NAME = 'open_clip_config.json' 38 | 39 | 40 | def save_config_for_hf( 41 | model, 42 | config_path: str, 43 | model_config: Optional[dict] 44 | ): 45 | preprocess_cfg = { 46 | 'mean': model.visual.image_mean, 47 | 'std': model.visual.image_std, 48 | } 49 | other_pp = getattr(model.visual, 'preprocess_cfg', {}) 50 | if 'interpolation' in other_pp: 51 | preprocess_cfg['interpolation'] = other_pp['interpolation'] 52 | if 'resize_mode' in other_pp: 53 | preprocess_cfg['resize_mode'] = other_pp['resize_mode'] 54 | hf_config = { 55 | 'model_cfg': model_config, 56 | 'preprocess_cfg': preprocess_cfg, 57 | } 58 | 59 | with config_path.open('w') as f: 60 | json.dump(hf_config, f, indent=2) 61 | 62 | 63 | def save_for_hf( 64 | model, 65 | tokenizer: HFTokenizer, 66 | model_config: dict, 67 | save_directory: str, 68 | safe_serialization: Union[bool, str] = 'both', 69 | skip_weights : bool = False, 70 | ): 71 | config_filename = HF_CONFIG_NAME 72 | 73 | save_directory = Path(save_directory) 74 | save_directory.mkdir(exist_ok=True, parents=True) 75 | 76 | if not skip_weights: 77 | tensors = model.state_dict() 78 | if safe_serialization is True or safe_serialization == "both": 79 | assert _has_safetensors, "`pip install safetensors` to use .safetensors" 80 | safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME) 81 | if safe_serialization is False or safe_serialization == "both": 82 | torch.save(tensors, save_directory / HF_WEIGHTS_NAME) 83 | 84 | tokenizer.save_pretrained(save_directory) 85 | 86 | config_path = save_directory / config_filename 87 | save_config_for_hf(model, config_path, model_config=model_config) 88 | 89 | 90 | def push_to_hf_hub( 91 | model, 92 | tokenizer, 93 | model_config: Optional[dict], 94 | repo_id: str, 95 | commit_message: str = 'Add model', 96 | token: Optional[str] = None, 97 | revision: Optional[str] = None, 98 | private: bool = False, 99 | create_pr: bool = False, 100 | model_card: Optional[dict] = None, 101 | safe_serialization: Union[bool, str] = 'both', 102 | ): 103 | if not isinstance(tokenizer, HFTokenizer): 104 | # FIXME this makes it awkward to push models with new tokenizers, come up with better soln. 105 | # default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14 106 | tokenizer = HFTokenizer('openai/clip-vit-large-patch14') 107 | 108 | # Create repo if it doesn't exist yet 109 | repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) 110 | 111 | # Infer complete repo_id from repo_url 112 | # Can be different from the input `repo_id` if repo_owner was implicit 113 | _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) 114 | repo_id = f"{repo_owner}/{repo_name}" 115 | 116 | # Check if repo already exists and determine what needs updating 117 | repo_exists = False 118 | repo_files = {} 119 | try: 120 | repo_files = set(list_repo_files(repo_id)) 121 | repo_exists = True 122 | except Exception as e: 123 | print('Repo does not exist', e) 124 | 125 | try: 126 | get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision)) 127 | has_readme = True 128 | except EntryNotFoundError: 129 | has_readme = False 130 | 131 | # Dump model and push to Hub 132 | with TemporaryDirectory() as tmpdir: 133 | # Save model weights and config. 134 | save_for_hf( 135 | model, 136 | tokenizer=tokenizer, 137 | model_config=model_config, 138 | save_directory=tmpdir, 139 | safe_serialization=safe_serialization, 140 | ) 141 | 142 | # Add readme if it does not exist 143 | if not has_readme: 144 | model_card = model_card or {} 145 | model_name = repo_id.split('/')[-1] 146 | readme_path = Path(tmpdir) / "README.md" 147 | readme_text = generate_readme(model_card, model_name) 148 | readme_path.write_text(readme_text) 149 | 150 | # Upload model and return 151 | return upload_folder( 152 | repo_id=repo_id, 153 | folder_path=tmpdir, 154 | revision=revision, 155 | create_pr=create_pr, 156 | commit_message=commit_message, 157 | ) 158 | 159 | 160 | def push_pretrained_to_hf_hub( 161 | model_name, 162 | pretrained: str, 163 | repo_id: str, 164 | precision: str = 'fp32', 165 | image_mean: Optional[Tuple[float, ...]] = None, 166 | image_std: Optional[Tuple[float, ...]] = None, 167 | image_interpolation: Optional[str] = None, 168 | image_resize_mode: Optional[str] = None, # only effective for inference 169 | commit_message: str = 'Add model', 170 | token: Optional[str] = None, 171 | revision: Optional[str] = None, 172 | private: bool = False, 173 | create_pr: bool = False, 174 | model_card: Optional[dict] = None, 175 | hf_tokenizer_self: bool = False, 176 | **kwargs, 177 | ): 178 | model, preprocess_eval = create_model_from_pretrained( 179 | model_name, 180 | pretrained=pretrained, 181 | precision=precision, 182 | image_mean=image_mean, 183 | image_std=image_std, 184 | image_interpolation=image_interpolation, 185 | image_resize_mode=image_resize_mode, 186 | **kwargs, 187 | ) 188 | model_config = get_model_config(model_name) 189 | if pretrained == 'openai': 190 | model_config['quick_gelu'] = True 191 | assert model_config 192 | 193 | tokenizer = get_tokenizer(model_name) 194 | if hf_tokenizer_self: 195 | # make hf tokenizer config in the uploaded model point to self instead of original location 196 | model_config['text']['hf_tokenizer_name'] = repo_id 197 | 198 | push_to_hf_hub( 199 | model=model, 200 | tokenizer=tokenizer, 201 | model_config=model_config, 202 | repo_id=repo_id, 203 | commit_message=commit_message, 204 | token=token, 205 | revision=revision, 206 | private=private, 207 | create_pr=create_pr, 208 | model_card=model_card, 209 | safe_serialization='both', 210 | ) 211 | 212 | 213 | def generate_readme(model_card: dict, model_name: str): 214 | tags = model_card.pop('tags', ('clip',)) 215 | pipeline_tag = model_card.pop('pipeline_tag', 'zero-shot-image-classification') 216 | readme_text = "---\n" 217 | if tags: 218 | readme_text += "tags:\n" 219 | for t in tags: 220 | readme_text += f"- {t}\n" 221 | readme_text += "library_name: open_clip\n" 222 | readme_text += f"pipeline_tag: {pipeline_tag}\n" 223 | readme_text += f"license: {model_card.get('license', 'mit')}\n" 224 | if 'details' in model_card and 'Dataset' in model_card['details']: 225 | readme_text += 'datasets:\n' 226 | readme_text += f"- {model_card['details']['Dataset'].lower()}\n" 227 | readme_text += "---\n" 228 | readme_text += f"# Model card for {model_name}\n" 229 | if 'description' in model_card: 230 | readme_text += f"\n{model_card['description']}\n" 231 | if 'details' in model_card: 232 | readme_text += f"\n## Model Details\n" 233 | for k, v in model_card['details'].items(): 234 | if isinstance(v, (list, tuple)): 235 | readme_text += f"- **{k}:**\n" 236 | for vi in v: 237 | readme_text += f" - {vi}\n" 238 | elif isinstance(v, dict): 239 | readme_text += f"- **{k}:**\n" 240 | for ki, vi in v.items(): 241 | readme_text += f" - {ki}: {vi}\n" 242 | else: 243 | readme_text += f"- **{k}:** {v}\n" 244 | if 'usage' in model_card: 245 | readme_text += f"\n## Model Usage\n" 246 | readme_text += model_card['usage'] 247 | readme_text += '\n' 248 | 249 | if 'comparison' in model_card: 250 | readme_text += f"\n## Model Comparison\n" 251 | readme_text += model_card['comparison'] 252 | readme_text += '\n' 253 | 254 | if 'citation' in model_card: 255 | readme_text += f"\n## Citation\n" 256 | if not isinstance(model_card['citation'], (list, tuple)): 257 | citations = [model_card['citation']] 258 | else: 259 | citations = model_card['citation'] 260 | for c in citations: 261 | readme_text += f"```bibtex\n{c}\n```\n" 262 | 263 | return readme_text 264 | 265 | 266 | if __name__ == "__main__": 267 | parser = argparse.ArgumentParser(description="Push to Hugging Face Hub") 268 | parser.add_argument( 269 | "--model", type=str, help="Name of the model to use.", 270 | ) 271 | parser.add_argument( 272 | "--pretrained", type=str, 273 | help="Use a pretrained CLIP model weights with the specified tag or file path.", 274 | ) 275 | parser.add_argument( 276 | "--repo-id", type=str, 277 | help="Destination HF Hub repo-id ie 'organization/model_id'.", 278 | ) 279 | parser.add_argument( 280 | "--precision", type=str, default='fp32', 281 | ) 282 | parser.add_argument( 283 | '--image-mean', type=float, nargs='+', default=None, metavar='MEAN', 284 | help='Override default image mean value of dataset') 285 | parser.add_argument( 286 | '--image-std', type=float, nargs='+', default=None, metavar='STD', 287 | help='Override default image std deviation of of dataset') 288 | parser.add_argument( 289 | '--image-interpolation', 290 | default=None, type=str, choices=['bicubic', 'bilinear', 'random'], 291 | help="image resize interpolation" 292 | ) 293 | parser.add_argument( 294 | '--image-resize-mode', 295 | default=None, type=str, choices=['shortest', 'longest', 'squash'], 296 | help="image resize mode during inference" 297 | ) 298 | parser.add_argument( 299 | "--hf-tokenizer-self", 300 | default=False, 301 | action="store_true", 302 | help="make hf_tokenizer_name point in uploaded config point to itself" 303 | ) 304 | args = parser.parse_args() 305 | 306 | print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}') 307 | 308 | # FIXME add support to pass model_card json / template from file via cmd line 309 | 310 | push_pretrained_to_hf_hub( 311 | args.model, 312 | args.pretrained, 313 | args.repo_id, 314 | precision=args.precision, 315 | image_mean=args.image_mean, # override image mean/std if trained w/ non defaults 316 | image_std=args.image_std, 317 | image_interpolation=args.image_interpolation, 318 | image_resize_mode=args.image_resize_mode, 319 | ) 320 | 321 | print(f'{args.model} saved.') 322 | -------------------------------------------------------------------------------- /src/open_clip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | import logging 6 | from collections import OrderedDict 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | try: 12 | import timm 13 | from timm.models.layers import Mlp, to_2tuple 14 | try: 15 | # old timm imports < 0.8.1 16 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 17 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 18 | except ImportError: 19 | # new timm imports >= 0.8.1 20 | from timm.layers import RotAttentionPool2d 21 | from timm.layers import AttentionPool2d as AbsAttentionPool2d 22 | except ImportError: 23 | timm = None 24 | 25 | from .utils import freeze_batch_norm_2d 26 | 27 | 28 | class TimmModel(nn.Module): 29 | """ timm model adapter 30 | """ 31 | 32 | def __init__( 33 | self, 34 | model_name, 35 | embed_dim, 36 | image_size=224, 37 | pool='avg', 38 | proj='linear', 39 | proj_bias=False, 40 | drop=0., 41 | drop_path=None, 42 | patch_drop=None, 43 | pretrained=False, 44 | ): 45 | super().__init__() 46 | if timm is None: 47 | raise RuntimeError("Please `pip install timm` to use timm models.") 48 | self.image_size = to_2tuple(image_size) 49 | 50 | # setup kwargs that may not be common across all models 51 | timm_kwargs = {} 52 | if drop_path is not None: 53 | timm_kwargs['drop_path_rate'] = drop_path 54 | if patch_drop is not None: 55 | timm_kwargs['patch_drop_rate'] = patch_drop 56 | 57 | custom_pool = pool in ('abs_attn', 'rot_attn') 58 | if proj: 59 | assert proj in ("linear", "mlp", "none") 60 | extra_proj = proj in ("linear", "mlp") 61 | if not extra_proj and not custom_pool: 62 | # use network classifier head as projection if no proj specified and no custom pooling used 63 | # if projection is explicitly set to "none" will be pass through from network trunk 64 | proj_dim = 0 if proj == 'none' else embed_dim 65 | self.trunk = timm.create_model( 66 | model_name, 67 | num_classes=proj_dim, 68 | global_pool=pool, 69 | pretrained=pretrained, 70 | **timm_kwargs, 71 | ) 72 | prev_chs = embed_dim 73 | else: 74 | self.trunk = timm.create_model( 75 | model_name, 76 | pretrained=pretrained, 77 | **timm_kwargs, 78 | ) 79 | feat_size = self.trunk.default_cfg.get('pool_size', None) 80 | feature_ndim = 1 if not feat_size else 2 81 | if custom_pool: 82 | assert feature_ndim == 2 83 | # if attn pooling used, remove both classifier and default pool 84 | self.trunk.reset_classifier(0, global_pool='') 85 | else: 86 | # reset global pool if pool config set, otherwise leave as network default 87 | reset_kwargs = dict(global_pool=pool) if pool else {} 88 | self.trunk.reset_classifier(0, **reset_kwargs) 89 | prev_chs = self.trunk.num_features 90 | 91 | head_layers = OrderedDict() 92 | 93 | # Add custom pooling to head 94 | if pool == 'abs_attn': 95 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 96 | prev_chs = embed_dim 97 | elif pool == 'rot_attn': 98 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 99 | prev_chs = embed_dim 100 | 101 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 102 | if proj == 'linear': 103 | head_layers['drop'] = nn.Dropout(drop) 104 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) 105 | elif proj == 'mlp': 106 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) 107 | 108 | self.head = nn.Sequential(head_layers) 109 | 110 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 111 | """ lock modules 112 | Args: 113 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 114 | """ 115 | if not unlocked_groups: 116 | # lock full model 117 | for param in self.trunk.parameters(): 118 | param.requires_grad = False 119 | if freeze_bn_stats: 120 | freeze_batch_norm_2d(self.trunk) 121 | else: 122 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 123 | try: 124 | # FIXME import here until API stable and in an official release 125 | from timm.models.helpers import group_parameters, group_modules 126 | except ImportError: 127 | raise RuntimeError( 128 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') 129 | matcher = self.trunk.group_matcher() 130 | gparams = group_parameters(self.trunk, matcher) 131 | max_layer_id = max(gparams.keys()) 132 | max_layer_id = max_layer_id - unlocked_groups 133 | for group_idx in range(max_layer_id + 1): 134 | group = gparams[group_idx] 135 | for param in group: 136 | self.trunk.get_parameter(param).requires_grad = False 137 | if freeze_bn_stats: 138 | gmodules = group_modules(self.trunk, matcher, reverse=True) 139 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 140 | freeze_batch_norm_2d(self.trunk, gmodules) 141 | 142 | @torch.jit.ignore 143 | def set_grad_checkpointing(self, enable=True): 144 | try: 145 | self.trunk.set_grad_checkpointing(enable) 146 | except Exception as e: 147 | logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') 148 | 149 | def forward(self, x): 150 | x = self.trunk(x) 151 | x = self.head(x) 152 | return x 153 | -------------------------------------------------------------------------------- /src/open_clip/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | import collections.abc 3 | 4 | import torch 5 | from torch import nn as nn 6 | from torchvision.ops.misc import FrozenBatchNorm2d 7 | 8 | 9 | def freeze_batch_norm_2d(module, module_match={}, name=''): 10 | """ 11 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 12 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 13 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 14 | 15 | Args: 16 | module (torch.nn.Module): Any PyTorch module. 17 | module_match (dict): Dictionary of full module names to freeze (all if empty) 18 | name (str): Full module name (prefix) 19 | 20 | Returns: 21 | torch.nn.Module: Resulting module 22 | 23 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 24 | """ 25 | res = module 26 | is_match = True 27 | if module_match: 28 | is_match = name in module_match 29 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): 30 | res = FrozenBatchNorm2d(module.num_features) 31 | res.num_features = module.num_features 32 | res.affine = module.affine 33 | if module.affine: 34 | res.weight.data = module.weight.data.clone().detach() 35 | res.bias.data = module.bias.data.clone().detach() 36 | res.running_mean.data = module.running_mean.data 37 | res.running_var.data = module.running_var.data 38 | res.eps = module.eps 39 | else: 40 | for child_name, child in module.named_children(): 41 | full_child_name = '.'.join([name, child_name]) if name else child_name 42 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 43 | if new_child is not child: 44 | res.add_module(child_name, new_child) 45 | return res 46 | 47 | 48 | # From PyTorch internals 49 | def _ntuple(n): 50 | def parse(x): 51 | if isinstance(x, collections.abc.Iterable): 52 | return x 53 | return tuple(repeat(x, n)) 54 | return parse 55 | 56 | 57 | to_1tuple = _ntuple(1) 58 | to_2tuple = _ntuple(2) 59 | to_3tuple = _ntuple(3) 60 | to_4tuple = _ntuple(4) 61 | to_ntuple = lambda n, x: _ntuple(n)(x) 62 | 63 | # Replaces all linear layers with linear_replacement 64 | # TODO: add int8 support for other linear layers including attn and convnets 65 | def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True): 66 | for name, module in model.named_children(): 67 | if len(list(module.children())) > 0: 68 | replace_linear(module, linear_replacement, include_modules, copy_weights) 69 | 70 | if isinstance(module, torch.nn.Linear) and name in include_modules: 71 | old_module = model._modules[name] 72 | model._modules[name] = linear_replacement( 73 | module.in_features, 74 | module.out_features, 75 | module.bias is not None, 76 | ) 77 | if copy_weights: 78 | model._modules[name].weight.data.copy_(old_module.weight.data) 79 | if model._modules[name].bias is not None: 80 | model._modules[name].bias.data.copy_(old_module.bias) 81 | 82 | return model 83 | 84 | def convert_int8_model_to_inference_mode(model): 85 | for m in model.modules(): 86 | if hasattr(m, 'prepare_for_eval'): 87 | int8_original_dtype = m.weight.dtype 88 | m.prepare_for_eval() 89 | m.int8_original_dtype = int8_original_dtype 90 | -------------------------------------------------------------------------------- /src/open_clip/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '2.24.0' 2 | -------------------------------------------------------------------------------- /src/open_clip/zero_shot_classifier.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from itertools import islice 3 | from typing import Callable, List, Optional, Sequence, Union 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | def batched(iterable, n): 10 | """Batch data into lists of length *n*. The last batch may be shorter. 11 | NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl 12 | """ 13 | it = iter(iterable) 14 | while True: 15 | batch = list(islice(it, n)) 16 | if not batch: 17 | break 18 | yield batch 19 | 20 | 21 | def build_zero_shot_classifier( 22 | model, 23 | tokenizer, 24 | args, 25 | classnames: Sequence[str], 26 | templates: Sequence[Union[Callable, str]], 27 | num_classes_per_batch: Optional[int] = 10, 28 | device: Union[str, torch.device] = 'cpu', 29 | use_tqdm: bool = False, 30 | ): 31 | """ Build zero-shot classifier weights by iterating over class names in batches 32 | Args: 33 | model: CLIP model instance 34 | tokenizer: CLIP tokenizer instance 35 | classnames: A sequence of class (label) names 36 | templates: A sequence of callables or format() friendly strings to produce templates per class name 37 | num_classes_per_batch: The number of classes to batch together in each forward, all if None 38 | device: Device to use. 39 | use_tqdm: Enable TQDM progress bar. 40 | """ 41 | assert isinstance(templates, Sequence) and len(templates) > 0 42 | assert isinstance(classnames, Sequence) and len(classnames) > 0 43 | use_format = isinstance(templates[0], str) 44 | num_templates = len(templates) 45 | num_classes = len(classnames) 46 | if use_tqdm: 47 | import tqdm 48 | num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1) 49 | iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch) 50 | else: 51 | iter_wrap = iter 52 | 53 | def _process_batch(batch_classnames): 54 | num_batch_classes = len(batch_classnames) 55 | texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates] 56 | texts = tokenizer(texts).to(device) 57 | class_embeddings = model.encode_text(texts, normalize=True) 58 | if isinstance(class_embeddings, dict): 59 | class_embeddings = class_embeddings['text_features'] 60 | 61 | class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1) 62 | class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True) 63 | class_embeddings = class_embeddings.T # (512, num_batch_classes) 64 | return class_embeddings 65 | 66 | with torch.no_grad(): 67 | if num_classes_per_batch: 68 | batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))] 69 | zeroshot_weights = torch.cat(batched_embeds, dim=-1) 70 | else: 71 | zeroshot_weights = _process_batch(classnames) 72 | return zeroshot_weights 73 | 74 | 75 | def build_zero_shot_classifier_legacy( 76 | model, 77 | tokenizer, 78 | classnames: Sequence[str], 79 | templates: Sequence[Union[Callable, str]], 80 | device: Union[str, torch.device] = 'cpu', 81 | use_tqdm: bool = False, 82 | ): 83 | """ Build zero-shot classifier weights by iterating over class names 1 by 1 84 | Args: 85 | model: CLIP model instance 86 | tokenizer: CLIP tokenizer instance 87 | classnames: A sequence of class (label) names 88 | templates: A sequence of callables or format() friendly strings to produce templates per class name 89 | device: Device to use. 90 | use_tqdm: Enable TQDM progress bar. 91 | """ 92 | assert isinstance(templates, Sequence) and len(templates) > 0 93 | assert isinstance(classnames, Sequence) and len(classnames) > 0 94 | if use_tqdm: 95 | import tqdm 96 | iter_wrap = tqdm.tqdm 97 | else: 98 | iter_wrap = iter 99 | 100 | use_format = isinstance(templates[0], str) 101 | 102 | with torch.no_grad(): 103 | zeroshot_weights = [] 104 | for classname in iter_wrap(classnames): 105 | texts = [template.format(classname) if use_format else template(classname) for template in templates] 106 | texts = tokenizer(texts).to(device) # tokenize 107 | class_embeddings = model.encode_text(texts) 108 | class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) 109 | class_embedding /= class_embedding.norm() 110 | zeroshot_weights.append(class_embedding) 111 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) 112 | 113 | return zeroshot_weights 114 | 115 | -------------------------------------------------------------------------------- /src/seg_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import training.clip_segmentor 5 | import training.custom_datasets 6 | from training.params import parse_args 7 | 8 | from mmengine.config import Config 9 | from mmengine.runner import Runner 10 | 11 | 12 | def trigger_visualization_hook(cfg, args): 13 | default_hooks = cfg.default_hooks 14 | if 'visualization' in default_hooks: 15 | visualization_hook = default_hooks['visualization'] 16 | # Turn on visualization 17 | visualization_hook['draw'] = True 18 | if args.show: 19 | visualization_hook['show'] = True 20 | visualization_hook['wait_time'] = args.wait_time 21 | if args.show_dir: 22 | visualizer = cfg.visualizer 23 | visualizer['save_dir'] = args.show_dir 24 | else: 25 | raise RuntimeError( 26 | 'VisualizationHook must be included in default_hooks.' 27 | 'refer to usage ' 28 | '"visualization=dict(type=\'VisualizationHook\')"') 29 | 30 | return cfg 31 | 32 | 33 | def main(args): 34 | args = parse_args(args) 35 | if args.seg_w_background: # with background 36 | conf_files = ['cfg_voc21.py', 'cfg_context60.py', 'cfg_coco_object.py'] 37 | else: 38 | conf_files = ['cfg_voc20.py', 'cfg_city_scapes.py', 'cfg_context59.py', 'cfg_ade20k.py', 'cfg_coco_stuff164k.py'] 39 | 40 | for conf_f in conf_files: 41 | cfg = Config.fromfile(f'./training/seg_configs/{conf_f}') 42 | cfg.launcher = 'none' 43 | cfg.work_dir = './work_logs/' 44 | 45 | openclip_args = {} 46 | for arg in vars(args): 47 | openclip_args[arg] = getattr(args, arg) 48 | cfg.model.update(openclip_args) 49 | 50 | runner = Runner.from_cfg(cfg) 51 | runner.test() 52 | 53 | 54 | if __name__ == "__main__": 55 | main(sys.argv[1:]) -------------------------------------------------------------------------------- /src/train_cc12m.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -x 2 | #SBATCH --nodes=32 3 | #SBATCH --gres=gpu:4 4 | #SBATCH --ntasks-per-node=4 5 | #SBATCH --cpus-per-task=48 6 | #SBATCH --job-name=train_cc12m 7 | #SBATCH --output=./cluster_logs/train_cc12m.txt 8 | #SBATCH --partition PARTITION_NAME 9 | 10 | source cosmos_env/bin/activate 11 | cd cosmos/src 12 | 13 | export CUDA_VISIBLE_DEVICES=0,1,2,3 14 | export MASTER_PORT=12802 15 | 16 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 17 | export MASTER_ADDR=$master_addr 18 | 19 | srun env -u CUDA_VISIBLE_DEVICES torchrun \ 20 | --nproc_per_node=4 \ 21 | --nnode=$SLURM_JOB_NUM_NODES \ 22 | --rdzv_id=$SLURM_JOB_ID \ 23 | --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ 24 | --rdzv_backend=c10d \ 25 | -m main \ 26 | --logs-dir ./logs/ \ 27 | --model ViT-B-16 \ 28 | --dataset-type webdataset \ 29 | --lr 5e-4 \ 30 | --warmup 2000 \ 31 | --epochs 32 \ 32 | --train-data 'datasets/cc12m_recap/cc12m-train-{0000..2175}.tar' \ 33 | --train-num-samples 10010225 \ 34 | --val-data 'coco' \ 35 | --data-root-dir directory/to/coco/ \ 36 | --batch-size 32 \ 37 | --precision amp \ 38 | --workers 16 \ 39 | --save-frequency 1 \ 40 | --log-every-n-steps 200 \ 41 | --wd 0.5 \ 42 | --beta1 0.9 \ 43 | --beta2 0.98 \ 44 | --eps 1e-8 \ 45 | --use-imagecrop-aug \ 46 | --global-crops-number 2 \ 47 | --local-crops-number 6 \ 48 | --crop-scale 0.4 \ 49 | --caption-sampling-mode textcrop \ 50 | --num-sampled-captions 8 \ 51 | --momentum-teacher 0.99 \ 52 | --fix-momentum \ 53 | --output-all \ 54 | --attentional-pool \ 55 | --cosmos 56 | -------------------------------------------------------------------------------- /src/train_cc3m.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -x 2 | #SBATCH --nodes=4 3 | #SBATCH --gres=gpu:4 4 | #SBATCH --ntasks-per-node=4 5 | #SBATCH --cpus-per-task=48 6 | #SBATCH --job-name=train_cc3m 7 | #SBATCH --output=./cluster_logs/train_cc3m.txt 8 | #SBATCH --partition PARTITION_NAME 9 | 10 | source cosmos_env/bin/activate 11 | cd cosmos/src 12 | 13 | export CUDA_VISIBLE_DEVICES=0,1,2,3 14 | export MASTER_PORT=12802 15 | 16 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 17 | export MASTER_ADDR=$master_addr 18 | 19 | srun env -u CUDA_VISIBLE_DEVICES torchrun \ 20 | --nproc_per_node=4 \ 21 | --nnode=$SLURM_JOB_NUM_NODES \ 22 | --rdzv_id=$SLURM_JOB_ID \ 23 | --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ 24 | --rdzv_backend=c10d \ 25 | -m main \ 26 | --logs-dir ./logs/ \ 27 | --model ViT-B-16 \ 28 | --dataset-type webdataset \ 29 | --lr 5e-4 \ 30 | --warmup 2000 \ 31 | --epochs 32 \ 32 | --train-data 'datasets/cc3m_recap/cc3m-train-{0000..0575}.tar' \ 33 | --train-num-samples 2823019 \ 34 | --val-data 'datasets/cc3m/cc3m-validation-00{00..15}.tar' \ 35 | --val-num-samples 13443 \ 36 | --batch-size 64 \ 37 | --precision amp \ 38 | --workers 16 \ 39 | --save-frequency 1 \ 40 | --log-every-n-steps 200 \ 41 | --wd 0.5 \ 42 | --beta1 0.9 \ 43 | --beta2 0.98 \ 44 | --eps 1e-8 \ 45 | --use-imagecrop-aug \ 46 | --global-crops-number 2 \ 47 | --local-crops-number 6 \ 48 | --crop-scale 0.4 \ 49 | --caption-sampling-mode textcrop \ 50 | --num-sampled-captions 8 \ 51 | --momentum-teacher 0.999 \ 52 | --fix-momentum \ 53 | --output-all \ 54 | --attentional-pool \ 55 | --cosmos 56 | -------------------------------------------------------------------------------- /src/train_merged30m.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -x 2 | #SBATCH --nodes=32 3 | #SBATCH --gres=gpu:4 4 | #SBATCH --ntasks-per-node=4 5 | #SBATCH --cpus-per-task=48 6 | #SBATCH --job-name=train_merged30m 7 | #SBATCH --output=./cluster_logs/train_merged30m.txt 8 | #SBATCH --partition PARTITION_NAME 9 | 10 | source cosmos_env/bin/activate 11 | cd cosmos/src 12 | 13 | export CUDA_VISIBLE_DEVICES=0,1,2,3 14 | export MASTER_PORT=12802 15 | 16 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 17 | export MASTER_ADDR=$master_addr 18 | 19 | srun env -u CUDA_VISIBLE_DEVICES torchrun \ 20 | --nproc_per_node=4 \ 21 | --nnode=$SLURM_JOB_NUM_NODES \ 22 | --rdzv_id=$SLURM_JOB_ID \ 23 | --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ 24 | --rdzv_backend=c10d \ 25 | -m main \ 26 | --logs-dir ./logs/ \ 27 | --model ViT-B-16 \ 28 | --dataset-type webdataset \ 29 | --lr 5e-4 \ 30 | --warmup 2000 \ 31 | --epochs 32 \ 32 | --train-data '/datasets/yfcc15m_recap/yfcc15m-train-{0000..3636}.tar::/datasets/cc12m_recap/cc12m-train-{0000..2175}.tar::/datasets/cc3m_recap/cc3m-train-{0001..0575}.tar' \ 33 | --train-num-samples 26899071 \ 34 | --val-data 'coco' \ 35 | --data-root-dir directory/to/coco/ \ 36 | --batch-size 32 \ 37 | --precision amp \ 38 | --workers 16 \ 39 | --save-frequency 1 \ 40 | --log-every-n-steps 200 \ 41 | --wd 0.5 \ 42 | --beta1 0.9 \ 43 | --beta2 0.98 \ 44 | --eps 1e-8 \ 45 | --use-imagecrop-aug \ 46 | --global-crops-number 2 \ 47 | --local-crops-number 6 \ 48 | --crop-scale 0.4 \ 49 | --caption-sampling-mode textcrop \ 50 | --num-sampled-captions 8 \ 51 | --momentum-teacher 0.99 \ 52 | --fix-momentum \ 53 | --output-all \ 54 | --attentional-pool \ 55 | --cosmos 56 | -------------------------------------------------------------------------------- /src/train_pixelprose.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -x 2 | #SBATCH --nodes=32 3 | #SBATCH --gres=gpu:4 4 | #SBATCH --ntasks-per-node=4 5 | #SBATCH --cpus-per-task=48 6 | #SBATCH --job-name=train_pixelprose 7 | #SBATCH --output=./cluster_logs/train_pixelprose.txt 8 | #SBATCH --partition PARTITION_NAME 9 | 10 | source cosmos_env/bin/activate 11 | cd cosmos/src 12 | 13 | export CUDA_VISIBLE_DEVICES=0,1,2,3 14 | export MASTER_PORT=12802 15 | 16 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 17 | export MASTER_ADDR=$master_addr 18 | 19 | srun env -u CUDA_VISIBLE_DEVICES torchrun \ 20 | --nproc_per_node=4 \ 21 | --nnode=$SLURM_JOB_NUM_NODES \ 22 | --rdzv_id=$SLURM_JOB_ID \ 23 | --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ 24 | --rdzv_backend=c10d \ 25 | -m main \ 26 | --logs-dir ./logs/ \ 27 | --model ViT-B-16 \ 28 | --dataset-type webdataset \ 29 | --lr 5e-4 \ 30 | --warmup 2000 \ 31 | --epochs 32 \ 32 | --train-data '/datasets/pixelprose/data/{00000..01699}.tar' \ 33 | --train-num-samples 15037386 \ 34 | --val-data 'coco' \ 35 | --data-root-dir directory/to/coco/ \ 36 | --batch-size 32 \ 37 | --precision amp \ 38 | --workers 16 \ 39 | --save-frequency 1 \ 40 | --log-every-n-steps 200 \ 41 | --wd 0.5 \ 42 | --beta1 0.9 \ 43 | --beta2 0.98 \ 44 | --eps 1e-8 \ 45 | --use-imagecrop-aug \ 46 | --global-crops-number 2 \ 47 | --local-crops-number 6 \ 48 | --crop-scale 0.4 \ 49 | --caption-sampling-mode textcrop_pixelprose \ 50 | --num-sampled-captions 8 \ 51 | --momentum-teacher 0.99 \ 52 | --fix-momentum \ 53 | --output-all \ 54 | --attentional-pool \ 55 | --cosmos 56 | -------------------------------------------------------------------------------- /src/train_yfcc15m.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -x 2 | #SBATCH --nodes=32 3 | #SBATCH --gres=gpu:4 4 | #SBATCH --ntasks-per-node=4 5 | #SBATCH --cpus-per-task=48 6 | #SBATCH --job-name=train_yfcc15m 7 | #SBATCH --output=./cluster_logs/train_yfcc15m.txt 8 | #SBATCH --partition PARTITION_NAME 9 | 10 | source cosmos_env/bin/activate 11 | cd cosmos/src 12 | 13 | export CUDA_VISIBLE_DEVICES=0,1,2,3 14 | export MASTER_PORT=12802 15 | 16 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 17 | export MASTER_ADDR=$master_addr 18 | 19 | srun env -u CUDA_VISIBLE_DEVICES torchrun \ 20 | --nproc_per_node=4 \ 21 | --nnode=$SLURM_JOB_NUM_NODES \ 22 | --rdzv_id=$SLURM_JOB_ID \ 23 | --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ 24 | --rdzv_backend=c10d \ 25 | -m main \ 26 | --logs-dir ./logs/ \ 27 | --model ViT-B-16 \ 28 | --dataset-type webdataset \ 29 | --lr 5e-4 \ 30 | --warmup 2000 \ 31 | --epochs 32 \ 32 | --train-data 'datasets/yfcc15m_recap/yfcc15m-train-{0000..3636}.tar' \ 33 | --train-num-samples 14065827 \ 34 | --val-data 'coco' \ 35 | --data-root-dir directory/to/coco/ \ 36 | --batch-size 32 \ 37 | --precision amp \ 38 | --workers 16 \ 39 | --save-frequency 1 \ 40 | --log-every-n-steps 200 \ 41 | --wd 0.5 \ 42 | --beta1 0.9 \ 43 | --beta2 0.98 \ 44 | --eps 1e-8 \ 45 | --use-imagecrop-aug \ 46 | --global-crops-number 2 \ 47 | --local-crops-number 6 \ 48 | --crop-scale 0.4 \ 49 | --caption-sampling-mode textcrop \ 50 | --num-sampled-captions 8 \ 51 | --momentum-teacher 0.99 \ 52 | --fix-momentum \ 53 | --output-all \ 54 | --attentional-pool \ 55 | --cosmos 56 | -------------------------------------------------------------------------------- /src/training/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/training/clip_segmentor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import sys 4 | import os 5 | sys.path.append("..") 6 | 7 | import open_clip 8 | from open_clip import OPENAI_IMAGENET_TEMPLATES 9 | 10 | from mmseg.models.segmentors import BaseSegmentor 11 | from mmseg.models.data_preprocessor import SegDataPreProcessor 12 | from mmengine.structures import PixelData 13 | 14 | from mmseg.registry import MODELS 15 | 16 | from training.pamr import PAMR 17 | 18 | from training.file_utils import pt_load 19 | import copy 20 | 21 | import numpy as np 22 | from huggingface_hub import hf_hub_download 23 | 24 | def download_weights_from_hf(model_repo, filename): 25 | # Define the custom cache directory relative to the current script 26 | cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "pretrained") 27 | if not os.path.exists(cache_dir): 28 | os.makedirs(cache_dir, exist_ok=True) 29 | local_path = hf_hub_download(repo_id=model_repo, filename=filename, cache_dir=cache_dir) 30 | return local_path 31 | 32 | def load_model(**model_kwargs): 33 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 34 | input_model_kwargs = {} 35 | if model_kwargs['siglip']: 36 | input_model_kwargs['init_logit_scale'] = np.log(10) # different from CLIP 37 | input_model_kwargs['init_logit_bias'] = -10 38 | model, _, _ = open_clip.create_model_and_transforms( 39 | model_kwargs['model'], 40 | model_kwargs['pretrained'], 41 | precision=model_kwargs['precision'], 42 | device=device, 43 | jit=model_kwargs['torchscript'], 44 | force_quick_gelu=model_kwargs['force_quick_gelu'], 45 | force_custom_text=model_kwargs['force_custom_text'], 46 | force_patch_dropout=model_kwargs['force_patch_dropout'], 47 | force_image_size=model_kwargs['force_image_size'], 48 | image_mean=model_kwargs['image_mean'], 49 | image_std=model_kwargs['image_std'], 50 | image_interpolation=model_kwargs['image_interpolation'], 51 | image_resize_mode=model_kwargs['image_resize_mode'], 52 | use_imagecrop_aug=model_kwargs['use_imagecrop_aug'], 53 | global_crops_number=model_kwargs['global_crops_number'], 54 | local_crops_number=model_kwargs['local_crops_number'], 55 | crop_scale=model_kwargs['crop_scale'], 56 | aug_cfg=model_kwargs['aug_cfg'], 57 | pretrained_image=model_kwargs['pretrained_image'], 58 | output_dict=True, 59 | output_all=model_kwargs['output_all'], 60 | pool_type=model_kwargs['pool_type'], 61 | attentional_pool=model_kwargs['attentional_pool'], 62 | add_zero_attn=model_kwargs['add_zero_attn'], 63 | cosmos=model_kwargs['cosmos'], 64 | **input_model_kwargs,) 65 | 66 | ema_model = copy.deepcopy(model) 67 | 68 | if model_kwargs['pretrained']: 69 | return model 70 | 71 | if model_kwargs['huggingface_model_name'] != '': 72 | # based on huggingface model name, download the pre-trained weights, the downloaded path is passed as the 'resume' arguments 73 | huggingface_model_name, huggingface_repo_name = model_kwargs['huggingface_model_name'], model_kwargs['huggingface_repo_name'] 74 | model_kwargs['resume'] = download_weights_from_hf(model_repo=huggingface_repo_name, filename=huggingface_model_name) 75 | 76 | checkpoint = pt_load(model_kwargs['resume'], map_location='cpu') 77 | 78 | # resuming a train checkpoint w/ epoch and optimizer state 79 | start_epoch = checkpoint["epoch"] 80 | if "state_dict" in checkpoint: 81 | if model_kwargs['use_ema_model']: 82 | ema_sd = checkpoint["ema_state_dict"] 83 | if next(iter(ema_sd.items()))[0].startswith('module'): 84 | ema_sd = {k[len('module.'):]: v for k, v in ema_sd.items()} 85 | ema_model.load_state_dict(ema_sd) 86 | print(f"=> resuming ema checkpoint '{model_kwargs['resume']}' (epoch {start_epoch})") 87 | return ema_model 88 | sd = checkpoint["state_dict"] 89 | if next(iter(sd.items()))[0].startswith('module'): 90 | sd = {k[len('module.'):]: v for k, v in sd.items()} 91 | model.load_state_dict(sd) 92 | print(f"=> resuming checkpoint '{model_kwargs['resume']}' (epoch {start_epoch})") 93 | return model 94 | else: 95 | """ 96 | sd = checkpoint["student"] 97 | if next(iter(sd.items()))[0].startswith('module'): 98 | sd = {k[len('module.'):]: v for k, v in sd.items()} 99 | model.load_state_dict(sd) 100 | print(f"=> resuming checkpoint '{model_kwargs['resume']}' (epoch {start_epoch})") 101 | """ 102 | # Evaluation on Teacher 103 | ema_sd = checkpoint["teacher"] 104 | if next(iter(ema_sd.items()))[0].startswith('module'): 105 | ema_sd = {k[len('module.'):]: v for k, v in ema_sd.items()} 106 | ema_model.load_state_dict(ema_sd) 107 | print(f"=> resuming ema checkpoint '{model_kwargs['resume']}' (epoch {start_epoch})") 108 | 109 | return ema_model 110 | 111 | @MODELS.register_module() 112 | class CLIPForSegmentation(BaseSegmentor): 113 | def __init__(self, name_path, device=torch.device('cuda:0'), 114 | pamr_steps=0, pamr_stride=(8, 16), prob_thd=0.0, logit_scale=40, 115 | slide_stride=112, slide_crop=224, area_thd=None, **model_kwargs): 116 | 117 | data_preprocessor = SegDataPreProcessor( 118 | mean=[122.771, 116.746, 104.094], 119 | std=[68.501, 66.632, 70.323], 120 | rgb_to_bgr=True) 121 | super().__init__(data_preprocessor=data_preprocessor) 122 | 123 | self.net = load_model(**model_kwargs) 124 | query_words, self.query_idx = get_cls_idx(name_path) 125 | self.num_queries = len(query_words) 126 | self.num_classes = max(self.query_idx) + 1 127 | self.query_idx = torch.Tensor(self.query_idx).to(torch.int64).to(device) 128 | 129 | query_features = [] 130 | with torch.no_grad(): 131 | for qw in query_words: 132 | query = open_clip.tokenize([temp(qw) for temp in OPENAI_IMAGENET_TEMPLATES]).to(device) # clip.tokenize([temp(qw) for temp in OPENAI_IMAGENET_TEMPLATES]).to(device) 133 | feature = self.net.encode_text(query) 134 | feature = feature['text_features'] if isinstance(feature, dict) else feature 135 | feature /= feature.norm(dim=-1, keepdim=True) 136 | feature = feature.mean(dim=0) 137 | feature /= feature.norm() 138 | query_features.append(feature.unsqueeze(0)) 139 | self.query_features = torch.cat(query_features, dim=0) 140 | 141 | self.dtype = self.query_features.dtype 142 | self.logit_scale = logit_scale 143 | self.prob_thd = prob_thd 144 | self.area_thd = area_thd 145 | self.slide_stride = slide_stride 146 | self.slide_crop = slide_crop 147 | self.align_corners = False 148 | self.use_csa = model_kwargs['use_csa'] 149 | 150 | if pamr_steps > 0: 151 | self.pamr = PAMR(pamr_steps, dilations=pamr_stride).to(device) 152 | else: 153 | self.pamr = None 154 | 155 | def forward_feature(self, img, logit_size=None): 156 | if type(img) == list: 157 | img = img[0] 158 | if self.use_csa: 159 | csa_image_features, _ = self.net.visual(img, return_all=True, csa=True) 160 | csa_image_features = csa_image_features @ self.net.visual.proj # [B, L-1, C] 161 | 162 | image_features = csa_image_features 163 | image_features /= image_features.norm(dim=-1, keepdim=True) 164 | logits = image_features @ self.query_features.T 165 | else: 166 | image_features, _ = self.net.visual(img, return_all=True, csa=self.use_csa) 167 | image_features = image_features @ self.net.visual.proj # [B, L-1, C] 168 | image_features /= image_features.norm(dim=-1, keepdim=True) 169 | logits = image_features @ self.query_features.T 170 | 171 | patch_size = self.net.visual.patch_size 172 | patch_size = patch_size[0] if isinstance(patch_size, (list, tuple)) else patch_size 173 | 174 | w, h = img[0].shape[-2] // patch_size, img[0].shape[-1] // patch_size 175 | out_dim = logits.shape[-1] 176 | logits = logits.permute(0, 2, 1).reshape(-1, out_dim, w, h) 177 | 178 | if logit_size == None: 179 | logits = nn.functional.interpolate(logits, size=img.shape[-2:], mode='bilinear') 180 | else: 181 | logits = nn.functional.interpolate(logits, size=logit_size, mode='bilinear') 182 | 183 | return logits 184 | 185 | def forward_slide(self, img, img_metas, stride=112, crop_size=224): 186 | """Inference by sliding-window with overlap. 187 | If h_crop > h_img or w_crop > w_img, the small patch will be used to 188 | decode without padding. 189 | """ 190 | if type(img) == list: 191 | img = img[0].unsqueeze(0) 192 | if type(stride) == int: 193 | stride = (stride, stride) 194 | if type(crop_size) == int: 195 | crop_size = (crop_size, crop_size) 196 | 197 | h_stride, w_stride = stride 198 | h_crop, w_crop = crop_size 199 | batch_size, _, h_img, w_img = img.shape 200 | out_channels = self.num_queries 201 | h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 202 | w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 203 | preds = img.new_zeros((batch_size, out_channels, h_img, w_img)) 204 | count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) 205 | for h_idx in range(h_grids): 206 | for w_idx in range(w_grids): 207 | y1 = h_idx * h_stride 208 | x1 = w_idx * w_stride 209 | y2 = min(y1 + h_crop, h_img) 210 | x2 = min(x1 + w_crop, w_img) 211 | y1 = max(y2 - h_crop, 0) 212 | x1 = max(x2 - w_crop, 0) 213 | crop_img = img[:, :, y1:y2, x1:x2] 214 | crop_seg_logit = self.forward_feature(crop_img) 215 | preds += nn.functional.pad(crop_seg_logit, 216 | (int(x1), int(preds.shape[3] - x2), int(y1), 217 | int(preds.shape[2] - y2))) 218 | 219 | count_mat[:, :, y1:y2, x1:x2] += 1 220 | assert (count_mat == 0).sum() == 0 221 | 222 | preds = preds / count_mat 223 | img_size = img_metas[0]['ori_shape'][:2] 224 | logits = nn.functional.interpolate(preds, size=img_size, mode='bilinear') 225 | 226 | if self.pamr: 227 | img = nn.functional.interpolate(img, size=img_size, mode='bilinear') 228 | logits = self.pamr(img, logits.to(img.dtype)).to(self.dtype) 229 | 230 | return logits 231 | 232 | def predict(self, inputs, data_samples): 233 | if data_samples is not None: 234 | batch_img_metas = [ 235 | data_sample.metainfo for data_sample in data_samples 236 | ] 237 | else: 238 | batch_img_metas = [ 239 | dict( 240 | ori_shape=inputs.shape[2:], 241 | img_shape=inputs.shape[2:], 242 | pad_shape=inputs.shape[2:], 243 | padding_size=[0, 0, 0, 0]) 244 | ] * inputs.shape[0] 245 | 246 | if self.slide_crop > 0: 247 | seg_logits = self.forward_slide(inputs, batch_img_metas, self.slide_stride, self.slide_crop) 248 | else: 249 | seg_logits = self.forward_feature(inputs, batch_img_metas[0]['ori_shape']) 250 | 251 | return self.postprocess_result(seg_logits, data_samples) 252 | 253 | def postprocess_result(self, seg_logits, data_samples): 254 | batch_size = seg_logits.shape[0] 255 | for i in range(batch_size): 256 | seg_logits = seg_logits[i] * self.logit_scale 257 | seg_logits = seg_logits.softmax(0) # n_queries * w * h 258 | 259 | num_cls, num_queries = max(self.query_idx) + 1, len(self.query_idx) 260 | if num_cls != num_queries: 261 | seg_logits = seg_logits.unsqueeze(0) 262 | cls_index = nn.functional.one_hot(self.query_idx) 263 | cls_index = cls_index.T.view(num_cls, num_queries, 1, 1) 264 | seg_logits = (seg_logits * cls_index).max(1)[0] 265 | seg_pred = seg_logits.argmax(0, keepdim=True) 266 | 267 | if self.area_thd is not None: 268 | # Force segmentations with area < self.area_thd to 0 (background) 269 | predictions = nn.functional.one_hot(seg_logits.argmax(0), num_cls).to(seg_logits.dtype) 270 | area_pred = predictions[:, :, 1:].sum((0, 1), keepdim=True) # prone background 271 | area_pred = (area_pred > self.area_thd * area_pred.sum()).to(seg_logits.dtype) 272 | seg_logits[1:] *= area_pred.transpose(0, -1) 273 | 274 | seg_pred = seg_logits.argmax(0, keepdim=True) 275 | seg_pred[seg_logits.max(0, keepdim=True)[0] < self.prob_thd] = 0 276 | 277 | data_samples[i].set_data({ 278 | 'seg_logits': 279 | PixelData(**{'data': seg_logits}), 280 | 'pred_sem_seg': 281 | PixelData(**{'data': seg_pred}) 282 | }) 283 | 284 | return data_samples 285 | 286 | def _forward(data_samples): 287 | """ 288 | """ 289 | 290 | def inference(self, img, batch_img_metas): 291 | """ 292 | """ 293 | 294 | def encode_decode(self, inputs, batch_img_metas): 295 | """ 296 | """ 297 | 298 | def extract_feat(self, inputs): 299 | """ 300 | """ 301 | 302 | def loss(self, inputs, data_samples): 303 | """ 304 | """ 305 | 306 | def get_cls_idx(path): 307 | with open(path, 'r') as f: 308 | name_sets = f.readlines() 309 | num_cls = len(name_sets) 310 | 311 | class_names, class_indices = [], [] 312 | for idx in range(num_cls): 313 | names_i = name_sets[idx].split(', ') 314 | class_names += names_i 315 | class_indices += [idx for _ in range(len(names_i))] 316 | class_names = [item.replace('\n', '') for item in class_names] 317 | return class_names, class_indices -------------------------------------------------------------------------------- /src/training/custom_datasets.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import mmengine.fileio as fileio 3 | 4 | from mmseg.registry import DATASETS 5 | from mmseg.datasets import BaseSegDataset 6 | 7 | @DATASETS.register_module() 8 | class PascalVOC20Dataset(BaseSegDataset): 9 | """Pascal VOC dataset. 10 | 11 | Args: 12 | split (str): Split txt file for Pascal VOC. 13 | """ 14 | METAINFO = dict( 15 | classes=('aeroplane', 'bicycle', 'bird', 'boat', 16 | 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 17 | 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 18 | 'sofa', 'train', 'tvmonitor'), 19 | palette=[[128, 0, 0], [0, 128, 0], [0, 0, 192], 20 | [128, 128, 0], [128, 0, 128], [0, 128, 128], [192, 128, 64], 21 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], 22 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], 23 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], 24 | [0, 64, 128]]) 25 | 26 | def __init__(self, 27 | ann_file, 28 | img_suffix='.jpg', 29 | seg_map_suffix='.png', 30 | reduce_zero_label=True, 31 | **kwargs) -> None: 32 | super().__init__( 33 | img_suffix=img_suffix, 34 | seg_map_suffix=seg_map_suffix, 35 | reduce_zero_label=reduce_zero_label, 36 | ann_file=ann_file, 37 | **kwargs) 38 | assert fileio.exists(self.data_prefix['img_path'], 39 | self.backend_args) and osp.isfile(self.ann_file) 40 | 41 | @DATASETS.register_module() 42 | class COCOObjectDataset(BaseSegDataset): 43 | """ 44 | Implementation borrowed from TCL (https://github.com/kakaobrain/tcl) and GroupViT (https://github.com/NVlabs/GroupViT) 45 | COCO-Object dataset. 46 | 1 bg class + first 80 classes from the COCO-Stuff dataset. 47 | """ 48 | 49 | METAINFO = dict( 50 | 51 | classes = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 52 | 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 53 | 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 54 | 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 55 | 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 56 | 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 57 | 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 58 | 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 59 | 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'), 60 | 61 | palette = [[0, 0, 0], [0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192], [0, 64, 64], [0, 192, 224], 62 | [0, 192, 192], [128, 192, 64], [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224], [0, 0, 64], 63 | [0, 160, 192], [128, 0, 96], [128, 0, 192], [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192], 64 | [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128], [64, 128, 32], [0, 160, 0], [0, 0, 0], 65 | [192, 128, 160], [0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0], [0, 128, 0], [192, 128, 32], 66 | [128, 96, 128], [0, 0, 128], [64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160], [0, 96, 128], 67 | [128, 128, 128], [64, 0, 160], [128, 224, 128], [128, 128, 64], [192, 0, 32], 68 | [128, 96, 0], [128, 0, 192], [0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160], [64, 96, 0], 69 | [0, 128, 192], [0, 128, 160], [192, 224, 0], [0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192], 70 | [0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160], [64, 32, 128], [128, 192, 192], [0, 0, 160], 71 | [192, 160, 128], [128, 192, 0], [128, 0, 96], [192, 32, 0], [128, 64, 128], [64, 128, 96], [64, 160, 0], 72 | [0, 64, 0], [192, 128, 224], [64, 32, 0], [0, 192, 128], [64, 128, 224], [192, 160, 0]]) 73 | 74 | def __init__(self, **kwargs): 75 | super(COCOObjectDataset, self).__init__(img_suffix='.jpg', seg_map_suffix='_instanceTrainIds.png', **kwargs) 76 | 77 | @DATASETS.register_module() 78 | class PascalContext60Dataset(BaseSegDataset): 79 | METAINFO = dict( 80 | classes=('background', 'aeroplane', 'bag', 'bed', 'bedclothes', 81 | 'bench', 'bicycle', 'bird', 'boat', 'book', 'bottle', 82 | 'building', 'bus', 'cabinet', 'car', 'cat', 'ceiling', 83 | 'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain', 'dog', 84 | 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground', 85 | 'horse', 'keyboard', 'light', 'motorbike', 'mountain', 86 | 'mouse', 'person', 'plate', 'platform', 'pottedplant', 'road', 87 | 'rock', 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 88 | 'sofa', 'table', 'track', 'train', 'tree', 'truck', 89 | 'tvmonitor', 'wall', 'water', 'window', 'wood'), 90 | palette=[[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], 91 | [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], 92 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], 93 | [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], 94 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], 95 | [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], 96 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], 97 | [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], 98 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], 99 | [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], 100 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], 101 | [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], 102 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], 103 | [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], 104 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]) 105 | 106 | def __init__(self, 107 | ann_file: str, 108 | img_suffix='.jpg', 109 | seg_map_suffix='.png', 110 | **kwargs) -> None: 111 | super().__init__( 112 | img_suffix=img_suffix, 113 | seg_map_suffix=seg_map_suffix, 114 | ann_file=ann_file, 115 | reduce_zero_label=False, 116 | **kwargs) 117 | 118 | 119 | @DATASETS.register_module() 120 | class PascalContext59Dataset(BaseSegDataset): 121 | METAINFO = dict( 122 | classes=('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', 123 | 'bird', 'boat', 'book', 'bottle', 'building', 'bus', 124 | 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth', 125 | 'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence', 126 | 'floor', 'flower', 'food', 'grass', 'ground', 'horse', 127 | 'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 128 | 'person', 'plate', 'platform', 'pottedplant', 'road', 'rock', 129 | 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 130 | 'table', 'track', 'train', 'tree', 'truck', 'tvmonitor', 131 | 'wall', 'water', 'window', 'wood'), 132 | palette=[[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], 133 | [120, 120, 80], [140, 140, 140], [204, 5, 255], 134 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], 135 | [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], 136 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], 137 | [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], 138 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], 139 | [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], 140 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], 141 | [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], 142 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], 143 | [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], 144 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], 145 | [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], 146 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]) 147 | 148 | def __init__(self, 149 | ann_file: str, 150 | img_suffix='.jpg', 151 | seg_map_suffix='.png', 152 | reduce_zero_label=True, 153 | **kwargs): 154 | super().__init__( 155 | img_suffix=img_suffix, 156 | seg_map_suffix=seg_map_suffix, 157 | ann_file=ann_file, 158 | reduce_zero_label=reduce_zero_label, 159 | **kwargs) -------------------------------------------------------------------------------- /src/training/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | try: 7 | import horovod.torch as hvd 8 | except ImportError: 9 | hvd = None 10 | 11 | from datetime import timedelta 12 | 13 | def is_global_master(args): 14 | return args.rank == 0 15 | 16 | 17 | def is_local_master(args): 18 | return args.local_rank == 0 19 | 20 | 21 | def is_master(args, local=False): 22 | return is_local_master(args) if local else is_global_master(args) 23 | 24 | 25 | def is_using_horovod(): 26 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set 27 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... 28 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] 29 | pmi_vars = ["PMI_RANK", "PMI_SIZE"] 30 | if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]): 31 | return True 32 | else: 33 | return False 34 | 35 | 36 | def is_using_distributed(): 37 | if 'WORLD_SIZE' in os.environ: 38 | return int(os.environ['WORLD_SIZE']) > 1 39 | if 'SLURM_NTASKS' in os.environ: 40 | return int(os.environ['SLURM_NTASKS']) > 1 41 | return False 42 | 43 | 44 | def world_info_from_env(): 45 | local_rank = 0 46 | for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'): 47 | if v in os.environ: 48 | local_rank = int(os.environ[v]) 49 | break 50 | global_rank = 0 51 | for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'): 52 | if v in os.environ: 53 | global_rank = int(os.environ[v]) 54 | break 55 | world_size = 1 56 | for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'): 57 | if v in os.environ: 58 | world_size = int(os.environ[v]) 59 | break 60 | 61 | return local_rank, global_rank, world_size 62 | 63 | def init_distributed_device(args): 64 | # Distributed training = training on more than one GPU. 65 | # Works in both single and multi-node scenarios. 66 | args.distributed = False 67 | args.world_size = 1 68 | args.rank = 0 # global rank 69 | args.local_rank = 0 70 | if args.horovod: 71 | assert hvd is not None, "Horovod is not installed" 72 | hvd.init() 73 | args.local_rank = int(hvd.local_rank()) 74 | args.rank = hvd.rank() 75 | args.world_size = hvd.size() 76 | args.distributed = True 77 | os.environ['LOCAL_RANK'] = str(args.local_rank) 78 | os.environ['RANK'] = str(args.rank) 79 | os.environ['WORLD_SIZE'] = str(args.world_size) 80 | elif is_using_distributed(): 81 | timeout=None 82 | if 'SLURM_PROCID' in os.environ: 83 | # DDP via SLURM 84 | args.local_rank, args.rank, args.world_size = world_info_from_env() 85 | # SLURM var -> torch.distributed vars in case needed 86 | os.environ['LOCAL_RANK'] = str(args.local_rank) 87 | os.environ['RANK'] = str(args.rank) 88 | os.environ['WORLD_SIZE'] = str(args.world_size) 89 | torch.distributed.init_process_group( 90 | backend=args.dist_backend, 91 | init_method=args.dist_url, 92 | world_size=args.world_size, 93 | rank=args.rank, 94 | timeout=timeout 95 | ) 96 | else: 97 | # DDP via torchrun, torch.distributed.launch 98 | args.local_rank, _, _ = world_info_from_env() 99 | torch.distributed.init_process_group( 100 | backend=args.dist_backend, 101 | init_method=args.dist_url, 102 | timeout=timeout) 103 | args.world_size = torch.distributed.get_world_size() 104 | args.rank = torch.distributed.get_rank() 105 | args.distributed = True 106 | 107 | if torch.cuda.is_available(): 108 | if args.distributed and not args.no_set_device_rank: 109 | device = 'cuda:%d' % args.local_rank 110 | else: 111 | device = 'cuda:0' 112 | torch.cuda.set_device(device) 113 | else: 114 | device = 'cpu' 115 | args.device = device 116 | device = torch.device(device) 117 | return device 118 | 119 | 120 | def broadcast_object(args, obj, src=0): 121 | # broadcast a pickle-able python object from rank-0 to all ranks 122 | if args.horovod: 123 | return hvd.broadcast_object(obj, root_rank=src) 124 | else: 125 | if args.rank == src: 126 | objects = [obj] 127 | else: 128 | objects = [None] 129 | dist.broadcast_object_list(objects, src=src) 130 | return objects[0] 131 | 132 | 133 | def all_gather_object(args, obj, dst=0): 134 | # gather a pickle-able python object across all ranks 135 | if args.horovod: 136 | return hvd.allgather_object(obj) 137 | else: 138 | objects = [None for _ in range(args.world_size)] 139 | dist.all_gather_object(objects, obj) 140 | return objects 141 | -------------------------------------------------------------------------------- /src/training/file_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import multiprocessing 4 | import subprocess 5 | import time 6 | import fsspec 7 | import torch 8 | from tqdm import tqdm 9 | 10 | def remote_sync_s3(local_dir, remote_dir): 11 | # skip epoch_latest which can change during sync. 12 | result = subprocess.run(["aws", "s3", "sync", local_dir, remote_dir, '--exclude', '*epoch_latest.pt'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) 13 | if result.returncode != 0: 14 | logging.error(f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}") 15 | return False 16 | 17 | logging.info(f"Successfully synced with S3 bucket") 18 | return True 19 | 20 | def remote_sync_fsspec(local_dir, remote_dir): 21 | # FIXME currently this is slow and not recommended. Look into speeding up. 22 | a = fsspec.get_mapper(local_dir) 23 | b = fsspec.get_mapper(remote_dir) 24 | 25 | for k in a: 26 | # skip epoch_latest which can change during sync. 27 | if 'epoch_latest.pt' in k: 28 | continue 29 | 30 | logging.info(f'Attempting to sync {k}') 31 | if k in b and len(a[k]) == len(b[k]): 32 | logging.debug(f'Skipping remote sync for {k}.') 33 | continue 34 | 35 | try: 36 | logging.info(f'Successful sync for {k}.') 37 | b[k] = a[k] 38 | except Exception as e: 39 | logging.info(f'Error during remote sync for {k}: {e}') 40 | return False 41 | 42 | return True 43 | 44 | def remote_sync(local_dir, remote_dir, protocol): 45 | logging.info('Starting remote sync.') 46 | if protocol == 's3': 47 | return remote_sync_s3(local_dir, remote_dir) 48 | elif protocol == 'fsspec': 49 | return remote_sync_fsspec(local_dir, remote_dir) 50 | else: 51 | logging.error('Remote protocol not known') 52 | return False 53 | 54 | def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol): 55 | while True: 56 | time.sleep(sync_every) 57 | remote_sync(local_dir, remote_dir, protocol) 58 | 59 | def start_sync_process(sync_every, local_dir, remote_dir, protocol): 60 | p = multiprocessing.Process(target=keep_running_remote_sync, args=(sync_every, local_dir, remote_dir, protocol)) 61 | return p 62 | 63 | # Note: we are not currently using this save function. 64 | def pt_save(pt_obj, file_path): 65 | of = fsspec.open(file_path, "wb") 66 | with of as f: 67 | torch.save(pt_obj, file_path) 68 | 69 | def pt_load(file_path, map_location=None): 70 | if file_path.startswith('s3'): 71 | logging.info('Loading remote checkpoint, which may take a bit.') 72 | of = fsspec.open(file_path, "rb") 73 | with of as f: 74 | out = torch.load(f, map_location=map_location, weights_only=False) 75 | return out 76 | 77 | def check_exists(file_path): 78 | try: 79 | with fsspec.open(file_path): 80 | pass 81 | except FileNotFoundError: 82 | return False 83 | return True 84 | -------------------------------------------------------------------------------- /src/training/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def setup_logging(log_file, level, include_host=False): 5 | if include_host: 6 | import socket 7 | hostname = socket.gethostname() 8 | formatter = logging.Formatter( 9 | f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 10 | else: 11 | formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 12 | 13 | logging.root.setLevel(level) 14 | loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] 15 | for logger in loggers: 16 | logger.setLevel(level) 17 | 18 | stream_handler = logging.StreamHandler() 19 | stream_handler.setFormatter(formatter) 20 | logging.root.addHandler(stream_handler) 21 | 22 | if log_file: 23 | file_handler = logging.FileHandler(filename=log_file) 24 | file_handler.setFormatter(formatter) 25 | logging.root.addHandler(file_handler) 26 | 27 | -------------------------------------------------------------------------------- /src/training/pamr.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 TU Darmstadt 2 | # Licnese: Apache 2.0 License. 3 | # https://github.com/visinf/1-stage-wseg/blob/master/models/mods/pamr.py 4 | import torch 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | 8 | from functools import partial 9 | 10 | # 11 | # Helper modules 12 | # 13 | class LocalAffinity(nn.Module): 14 | 15 | def __init__(self, dilations=[1]): 16 | super(LocalAffinity, self).__init__() 17 | self.dilations = dilations 18 | weight = self._init_aff() 19 | self.register_buffer('kernel', weight) 20 | 21 | def _init_aff(self): 22 | # initialising the shift kernel 23 | weight = torch.zeros(8, 1, 3, 3) 24 | 25 | for i in range(weight.size(0)): 26 | weight[i, 0, 1, 1] = 1 27 | 28 | weight[0, 0, 0, 0] = -1 29 | weight[1, 0, 0, 1] = -1 30 | weight[2, 0, 0, 2] = -1 31 | 32 | weight[3, 0, 1, 0] = -1 33 | weight[4, 0, 1, 2] = -1 34 | 35 | weight[5, 0, 2, 0] = -1 36 | weight[6, 0, 2, 1] = -1 37 | weight[7, 0, 2, 2] = -1 38 | 39 | self.weight_check = weight.clone() 40 | 41 | return weight 42 | 43 | def forward(self, x): 44 | 45 | self.weight_check = self.weight_check.type_as(x) 46 | assert torch.all(self.weight_check.eq(self.kernel)) 47 | 48 | B,K,H,W = x.size() 49 | x = x.view(B*K,1,H,W) 50 | 51 | x_affs = [] 52 | for d in self.dilations: 53 | x_pad = F.pad(x, [d]*4, mode='replicate') 54 | x_aff = F.conv2d(x_pad, self.kernel, dilation=d) 55 | x_affs.append(x_aff) 56 | 57 | x_aff = torch.cat(x_affs, 1) 58 | return x_aff.view(B,K,-1,H,W) 59 | 60 | class LocalAffinityCopy(LocalAffinity): 61 | 62 | def _init_aff(self): 63 | # initialising the shift kernel 64 | weight = torch.zeros(8, 1, 3, 3) 65 | 66 | weight[0, 0, 0, 0] = 1 67 | weight[1, 0, 0, 1] = 1 68 | weight[2, 0, 0, 2] = 1 69 | 70 | weight[3, 0, 1, 0] = 1 71 | weight[4, 0, 1, 2] = 1 72 | 73 | weight[5, 0, 2, 0] = 1 74 | weight[6, 0, 2, 1] = 1 75 | weight[7, 0, 2, 2] = 1 76 | 77 | self.weight_check = weight.clone() 78 | return weight 79 | 80 | class LocalStDev(LocalAffinity): 81 | 82 | def _init_aff(self): 83 | weight = torch.zeros(9, 1, 3, 3) 84 | weight.zero_() 85 | 86 | weight[0, 0, 0, 0] = 1 87 | weight[1, 0, 0, 1] = 1 88 | weight[2, 0, 0, 2] = 1 89 | 90 | weight[3, 0, 1, 0] = 1 91 | weight[4, 0, 1, 1] = 1 92 | weight[5, 0, 1, 2] = 1 93 | 94 | weight[6, 0, 2, 0] = 1 95 | weight[7, 0, 2, 1] = 1 96 | weight[8, 0, 2, 2] = 1 97 | 98 | self.weight_check = weight.clone() 99 | return weight 100 | 101 | def forward(self, x): 102 | # returns (B,K,P,H,W), where P is the number 103 | # of locations 104 | x = super(LocalStDev, self).forward(x) 105 | 106 | return x.std(2, keepdim=True) 107 | 108 | class LocalAffinityAbs(LocalAffinity): 109 | 110 | def forward(self, x): 111 | x = super(LocalAffinityAbs, self).forward(x) 112 | return torch.abs(x) 113 | 114 | # 115 | # PAMR module 116 | # 117 | class PAMR(nn.Module): 118 | 119 | def __init__(self, num_iter=1, dilations=[1]): 120 | super(PAMR, self).__init__() 121 | 122 | self.num_iter = num_iter 123 | self.aff_x = LocalAffinityAbs(dilations) 124 | self.aff_m = LocalAffinityCopy(dilations) 125 | self.aff_std = LocalStDev(dilations) 126 | 127 | def forward(self, x, mask): 128 | mask = F.interpolate(mask, size=x.size()[-2:], mode="bilinear", align_corners=True) 129 | 130 | # x: [BxKxHxW] 131 | # mask: [BxCxHxW] 132 | B,K,H,W = x.size() 133 | _,C,_,_ = mask.size() 134 | 135 | x_std = self.aff_std(x) 136 | 137 | x = -self.aff_x(x) / (1e-8 + 0.1 * x_std) 138 | x = x.mean(1, keepdim=True) 139 | x = F.softmax(x, 2) 140 | 141 | for _ in range(self.num_iter): 142 | m = self.aff_m(mask) # [BxCxPxHxW] 143 | mask = (m * x).sum(2) 144 | 145 | # xvals: [BxCxHxW] 146 | return mask -------------------------------------------------------------------------------- /src/training/precision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from contextlib import suppress 3 | 4 | 5 | def get_autocast(precision): 6 | if precision == 'amp': 7 | return torch.cuda.amp.autocast 8 | elif precision == 'amp_bfloat16' or precision == 'amp_bf16': 9 | # amp_bfloat16 is more stable than amp float16 for clip training 10 | return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) 11 | else: 12 | return suppress 13 | -------------------------------------------------------------------------------- /src/training/profiler.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import open_clip 5 | import pandas as pd 6 | from torch.utils.flop_counter import FlopCounterMode 7 | try: 8 | import fvcore 9 | except: 10 | fvcore = None 11 | 12 | parser = argparse.ArgumentParser(description='OpenCLIP Profiler') 13 | 14 | # benchmark specific args 15 | parser.add_argument('--model', metavar='NAME', default='', 16 | help='model(s) to profile') 17 | parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', 18 | help='Output csv file for results') 19 | parser.add_argument('--profiler', default='torch', type=str, choices=['torch', 'fvcore']) 20 | parser.add_argument('--batch-size', default=1, type=int, help='Batch size for profiling') 21 | 22 | 23 | def profile_fvcore( 24 | model, 25 | image_input_size=(3, 224, 224), 26 | text_input_size=(77,), 27 | batch_size=1, 28 | detailed=False, 29 | force_cpu=False 30 | ): 31 | if force_cpu: 32 | model = model.to('cpu') 33 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 34 | example_image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) 35 | example_text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) 36 | fca = fvcore.nn.FlopCountAnalysis(model, (example_image_input, example_text_input)) 37 | aca = fvcore.nn.ActivationCountAnalysis(model, (example_image_input, example_text_input)) 38 | if detailed: 39 | fcs = fvcore.nn.flop_count_str(fca) 40 | print(fcs) 41 | return fca.total() / batch_size, aca.total() / batch_size 42 | 43 | 44 | def profile_fvcore_text( 45 | model, 46 | text_input_size=(77,), 47 | batch_size=1, 48 | detailed=False, 49 | force_cpu=False 50 | ): 51 | if force_cpu: 52 | model = model.to('cpu') 53 | device = next(model.parameters()).device 54 | example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) 55 | fca = fvcore.nn.FlopCountAnalysis(model, example_input) 56 | aca = fvcore.nn.ActivationCountAnalysis(model, example_input) 57 | if detailed: 58 | fcs = fvcore.nn.flop_count_str(fca) 59 | print(fcs) 60 | return fca.total() / batch_size, aca.total() / batch_size 61 | 62 | 63 | def profile_fvcore_image( 64 | model, 65 | image_input_size=(3, 224, 224), 66 | batch_size=1, 67 | detailed=False, 68 | force_cpu=False 69 | ): 70 | if force_cpu: 71 | model = model.to('cpu') 72 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 73 | example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) 74 | fca = fvcore.nn.FlopCountAnalysis(model, example_input) 75 | aca = fvcore.nn.ActivationCountAnalysis(model, example_input) 76 | if detailed: 77 | fcs = fvcore.nn.flop_count_str(fca) 78 | print(fcs) 79 | return fca.total() / batch_size, aca.total() / batch_size 80 | 81 | 82 | def profile_torch_image(model, image_input_size, batch_size=1, force_cpu=False): 83 | """Profile the image encoder using torch.utils.flop_counter""" 84 | if force_cpu: 85 | model = model.to('cpu') 86 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 87 | example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) 88 | 89 | flop_counter = FlopCounterMode() 90 | with flop_counter: 91 | model(example_input) 92 | total_flops = sum(flop_counter.get_flop_counts()['Global'].values()) 93 | return total_flops / batch_size 94 | 95 | 96 | def profile_torch_text(model, text_input_size, batch_size=1, force_cpu=False): 97 | """Profile the text encoder using torch.utils.flop_counter""" 98 | if force_cpu: 99 | model = model.to('cpu') 100 | device = next(model.parameters()).device 101 | example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) 102 | 103 | flop_counter = FlopCounterMode() 104 | with flop_counter: 105 | model(example_input) 106 | total_flops = sum(flop_counter.get_flop_counts()['Global'].values()) 107 | return total_flops / batch_size 108 | 109 | 110 | def profile_torch(model, text_input_size, image_input_size, batch_size=1, force_cpu=False): 111 | """Profile the full model using torch.utils.flop_counter""" 112 | if force_cpu: 113 | model = model.to('cpu') 114 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 115 | image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) 116 | text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) 117 | 118 | flop_counter = FlopCounterMode() 119 | with flop_counter: 120 | model(image_input, text_input) 121 | total_flops = sum(flop_counter.get_flop_counts()['Global'].values()) 122 | return total_flops / batch_size 123 | 124 | 125 | def count_params(model): 126 | return sum(m.numel() for m in model.parameters()) 127 | 128 | def profile_model(model_name, batch_size=1, profiler='torch'): 129 | assert profiler in ['torch', 'fvcore'], 'Only torch and fvcore profilers are supported' 130 | if profiler == 'fvcore': 131 | assert fvcore is not None, 'Please install fvcore.' 132 | model = open_clip.create_model(model_name, force_custom_text=True, pretrained_hf=False) 133 | model.eval() 134 | if torch.cuda.is_available(): 135 | model = model.cuda() 136 | 137 | if isinstance(model.visual.image_size, (tuple, list)): 138 | image_input_size = (3,) + tuple(model.visual.image_size[-2:]) 139 | else: 140 | image_input_size = (3, model.visual.image_size, model.visual.image_size) 141 | 142 | text_input_size = (77,) 143 | if hasattr(model, 'context_length') and model.context_length: 144 | text_input_size = (model.context_length,) 145 | 146 | results = {} 147 | results['model'] = model_name 148 | results['image_size'] = image_input_size[1] 149 | 150 | model_cfg = open_clip.get_model_config(model_name) 151 | if model_cfg: 152 | vision_cfg = open_clip.CLIPVisionCfg(**model_cfg['vision_cfg']) 153 | text_cfg = open_clip.CLIPTextCfg(**model_cfg['text_cfg']) 154 | results['image_width'] = int(vision_cfg.width) 155 | results['text_width'] = int(text_cfg.width) 156 | results['embed_dim'] = int(model_cfg['embed_dim']) 157 | else: 158 | results['image_width'] = 0 159 | results['text_width'] = 0 160 | results['embed_dim'] = 0 161 | 162 | retries = 2 163 | while retries: 164 | retries -= 1 165 | try: 166 | results['mparams'] = round(count_params(model) / 1e6, 2) 167 | results['image_mparams'] = round(count_params(model.visual) / 1e6, 2) 168 | results['text_mparams'] = round(count_params(model.text) / 1e6, 2) 169 | 170 | if profiler == 'fvcore': 171 | macs, acts = profile_fvcore( 172 | model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size) 173 | 174 | image_macs, image_acts = profile_fvcore_image( 175 | model.visual, image_input_size=image_input_size, force_cpu=not retries, batch_size=batch_size) 176 | 177 | text_macs, text_acts = profile_fvcore_text( 178 | model.text, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size) 179 | 180 | results['gmacs'] = round(macs / 1e9, 2) 181 | results['macts'] = round(acts / 1e6, 2) 182 | 183 | results['image_gmacs'] = round(image_macs / 1e9, 2) 184 | results['image_macts'] = round(image_acts / 1e6, 2) 185 | 186 | results['text_gmacs'] = round(text_macs / 1e9, 2) 187 | results['text_macts'] = round(text_acts / 1e6, 2) 188 | elif profiler == 'torch': 189 | image_flops = profile_torch_image( 190 | model.visual, image_input_size=image_input_size, force_cpu=not retries, batch_size=batch_size) 191 | text_flops = profile_torch_text( 192 | model.text, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size) 193 | total_flops = profile_torch( 194 | model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size) 195 | 196 | results['gflops'] = round(total_flops / 1e9, 2) 197 | results['image_gflops'] = round(image_flops / 1e9, 2) 198 | results['text_gflops'] = round(text_flops / 1e9, 2) 199 | 200 | except RuntimeError as e: 201 | pass 202 | return results 203 | 204 | 205 | def main(): 206 | args = parser.parse_args() 207 | 208 | # FIXME accept a text file name to allow lists of models in txt/csv 209 | if args.model == 'all': 210 | parsed_model = open_clip.list_models() 211 | else: 212 | parsed_model = args.model.split(',') 213 | 214 | results = [] 215 | models_with_errors = [] 216 | for m in parsed_model: 217 | print('='*100) 218 | print(f'Profiling {m}') 219 | try: 220 | row = profile_model(m, batch_size=args.batch_size, profiler=args.profiler) 221 | results.append(row) 222 | except Exception as e: 223 | print(f'Error profiling {m}: {e}') 224 | import traceback 225 | traceback.print_exc() 226 | models_with_errors.append(m) 227 | 228 | df = pd.DataFrame(results, columns=results[0].keys()) 229 | 230 | if 'gmacs' in df.columns: 231 | df = df.sort_values(by=['gmacs', 'mparams', 'model']) 232 | else: 233 | df = df.sort_values(by=['gflops', 'mparams', 'model']) 234 | 235 | print('='*100) 236 | print('Done.') 237 | print(df) 238 | if args.results_file: 239 | df.to_csv(args.results_file, index=False) 240 | 241 | if models_with_errors: 242 | print('Models with errors:', models_with_errors) 243 | 244 | 245 | if __name__ == '__main__': 246 | main() 247 | -------------------------------------------------------------------------------- /src/training/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def assign_learning_rate(optimizer, new_lr): 5 | for param_group in optimizer.param_groups: 6 | param_group["lr"] = new_lr 7 | 8 | 9 | def _warmup_lr(base_lr, warmup_length, step): 10 | return base_lr * (step + 1) / warmup_length 11 | 12 | 13 | def const_lr(optimizer, base_lr, warmup_length, steps): 14 | def _lr_adjuster(step): 15 | if step < warmup_length: 16 | lr = _warmup_lr(base_lr, warmup_length, step) 17 | else: 18 | lr = base_lr 19 | assign_learning_rate(optimizer, lr) 20 | return lr 21 | return _lr_adjuster 22 | 23 | 24 | def const_lr_cooldown(optimizer, base_lr, warmup_length, steps, cooldown_steps, cooldown_power=1.0, cooldown_end_lr=0.): 25 | def _lr_adjuster(step): 26 | start_cooldown_step = steps - cooldown_steps 27 | if step < warmup_length: 28 | lr = _warmup_lr(base_lr, warmup_length, step) 29 | else: 30 | if step < start_cooldown_step: 31 | lr = base_lr 32 | else: 33 | e = step - start_cooldown_step 34 | es = steps - start_cooldown_step 35 | # linear decay if power == 1; polynomial decay otherwise; 36 | decay = (1 - (e/es)) ** cooldown_power 37 | lr = decay * (base_lr - cooldown_end_lr) + cooldown_end_lr 38 | assign_learning_rate(optimizer, lr) 39 | return lr 40 | return _lr_adjuster 41 | 42 | 43 | def cosine_lr(optimizer, base_lr, warmup_length, steps): 44 | def _lr_adjuster(step): 45 | if step < warmup_length: 46 | lr = _warmup_lr(base_lr, warmup_length, step) 47 | else: 48 | e = step - warmup_length 49 | es = steps - warmup_length 50 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 51 | assign_learning_rate(optimizer, lr) 52 | return lr 53 | return _lr_adjuster 54 | 55 | 56 | def cosine_scheduler(base_value, final_value, warmup_length, steps): 57 | def _adjuster(step): 58 | if step < warmup_length: 59 | lr = base_value * (step + 1) / warmup_length 60 | else: 61 | e = step - warmup_length 62 | es = steps - warmup_length 63 | lr = final_value + 0.5 * (1 + np.cos(np.pi * e / es)) * (base_value - final_value) 64 | return lr 65 | return _adjuster -------------------------------------------------------------------------------- /src/training/seg_configs/base_config.py: -------------------------------------------------------------------------------- 1 | # base configurations 2 | model = dict( 3 | type='CLIPForSegmentation', 4 | clip_path='ViT-B/16' 5 | ) 6 | 7 | test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) 8 | 9 | default_scope = 'mmseg' 10 | env_cfg = dict( 11 | cudnn_benchmark=True, 12 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), 13 | dist_cfg=dict(backend='nccl'), 14 | ) 15 | vis_backends = [dict(type='LocalVisBackend')] 16 | visualizer = dict( 17 | type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer') 18 | log_processor = dict(by_epoch=False) 19 | log_level = 'INFO' 20 | load_from = None 21 | resume = False 22 | 23 | test_cfg = dict(type='TestLoop') 24 | 25 | 26 | default_hooks = dict( 27 | timer=dict(type='IterTimerHook'), 28 | logger=dict(type='LoggerHook', interval=200, log_metric_by_epoch=False), 29 | param_scheduler=dict(type='ParamSchedulerHook'), 30 | checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000), 31 | sampler_seed=dict(type='DistSamplerSeedHook'), 32 | visualization=dict(type='SegVisualizationHook', interval=1)) 33 | -------------------------------------------------------------------------------- /src/training/seg_configs/cfg_ade20k.py: -------------------------------------------------------------------------------- 1 | # ADE20K config file that define test pipeline 2 | # Please refer to here https://github.com/wangf3014/SCLIP/blob/main/configs/cfg_ade20k.py 3 | _base_ = './base_config.py' 4 | 5 | # model settings 6 | model = dict( 7 | name_path='./training/seg_configs/cls_ade20k.txt' 8 | ) 9 | 10 | # dataset settings 11 | dataset_type = 'ADE20KDataset' 12 | data_root = '/mmsegmentation_datasets/data/ade/ADEChallengeData2016' 13 | 14 | test_pipeline = [ 15 | dict(type='LoadImageFromFile'), 16 | dict(type='Resize', scale=(2048, 336), keep_ratio=True), 17 | dict(type='LoadAnnotations', reduce_zero_label=True), 18 | dict(type='PackSegInputs') 19 | ] 20 | 21 | test_dataloader = dict( 22 | batch_size=1, 23 | num_workers=4, 24 | persistent_workers=True, 25 | sampler=dict(type='DefaultSampler', shuffle=False), 26 | dataset=dict( 27 | type=dataset_type, 28 | data_root=data_root, 29 | data_prefix=dict( 30 | img_path='images/validation', 31 | seg_map_path='annotations/validation'), 32 | pipeline=test_pipeline)) -------------------------------------------------------------------------------- /src/training/seg_configs/cfg_city_scapes.py: -------------------------------------------------------------------------------- 1 | _base_ = './base_config.py' 2 | 3 | # model settings 4 | model = dict( 5 | name_path='./training/seg_configs/cls_city_scapes.txt' 6 | ) 7 | 8 | # dataset settings 9 | dataset_type = 'CityscapesDataset' 10 | data_root = '/mmsegmentation_datasets/data/cityscapes' 11 | 12 | test_pipeline = [ 13 | dict(type='LoadImageFromFile'), 14 | dict(type='Resize', scale=(2048, 560), keep_ratio=True), 15 | # add loading annotation after ``Resize`` because ground truth 16 | # does not need to do resize data transform 17 | dict(type='LoadAnnotations'), 18 | dict(type='PackSegInputs') 19 | ] 20 | 21 | test_dataloader = dict( 22 | batch_size=1, 23 | num_workers=4, 24 | persistent_workers=True, 25 | sampler=dict(type='DefaultSampler', shuffle=False), 26 | dataset=dict( 27 | type=dataset_type, 28 | data_root=data_root, 29 | data_prefix=dict( 30 | img_path='leftImg8bit/val', seg_map_path='gtFine/val'), 31 | pipeline=test_pipeline)) -------------------------------------------------------------------------------- /src/training/seg_configs/cfg_coco_object.py: -------------------------------------------------------------------------------- 1 | _base_ = './base_config.py' 2 | 3 | # model settings 4 | model = dict( 5 | name_path='./training/seg_configs/cls_coco_object.txt', 6 | logit_scale=50, 7 | prob_thd=0.1 8 | ) 9 | 10 | # dataset settings 11 | dataset_type = 'COCOObjectDataset' 12 | data_root = '/mmsegmentation_datasets/data/coco_stuff164k' 13 | 14 | test_pipeline = [ 15 | dict(type='LoadImageFromFile'), 16 | dict(type='Resize', scale=(2048, 336), keep_ratio=True), 17 | # add loading annotation after ``Resize`` because ground truth 18 | # does not need to do resize data transform 19 | dict(type='LoadAnnotations'), 20 | dict(type='PackSegInputs') 21 | ] 22 | 23 | test_dataloader = dict( 24 | batch_size=1, 25 | num_workers=4, 26 | persistent_workers=True, 27 | sampler=dict(type='DefaultSampler', shuffle=False), 28 | dataset=dict( 29 | type=dataset_type, 30 | data_root=data_root, 31 | reduce_zero_label=False, 32 | data_prefix=dict( 33 | img_path='images/val2017', seg_map_path='annotations/val2017'), 34 | pipeline=test_pipeline)) -------------------------------------------------------------------------------- /src/training/seg_configs/cfg_coco_stuff164k.py: -------------------------------------------------------------------------------- 1 | _base_ = './base_config.py' 2 | 3 | # model settings 4 | model = dict( 5 | name_path='./training/seg_configs/cls_coco_stuff.txt' 6 | ) 7 | 8 | # dataset settings 9 | dataset_type = 'COCOStuffDataset' 10 | data_root = '/mmsegmentation_datasets/data/coco_stuff164k' 11 | 12 | test_pipeline = [ 13 | dict(type='LoadImageFromFile'), 14 | dict(type='Resize', scale=(2048, 448), keep_ratio=True), 15 | dict(type='LoadAnnotations'), 16 | dict(type='PackSegInputs') 17 | ] 18 | 19 | test_dataloader = dict( 20 | batch_size=1, 21 | num_workers=4, 22 | persistent_workers=True, 23 | sampler=dict(type='DefaultSampler', shuffle=False), 24 | dataset=dict( 25 | type=dataset_type, 26 | data_root=data_root, 27 | data_prefix=dict( 28 | img_path='images/val2017', seg_map_path='annotations/val2017'), 29 | pipeline=test_pipeline)) -------------------------------------------------------------------------------- /src/training/seg_configs/cfg_context59.py: -------------------------------------------------------------------------------- 1 | _base_ = './base_config.py' 2 | 3 | # model settings 4 | model = dict( 5 | name_path='./training/seg_configs/cls_context59.txt' 6 | ) 7 | 8 | # dataset settings 9 | dataset_type = 'PascalContext59Dataset' 10 | data_root = '/mmsegmentation_datasets/data/VOCdevkit/VOC2010' 11 | 12 | test_pipeline = [ 13 | dict(type='LoadImageFromFile'), 14 | dict(type='Resize', scale=(2048, 336), keep_ratio=True), 15 | dict(type='LoadAnnotations', reduce_zero_label=True), 16 | dict(type='PackSegInputs') 17 | ] 18 | 19 | test_dataloader = dict( 20 | batch_size=1, 21 | num_workers=4, 22 | persistent_workers=True, 23 | sampler=dict(type='DefaultSampler', shuffle=False), 24 | dataset=dict( 25 | type=dataset_type, 26 | data_root=data_root, 27 | data_prefix=dict( 28 | img_path='JPEGImages', seg_map_path='SegmentationClassContext'), 29 | ann_file='ImageSets/SegmentationContext/val.txt', 30 | pipeline=test_pipeline)) -------------------------------------------------------------------------------- /src/training/seg_configs/cfg_context60.py: -------------------------------------------------------------------------------- 1 | _base_ = './base_config.py' 2 | 3 | # model settings 4 | model = dict( 5 | name_path='./training/seg_configs/cls_context60.txt', 6 | logit_scale=50, 7 | prob_thd=0.1 8 | ) 9 | 10 | # dataset settings 11 | dataset_type = 'PascalContext60Dataset' 12 | data_root = '/mmsegmentation_datasets/data/VOCdevkit/VOC2010' 13 | 14 | test_pipeline = [ 15 | dict(type='LoadImageFromFile'), 16 | dict(type='Resize', scale=(2048, 336), keep_ratio=True), 17 | dict(type='LoadAnnotations'), 18 | dict(type='PackSegInputs') 19 | ] 20 | 21 | test_dataloader = dict( 22 | batch_size=1, 23 | num_workers=4, 24 | persistent_workers=True, 25 | sampler=dict(type='DefaultSampler', shuffle=False), 26 | dataset=dict( 27 | type=dataset_type, 28 | data_root=data_root, 29 | data_prefix=dict( 30 | img_path='JPEGImages', seg_map_path='SegmentationClassContext'), 31 | ann_file='ImageSets/SegmentationContext/val.txt', 32 | pipeline=test_pipeline)) -------------------------------------------------------------------------------- /src/training/seg_configs/cfg_voc20.py: -------------------------------------------------------------------------------- 1 | # VOC20 config file that define test pipeline 2 | # Please refer to here https://github.com/wangf3014/SCLIP/blob/main/configs/cfg_voc20.py 3 | _base_ = './base_config.py' 4 | 5 | # model settings 6 | model = dict( 7 | name_path='./training/seg_configs/cls_voc20.txt' 8 | ) 9 | 10 | # dataset settings 11 | dataset_type = 'PascalVOC20Dataset' 12 | data_root = '/mmsegmentation_datasets/data/VOCdevkit/VOC2012' 13 | 14 | test_pipeline = [ 15 | dict(type='LoadImageFromFile'), 16 | dict(type='Resize', scale=(2048, 336), keep_ratio=True), 17 | dict(type='LoadAnnotations'), 18 | dict(type='PackSegInputs') 19 | ] 20 | 21 | test_dataloader = dict( 22 | batch_size=1, 23 | num_workers=4, 24 | persistent_workers=True, 25 | sampler=dict(type='DefaultSampler', shuffle=False), 26 | dataset=dict( 27 | type=dataset_type, 28 | data_root=data_root, 29 | data_prefix=dict( 30 | img_path='JPEGImages', seg_map_path='SegmentationClass'), 31 | ann_file='ImageSets/Segmentation/val.txt', 32 | pipeline=test_pipeline)) -------------------------------------------------------------------------------- /src/training/seg_configs/cfg_voc21.py: -------------------------------------------------------------------------------- 1 | _base_ = './base_config.py' 2 | 3 | # model settings 4 | model = dict( 5 | name_path='./training/seg_configs/cls_voc21.txt', 6 | logit_scale=65, 7 | prob_thd=0.1, 8 | area_thd=0.1 9 | ) 10 | 11 | # dataset settings 12 | dataset_type = 'PascalVOCDataset' 13 | data_root = '/mmsegmentation_datasets/data/VOCdevkit/VOC2012' 14 | 15 | test_pipeline = [ 16 | dict(type='LoadImageFromFile'), 17 | dict(type='Resize', scale=(2048, 336), keep_ratio=True), 18 | dict(type='LoadAnnotations'), 19 | dict(type='PackSegInputs') 20 | ] 21 | 22 | test_dataloader = dict( 23 | batch_size=1, 24 | num_workers=4, 25 | persistent_workers=True, 26 | sampler=dict(type='DefaultSampler', shuffle=False), 27 | dataset=dict( 28 | type=dataset_type, 29 | data_root=data_root, 30 | data_prefix=dict( 31 | img_path='JPEGImages', seg_map_path='SegmentationClass'), 32 | ann_file='ImageSets/Segmentation/val.txt', 33 | pipeline=test_pipeline)) -------------------------------------------------------------------------------- /src/training/seg_configs/cls_ade20k.txt: -------------------------------------------------------------------------------- 1 | wall 2 | building 3 | sky 4 | floor 5 | tree 6 | ceiling 7 | road 8 | bed 9 | windowpane 10 | grass 11 | cabinet 12 | sidewalk 13 | person 14 | earth 15 | door 16 | table 17 | mountain 18 | plant 19 | curtain 20 | chair 21 | car 22 | water 23 | painting 24 | sofa 25 | shelf 26 | house 27 | sea 28 | mirror 29 | rug 30 | field 31 | armchair 32 | seat 33 | fence 34 | desk 35 | rock 36 | wardrobe 37 | lamp 38 | bathtub 39 | railing 40 | cushion 41 | base 42 | box 43 | column 44 | signboard 45 | chestofdrawers 46 | counter 47 | sand 48 | sink 49 | skyscraper 50 | fireplace 51 | refrigerator 52 | grandstand 53 | path 54 | stairs 55 | runway 56 | case 57 | pooltable 58 | pillow 59 | screendoor 60 | stairway 61 | river 62 | bridge 63 | bookcase 64 | blind 65 | coffeetable 66 | toilet 67 | flower 68 | book 69 | hill 70 | bench 71 | countertop 72 | stove 73 | palm 74 | kitchenisland 75 | computer 76 | swivelchair 77 | boat 78 | bar 79 | arcademachine 80 | hovel 81 | bus 82 | towel 83 | light 84 | truck 85 | tower 86 | chandelier 87 | awning 88 | streetlight 89 | booth 90 | televisionreceiver 91 | airplane 92 | dirttrack 93 | apparel 94 | pole 95 | land 96 | bannister 97 | escalator 98 | ottoman 99 | bottle 100 | buffet 101 | poster 102 | stage 103 | van 104 | ship 105 | fountain 106 | conveyerbelt 107 | canopy 108 | washer 109 | plaything 110 | swimmingpool 111 | stool 112 | barrel 113 | basket 114 | waterfall 115 | tent 116 | bag 117 | minibike 118 | cradle 119 | oven 120 | ball 121 | food 122 | step 123 | tank 124 | tradename 125 | microwave 126 | pot 127 | animal 128 | bicycle 129 | lake 130 | dishwasher 131 | screen 132 | blanket 133 | sculpture 134 | hood 135 | sconce 136 | vase 137 | trafficlight 138 | tray 139 | ashcan 140 | fan 141 | pier 142 | crtscreen 143 | plate 144 | monitor 145 | bulletinboard 146 | shower 147 | radiator 148 | glass 149 | clock 150 | flag -------------------------------------------------------------------------------- /src/training/seg_configs/cls_city_scapes.txt: -------------------------------------------------------------------------------- 1 | road 2 | sidewalk 3 | building 4 | wall 5 | fence 6 | pole 7 | trafficlight 8 | trafficsign 9 | vegetation 10 | terrain 11 | sky 12 | person 13 | rider 14 | car 15 | truck 16 | bus 17 | train 18 | motorcycle 19 | bicycle -------------------------------------------------------------------------------- /src/training/seg_configs/cls_coco_object.txt: -------------------------------------------------------------------------------- 1 | sky, wall, tree, wood, grass, road, sea, river, mountain, sands, desk, bed, building, cloud, lamp, door, window, wardrobe, ceiling, shelf, curtain, stair, floor, hill, rail, fence 2 | person, person in shirt, person in jeans, person in dress, person in sweater, person in skirt, person in jacket, body 3 | bicycle 4 | car 5 | motorcycle 6 | airplane 7 | bus 8 | train 9 | truck 10 | boat 11 | traffic light 12 | fire hydrant 13 | stop sign 14 | parking meter 15 | bench 16 | bird 17 | cat 18 | dog 19 | horse 20 | sheep 21 | cow 22 | elephant 23 | bear 24 | zebra 25 | giraffe 26 | backpack 27 | umbrella 28 | handbag 29 | tie 30 | suitcase 31 | frisbee 32 | skis 33 | snowboard 34 | sports ball 35 | kite 36 | baseball bat 37 | baseball glove 38 | skateboard 39 | surfboard 40 | tennis racket 41 | bottle 42 | wine glass 43 | cup 44 | fork 45 | knife 46 | spoon 47 | bowl 48 | banana 49 | apple 50 | sandwich 51 | orange 52 | broccoli 53 | carrot 54 | hot dog 55 | pizza 56 | donut 57 | cake 58 | chair 59 | couch 60 | potted plant 61 | bed 62 | dining table 63 | toilet 64 | tv 65 | laptop 66 | mouse 67 | remote 68 | keyboard 69 | cell phone 70 | microwave 71 | oven 72 | toaster 73 | sink 74 | refrigerator 75 | book 76 | clock 77 | vase 78 | scissors 79 | teddy bear 80 | hair drier 81 | toothbrush -------------------------------------------------------------------------------- /src/training/seg_configs/cls_coco_stuff.txt: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorcycle 5 | airplane 6 | bus 7 | train 8 | truck 9 | boat 10 | trafficlight 11 | firehydrant 12 | stopsign 13 | parkingmeter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sportsball 34 | kite 35 | baseballbat 36 | baseballglove 37 | skateboard 38 | surfboard 39 | tennisracket 40 | bottle 41 | wineglass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hotdog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | couch 59 | pottedplant 60 | bed 61 | diningtable 62 | toilet 63 | tv 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cellphone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddybear 79 | hairdrier 80 | toothbrush 81 | banner 82 | blanket 83 | branch 84 | bridge 85 | building-other 86 | bush 87 | cabinet 88 | cage 89 | cardboard 90 | carpet 91 | ceiling-other 92 | ceiling-tile 93 | cloth 94 | clothes 95 | clouds 96 | counter 97 | cupboard 98 | curtain 99 | desk-stuff 100 | dirt 101 | door-stuff 102 | fence 103 | floor-marble 104 | floor-other 105 | floor-stone 106 | floor-tile 107 | floor-wood 108 | flower 109 | fog 110 | food-other 111 | fruit 112 | furniture-other 113 | grass 114 | gravel 115 | ground-other 116 | hill 117 | house 118 | leaves 119 | light 120 | mat 121 | metal 122 | mirror-stuff 123 | moss 124 | mountain 125 | mud 126 | napkin 127 | net 128 | paper 129 | pavement 130 | pillow 131 | plant-other 132 | plastic 133 | platform 134 | playingfield 135 | railing 136 | railroad 137 | river 138 | road 139 | rock 140 | roof 141 | rug 142 | salad 143 | sand 144 | sea 145 | shelf 146 | sky-other 147 | skyscraper 148 | snow 149 | solid-other 150 | stairs 151 | stone 152 | straw 153 | structural-other 154 | table 155 | tent 156 | textile-other 157 | towel 158 | tree 159 | vegetable 160 | wall-brick 161 | wall-concrete 162 | wall-other 163 | wall-panel 164 | wall-stone 165 | wall-tile 166 | wall-wood 167 | water-other 168 | waterdrops 169 | window-blind 170 | window-other 171 | wood -------------------------------------------------------------------------------- /src/training/seg_configs/cls_context59.txt: -------------------------------------------------------------------------------- 1 | aeroplane 2 | bag 3 | bed 4 | bedclothes 5 | bench 6 | bicycle 7 | bird 8 | boat 9 | book 10 | bottle 11 | building 12 | bus 13 | cabinet 14 | car 15 | cat 16 | ceiling 17 | chair 18 | cloth 19 | computer 20 | cow 21 | cup 22 | curtain 23 | dog 24 | door 25 | fence 26 | floor 27 | flower 28 | food 29 | grass 30 | ground 31 | horse 32 | keyboard 33 | light 34 | motorbike 35 | mountain 36 | mouse 37 | person 38 | plate 39 | platform 40 | pottedplant 41 | road 42 | rock 43 | sheep 44 | shelves 45 | sidewalk 46 | sign 47 | sky 48 | snow 49 | sofa 50 | table 51 | track 52 | train 53 | tree 54 | truck 55 | tvmonitor 56 | wall 57 | water 58 | window 59 | wood -------------------------------------------------------------------------------- /src/training/seg_configs/cls_context60.txt: -------------------------------------------------------------------------------- 1 | background 2 | aeroplane 3 | bag 4 | bed 5 | bedclothes 6 | bench 7 | bicycle 8 | bird 9 | boat 10 | book 11 | bottle 12 | building 13 | bus 14 | cabinet 15 | car 16 | cat 17 | ceiling 18 | chair 19 | cloth 20 | computer 21 | cow 22 | cup 23 | curtain 24 | dog 25 | door 26 | fence 27 | floor 28 | flower 29 | food 30 | grass 31 | ground 32 | horse 33 | keyboard 34 | light 35 | motorbike 36 | mountain 37 | mouse 38 | person 39 | plate 40 | platform 41 | pottedplant 42 | road 43 | rock 44 | sheep 45 | shelves 46 | sidewalk 47 | sign 48 | sky 49 | snow 50 | sofa 51 | table 52 | track 53 | train 54 | tree 55 | truck 56 | tvmonitor 57 | wall 58 | water 59 | window 60 | wood -------------------------------------------------------------------------------- /src/training/seg_configs/cls_voc20.txt: -------------------------------------------------------------------------------- 1 | aeroplane 2 | bicycle 3 | bird 4 | ship 5 | bottle 6 | bus 7 | car 8 | cat 9 | chair 10 | cow 11 | table 12 | dog 13 | horse 14 | motorbike 15 | person, person in shirt, person in jeans, person in dress, person in sweater, person in skirt, person in jacket 16 | pottedplant 17 | sheep 18 | sofa 19 | train 20 | television monitor, tv monitor, monitor, television, screen -------------------------------------------------------------------------------- /src/training/seg_configs/cls_voc21.txt: -------------------------------------------------------------------------------- 1 | sky, wall, tree, wood, grass, road, sea, river, mountain, sands, desk, bed, building, cloud, lamp, door, window, wardrobe, ceiling, shelf, curtain, stair, floor, hill, rail, fence 2 | aeroplane 3 | bicycle 4 | bird 5 | ship 6 | bottle 7 | bus 8 | car 9 | cat 10 | chair 11 | cow 12 | table 13 | dog 14 | horse 15 | motorbike 16 | person, person in shirt, person in jeans, person in dress, person in sweater, person in skirt, person in jacket 17 | pottedplant 18 | sheep 19 | sofa 20 | train 21 | television monitor, tv monitor, monitor, television, screen -------------------------------------------------------------------------------- /src/training/seg_configs/convert_coco_object.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # GroupViT (https://github.com/NVlabs/GroupViT) 3 | # Copyright (c) 2021-22, NVIDIA Corporation & affiliates. All Rights Reserved. 4 | # ------------------------------------------------------------------------------ 5 | 6 | import argparse 7 | import os.path as osp 8 | import shutil 9 | from functools import partial 10 | from glob import glob 11 | 12 | from mmengine.utils import (mkdir_or_exist, track_parallel_progress, 13 | track_progress) 14 | import numpy as np 15 | from PIL import Image 16 | 17 | COCO_LEN = 123287 18 | 19 | clsID_to_trID = { 20 | 0: 0, 21 | 1: 1, 22 | 2: 2, 23 | 3: 3, 24 | 4: 4, 25 | 5: 5, 26 | 6: 6, 27 | 7: 7, 28 | 8: 8, 29 | 9: 9, 30 | 10: 10, 31 | 12: 11, 32 | 13: 12, 33 | 14: 13, 34 | 15: 14, 35 | 16: 15, 36 | 17: 16, 37 | 18: 17, 38 | 19: 18, 39 | 20: 19, 40 | 21: 20, 41 | 22: 21, 42 | 23: 22, 43 | 24: 23, 44 | 26: 24, 45 | 27: 25, 46 | 30: 26, 47 | 31: 27, 48 | 32: 28, 49 | 33: 29, 50 | 34: 30, 51 | 35: 31, 52 | 36: 32, 53 | 37: 33, 54 | 38: 34, 55 | 39: 35, 56 | 40: 36, 57 | 41: 37, 58 | 42: 38, 59 | 43: 39, 60 | 45: 40, 61 | 46: 41, 62 | 47: 42, 63 | 48: 43, 64 | 49: 44, 65 | 50: 45, 66 | 51: 46, 67 | 52: 47, 68 | 53: 48, 69 | 54: 49, 70 | 55: 50, 71 | 56: 51, 72 | 57: 52, 73 | 58: 53, 74 | 59: 54, 75 | 60: 55, 76 | 61: 56, 77 | 62: 57, 78 | 63: 58, 79 | 64: 59, 80 | 66: 60, 81 | 69: 61, 82 | 71: 62, 83 | 72: 63, 84 | 73: 64, 85 | 74: 65, 86 | 75: 66, 87 | 76: 67, 88 | 77: 68, 89 | 78: 69, 90 | 79: 70, 91 | 80: 71, 92 | 81: 72, 93 | 83: 73, 94 | 84: 74, 95 | 85: 75, 96 | 86: 76, 97 | 87: 77, 98 | 88: 78, 99 | 89: 79, 100 | 91: 80, 101 | 92: 81, 102 | 93: 82, 103 | 94: 83, 104 | 95: 84, 105 | 96: 85, 106 | 97: 86, 107 | 98: 87, 108 | 99: 88, 109 | 100: 89, 110 | 101: 90, 111 | 102: 91, 112 | 103: 92, 113 | 104: 93, 114 | 105: 94, 115 | 106: 95, 116 | 107: 96, 117 | 108: 97, 118 | 109: 98, 119 | 110: 99, 120 | 111: 100, 121 | 112: 101, 122 | 113: 102, 123 | 114: 103, 124 | 115: 104, 125 | 116: 105, 126 | 117: 106, 127 | 118: 107, 128 | 119: 108, 129 | 120: 109, 130 | 121: 110, 131 | 122: 111, 132 | 123: 112, 133 | 124: 113, 134 | 125: 114, 135 | 126: 115, 136 | 127: 116, 137 | 128: 117, 138 | 129: 118, 139 | 130: 119, 140 | 131: 120, 141 | 132: 121, 142 | 133: 122, 143 | 134: 123, 144 | 135: 124, 145 | 136: 125, 146 | 137: 126, 147 | 138: 127, 148 | 139: 128, 149 | 140: 129, 150 | 141: 130, 151 | 142: 131, 152 | 143: 132, 153 | 144: 133, 154 | 145: 134, 155 | 146: 135, 156 | 147: 136, 157 | 148: 137, 158 | 149: 138, 159 | 150: 139, 160 | 151: 140, 161 | 152: 141, 162 | 153: 142, 163 | 154: 143, 164 | 155: 144, 165 | 156: 145, 166 | 157: 146, 167 | 158: 147, 168 | 159: 148, 169 | 160: 149, 170 | 161: 150, 171 | 162: 151, 172 | 163: 152, 173 | 164: 153, 174 | 165: 154, 175 | 166: 155, 176 | 167: 156, 177 | 168: 157, 178 | 169: 158, 179 | 170: 159, 180 | 171: 160, 181 | 172: 161, 182 | 173: 162, 183 | 174: 163, 184 | 175: 164, 185 | 176: 165, 186 | 177: 166, 187 | 178: 167, 188 | 179: 168, 189 | 180: 169, 190 | 181: 170, 191 | 255: 255 192 | } 193 | 194 | # set to background 195 | for k, v in clsID_to_trID.items(): 196 | clsID_to_trID[k] = v + 1 197 | if k > 90: 198 | clsID_to_trID[k] = 0 199 | 200 | 201 | def convert_to_trainID(maskpath, out_mask_dir, is_train): 202 | mask = np.array(Image.open(maskpath)) 203 | mask_copy = mask.copy() 204 | for clsID, trID in clsID_to_trID.items(): 205 | mask_copy[mask == clsID] = trID 206 | seg_filename = osp.join( 207 | out_mask_dir, 'train2017', 208 | osp.basename(maskpath).split('.')[0] + 209 | '_instanceTrainIds.png') if is_train else osp.join( 210 | out_mask_dir, 'val2017', 211 | osp.basename(maskpath).split('.')[0] + '_instanceTrainIds.png') 212 | Image.fromarray(mask_copy).save(seg_filename, 'PNG') 213 | 214 | 215 | def parse_args(): 216 | parser = argparse.ArgumentParser( 217 | description=\ 218 | 'Convert COCO Stuff 164k annotations to COCO Objects') # noqa 219 | parser.add_argument('coco_path', help='coco stuff path') 220 | parser.add_argument('-o', '--out_dir', help='output path') 221 | parser.add_argument( 222 | '--nproc', default=16, type=int, help='number of process') 223 | args = parser.parse_args() 224 | return args 225 | 226 | 227 | def main(): 228 | args = parse_args() 229 | coco_path = args.coco_path 230 | nproc = args.nproc 231 | 232 | out_dir = args.out_dir or coco_path 233 | out_img_dir = osp.join(out_dir, 'images') 234 | out_mask_dir = osp.join(out_dir, 'annotations') 235 | 236 | mkdir_or_exist(osp.join(out_mask_dir, 'train2017')) 237 | mkdir_or_exist(osp.join(out_mask_dir, 'val2017')) 238 | 239 | if out_dir != coco_path: 240 | shutil.copytree(osp.join(coco_path, 'images'), out_img_dir) 241 | 242 | #train_list = glob(osp.join(coco_path, 'annotations', 'train2017', '*.png')) 243 | #train_list = [file for file in train_list if 'TrainIds' not in file] 244 | test_list = glob(osp.join(coco_path, 'annotations', 'val2017', '*.png')) 245 | test_list = [file for file in test_list if 'TrainIds' not in file] 246 | 247 | #assert (len(train_list) + 248 | # len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format( 249 | # len(train_list), len(test_list)) 250 | 251 | if args.nproc > 1: 252 | """ 253 | track_parallel_progress( 254 | partial( 255 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True), 256 | train_list, 257 | nproc=nproc) 258 | """ 259 | track_parallel_progress( 260 | partial( 261 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False), 262 | test_list, 263 | nproc=nproc) 264 | else: 265 | """ 266 | track_progress( 267 | partial( 268 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True), 269 | train_list) 270 | """ 271 | track_progress( 272 | partial( 273 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False), 274 | test_list) 275 | 276 | print('Done!') 277 | 278 | 279 | if __name__ == '__main__': 280 | main() -------------------------------------------------------------------------------- /src/training/zero_shot.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from tqdm import tqdm 5 | 6 | from open_clip import get_input_dtype, get_tokenizer, build_zero_shot_classifier, \ 7 | IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES 8 | from .precision import get_autocast 9 | 10 | import json 11 | 12 | def accuracy(output, target, topk=(1,)): 13 | pred = output.topk(max(topk), 1, True, True)[1].t() 14 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 15 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] 16 | 17 | def run(model, classifier, dataloader, args): 18 | autocast = get_autocast(args.precision) 19 | input_dtype = get_input_dtype(args.precision) 20 | 21 | with torch.no_grad(): 22 | top1, top5, n = 0., 0., 0. 23 | for images, target in tqdm(dataloader, unit_scale=args.batch_size): 24 | images = images.to(device=args.device, dtype=input_dtype) 25 | target = target.to(args.device) 26 | 27 | with autocast(): 28 | # predict 29 | output = model(image=images) 30 | image_features = output['image_features'] if isinstance(output, dict) else output 31 | # image_features (B, 512), classifier (512, 1000) 32 | logits = 100. * image_features @ classifier 33 | 34 | # measure accuracy 35 | acc1, acc5 = accuracy(logits, target, topk=(1, 5)) 36 | top1 += acc1 37 | top5 += acc5 38 | n += images.size(0) 39 | 40 | top1 = (top1 / n) 41 | top5 = (top5 / n) 42 | return top1, top5 43 | 44 | def zero_shot_eval(model, data, epoch, args, tokenizer=None): 45 | if 'imagenet-val' not in data and 'imagenet-v2' not in data: 46 | return {} 47 | if args.zeroshot_frequency == 0: 48 | return {} 49 | if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: 50 | return {} 51 | if args.distributed and not args.horovod: 52 | model = model.module 53 | 54 | logging.info('Starting zero-shot imagenet.') 55 | if tokenizer is None: 56 | tokenizer = get_tokenizer(args.model) 57 | 58 | logging.info('Building zero-shot classifier') 59 | autocast = get_autocast(args.precision) 60 | with autocast(): 61 | classifier = build_zero_shot_classifier( 62 | model, 63 | tokenizer=tokenizer, 64 | args=args, 65 | classnames=IMAGENET_CLASSNAMES, 66 | templates=OPENAI_IMAGENET_TEMPLATES, 67 | num_classes_per_batch=10, 68 | device=args.device, 69 | use_tqdm=True, 70 | ) 71 | 72 | logging.info('Using classifier') 73 | results = {} 74 | if 'imagenet-val' in data: 75 | top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args) 76 | results['imagenet-zeroshot-val-top1'] = top1 77 | results['imagenet-zeroshot-val-top5'] = top5 78 | if 'imagenet-v2' in data: 79 | top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args) 80 | results['imagenetv2-zeroshot-val-top1'] = top1 81 | results['imagenetv2-zeroshot-val-top5'] = top5 82 | 83 | logging.info('Finished zero-shot imagenet.') 84 | 85 | return results 86 | 87 | def zero_shot_classification_eval(model, data_name, dataloader, dataset_labels, dataset_templates, epoch, args, tokenizer=None): 88 | if args.zeroshot_frequency == 0: 89 | return {} 90 | if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: 91 | return {} 92 | if args.distributed and not args.horovod: 93 | model = model.module 94 | 95 | logging.info(f'Starting zero-shot {data_name}.') 96 | if tokenizer is None: 97 | tokenizer = get_tokenizer(args.model) 98 | 99 | logging.info('Building zero-shot classifier') 100 | autocast = get_autocast(args.precision) 101 | with autocast(): 102 | classifier = build_zero_shot_classifier( 103 | model, 104 | tokenizer=tokenizer, 105 | args=args, 106 | classnames=dataset_labels[data_name], 107 | templates=dataset_templates[data_name], 108 | num_classes_per_batch=10, 109 | device=args.device, 110 | use_tqdm=True, 111 | ) 112 | 113 | logging.info('Using classifier') 114 | results = {} 115 | top1, top5 = run(model, classifier, dataloader, args) 116 | results[f'{data_name}-zeroshot-val-top1'] = top1 117 | results[f'{data_name}-zeroshot-val-top5'] = top5 118 | 119 | logging.info(f'Finished zero-shot {data_name}.') 120 | 121 | return results --------------------------------------------------------------------------------