├── README.md
├── assets
├── .gitkeep
├── cosmos.png
├── framework.png
└── qualitative_results_supp.png
├── datasets
├── README.md
└── imagenet_organize.py
├── requirements.txt
└── src
├── dataloaders
├── __init__.py
├── caltech101.py
├── cifar10.py
├── cifar100.py
├── dtd.py
├── fgvc_aircraft.py
├── flowers102.py
├── food101.py
├── label.json
├── oxford_pets.py
├── stanford_car.py
├── sun397.py
├── templates.json
└── utils.py
├── inference_classification.sh
├── inference_retrieval.sh
├── inference_segmentation.sh
├── main.py
├── open_clip
├── __init__.py
├── bpe_simple_vocab_16e6.txt.gz
├── coca_model.py
├── constants.py
├── convert.py
├── factory.py
├── hf_configs.py
├── hf_model.py
├── loss.py
├── model.py
├── model_configs
│ ├── ViT-B-16.json
│ └── ViT-B-32.json
├── modified_resnet.py
├── openai.py
├── pos_embed.py
├── pretrained.py
├── push_to_hf_hub.py
├── timm_model.py
├── tokenizer.py
├── transform.py
├── transformer.py
├── utils.py
├── version.py
├── zero_shot_classifier.py
└── zero_shot_metadata.py
├── seg_eval.py
├── train_cc12m.sh
├── train_cc3m.sh
├── train_merged30m.sh
├── train_pixelprose.sh
├── train_yfcc15m.sh
└── training
├── __init__.py
├── clip_segmentor.py
├── custom_datasets.py
├── data.py
├── distributed.py
├── file_utils.py
├── logger.py
├── pamr.py
├── params.py
├── precision.py
├── profiler.py
├── scheduler.py
├── seg_configs
├── base_config.py
├── cfg_ade20k.py
├── cfg_city_scapes.py
├── cfg_coco_object.py
├── cfg_coco_stuff164k.py
├── cfg_context59.py
├── cfg_context60.py
├── cfg_voc20.py
├── cfg_voc21.py
├── cls_ade20k.txt
├── cls_city_scapes.txt
├── cls_coco_object.txt
├── cls_coco_stuff.txt
├── cls_context59.txt
├── cls_context60.txt
├── cls_voc20.txt
├── cls_voc21.txt
├── convert_cityscapes.py
└── convert_coco_object.py
├── train.py
└── zero_shot.py
/README.md:
--------------------------------------------------------------------------------
1 | # [CVPR 2025] COSMOS: Cross-Modality Self-Distillation for Vision Language Pre-training
2 | [](https://arxiv.org/abs/2412.01814)
3 | [](https://huggingface.co/sankim2/cosmos)
4 |
5 |
6 | **Authors:** [Sanghwan Kim](https://kim-sanghwan.github.io/), [Rui Xiao](https://www.eml-munich.de/people/rui-xiao), [Mariana-Iuliana Georgescu](https://lilygeorgescu.github.io/), [Stephan Alaniz](https://www.eml-munich.de/people/stephan-alaniz), [Zeynep Akata](https://www.eml-munich.de/people/zeynep-akata)
7 |
8 | ### Abstract
9 | Vision-Language Models (VLMs) trained with contrastive loss have achieved significant advancements in various vision and language tasks. However, the global nature of contrastive loss makes VLMs focus predominantly on foreground objects, neglecting other crucial information in the image, which limits their effectiveness in downstream tasks. To address these challenges, we propose COSMOS: CrOSs-MOdality Self-distillation for vision-language pre-training that integrates a novel text-cropping strategy and cross-attention module into a self-supervised learning framework. We create global and local views of images and texts (i.e., multi-modal augmentations), which are essential for self-distillation in VLMs. We further introduce a cross-attention module, enabling COSMOS to learn comprehensive cross-modal representations optimized via a cross-modality self-distillation loss. COSMOS consistently outperforms previous strong baselines on various zero-shot downstream tasks, including retrieval, classification, and semantic segmentation. Additionally, it surpasses CLIP-based models trained on larger datasets in visual perception and contextual understanding tasks.
10 |
11 | ## Methodology
12 | 
13 |
14 | ## Pre-trained Model Weights
15 |
16 | We released the pre-trained COSMOS models on [Huggingface](https://huggingface.co/sankim2/cosmos). Our pre-trained models and their corresponding performances on COCO (I2T R@1 and T2I R@1), Flickr (I2T R@1 and T2I R@1) and ImageNet (Top-1) are reported below. For the full results, please refer to our [paper](https://arxiv.org/abs/2412.01814).
17 |
18 | | **Checkpoints** | **Arch.** | **Datasets** | **COCO I2T** | **COCO T2I** | **Flickr I2T** | **Flickr T2I** | **IN Top-1** |
19 | |------------------------------------------------------------------------------------------------------------|------------------|--------------|--------------|----------------|----------------|----------------|----------------|
20 | | [cosmos_vitb16_cc3m](https://huggingface.co/sankim2/cosmos/resolve/main/cosmos_vitb16_cc3m.pt?download=true) | ViT-B/16 | CC3M-recap | 53.1 | 40.1 | 84.1 | 68.6 |37.1 |
21 | | [cosmos_vitb16_cc12m](https://huggingface.co/sankim2/cosmos/resolve/main/cosmos_vitb16_cc12m.pt?download=true) | ViT-B/16 | CC12M-recap | 64.2 | 48.9 | 91.4 | 76.2 |51.4 |
22 | | [cosmos_vitb16_yfcc15m](https://huggingface.co/sankim2/cosmos/resolve/main/cosmos_vitb16_yfcc15m.pt?download=true) | ViT-B/16 | YFCC15M-recap | 67.5 | 50.9 | 92.6 | 79.6 |52.4 |
23 | | [cosmos_vitb16_merged30m](https://huggingface.co/sankim2/cosmos/resolve/main/cosmos_vitb16_merged30m.pt?download=true) | ViT-B/16 | Merged30M | 68.0 | 52.5 | 92.9 | 80.3 |57.6 |
24 | | [cosmos_vitb16_pixelprose](https://huggingface.co/sankim2/cosmos/resolve/main/cosmos_vitb16_pixelprose.pt?download=true) | ViT-B/16 | PixelProse | 62.4 | 43.4 | 89.9 | 73.6 |59.6 |
25 | | [cosmos_vitb32_cc3m](https://huggingface.co/sankim2/cosmos/resolve/main/cosmos_vitb32_cc3m.pt?download=true) | ViT-B/32 | CC3M-recap | 47.6 | 33.5 | 74.3 | 59.2 |33.0 |
26 | | [cosmos_vitb32_cc12m](https://huggingface.co/sankim2/cosmos/resolve/main/cosmos_vitb32_cc12m.pt?download=true) | ViT-B/32 | CC12M-recap | 59.6 | 43.0 | 86.5 | 69.8 |46.7 |
27 | | [cosmos_vitb32_yfcc15m](https://huggingface.co/sankim2/cosmos/resolve/main/cosmos_vitb32_yfcc15m.pt?download=true) | ViT-B/32 | YFCC15M-recap | 64.5 | 46.0 | 90.2 | 73.3 |48.1 |
28 | | [cosmos_vitb32_merged30m](https://huggingface.co/sankim2/cosmos/resolve/main/cosmos_vitb32_merged30m.pt?download=true) | ViT-B/32 | Merged30M | 64.3 | 48.4 | 89.9 | 76.1 |53.4 |
29 | | [cosmos_vitb32_pixelprose](https://huggingface.co/sankim2/cosmos/resolve/main/cosmos_vitb32_pixelprose.pt?download=true) | ViT-B/32 | PixelProse | 57.2 | 38.9 | 85.6 | 66.3 |54.3 |
30 |
31 | ⚠️ You don't need to manually download the pre-trained weights to run the inference, the pre-trained weights will be automatically downloaded by specifying the `--huggingface-model-name` and `--huggingface-repo-name` during inference.
32 | Optionally, you could download each weight separately and set `--resume path/to/pretrained_weights` flag in inference code.
33 |
34 | ## Dependencies
35 | You can set up your virtual environment following the below instructions. We built our code repository upon [OpenCLIP](https://github.com/mlfoundations/open_clip), which is still updated frequently. We recommend you to check their repo for a detailed tutorial on creating an environment that is best suited for your system. A conda environment is also possible with the same Python and PyTorch version.
36 |
37 | ### 1. Download our github Repository
38 | First, download the COSMOS github repo and navigate to the project’s root directory `cosmos/`.
39 | ```bash
40 | git clone https://github.com/ExplainableML/cosmos.git
41 | cd cosmos/
42 | ```
43 |
44 | ### 2. Create a Virtual Environment
45 | Create a virtual environment using Python 3.12 and activate the virtual environment.
46 | ```bash
47 | python3.12 -m venv cosmos_env
48 | source cosmos_env/bin/activate
49 | ```
50 |
51 | ### 3. Install Dependencies
52 | Install all requirements via pip.
53 | ```bash
54 | pip install --upgrade pip
55 | pip install -r requirements.txt
56 | ```
57 | If you want to conduct semantic segmentation tasks, please follow [SCLIP](https://github.com/wangf3014/SCLIP) to install their dependencies as well. We wrote down their command below for completeness.
58 | ```bash
59 | pip install openmim
60 | mim install mmcv==2.0.1 mmengine==0.8.4 mmsegmentation==1.1.1
61 | pip install ftfy regex yapf==0.40.1
62 | ```
63 |
64 | ### [Optional] Anaconda Environment
65 | One can optionally use anaconda to set up the environment.
66 | ```bash
67 | conda create --name cosmos_env python=3.12
68 | conda activate cosmos_env
69 | ```
70 | Then, install all dependencies as follows.
71 | ```bash
72 | pip install --upgrade pip
73 | pip install -r requirements.txt
74 | ```
75 |
76 | ## Inference Datasets Preparation
77 | Check [datasets/README.md](datasets/README.md) to prepare all the inference datasets for retrieval, classification, and segmentation tasks.
78 |
79 | ## Inference with COSMOS
80 | To reproduce the results of downstream tasks (image-text retrieval, image classification, semantic segmentation) in the COSMOS paper, we provide an example inference bash script for each task: `src/inference_retrieval.sh`, `src/inference_classification.sh`, and `src/inference_segmentation.sh`.
81 |
82 | Here are detailed explanations of important flags.
83 |
84 |
85 | - `--huggingface-repo-name`: Name of the Huggingface repo where the pre-trained models are stored. Should be fixed as `sankim2/cosmos`.
86 | - `--huggingface-model-name`: Name of the pretrained models. Options include `cosmos_vitb16_cc3m.pt, cosmos_vitb16_cc12m.pt, cosmos_vitb16_yfcc15m.pt, cosmos_vitb16_merged30m.pt, cosmos_vitb16_pixelprose.pt` for ViT-B/16 and `cosmos_vitb32_cc3m.pt, cosmos_vitb32_cc12m.pt, cosmos_vitb32_yfcc15m.pt, cosmos_vitb32_merged30m.pt, cosmos_vitb32_pixelprose.pt` for ViT-B/32.
87 | - `--model`: Model architecture should be matched with `--huggingface-model-name`. Options include `ViT-B-16` and `ViT-B-32`.
88 | - `--precision`: Defualt as `amp` in our paper.
89 | - `--workers`: Adjustable according to your system.
90 |
91 | ### Image-Text Retrieval Task
92 | `--data-root-dir` should denote your directory which contains COCO and Flickr30k validation set. Please refer to [/src/inference_retrieval.sh](/src/inference_retrieval.sh) for running inference on retrieval task.
93 |
94 | ### Image Classification Task
95 | `--imagenet-val` should denote your directory which contains ImageNet validation set. Please refer to [/src/inference_classification.sh](/src/inference_classification.sh) for running inference on classification task.
96 |
97 | ### Semantic Segmentation Task
98 | `--seg-w-background` denotes a flag whether to evaluate on segmentation benchmarks with background. If `--use-csa` is included, the model will use Correlative Self-Attention (CSA) block from SCLIP for segmentation. Please refer to [/src/inference_segmentation.sh](/src/inference_segmentation.sh) for running inference on segmentation task.
99 |
100 | ## Training COSMOS
101 | In order to train COSMOS from scratch, synthetic long caption datasets should be downloaded from [DreamLIP](https://github.com/ant-research/DreamLIP)'s recaptioned [CC3M-recap](https://huggingface.co/datasets/qidouxiong619/dreamlip_long_captions), [CC12M-recap](https://huggingface.co/datasets/qidouxiong619/dreamlip_long_captions), [YFCC15M-recap](https://huggingface.co/datasets/qidouxiong619/dreamlip_long_captions) and combined (Merged-30M), and [PixelProse](https://huggingface.co/datasets/tomg-group-umd/pixelprose). Notably, COSMOS requires all pre-training dataset to be processed into the [webdataset](https://github.com/webdataset/webdataset) format, to achieve higher I/O efficiency for large-scale training. In the pre-training dataset preparation step, we take [CC3M-recap](https://huggingface.co/datasets/qidouxiong619/dreamlip_long_captions) as an example to demonstrate how to prepare the pretraining data. The preparation for other datasets should be similar. We share the same pre-training dataset as [FLAIR](https://arxiv.org/abs/2412.03561). Please check their [repo](https://github.com/ExplainableML/flair) as well if you find it interesting!
102 |
103 | ### Prepare Pre-training Data
104 | 1. Download DreamLIP's annotations for CC3M-recap:
105 | `wget https://huggingface.co/datasets/qidouxiong619/dreamlip_long_captions/resolve/main/cc3m_3long_3short_1raw_captions_url.csv`
106 | 2. Scrape the images based on the url links using [img2dataset](https://github.com/rom1504/img2dataset).
107 |
108 | ### Training Script
109 | COSMOS is trained with Slurm GPU Cluster on 16 NVIDIA A100s 40GB (on CC3M) or 128 NVIDIA A100s 40GB (on other larger datasets). In `src/`, we provide example slurm training scripts for each of the datasets: `train_cc3m.sh, train_cc12m.sh, train_yfcc15m.sh, train_merged30m.sh, train_pixelprose.sh`.
110 |
111 | Important flags are described below:
112 | - `--train-data`: Root dir of where the training data (shards) is stored.
113 | - `--train-num-samples`: the total number of training samples. This should be adjustable based on your available data.
114 | - `--use-imagecrop-aug`: Using multi-crop image augmentation described in the paper.
115 | - `--global-crops-number`: Number of global crop of image. Fixed as 2.
116 | - `--local-crops-number`: Number of local crop of image.
117 | - `--crop-scale`: Determine the scale s of global and local crop images. (0.05, s) for local crops and (s, 1.0) for global crops. Fixed as 0.4
118 | - `--caption-sampling-mode`: Determine how captions are sampled. Fixed as `textcrop` or `textcrop_pixelprose`.
119 | - `--num-sampled-captions`: Total number of captions (global+local)
120 | - `--momentum-teacher`: Initial momentum value. This should be adjusted based on batch size. We used 0.999 for 1k batch and 0.99 for 4k batch.
121 | - `--fix-momentum`: Fix momentum value during training.
122 | - `--output-all`: Output both patch (or word) tokens and [cls] (or [eot]) tokens.
123 | - `--attentional-pool`: Set cross-attention module in model.
124 | - `--cosmos`: Use COSMOS loss during training.
125 |
126 | ## Qualitative Results
127 | We visualize the attention weights of image and text cross-attention modules. Patch-wise (image) and token-wise (caption) attention weights are both normalized between 0 and 1.
128 |
129 | 
130 |
131 | ## Acknowledgements
132 | We thank [OpenCLIP](https://github.com/mlfoundations/open_clip) for providing the amazing code base. Meanwhile, we acknowledge [DreamLIP](https://github.com/zyf0619sjtu/DreamLIP) and [PixelProse](https://huggingface.co/datasets/tomg-group-umd/pixelprose) for providing us with various pre-training datasets with captions from MLLMs. We are also grateful to [SCLIP](https://github.com/wangf3014/SCLIP) for providing the detailed scheme for semantic segmentation task.
133 |
134 | ## Citations
135 | If you find our work useful, please star this repo and cite:
136 |
137 | ```bibtex
138 | @article{kim2025cosmos,
139 | title={COSMOS: Cross-Modality Self-Distillation for Vision Language Pre-training},
140 | author={Kim, Sanghwan and Xiao, Rui and Georgescu, Mariana-Iuliana and Alaniz, Stephan and Akata, Zeynep},
141 | journal={CVPR},
142 | year={2025}
143 | }
144 | ```
145 |
--------------------------------------------------------------------------------
/assets/.gitkeep:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/assets/cosmos.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExplainableML/cosmos/c80348f08e9b02c5a81adadd7dce486b3c284b78/assets/cosmos.png
--------------------------------------------------------------------------------
/assets/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExplainableML/cosmos/c80348f08e9b02c5a81adadd7dce486b3c284b78/assets/framework.png
--------------------------------------------------------------------------------
/assets/qualitative_results_supp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExplainableML/cosmos/c80348f08e9b02c5a81adadd7dce486b3c284b78/assets/qualitative_results_supp.png
--------------------------------------------------------------------------------
/datasets/README.md:
--------------------------------------------------------------------------------
1 | # Data Preparation for Downstream Tasks
2 | ### Datasets list:
3 | - [MSCOCO](#coco)
4 | - [Flickr30k](#flickr)
5 | - [ImageNet](#imagenet)
6 | - [Other Classification datasets](#others)
7 | - [Segmentation datasets](#segmentation)
8 |
9 |
10 | ## Image-Text Retrieval Task
11 |
12 | ### MSCOCO dataset
13 | ```
14 | $coco/
15 | |–– images/
16 | |–––– val2017/
17 | |–––––– 000000134722.jpg
18 | |–––––– 000000177015.jpg
19 | |–––––– ...
20 | |–– annotations/
21 | |–––– captions_val2017.json
22 | ```
23 | Step 1. Download validation images from [COCO 2017 Val Images](https://cocodataset.org/#download), unzip them to `coco/images/val2017`.
24 |
25 | Step 2. Download the 2017 Val annotations, place it under `coco/annotations/captions_val2017.json`.
26 |
27 | ### Flickr30K dataset
28 | ```
29 | $flickr30k-images/
30 | |–– 2217728745.jpg
31 | |–– 2217728745.jpg
32 | |–– ...
33 | |–– flickr30k_val.json
34 | |–– flickr30k_test.json
35 | ```
36 | Step 1. Download [flickr30k dataset](https://huggingface.co/datasets/nlphuji/flickr30k), unzip them under `flickr30k-images/`, all the images and annotations files will be structured as above.
37 |
38 | ## Image Classification Task
39 |
40 | ### ImageNet dataset
41 | ```
42 | $imagenet/
43 | |–– data/
44 | |–––– val_images/
45 | |–––––– n01440764/
46 | |–––––––– ILSVRC2012_val_00000293_n01440764.JPEG
47 | |–––––––– ILSVRC2012_val_00017699_n01440764.JPEG
48 | |–––––––– ...
49 | |–––––– n01871265/
50 | |–––––––– ILSVRC2012_val_00000067_n01871265.JPEG
51 | |–––––––– ILSVRC2012_val_00017361_n01871265.JPEG
52 | |–––––––– ...
53 | ```
54 |
55 | Step 1. Download validation data `val_images.tar.gz` from [ILSVRC/imagenet-1k](https://huggingface.co/datasets/ILSVRC/imagenet-1k), and unzip them to `imagenet/data/val_images`.
56 | You can manually download the `imagenet-1k/data/val_images.tar.gz` or use this command. `huggingface-cli download ILSVRC/imagenet-1k --repo-type dataset --local-dir /directory/to/your/dataset/`.
57 |
58 | Step 2. Change [source_dir](imagenet_organize.py#L5) in `imagenet_organize.py` according to your val_images folder. Then, run `imagenet_organize.py` to organize the image in the above format.
59 |
60 | ### Other Classification datasets
61 |
62 | Other classification datasets include `["food101", "cifar10", "cifar100", "sun397", "stanford_car", "aircraft", "dtd", "pets", "caltech101", "flowers"]`.
63 |
64 | Please set appropriate [dataset_root](/src/dataloaders/utils.py#L17) in `src/dataloaders/utils.py` to save classification datasets.
65 |
66 | Then, `torchvision.datasets` will automatically download the datatsets in `dataset_root` during inference.
67 |
68 |
69 | ## Semantic Segmentation Task
70 |
71 | ### Segmentation datasets
72 |
73 | We followed the evaluation scheme and config files provided by [SCLIP](https://github.com/wangf3014/SCLIP) as shown [here](/src/training/seg_configs).
74 |
75 | Our segmentation configs include benchmarks with background `['cfg_voc21.py', 'cfg_context60.py', 'cfg_coco_object.py']` and without background `['cfg_voc20.py', 'cfg_city_scapes.py', 'cfg_context59.py', 'cfg_ade20k.py', 'cfg_coco_stuff164k.py']`.
76 |
77 | Please follow the dataset preparation instruction provided by [SCLIP](https://github.com/wangf3014/SCLIP) and [mmsegmentation](https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md) to download the following datasets: `["VOCdevkit/VOC2012", "VOCdevkit/VOC2010", "coco_stuff164k", "cityscapes, "ade"]`.
78 |
79 | Then, change the `data_root` in each segmentation config according to the dataset location. For example, this is [root_dir](/src/training/seg_configs/cfg_ade20k.py#L12) for `cfg_ade20k.py`.
80 |
--------------------------------------------------------------------------------
/datasets/imagenet_organize.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | # Define the source directory
5 | source_dir = "/your/directory/to/imagenet/data/val_images"
6 |
7 | # Check if the source directory exists
8 | if not os.path.exists(source_dir):
9 | print(f"Source directory {source_dir} does not exist.")
10 | exit()
11 |
12 | # Get a list of all JPEG files in the source directory
13 | jpeg_files = [f for f in os.listdir(source_dir) if f.endswith('.JPEG')]
14 |
15 | # Process each JPEG file
16 | for jpeg_file in jpeg_files:
17 | # Extract the class name from the file name
18 | # Example) ILSVRC2012_val_00012508_n01843065.JPEG => n01843065
19 | class_name = jpeg_file.split('_')[-1].split('.')[0]
20 |
21 | # Define the target directory for this class
22 | target_dir = os.path.join(source_dir, class_name)
23 |
24 | # Create the target directory if it doesn't exist
25 | if not os.path.exists(target_dir):
26 | os.makedirs(target_dir)
27 |
28 | # Define the source and target file paths
29 | source_file = os.path.join(source_dir, jpeg_file)
30 | target_file = os.path.join(target_dir, jpeg_file)
31 |
32 | # Move the file to the target directory
33 | shutil.move(source_file, target_file)
34 |
35 | print("Files have been classified and moved to their respective directories.")
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.9.0
2 | torchvision
3 | webdataset>=0.2.5
4 | regex
5 | ftfy
6 | tqdm
7 | pandas
8 | braceexpand
9 | huggingface_hub
10 | pytest-split==0.8.0
11 | pytest==7.2.0
12 | transformers[sentencepiece]
13 | timm>=1.0.7
14 | fsspec
15 | wandb
--------------------------------------------------------------------------------
/src/dataloaders/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/dataloaders/caltech101.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | import torch
3 | from torch.utils.data import DataLoader, Dataset, random_split
4 | from torchvision.datasets import Caltech101
5 | from .utils import dataset_root
6 |
7 |
8 | root = dataset_root
9 | num_example_train = 3000
10 | num_example_test = 5677
11 | num_classes = 101
12 | mean_per_class = True
13 |
14 | def get_loader_train(
15 | transform, batch_size, num_workers, seed
16 | ) -> Tuple[Dataset, DataLoader]:
17 | dataset = Caltech101(root, download=False, transform=transform)
18 | dataset_train, dataset_test = random_split(
19 | dataset,
20 | lengths=[num_example_train,
21 | num_example_test],
22 | generator=torch.Generator().manual_seed(seed))
23 | return (dataset_train, None)
24 |
25 |
26 | def get_loader_test(
27 | transform, batch_size, num_workers, seed
28 | ) -> Tuple[Dataset, DataLoader]:
29 | dataset = Caltech101(root, download=False, transform=transform)
30 | dataset_train, dataset_test = random_split(
31 | dataset,
32 | lengths=[num_example_train,
33 | num_example_test],
34 | generator=torch.Generator().manual_seed(seed))
35 | return (dataset_test, None)
--------------------------------------------------------------------------------
/src/dataloaders/cifar10.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | from torch.utils.data import DataLoader, Dataset
3 | from torchvision.datasets import CIFAR10
4 | from .utils import dataset_root
5 |
6 |
7 | root = dataset_root
8 | num_example_train = 50000
9 | num_example_test = 10000
10 | num_classes = 10
11 |
12 |
13 | def get_loader_train(
14 | transform, batch_size, num_workers, seed
15 | ) -> Tuple[Dataset, DataLoader]:
16 | dataset_train = CIFAR10(root, download=False, train=True, transform=transform)
17 | return (dataset_train, None)
18 |
19 |
20 | def get_loader_test(
21 | transform, batch_size, num_workers, seed
22 | ) -> Tuple[Dataset, DataLoader]:
23 | dataset = CIFAR10(root, download=True, train=False, transform=transform)
24 | return (dataset, None)
--------------------------------------------------------------------------------
/src/dataloaders/cifar100.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | from torch.utils.data import DataLoader, Dataset
3 | from torchvision.datasets import CIFAR100
4 | from .utils import dataset_root
5 |
6 |
7 | root = dataset_root
8 | num_example_train_val = 50000
9 | num_example_test = 10000
10 | num_classes = 100
11 |
12 |
13 | def get_loader_train(
14 | transform, batch_size, num_workers, seed
15 | ) -> Tuple[Dataset, DataLoader]:
16 | dataset_train_val = CIFAR100(root, download=False, train=True, transform=transform)
17 | return (dataset_train_val, None)
18 |
19 |
20 | def get_loader_test(
21 | transform, batch_size, num_workers, seed
22 | ) -> Tuple[Dataset, DataLoader]:
23 | dataset_test = CIFAR100(root, download=True, train=False, transform=transform)
24 | return (dataset_test, None)
--------------------------------------------------------------------------------
/src/dataloaders/dtd.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | from torch.utils.data import DataLoader, Dataset, ConcatDataset
3 | from torchvision.datasets import DTD
4 | import torchvision
5 | from .utils import dataset_root
6 |
7 | root = dataset_root
8 | num_example_train = 3760
9 | num_example_test = 1880
10 | num_classes = 47
11 |
12 |
13 | class Warper(Dataset):
14 | def __init__(self, dataset) -> None:
15 | super().__init__()
16 | self.dataset = dataset
17 |
18 | def __len__(self) -> int:
19 | return len(self.dataset)
20 |
21 | def __getitem__(self, index):
22 | img, label = self.dataset[index]
23 | if torchvision.__version__ >= "0.13.0":
24 | return img, label
25 | else:
26 | return img, label - 1
27 |
28 |
29 | def get_loader_train(
30 | transform, batch_size, num_workers, seed
31 | ) -> Tuple[Dataset, DataLoader]:
32 | dataset_train = DTD(root, download=True, split='train', transform=transform)
33 | dataset_val = DTD(root, download=True, split='val', transform=transform)
34 | dataset_train = ConcatDataset([dataset_train, dataset_val])
35 | dataset_train = Warper(dataset_train)
36 | return (dataset_train, None)
37 |
38 |
39 | def get_loader_test(
40 | transform, batch_size, num_workers, seed
41 | ) -> Tuple[Dataset, DataLoader]:
42 | dataset_test = DTD(root, download=True, split='test', transform=transform)
43 | return (dataset_test, None)
--------------------------------------------------------------------------------
/src/dataloaders/fgvc_aircraft.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | from torch.utils.data import DataLoader, Dataset
3 | from torchvision.datasets import FGVCAircraft
4 | from .utils import dataset_root
5 |
6 |
7 | root = dataset_root
8 | num_example_train_val = 6667
9 | num_example_test = 3333
10 | num_classes = 100
11 | mean_per_class = True
12 |
13 |
14 | def get_loader_train(
15 | transform, batch_size, num_workers, seed
16 | ) -> Tuple[Dataset, DataLoader]:
17 | dataset_train_val = FGVCAircraft(root, download=True, annotation_level='variant', split='trainval',transform=transform)
18 | return (dataset_train_val, None)
19 |
20 |
21 | def get_loader_test(
22 | transform, batch_size, num_workers, seed
23 | ) -> Tuple[Dataset, DataLoader]:
24 | dataset_test = FGVCAircraft(root, download=True, annotation_level='variant', split='test', transform=transform)
25 | return (dataset_test, None)
26 |
--------------------------------------------------------------------------------
/src/dataloaders/flowers102.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | from torch.utils.data import DataLoader, Dataset, ConcatDataset
3 | from torchvision.datasets import Flowers102
4 | import torchvision
5 | from .utils import dataset_root
6 |
7 | root = dataset_root
8 | num_example_train_val = 2040
9 | num_example_test = 6149
10 | num_classes = 102
11 | mean_per_class = True
12 |
13 |
14 | class Warper(Dataset):
15 | def __init__(self, dataset) -> None:
16 | super().__init__()
17 | self.dataset = dataset
18 |
19 | def __len__(self) -> int:
20 | return len(self.dataset)
21 |
22 | def __getitem__(self, index):
23 | img, label = self.dataset[index]
24 | if torchvision.__version__ >= "0.13.0":
25 | return img, label
26 | else:
27 | return img, label - 1
28 |
29 |
30 | def get_loader_train(
31 | transform, batch_size, num_workers, seed
32 | ) -> Tuple[Dataset, DataLoader]:
33 | dataset_train = Flowers102(root, download=True, split='train', transform=transform)
34 | dataset_val = Flowers102(root, download=True, split='val', transform=transform)
35 | dataset_train = ConcatDataset([dataset_train, dataset_val])
36 | dataset_train = Warper(dataset_train)
37 | return (dataset_train, None)
38 |
39 |
40 | def get_loader_test(
41 | transform, batch_size, num_workers, seed
42 | ) -> Tuple[Dataset, DataLoader]:
43 | dataset_test = Flowers102(root, download=True, split='test', transform=transform)
44 | dataset_test = Warper(dataset_test)
45 | return (dataset_test, None)
46 |
--------------------------------------------------------------------------------
/src/dataloaders/food101.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | from torch.utils.data import DataLoader, Dataset
3 | from torchvision.datasets import Food101
4 | from .utils import dataset_root
5 |
6 |
7 | root = dataset_root
8 | num_example_train_val = 75750
9 | num_example_test = 25250
10 | num_classes = 101
11 |
12 |
13 | def get_loader_train(
14 | transform, batch_size, num_workers, seed
15 | ) -> Tuple[Dataset, DataLoader]:
16 | dataset_train_val = Food101(root, download=True, split='train', transform=transform)
17 | return (dataset_train_val, None)
18 |
19 |
20 | def get_loader_test(
21 | transform, batch_size, num_workers, seed
22 | ) -> Tuple[Dataset, DataLoader]:
23 | dataset_test = Food101(root, download=True, split='test', transform=transform)
24 | return (dataset_test, None)
--------------------------------------------------------------------------------
/src/dataloaders/oxford_pets.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | from torch.utils.data import DataLoader, Dataset
3 | from torchvision.datasets import OxfordIIITPet
4 | from .utils import dataset_root
5 |
6 |
7 | root = dataset_root
8 | num_example_train_val = 3680
9 | num_example_test = 3699
10 | num_classes = 37
11 | mean_per_class = True
12 |
13 |
14 | def get_loader_train(
15 | transform, batch_size, num_workers, seed
16 | ) -> Tuple[Dataset, DataLoader]:
17 | dataset_train_val = OxfordIIITPet(root, download=True, split='trainval',transform=transform)
18 | return (dataset_train_val, None)
19 |
20 |
21 | def get_loader_test(
22 | transform, batch_size, num_workers, seed
23 | ) -> Tuple[Dataset, DataLoader]:
24 | dataset_test = OxfordIIITPet(root, download=True, split='test', transform=transform)
25 | return (dataset_test, None)
26 |
--------------------------------------------------------------------------------
/src/dataloaders/stanford_car.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | from torch.utils.data import DataLoader, Dataset
3 | from torchvision.datasets import StanfordCars
4 | from .utils import dataset_root
5 |
6 |
7 | root = dataset_root
8 | num_example_train_val = 8144
9 | num_example_test = 8041
10 | num_classes = 196
11 |
12 |
13 | def get_loader_train(
14 | transform, batch_size, num_workers, seed
15 | ) -> Tuple[Dataset, DataLoader]:
16 | dataset_train = StanfordCars(
17 | root, download=False, split='train', transform=transform)
18 | return (dataset_train, None)
19 |
20 |
21 | def get_loader_test(
22 | transform, batch_size, num_workers, seed
23 | ) -> Tuple[Dataset, DataLoader]:
24 | dataset_test = StanfordCars(
25 | root, download=False, split='test', transform=transform)
26 | return (dataset_test, None)
--------------------------------------------------------------------------------
/src/dataloaders/sun397.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | import torch
3 | from torch.utils.data import DataLoader, Dataset, random_split
4 | from torchvision.datasets import SUN397
5 | from .utils import dataset_root
6 |
7 |
8 | num_example_train = 19850
9 | num_example_test = 19850
10 | num_example_others = 69054
11 | num_classes = 397
12 | root = dataset_root
13 |
14 |
15 | def get_loader_train(
16 | transform, batch_size, num_workers, seed
17 | ) -> Tuple[Dataset, DataLoader]:
18 | dataset = SUN397(root, transform=transform)
19 | dataset_train, dataset_test, others = random_split(
20 | dataset,
21 | lengths=[num_example_train,
22 | num_example_test,
23 | num_example_others,],
24 | generator=torch.Generator().manual_seed(seed + hash("sun397") % 2048))
25 | return (dataset_train, None)
26 |
27 |
28 | def get_loader_test(
29 | transform, batch_size, num_workers, seed
30 | ) -> Tuple[Dataset, DataLoader]:
31 | dataset = SUN397(root, download=True, transform=transform)
32 | dataset_train, dataset_test, others = random_split(
33 | dataset,
34 | lengths=[num_example_train,
35 | num_example_test,
36 | num_example_others,],
37 | generator=torch.Generator().manual_seed(seed + hash("sun397") % 2048))
38 | return (dataset_test, None)
--------------------------------------------------------------------------------
/src/dataloaders/templates.json:
--------------------------------------------------------------------------------
1 | {
2 | "food101": [
3 | "a photo of {}, a type of food."
4 | ],
5 | "cifar10": [
6 | "a photo of a {}.",
7 | "a blurry photo of a {}.",
8 | "a black and white photo of a {}.",
9 | "a low contrast photo of a {}.",
10 | "a high contrast photo of a {}.",
11 | "a bad photo of a {}.",
12 | "a good photo of a {}.",
13 | "a photo of a small {}.",
14 | "a photo of a big {}.",
15 | "a photo of the {}.",
16 | "a blurry photo of the {}.",
17 | "a black and white photo of the {}.",
18 | "a low contrast photo of the {}.",
19 | "a high contrast photo of the {}.",
20 | "a bad photo of the {}.",
21 | "a good photo of the {}.",
22 | "a photo of the small {}.",
23 | "a photo of the big {}."
24 | ],
25 | "cifar100": [
26 | "a photo of a {}.",
27 | "a blurry photo of a {}.",
28 | "a black and white photo of a {}.",
29 | "a low contrast photo of a {}.",
30 | "a high contrast photo of a {}.",
31 | "a bad photo of a {}.",
32 | "a good photo of a {}.",
33 | "a photo of a small {}.",
34 | "a photo of a big {}.",
35 | "a photo of the {}.",
36 | "a blurry photo of the {}.",
37 | "a black and white photo of the {}.",
38 | "a low contrast photo of the {}.",
39 | "a high contrast photo of the {}.",
40 | "a bad photo of the {}.",
41 | "a good photo of the {}.",
42 | "a photo of the small {}.",
43 | "a photo of the big {}."
44 | ],
45 | "birdsnap": [
46 | "a photo of a {}, a type of bird."
47 | ],
48 | "cub": [
49 | "a photo of a {}, a type of bird."
50 | ],
51 | "imagenet": [
52 | "a bad photo of a {}.",
53 | "a photo of many {}.",
54 | "a sculpture of a {}.",
55 | "a photo of the hard to see {}.",
56 | "a low resolution photo of the {}.",
57 | "a rendering of a {}.",
58 | "graffiti of a {}.",
59 | "a bad photo of the {}.",
60 | "a cropped photo of the {}.",
61 | "a tattoo of a {}.",
62 | "the embroidered {}.",
63 | "a photo of a hard to see {}.",
64 | "a bright photo of a {}.",
65 | "a photo of a clean {}.",
66 | "a photo of a dirty {}.",
67 | "a dark photo of the {}.",
68 | "a drawing of a {}.",
69 | "a photo of my {}.",
70 | "the plastic {}.",
71 | "a photo of the cool {}.",
72 | "a close-up photo of a {}.",
73 | "a black and white photo of the {}.",
74 | "a painting of the {}.",
75 | "a painting of a {}.",
76 | "a pixelated photo of the {}.",
77 | "a sculpture of the {}.",
78 | "a bright photo of the {}.",
79 | "a cropped photo of a {}.",
80 | "a plastic {}.",
81 | "a photo of the dirty {}.",
82 | "a jpeg corrupted photo of a {}.",
83 | "a blurry photo of the {}.",
84 | "a photo of the {}.",
85 | "a good photo of the {}.",
86 | "a rendering of the {}.",
87 | "a {} in a video game.",
88 | "a photo of one {}.",
89 | "a doodle of a {}.",
90 | "a close-up photo of the {}.",
91 | "a photo of a {}.",
92 | "the origami {}.",
93 | "the {} in a video game.",
94 | "a sketch of a {}.",
95 | "a doodle of the {}.",
96 | "a origami {}.",
97 | "a low resolution photo of a {}.",
98 | "the toy {}.",
99 | "a rendition of the {}.",
100 | "a photo of the clean {}.",
101 | "a photo of a large {}.",
102 | "a rendition of a {}.",
103 | "a photo of a nice {}.",
104 | "a photo of a weird {}.",
105 | "a blurry photo of a {}.",
106 | "a cartoon {}.",
107 | "art of a {}.",
108 | "a sketch of the {}.",
109 | "a embroidered {}.",
110 | "a pixelated photo of a {}.",
111 | "itap of the {}.",
112 | "a jpeg corrupted photo of the {}.",
113 | "a good photo of a {}.",
114 | "a plushie {}.",
115 | "a photo of the nice {}.",
116 | "a photo of the small {}.",
117 | "a photo of the weird {}.",
118 | "the cartoon {}.",
119 | "art of the {}.",
120 | "a drawing of the {}.",
121 | "a photo of the large {}.",
122 | "a black and white photo of a {}.",
123 | "the plushie {}.",
124 | "a dark photo of a {}.",
125 | "itap of a {}.",
126 | "graffiti of the {}.",
127 | "a toy {}.",
128 | "itap of my {}.",
129 | "a photo of a cool {}.",
130 | "a photo of a small {}.",
131 | "a tattoo of the {}."
132 | ],
133 | "imagenet_a": [
134 | "a bad photo of a {}.",
135 | "a photo of many {}.",
136 | "a sculpture of a {}.",
137 | "a photo of the hard to see {}.",
138 | "a low resolution photo of the {}.",
139 | "a rendering of a {}.",
140 | "graffiti of a {}.",
141 | "a bad photo of the {}.",
142 | "a cropped photo of the {}.",
143 | "a tattoo of a {}.",
144 | "the embroidered {}.",
145 | "a photo of a hard to see {}.",
146 | "a bright photo of a {}.",
147 | "a photo of a clean {}.",
148 | "a photo of a dirty {}.",
149 | "a dark photo of the {}.",
150 | "a drawing of a {}.",
151 | "a photo of my {}.",
152 | "the plastic {}.",
153 | "a photo of the cool {}.",
154 | "a close-up photo of a {}.",
155 | "a black and white photo of the {}.",
156 | "a painting of the {}.",
157 | "a painting of a {}.",
158 | "a pixelated photo of the {}.",
159 | "a sculpture of the {}.",
160 | "a bright photo of the {}.",
161 | "a cropped photo of a {}.",
162 | "a plastic {}.",
163 | "a photo of the dirty {}.",
164 | "a jpeg corrupted photo of a {}.",
165 | "a blurry photo of the {}.",
166 | "a photo of the {}.",
167 | "a good photo of the {}.",
168 | "a rendering of the {}.",
169 | "a {} in a video game.",
170 | "a photo of one {}.",
171 | "a doodle of a {}.",
172 | "a close-up photo of the {}.",
173 | "a photo of a {}.",
174 | "the origami {}.",
175 | "the {} in a video game.",
176 | "a sketch of a {}.",
177 | "a doodle of the {}.",
178 | "a origami {}.",
179 | "a low resolution photo of a {}.",
180 | "the toy {}.",
181 | "a rendition of the {}.",
182 | "a photo of the clean {}.",
183 | "a photo of a large {}.",
184 | "a rendition of a {}.",
185 | "a photo of a nice {}.",
186 | "a photo of a weird {}.",
187 | "a blurry photo of a {}.",
188 | "a cartoon {}.",
189 | "art of a {}.",
190 | "a sketch of the {}.",
191 | "a embroidered {}.",
192 | "a pixelated photo of a {}.",
193 | "itap of the {}.",
194 | "a jpeg corrupted photo of the {}.",
195 | "a good photo of a {}.",
196 | "a plushie {}.",
197 | "a photo of the nice {}.",
198 | "a photo of the small {}.",
199 | "a photo of the weird {}.",
200 | "the cartoon {}.",
201 | "art of the {}.",
202 | "a drawing of the {}.",
203 | "a photo of the large {}.",
204 | "a black and white photo of a {}.",
205 | "the plushie {}.",
206 | "a dark photo of a {}.",
207 | "itap of a {}.",
208 | "graffiti of the {}.",
209 | "a toy {}.",
210 | "itap of my {}.",
211 | "a photo of a cool {}.",
212 | "a photo of a small {}.",
213 | "a tattoo of the {}."
214 | ],
215 | "sst2": [
216 | "a {} review of a movie."
217 | ],
218 | "hateful_memes": [
219 | "a {}."
220 | ],
221 | "clevr": [
222 | "a photo of {} objects."
223 | ],
224 | "kinetics700": [
225 | "a photo of {}.",
226 | "a photo of a person {}.",
227 | "a photo of a person using {}.",
228 | "a photo of a person doing {}.",
229 | "a photo of a person during {}.",
230 | "a photo of a person performing {}.",
231 | "a photo of a person practicing {}.",
232 | "a video of {}.",
233 | "a video of a person {}.",
234 | "a video of a person using {}.",
235 | "a video of a person doing {}.",
236 | "a video of a person during {}.",
237 | "a video of a person performing {}.",
238 | "a video of a person practicing {}.",
239 | "a example of {}.",
240 | "a example of a person {}.",
241 | "a example of a person using {}.",
242 | "a example of a person doing {}.",
243 | "a example of a person during {}.",
244 | "a example of a person performing {}.",
245 | "a example of a person practicing {}.",
246 | "a demonstration of {}.",
247 | "a demonstration of a person {}.",
248 | "a demonstration of a person using {}.",
249 | "a demonstration of a person doing {}.",
250 | "a demonstration of a person during {}.",
251 | "a demonstration of a person performing {}.",
252 | "a demonstration of a person practicing {}."
253 | ],
254 | "ucf101": [
255 | "a photo of a person {}.",
256 | "a video of a person {}.",
257 | "a example of a person {}.",
258 | "a demonstration of a person {}.",
259 | "a photo of the person {}.",
260 | "a video of the person {}.",
261 | "a example of the person {}.",
262 | "a demonstration of the person {}.",
263 | "a photo of a person using {}.",
264 | "a video of a person using {}.",
265 | "a example of a person using {}.",
266 | "a demonstration of a person using {}.",
267 | "a photo of the person using {}.",
268 | "a video of the person using {}.",
269 | "a example of the person using {}.",
270 | "a demonstration of the person using {}.",
271 | "a photo of a person doing {}.",
272 | "a video of a person doing {}.",
273 | "a example of a person doing {}.",
274 | "a demonstration of a person doing {}.",
275 | "a photo of the person doing {}.",
276 | "a video of the person doing {}.",
277 | "a example of the person doing {}.",
278 | "a demonstration of the person doing {}.",
279 | "a photo of a person during {}.",
280 | "a video of a person during {}.",
281 | "a example of a person during {}.",
282 | "a demonstration of a person during {}.",
283 | "a photo of the person during {}.",
284 | "a video of the person during {}.",
285 | "a example of the person during {}.",
286 | "a demonstration of the person during {}.",
287 | "a photo of a person performing {}.",
288 | "a video of a person performing {}.",
289 | "a example of a person performing {}.",
290 | "a demonstration of a person performing {}.",
291 | "a photo of the person performing {}.",
292 | "a video of the person performing {}.",
293 | "a example of the person performing {}.",
294 | "a demonstration of the person performing {}.",
295 | "a photo of a person practicing {}.",
296 | "a video of a person practicing {}.",
297 | "a example of a person practicing {}.",
298 | "a demonstration of a person practicing {}.",
299 | "a photo of the person practicing {}.",
300 | "a video of the person practicing {}.",
301 | "a example of the person practicing {}.",
302 | "a demonstration of the person practicing {}."
303 | ],
304 | "pcam": [
305 | "this is a photo of {}"
306 | ],
307 | "country211": [
308 | "a photo i took in {}.",
309 | "a photo i took while visiting {}.",
310 | "a photo from my home country of {}.",
311 | "a photo from my visit to {}.",
312 | "a photo showing the country of {}."
313 | ],
314 | "kitti": [
315 | "{}"
316 | ],
317 | "gtsrb": [
318 | "a zoomed in photo of a \"{}\" traffic sign.",
319 | "a centered photo of a \"{}\" traffic sign.",
320 | "a close up photo of a \"{}\" traffic sign."
321 | ],
322 | "resisc45": [
323 | "satellite imagery of {}.",
324 | "aerial imagery of {}.",
325 | "satellite photo of {}.",
326 | "aerial photo of {}.",
327 | "satellite view of {}.",
328 | "aerial view of {}.",
329 | "satellite imagery of a {}.",
330 | "aerial imagery of a {}.",
331 | "satellite photo of a {}.",
332 | "aerial photo of a {}.",
333 | "satellite view of a {}.",
334 | "aerial view of a {}.",
335 | "satellite imagery of the {}.",
336 | "aerial imagery of the {}.",
337 | "satellite photo of the {}.",
338 | "aerial photo of the {}.",
339 | "satellite view of the {}.",
340 | "aerial view of the {}."
341 | ],
342 | "eurosat": [
343 | "a centered satellite photo of {}.",
344 | "a centered satellite photo of a {}.",
345 | "a centered satellite photo of the {}."
346 | ],
347 | "stl10": [
348 | "a photo of a {}.",
349 | "a photo of the {}."
350 | ],
351 | "fer2013": [
352 | "a photo of a {} looking face.",
353 | "a photo of a face showing the emotion: {}.",
354 | "a photo of a face looking {}.",
355 | "a face that looks {}.",
356 | "they look {}.",
357 | "look at how {} they are."
358 | ],
359 | "mnist": [
360 | "a photo of the number: \"{}\"."
361 | ],
362 | "flowers": [
363 | "a photo of a {}, a type of flower."
364 | ],
365 | "caltech101": [
366 | "a photo of a {}.",
367 | "a painting of a {}.",
368 | "a plastic {}.",
369 | "a sculpture of a {}.",
370 | "a sketch of a {}.",
371 | "a tattoo of a {}.",
372 | "a toy {}.",
373 | "a rendition of a {}.",
374 | "a embroidered {}.",
375 | "a cartoon {}.",
376 | "a {} in a video game.",
377 | "a plushie {}.",
378 | "a origami {}.",
379 | "art of a {}.",
380 | "graffiti of a {}.",
381 | "a drawing of a {}.",
382 | "a doodle of a {}.",
383 | "a photo of the {}.",
384 | "a painting of the {}.",
385 | "the plastic {}.",
386 | "a sculpture of the {}.",
387 | "a sketch of the {}.",
388 | "a tattoo of the {}.",
389 | "the toy {}.",
390 | "a rendition of the {}.",
391 | "the embroidered {}.",
392 | "the cartoon {}.",
393 | "the {} in a video game.",
394 | "the plushie {}.",
395 | "the origami {}.",
396 | "art of the {}.",
397 | "graffiti of the {}.",
398 | "a drawing of the {}.",
399 | "a doodle of the {}."
400 | ],
401 | "pets": [
402 | "a photo of a {}, a type of pet."
403 | ],
404 | "dtd": [
405 | "a photo of a {} texture.",
406 | "a photo of a {} pattern.",
407 | "a photo of a {} thing.",
408 | "a photo of a {} object.",
409 | "a photo of the {} texture.",
410 | "a photo of the {} pattern.",
411 | "a photo of the {} thing.",
412 | "a photo of the {} object."
413 | ],
414 | "voc2007": [
415 | "a photo of a {}."
416 | ],
417 | "aircraft": [
418 | "a photo of a {}, a type of aircraft.",
419 | "a photo of the {}, a type of aircraft."
420 | ],
421 | "stanford_car": [
422 | "a photo of a {}.",
423 | "a photo of the {}.",
424 | "a photo of my {}.",
425 | "i love my {}!",
426 | "a photo of my dirty {}.",
427 | "a photo of my clean {}.",
428 | "a photo of my new {}.",
429 | "a photo of my old {}."
430 | ],
431 | "sun397": [
432 | "a photo of a {}.",
433 | "a photo of the {}."
434 | ]
435 | }
--------------------------------------------------------------------------------
/src/dataloaders/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | import numpy as np
4 | import torch
5 | from torch import distributed as dist
6 | from torch.utils.data import DistributedSampler as _DistributedSampler
7 | import os
8 | try:
9 | rank = int(os.environ["RANK"])
10 | world_size = int(os.environ["WORLD_SIZE"])
11 | except KeyError:
12 | rank = 0
13 | world_size = 1
14 |
15 | # torchvision.datasets will automatically download the datasets
16 | # Set your directory to classification dataset.
17 | dataset_root = "/your/directory/to/classification/dataset"
18 |
19 | def worker_init_fn(worker_id, num_workers, rank, seed):
20 | # The seed of each worker equals to
21 | # num_worker * rank + worker_id + user_seed
22 | worker_seed = num_workers * rank + worker_id + seed
23 | np.random.seed(worker_seed)
24 | random.seed(worker_seed)
25 | torch.manual_seed(worker_seed)
26 |
27 |
28 | def get_dist_info():
29 | if dist.is_available() and dist.is_initialized():
30 | rank = dist.get_rank()
31 | world_size = dist.get_world_size()
32 | else:
33 | rank = 0
34 | world_size = 1
35 |
36 | return rank, world_size
37 |
38 |
39 | def sync_random_seed(seed=None, device="cuda"):
40 | """Make sure different ranks share the same seed.
41 | All workers must call this function, otherwise it will deadlock.
42 | This method is generally used in `DistributedSampler`,
43 | because the seed should be identical across all processes
44 | in the distributed group.
45 | In distributed sampling, different ranks should sample non-overlapped
46 | data in the dataset. Therefore, this function is used to make sure that
47 | each rank shuffles the data indices in the same order based
48 | on the same seed. Then different ranks could use different indices
49 | to select non-overlapped data from the same data list.
50 | Args:
51 | seed (int, Optional): The seed. Default to None.
52 | device (str): The device where the seed will be put on.
53 | Default to 'cuda'.
54 | Returns:
55 | int: Seed to be used.
56 | """
57 | if seed is None:
58 | seed = np.random.randint(2**31)
59 | assert isinstance(seed, int)
60 |
61 | rank, world_size = get_dist_info()
62 |
63 | if world_size == 1:
64 | return seed
65 |
66 | if rank == 0:
67 | random_num = torch.tensor(seed, dtype=torch.int32, device=device)
68 | else:
69 | random_num = torch.tensor(0, dtype=torch.int32, device=device)
70 |
71 | dist.broadcast(random_num, src=0)
72 |
73 | return random_num.item()
74 |
75 |
76 | class DistributedSampler(_DistributedSampler):
77 | def __init__(
78 | self,
79 | dataset,
80 | num_replicas=None, # world_size
81 | rank=None, # local_rank
82 | shuffle=True,
83 | seed=0,
84 | ):
85 |
86 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
87 |
88 | # In distributed sampling, different ranks should sample
89 | # non-overlapped data in the dataset. Therefore, this function
90 | # is used to make sure that each rank shuffles the data indices
91 | # in the same order based on the same seed. Then different ranks
92 | # could use different indices to select non-overlapped data from the
93 | # same data list.
94 | self.seed = sync_random_seed(seed)
95 |
96 | def __iter__(self):
97 | # deterministically shuffle based on epoch
98 | if self.shuffle:
99 | g = torch.Generator()
100 | # When :attr:`shuffle=True`, this ensures all replicas
101 | # use a different random ordering for each epoch.
102 | # Otherwise, the next iteration of this sampler will
103 | # yield the same ordering.
104 | g.manual_seed(self.epoch + self.seed)
105 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
106 | else:
107 | indices = torch.arange(len(self.dataset)).tolist()
108 |
109 | # add extra samples to make it evenly divisible
110 | # in case that indices is shorter than half of total_size
111 | indices = (indices * math.ceil(self.total_size / len(indices)))[
112 | : self.total_size
113 | ]
114 | assert len(indices) == self.total_size
115 |
116 | # subsample
117 | indices = indices[self.rank : self.total_size : self.num_replicas]
118 | assert len(indices) == self.num_samples
119 |
120 | return iter(indices)
--------------------------------------------------------------------------------
/src/inference_classification.sh:
--------------------------------------------------------------------------------
1 | # COSMOS models
2 | # --model ViT-B-16
3 | # --huggingface-model-name [cosmos_vitb16_cc3m, cosmos_vitb16_cc12m, cosmos_vitb16_yfcc15m, cosmos_vitb16_merged30m, cosmos_vitb16_pixelprose]
4 | # --model ViT-B-32
5 | # --huggingface-model-name [cosmos_vitb32_cc3m, cosmos_vitb32_cc12m, cosmos_vitb32_yfcc15m, cosmos_vitb32_merged30m, cosmos_vitb32_pixelprose]
6 | torchrun --nproc_per_node 1 -m main \
7 | --model ViT-B-16 \
8 | --huggingface-repo-name sankim2/cosmos \
9 | --huggingface-model-name cosmos_vitb16_merged30m.pt \
10 | --val-data classification \
11 | --imagenet-val /directory/to/your/imagenet/data/val_images \
12 | --batch-size 256 \
13 | --workers 16 \
14 | --output-all \
15 | --attentional-pool \
16 | --cosmos \
17 |
18 | # OpenCLIP models
19 | # --model ViT-B-16 --pretrained [laion400m_e32, datacomp_xl_s13b_b90k, laion2b_s34b_b88k]
20 | # --model ViT-B-32 --pretrained [laion400m_e32, datacomp_xl_s13b_b90k, laion2b_s34b_b79k]
21 | torchrun --nproc_per_node 1 -m main \
22 | --model ViT-B-16 \
23 | --pretrained laion400m_e32 \
24 | --val-data classification \
25 | --imagenet-val /directory/to/your/imagenet/data/val_images \
26 | --batch-size 256 \
27 | --workers 16 \
--------------------------------------------------------------------------------
/src/inference_retrieval.sh:
--------------------------------------------------------------------------------
1 | # COSMOS models
2 | # --model ViT-B-16
3 | # --huggingface-model-name [cosmos_vitb16_cc3m, cosmos_vitb16_cc12m, cosmos_vitb16_yfcc15m, cosmos_vitb16_merged30m, cosmos_vitb16_pixelprose]
4 | # --model ViT-B-32
5 | # --huggingface-model-name [cosmos_vitb32_cc3m, cosmos_vitb32_cc12m, cosmos_vitb32_yfcc15m, cosmos_vitb32_merged30m, cosmos_vitb32_pixelprose]
6 | torchrun --nproc_per_node 1 -m main \
7 | --model ViT-B-16 \
8 | --huggingface-repo-name sankim2/cosmos \
9 | --huggingface-model-name cosmos_vitb16_merged30m.pt \
10 | --val-data retrieval \
11 | --data-root-dir /directory/to/your/coco/and/flickr30k/ \
12 | --batch-size 256 \
13 | --workers 16 \
14 | --output-all \
15 | --attentional-pool \
16 | --cosmos \
17 |
18 | # OpenCLIP models
19 | # --model ViT-B-16 --pretrained [laion400m_e32, datacomp_xl_s13b_b90k, laion2b_s34b_b88k]
20 | # --model ViT-B-32 --pretrained [laion400m_e32, datacomp_xl_s13b_b90k, laion2b_s34b_b79k]
21 | torchrun --nproc_per_node 1 -m main \
22 | --model ViT-B-16 \
23 | --pretrained laion400m_e32 \
24 | --val-data retrieval \
25 | --data-root-dir /directory/to/your/coco/and/flickr30k/ \
26 | --batch-size 256 \
27 | --workers 16 \
--------------------------------------------------------------------------------
/src/inference_segmentation.sh:
--------------------------------------------------------------------------------
1 | # COSMOS models
2 | # --model ViT-B-16
3 | # --huggingface-model-name [cosmos_vitb16_cc3m, cosmos_vitb16_cc12m, cosmos_vitb16_yfcc15m, cosmos_vitb16_merged30m, cosmos_vitb16_pixelprose]
4 | # --model ViT-B-32
5 | # --huggingface-model-name [cosmos_vitb32_cc3m, cosmos_vitb32_cc12m, cosmos_vitb32_yfcc15m, cosmos_vitb32_merged30m, cosmos_vitb32_pixelprose]
6 | # --seg-w-background : segmentation with background benchmarks
7 | # --use-csa : using Correlative Self-Attention (CSA) block from SCLIP
8 | torchrun --nproc_per_node 1 -m seg_eval.py \
9 | --model ViT-B-16 \
10 | --huggingface-repo-name sankim2/cosmos \
11 | --huggingface-model-name cosmos_vitb16_merged30m.pt \
12 | --batch-size 256 \
13 | --workers 16 \
14 | --output-all \
15 | --attentional-pool \
16 | --cosmos \
17 | --seg-w-background \
18 | --use-csa
19 |
20 | # OpenCLIP models
21 | # --model ViT-B-16 --pretrained [laion400m_e32, datacomp_xl_s13b_b90k, laion2b_s34b_b88k]
22 | # --model ViT-B-32 --pretrained [laion400m_e32, datacomp_xl_s13b_b90k, laion2b_s34b_b79k]
23 | torchrun --nproc_per_node 1 -m seg_eval.py \
24 | --model ViT-B-16 \
25 | --pretrained laion400m_e32 \
26 | --batch-size 256 \
27 | --workers 16 \
--------------------------------------------------------------------------------
/src/open_clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .coca_model import CoCa
2 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
3 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
4 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint
5 | from .loss import ClipLoss, DistillClipLoss, CoCaLoss
6 | from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
7 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \
8 | get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg
9 | from .openai import load_openai_model, list_openai_models
10 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
11 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
12 | from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
13 | from .tokenizer import SimpleTokenizer, tokenize, decode
14 | from .transform import image_transform, AugmentationCfg
15 | from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy
16 | from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES
17 |
--------------------------------------------------------------------------------
/src/open_clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExplainableML/cosmos/c80348f08e9b02c5a81adadd7dce486b3c284b78/src/open_clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/src/open_clip/constants.py:
--------------------------------------------------------------------------------
1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
3 | IMAGENET_MEAN = (0.485, 0.456, 0.406)
4 | IMAGENET_STD = (0.229, 0.224, 0.225)
5 | INCEPTION_MEAN = (0.5, 0.5, 0.5)
6 | INCEPTION_STD = (0.5, 0.5, 0.5)
7 |
--------------------------------------------------------------------------------
/src/open_clip/convert.py:
--------------------------------------------------------------------------------
1 | """ Conversion functions for 3rd part state-dicts and non-torch native checkpoint formats.
2 | """
3 | from typing import Union
4 |
5 | import torch
6 | import numpy as np
7 |
8 | from .model import CLIP, CustomTextCLIP
9 | from .transformer import TextTransformer, Transformer
10 |
11 |
12 | @torch.no_grad()
13 | def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str):
14 | """ Load weights from .npz checkpoints for official Google big_vision image-text models
15 |
16 | Currently the SigLIP source models are supported and a CustomTextCLIP destination model
17 | w/ timm image encoder.
18 | """
19 | from timm.layers import resample_patch_embed, resample_abs_pos_embed
20 |
21 | def _n2p(w, t=True):
22 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
23 | w = w.flatten()
24 | if t:
25 | if w.ndim == 4:
26 | w = w.transpose([3, 2, 0, 1])
27 | elif w.ndim == 3:
28 | w = w.transpose([2, 0, 1])
29 | elif w.ndim == 2:
30 | w = w.transpose([1, 0])
31 | return torch.from_numpy(w)
32 |
33 | w = np.load(checkpoint_path)
34 | interpolation = 'bilinear'
35 | antialias = False
36 |
37 | def _convert_timm_img(module, prefix):
38 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
39 | if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]:
40 | embed_conv_w = resample_patch_embed(
41 | embed_conv_w,
42 | module.patch_embed.proj.weight.shape[-2:],
43 | interpolation=interpolation,
44 | antialias=antialias,
45 | verbose=True,
46 | )
47 | module.patch_embed.proj.weight.copy_(embed_conv_w)
48 | module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
49 |
50 | if module.cls_token is not None:
51 | module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
52 |
53 | pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False)
54 | if pos_embed_w.shape != module.pos_embed.shape:
55 | assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}'
56 | num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1)
57 | pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
58 | pos_embed_w,
59 | new_size=module.patch_embed.grid_size,
60 | num_prefix_tokens=num_prefix_tokens,
61 | interpolation=interpolation,
62 | antialias=antialias,
63 | verbose=True,
64 | )
65 | module.pos_embed.copy_(pos_embed_w)
66 |
67 | mha_sub, b_sub, ln1_sub = (0, 0, 1)
68 | for i, block in enumerate(module.blocks.children()):
69 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
70 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
71 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
72 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
73 | block.attn.qkv.weight.copy_(torch.cat([
74 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
75 | block.attn.qkv.bias.copy_(torch.cat([
76 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
77 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
78 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
79 | for r in range(2):
80 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel']))
81 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias']))
82 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale']))
83 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias']))
84 |
85 | module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
86 | module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
87 |
88 | if module.attn_pool is not None:
89 | block_prefix = f'{prefix}MAPHead_0/'
90 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
91 | module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False))
92 | module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T)
93 | module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1))
94 | module.attn_pool.kv.weight.copy_(torch.cat([
95 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')]))
96 | module.attn_pool.kv.bias.copy_(torch.cat([
97 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')]))
98 | module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
99 | module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
100 | module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
101 | module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
102 | for r in range(2):
103 | getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel']))
104 | getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias']))
105 |
106 | def _convert_openclip_transformer(module: Transformer, prefix):
107 | for i, block in enumerate(module.resblocks.children()):
108 | block_prefix = f'{prefix}encoderblock_{i}/'
109 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
110 | block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
111 | block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
112 | block.attn.in_proj_weight.copy_(torch.cat([
113 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
114 | block.attn.in_proj_bias.copy_(torch.cat([
115 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
116 | block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
117 | block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
118 | block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale']))
119 | block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias']))
120 | block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel']))
121 | block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias']))
122 | block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel']))
123 | block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias']))
124 |
125 | def _convert_openclip_txt(module: TextTransformer, prefix):
126 | module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False))
127 | pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0)
128 | module.positional_embedding.copy_(pos_embed_w)
129 | _convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/')
130 | module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale']))
131 | module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias']))
132 | module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
133 | module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias']))
134 |
135 | _convert_timm_img(model.visual.trunk, 'params/img/')
136 | _convert_openclip_txt(model.text, 'params/txt/')
137 | model.logit_bias.copy_(_n2p(w['params/b'])[0])
138 | model.logit_scale.copy_(_n2p(w['params/t'])[0])
139 |
140 |
141 | @torch.no_grad()
142 | def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fastvit = True):
143 |
144 | def _convert_timm_img(state_dict):
145 | if fastvit:
146 | from timm.models.fastvit import checkpoint_filter_fn
147 | else:
148 | from timm.models.vision_transformer_hybrid import checkpoint_filter_fn
149 | timm_state_dict = checkpoint_filter_fn(state_dict, model.visual.trunk)
150 | timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()}
151 | return timm_state_dict
152 |
153 | def _convert_openclip_txt(state_dict, prefix='text_encoder.'):
154 | text_dict = {}
155 | for k, v in state_dict.items():
156 | if not k.startswith(prefix):
157 | continue
158 | k = k.replace(prefix, '')
159 | k = k.replace('projection_layer', 'text_projection')
160 | k = k.replace('embedding_layer', 'token_embedding')
161 | if k.startswith('positional_embedding.pos_embed.pos_embed'):
162 | k = k.replace('positional_embedding.pos_embed.pos_embed', 'positional_embedding')
163 | v = v.squeeze()
164 | k = k.replace('final_layer_norm', 'ln_final')
165 | k = k.replace('pre_norm_mha.0', 'ln_1')
166 | k = k.replace('pre_norm_mha.1', 'attn')
167 | k = k.replace('pre_norm_ffn.0', 'ln_2')
168 | k = k.replace('pre_norm_ffn.1', 'mlp.c_fc')
169 | k = k.replace('pre_norm_ffn.4', 'mlp.c_proj')
170 | k = k.replace('qkv_proj.weight', 'in_proj_weight')
171 | k = k.replace('qkv_proj.bias', 'in_proj_bias')
172 | k = k.replace('transformer.', 'transformer.resblocks.')
173 | text_dict['text.' + k] = v
174 | return text_dict
175 |
176 | image_dict = _convert_timm_img(state_dict)
177 | text_dict = _convert_openclip_txt(state_dict)
178 | out_dict = {**image_dict, **text_dict}
179 | out_dict['logit_scale'] = state_dict['logit_scale']
180 | return out_dict
181 |
182 |
183 | def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict):
184 | if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict:
185 | # Apple MobileCLIP s1 & s2 state_dicts (s0 and b not currently supported)
186 | state_dict = convert_mobile_clip_state_dict(model, state_dict)
187 | if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict:
188 | # convert b model
189 | state_dict = convert_mobile_clip_state_dict(model, state_dict, fastvit=False)
190 | return state_dict
191 |
--------------------------------------------------------------------------------
/src/open_clip/hf_configs.py:
--------------------------------------------------------------------------------
1 | # HF architecture dict:
2 | arch_dict = {
3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
4 | "roberta": {
5 | "config_names": {
6 | "context_length": "max_position_embeddings",
7 | "vocab_size": "vocab_size",
8 | "width": "hidden_size",
9 | "heads": "num_attention_heads",
10 | "layers": "num_hidden_layers",
11 | "layer_attr": "layer",
12 | "token_embeddings_attr": "embeddings"
13 | },
14 | "pooler": "mean_pooler",
15 | },
16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17 | "xlm-roberta": {
18 | "config_names": {
19 | "context_length": "max_position_embeddings",
20 | "vocab_size": "vocab_size",
21 | "width": "hidden_size",
22 | "heads": "num_attention_heads",
23 | "layers": "num_hidden_layers",
24 | "layer_attr": "layer",
25 | "token_embeddings_attr": "embeddings"
26 | },
27 | "pooler": "mean_pooler",
28 | },
29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30 | "mt5": {
31 | "config_names": {
32 | # unlimited seqlen
33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35 | "context_length": "",
36 | "vocab_size": "vocab_size",
37 | "width": "d_model",
38 | "heads": "num_heads",
39 | "layers": "num_layers",
40 | "layer_attr": "block",
41 | "token_embeddings_attr": "embed_tokens"
42 | },
43 | "pooler": "mean_pooler",
44 | },
45 | # https://huggingface.co/docs/transformers/model_doc/bert
46 | "bert": {
47 | "config_names": {
48 | "context_length": "max_position_embeddings",
49 | "vocab_size": "vocab_size",
50 | "width": "hidden_size",
51 | "heads": "num_attention_heads",
52 | "layers": "num_hidden_layers",
53 | },
54 | "pooler": "cls_pooler",
55 | },
56 | # https://huggingface.co/docs/transformers/model_doc/m2m_100
57 | "m2m_100": {
58 | "config_names": {
59 | "context_length": "max_position_embeddings",
60 | "vocab_size": "vocab_size",
61 | "width": "d_model",
62 | "heads": "encoder_attention_heads",
63 | "layers": "encoder_layers",
64 | },
65 | "pooler": "cls_pooler",
66 | },
67 | }
68 |
--------------------------------------------------------------------------------
/src/open_clip/hf_model.py:
--------------------------------------------------------------------------------
1 | """ huggingface model adapter
2 |
3 | Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
4 | """
5 | import re
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch import TensorType
10 |
11 | try:
12 | import transformers
13 | from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
14 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
15 | BaseModelOutputWithPoolingAndCrossAttentions
16 | except ImportError as e:
17 | transformers = None
18 |
19 |
20 | class BaseModelOutput:
21 | pass
22 |
23 |
24 | class PretrainedConfig:
25 | pass
26 |
27 | from .hf_configs import arch_dict
28 |
29 |
30 | # utils
31 | def _camel2snake(s):
32 | return re.sub(r'(? 1
17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 | self.act1 = nn.ReLU(inplace=True)
20 |
21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 | self.act2 = nn.ReLU(inplace=True)
24 |
25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26 |
27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29 | self.act3 = nn.ReLU(inplace=True)
30 |
31 | self.downsample = None
32 | self.stride = stride
33 |
34 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36 | self.downsample = nn.Sequential(OrderedDict([
37 | ("-1", nn.AvgPool2d(stride)),
38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39 | ("1", nn.BatchNorm2d(planes * self.expansion))
40 | ]))
41 |
42 | def forward(self, x: torch.Tensor):
43 | identity = x
44 |
45 | out = self.act1(self.bn1(self.conv1(x)))
46 | out = self.act2(self.bn2(self.conv2(out)))
47 | out = self.avgpool(out)
48 | out = self.bn3(self.conv3(out))
49 |
50 | if self.downsample is not None:
51 | identity = self.downsample(x)
52 |
53 | out += identity
54 | out = self.act3(out)
55 | return out
56 |
57 |
58 | class AttentionPool2d(nn.Module):
59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60 | super().__init__()
61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62 | self.k_proj = nn.Linear(embed_dim, embed_dim)
63 | self.q_proj = nn.Linear(embed_dim, embed_dim)
64 | self.v_proj = nn.Linear(embed_dim, embed_dim)
65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66 | self.num_heads = num_heads
67 |
68 | def forward(self, x):
69 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72 | x, _ = F.multi_head_attention_forward(
73 | query=x, key=x, value=x,
74 | embed_dim_to_check=x.shape[-1],
75 | num_heads=self.num_heads,
76 | q_proj_weight=self.q_proj.weight,
77 | k_proj_weight=self.k_proj.weight,
78 | v_proj_weight=self.v_proj.weight,
79 | in_proj_weight=None,
80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81 | bias_k=None,
82 | bias_v=None,
83 | add_zero_attn=False,
84 | dropout_p=0.,
85 | out_proj_weight=self.c_proj.weight,
86 | out_proj_bias=self.c_proj.bias,
87 | use_separate_proj_weight=True,
88 | training=self.training,
89 | need_weights=False
90 | )
91 |
92 | return x[0]
93 |
94 |
95 | class ModifiedResNet(nn.Module):
96 | """
97 | A ResNet class that is similar to torchvision's but contains the following changes:
98 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
99 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
100 | - The final pooling layer is a QKV attention instead of an average pool
101 | """
102 |
103 | def __init__(self, layers, output_dim, heads, image_size=224, width=64):
104 | super().__init__()
105 | self.output_dim = output_dim
106 | self.image_size = image_size
107 |
108 | # the 3-layer stem
109 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
110 | self.bn1 = nn.BatchNorm2d(width // 2)
111 | self.act1 = nn.ReLU(inplace=True)
112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
113 | self.bn2 = nn.BatchNorm2d(width // 2)
114 | self.act2 = nn.ReLU(inplace=True)
115 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
116 | self.bn3 = nn.BatchNorm2d(width)
117 | self.act3 = nn.ReLU(inplace=True)
118 | self.avgpool = nn.AvgPool2d(2)
119 |
120 | # residual layers
121 | self._inplanes = width # this is a *mutable* variable used during construction
122 | self.layer1 = self._make_layer(width, layers[0])
123 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
124 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
125 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
126 |
127 | embed_dim = width * 32 # the ResNet feature dimension
128 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
129 |
130 | self.init_parameters()
131 |
132 | def _make_layer(self, planes, blocks, stride=1):
133 | layers = [Bottleneck(self._inplanes, planes, stride)]
134 |
135 | self._inplanes = planes * Bottleneck.expansion
136 | for _ in range(1, blocks):
137 | layers.append(Bottleneck(self._inplanes, planes))
138 |
139 | return nn.Sequential(*layers)
140 |
141 | def init_parameters(self):
142 | if self.attnpool is not None:
143 | std = self.attnpool.c_proj.in_features ** -0.5
144 | nn.init.normal_(self.attnpool.q_proj.weight, std=std)
145 | nn.init.normal_(self.attnpool.k_proj.weight, std=std)
146 | nn.init.normal_(self.attnpool.v_proj.weight, std=std)
147 | nn.init.normal_(self.attnpool.c_proj.weight, std=std)
148 |
149 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
150 | for name, param in resnet_block.named_parameters():
151 | if name.endswith("bn3.weight"):
152 | nn.init.zeros_(param)
153 |
154 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
155 | assert unlocked_groups == 0, 'partial locking not currently supported for this model'
156 | for param in self.parameters():
157 | param.requires_grad = False
158 | if freeze_bn_stats:
159 | freeze_batch_norm_2d(self)
160 |
161 | @torch.jit.ignore
162 | def set_grad_checkpointing(self, enable=True):
163 | # FIXME support for non-transformer
164 | pass
165 |
166 | def stem(self, x):
167 | x = self.act1(self.bn1(self.conv1(x)))
168 | x = self.act2(self.bn2(self.conv2(x)))
169 | x = self.act3(self.bn3(self.conv3(x)))
170 | x = self.avgpool(x)
171 | return x
172 |
173 | def forward(self, x):
174 | x = self.stem(x)
175 | x = self.layer1(x)
176 | x = self.layer2(x)
177 | x = self.layer3(x)
178 | x = self.layer4(x)
179 | x = self.attnpool(x)
180 |
181 | return x
182 |
--------------------------------------------------------------------------------
/src/open_clip/openai.py:
--------------------------------------------------------------------------------
1 | """ OpenAI pretrained model functions
2 |
3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 |
6 | import os
7 | import warnings
8 | from typing import List, Optional, Union
9 |
10 | import torch
11 |
12 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
13 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
14 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
15 |
16 | __all__ = ["list_openai_models", "load_openai_model"]
17 |
18 |
19 | def list_openai_models() -> List[str]:
20 | """Returns the names of available CLIP models"""
21 | return list_pretrained_models_by_tag('openai')
22 |
23 |
24 | def load_openai_model(
25 | name: str,
26 | precision: Optional[str] = None,
27 | device: Optional[Union[str, torch.device]] = None,
28 | cache_dir: Optional[str] = None,
29 | ):
30 | """Load a CLIP model
31 |
32 | Parameters
33 | ----------
34 | name : str
35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
36 | precision: str
37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
38 | device : Union[str, torch.device]
39 | The device to put the loaded model
40 | cache_dir : Optional[str]
41 | The directory to cache the downloaded model weights
42 |
43 | Returns
44 | -------
45 | model : torch.nn.Module
46 | The CLIP model
47 | preprocess : Callable[[PIL.Image], torch.Tensor]
48 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
49 | """
50 | if device is None:
51 | device = "cuda" if torch.cuda.is_available() else "cpu"
52 | if precision is None:
53 | precision = 'fp32' if device == 'cpu' else 'fp16'
54 |
55 | if get_pretrained_url(name, 'openai'):
56 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
57 | elif os.path.isfile(name):
58 | model_path = name
59 | else:
60 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
61 |
62 | try:
63 | # loading JIT archive
64 | model = torch.jit.load(model_path, map_location="cpu").eval()
65 | state_dict = None
66 | except RuntimeError:
67 | # loading saved state dict
68 | state_dict = torch.load(model_path, map_location="cpu")
69 |
70 | # Build a non-jit model from the OpenAI jitted model state dict
71 | cast_dtype = get_cast_dtype(precision)
72 | try:
73 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
74 | except KeyError:
75 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
76 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
77 |
78 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
79 | model = model.to(device)
80 | # FIXME support pure fp16/bf16 precision modes
81 | if precision != 'fp16':
82 | model.float()
83 | if precision == 'bf16':
84 | # for bf16, convert back to low-precision
85 | convert_weights_to_lp(model, dtype=torch.bfloat16)
86 |
87 | # add mean / std attributes for consistency with OpenCLIP models
88 | model.visual.image_mean = OPENAI_DATASET_MEAN
89 | model.visual.image_std = OPENAI_DATASET_STD
90 | return model
91 |
--------------------------------------------------------------------------------
/src/open_clip/pos_embed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # Position embedding utils
8 | # --------------------------------------------------------
9 |
10 | import numpy as np
11 |
12 | import torch
13 |
14 | # --------------------------------------------------------
15 | # 2D sine-cosine position embedding
16 | # References:
17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
18 | # MoCo v3: https://github.com/facebookresearch/moco-v3
19 | # --------------------------------------------------------
20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
21 | """
22 | grid_size: int of the grid height and width
23 | return:
24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
25 | """
26 | grid_h = np.arange(grid_size, dtype=np.float32)
27 | grid_w = np.arange(grid_size, dtype=np.float32)
28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
29 | grid = np.stack(grid, axis=0)
30 |
31 | grid = grid.reshape([2, 1, grid_size, grid_size])
32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
33 | if cls_token:
34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
35 | return pos_embed
36 |
37 |
38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
39 | assert embed_dim % 2 == 0
40 |
41 | # use half of dimensions to encode grid_h
42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
44 |
45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
46 | return emb
47 |
48 |
49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
50 | """
51 | embed_dim: output dimension for each position
52 | pos: a list of positions to be encoded: size (M,)
53 | out: (M, D)
54 | """
55 | assert embed_dim % 2 == 0
56 | omega = np.arange(embed_dim // 2, dtype=float)
57 | omega /= embed_dim / 2.
58 | omega = 1. / 10000**omega # (D/2,)
59 |
60 | pos = pos.reshape(-1) # (M,)
61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
62 |
63 | emb_sin = np.sin(out) # (M, D/2)
64 | emb_cos = np.cos(out) # (M, D/2)
65 |
66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
67 | return emb
68 |
69 |
70 | # --------------------------------------------------------
71 | # Interpolate position embeddings for high-resolution
72 | # References:
73 | # DeiT: https://github.com/facebookresearch/deit
74 | # --------------------------------------------------------
75 | def interpolate_pos_embed(model, checkpoint_model):
76 | if 'pos_embed' in checkpoint_model:
77 | pos_embed_checkpoint = checkpoint_model['pos_embed']
78 | embedding_size = pos_embed_checkpoint.shape[-1]
79 | num_patches = model.patch_embed.num_patches
80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
81 | # height (== width) for the checkpoint position embedding
82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
83 | # height (== width) for the new position embedding
84 | new_size = int(num_patches ** 0.5)
85 | # class_token and dist_token are kept unchanged
86 | if orig_size != new_size:
87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
89 | # only the position tokens are interpolated
90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
92 | pos_tokens = torch.nn.functional.interpolate(
93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
96 | checkpoint_model['pos_embed'] = new_pos_embed
97 |
--------------------------------------------------------------------------------
/src/open_clip/push_to_hf_hub.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | from pathlib import Path
5 | from tempfile import TemporaryDirectory
6 | from typing import Optional, Tuple, Union
7 |
8 | import torch
9 |
10 | try:
11 | from huggingface_hub import (
12 | create_repo,
13 | get_hf_file_metadata,
14 | hf_hub_download,
15 | hf_hub_url,
16 | repo_type_and_id_from_hf_id,
17 | upload_folder,
18 | list_repo_files,
19 | )
20 | from huggingface_hub.utils import EntryNotFoundError
21 | _has_hf_hub = True
22 | except ImportError:
23 | _has_hf_hub = False
24 |
25 | try:
26 | import safetensors.torch
27 | _has_safetensors = True
28 | except ImportError:
29 | _has_safetensors = False
30 |
31 | from .factory import create_model_from_pretrained, get_model_config, get_tokenizer
32 | from .tokenizer import HFTokenizer
33 |
34 | # Default name for a weights file hosted on the Huggingface Hub.
35 | HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl
36 | HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version
37 | HF_CONFIG_NAME = 'open_clip_config.json'
38 |
39 |
40 | def save_config_for_hf(
41 | model,
42 | config_path: str,
43 | model_config: Optional[dict]
44 | ):
45 | preprocess_cfg = {
46 | 'mean': model.visual.image_mean,
47 | 'std': model.visual.image_std,
48 | }
49 | other_pp = getattr(model.visual, 'preprocess_cfg', {})
50 | if 'interpolation' in other_pp:
51 | preprocess_cfg['interpolation'] = other_pp['interpolation']
52 | if 'resize_mode' in other_pp:
53 | preprocess_cfg['resize_mode'] = other_pp['resize_mode']
54 | hf_config = {
55 | 'model_cfg': model_config,
56 | 'preprocess_cfg': preprocess_cfg,
57 | }
58 |
59 | with config_path.open('w') as f:
60 | json.dump(hf_config, f, indent=2)
61 |
62 |
63 | def save_for_hf(
64 | model,
65 | tokenizer: HFTokenizer,
66 | model_config: dict,
67 | save_directory: str,
68 | safe_serialization: Union[bool, str] = 'both',
69 | skip_weights : bool = False,
70 | ):
71 | config_filename = HF_CONFIG_NAME
72 |
73 | save_directory = Path(save_directory)
74 | save_directory.mkdir(exist_ok=True, parents=True)
75 |
76 | if not skip_weights:
77 | tensors = model.state_dict()
78 | if safe_serialization is True or safe_serialization == "both":
79 | assert _has_safetensors, "`pip install safetensors` to use .safetensors"
80 | safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
81 | if safe_serialization is False or safe_serialization == "both":
82 | torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
83 |
84 | tokenizer.save_pretrained(save_directory)
85 |
86 | config_path = save_directory / config_filename
87 | save_config_for_hf(model, config_path, model_config=model_config)
88 |
89 |
90 | def push_to_hf_hub(
91 | model,
92 | tokenizer,
93 | model_config: Optional[dict],
94 | repo_id: str,
95 | commit_message: str = 'Add model',
96 | token: Optional[str] = None,
97 | revision: Optional[str] = None,
98 | private: bool = False,
99 | create_pr: bool = False,
100 | model_card: Optional[dict] = None,
101 | safe_serialization: Union[bool, str] = 'both',
102 | ):
103 | if not isinstance(tokenizer, HFTokenizer):
104 | # FIXME this makes it awkward to push models with new tokenizers, come up with better soln.
105 | # default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14
106 | tokenizer = HFTokenizer('openai/clip-vit-large-patch14')
107 |
108 | # Create repo if it doesn't exist yet
109 | repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
110 |
111 | # Infer complete repo_id from repo_url
112 | # Can be different from the input `repo_id` if repo_owner was implicit
113 | _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
114 | repo_id = f"{repo_owner}/{repo_name}"
115 |
116 | # Check if repo already exists and determine what needs updating
117 | repo_exists = False
118 | repo_files = {}
119 | try:
120 | repo_files = set(list_repo_files(repo_id))
121 | repo_exists = True
122 | except Exception as e:
123 | print('Repo does not exist', e)
124 |
125 | try:
126 | get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
127 | has_readme = True
128 | except EntryNotFoundError:
129 | has_readme = False
130 |
131 | # Dump model and push to Hub
132 | with TemporaryDirectory() as tmpdir:
133 | # Save model weights and config.
134 | save_for_hf(
135 | model,
136 | tokenizer=tokenizer,
137 | model_config=model_config,
138 | save_directory=tmpdir,
139 | safe_serialization=safe_serialization,
140 | )
141 |
142 | # Add readme if it does not exist
143 | if not has_readme:
144 | model_card = model_card or {}
145 | model_name = repo_id.split('/')[-1]
146 | readme_path = Path(tmpdir) / "README.md"
147 | readme_text = generate_readme(model_card, model_name)
148 | readme_path.write_text(readme_text)
149 |
150 | # Upload model and return
151 | return upload_folder(
152 | repo_id=repo_id,
153 | folder_path=tmpdir,
154 | revision=revision,
155 | create_pr=create_pr,
156 | commit_message=commit_message,
157 | )
158 |
159 |
160 | def push_pretrained_to_hf_hub(
161 | model_name,
162 | pretrained: str,
163 | repo_id: str,
164 | precision: str = 'fp32',
165 | image_mean: Optional[Tuple[float, ...]] = None,
166 | image_std: Optional[Tuple[float, ...]] = None,
167 | image_interpolation: Optional[str] = None,
168 | image_resize_mode: Optional[str] = None, # only effective for inference
169 | commit_message: str = 'Add model',
170 | token: Optional[str] = None,
171 | revision: Optional[str] = None,
172 | private: bool = False,
173 | create_pr: bool = False,
174 | model_card: Optional[dict] = None,
175 | hf_tokenizer_self: bool = False,
176 | **kwargs,
177 | ):
178 | model, preprocess_eval = create_model_from_pretrained(
179 | model_name,
180 | pretrained=pretrained,
181 | precision=precision,
182 | image_mean=image_mean,
183 | image_std=image_std,
184 | image_interpolation=image_interpolation,
185 | image_resize_mode=image_resize_mode,
186 | **kwargs,
187 | )
188 | model_config = get_model_config(model_name)
189 | if pretrained == 'openai':
190 | model_config['quick_gelu'] = True
191 | assert model_config
192 |
193 | tokenizer = get_tokenizer(model_name)
194 | if hf_tokenizer_self:
195 | # make hf tokenizer config in the uploaded model point to self instead of original location
196 | model_config['text']['hf_tokenizer_name'] = repo_id
197 |
198 | push_to_hf_hub(
199 | model=model,
200 | tokenizer=tokenizer,
201 | model_config=model_config,
202 | repo_id=repo_id,
203 | commit_message=commit_message,
204 | token=token,
205 | revision=revision,
206 | private=private,
207 | create_pr=create_pr,
208 | model_card=model_card,
209 | safe_serialization='both',
210 | )
211 |
212 |
213 | def generate_readme(model_card: dict, model_name: str):
214 | tags = model_card.pop('tags', ('clip',))
215 | pipeline_tag = model_card.pop('pipeline_tag', 'zero-shot-image-classification')
216 | readme_text = "---\n"
217 | if tags:
218 | readme_text += "tags:\n"
219 | for t in tags:
220 | readme_text += f"- {t}\n"
221 | readme_text += "library_name: open_clip\n"
222 | readme_text += f"pipeline_tag: {pipeline_tag}\n"
223 | readme_text += f"license: {model_card.get('license', 'mit')}\n"
224 | if 'details' in model_card and 'Dataset' in model_card['details']:
225 | readme_text += 'datasets:\n'
226 | readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
227 | readme_text += "---\n"
228 | readme_text += f"# Model card for {model_name}\n"
229 | if 'description' in model_card:
230 | readme_text += f"\n{model_card['description']}\n"
231 | if 'details' in model_card:
232 | readme_text += f"\n## Model Details\n"
233 | for k, v in model_card['details'].items():
234 | if isinstance(v, (list, tuple)):
235 | readme_text += f"- **{k}:**\n"
236 | for vi in v:
237 | readme_text += f" - {vi}\n"
238 | elif isinstance(v, dict):
239 | readme_text += f"- **{k}:**\n"
240 | for ki, vi in v.items():
241 | readme_text += f" - {ki}: {vi}\n"
242 | else:
243 | readme_text += f"- **{k}:** {v}\n"
244 | if 'usage' in model_card:
245 | readme_text += f"\n## Model Usage\n"
246 | readme_text += model_card['usage']
247 | readme_text += '\n'
248 |
249 | if 'comparison' in model_card:
250 | readme_text += f"\n## Model Comparison\n"
251 | readme_text += model_card['comparison']
252 | readme_text += '\n'
253 |
254 | if 'citation' in model_card:
255 | readme_text += f"\n## Citation\n"
256 | if not isinstance(model_card['citation'], (list, tuple)):
257 | citations = [model_card['citation']]
258 | else:
259 | citations = model_card['citation']
260 | for c in citations:
261 | readme_text += f"```bibtex\n{c}\n```\n"
262 |
263 | return readme_text
264 |
265 |
266 | if __name__ == "__main__":
267 | parser = argparse.ArgumentParser(description="Push to Hugging Face Hub")
268 | parser.add_argument(
269 | "--model", type=str, help="Name of the model to use.",
270 | )
271 | parser.add_argument(
272 | "--pretrained", type=str,
273 | help="Use a pretrained CLIP model weights with the specified tag or file path.",
274 | )
275 | parser.add_argument(
276 | "--repo-id", type=str,
277 | help="Destination HF Hub repo-id ie 'organization/model_id'.",
278 | )
279 | parser.add_argument(
280 | "--precision", type=str, default='fp32',
281 | )
282 | parser.add_argument(
283 | '--image-mean', type=float, nargs='+', default=None, metavar='MEAN',
284 | help='Override default image mean value of dataset')
285 | parser.add_argument(
286 | '--image-std', type=float, nargs='+', default=None, metavar='STD',
287 | help='Override default image std deviation of of dataset')
288 | parser.add_argument(
289 | '--image-interpolation',
290 | default=None, type=str, choices=['bicubic', 'bilinear', 'random'],
291 | help="image resize interpolation"
292 | )
293 | parser.add_argument(
294 | '--image-resize-mode',
295 | default=None, type=str, choices=['shortest', 'longest', 'squash'],
296 | help="image resize mode during inference"
297 | )
298 | parser.add_argument(
299 | "--hf-tokenizer-self",
300 | default=False,
301 | action="store_true",
302 | help="make hf_tokenizer_name point in uploaded config point to itself"
303 | )
304 | args = parser.parse_args()
305 |
306 | print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}')
307 |
308 | # FIXME add support to pass model_card json / template from file via cmd line
309 |
310 | push_pretrained_to_hf_hub(
311 | args.model,
312 | args.pretrained,
313 | args.repo_id,
314 | precision=args.precision,
315 | image_mean=args.image_mean, # override image mean/std if trained w/ non defaults
316 | image_std=args.image_std,
317 | image_interpolation=args.image_interpolation,
318 | image_resize_mode=args.image_resize_mode,
319 | )
320 |
321 | print(f'{args.model} saved.')
322 |
--------------------------------------------------------------------------------
/src/open_clip/timm_model.py:
--------------------------------------------------------------------------------
1 | """ timm model adapter
2 |
3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4 | """
5 | import logging
6 | from collections import OrderedDict
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 | try:
12 | import timm
13 | from timm.models.layers import Mlp, to_2tuple
14 | try:
15 | # old timm imports < 0.8.1
16 | from timm.models.layers.attention_pool2d import RotAttentionPool2d
17 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
18 | except ImportError:
19 | # new timm imports >= 0.8.1
20 | from timm.layers import RotAttentionPool2d
21 | from timm.layers import AttentionPool2d as AbsAttentionPool2d
22 | except ImportError:
23 | timm = None
24 |
25 | from .utils import freeze_batch_norm_2d
26 |
27 |
28 | class TimmModel(nn.Module):
29 | """ timm model adapter
30 | """
31 |
32 | def __init__(
33 | self,
34 | model_name,
35 | embed_dim,
36 | image_size=224,
37 | pool='avg',
38 | proj='linear',
39 | proj_bias=False,
40 | drop=0.,
41 | drop_path=None,
42 | patch_drop=None,
43 | pretrained=False,
44 | ):
45 | super().__init__()
46 | if timm is None:
47 | raise RuntimeError("Please `pip install timm` to use timm models.")
48 | self.image_size = to_2tuple(image_size)
49 |
50 | # setup kwargs that may not be common across all models
51 | timm_kwargs = {}
52 | if drop_path is not None:
53 | timm_kwargs['drop_path_rate'] = drop_path
54 | if patch_drop is not None:
55 | timm_kwargs['patch_drop_rate'] = patch_drop
56 |
57 | custom_pool = pool in ('abs_attn', 'rot_attn')
58 | if proj:
59 | assert proj in ("linear", "mlp", "none")
60 | extra_proj = proj in ("linear", "mlp")
61 | if not extra_proj and not custom_pool:
62 | # use network classifier head as projection if no proj specified and no custom pooling used
63 | # if projection is explicitly set to "none" will be pass through from network trunk
64 | proj_dim = 0 if proj == 'none' else embed_dim
65 | self.trunk = timm.create_model(
66 | model_name,
67 | num_classes=proj_dim,
68 | global_pool=pool,
69 | pretrained=pretrained,
70 | **timm_kwargs,
71 | )
72 | prev_chs = embed_dim
73 | else:
74 | self.trunk = timm.create_model(
75 | model_name,
76 | pretrained=pretrained,
77 | **timm_kwargs,
78 | )
79 | feat_size = self.trunk.default_cfg.get('pool_size', None)
80 | feature_ndim = 1 if not feat_size else 2
81 | if custom_pool:
82 | assert feature_ndim == 2
83 | # if attn pooling used, remove both classifier and default pool
84 | self.trunk.reset_classifier(0, global_pool='')
85 | else:
86 | # reset global pool if pool config set, otherwise leave as network default
87 | reset_kwargs = dict(global_pool=pool) if pool else {}
88 | self.trunk.reset_classifier(0, **reset_kwargs)
89 | prev_chs = self.trunk.num_features
90 |
91 | head_layers = OrderedDict()
92 |
93 | # Add custom pooling to head
94 | if pool == 'abs_attn':
95 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
96 | prev_chs = embed_dim
97 | elif pool == 'rot_attn':
98 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
99 | prev_chs = embed_dim
100 |
101 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
102 | if proj == 'linear':
103 | head_layers['drop'] = nn.Dropout(drop)
104 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
105 | elif proj == 'mlp':
106 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias))
107 |
108 | self.head = nn.Sequential(head_layers)
109 |
110 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
111 | """ lock modules
112 | Args:
113 | unlocked_groups (int): leave last n layer groups unlocked (default: 0)
114 | """
115 | if not unlocked_groups:
116 | # lock full model
117 | for param in self.trunk.parameters():
118 | param.requires_grad = False
119 | if freeze_bn_stats:
120 | freeze_batch_norm_2d(self.trunk)
121 | else:
122 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change
123 | try:
124 | # FIXME import here until API stable and in an official release
125 | from timm.models.helpers import group_parameters, group_modules
126 | except ImportError:
127 | raise RuntimeError(
128 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
129 | matcher = self.trunk.group_matcher()
130 | gparams = group_parameters(self.trunk, matcher)
131 | max_layer_id = max(gparams.keys())
132 | max_layer_id = max_layer_id - unlocked_groups
133 | for group_idx in range(max_layer_id + 1):
134 | group = gparams[group_idx]
135 | for param in group:
136 | self.trunk.get_parameter(param).requires_grad = False
137 | if freeze_bn_stats:
138 | gmodules = group_modules(self.trunk, matcher, reverse=True)
139 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
140 | freeze_batch_norm_2d(self.trunk, gmodules)
141 |
142 | @torch.jit.ignore
143 | def set_grad_checkpointing(self, enable=True):
144 | try:
145 | self.trunk.set_grad_checkpointing(enable)
146 | except Exception as e:
147 | logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
148 |
149 | def forward(self, x):
150 | x = self.trunk(x)
151 | x = self.head(x)
152 | return x
153 |
--------------------------------------------------------------------------------
/src/open_clip/utils.py:
--------------------------------------------------------------------------------
1 | from itertools import repeat
2 | import collections.abc
3 |
4 | import torch
5 | from torch import nn as nn
6 | from torchvision.ops.misc import FrozenBatchNorm2d
7 |
8 |
9 | def freeze_batch_norm_2d(module, module_match={}, name=''):
10 | """
11 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
12 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
13 | returned. Otherwise, the module is walked recursively and submodules are converted in place.
14 |
15 | Args:
16 | module (torch.nn.Module): Any PyTorch module.
17 | module_match (dict): Dictionary of full module names to freeze (all if empty)
18 | name (str): Full module name (prefix)
19 |
20 | Returns:
21 | torch.nn.Module: Resulting module
22 |
23 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
24 | """
25 | res = module
26 | is_match = True
27 | if module_match:
28 | is_match = name in module_match
29 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
30 | res = FrozenBatchNorm2d(module.num_features)
31 | res.num_features = module.num_features
32 | res.affine = module.affine
33 | if module.affine:
34 | res.weight.data = module.weight.data.clone().detach()
35 | res.bias.data = module.bias.data.clone().detach()
36 | res.running_mean.data = module.running_mean.data
37 | res.running_var.data = module.running_var.data
38 | res.eps = module.eps
39 | else:
40 | for child_name, child in module.named_children():
41 | full_child_name = '.'.join([name, child_name]) if name else child_name
42 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
43 | if new_child is not child:
44 | res.add_module(child_name, new_child)
45 | return res
46 |
47 |
48 | # From PyTorch internals
49 | def _ntuple(n):
50 | def parse(x):
51 | if isinstance(x, collections.abc.Iterable):
52 | return x
53 | return tuple(repeat(x, n))
54 | return parse
55 |
56 |
57 | to_1tuple = _ntuple(1)
58 | to_2tuple = _ntuple(2)
59 | to_3tuple = _ntuple(3)
60 | to_4tuple = _ntuple(4)
61 | to_ntuple = lambda n, x: _ntuple(n)(x)
62 |
63 | # Replaces all linear layers with linear_replacement
64 | # TODO: add int8 support for other linear layers including attn and convnets
65 | def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True):
66 | for name, module in model.named_children():
67 | if len(list(module.children())) > 0:
68 | replace_linear(module, linear_replacement, include_modules, copy_weights)
69 |
70 | if isinstance(module, torch.nn.Linear) and name in include_modules:
71 | old_module = model._modules[name]
72 | model._modules[name] = linear_replacement(
73 | module.in_features,
74 | module.out_features,
75 | module.bias is not None,
76 | )
77 | if copy_weights:
78 | model._modules[name].weight.data.copy_(old_module.weight.data)
79 | if model._modules[name].bias is not None:
80 | model._modules[name].bias.data.copy_(old_module.bias)
81 |
82 | return model
83 |
84 | def convert_int8_model_to_inference_mode(model):
85 | for m in model.modules():
86 | if hasattr(m, 'prepare_for_eval'):
87 | int8_original_dtype = m.weight.dtype
88 | m.prepare_for_eval()
89 | m.int8_original_dtype = int8_original_dtype
90 |
--------------------------------------------------------------------------------
/src/open_clip/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '2.24.0'
2 |
--------------------------------------------------------------------------------
/src/open_clip/zero_shot_classifier.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from itertools import islice
3 | from typing import Callable, List, Optional, Sequence, Union
4 |
5 | import torch
6 | import torch.nn.functional as F
7 |
8 |
9 | def batched(iterable, n):
10 | """Batch data into lists of length *n*. The last batch may be shorter.
11 | NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl
12 | """
13 | it = iter(iterable)
14 | while True:
15 | batch = list(islice(it, n))
16 | if not batch:
17 | break
18 | yield batch
19 |
20 |
21 | def build_zero_shot_classifier(
22 | model,
23 | tokenizer,
24 | args,
25 | classnames: Sequence[str],
26 | templates: Sequence[Union[Callable, str]],
27 | num_classes_per_batch: Optional[int] = 10,
28 | device: Union[str, torch.device] = 'cpu',
29 | use_tqdm: bool = False,
30 | ):
31 | """ Build zero-shot classifier weights by iterating over class names in batches
32 | Args:
33 | model: CLIP model instance
34 | tokenizer: CLIP tokenizer instance
35 | classnames: A sequence of class (label) names
36 | templates: A sequence of callables or format() friendly strings to produce templates per class name
37 | num_classes_per_batch: The number of classes to batch together in each forward, all if None
38 | device: Device to use.
39 | use_tqdm: Enable TQDM progress bar.
40 | """
41 | assert isinstance(templates, Sequence) and len(templates) > 0
42 | assert isinstance(classnames, Sequence) and len(classnames) > 0
43 | use_format = isinstance(templates[0], str)
44 | num_templates = len(templates)
45 | num_classes = len(classnames)
46 | if use_tqdm:
47 | import tqdm
48 | num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1)
49 | iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch)
50 | else:
51 | iter_wrap = iter
52 |
53 | def _process_batch(batch_classnames):
54 | num_batch_classes = len(batch_classnames)
55 | texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates]
56 | texts = tokenizer(texts).to(device)
57 | class_embeddings = model.encode_text(texts, normalize=True)
58 | if isinstance(class_embeddings, dict):
59 | class_embeddings = class_embeddings['text_features']
60 |
61 | class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1)
62 | class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True)
63 | class_embeddings = class_embeddings.T # (512, num_batch_classes)
64 | return class_embeddings
65 |
66 | with torch.no_grad():
67 | if num_classes_per_batch:
68 | batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))]
69 | zeroshot_weights = torch.cat(batched_embeds, dim=-1)
70 | else:
71 | zeroshot_weights = _process_batch(classnames)
72 | return zeroshot_weights
73 |
74 |
75 | def build_zero_shot_classifier_legacy(
76 | model,
77 | tokenizer,
78 | classnames: Sequence[str],
79 | templates: Sequence[Union[Callable, str]],
80 | device: Union[str, torch.device] = 'cpu',
81 | use_tqdm: bool = False,
82 | ):
83 | """ Build zero-shot classifier weights by iterating over class names 1 by 1
84 | Args:
85 | model: CLIP model instance
86 | tokenizer: CLIP tokenizer instance
87 | classnames: A sequence of class (label) names
88 | templates: A sequence of callables or format() friendly strings to produce templates per class name
89 | device: Device to use.
90 | use_tqdm: Enable TQDM progress bar.
91 | """
92 | assert isinstance(templates, Sequence) and len(templates) > 0
93 | assert isinstance(classnames, Sequence) and len(classnames) > 0
94 | if use_tqdm:
95 | import tqdm
96 | iter_wrap = tqdm.tqdm
97 | else:
98 | iter_wrap = iter
99 |
100 | use_format = isinstance(templates[0], str)
101 |
102 | with torch.no_grad():
103 | zeroshot_weights = []
104 | for classname in iter_wrap(classnames):
105 | texts = [template.format(classname) if use_format else template(classname) for template in templates]
106 | texts = tokenizer(texts).to(device) # tokenize
107 | class_embeddings = model.encode_text(texts)
108 | class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
109 | class_embedding /= class_embedding.norm()
110 | zeroshot_weights.append(class_embedding)
111 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
112 |
113 | return zeroshot_weights
114 |
115 |
--------------------------------------------------------------------------------
/src/seg_eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | import training.clip_segmentor
5 | import training.custom_datasets
6 | from training.params import parse_args
7 |
8 | from mmengine.config import Config
9 | from mmengine.runner import Runner
10 |
11 |
12 | def trigger_visualization_hook(cfg, args):
13 | default_hooks = cfg.default_hooks
14 | if 'visualization' in default_hooks:
15 | visualization_hook = default_hooks['visualization']
16 | # Turn on visualization
17 | visualization_hook['draw'] = True
18 | if args.show:
19 | visualization_hook['show'] = True
20 | visualization_hook['wait_time'] = args.wait_time
21 | if args.show_dir:
22 | visualizer = cfg.visualizer
23 | visualizer['save_dir'] = args.show_dir
24 | else:
25 | raise RuntimeError(
26 | 'VisualizationHook must be included in default_hooks.'
27 | 'refer to usage '
28 | '"visualization=dict(type=\'VisualizationHook\')"')
29 |
30 | return cfg
31 |
32 |
33 | def main(args):
34 | args = parse_args(args)
35 | if args.seg_w_background: # with background
36 | conf_files = ['cfg_voc21.py', 'cfg_context60.py', 'cfg_coco_object.py']
37 | else:
38 | conf_files = ['cfg_voc20.py', 'cfg_city_scapes.py', 'cfg_context59.py', 'cfg_ade20k.py', 'cfg_coco_stuff164k.py']
39 |
40 | for conf_f in conf_files:
41 | cfg = Config.fromfile(f'./training/seg_configs/{conf_f}')
42 | cfg.launcher = 'none'
43 | cfg.work_dir = './work_logs/'
44 |
45 | openclip_args = {}
46 | for arg in vars(args):
47 | openclip_args[arg] = getattr(args, arg)
48 | cfg.model.update(openclip_args)
49 |
50 | runner = Runner.from_cfg(cfg)
51 | runner.test()
52 |
53 |
54 | if __name__ == "__main__":
55 | main(sys.argv[1:])
--------------------------------------------------------------------------------
/src/train_cc12m.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -x
2 | #SBATCH --nodes=32
3 | #SBATCH --gres=gpu:4
4 | #SBATCH --ntasks-per-node=4
5 | #SBATCH --cpus-per-task=48
6 | #SBATCH --job-name=train_cc12m
7 | #SBATCH --output=./cluster_logs/train_cc12m.txt
8 | #SBATCH --partition PARTITION_NAME
9 |
10 | source cosmos_env/bin/activate
11 | cd cosmos/src
12 |
13 | export CUDA_VISIBLE_DEVICES=0,1,2,3
14 | export MASTER_PORT=12802
15 |
16 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
17 | export MASTER_ADDR=$master_addr
18 |
19 | srun env -u CUDA_VISIBLE_DEVICES torchrun \
20 | --nproc_per_node=4 \
21 | --nnode=$SLURM_JOB_NUM_NODES \
22 | --rdzv_id=$SLURM_JOB_ID \
23 | --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
24 | --rdzv_backend=c10d \
25 | -m main \
26 | --logs-dir ./logs/ \
27 | --model ViT-B-16 \
28 | --dataset-type webdataset \
29 | --lr 5e-4 \
30 | --warmup 2000 \
31 | --epochs 32 \
32 | --train-data 'datasets/cc12m_recap/cc12m-train-{0000..2175}.tar' \
33 | --train-num-samples 10010225 \
34 | --val-data 'coco' \
35 | --data-root-dir directory/to/coco/ \
36 | --batch-size 32 \
37 | --precision amp \
38 | --workers 16 \
39 | --save-frequency 1 \
40 | --log-every-n-steps 200 \
41 | --wd 0.5 \
42 | --beta1 0.9 \
43 | --beta2 0.98 \
44 | --eps 1e-8 \
45 | --use-imagecrop-aug \
46 | --global-crops-number 2 \
47 | --local-crops-number 6 \
48 | --crop-scale 0.4 \
49 | --caption-sampling-mode textcrop \
50 | --num-sampled-captions 8 \
51 | --momentum-teacher 0.99 \
52 | --fix-momentum \
53 | --output-all \
54 | --attentional-pool \
55 | --cosmos
56 |
--------------------------------------------------------------------------------
/src/train_cc3m.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -x
2 | #SBATCH --nodes=4
3 | #SBATCH --gres=gpu:4
4 | #SBATCH --ntasks-per-node=4
5 | #SBATCH --cpus-per-task=48
6 | #SBATCH --job-name=train_cc3m
7 | #SBATCH --output=./cluster_logs/train_cc3m.txt
8 | #SBATCH --partition PARTITION_NAME
9 |
10 | source cosmos_env/bin/activate
11 | cd cosmos/src
12 |
13 | export CUDA_VISIBLE_DEVICES=0,1,2,3
14 | export MASTER_PORT=12802
15 |
16 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
17 | export MASTER_ADDR=$master_addr
18 |
19 | srun env -u CUDA_VISIBLE_DEVICES torchrun \
20 | --nproc_per_node=4 \
21 | --nnode=$SLURM_JOB_NUM_NODES \
22 | --rdzv_id=$SLURM_JOB_ID \
23 | --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
24 | --rdzv_backend=c10d \
25 | -m main \
26 | --logs-dir ./logs/ \
27 | --model ViT-B-16 \
28 | --dataset-type webdataset \
29 | --lr 5e-4 \
30 | --warmup 2000 \
31 | --epochs 32 \
32 | --train-data 'datasets/cc3m_recap/cc3m-train-{0000..0575}.tar' \
33 | --train-num-samples 2823019 \
34 | --val-data 'datasets/cc3m/cc3m-validation-00{00..15}.tar' \
35 | --val-num-samples 13443 \
36 | --batch-size 64 \
37 | --precision amp \
38 | --workers 16 \
39 | --save-frequency 1 \
40 | --log-every-n-steps 200 \
41 | --wd 0.5 \
42 | --beta1 0.9 \
43 | --beta2 0.98 \
44 | --eps 1e-8 \
45 | --use-imagecrop-aug \
46 | --global-crops-number 2 \
47 | --local-crops-number 6 \
48 | --crop-scale 0.4 \
49 | --caption-sampling-mode textcrop \
50 | --num-sampled-captions 8 \
51 | --momentum-teacher 0.999 \
52 | --fix-momentum \
53 | --output-all \
54 | --attentional-pool \
55 | --cosmos
56 |
--------------------------------------------------------------------------------
/src/train_merged30m.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -x
2 | #SBATCH --nodes=32
3 | #SBATCH --gres=gpu:4
4 | #SBATCH --ntasks-per-node=4
5 | #SBATCH --cpus-per-task=48
6 | #SBATCH --job-name=train_merged30m
7 | #SBATCH --output=./cluster_logs/train_merged30m.txt
8 | #SBATCH --partition PARTITION_NAME
9 |
10 | source cosmos_env/bin/activate
11 | cd cosmos/src
12 |
13 | export CUDA_VISIBLE_DEVICES=0,1,2,3
14 | export MASTER_PORT=12802
15 |
16 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
17 | export MASTER_ADDR=$master_addr
18 |
19 | srun env -u CUDA_VISIBLE_DEVICES torchrun \
20 | --nproc_per_node=4 \
21 | --nnode=$SLURM_JOB_NUM_NODES \
22 | --rdzv_id=$SLURM_JOB_ID \
23 | --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
24 | --rdzv_backend=c10d \
25 | -m main \
26 | --logs-dir ./logs/ \
27 | --model ViT-B-16 \
28 | --dataset-type webdataset \
29 | --lr 5e-4 \
30 | --warmup 2000 \
31 | --epochs 32 \
32 | --train-data '/datasets/yfcc15m_recap/yfcc15m-train-{0000..3636}.tar::/datasets/cc12m_recap/cc12m-train-{0000..2175}.tar::/datasets/cc3m_recap/cc3m-train-{0001..0575}.tar' \
33 | --train-num-samples 26899071 \
34 | --val-data 'coco' \
35 | --data-root-dir directory/to/coco/ \
36 | --batch-size 32 \
37 | --precision amp \
38 | --workers 16 \
39 | --save-frequency 1 \
40 | --log-every-n-steps 200 \
41 | --wd 0.5 \
42 | --beta1 0.9 \
43 | --beta2 0.98 \
44 | --eps 1e-8 \
45 | --use-imagecrop-aug \
46 | --global-crops-number 2 \
47 | --local-crops-number 6 \
48 | --crop-scale 0.4 \
49 | --caption-sampling-mode textcrop \
50 | --num-sampled-captions 8 \
51 | --momentum-teacher 0.99 \
52 | --fix-momentum \
53 | --output-all \
54 | --attentional-pool \
55 | --cosmos
56 |
--------------------------------------------------------------------------------
/src/train_pixelprose.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -x
2 | #SBATCH --nodes=32
3 | #SBATCH --gres=gpu:4
4 | #SBATCH --ntasks-per-node=4
5 | #SBATCH --cpus-per-task=48
6 | #SBATCH --job-name=train_pixelprose
7 | #SBATCH --output=./cluster_logs/train_pixelprose.txt
8 | #SBATCH --partition PARTITION_NAME
9 |
10 | source cosmos_env/bin/activate
11 | cd cosmos/src
12 |
13 | export CUDA_VISIBLE_DEVICES=0,1,2,3
14 | export MASTER_PORT=12802
15 |
16 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
17 | export MASTER_ADDR=$master_addr
18 |
19 | srun env -u CUDA_VISIBLE_DEVICES torchrun \
20 | --nproc_per_node=4 \
21 | --nnode=$SLURM_JOB_NUM_NODES \
22 | --rdzv_id=$SLURM_JOB_ID \
23 | --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
24 | --rdzv_backend=c10d \
25 | -m main \
26 | --logs-dir ./logs/ \
27 | --model ViT-B-16 \
28 | --dataset-type webdataset \
29 | --lr 5e-4 \
30 | --warmup 2000 \
31 | --epochs 32 \
32 | --train-data '/datasets/pixelprose/data/{00000..01699}.tar' \
33 | --train-num-samples 15037386 \
34 | --val-data 'coco' \
35 | --data-root-dir directory/to/coco/ \
36 | --batch-size 32 \
37 | --precision amp \
38 | --workers 16 \
39 | --save-frequency 1 \
40 | --log-every-n-steps 200 \
41 | --wd 0.5 \
42 | --beta1 0.9 \
43 | --beta2 0.98 \
44 | --eps 1e-8 \
45 | --use-imagecrop-aug \
46 | --global-crops-number 2 \
47 | --local-crops-number 6 \
48 | --crop-scale 0.4 \
49 | --caption-sampling-mode textcrop_pixelprose \
50 | --num-sampled-captions 8 \
51 | --momentum-teacher 0.99 \
52 | --fix-momentum \
53 | --output-all \
54 | --attentional-pool \
55 | --cosmos
56 |
--------------------------------------------------------------------------------
/src/train_yfcc15m.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -x
2 | #SBATCH --nodes=32
3 | #SBATCH --gres=gpu:4
4 | #SBATCH --ntasks-per-node=4
5 | #SBATCH --cpus-per-task=48
6 | #SBATCH --job-name=train_yfcc15m
7 | #SBATCH --output=./cluster_logs/train_yfcc15m.txt
8 | #SBATCH --partition PARTITION_NAME
9 |
10 | source cosmos_env/bin/activate
11 | cd cosmos/src
12 |
13 | export CUDA_VISIBLE_DEVICES=0,1,2,3
14 | export MASTER_PORT=12802
15 |
16 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
17 | export MASTER_ADDR=$master_addr
18 |
19 | srun env -u CUDA_VISIBLE_DEVICES torchrun \
20 | --nproc_per_node=4 \
21 | --nnode=$SLURM_JOB_NUM_NODES \
22 | --rdzv_id=$SLURM_JOB_ID \
23 | --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
24 | --rdzv_backend=c10d \
25 | -m main \
26 | --logs-dir ./logs/ \
27 | --model ViT-B-16 \
28 | --dataset-type webdataset \
29 | --lr 5e-4 \
30 | --warmup 2000 \
31 | --epochs 32 \
32 | --train-data 'datasets/yfcc15m_recap/yfcc15m-train-{0000..3636}.tar' \
33 | --train-num-samples 14065827 \
34 | --val-data 'coco' \
35 | --data-root-dir directory/to/coco/ \
36 | --batch-size 32 \
37 | --precision amp \
38 | --workers 16 \
39 | --save-frequency 1 \
40 | --log-every-n-steps 200 \
41 | --wd 0.5 \
42 | --beta1 0.9 \
43 | --beta2 0.98 \
44 | --eps 1e-8 \
45 | --use-imagecrop-aug \
46 | --global-crops-number 2 \
47 | --local-crops-number 6 \
48 | --crop-scale 0.4 \
49 | --caption-sampling-mode textcrop \
50 | --num-sampled-captions 8 \
51 | --momentum-teacher 0.99 \
52 | --fix-momentum \
53 | --output-all \
54 | --attentional-pool \
55 | --cosmos
56 |
--------------------------------------------------------------------------------
/src/training/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/training/clip_segmentor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import sys
4 | import os
5 | sys.path.append("..")
6 |
7 | import open_clip
8 | from open_clip import OPENAI_IMAGENET_TEMPLATES
9 |
10 | from mmseg.models.segmentors import BaseSegmentor
11 | from mmseg.models.data_preprocessor import SegDataPreProcessor
12 | from mmengine.structures import PixelData
13 |
14 | from mmseg.registry import MODELS
15 |
16 | from training.pamr import PAMR
17 |
18 | from training.file_utils import pt_load
19 | import copy
20 |
21 | import numpy as np
22 | from huggingface_hub import hf_hub_download
23 |
24 | def download_weights_from_hf(model_repo, filename):
25 | # Define the custom cache directory relative to the current script
26 | cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "pretrained")
27 | if not os.path.exists(cache_dir):
28 | os.makedirs(cache_dir, exist_ok=True)
29 | local_path = hf_hub_download(repo_id=model_repo, filename=filename, cache_dir=cache_dir)
30 | return local_path
31 |
32 | def load_model(**model_kwargs):
33 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
34 | input_model_kwargs = {}
35 | if model_kwargs['siglip']:
36 | input_model_kwargs['init_logit_scale'] = np.log(10) # different from CLIP
37 | input_model_kwargs['init_logit_bias'] = -10
38 | model, _, _ = open_clip.create_model_and_transforms(
39 | model_kwargs['model'],
40 | model_kwargs['pretrained'],
41 | precision=model_kwargs['precision'],
42 | device=device,
43 | jit=model_kwargs['torchscript'],
44 | force_quick_gelu=model_kwargs['force_quick_gelu'],
45 | force_custom_text=model_kwargs['force_custom_text'],
46 | force_patch_dropout=model_kwargs['force_patch_dropout'],
47 | force_image_size=model_kwargs['force_image_size'],
48 | image_mean=model_kwargs['image_mean'],
49 | image_std=model_kwargs['image_std'],
50 | image_interpolation=model_kwargs['image_interpolation'],
51 | image_resize_mode=model_kwargs['image_resize_mode'],
52 | use_imagecrop_aug=model_kwargs['use_imagecrop_aug'],
53 | global_crops_number=model_kwargs['global_crops_number'],
54 | local_crops_number=model_kwargs['local_crops_number'],
55 | crop_scale=model_kwargs['crop_scale'],
56 | aug_cfg=model_kwargs['aug_cfg'],
57 | pretrained_image=model_kwargs['pretrained_image'],
58 | output_dict=True,
59 | output_all=model_kwargs['output_all'],
60 | pool_type=model_kwargs['pool_type'],
61 | attentional_pool=model_kwargs['attentional_pool'],
62 | add_zero_attn=model_kwargs['add_zero_attn'],
63 | cosmos=model_kwargs['cosmos'],
64 | **input_model_kwargs,)
65 |
66 | ema_model = copy.deepcopy(model)
67 |
68 | if model_kwargs['pretrained']:
69 | return model
70 |
71 | if model_kwargs['huggingface_model_name'] != '':
72 | # based on huggingface model name, download the pre-trained weights, the downloaded path is passed as the 'resume' arguments
73 | huggingface_model_name, huggingface_repo_name = model_kwargs['huggingface_model_name'], model_kwargs['huggingface_repo_name']
74 | model_kwargs['resume'] = download_weights_from_hf(model_repo=huggingface_repo_name, filename=huggingface_model_name)
75 |
76 | checkpoint = pt_load(model_kwargs['resume'], map_location='cpu')
77 |
78 | # resuming a train checkpoint w/ epoch and optimizer state
79 | start_epoch = checkpoint["epoch"]
80 | if "state_dict" in checkpoint:
81 | if model_kwargs['use_ema_model']:
82 | ema_sd = checkpoint["ema_state_dict"]
83 | if next(iter(ema_sd.items()))[0].startswith('module'):
84 | ema_sd = {k[len('module.'):]: v for k, v in ema_sd.items()}
85 | ema_model.load_state_dict(ema_sd)
86 | print(f"=> resuming ema checkpoint '{model_kwargs['resume']}' (epoch {start_epoch})")
87 | return ema_model
88 | sd = checkpoint["state_dict"]
89 | if next(iter(sd.items()))[0].startswith('module'):
90 | sd = {k[len('module.'):]: v for k, v in sd.items()}
91 | model.load_state_dict(sd)
92 | print(f"=> resuming checkpoint '{model_kwargs['resume']}' (epoch {start_epoch})")
93 | return model
94 | else:
95 | """
96 | sd = checkpoint["student"]
97 | if next(iter(sd.items()))[0].startswith('module'):
98 | sd = {k[len('module.'):]: v for k, v in sd.items()}
99 | model.load_state_dict(sd)
100 | print(f"=> resuming checkpoint '{model_kwargs['resume']}' (epoch {start_epoch})")
101 | """
102 | # Evaluation on Teacher
103 | ema_sd = checkpoint["teacher"]
104 | if next(iter(ema_sd.items()))[0].startswith('module'):
105 | ema_sd = {k[len('module.'):]: v for k, v in ema_sd.items()}
106 | ema_model.load_state_dict(ema_sd)
107 | print(f"=> resuming ema checkpoint '{model_kwargs['resume']}' (epoch {start_epoch})")
108 |
109 | return ema_model
110 |
111 | @MODELS.register_module()
112 | class CLIPForSegmentation(BaseSegmentor):
113 | def __init__(self, name_path, device=torch.device('cuda:0'),
114 | pamr_steps=0, pamr_stride=(8, 16), prob_thd=0.0, logit_scale=40,
115 | slide_stride=112, slide_crop=224, area_thd=None, **model_kwargs):
116 |
117 | data_preprocessor = SegDataPreProcessor(
118 | mean=[122.771, 116.746, 104.094],
119 | std=[68.501, 66.632, 70.323],
120 | rgb_to_bgr=True)
121 | super().__init__(data_preprocessor=data_preprocessor)
122 |
123 | self.net = load_model(**model_kwargs)
124 | query_words, self.query_idx = get_cls_idx(name_path)
125 | self.num_queries = len(query_words)
126 | self.num_classes = max(self.query_idx) + 1
127 | self.query_idx = torch.Tensor(self.query_idx).to(torch.int64).to(device)
128 |
129 | query_features = []
130 | with torch.no_grad():
131 | for qw in query_words:
132 | query = open_clip.tokenize([temp(qw) for temp in OPENAI_IMAGENET_TEMPLATES]).to(device) # clip.tokenize([temp(qw) for temp in OPENAI_IMAGENET_TEMPLATES]).to(device)
133 | feature = self.net.encode_text(query)
134 | feature = feature['text_features'] if isinstance(feature, dict) else feature
135 | feature /= feature.norm(dim=-1, keepdim=True)
136 | feature = feature.mean(dim=0)
137 | feature /= feature.norm()
138 | query_features.append(feature.unsqueeze(0))
139 | self.query_features = torch.cat(query_features, dim=0)
140 |
141 | self.dtype = self.query_features.dtype
142 | self.logit_scale = logit_scale
143 | self.prob_thd = prob_thd
144 | self.area_thd = area_thd
145 | self.slide_stride = slide_stride
146 | self.slide_crop = slide_crop
147 | self.align_corners = False
148 | self.use_csa = model_kwargs['use_csa']
149 |
150 | if pamr_steps > 0:
151 | self.pamr = PAMR(pamr_steps, dilations=pamr_stride).to(device)
152 | else:
153 | self.pamr = None
154 |
155 | def forward_feature(self, img, logit_size=None):
156 | if type(img) == list:
157 | img = img[0]
158 | if self.use_csa:
159 | csa_image_features, _ = self.net.visual(img, return_all=True, csa=True)
160 | csa_image_features = csa_image_features @ self.net.visual.proj # [B, L-1, C]
161 |
162 | image_features = csa_image_features
163 | image_features /= image_features.norm(dim=-1, keepdim=True)
164 | logits = image_features @ self.query_features.T
165 | else:
166 | image_features, _ = self.net.visual(img, return_all=True, csa=self.use_csa)
167 | image_features = image_features @ self.net.visual.proj # [B, L-1, C]
168 | image_features /= image_features.norm(dim=-1, keepdim=True)
169 | logits = image_features @ self.query_features.T
170 |
171 | patch_size = self.net.visual.patch_size
172 | patch_size = patch_size[0] if isinstance(patch_size, (list, tuple)) else patch_size
173 |
174 | w, h = img[0].shape[-2] // patch_size, img[0].shape[-1] // patch_size
175 | out_dim = logits.shape[-1]
176 | logits = logits.permute(0, 2, 1).reshape(-1, out_dim, w, h)
177 |
178 | if logit_size == None:
179 | logits = nn.functional.interpolate(logits, size=img.shape[-2:], mode='bilinear')
180 | else:
181 | logits = nn.functional.interpolate(logits, size=logit_size, mode='bilinear')
182 |
183 | return logits
184 |
185 | def forward_slide(self, img, img_metas, stride=112, crop_size=224):
186 | """Inference by sliding-window with overlap.
187 | If h_crop > h_img or w_crop > w_img, the small patch will be used to
188 | decode without padding.
189 | """
190 | if type(img) == list:
191 | img = img[0].unsqueeze(0)
192 | if type(stride) == int:
193 | stride = (stride, stride)
194 | if type(crop_size) == int:
195 | crop_size = (crop_size, crop_size)
196 |
197 | h_stride, w_stride = stride
198 | h_crop, w_crop = crop_size
199 | batch_size, _, h_img, w_img = img.shape
200 | out_channels = self.num_queries
201 | h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
202 | w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
203 | preds = img.new_zeros((batch_size, out_channels, h_img, w_img))
204 | count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
205 | for h_idx in range(h_grids):
206 | for w_idx in range(w_grids):
207 | y1 = h_idx * h_stride
208 | x1 = w_idx * w_stride
209 | y2 = min(y1 + h_crop, h_img)
210 | x2 = min(x1 + w_crop, w_img)
211 | y1 = max(y2 - h_crop, 0)
212 | x1 = max(x2 - w_crop, 0)
213 | crop_img = img[:, :, y1:y2, x1:x2]
214 | crop_seg_logit = self.forward_feature(crop_img)
215 | preds += nn.functional.pad(crop_seg_logit,
216 | (int(x1), int(preds.shape[3] - x2), int(y1),
217 | int(preds.shape[2] - y2)))
218 |
219 | count_mat[:, :, y1:y2, x1:x2] += 1
220 | assert (count_mat == 0).sum() == 0
221 |
222 | preds = preds / count_mat
223 | img_size = img_metas[0]['ori_shape'][:2]
224 | logits = nn.functional.interpolate(preds, size=img_size, mode='bilinear')
225 |
226 | if self.pamr:
227 | img = nn.functional.interpolate(img, size=img_size, mode='bilinear')
228 | logits = self.pamr(img, logits.to(img.dtype)).to(self.dtype)
229 |
230 | return logits
231 |
232 | def predict(self, inputs, data_samples):
233 | if data_samples is not None:
234 | batch_img_metas = [
235 | data_sample.metainfo for data_sample in data_samples
236 | ]
237 | else:
238 | batch_img_metas = [
239 | dict(
240 | ori_shape=inputs.shape[2:],
241 | img_shape=inputs.shape[2:],
242 | pad_shape=inputs.shape[2:],
243 | padding_size=[0, 0, 0, 0])
244 | ] * inputs.shape[0]
245 |
246 | if self.slide_crop > 0:
247 | seg_logits = self.forward_slide(inputs, batch_img_metas, self.slide_stride, self.slide_crop)
248 | else:
249 | seg_logits = self.forward_feature(inputs, batch_img_metas[0]['ori_shape'])
250 |
251 | return self.postprocess_result(seg_logits, data_samples)
252 |
253 | def postprocess_result(self, seg_logits, data_samples):
254 | batch_size = seg_logits.shape[0]
255 | for i in range(batch_size):
256 | seg_logits = seg_logits[i] * self.logit_scale
257 | seg_logits = seg_logits.softmax(0) # n_queries * w * h
258 |
259 | num_cls, num_queries = max(self.query_idx) + 1, len(self.query_idx)
260 | if num_cls != num_queries:
261 | seg_logits = seg_logits.unsqueeze(0)
262 | cls_index = nn.functional.one_hot(self.query_idx)
263 | cls_index = cls_index.T.view(num_cls, num_queries, 1, 1)
264 | seg_logits = (seg_logits * cls_index).max(1)[0]
265 | seg_pred = seg_logits.argmax(0, keepdim=True)
266 |
267 | if self.area_thd is not None:
268 | # Force segmentations with area < self.area_thd to 0 (background)
269 | predictions = nn.functional.one_hot(seg_logits.argmax(0), num_cls).to(seg_logits.dtype)
270 | area_pred = predictions[:, :, 1:].sum((0, 1), keepdim=True) # prone background
271 | area_pred = (area_pred > self.area_thd * area_pred.sum()).to(seg_logits.dtype)
272 | seg_logits[1:] *= area_pred.transpose(0, -1)
273 |
274 | seg_pred = seg_logits.argmax(0, keepdim=True)
275 | seg_pred[seg_logits.max(0, keepdim=True)[0] < self.prob_thd] = 0
276 |
277 | data_samples[i].set_data({
278 | 'seg_logits':
279 | PixelData(**{'data': seg_logits}),
280 | 'pred_sem_seg':
281 | PixelData(**{'data': seg_pred})
282 | })
283 |
284 | return data_samples
285 |
286 | def _forward(data_samples):
287 | """
288 | """
289 |
290 | def inference(self, img, batch_img_metas):
291 | """
292 | """
293 |
294 | def encode_decode(self, inputs, batch_img_metas):
295 | """
296 | """
297 |
298 | def extract_feat(self, inputs):
299 | """
300 | """
301 |
302 | def loss(self, inputs, data_samples):
303 | """
304 | """
305 |
306 | def get_cls_idx(path):
307 | with open(path, 'r') as f:
308 | name_sets = f.readlines()
309 | num_cls = len(name_sets)
310 |
311 | class_names, class_indices = [], []
312 | for idx in range(num_cls):
313 | names_i = name_sets[idx].split(', ')
314 | class_names += names_i
315 | class_indices += [idx for _ in range(len(names_i))]
316 | class_names = [item.replace('\n', '') for item in class_names]
317 | return class_names, class_indices
--------------------------------------------------------------------------------
/src/training/custom_datasets.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import mmengine.fileio as fileio
3 |
4 | from mmseg.registry import DATASETS
5 | from mmseg.datasets import BaseSegDataset
6 |
7 | @DATASETS.register_module()
8 | class PascalVOC20Dataset(BaseSegDataset):
9 | """Pascal VOC dataset.
10 |
11 | Args:
12 | split (str): Split txt file for Pascal VOC.
13 | """
14 | METAINFO = dict(
15 | classes=('aeroplane', 'bicycle', 'bird', 'boat',
16 | 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',
17 | 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep',
18 | 'sofa', 'train', 'tvmonitor'),
19 | palette=[[128, 0, 0], [0, 128, 0], [0, 0, 192],
20 | [128, 128, 0], [128, 0, 128], [0, 128, 128], [192, 128, 64],
21 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
22 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
23 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
24 | [0, 64, 128]])
25 |
26 | def __init__(self,
27 | ann_file,
28 | img_suffix='.jpg',
29 | seg_map_suffix='.png',
30 | reduce_zero_label=True,
31 | **kwargs) -> None:
32 | super().__init__(
33 | img_suffix=img_suffix,
34 | seg_map_suffix=seg_map_suffix,
35 | reduce_zero_label=reduce_zero_label,
36 | ann_file=ann_file,
37 | **kwargs)
38 | assert fileio.exists(self.data_prefix['img_path'],
39 | self.backend_args) and osp.isfile(self.ann_file)
40 |
41 | @DATASETS.register_module()
42 | class COCOObjectDataset(BaseSegDataset):
43 | """
44 | Implementation borrowed from TCL (https://github.com/kakaobrain/tcl) and GroupViT (https://github.com/NVlabs/GroupViT)
45 | COCO-Object dataset.
46 | 1 bg class + first 80 classes from the COCO-Stuff dataset.
47 | """
48 |
49 | METAINFO = dict(
50 |
51 | classes = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
52 | 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
53 | 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
54 | 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
55 | 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
56 | 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut',
57 | 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse',
58 | 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
59 | 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'),
60 |
61 | palette = [[0, 0, 0], [0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192], [0, 64, 64], [0, 192, 224],
62 | [0, 192, 192], [128, 192, 64], [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224], [0, 0, 64],
63 | [0, 160, 192], [128, 0, 96], [128, 0, 192], [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192],
64 | [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128], [64, 128, 32], [0, 160, 0], [0, 0, 0],
65 | [192, 128, 160], [0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0], [0, 128, 0], [192, 128, 32],
66 | [128, 96, 128], [0, 0, 128], [64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160], [0, 96, 128],
67 | [128, 128, 128], [64, 0, 160], [128, 224, 128], [128, 128, 64], [192, 0, 32],
68 | [128, 96, 0], [128, 0, 192], [0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160], [64, 96, 0],
69 | [0, 128, 192], [0, 128, 160], [192, 224, 0], [0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192],
70 | [0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160], [64, 32, 128], [128, 192, 192], [0, 0, 160],
71 | [192, 160, 128], [128, 192, 0], [128, 0, 96], [192, 32, 0], [128, 64, 128], [64, 128, 96], [64, 160, 0],
72 | [0, 64, 0], [192, 128, 224], [64, 32, 0], [0, 192, 128], [64, 128, 224], [192, 160, 0]])
73 |
74 | def __init__(self, **kwargs):
75 | super(COCOObjectDataset, self).__init__(img_suffix='.jpg', seg_map_suffix='_instanceTrainIds.png', **kwargs)
76 |
77 | @DATASETS.register_module()
78 | class PascalContext60Dataset(BaseSegDataset):
79 | METAINFO = dict(
80 | classes=('background', 'aeroplane', 'bag', 'bed', 'bedclothes',
81 | 'bench', 'bicycle', 'bird', 'boat', 'book', 'bottle',
82 | 'building', 'bus', 'cabinet', 'car', 'cat', 'ceiling',
83 | 'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain', 'dog',
84 | 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground',
85 | 'horse', 'keyboard', 'light', 'motorbike', 'mountain',
86 | 'mouse', 'person', 'plate', 'platform', 'pottedplant', 'road',
87 | 'rock', 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow',
88 | 'sofa', 'table', 'track', 'train', 'tree', 'truck',
89 | 'tvmonitor', 'wall', 'water', 'window', 'wood'),
90 | palette=[[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
91 | [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
92 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
93 | [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
94 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
95 | [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
96 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
97 | [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
98 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
99 | [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
100 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
101 | [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
102 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
103 | [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
104 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]])
105 |
106 | def __init__(self,
107 | ann_file: str,
108 | img_suffix='.jpg',
109 | seg_map_suffix='.png',
110 | **kwargs) -> None:
111 | super().__init__(
112 | img_suffix=img_suffix,
113 | seg_map_suffix=seg_map_suffix,
114 | ann_file=ann_file,
115 | reduce_zero_label=False,
116 | **kwargs)
117 |
118 |
119 | @DATASETS.register_module()
120 | class PascalContext59Dataset(BaseSegDataset):
121 | METAINFO = dict(
122 | classes=('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle',
123 | 'bird', 'boat', 'book', 'bottle', 'building', 'bus',
124 | 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth',
125 | 'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence',
126 | 'floor', 'flower', 'food', 'grass', 'ground', 'horse',
127 | 'keyboard', 'light', 'motorbike', 'mountain', 'mouse',
128 | 'person', 'plate', 'platform', 'pottedplant', 'road', 'rock',
129 | 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa',
130 | 'table', 'track', 'train', 'tree', 'truck', 'tvmonitor',
131 | 'wall', 'water', 'window', 'wood'),
132 | palette=[[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3],
133 | [120, 120, 80], [140, 140, 140], [204, 5, 255],
134 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
135 | [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
136 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
137 | [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
138 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
139 | [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
140 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
141 | [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
142 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
143 | [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
144 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
145 | [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
146 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]])
147 |
148 | def __init__(self,
149 | ann_file: str,
150 | img_suffix='.jpg',
151 | seg_map_suffix='.png',
152 | reduce_zero_label=True,
153 | **kwargs):
154 | super().__init__(
155 | img_suffix=img_suffix,
156 | seg_map_suffix=seg_map_suffix,
157 | ann_file=ann_file,
158 | reduce_zero_label=reduce_zero_label,
159 | **kwargs)
--------------------------------------------------------------------------------
/src/training/distributed.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.distributed as dist
5 |
6 | try:
7 | import horovod.torch as hvd
8 | except ImportError:
9 | hvd = None
10 |
11 | from datetime import timedelta
12 |
13 | def is_global_master(args):
14 | return args.rank == 0
15 |
16 |
17 | def is_local_master(args):
18 | return args.local_rank == 0
19 |
20 |
21 | def is_master(args, local=False):
22 | return is_local_master(args) if local else is_global_master(args)
23 |
24 |
25 | def is_using_horovod():
26 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
27 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
28 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
29 | pmi_vars = ["PMI_RANK", "PMI_SIZE"]
30 | if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]):
31 | return True
32 | else:
33 | return False
34 |
35 |
36 | def is_using_distributed():
37 | if 'WORLD_SIZE' in os.environ:
38 | return int(os.environ['WORLD_SIZE']) > 1
39 | if 'SLURM_NTASKS' in os.environ:
40 | return int(os.environ['SLURM_NTASKS']) > 1
41 | return False
42 |
43 |
44 | def world_info_from_env():
45 | local_rank = 0
46 | for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'):
47 | if v in os.environ:
48 | local_rank = int(os.environ[v])
49 | break
50 | global_rank = 0
51 | for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'):
52 | if v in os.environ:
53 | global_rank = int(os.environ[v])
54 | break
55 | world_size = 1
56 | for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'):
57 | if v in os.environ:
58 | world_size = int(os.environ[v])
59 | break
60 |
61 | return local_rank, global_rank, world_size
62 |
63 | def init_distributed_device(args):
64 | # Distributed training = training on more than one GPU.
65 | # Works in both single and multi-node scenarios.
66 | args.distributed = False
67 | args.world_size = 1
68 | args.rank = 0 # global rank
69 | args.local_rank = 0
70 | if args.horovod:
71 | assert hvd is not None, "Horovod is not installed"
72 | hvd.init()
73 | args.local_rank = int(hvd.local_rank())
74 | args.rank = hvd.rank()
75 | args.world_size = hvd.size()
76 | args.distributed = True
77 | os.environ['LOCAL_RANK'] = str(args.local_rank)
78 | os.environ['RANK'] = str(args.rank)
79 | os.environ['WORLD_SIZE'] = str(args.world_size)
80 | elif is_using_distributed():
81 | timeout=None
82 | if 'SLURM_PROCID' in os.environ:
83 | # DDP via SLURM
84 | args.local_rank, args.rank, args.world_size = world_info_from_env()
85 | # SLURM var -> torch.distributed vars in case needed
86 | os.environ['LOCAL_RANK'] = str(args.local_rank)
87 | os.environ['RANK'] = str(args.rank)
88 | os.environ['WORLD_SIZE'] = str(args.world_size)
89 | torch.distributed.init_process_group(
90 | backend=args.dist_backend,
91 | init_method=args.dist_url,
92 | world_size=args.world_size,
93 | rank=args.rank,
94 | timeout=timeout
95 | )
96 | else:
97 | # DDP via torchrun, torch.distributed.launch
98 | args.local_rank, _, _ = world_info_from_env()
99 | torch.distributed.init_process_group(
100 | backend=args.dist_backend,
101 | init_method=args.dist_url,
102 | timeout=timeout)
103 | args.world_size = torch.distributed.get_world_size()
104 | args.rank = torch.distributed.get_rank()
105 | args.distributed = True
106 |
107 | if torch.cuda.is_available():
108 | if args.distributed and not args.no_set_device_rank:
109 | device = 'cuda:%d' % args.local_rank
110 | else:
111 | device = 'cuda:0'
112 | torch.cuda.set_device(device)
113 | else:
114 | device = 'cpu'
115 | args.device = device
116 | device = torch.device(device)
117 | return device
118 |
119 |
120 | def broadcast_object(args, obj, src=0):
121 | # broadcast a pickle-able python object from rank-0 to all ranks
122 | if args.horovod:
123 | return hvd.broadcast_object(obj, root_rank=src)
124 | else:
125 | if args.rank == src:
126 | objects = [obj]
127 | else:
128 | objects = [None]
129 | dist.broadcast_object_list(objects, src=src)
130 | return objects[0]
131 |
132 |
133 | def all_gather_object(args, obj, dst=0):
134 | # gather a pickle-able python object across all ranks
135 | if args.horovod:
136 | return hvd.allgather_object(obj)
137 | else:
138 | objects = [None for _ in range(args.world_size)]
139 | dist.all_gather_object(objects, obj)
140 | return objects
141 |
--------------------------------------------------------------------------------
/src/training/file_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import multiprocessing
4 | import subprocess
5 | import time
6 | import fsspec
7 | import torch
8 | from tqdm import tqdm
9 |
10 | def remote_sync_s3(local_dir, remote_dir):
11 | # skip epoch_latest which can change during sync.
12 | result = subprocess.run(["aws", "s3", "sync", local_dir, remote_dir, '--exclude', '*epoch_latest.pt'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
13 | if result.returncode != 0:
14 | logging.error(f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}")
15 | return False
16 |
17 | logging.info(f"Successfully synced with S3 bucket")
18 | return True
19 |
20 | def remote_sync_fsspec(local_dir, remote_dir):
21 | # FIXME currently this is slow and not recommended. Look into speeding up.
22 | a = fsspec.get_mapper(local_dir)
23 | b = fsspec.get_mapper(remote_dir)
24 |
25 | for k in a:
26 | # skip epoch_latest which can change during sync.
27 | if 'epoch_latest.pt' in k:
28 | continue
29 |
30 | logging.info(f'Attempting to sync {k}')
31 | if k in b and len(a[k]) == len(b[k]):
32 | logging.debug(f'Skipping remote sync for {k}.')
33 | continue
34 |
35 | try:
36 | logging.info(f'Successful sync for {k}.')
37 | b[k] = a[k]
38 | except Exception as e:
39 | logging.info(f'Error during remote sync for {k}: {e}')
40 | return False
41 |
42 | return True
43 |
44 | def remote_sync(local_dir, remote_dir, protocol):
45 | logging.info('Starting remote sync.')
46 | if protocol == 's3':
47 | return remote_sync_s3(local_dir, remote_dir)
48 | elif protocol == 'fsspec':
49 | return remote_sync_fsspec(local_dir, remote_dir)
50 | else:
51 | logging.error('Remote protocol not known')
52 | return False
53 |
54 | def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol):
55 | while True:
56 | time.sleep(sync_every)
57 | remote_sync(local_dir, remote_dir, protocol)
58 |
59 | def start_sync_process(sync_every, local_dir, remote_dir, protocol):
60 | p = multiprocessing.Process(target=keep_running_remote_sync, args=(sync_every, local_dir, remote_dir, protocol))
61 | return p
62 |
63 | # Note: we are not currently using this save function.
64 | def pt_save(pt_obj, file_path):
65 | of = fsspec.open(file_path, "wb")
66 | with of as f:
67 | torch.save(pt_obj, file_path)
68 |
69 | def pt_load(file_path, map_location=None):
70 | if file_path.startswith('s3'):
71 | logging.info('Loading remote checkpoint, which may take a bit.')
72 | of = fsspec.open(file_path, "rb")
73 | with of as f:
74 | out = torch.load(f, map_location=map_location, weights_only=False)
75 | return out
76 |
77 | def check_exists(file_path):
78 | try:
79 | with fsspec.open(file_path):
80 | pass
81 | except FileNotFoundError:
82 | return False
83 | return True
84 |
--------------------------------------------------------------------------------
/src/training/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 |
4 | def setup_logging(log_file, level, include_host=False):
5 | if include_host:
6 | import socket
7 | hostname = socket.gethostname()
8 | formatter = logging.Formatter(
9 | f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S')
10 | else:
11 | formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S')
12 |
13 | logging.root.setLevel(level)
14 | loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
15 | for logger in loggers:
16 | logger.setLevel(level)
17 |
18 | stream_handler = logging.StreamHandler()
19 | stream_handler.setFormatter(formatter)
20 | logging.root.addHandler(stream_handler)
21 |
22 | if log_file:
23 | file_handler = logging.FileHandler(filename=log_file)
24 | file_handler.setFormatter(formatter)
25 | logging.root.addHandler(file_handler)
26 |
27 |
--------------------------------------------------------------------------------
/src/training/pamr.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 TU Darmstadt
2 | # Licnese: Apache 2.0 License.
3 | # https://github.com/visinf/1-stage-wseg/blob/master/models/mods/pamr.py
4 | import torch
5 | import torch.nn.functional as F
6 | import torch.nn as nn
7 |
8 | from functools import partial
9 |
10 | #
11 | # Helper modules
12 | #
13 | class LocalAffinity(nn.Module):
14 |
15 | def __init__(self, dilations=[1]):
16 | super(LocalAffinity, self).__init__()
17 | self.dilations = dilations
18 | weight = self._init_aff()
19 | self.register_buffer('kernel', weight)
20 |
21 | def _init_aff(self):
22 | # initialising the shift kernel
23 | weight = torch.zeros(8, 1, 3, 3)
24 |
25 | for i in range(weight.size(0)):
26 | weight[i, 0, 1, 1] = 1
27 |
28 | weight[0, 0, 0, 0] = -1
29 | weight[1, 0, 0, 1] = -1
30 | weight[2, 0, 0, 2] = -1
31 |
32 | weight[3, 0, 1, 0] = -1
33 | weight[4, 0, 1, 2] = -1
34 |
35 | weight[5, 0, 2, 0] = -1
36 | weight[6, 0, 2, 1] = -1
37 | weight[7, 0, 2, 2] = -1
38 |
39 | self.weight_check = weight.clone()
40 |
41 | return weight
42 |
43 | def forward(self, x):
44 |
45 | self.weight_check = self.weight_check.type_as(x)
46 | assert torch.all(self.weight_check.eq(self.kernel))
47 |
48 | B,K,H,W = x.size()
49 | x = x.view(B*K,1,H,W)
50 |
51 | x_affs = []
52 | for d in self.dilations:
53 | x_pad = F.pad(x, [d]*4, mode='replicate')
54 | x_aff = F.conv2d(x_pad, self.kernel, dilation=d)
55 | x_affs.append(x_aff)
56 |
57 | x_aff = torch.cat(x_affs, 1)
58 | return x_aff.view(B,K,-1,H,W)
59 |
60 | class LocalAffinityCopy(LocalAffinity):
61 |
62 | def _init_aff(self):
63 | # initialising the shift kernel
64 | weight = torch.zeros(8, 1, 3, 3)
65 |
66 | weight[0, 0, 0, 0] = 1
67 | weight[1, 0, 0, 1] = 1
68 | weight[2, 0, 0, 2] = 1
69 |
70 | weight[3, 0, 1, 0] = 1
71 | weight[4, 0, 1, 2] = 1
72 |
73 | weight[5, 0, 2, 0] = 1
74 | weight[6, 0, 2, 1] = 1
75 | weight[7, 0, 2, 2] = 1
76 |
77 | self.weight_check = weight.clone()
78 | return weight
79 |
80 | class LocalStDev(LocalAffinity):
81 |
82 | def _init_aff(self):
83 | weight = torch.zeros(9, 1, 3, 3)
84 | weight.zero_()
85 |
86 | weight[0, 0, 0, 0] = 1
87 | weight[1, 0, 0, 1] = 1
88 | weight[2, 0, 0, 2] = 1
89 |
90 | weight[3, 0, 1, 0] = 1
91 | weight[4, 0, 1, 1] = 1
92 | weight[5, 0, 1, 2] = 1
93 |
94 | weight[6, 0, 2, 0] = 1
95 | weight[7, 0, 2, 1] = 1
96 | weight[8, 0, 2, 2] = 1
97 |
98 | self.weight_check = weight.clone()
99 | return weight
100 |
101 | def forward(self, x):
102 | # returns (B,K,P,H,W), where P is the number
103 | # of locations
104 | x = super(LocalStDev, self).forward(x)
105 |
106 | return x.std(2, keepdim=True)
107 |
108 | class LocalAffinityAbs(LocalAffinity):
109 |
110 | def forward(self, x):
111 | x = super(LocalAffinityAbs, self).forward(x)
112 | return torch.abs(x)
113 |
114 | #
115 | # PAMR module
116 | #
117 | class PAMR(nn.Module):
118 |
119 | def __init__(self, num_iter=1, dilations=[1]):
120 | super(PAMR, self).__init__()
121 |
122 | self.num_iter = num_iter
123 | self.aff_x = LocalAffinityAbs(dilations)
124 | self.aff_m = LocalAffinityCopy(dilations)
125 | self.aff_std = LocalStDev(dilations)
126 |
127 | def forward(self, x, mask):
128 | mask = F.interpolate(mask, size=x.size()[-2:], mode="bilinear", align_corners=True)
129 |
130 | # x: [BxKxHxW]
131 | # mask: [BxCxHxW]
132 | B,K,H,W = x.size()
133 | _,C,_,_ = mask.size()
134 |
135 | x_std = self.aff_std(x)
136 |
137 | x = -self.aff_x(x) / (1e-8 + 0.1 * x_std)
138 | x = x.mean(1, keepdim=True)
139 | x = F.softmax(x, 2)
140 |
141 | for _ in range(self.num_iter):
142 | m = self.aff_m(mask) # [BxCxPxHxW]
143 | mask = (m * x).sum(2)
144 |
145 | # xvals: [BxCxHxW]
146 | return mask
--------------------------------------------------------------------------------
/src/training/precision.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from contextlib import suppress
3 |
4 |
5 | def get_autocast(precision):
6 | if precision == 'amp':
7 | return torch.cuda.amp.autocast
8 | elif precision == 'amp_bfloat16' or precision == 'amp_bf16':
9 | # amp_bfloat16 is more stable than amp float16 for clip training
10 | return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
11 | else:
12 | return suppress
13 |
--------------------------------------------------------------------------------
/src/training/profiler.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import open_clip
5 | import pandas as pd
6 | from torch.utils.flop_counter import FlopCounterMode
7 | try:
8 | import fvcore
9 | except:
10 | fvcore = None
11 |
12 | parser = argparse.ArgumentParser(description='OpenCLIP Profiler')
13 |
14 | # benchmark specific args
15 | parser.add_argument('--model', metavar='NAME', default='',
16 | help='model(s) to profile')
17 | parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
18 | help='Output csv file for results')
19 | parser.add_argument('--profiler', default='torch', type=str, choices=['torch', 'fvcore'])
20 | parser.add_argument('--batch-size', default=1, type=int, help='Batch size for profiling')
21 |
22 |
23 | def profile_fvcore(
24 | model,
25 | image_input_size=(3, 224, 224),
26 | text_input_size=(77,),
27 | batch_size=1,
28 | detailed=False,
29 | force_cpu=False
30 | ):
31 | if force_cpu:
32 | model = model.to('cpu')
33 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
34 | example_image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
35 | example_text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
36 | fca = fvcore.nn.FlopCountAnalysis(model, (example_image_input, example_text_input))
37 | aca = fvcore.nn.ActivationCountAnalysis(model, (example_image_input, example_text_input))
38 | if detailed:
39 | fcs = fvcore.nn.flop_count_str(fca)
40 | print(fcs)
41 | return fca.total() / batch_size, aca.total() / batch_size
42 |
43 |
44 | def profile_fvcore_text(
45 | model,
46 | text_input_size=(77,),
47 | batch_size=1,
48 | detailed=False,
49 | force_cpu=False
50 | ):
51 | if force_cpu:
52 | model = model.to('cpu')
53 | device = next(model.parameters()).device
54 | example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
55 | fca = fvcore.nn.FlopCountAnalysis(model, example_input)
56 | aca = fvcore.nn.ActivationCountAnalysis(model, example_input)
57 | if detailed:
58 | fcs = fvcore.nn.flop_count_str(fca)
59 | print(fcs)
60 | return fca.total() / batch_size, aca.total() / batch_size
61 |
62 |
63 | def profile_fvcore_image(
64 | model,
65 | image_input_size=(3, 224, 224),
66 | batch_size=1,
67 | detailed=False,
68 | force_cpu=False
69 | ):
70 | if force_cpu:
71 | model = model.to('cpu')
72 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
73 | example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
74 | fca = fvcore.nn.FlopCountAnalysis(model, example_input)
75 | aca = fvcore.nn.ActivationCountAnalysis(model, example_input)
76 | if detailed:
77 | fcs = fvcore.nn.flop_count_str(fca)
78 | print(fcs)
79 | return fca.total() / batch_size, aca.total() / batch_size
80 |
81 |
82 | def profile_torch_image(model, image_input_size, batch_size=1, force_cpu=False):
83 | """Profile the image encoder using torch.utils.flop_counter"""
84 | if force_cpu:
85 | model = model.to('cpu')
86 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
87 | example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
88 |
89 | flop_counter = FlopCounterMode()
90 | with flop_counter:
91 | model(example_input)
92 | total_flops = sum(flop_counter.get_flop_counts()['Global'].values())
93 | return total_flops / batch_size
94 |
95 |
96 | def profile_torch_text(model, text_input_size, batch_size=1, force_cpu=False):
97 | """Profile the text encoder using torch.utils.flop_counter"""
98 | if force_cpu:
99 | model = model.to('cpu')
100 | device = next(model.parameters()).device
101 | example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
102 |
103 | flop_counter = FlopCounterMode()
104 | with flop_counter:
105 | model(example_input)
106 | total_flops = sum(flop_counter.get_flop_counts()['Global'].values())
107 | return total_flops / batch_size
108 |
109 |
110 | def profile_torch(model, text_input_size, image_input_size, batch_size=1, force_cpu=False):
111 | """Profile the full model using torch.utils.flop_counter"""
112 | if force_cpu:
113 | model = model.to('cpu')
114 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
115 | image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
116 | text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
117 |
118 | flop_counter = FlopCounterMode()
119 | with flop_counter:
120 | model(image_input, text_input)
121 | total_flops = sum(flop_counter.get_flop_counts()['Global'].values())
122 | return total_flops / batch_size
123 |
124 |
125 | def count_params(model):
126 | return sum(m.numel() for m in model.parameters())
127 |
128 | def profile_model(model_name, batch_size=1, profiler='torch'):
129 | assert profiler in ['torch', 'fvcore'], 'Only torch and fvcore profilers are supported'
130 | if profiler == 'fvcore':
131 | assert fvcore is not None, 'Please install fvcore.'
132 | model = open_clip.create_model(model_name, force_custom_text=True, pretrained_hf=False)
133 | model.eval()
134 | if torch.cuda.is_available():
135 | model = model.cuda()
136 |
137 | if isinstance(model.visual.image_size, (tuple, list)):
138 | image_input_size = (3,) + tuple(model.visual.image_size[-2:])
139 | else:
140 | image_input_size = (3, model.visual.image_size, model.visual.image_size)
141 |
142 | text_input_size = (77,)
143 | if hasattr(model, 'context_length') and model.context_length:
144 | text_input_size = (model.context_length,)
145 |
146 | results = {}
147 | results['model'] = model_name
148 | results['image_size'] = image_input_size[1]
149 |
150 | model_cfg = open_clip.get_model_config(model_name)
151 | if model_cfg:
152 | vision_cfg = open_clip.CLIPVisionCfg(**model_cfg['vision_cfg'])
153 | text_cfg = open_clip.CLIPTextCfg(**model_cfg['text_cfg'])
154 | results['image_width'] = int(vision_cfg.width)
155 | results['text_width'] = int(text_cfg.width)
156 | results['embed_dim'] = int(model_cfg['embed_dim'])
157 | else:
158 | results['image_width'] = 0
159 | results['text_width'] = 0
160 | results['embed_dim'] = 0
161 |
162 | retries = 2
163 | while retries:
164 | retries -= 1
165 | try:
166 | results['mparams'] = round(count_params(model) / 1e6, 2)
167 | results['image_mparams'] = round(count_params(model.visual) / 1e6, 2)
168 | results['text_mparams'] = round(count_params(model.text) / 1e6, 2)
169 |
170 | if profiler == 'fvcore':
171 | macs, acts = profile_fvcore(
172 | model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size)
173 |
174 | image_macs, image_acts = profile_fvcore_image(
175 | model.visual, image_input_size=image_input_size, force_cpu=not retries, batch_size=batch_size)
176 |
177 | text_macs, text_acts = profile_fvcore_text(
178 | model.text, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size)
179 |
180 | results['gmacs'] = round(macs / 1e9, 2)
181 | results['macts'] = round(acts / 1e6, 2)
182 |
183 | results['image_gmacs'] = round(image_macs / 1e9, 2)
184 | results['image_macts'] = round(image_acts / 1e6, 2)
185 |
186 | results['text_gmacs'] = round(text_macs / 1e9, 2)
187 | results['text_macts'] = round(text_acts / 1e6, 2)
188 | elif profiler == 'torch':
189 | image_flops = profile_torch_image(
190 | model.visual, image_input_size=image_input_size, force_cpu=not retries, batch_size=batch_size)
191 | text_flops = profile_torch_text(
192 | model.text, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size)
193 | total_flops = profile_torch(
194 | model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size)
195 |
196 | results['gflops'] = round(total_flops / 1e9, 2)
197 | results['image_gflops'] = round(image_flops / 1e9, 2)
198 | results['text_gflops'] = round(text_flops / 1e9, 2)
199 |
200 | except RuntimeError as e:
201 | pass
202 | return results
203 |
204 |
205 | def main():
206 | args = parser.parse_args()
207 |
208 | # FIXME accept a text file name to allow lists of models in txt/csv
209 | if args.model == 'all':
210 | parsed_model = open_clip.list_models()
211 | else:
212 | parsed_model = args.model.split(',')
213 |
214 | results = []
215 | models_with_errors = []
216 | for m in parsed_model:
217 | print('='*100)
218 | print(f'Profiling {m}')
219 | try:
220 | row = profile_model(m, batch_size=args.batch_size, profiler=args.profiler)
221 | results.append(row)
222 | except Exception as e:
223 | print(f'Error profiling {m}: {e}')
224 | import traceback
225 | traceback.print_exc()
226 | models_with_errors.append(m)
227 |
228 | df = pd.DataFrame(results, columns=results[0].keys())
229 |
230 | if 'gmacs' in df.columns:
231 | df = df.sort_values(by=['gmacs', 'mparams', 'model'])
232 | else:
233 | df = df.sort_values(by=['gflops', 'mparams', 'model'])
234 |
235 | print('='*100)
236 | print('Done.')
237 | print(df)
238 | if args.results_file:
239 | df.to_csv(args.results_file, index=False)
240 |
241 | if models_with_errors:
242 | print('Models with errors:', models_with_errors)
243 |
244 |
245 | if __name__ == '__main__':
246 | main()
247 |
--------------------------------------------------------------------------------
/src/training/scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def assign_learning_rate(optimizer, new_lr):
5 | for param_group in optimizer.param_groups:
6 | param_group["lr"] = new_lr
7 |
8 |
9 | def _warmup_lr(base_lr, warmup_length, step):
10 | return base_lr * (step + 1) / warmup_length
11 |
12 |
13 | def const_lr(optimizer, base_lr, warmup_length, steps):
14 | def _lr_adjuster(step):
15 | if step < warmup_length:
16 | lr = _warmup_lr(base_lr, warmup_length, step)
17 | else:
18 | lr = base_lr
19 | assign_learning_rate(optimizer, lr)
20 | return lr
21 | return _lr_adjuster
22 |
23 |
24 | def const_lr_cooldown(optimizer, base_lr, warmup_length, steps, cooldown_steps, cooldown_power=1.0, cooldown_end_lr=0.):
25 | def _lr_adjuster(step):
26 | start_cooldown_step = steps - cooldown_steps
27 | if step < warmup_length:
28 | lr = _warmup_lr(base_lr, warmup_length, step)
29 | else:
30 | if step < start_cooldown_step:
31 | lr = base_lr
32 | else:
33 | e = step - start_cooldown_step
34 | es = steps - start_cooldown_step
35 | # linear decay if power == 1; polynomial decay otherwise;
36 | decay = (1 - (e/es)) ** cooldown_power
37 | lr = decay * (base_lr - cooldown_end_lr) + cooldown_end_lr
38 | assign_learning_rate(optimizer, lr)
39 | return lr
40 | return _lr_adjuster
41 |
42 |
43 | def cosine_lr(optimizer, base_lr, warmup_length, steps):
44 | def _lr_adjuster(step):
45 | if step < warmup_length:
46 | lr = _warmup_lr(base_lr, warmup_length, step)
47 | else:
48 | e = step - warmup_length
49 | es = steps - warmup_length
50 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
51 | assign_learning_rate(optimizer, lr)
52 | return lr
53 | return _lr_adjuster
54 |
55 |
56 | def cosine_scheduler(base_value, final_value, warmup_length, steps):
57 | def _adjuster(step):
58 | if step < warmup_length:
59 | lr = base_value * (step + 1) / warmup_length
60 | else:
61 | e = step - warmup_length
62 | es = steps - warmup_length
63 | lr = final_value + 0.5 * (1 + np.cos(np.pi * e / es)) * (base_value - final_value)
64 | return lr
65 | return _adjuster
--------------------------------------------------------------------------------
/src/training/seg_configs/base_config.py:
--------------------------------------------------------------------------------
1 | # base configurations
2 | model = dict(
3 | type='CLIPForSegmentation',
4 | clip_path='ViT-B/16'
5 | )
6 |
7 | test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
8 |
9 | default_scope = 'mmseg'
10 | env_cfg = dict(
11 | cudnn_benchmark=True,
12 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
13 | dist_cfg=dict(backend='nccl'),
14 | )
15 | vis_backends = [dict(type='LocalVisBackend')]
16 | visualizer = dict(
17 | type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer')
18 | log_processor = dict(by_epoch=False)
19 | log_level = 'INFO'
20 | load_from = None
21 | resume = False
22 |
23 | test_cfg = dict(type='TestLoop')
24 |
25 |
26 | default_hooks = dict(
27 | timer=dict(type='IterTimerHook'),
28 | logger=dict(type='LoggerHook', interval=200, log_metric_by_epoch=False),
29 | param_scheduler=dict(type='ParamSchedulerHook'),
30 | checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000),
31 | sampler_seed=dict(type='DistSamplerSeedHook'),
32 | visualization=dict(type='SegVisualizationHook', interval=1))
33 |
--------------------------------------------------------------------------------
/src/training/seg_configs/cfg_ade20k.py:
--------------------------------------------------------------------------------
1 | # ADE20K config file that define test pipeline
2 | # Please refer to here https://github.com/wangf3014/SCLIP/blob/main/configs/cfg_ade20k.py
3 | _base_ = './base_config.py'
4 |
5 | # model settings
6 | model = dict(
7 | name_path='./training/seg_configs/cls_ade20k.txt'
8 | )
9 |
10 | # dataset settings
11 | dataset_type = 'ADE20KDataset'
12 | data_root = '/mmsegmentation_datasets/data/ade/ADEChallengeData2016'
13 |
14 | test_pipeline = [
15 | dict(type='LoadImageFromFile'),
16 | dict(type='Resize', scale=(2048, 336), keep_ratio=True),
17 | dict(type='LoadAnnotations', reduce_zero_label=True),
18 | dict(type='PackSegInputs')
19 | ]
20 |
21 | test_dataloader = dict(
22 | batch_size=1,
23 | num_workers=4,
24 | persistent_workers=True,
25 | sampler=dict(type='DefaultSampler', shuffle=False),
26 | dataset=dict(
27 | type=dataset_type,
28 | data_root=data_root,
29 | data_prefix=dict(
30 | img_path='images/validation',
31 | seg_map_path='annotations/validation'),
32 | pipeline=test_pipeline))
--------------------------------------------------------------------------------
/src/training/seg_configs/cfg_city_scapes.py:
--------------------------------------------------------------------------------
1 | _base_ = './base_config.py'
2 |
3 | # model settings
4 | model = dict(
5 | name_path='./training/seg_configs/cls_city_scapes.txt'
6 | )
7 |
8 | # dataset settings
9 | dataset_type = 'CityscapesDataset'
10 | data_root = '/mmsegmentation_datasets/data/cityscapes'
11 |
12 | test_pipeline = [
13 | dict(type='LoadImageFromFile'),
14 | dict(type='Resize', scale=(2048, 560), keep_ratio=True),
15 | # add loading annotation after ``Resize`` because ground truth
16 | # does not need to do resize data transform
17 | dict(type='LoadAnnotations'),
18 | dict(type='PackSegInputs')
19 | ]
20 |
21 | test_dataloader = dict(
22 | batch_size=1,
23 | num_workers=4,
24 | persistent_workers=True,
25 | sampler=dict(type='DefaultSampler', shuffle=False),
26 | dataset=dict(
27 | type=dataset_type,
28 | data_root=data_root,
29 | data_prefix=dict(
30 | img_path='leftImg8bit/val', seg_map_path='gtFine/val'),
31 | pipeline=test_pipeline))
--------------------------------------------------------------------------------
/src/training/seg_configs/cfg_coco_object.py:
--------------------------------------------------------------------------------
1 | _base_ = './base_config.py'
2 |
3 | # model settings
4 | model = dict(
5 | name_path='./training/seg_configs/cls_coco_object.txt',
6 | logit_scale=50,
7 | prob_thd=0.1
8 | )
9 |
10 | # dataset settings
11 | dataset_type = 'COCOObjectDataset'
12 | data_root = '/mmsegmentation_datasets/data/coco_stuff164k'
13 |
14 | test_pipeline = [
15 | dict(type='LoadImageFromFile'),
16 | dict(type='Resize', scale=(2048, 336), keep_ratio=True),
17 | # add loading annotation after ``Resize`` because ground truth
18 | # does not need to do resize data transform
19 | dict(type='LoadAnnotations'),
20 | dict(type='PackSegInputs')
21 | ]
22 |
23 | test_dataloader = dict(
24 | batch_size=1,
25 | num_workers=4,
26 | persistent_workers=True,
27 | sampler=dict(type='DefaultSampler', shuffle=False),
28 | dataset=dict(
29 | type=dataset_type,
30 | data_root=data_root,
31 | reduce_zero_label=False,
32 | data_prefix=dict(
33 | img_path='images/val2017', seg_map_path='annotations/val2017'),
34 | pipeline=test_pipeline))
--------------------------------------------------------------------------------
/src/training/seg_configs/cfg_coco_stuff164k.py:
--------------------------------------------------------------------------------
1 | _base_ = './base_config.py'
2 |
3 | # model settings
4 | model = dict(
5 | name_path='./training/seg_configs/cls_coco_stuff.txt'
6 | )
7 |
8 | # dataset settings
9 | dataset_type = 'COCOStuffDataset'
10 | data_root = '/mmsegmentation_datasets/data/coco_stuff164k'
11 |
12 | test_pipeline = [
13 | dict(type='LoadImageFromFile'),
14 | dict(type='Resize', scale=(2048, 448), keep_ratio=True),
15 | dict(type='LoadAnnotations'),
16 | dict(type='PackSegInputs')
17 | ]
18 |
19 | test_dataloader = dict(
20 | batch_size=1,
21 | num_workers=4,
22 | persistent_workers=True,
23 | sampler=dict(type='DefaultSampler', shuffle=False),
24 | dataset=dict(
25 | type=dataset_type,
26 | data_root=data_root,
27 | data_prefix=dict(
28 | img_path='images/val2017', seg_map_path='annotations/val2017'),
29 | pipeline=test_pipeline))
--------------------------------------------------------------------------------
/src/training/seg_configs/cfg_context59.py:
--------------------------------------------------------------------------------
1 | _base_ = './base_config.py'
2 |
3 | # model settings
4 | model = dict(
5 | name_path='./training/seg_configs/cls_context59.txt'
6 | )
7 |
8 | # dataset settings
9 | dataset_type = 'PascalContext59Dataset'
10 | data_root = '/mmsegmentation_datasets/data/VOCdevkit/VOC2010'
11 |
12 | test_pipeline = [
13 | dict(type='LoadImageFromFile'),
14 | dict(type='Resize', scale=(2048, 336), keep_ratio=True),
15 | dict(type='LoadAnnotations', reduce_zero_label=True),
16 | dict(type='PackSegInputs')
17 | ]
18 |
19 | test_dataloader = dict(
20 | batch_size=1,
21 | num_workers=4,
22 | persistent_workers=True,
23 | sampler=dict(type='DefaultSampler', shuffle=False),
24 | dataset=dict(
25 | type=dataset_type,
26 | data_root=data_root,
27 | data_prefix=dict(
28 | img_path='JPEGImages', seg_map_path='SegmentationClassContext'),
29 | ann_file='ImageSets/SegmentationContext/val.txt',
30 | pipeline=test_pipeline))
--------------------------------------------------------------------------------
/src/training/seg_configs/cfg_context60.py:
--------------------------------------------------------------------------------
1 | _base_ = './base_config.py'
2 |
3 | # model settings
4 | model = dict(
5 | name_path='./training/seg_configs/cls_context60.txt',
6 | logit_scale=50,
7 | prob_thd=0.1
8 | )
9 |
10 | # dataset settings
11 | dataset_type = 'PascalContext60Dataset'
12 | data_root = '/mmsegmentation_datasets/data/VOCdevkit/VOC2010'
13 |
14 | test_pipeline = [
15 | dict(type='LoadImageFromFile'),
16 | dict(type='Resize', scale=(2048, 336), keep_ratio=True),
17 | dict(type='LoadAnnotations'),
18 | dict(type='PackSegInputs')
19 | ]
20 |
21 | test_dataloader = dict(
22 | batch_size=1,
23 | num_workers=4,
24 | persistent_workers=True,
25 | sampler=dict(type='DefaultSampler', shuffle=False),
26 | dataset=dict(
27 | type=dataset_type,
28 | data_root=data_root,
29 | data_prefix=dict(
30 | img_path='JPEGImages', seg_map_path='SegmentationClassContext'),
31 | ann_file='ImageSets/SegmentationContext/val.txt',
32 | pipeline=test_pipeline))
--------------------------------------------------------------------------------
/src/training/seg_configs/cfg_voc20.py:
--------------------------------------------------------------------------------
1 | # VOC20 config file that define test pipeline
2 | # Please refer to here https://github.com/wangf3014/SCLIP/blob/main/configs/cfg_voc20.py
3 | _base_ = './base_config.py'
4 |
5 | # model settings
6 | model = dict(
7 | name_path='./training/seg_configs/cls_voc20.txt'
8 | )
9 |
10 | # dataset settings
11 | dataset_type = 'PascalVOC20Dataset'
12 | data_root = '/mmsegmentation_datasets/data/VOCdevkit/VOC2012'
13 |
14 | test_pipeline = [
15 | dict(type='LoadImageFromFile'),
16 | dict(type='Resize', scale=(2048, 336), keep_ratio=True),
17 | dict(type='LoadAnnotations'),
18 | dict(type='PackSegInputs')
19 | ]
20 |
21 | test_dataloader = dict(
22 | batch_size=1,
23 | num_workers=4,
24 | persistent_workers=True,
25 | sampler=dict(type='DefaultSampler', shuffle=False),
26 | dataset=dict(
27 | type=dataset_type,
28 | data_root=data_root,
29 | data_prefix=dict(
30 | img_path='JPEGImages', seg_map_path='SegmentationClass'),
31 | ann_file='ImageSets/Segmentation/val.txt',
32 | pipeline=test_pipeline))
--------------------------------------------------------------------------------
/src/training/seg_configs/cfg_voc21.py:
--------------------------------------------------------------------------------
1 | _base_ = './base_config.py'
2 |
3 | # model settings
4 | model = dict(
5 | name_path='./training/seg_configs/cls_voc21.txt',
6 | logit_scale=65,
7 | prob_thd=0.1,
8 | area_thd=0.1
9 | )
10 |
11 | # dataset settings
12 | dataset_type = 'PascalVOCDataset'
13 | data_root = '/mmsegmentation_datasets/data/VOCdevkit/VOC2012'
14 |
15 | test_pipeline = [
16 | dict(type='LoadImageFromFile'),
17 | dict(type='Resize', scale=(2048, 336), keep_ratio=True),
18 | dict(type='LoadAnnotations'),
19 | dict(type='PackSegInputs')
20 | ]
21 |
22 | test_dataloader = dict(
23 | batch_size=1,
24 | num_workers=4,
25 | persistent_workers=True,
26 | sampler=dict(type='DefaultSampler', shuffle=False),
27 | dataset=dict(
28 | type=dataset_type,
29 | data_root=data_root,
30 | data_prefix=dict(
31 | img_path='JPEGImages', seg_map_path='SegmentationClass'),
32 | ann_file='ImageSets/Segmentation/val.txt',
33 | pipeline=test_pipeline))
--------------------------------------------------------------------------------
/src/training/seg_configs/cls_ade20k.txt:
--------------------------------------------------------------------------------
1 | wall
2 | building
3 | sky
4 | floor
5 | tree
6 | ceiling
7 | road
8 | bed
9 | windowpane
10 | grass
11 | cabinet
12 | sidewalk
13 | person
14 | earth
15 | door
16 | table
17 | mountain
18 | plant
19 | curtain
20 | chair
21 | car
22 | water
23 | painting
24 | sofa
25 | shelf
26 | house
27 | sea
28 | mirror
29 | rug
30 | field
31 | armchair
32 | seat
33 | fence
34 | desk
35 | rock
36 | wardrobe
37 | lamp
38 | bathtub
39 | railing
40 | cushion
41 | base
42 | box
43 | column
44 | signboard
45 | chestofdrawers
46 | counter
47 | sand
48 | sink
49 | skyscraper
50 | fireplace
51 | refrigerator
52 | grandstand
53 | path
54 | stairs
55 | runway
56 | case
57 | pooltable
58 | pillow
59 | screendoor
60 | stairway
61 | river
62 | bridge
63 | bookcase
64 | blind
65 | coffeetable
66 | toilet
67 | flower
68 | book
69 | hill
70 | bench
71 | countertop
72 | stove
73 | palm
74 | kitchenisland
75 | computer
76 | swivelchair
77 | boat
78 | bar
79 | arcademachine
80 | hovel
81 | bus
82 | towel
83 | light
84 | truck
85 | tower
86 | chandelier
87 | awning
88 | streetlight
89 | booth
90 | televisionreceiver
91 | airplane
92 | dirttrack
93 | apparel
94 | pole
95 | land
96 | bannister
97 | escalator
98 | ottoman
99 | bottle
100 | buffet
101 | poster
102 | stage
103 | van
104 | ship
105 | fountain
106 | conveyerbelt
107 | canopy
108 | washer
109 | plaything
110 | swimmingpool
111 | stool
112 | barrel
113 | basket
114 | waterfall
115 | tent
116 | bag
117 | minibike
118 | cradle
119 | oven
120 | ball
121 | food
122 | step
123 | tank
124 | tradename
125 | microwave
126 | pot
127 | animal
128 | bicycle
129 | lake
130 | dishwasher
131 | screen
132 | blanket
133 | sculpture
134 | hood
135 | sconce
136 | vase
137 | trafficlight
138 | tray
139 | ashcan
140 | fan
141 | pier
142 | crtscreen
143 | plate
144 | monitor
145 | bulletinboard
146 | shower
147 | radiator
148 | glass
149 | clock
150 | flag
--------------------------------------------------------------------------------
/src/training/seg_configs/cls_city_scapes.txt:
--------------------------------------------------------------------------------
1 | road
2 | sidewalk
3 | building
4 | wall
5 | fence
6 | pole
7 | trafficlight
8 | trafficsign
9 | vegetation
10 | terrain
11 | sky
12 | person
13 | rider
14 | car
15 | truck
16 | bus
17 | train
18 | motorcycle
19 | bicycle
--------------------------------------------------------------------------------
/src/training/seg_configs/cls_coco_object.txt:
--------------------------------------------------------------------------------
1 | sky, wall, tree, wood, grass, road, sea, river, mountain, sands, desk, bed, building, cloud, lamp, door, window, wardrobe, ceiling, shelf, curtain, stair, floor, hill, rail, fence
2 | person, person in shirt, person in jeans, person in dress, person in sweater, person in skirt, person in jacket, body
3 | bicycle
4 | car
5 | motorcycle
6 | airplane
7 | bus
8 | train
9 | truck
10 | boat
11 | traffic light
12 | fire hydrant
13 | stop sign
14 | parking meter
15 | bench
16 | bird
17 | cat
18 | dog
19 | horse
20 | sheep
21 | cow
22 | elephant
23 | bear
24 | zebra
25 | giraffe
26 | backpack
27 | umbrella
28 | handbag
29 | tie
30 | suitcase
31 | frisbee
32 | skis
33 | snowboard
34 | sports ball
35 | kite
36 | baseball bat
37 | baseball glove
38 | skateboard
39 | surfboard
40 | tennis racket
41 | bottle
42 | wine glass
43 | cup
44 | fork
45 | knife
46 | spoon
47 | bowl
48 | banana
49 | apple
50 | sandwich
51 | orange
52 | broccoli
53 | carrot
54 | hot dog
55 | pizza
56 | donut
57 | cake
58 | chair
59 | couch
60 | potted plant
61 | bed
62 | dining table
63 | toilet
64 | tv
65 | laptop
66 | mouse
67 | remote
68 | keyboard
69 | cell phone
70 | microwave
71 | oven
72 | toaster
73 | sink
74 | refrigerator
75 | book
76 | clock
77 | vase
78 | scissors
79 | teddy bear
80 | hair drier
81 | toothbrush
--------------------------------------------------------------------------------
/src/training/seg_configs/cls_coco_stuff.txt:
--------------------------------------------------------------------------------
1 | person
2 | bicycle
3 | car
4 | motorcycle
5 | airplane
6 | bus
7 | train
8 | truck
9 | boat
10 | trafficlight
11 | firehydrant
12 | stopsign
13 | parkingmeter
14 | bench
15 | bird
16 | cat
17 | dog
18 | horse
19 | sheep
20 | cow
21 | elephant
22 | bear
23 | zebra
24 | giraffe
25 | backpack
26 | umbrella
27 | handbag
28 | tie
29 | suitcase
30 | frisbee
31 | skis
32 | snowboard
33 | sportsball
34 | kite
35 | baseballbat
36 | baseballglove
37 | skateboard
38 | surfboard
39 | tennisracket
40 | bottle
41 | wineglass
42 | cup
43 | fork
44 | knife
45 | spoon
46 | bowl
47 | banana
48 | apple
49 | sandwich
50 | orange
51 | broccoli
52 | carrot
53 | hotdog
54 | pizza
55 | donut
56 | cake
57 | chair
58 | couch
59 | pottedplant
60 | bed
61 | diningtable
62 | toilet
63 | tv
64 | laptop
65 | mouse
66 | remote
67 | keyboard
68 | cellphone
69 | microwave
70 | oven
71 | toaster
72 | sink
73 | refrigerator
74 | book
75 | clock
76 | vase
77 | scissors
78 | teddybear
79 | hairdrier
80 | toothbrush
81 | banner
82 | blanket
83 | branch
84 | bridge
85 | building-other
86 | bush
87 | cabinet
88 | cage
89 | cardboard
90 | carpet
91 | ceiling-other
92 | ceiling-tile
93 | cloth
94 | clothes
95 | clouds
96 | counter
97 | cupboard
98 | curtain
99 | desk-stuff
100 | dirt
101 | door-stuff
102 | fence
103 | floor-marble
104 | floor-other
105 | floor-stone
106 | floor-tile
107 | floor-wood
108 | flower
109 | fog
110 | food-other
111 | fruit
112 | furniture-other
113 | grass
114 | gravel
115 | ground-other
116 | hill
117 | house
118 | leaves
119 | light
120 | mat
121 | metal
122 | mirror-stuff
123 | moss
124 | mountain
125 | mud
126 | napkin
127 | net
128 | paper
129 | pavement
130 | pillow
131 | plant-other
132 | plastic
133 | platform
134 | playingfield
135 | railing
136 | railroad
137 | river
138 | road
139 | rock
140 | roof
141 | rug
142 | salad
143 | sand
144 | sea
145 | shelf
146 | sky-other
147 | skyscraper
148 | snow
149 | solid-other
150 | stairs
151 | stone
152 | straw
153 | structural-other
154 | table
155 | tent
156 | textile-other
157 | towel
158 | tree
159 | vegetable
160 | wall-brick
161 | wall-concrete
162 | wall-other
163 | wall-panel
164 | wall-stone
165 | wall-tile
166 | wall-wood
167 | water-other
168 | waterdrops
169 | window-blind
170 | window-other
171 | wood
--------------------------------------------------------------------------------
/src/training/seg_configs/cls_context59.txt:
--------------------------------------------------------------------------------
1 | aeroplane
2 | bag
3 | bed
4 | bedclothes
5 | bench
6 | bicycle
7 | bird
8 | boat
9 | book
10 | bottle
11 | building
12 | bus
13 | cabinet
14 | car
15 | cat
16 | ceiling
17 | chair
18 | cloth
19 | computer
20 | cow
21 | cup
22 | curtain
23 | dog
24 | door
25 | fence
26 | floor
27 | flower
28 | food
29 | grass
30 | ground
31 | horse
32 | keyboard
33 | light
34 | motorbike
35 | mountain
36 | mouse
37 | person
38 | plate
39 | platform
40 | pottedplant
41 | road
42 | rock
43 | sheep
44 | shelves
45 | sidewalk
46 | sign
47 | sky
48 | snow
49 | sofa
50 | table
51 | track
52 | train
53 | tree
54 | truck
55 | tvmonitor
56 | wall
57 | water
58 | window
59 | wood
--------------------------------------------------------------------------------
/src/training/seg_configs/cls_context60.txt:
--------------------------------------------------------------------------------
1 | background
2 | aeroplane
3 | bag
4 | bed
5 | bedclothes
6 | bench
7 | bicycle
8 | bird
9 | boat
10 | book
11 | bottle
12 | building
13 | bus
14 | cabinet
15 | car
16 | cat
17 | ceiling
18 | chair
19 | cloth
20 | computer
21 | cow
22 | cup
23 | curtain
24 | dog
25 | door
26 | fence
27 | floor
28 | flower
29 | food
30 | grass
31 | ground
32 | horse
33 | keyboard
34 | light
35 | motorbike
36 | mountain
37 | mouse
38 | person
39 | plate
40 | platform
41 | pottedplant
42 | road
43 | rock
44 | sheep
45 | shelves
46 | sidewalk
47 | sign
48 | sky
49 | snow
50 | sofa
51 | table
52 | track
53 | train
54 | tree
55 | truck
56 | tvmonitor
57 | wall
58 | water
59 | window
60 | wood
--------------------------------------------------------------------------------
/src/training/seg_configs/cls_voc20.txt:
--------------------------------------------------------------------------------
1 | aeroplane
2 | bicycle
3 | bird
4 | ship
5 | bottle
6 | bus
7 | car
8 | cat
9 | chair
10 | cow
11 | table
12 | dog
13 | horse
14 | motorbike
15 | person, person in shirt, person in jeans, person in dress, person in sweater, person in skirt, person in jacket
16 | pottedplant
17 | sheep
18 | sofa
19 | train
20 | television monitor, tv monitor, monitor, television, screen
--------------------------------------------------------------------------------
/src/training/seg_configs/cls_voc21.txt:
--------------------------------------------------------------------------------
1 | sky, wall, tree, wood, grass, road, sea, river, mountain, sands, desk, bed, building, cloud, lamp, door, window, wardrobe, ceiling, shelf, curtain, stair, floor, hill, rail, fence
2 | aeroplane
3 | bicycle
4 | bird
5 | ship
6 | bottle
7 | bus
8 | car
9 | cat
10 | chair
11 | cow
12 | table
13 | dog
14 | horse
15 | motorbike
16 | person, person in shirt, person in jeans, person in dress, person in sweater, person in skirt, person in jacket
17 | pottedplant
18 | sheep
19 | sofa
20 | train
21 | television monitor, tv monitor, monitor, television, screen
--------------------------------------------------------------------------------
/src/training/seg_configs/convert_coco_object.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # GroupViT (https://github.com/NVlabs/GroupViT)
3 | # Copyright (c) 2021-22, NVIDIA Corporation & affiliates. All Rights Reserved.
4 | # ------------------------------------------------------------------------------
5 |
6 | import argparse
7 | import os.path as osp
8 | import shutil
9 | from functools import partial
10 | from glob import glob
11 |
12 | from mmengine.utils import (mkdir_or_exist, track_parallel_progress,
13 | track_progress)
14 | import numpy as np
15 | from PIL import Image
16 |
17 | COCO_LEN = 123287
18 |
19 | clsID_to_trID = {
20 | 0: 0,
21 | 1: 1,
22 | 2: 2,
23 | 3: 3,
24 | 4: 4,
25 | 5: 5,
26 | 6: 6,
27 | 7: 7,
28 | 8: 8,
29 | 9: 9,
30 | 10: 10,
31 | 12: 11,
32 | 13: 12,
33 | 14: 13,
34 | 15: 14,
35 | 16: 15,
36 | 17: 16,
37 | 18: 17,
38 | 19: 18,
39 | 20: 19,
40 | 21: 20,
41 | 22: 21,
42 | 23: 22,
43 | 24: 23,
44 | 26: 24,
45 | 27: 25,
46 | 30: 26,
47 | 31: 27,
48 | 32: 28,
49 | 33: 29,
50 | 34: 30,
51 | 35: 31,
52 | 36: 32,
53 | 37: 33,
54 | 38: 34,
55 | 39: 35,
56 | 40: 36,
57 | 41: 37,
58 | 42: 38,
59 | 43: 39,
60 | 45: 40,
61 | 46: 41,
62 | 47: 42,
63 | 48: 43,
64 | 49: 44,
65 | 50: 45,
66 | 51: 46,
67 | 52: 47,
68 | 53: 48,
69 | 54: 49,
70 | 55: 50,
71 | 56: 51,
72 | 57: 52,
73 | 58: 53,
74 | 59: 54,
75 | 60: 55,
76 | 61: 56,
77 | 62: 57,
78 | 63: 58,
79 | 64: 59,
80 | 66: 60,
81 | 69: 61,
82 | 71: 62,
83 | 72: 63,
84 | 73: 64,
85 | 74: 65,
86 | 75: 66,
87 | 76: 67,
88 | 77: 68,
89 | 78: 69,
90 | 79: 70,
91 | 80: 71,
92 | 81: 72,
93 | 83: 73,
94 | 84: 74,
95 | 85: 75,
96 | 86: 76,
97 | 87: 77,
98 | 88: 78,
99 | 89: 79,
100 | 91: 80,
101 | 92: 81,
102 | 93: 82,
103 | 94: 83,
104 | 95: 84,
105 | 96: 85,
106 | 97: 86,
107 | 98: 87,
108 | 99: 88,
109 | 100: 89,
110 | 101: 90,
111 | 102: 91,
112 | 103: 92,
113 | 104: 93,
114 | 105: 94,
115 | 106: 95,
116 | 107: 96,
117 | 108: 97,
118 | 109: 98,
119 | 110: 99,
120 | 111: 100,
121 | 112: 101,
122 | 113: 102,
123 | 114: 103,
124 | 115: 104,
125 | 116: 105,
126 | 117: 106,
127 | 118: 107,
128 | 119: 108,
129 | 120: 109,
130 | 121: 110,
131 | 122: 111,
132 | 123: 112,
133 | 124: 113,
134 | 125: 114,
135 | 126: 115,
136 | 127: 116,
137 | 128: 117,
138 | 129: 118,
139 | 130: 119,
140 | 131: 120,
141 | 132: 121,
142 | 133: 122,
143 | 134: 123,
144 | 135: 124,
145 | 136: 125,
146 | 137: 126,
147 | 138: 127,
148 | 139: 128,
149 | 140: 129,
150 | 141: 130,
151 | 142: 131,
152 | 143: 132,
153 | 144: 133,
154 | 145: 134,
155 | 146: 135,
156 | 147: 136,
157 | 148: 137,
158 | 149: 138,
159 | 150: 139,
160 | 151: 140,
161 | 152: 141,
162 | 153: 142,
163 | 154: 143,
164 | 155: 144,
165 | 156: 145,
166 | 157: 146,
167 | 158: 147,
168 | 159: 148,
169 | 160: 149,
170 | 161: 150,
171 | 162: 151,
172 | 163: 152,
173 | 164: 153,
174 | 165: 154,
175 | 166: 155,
176 | 167: 156,
177 | 168: 157,
178 | 169: 158,
179 | 170: 159,
180 | 171: 160,
181 | 172: 161,
182 | 173: 162,
183 | 174: 163,
184 | 175: 164,
185 | 176: 165,
186 | 177: 166,
187 | 178: 167,
188 | 179: 168,
189 | 180: 169,
190 | 181: 170,
191 | 255: 255
192 | }
193 |
194 | # set to background
195 | for k, v in clsID_to_trID.items():
196 | clsID_to_trID[k] = v + 1
197 | if k > 90:
198 | clsID_to_trID[k] = 0
199 |
200 |
201 | def convert_to_trainID(maskpath, out_mask_dir, is_train):
202 | mask = np.array(Image.open(maskpath))
203 | mask_copy = mask.copy()
204 | for clsID, trID in clsID_to_trID.items():
205 | mask_copy[mask == clsID] = trID
206 | seg_filename = osp.join(
207 | out_mask_dir, 'train2017',
208 | osp.basename(maskpath).split('.')[0] +
209 | '_instanceTrainIds.png') if is_train else osp.join(
210 | out_mask_dir, 'val2017',
211 | osp.basename(maskpath).split('.')[0] + '_instanceTrainIds.png')
212 | Image.fromarray(mask_copy).save(seg_filename, 'PNG')
213 |
214 |
215 | def parse_args():
216 | parser = argparse.ArgumentParser(
217 | description=\
218 | 'Convert COCO Stuff 164k annotations to COCO Objects') # noqa
219 | parser.add_argument('coco_path', help='coco stuff path')
220 | parser.add_argument('-o', '--out_dir', help='output path')
221 | parser.add_argument(
222 | '--nproc', default=16, type=int, help='number of process')
223 | args = parser.parse_args()
224 | return args
225 |
226 |
227 | def main():
228 | args = parse_args()
229 | coco_path = args.coco_path
230 | nproc = args.nproc
231 |
232 | out_dir = args.out_dir or coco_path
233 | out_img_dir = osp.join(out_dir, 'images')
234 | out_mask_dir = osp.join(out_dir, 'annotations')
235 |
236 | mkdir_or_exist(osp.join(out_mask_dir, 'train2017'))
237 | mkdir_or_exist(osp.join(out_mask_dir, 'val2017'))
238 |
239 | if out_dir != coco_path:
240 | shutil.copytree(osp.join(coco_path, 'images'), out_img_dir)
241 |
242 | #train_list = glob(osp.join(coco_path, 'annotations', 'train2017', '*.png'))
243 | #train_list = [file for file in train_list if 'TrainIds' not in file]
244 | test_list = glob(osp.join(coco_path, 'annotations', 'val2017', '*.png'))
245 | test_list = [file for file in test_list if 'TrainIds' not in file]
246 |
247 | #assert (len(train_list) +
248 | # len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format(
249 | # len(train_list), len(test_list))
250 |
251 | if args.nproc > 1:
252 | """
253 | track_parallel_progress(
254 | partial(
255 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True),
256 | train_list,
257 | nproc=nproc)
258 | """
259 | track_parallel_progress(
260 | partial(
261 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False),
262 | test_list,
263 | nproc=nproc)
264 | else:
265 | """
266 | track_progress(
267 | partial(
268 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True),
269 | train_list)
270 | """
271 | track_progress(
272 | partial(
273 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False),
274 | test_list)
275 |
276 | print('Done!')
277 |
278 |
279 | if __name__ == '__main__':
280 | main()
--------------------------------------------------------------------------------
/src/training/zero_shot.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 | from tqdm import tqdm
5 |
6 | from open_clip import get_input_dtype, get_tokenizer, build_zero_shot_classifier, \
7 | IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES
8 | from .precision import get_autocast
9 |
10 | import json
11 |
12 | def accuracy(output, target, topk=(1,)):
13 | pred = output.topk(max(topk), 1, True, True)[1].t()
14 | correct = pred.eq(target.view(1, -1).expand_as(pred))
15 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
16 |
17 | def run(model, classifier, dataloader, args):
18 | autocast = get_autocast(args.precision)
19 | input_dtype = get_input_dtype(args.precision)
20 |
21 | with torch.no_grad():
22 | top1, top5, n = 0., 0., 0.
23 | for images, target in tqdm(dataloader, unit_scale=args.batch_size):
24 | images = images.to(device=args.device, dtype=input_dtype)
25 | target = target.to(args.device)
26 |
27 | with autocast():
28 | # predict
29 | output = model(image=images)
30 | image_features = output['image_features'] if isinstance(output, dict) else output
31 | # image_features (B, 512), classifier (512, 1000)
32 | logits = 100. * image_features @ classifier
33 |
34 | # measure accuracy
35 | acc1, acc5 = accuracy(logits, target, topk=(1, 5))
36 | top1 += acc1
37 | top5 += acc5
38 | n += images.size(0)
39 |
40 | top1 = (top1 / n)
41 | top5 = (top5 / n)
42 | return top1, top5
43 |
44 | def zero_shot_eval(model, data, epoch, args, tokenizer=None):
45 | if 'imagenet-val' not in data and 'imagenet-v2' not in data:
46 | return {}
47 | if args.zeroshot_frequency == 0:
48 | return {}
49 | if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs:
50 | return {}
51 | if args.distributed and not args.horovod:
52 | model = model.module
53 |
54 | logging.info('Starting zero-shot imagenet.')
55 | if tokenizer is None:
56 | tokenizer = get_tokenizer(args.model)
57 |
58 | logging.info('Building zero-shot classifier')
59 | autocast = get_autocast(args.precision)
60 | with autocast():
61 | classifier = build_zero_shot_classifier(
62 | model,
63 | tokenizer=tokenizer,
64 | args=args,
65 | classnames=IMAGENET_CLASSNAMES,
66 | templates=OPENAI_IMAGENET_TEMPLATES,
67 | num_classes_per_batch=10,
68 | device=args.device,
69 | use_tqdm=True,
70 | )
71 |
72 | logging.info('Using classifier')
73 | results = {}
74 | if 'imagenet-val' in data:
75 | top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args)
76 | results['imagenet-zeroshot-val-top1'] = top1
77 | results['imagenet-zeroshot-val-top5'] = top5
78 | if 'imagenet-v2' in data:
79 | top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args)
80 | results['imagenetv2-zeroshot-val-top1'] = top1
81 | results['imagenetv2-zeroshot-val-top5'] = top5
82 |
83 | logging.info('Finished zero-shot imagenet.')
84 |
85 | return results
86 |
87 | def zero_shot_classification_eval(model, data_name, dataloader, dataset_labels, dataset_templates, epoch, args, tokenizer=None):
88 | if args.zeroshot_frequency == 0:
89 | return {}
90 | if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs:
91 | return {}
92 | if args.distributed and not args.horovod:
93 | model = model.module
94 |
95 | logging.info(f'Starting zero-shot {data_name}.')
96 | if tokenizer is None:
97 | tokenizer = get_tokenizer(args.model)
98 |
99 | logging.info('Building zero-shot classifier')
100 | autocast = get_autocast(args.precision)
101 | with autocast():
102 | classifier = build_zero_shot_classifier(
103 | model,
104 | tokenizer=tokenizer,
105 | args=args,
106 | classnames=dataset_labels[data_name],
107 | templates=dataset_templates[data_name],
108 | num_classes_per_batch=10,
109 | device=args.device,
110 | use_tqdm=True,
111 | )
112 |
113 | logging.info('Using classifier')
114 | results = {}
115 | top1, top5 = run(model, classifier, dataloader, args)
116 | results[f'{data_name}-zeroshot-val-top1'] = top1
117 | results[f'{data_name}-zeroshot-val-top5'] = top5
118 |
119 | logging.info(f'Finished zero-shot {data_name}.')
120 |
121 | return results
--------------------------------------------------------------------------------