├── LICENSE ├── Makefile ├── README.md ├── dataloaders ├── __init__.py ├── caltech101.py ├── cifar10.py ├── cifar100.py ├── coco.py ├── dtd.py ├── fgvc_aircraft.py ├── flickr30k.py ├── flowers102.py ├── food101.py ├── oxford_pets.py ├── stanford_car.py ├── sun397.py └── utils.py ├── dist_train.sh ├── eval_zs.sh ├── figures ├── moti.png ├── radar.jpg └── radar.png ├── main.py ├── open_clip ├── __init__.py ├── big_vision.py ├── bpe_simple_vocab_16e6.txt.gz ├── coca_model.py ├── constants.py ├── factory.py ├── generation_utils.py ├── hf_configs.py ├── hf_model.py ├── loss.py ├── model.py ├── model_configs │ ├── EVA01-g-14-plus.json │ ├── EVA01-g-14.json │ ├── EVA02-B-16.json │ ├── EVA02-E-14-plus.json │ ├── EVA02-E-14.json │ ├── EVA02-L-14-336.json │ ├── EVA02-L-14.json │ ├── RN101-quickgelu.json │ ├── RN101.json │ ├── RN50-quickgelu.json │ ├── RN50.json │ ├── RN50x16.json │ ├── RN50x4.json │ ├── RN50x64.json │ ├── ViT-B-16-SigLIP-256.json │ ├── ViT-B-16-SigLIP-384.json │ ├── ViT-B-16-SigLIP-512.json │ ├── ViT-B-16-SigLIP-i18n-256.json │ ├── ViT-B-16-SigLIP.json │ ├── ViT-B-16-plus-240.json │ ├── ViT-B-16-plus.json │ ├── ViT-B-16-quickgelu.json │ ├── ViT-B-16.json │ ├── ViT-B-32-256.json │ ├── ViT-B-32-plus-256.json │ ├── ViT-B-32-quickgelu.json │ ├── ViT-B-32.json │ ├── ViT-H-14-378-quickgelu.json │ ├── ViT-H-14-CLIPA-336.json │ ├── ViT-H-14-CLIPA.json │ ├── ViT-H-14-quickgelu.json │ ├── ViT-H-14.json │ ├── ViT-H-16.json │ ├── ViT-L-14-280.json │ ├── ViT-L-14-336.json │ ├── ViT-L-14-CLIPA-336.json │ ├── ViT-L-14-CLIPA.json │ ├── ViT-L-14-quickgelu.json │ ├── ViT-L-14.json │ ├── ViT-L-16-320.json │ ├── ViT-L-16-SigLIP-256.json │ ├── ViT-L-16-SigLIP-384.json │ ├── ViT-L-16.json │ ├── ViT-M-16-alt.json │ ├── ViT-M-16.json │ ├── ViT-M-32-alt.json │ ├── ViT-M-32.json │ ├── ViT-S-16-alt.json │ ├── ViT-S-16.json │ ├── ViT-S-32-alt.json │ ├── ViT-S-32.json │ ├── ViT-SO400M-14-SigLIP-384.json │ ├── ViT-SO400M-14-SigLIP.json │ ├── ViT-bigG-14-CLIPA-336.json │ ├── ViT-bigG-14-CLIPA.json │ ├── ViT-bigG-14.json │ ├── ViT-e-14.json │ ├── ViT-g-14.json │ ├── coca_ViT-B-32.json │ ├── coca_ViT-L-14.json │ ├── coca_base.json │ ├── coca_roberta-ViT-B-32.json │ ├── convnext_base.json │ ├── convnext_base_w.json │ ├── convnext_base_w_320.json │ ├── convnext_large.json │ ├── convnext_large_d.json │ ├── convnext_large_d_320.json │ ├── convnext_small.json │ ├── convnext_tiny.json │ ├── convnext_xlarge.json │ ├── convnext_xxlarge.json │ ├── convnext_xxlarge_320.json │ ├── mt5-base-ViT-B-32.json │ ├── mt5-xl-ViT-H-14.json │ ├── nllb-clip-base.json │ ├── nllb-clip-large.json │ ├── roberta-ViT-B-16.json │ ├── roberta-ViT-B-32.json │ ├── swin_base_patch4_window7_224.json │ ├── vit_medium_patch16_gap_256.json │ ├── vit_relpos_medium_patch16_cls_224.json │ ├── xlm-roberta-base-ViT-B-32.json │ └── xlm-roberta-large-ViT-H-14.json ├── modified_resnet.py ├── openai.py ├── pos_embed.py ├── pretrained.py ├── push_to_hf_hub.py ├── timm_model.py ├── tokenizer.py ├── transform.py ├── transformer.py ├── utils.py ├── version.py ├── zero_shot_classifier.py └── zero_shot_metadata.py ├── requirements-test.txt ├── requirements-training.txt ├── requirements.txt ├── setup.py └── training ├── __init__.py ├── data.py ├── distributed.py ├── file_utils.py ├── logger.py ├── params.py ├── precision.py ├── profiler.py ├── random_aug.py ├── scheduler.py ├── train.py └── zero_shot.py /Makefile: -------------------------------------------------------------------------------- 1 | install: ## [Local development] Upgrade pip, install requirements, install package. 2 | python -m pip install -U pip 3 | python -m pip install -e . 4 | 5 | install-training: 6 | python -m pip install -r requirements-training.txt 7 | 8 | install-test: ## [Local development] Install test requirements 9 | python -m pip install -r requirements-test.txt 10 | 11 | test: ## [Local development] Run unit tests 12 | python -m pytest -x -s -v tests 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DreamLIP: Language-Image Pre-training with Long Captions 2 | 3 | 4 | > **DreamLIP: Language-Image Pre-training with Long Captions**
5 | Kecheng Zheng, 6 | Yifei Zhang, 7 | Wei Wu, 8 | Fan Lu, 9 | Shuailei Ma, 10 | Xin Jin, 11 | Wei Chen, 12 | Yujun Shen
13 | [Project Page](https://zyf0619sjtu.github.io/dream-lip/) | [Paper](https://arxiv.org/pdf/2403.17007.pdf) 14 | 15 | 16 | ## 📰 News 17 | 18 | [//]: # (- [2024/11/15] Long captions (LLAVA1.5, InstructBLIP and shareGPT4V) of Laion50M and Coyo20M that are used in [NeurIPS'24 LotLIP](https://github.com/wuw2019/LoTLIP) are released in google drive~) 19 | - [2024/11/26] Long captions (LLAVA1.5, InstructBLIP and shareGPT4V) of COYO24M/LAION49M are released in huggingface~ 20 | - [2024/08/26] Long captions (LLAVA1.5, InstructBLIP and shareGPT4V) of CC3M/CC12M/YFCC15M are released in huggingface~ 21 | - [2024/07/16] Upload the pretrained weight of VIT-B/16 pretrained in CC3M, CC12M, YFCC15M, and merged-30M (long captions of ShareGPT4V)! 22 | - [2024/07/08] DreamLIP is accepted by ECCV 2024! 23 | 24 | ## 💡 Highlights 25 | - 🔥 Exploring how language-image pre-training could benefit from long captions. 26 | - 🔥 Strong improvement on semantic segmentation, image-text retrieval, semantic segmentation, and image understanding in MLLM. 27 | 28 | 29 | 30 | - 🔥 DreamLIP trained with 30M image-text pairs achieves on par or even better performance than CLIP trained with 400M pairs. 31 | ![timeline.jpg](figures/moti.png) 32 | 33 | ## 🎨 In-Progress 34 | 35 | - [x] Release long captions of CC3M, CC12M, YFCC15M, COYO24M and LAION49M. 36 | - [ ] Release training code. 37 | 38 | ## 🏝️ Overview of supported long captions: 39 | 40 |
41 | Long Captions of Supported Datasets (5) 42 | 43 | > - [x] [![](https://img.shields.io/badge/CC3M-red?style=for-the-badge)](https://ai.google.com/research/ConceptualCaptions/) 44 | > - [x] [![](https://img.shields.io/badge/CC12M-d0e9ff?style=for-the-badge)](https://github.com/google-research-datasets/conceptual-12m) 45 | > - [x] [![](https://img.shields.io/badge/YFCC15M-yellowgreen?style=for-the-badge)](https://github.com/Sense-GVT/DeCLIP/blob/main/docs/dataset_prepare.md) 46 | > - [x] [![](https://img.shields.io/badge/Laion50M-grey?style=for-the-badge)](https://laion.ai/laion-5b-a-new-era-of-open-large-scale-multi-modal-datasets/) 47 | > - [x] [![](https://img.shields.io/badge/Coyo20M-854?style=for-the-badge)](https://github.com/kakaobrain/coyo-dataset) 48 |
49 |
50 | Long Captions of MLLMs (3) 51 | 52 | > - [x] ![](https://img.shields.io/badge/InstructBLIP-blue?style=for-the-badge) 53 | > - [x] ![](https://img.shields.io/badge/LLAVA1.5-green?style=for-the-badge) 54 | > - [x] ![](https://img.shields.io/badge/SHAREGPT4V-orange?style=for-the-badge) 55 | 56 |
57 | 58 | [//]: # (## Acknowledgement) 59 | 60 | 61 | #### Generated Long Captions 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 |
DatasetHuggingface Dataset
CC3MRaw/Long/Short Caption
CC12MRaw/Long/Short Caption
YFCC15MRaw/Long/Short Caption
Laion49MLong Caption
COYO24MLong Caption
91 | 92 | ## Pretrained checkpoints 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 |
DatasetModelShareGPT4VInstructBLIP + LLAVA1.5 + ShareGPT4V
CC3MViT-B/16LinkTODO
CC12MViT-B/16LinkTODO
YFCC15MViT-B/16LinkTODO
CC30MViT-B/16LinkTODO
127 | 128 | ## 📣 Instructions 129 | Environment installation 130 | ``` 131 | pip install -r requirments.txt 132 | ``` 133 | 134 | Evaluate zero shot classification 135 | ``` 136 | bash eval_zs.sh 137 | ``` 138 | 139 | [//]: # (You can download checkpoints pre-trained ) 140 | 141 | ## License 142 | 143 | The project is under a standard Creative Common [CC-BY-4.0 License](./LICENSE). 144 | 145 | ## 📖 Citation 146 | 147 | We open source this library to the community to facilitate the research. If you do like our work and use the codebase for your projects, please cite our work as follows. 148 | 149 | ```bibtex 150 | @inproceedings{DreamLIP, 151 | title={DreamLIP: Language-Image Pre-training with Long Captions}, 152 | author={Zheng, Kecheng and Zhang, Yifei and Wu, Wei and Lu, Fan and Ma, Shuailei and Jin, Xin and Chen, Wei and Shen, Yujun}, 153 | booktitle={ECCV}, 154 | year={2024} 155 | } 156 | ``` 157 | 158 | ### Acknowledgements 159 | This project is based on [open_clip](https://github.com/mlfoundations/open_clip/tree/main), and thanks for the nice work! 160 | We also thank [InstructBLIP](https://github.com/salesforce/LAVIS), [ShareGPT4V](https://github.com/InternLM/InternLM-XComposer) and [LLAVA](https://github.com/haotian-liu/LLaVA) for the pretrained models and codes. 161 | 162 | 163 | 164 | -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/DreamLIP/b4de0a43b6c002033c02873f91a695ab449e464c/dataloaders/__init__.py -------------------------------------------------------------------------------- /dataloaders/caltech101.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import torch 3 | from torch.utils.data import DataLoader, Dataset, random_split 4 | from torchvision.datasets import Caltech101 5 | from .utils import dataset_root 6 | 7 | 8 | root = dataset_root 9 | num_example_train = 3000 10 | num_example_test = 5677 11 | num_classes = 101 12 | mean_per_class = True 13 | 14 | def get_loader_train( 15 | transform, batch_size, num_workers, seed 16 | ) -> Tuple[Dataset, DataLoader]: 17 | dataset = Caltech101(root, download=False, transform=transform) 18 | dataset_train, dataset_test = random_split( 19 | dataset, 20 | lengths=[num_example_train, 21 | num_example_test], 22 | generator=torch.Generator().manual_seed(seed)) 23 | return (dataset_train, None) 24 | 25 | 26 | def get_loader_test( 27 | transform, batch_size, num_workers, seed 28 | ) -> Tuple[Dataset, DataLoader]: 29 | dataset = Caltech101(root, download=False, transform=transform) 30 | dataset_train, dataset_test = random_split( 31 | dataset, 32 | lengths=[num_example_train, 33 | num_example_test], 34 | generator=torch.Generator().manual_seed(seed)) 35 | return (dataset_test, None) 36 | -------------------------------------------------------------------------------- /dataloaders/cifar10.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from torch.utils.data import DataLoader, Dataset 3 | from torchvision.datasets import CIFAR10 4 | from .utils import dataset_root 5 | 6 | 7 | root = dataset_root + '/cifar/' 8 | num_example_train = 50000 9 | num_example_test = 10000 10 | num_classes = 10 11 | 12 | 13 | def get_loader_train( 14 | transform, batch_size, num_workers, seed 15 | ) -> Tuple[Dataset, DataLoader]: 16 | dataset_train = CIFAR10(root, download=False, train=True, transform=transform) 17 | return (dataset_train, None) 18 | 19 | 20 | def get_loader_test( 21 | transform, batch_size, num_workers, seed 22 | ) -> Tuple[Dataset, DataLoader]: 23 | dataset = CIFAR10(root, download=False, train=False, transform=transform) 24 | return (dataset, None) 25 | -------------------------------------------------------------------------------- /dataloaders/cifar100.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from torch.utils.data import DataLoader, Dataset 3 | from torchvision.datasets import CIFAR100 4 | from .utils import dataset_root 5 | 6 | 7 | root = dataset_root + '/cifar/' 8 | num_example_train_val = 50000 9 | num_example_test = 10000 10 | num_classes = 100 11 | 12 | 13 | def get_loader_train( 14 | transform, batch_size, num_workers, seed 15 | ) -> Tuple[Dataset, DataLoader]: 16 | dataset_train_val = CIFAR100(root, download=False, train=True, transform=transform) 17 | return (dataset_train_val, None) 18 | 19 | 20 | def get_loader_test( 21 | transform, batch_size, num_workers, seed 22 | ) -> Tuple[Dataset, DataLoader]: 23 | dataset_test = CIFAR100(root, download=False, train=False, transform=transform) 24 | return (dataset_test, None) 25 | -------------------------------------------------------------------------------- /dataloaders/coco.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from typing import Any, Callable, Optional, Tuple, List 3 | 4 | from PIL import Image 5 | 6 | from torchvision.datasets.vision import VisionDataset 7 | 8 | from pcache_fileio import fileio 9 | 10 | 11 | class CocoDetection(VisionDataset): 12 | """`MS Coco Detection `_ Dataset. 13 | 14 | It requires the `COCO API to be installed `_. 15 | 16 | Args: 17 | root (string): Root directory where images are downloaded to. 18 | annFile (string): Path to json annotation file. 19 | transform (callable, optional): A function/transform that takes in an PIL image 20 | and returns a transformed version. E.g, ``transforms.ToTensor`` 21 | target_transform (callable, optional): A function/transform that takes in the 22 | target and transforms it. 23 | transforms (callable, optional): A function/transform that takes input sample and its target as entry 24 | and returns a transformed version. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | root: str, 30 | annFile: str, 31 | transform: Optional[Callable] = None, 32 | target_transform: Optional[Callable] = None, 33 | transforms: Optional[Callable] = None, 34 | ) -> None: 35 | super().__init__(root, transforms, transform, target_transform) 36 | from pycocotools.coco import COCO 37 | 38 | self.coco = COCO(annFile) 39 | self.ids = list(sorted(self.coco.imgs.keys())) 40 | 41 | def _load_image(self, id: int) -> Image.Image: 42 | path = self.coco.loadImgs(id)[0]["file_name"] 43 | return Image.open(os.path.join(self.root, path)).convert("RGB") 44 | 45 | def _load_target(self, id: int) -> List[Any]: 46 | return self.coco.loadAnns(self.coco.getAnnIds(id)) 47 | 48 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 49 | id = self.ids[index] 50 | image = self._load_image(id) 51 | target = self._load_target(id) 52 | 53 | if self.transforms is not None: 54 | image, target = self.transforms(image, target) 55 | 56 | return image, target 57 | 58 | def __len__(self) -> int: 59 | return len(self.ids) 60 | 61 | 62 | class CocoCaptions(CocoDetection): 63 | """`MS Coco Captions `_ Dataset. 64 | 65 | It requires the `COCO API to be installed `_. 66 | 67 | Args: 68 | root (string): Root directory where images are downloaded to. 69 | annFile (string): Path to json annotation file. 70 | transform (callable, optional): A function/transform that takes in an PIL image 71 | and returns a transformed version. E.g, ``transforms.ToTensor`` 72 | target_transform (callable, optional): A function/transform that takes in the 73 | target and transforms it. 74 | transforms (callable, optional): A function/transform that takes input sample and its target as entry 75 | and returns a transformed version. 76 | 77 | Example: 78 | 79 | .. code:: python 80 | 81 | import torchvision.datasets as dset 82 | import torchvision.transforms as transforms 83 | cap = dset.CocoCaptions(root = 'dir where images are', 84 | annFile = 'json annotation file', 85 | transform=transforms.ToTensor()) 86 | 87 | print('Number of samples: ', len(cap)) 88 | img, target = cap[3] # load 4th sample 89 | 90 | print("Image Size: ", img.size()) 91 | print(target) 92 | 93 | Output: :: 94 | 95 | Number of samples: 82783 96 | Image Size: (3L, 427L, 640L) 97 | [u'A plane emitting smoke stream flying over a mountain.', 98 | u'A plane darts across a bright blue sky behind a mountain covered in snow', 99 | u'A plane leaves a contrail above the snowy mountain top.', 100 | u'A mountain that has a plane flying overheard in the distance.', 101 | u'A mountain view with a plume of smoke in the background'] 102 | 103 | """ 104 | 105 | def _load_target(self, id: int) -> List[str]: 106 | return [ann["caption"] for ann in super()._load_target(id)] 107 | -------------------------------------------------------------------------------- /dataloaders/dtd.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from torch.utils.data import DataLoader, Dataset, ConcatDataset 3 | from torchvision.datasets import DTD 4 | import torchvision 5 | from .utils import dataset_root 6 | 7 | root = dataset_root 8 | num_example_train = 3760 9 | num_example_test = 1880 10 | num_classes = 47 11 | 12 | 13 | class Warper(Dataset): 14 | def __init__(self, dataset) -> None: 15 | super().__init__() 16 | self.dataset = dataset 17 | 18 | def __len__(self) -> int: 19 | return len(self.dataset) 20 | 21 | def __getitem__(self, index): 22 | img, label = self.dataset[index] 23 | if torchvision.__version__ >= "0.13.0": 24 | return img, label 25 | else: 26 | return img, label - 1 27 | 28 | 29 | def get_loader_train( 30 | transform, batch_size, num_workers, seed 31 | ) -> Tuple[Dataset, DataLoader]: 32 | dataset_train = DTD(root, download=False, split='train', transform=transform) 33 | dataset_val = DTD(root, download=False, split='val', transform=transform) 34 | dataset_train = ConcatDataset([dataset_train, dataset_val]) 35 | dataset_train = Warper(dataset_train) 36 | return (dataset_train, None) 37 | 38 | 39 | def get_loader_test( 40 | transform, batch_size, num_workers, seed 41 | ) -> Tuple[Dataset, DataLoader]: 42 | dataset_test = DTD(root, download=False, split='test', transform=transform) 43 | return (dataset_test, None) 44 | -------------------------------------------------------------------------------- /dataloaders/fgvc_aircraft.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from torch.utils.data import DataLoader, Dataset 3 | from torchvision.datasets import FGVCAircraft 4 | from .utils import dataset_root 5 | 6 | 7 | root = dataset_root 8 | num_example_train_val = 6667 9 | num_example_test = 3333 10 | num_classes = 100 11 | mean_per_class = True 12 | 13 | 14 | def get_loader_train( 15 | transform, batch_size, num_workers, seed 16 | ) -> Tuple[Dataset, DataLoader]: 17 | dataset_train_val = FGVCAircraft(root, download=True, annotation_level='variant', split='trainval',transform=transform) 18 | return (dataset_train_val, None) 19 | 20 | 21 | def get_loader_test( 22 | transform, batch_size, num_workers, seed 23 | ) -> Tuple[Dataset, DataLoader]: 24 | dataset_test = FGVCAircraft(root, download=True, annotation_level='variant', split='test', transform=transform) 25 | return (dataset_test, None) 26 | 27 | -------------------------------------------------------------------------------- /dataloaders/flickr30k.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from collections import defaultdict 4 | from html.parser import HTMLParser 5 | from typing import Any, Callable, Dict, List, Optional, Tuple 6 | import json 7 | 8 | from PIL import Image 9 | 10 | from torchvision.datasets.vision import VisionDataset 11 | 12 | from pcache_fileio import fileio 13 | 14 | 15 | class Flickr8kParser(HTMLParser): 16 | """Parser for extracting captions from the Flickr8k dataset web page.""" 17 | 18 | def __init__(self, root: str) -> None: 19 | super().__init__() 20 | 21 | self.root = root 22 | 23 | # Data structure to store captions 24 | self.annotations: Dict[str, List[str]] = {} 25 | 26 | # State variables 27 | self.in_table = False 28 | self.current_tag: Optional[str] = None 29 | self.current_img: Optional[str] = None 30 | 31 | def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None: 32 | self.current_tag = tag 33 | 34 | if tag == "table": 35 | self.in_table = True 36 | 37 | def handle_endtag(self, tag: str) -> None: 38 | self.current_tag = None 39 | 40 | if tag == "table": 41 | self.in_table = False 42 | 43 | def handle_data(self, data: str) -> None: 44 | if self.in_table: 45 | if data == "Image Not Found": 46 | self.current_img = None 47 | elif self.current_tag == "a": 48 | img_id = data.split("/")[-2] 49 | img_id = os.path.join(self.root, img_id + "_*.jpg") 50 | img_id = glob.glob(img_id)[0] 51 | self.current_img = img_id 52 | self.annotations[img_id] = [] 53 | elif self.current_tag == "li" and self.current_img: 54 | img_id = self.current_img 55 | self.annotations[img_id].append(data.strip()) 56 | 57 | 58 | class Flickr8k(VisionDataset): 59 | """`Flickr8k Entities `_ Dataset. 60 | 61 | Args: 62 | root (string): Root directory where images are downloaded to. 63 | ann_file (string): Path to annotation file. 64 | transform (callable, optional): A function/transform that takes in a PIL image 65 | and returns a transformed version. E.g, ``transforms.ToTensor`` 66 | target_transform (callable, optional): A function/transform that takes in the 67 | target and transforms it. 68 | """ 69 | 70 | def __init__( 71 | self, 72 | root: str, 73 | ann_file: str, 74 | transform: Optional[Callable] = None, 75 | target_transform: Optional[Callable] = None, 76 | ) -> None: 77 | super().__init__(root, transform=transform, target_transform=target_transform) 78 | self.ann_file = os.path.expanduser(ann_file) 79 | 80 | # Read annotations and store in a dict 81 | parser = Flickr8kParser(self.root) 82 | with open(self.ann_file) as fh: 83 | parser.feed(fh.read()) 84 | self.annotations = parser.annotations 85 | 86 | self.ids = list(sorted(self.annotations.keys())) 87 | 88 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 89 | """ 90 | Args: 91 | index (int): Index 92 | 93 | Returns: 94 | tuple: Tuple (image, target). target is a list of captions for the image. 95 | """ 96 | img_id = self.ids[index] 97 | 98 | # Image 99 | img = Image.open(img_id).convert("RGB") 100 | if self.transform is not None: 101 | img = self.transform(img) 102 | 103 | # Captions 104 | target = self.annotations[img_id] 105 | if self.target_transform is not None: 106 | target = self.target_transform(target) 107 | 108 | return img, target 109 | 110 | def __len__(self) -> int: 111 | return len(self.ids) 112 | 113 | 114 | class Flickr30k(VisionDataset): 115 | """`Flickr30k Entities `_ Dataset. 116 | 117 | Args: 118 | root (string): Root directory where images are downloaded to. 119 | ann_file (string): Path to annotation file. 120 | transform (callable, optional): A function/transform that takes in a PIL image 121 | and returns a transformed version. E.g, ``transforms.ToTensor`` 122 | target_transform (callable, optional): A function/transform that takes in the 123 | target and transforms it. 124 | """ 125 | 126 | def __init__( 127 | self, 128 | root: str, 129 | ann_file: str, 130 | transform: Optional[Callable] = None, 131 | target_transform: Optional[Callable] = None, 132 | ) -> None: 133 | super().__init__(root, transform=transform, target_transform=target_transform) 134 | self.ann_file = os.path.expanduser(ann_file) 135 | self.ann_file = json.load(open(ann_file, 'r')) 136 | # Read annotations and store in a dict 137 | self.annotations = defaultdict(list) 138 | 139 | for line in self.ann_file: 140 | img_id, caption = line['image'], line['caption'] 141 | self.annotations[img_id].append(caption) 142 | 143 | self.ids = list(sorted(self.annotations.keys())) 144 | 145 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 146 | """ 147 | Args: 148 | index (int): Index 149 | 150 | Returns: 151 | tuple: Tuple (image, target). target is a list of captions for the image. 152 | """ 153 | img_id = self.ids[index] 154 | 155 | # Image 156 | filename = os.path.join(self.root, img_id) 157 | img = Image.open(filename).convert("RGB") 158 | if self.transform is not None: 159 | img = self.transform(img) 160 | 161 | # Captions 162 | targets = self.annotations[img_id] 163 | if self.target_transform is not None: 164 | targets = [self.target_transform(target) for target in targets] 165 | 166 | return img, targets 167 | 168 | def __len__(self) -> int: 169 | return len(self.ids) 170 | -------------------------------------------------------------------------------- /dataloaders/flowers102.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from torch.utils.data import DataLoader, Dataset, ConcatDataset 3 | from torchvision.datasets import Flowers102 4 | import torchvision 5 | from .utils import dataset_root 6 | 7 | root = dataset_root 8 | num_example_train_val = 2040 9 | num_example_test = 6149 10 | num_classes = 102 11 | mean_per_class = True 12 | 13 | 14 | class Warper(Dataset): 15 | def __init__(self, dataset) -> None: 16 | super().__init__() 17 | self.dataset = dataset 18 | 19 | def __len__(self) -> int: 20 | return len(self.dataset) 21 | 22 | def __getitem__(self, index): 23 | img, label = self.dataset[index] 24 | if torchvision.__version__ >= "0.13.0": 25 | return img, label 26 | else: 27 | return img, label - 1 28 | 29 | 30 | def get_loader_train( 31 | transform, batch_size, num_workers, seed 32 | ) -> Tuple[Dataset, DataLoader]: 33 | dataset_train = Flowers102(root, download=False, split='train', transform=transform) 34 | dataset_val = Flowers102(root, download=False, split='val', transform=transform) 35 | dataset_train = ConcatDataset([dataset_train, dataset_val]) 36 | dataset_train = Warper(dataset_train) 37 | return (dataset_train, None) 38 | 39 | 40 | def get_loader_test( 41 | transform, batch_size, num_workers, seed 42 | ) -> Tuple[Dataset, DataLoader]: 43 | dataset_test = Flowers102(root, download=False, split='test', transform=transform) 44 | dataset_test = Warper(dataset_test) 45 | return (dataset_test, None) 46 | 47 | -------------------------------------------------------------------------------- /dataloaders/food101.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from torch.utils.data import DataLoader, Dataset 3 | from torchvision.datasets import Food101 4 | from .utils import dataset_root 5 | 6 | 7 | root = dataset_root 8 | num_example_train_val = 75750 9 | num_example_test = 25250 10 | num_classes = 101 11 | 12 | 13 | def get_loader_train( 14 | transform, batch_size, num_workers, seed 15 | ) -> Tuple[Dataset, DataLoader]: 16 | dataset_train_val = Food101(root, download=False, split='train', transform=transform) 17 | return (dataset_train_val, None) 18 | 19 | 20 | def get_loader_test( 21 | transform, batch_size, num_workers, seed 22 | ) -> Tuple[Dataset, DataLoader]: 23 | dataset_test = Food101(root, download=False, split='test', transform=transform) 24 | return (dataset_test, None) 25 | -------------------------------------------------------------------------------- /dataloaders/oxford_pets.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from torch.utils.data import DataLoader, Dataset 3 | from torchvision.datasets import OxfordIIITPet 4 | from .utils import dataset_root 5 | 6 | 7 | root = dataset_root 8 | num_example_train_val = 3680 9 | num_example_test = 3699 10 | num_classes = 37 11 | mean_per_class = True 12 | 13 | 14 | def get_loader_train( 15 | transform, batch_size, num_workers, seed 16 | ) -> Tuple[Dataset, DataLoader]: 17 | dataset_train_val = OxfordIIITPet(root, download=True, split='trainval',transform=transform) 18 | return (dataset_train_val, None) 19 | 20 | 21 | def get_loader_test( 22 | transform, batch_size, num_workers, seed 23 | ) -> Tuple[Dataset, DataLoader]: 24 | dataset_test = OxfordIIITPet(root, download=True, split='test', transform=transform) 25 | return (dataset_test, None) 26 | 27 | -------------------------------------------------------------------------------- /dataloaders/stanford_car.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from torch.utils.data import DataLoader, Dataset 3 | from torchvision.datasets import StanfordCars 4 | from .utils import dataset_root 5 | 6 | 7 | root = dataset_root 8 | num_example_train_val = 8144 9 | num_example_test = 8041 10 | num_classes = 196 11 | 12 | 13 | def get_loader_train( 14 | transform, batch_size, num_workers, seed 15 | ) -> Tuple[Dataset, DataLoader]: 16 | dataset_train = StanfordCars( 17 | root, download=False, split='train', transform=transform) 18 | return (dataset_train, None) 19 | 20 | 21 | def get_loader_test( 22 | transform, batch_size, num_workers, seed 23 | ) -> Tuple[Dataset, DataLoader]: 24 | dataset_test = StanfordCars( 25 | root, download=False, split='test', transform=transform) 26 | return (dataset_test, None) 27 | -------------------------------------------------------------------------------- /dataloaders/sun397.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import torch 3 | from torch.utils.data import DataLoader, Dataset, random_split 4 | from torchvision.datasets import SUN397 5 | from .utils import dataset_root 6 | 7 | 8 | num_example_train = 19850 9 | num_example_test = 19850 10 | num_example_others = 69054 11 | num_classes = 397 12 | root = dataset_root + '/sun397/' 13 | 14 | 15 | def get_loader_train( 16 | transform, batch_size, num_workers, seed 17 | ) -> Tuple[Dataset, DataLoader]: 18 | dataset = SUN397(root, transform=transform) 19 | dataset_train, dataset_test, others = random_split( 20 | dataset, 21 | lengths=[num_example_train, 22 | num_example_test, 23 | num_example_others,], 24 | generator=torch.Generator().manual_seed(seed + hash("sun397") % 2048)) 25 | return (dataset_train, None) 26 | 27 | 28 | def get_loader_test( 29 | transform, batch_size, num_workers, seed 30 | ) -> Tuple[Dataset, DataLoader]: 31 | dataset = SUN397(root, transform=transform) 32 | dataset_train, dataset_test, others = random_split( 33 | dataset, 34 | lengths=[num_example_train, 35 | num_example_test, 36 | num_example_others,], 37 | generator=torch.Generator().manual_seed(seed + hash("sun397") % 2048)) 38 | return (dataset_test, None) 39 | -------------------------------------------------------------------------------- /dataloaders/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import numpy as np 4 | import torch 5 | from torch import distributed as dist 6 | from torch.utils.data import DistributedSampler as _DistributedSampler 7 | import os 8 | try: 9 | rank = int(os.environ["RANK"]) 10 | world_size = int(os.environ["WORLD_SIZE"]) 11 | except KeyError: 12 | rank = 0 13 | world_size = 1 14 | 15 | dataset_root = "/input_ssd/datasets/" 16 | 17 | def worker_init_fn(worker_id, num_workers, rank, seed): 18 | # The seed of each worker equals to 19 | # num_worker * rank + worker_id + user_seed 20 | worker_seed = num_workers * rank + worker_id + seed 21 | np.random.seed(worker_seed) 22 | random.seed(worker_seed) 23 | torch.manual_seed(worker_seed) 24 | 25 | 26 | def get_dist_info(): 27 | if dist.is_available() and dist.is_initialized(): 28 | rank = dist.get_rank() 29 | world_size = dist.get_world_size() 30 | else: 31 | rank = 0 32 | world_size = 1 33 | 34 | return rank, world_size 35 | 36 | 37 | def sync_random_seed(seed=None, device="cuda"): 38 | """Make sure different ranks share the same seed. 39 | All workers must call this function, otherwise it will deadlock. 40 | This method is generally used in `DistributedSampler`, 41 | because the seed should be identical across all processes 42 | in the distributed group. 43 | In distributed sampling, different ranks should sample non-overlapped 44 | data in the dataset. Therefore, this function is used to make sure that 45 | each rank shuffles the data indices in the same order based 46 | on the same seed. Then different ranks could use different indices 47 | to select non-overlapped data from the same data list. 48 | Args: 49 | seed (int, Optional): The seed. Default to None. 50 | device (str): The device where the seed will be put on. 51 | Default to 'cuda'. 52 | Returns: 53 | int: Seed to be used. 54 | """ 55 | if seed is None: 56 | seed = np.random.randint(2**31) 57 | assert isinstance(seed, int) 58 | 59 | rank, world_size = get_dist_info() 60 | 61 | if world_size == 1: 62 | return seed 63 | 64 | if rank == 0: 65 | random_num = torch.tensor(seed, dtype=torch.int32, device=device) 66 | else: 67 | random_num = torch.tensor(0, dtype=torch.int32, device=device) 68 | 69 | dist.broadcast(random_num, src=0) 70 | 71 | return random_num.item() 72 | 73 | 74 | class DistributedSampler(_DistributedSampler): 75 | def __init__( 76 | self, 77 | dataset, 78 | num_replicas=None, # world_size 79 | rank=None, # local_rank 80 | shuffle=True, 81 | seed=0, 82 | ): 83 | 84 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) 85 | 86 | # In distributed sampling, different ranks should sample 87 | # non-overlapped data in the dataset. Therefore, this function 88 | # is used to make sure that each rank shuffles the data indices 89 | # in the same order based on the same seed. Then different ranks 90 | # could use different indices to select non-overlapped data from the 91 | # same data list. 92 | self.seed = sync_random_seed(seed) 93 | 94 | def __iter__(self): 95 | # deterministically shuffle based on epoch 96 | if self.shuffle: 97 | g = torch.Generator() 98 | # When :attr:`shuffle=True`, this ensures all replicas 99 | # use a different random ordering for each epoch. 100 | # Otherwise, the next iteration of this sampler will 101 | # yield the same ordering. 102 | g.manual_seed(self.epoch + self.seed) 103 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 104 | else: 105 | indices = torch.arange(len(self.dataset)).tolist() 106 | 107 | # add extra samples to make it evenly divisible 108 | # in case that indices is shorter than half of total_size 109 | indices = (indices * math.ceil(self.total_size / len(indices)))[ 110 | : self.total_size 111 | ] 112 | assert len(indices) == self.total_size 113 | 114 | # subsample 115 | indices = indices[self.rank : self.total_size : self.num_replicas] 116 | assert len(indices) == self.num_samples 117 | 118 | return iter(indices) 119 | -------------------------------------------------------------------------------- /dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Detect `python3` command. 4 | # This workaround addresses a common issue: 5 | # `python` points to `python2`, which is deprecated. 6 | export PYTHONS 7 | export RVAL 8 | 9 | PYTHONS=$(compgen -c | grep "^python3$") 10 | 11 | # `$?` is a built-in variable in bash, which is the exit status of the most 12 | # recently-executed command; by convention, 0 means success and anything else 13 | # indicates failure. 14 | RVAL=$? 15 | 16 | if [[ $RVAL -eq 0 ]]; then # if `python3` exist 17 | PYTHON="python3" 18 | else 19 | PYTHON="python" 20 | fi 21 | 22 | # Help message. 23 | if [[ $# -lt 2 ]]; then 24 | echo "This script helps launch distributed training job on local machine." 25 | echo 26 | echo "Usage: $0 GPUS COMMAND [ARGS]" 27 | echo 28 | echo "Example: $0 8 ddpm [--help]" 29 | echo 30 | echo "To enable multi-node training, one can reuse this script" \ 31 | "by simply setting the following environment variables on each" \ 32 | "machine:" 33 | echo " MASTER_IP: The IP address of the master node." 34 | echo " MASTER_PORT: The port of the master node." 35 | echo " NODE_SIZE: Number of nodes (machines) used for training." 36 | echo " NODE_RANK: Node rank of the current machine." 37 | echo 38 | echo "NOTE: In multi-node training, \`GPUS\` refers to the number" \ 39 | "of GPUs on each local machine, or say, GPUs per node." 40 | echo 41 | echo "Example of using 16 GPUs on 2 machines (i.e., 8 GPUs each):" 42 | echo 43 | echo " On machine 0: MASTER_IP=node_0_ip MASTER_PORT=node_0_port" \ 44 | "NODE_SIZE=2 NODE_RANK=0 $0 8 ddpm [--help]" 45 | echo " On machine 1: MASTER_IP=node_0_ip MASTER_PORT=node_0_port" \ 46 | "NODE_SIZE=2 NODE_RANK=1 $0 8 ddpm [--help]" 47 | echo 48 | echo "Detailed instruction on available commands:" 49 | echo "--------------------------------------------------" 50 | ${PYTHON} ./main.py --help 51 | echo 52 | exit 0 53 | fi 54 | 55 | GPUS=$1 56 | COMMAND=$2 57 | 58 | # Help message for a particular command. 59 | if [[ $# -lt 3 || ${*: -1} == "--help" ]]; then 60 | echo "Detailed instruction on the arguments for command \`"${COMMAND}"\`:" 61 | echo "--------------------------------------------------" 62 | ${PYTHON} ./main.py ${COMMAND} --help 63 | echo 64 | exit 0 65 | fi 66 | 67 | # Switch memory allocator if available. 68 | # Search order: jemalloc.so -> tcmalloc.so. 69 | # According to https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html, 70 | # it can get better performance by reusing memory as much as possible than 71 | # default malloc function. 72 | JEMALLOC=$(ldconfig -p | grep -i "libjemalloc.so$" | tr " " "\n" | grep "/" \ 73 | | head -n 1) 74 | TCMALLOC=$(ldconfig -p | grep -i "libtcmalloc.so.4$" | tr " " "\n" | grep "/" \ 75 | | head -n 1) 76 | if [ -n "$JEMALLOC" ]; then # if found the path to libjemalloc.so 77 | echo "Switch memory allocator to jemalloc." 78 | export LD_PRELOAD=$JEMALLOC:$LD_PRELOAD 79 | elif [ -n "$TCMALLOC" ]; then # if found the path to libtcmalloc.so.4 80 | echo "Switch memory allocator to tcmalloc." 81 | export LD_PRELOAD=$TCMALLOC:$LD_PRELOAD 82 | fi 83 | 84 | # Get an available port for launching distributed training. 85 | # Credit to https://superuser.com/a/1293762. 86 | export DEFAULT_FREE_PORT 87 | DEFAULT_FREE_PORT=$(comm -23 <(seq 49152 65535 | sort) \ 88 | <(ss -Htan | awk '{print $4}' | cut -d':' -f2 | sort -u) \ 89 | | shuf | head -n 1) 90 | 91 | MASTER_IP=${MASTER_IP:-127.0.0.1} 92 | MASTER_PORT=${MASTER_PORT:-$DEFAULT_FREE_PORT} 93 | NODE_SIZE=${NODE_SIZE:-1} 94 | NODE_RANK=${NODE_RANK:-0} 95 | 96 | ${PYTHON} -m torch.distributed.launch \ 97 | --master_addr=${MASTER_IP} \ 98 | --master_port=${MASTER_PORT} \ 99 | --nnodes=${NODE_SIZE} \ 100 | --node_rank=${NODE_RANK} \ 101 | --nproc_per_node=${GPUS} \ 102 | ./main.py \ 103 | ${COMMAND} \ 104 | ${@:3} \ 105 | || exit 1 # Stop the script when it finds exception threw by Python. 106 | -------------------------------------------------------------------------------- /eval_zs.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 3 \ 2 | --rdzv_endpoint=$HOSTE_NODE_ADDR \ 3 | -m main \ 4 | --imagenet-val '/nas1/datasets/ImageNet1k/val' \ 5 | --logs ../eval/ \ 6 | --pretrained '/home/kecheng/ECCV2024/epoch_32.pt' \ 7 | --batch-size=16 \ 8 | --model "ViT-B-16" \ 9 | -------------------------------------------------------------------------------- /figures/moti.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/DreamLIP/b4de0a43b6c002033c02873f91a695ab449e464c/figures/moti.png -------------------------------------------------------------------------------- /figures/radar.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/DreamLIP/b4de0a43b6c002033c02873f91a695ab449e464c/figures/radar.jpg -------------------------------------------------------------------------------- /figures/radar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/DreamLIP/b4de0a43b6c002033c02873f91a695ab449e464c/figures/radar.png -------------------------------------------------------------------------------- /open_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .coca_model import CoCa 2 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 3 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss 4 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint 5 | from .loss import ClipLoss, DistillClipLoss, CoCaLoss 6 | from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ 7 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \ 8 | get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg 9 | from .openai import load_openai_model, list_openai_models 10 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ 11 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 12 | from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub 13 | from .tokenizer import SimpleTokenizer, tokenize, decode 14 | from .transform import image_transform, AugmentationCfg 15 | from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy 16 | from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES 17 | -------------------------------------------------------------------------------- /open_clip/big_vision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from .model import CustomTextCLIP 5 | from .transformer import TextTransformer, Transformer 6 | 7 | 8 | @torch.no_grad() 9 | def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str): 10 | """ Load weights from .npz checkpoints for official Google big_vision image-text models 11 | 12 | Currently the SigLIP source models are supported and a CustomTextCLIP destination model 13 | w/ timm image encoder. 14 | """ 15 | from timm.layers import resample_patch_embed, resample_abs_pos_embed 16 | 17 | def _n2p(w, t=True): 18 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 19 | w = w.flatten() 20 | if t: 21 | if w.ndim == 4: 22 | w = w.transpose([3, 2, 0, 1]) 23 | elif w.ndim == 3: 24 | w = w.transpose([2, 0, 1]) 25 | elif w.ndim == 2: 26 | w = w.transpose([1, 0]) 27 | return torch.from_numpy(w) 28 | 29 | w = np.load(checkpoint_path) 30 | interpolation = 'bilinear' 31 | antialias = False 32 | 33 | def _convert_timm_img(module, prefix): 34 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 35 | if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]: 36 | embed_conv_w = resample_patch_embed( 37 | embed_conv_w, 38 | module.patch_embed.proj.weight.shape[-2:], 39 | interpolation=interpolation, 40 | antialias=antialias, 41 | verbose=True, 42 | ) 43 | module.patch_embed.proj.weight.copy_(embed_conv_w) 44 | module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 45 | 46 | if module.cls_token is not None: 47 | module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 48 | 49 | pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False) 50 | if pos_embed_w.shape != module.pos_embed.shape: 51 | assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}' 52 | num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1) 53 | pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights 54 | pos_embed_w, 55 | new_size=module.patch_embed.grid_size, 56 | num_prefix_tokens=num_prefix_tokens, 57 | interpolation=interpolation, 58 | antialias=antialias, 59 | verbose=True, 60 | ) 61 | module.pos_embed.copy_(pos_embed_w) 62 | 63 | mha_sub, b_sub, ln1_sub = (0, 0, 1) 64 | for i, block in enumerate(module.blocks.children()): 65 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 66 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/' 67 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 68 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 69 | block.attn.qkv.weight.copy_(torch.cat([ 70 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 71 | block.attn.qkv.bias.copy_(torch.cat([ 72 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 73 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 74 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 75 | for r in range(2): 76 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'])) 77 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'])) 78 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'])) 79 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'])) 80 | 81 | module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 82 | module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 83 | 84 | if module.attn_pool is not None: 85 | block_prefix = f'{prefix}MAPHead_0/' 86 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' 87 | module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False)) 88 | module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T) 89 | module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1)) 90 | module.attn_pool.kv.weight.copy_(torch.cat([ 91 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')])) 92 | module.attn_pool.kv.bias.copy_(torch.cat([ 93 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')])) 94 | module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 95 | module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 96 | module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 97 | module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 98 | for r in range(2): 99 | getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel'])) 100 | getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias'])) 101 | 102 | def _convert_openclip_transformer(module: Transformer, prefix): 103 | for i, block in enumerate(module.resblocks.children()): 104 | block_prefix = f'{prefix}encoderblock_{i}/' 105 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' 106 | block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 107 | block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 108 | block.attn.in_proj_weight.copy_(torch.cat([ 109 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 110 | block.attn.in_proj_bias.copy_(torch.cat([ 111 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 112 | block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 113 | block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 114 | block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale'])) 115 | block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias'])) 116 | block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel'])) 117 | block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias'])) 118 | block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel'])) 119 | block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias'])) 120 | 121 | def _convert_openclip_txt(module: TextTransformer, prefix): 122 | module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False)) 123 | pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0) 124 | module.positional_embedding.copy_(pos_embed_w) 125 | _convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/') 126 | module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale'])) 127 | module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias'])) 128 | module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 129 | module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 130 | 131 | _convert_timm_img(model.visual.trunk, 'params/img/') 132 | _convert_openclip_txt(model.text, 'params/txt/') 133 | model.logit_bias.copy_(_n2p(w['params/b'])[0]) 134 | model.logit_scale.copy_(_n2p(w['params/t'])[0]) 135 | 136 | 137 | -------------------------------------------------------------------------------- /open_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/DreamLIP/b4de0a43b6c002033c02873f91a695ab449e464c/open_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /open_clip/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 3 | IMAGENET_MEAN = (0.485, 0.456, 0.406) 4 | IMAGENET_STD = (0.229, 0.224, 0.225) 5 | INCEPTION_MEAN = (0.5, 0.5, 0.5) 6 | INCEPTION_STD = (0.5, 0.5, 0.5) 7 | -------------------------------------------------------------------------------- /open_clip/generation_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/DreamLIP/b4de0a43b6c002033c02873f91a695ab449e464c/open_clip/generation_utils.py -------------------------------------------------------------------------------- /open_clip/hf_configs.py: -------------------------------------------------------------------------------- 1 | # HF architecture dict: 2 | arch_dict = { 3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 4 | "roberta": { 5 | "config_names": { 6 | "context_length": "max_position_embeddings", 7 | "vocab_size": "vocab_size", 8 | "width": "hidden_size", 9 | "heads": "num_attention_heads", 10 | "layers": "num_hidden_layers", 11 | "layer_attr": "layer", 12 | "token_embeddings_attr": "embeddings" 13 | }, 14 | "pooler": "mean_pooler", 15 | }, 16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 17 | "xlm-roberta": { 18 | "config_names": { 19 | "context_length": "max_position_embeddings", 20 | "vocab_size": "vocab_size", 21 | "width": "hidden_size", 22 | "heads": "num_attention_heads", 23 | "layers": "num_hidden_layers", 24 | "layer_attr": "layer", 25 | "token_embeddings_attr": "embeddings" 26 | }, 27 | "pooler": "mean_pooler", 28 | }, 29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 30 | "mt5": { 31 | "config_names": { 32 | # unlimited seqlen 33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 35 | "context_length": "", 36 | "vocab_size": "vocab_size", 37 | "width": "d_model", 38 | "heads": "num_heads", 39 | "layers": "num_layers", 40 | "layer_attr": "block", 41 | "token_embeddings_attr": "embed_tokens" 42 | }, 43 | "pooler": "mean_pooler", 44 | }, 45 | # https://huggingface.co/docs/transformers/model_doc/bert 46 | "bert": { 47 | "config_names": { 48 | "context_length": "max_position_embeddings", 49 | "vocab_size": "vocab_size", 50 | "width": "hidden_size", 51 | "heads": "num_attention_heads", 52 | "layers": "num_hidden_layers", 53 | }, 54 | "pooler": "cls_pooler", 55 | }, 56 | # https://huggingface.co/docs/transformers/model_doc/m2m_100 57 | "m2m_100": { 58 | "config_names": { 59 | "context_length": "max_position_embeddings", 60 | "vocab_size": "vocab_size", 61 | "width": "d_model", 62 | "heads": "encoder_attention_heads", 63 | "layers": "encoder_layers", 64 | }, 65 | "pooler": "cls_pooler", 66 | }, 67 | } 68 | -------------------------------------------------------------------------------- /open_clip/hf_model.py: -------------------------------------------------------------------------------- 1 | """ huggingface model adapter 2 | 3 | Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. 4 | """ 5 | import re 6 | import os 7 | import torch 8 | import torch.nn as nn 9 | from torch import TensorType 10 | 11 | try: 12 | import transformers 13 | from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig 14 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ 15 | BaseModelOutputWithPoolingAndCrossAttentions 16 | except ImportError as e: 17 | transformers = None 18 | 19 | 20 | class BaseModelOutput: 21 | pass 22 | 23 | 24 | class PretrainedConfig: 25 | pass 26 | 27 | from .hf_configs import arch_dict 28 | 29 | 30 | # utils 31 | def _camel2snake(s): 32 | return re.sub(r'(? 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.act1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.act2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.act3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.act1(self.bn1(self.conv1(x))) 46 | out = self.act2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.act3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x, key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0., 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | 92 | return x[0] 93 | 94 | 95 | class ModifiedResNet(nn.Module): 96 | """ 97 | A ResNet class that is similar to torchvision's but contains the following changes: 98 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 99 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 100 | - The final pooling layer is a QKV attention instead of an average pool 101 | """ 102 | 103 | def __init__(self, layers, output_dim, heads, image_size=224, width=64): 104 | super().__init__() 105 | self.output_dim = output_dim 106 | self.image_size = image_size 107 | 108 | # the 3-layer stem 109 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 110 | self.bn1 = nn.BatchNorm2d(width // 2) 111 | self.act1 = nn.ReLU(inplace=True) 112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 113 | self.bn2 = nn.BatchNorm2d(width // 2) 114 | self.act2 = nn.ReLU(inplace=True) 115 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 116 | self.bn3 = nn.BatchNorm2d(width) 117 | self.act3 = nn.ReLU(inplace=True) 118 | self.avgpool = nn.AvgPool2d(2) 119 | 120 | # residual layers 121 | self._inplanes = width # this is a *mutable* variable used during construction 122 | self.layer1 = self._make_layer(width, layers[0]) 123 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 124 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 125 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 126 | 127 | embed_dim = width * 32 # the ResNet feature dimension 128 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 129 | 130 | self.init_parameters() 131 | 132 | def _make_layer(self, planes, blocks, stride=1): 133 | layers = [Bottleneck(self._inplanes, planes, stride)] 134 | 135 | self._inplanes = planes * Bottleneck.expansion 136 | for _ in range(1, blocks): 137 | layers.append(Bottleneck(self._inplanes, planes)) 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def init_parameters(self): 142 | if self.attnpool is not None: 143 | std = self.attnpool.c_proj.in_features ** -0.5 144 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 145 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 146 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 147 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 148 | 149 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 150 | for name, param in resnet_block.named_parameters(): 151 | if name.endswith("bn3.weight"): 152 | nn.init.zeros_(param) 153 | 154 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 155 | assert unlocked_groups == 0, 'partial locking not currently supported for this model' 156 | for param in self.parameters(): 157 | param.requires_grad = False 158 | if freeze_bn_stats: 159 | freeze_batch_norm_2d(self) 160 | 161 | @torch.jit.ignore 162 | def set_grad_checkpointing(self, enable=True): 163 | # FIXME support for non-transformer 164 | pass 165 | 166 | def stem(self, x): 167 | x = self.act1(self.bn1(self.conv1(x))) 168 | x = self.act2(self.bn2(self.conv2(x))) 169 | x = self.act3(self.bn3(self.conv3(x))) 170 | x = self.avgpool(x) 171 | return x 172 | 173 | def forward(self, x): 174 | x = self.stem(x) 175 | x = self.layer1(x) 176 | x = self.layer2(x) 177 | x = self.layer3(x) 178 | x = self.layer4(x) 179 | x = self.attnpool(x) 180 | 181 | return x 182 | -------------------------------------------------------------------------------- /open_clip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import List, Optional, Union 9 | 10 | import torch 11 | 12 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 13 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype 14 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url 15 | 16 | __all__ = ["list_openai_models", "load_openai_model"] 17 | 18 | 19 | def list_openai_models() -> List[str]: 20 | """Returns the names of available CLIP models""" 21 | return list_pretrained_models_by_tag('openai') 22 | 23 | 24 | def load_openai_model( 25 | name: str, 26 | precision: Optional[str] = None, 27 | device: Optional[Union[str, torch.device]] = None, 28 | cache_dir: Optional[str] = None, 29 | ): 30 | """Load a CLIP model 31 | 32 | Parameters 33 | ---------- 34 | name : str 35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 36 | precision: str 37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | cache_dir : Optional[str] 41 | The directory to cache the downloaded model weights 42 | 43 | Returns 44 | ------- 45 | model : torch.nn.Module 46 | The CLIP model 47 | preprocess : Callable[[PIL.Image], torch.Tensor] 48 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 49 | """ 50 | if device is None: 51 | device = "cuda" if torch.cuda.is_available() else "cpu" 52 | if precision is None: 53 | precision = 'fp32' if device == 'cpu' else 'fp16' 54 | 55 | if get_pretrained_url(name, 'openai'): 56 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) 57 | elif os.path.isfile(name): 58 | model_path = name 59 | else: 60 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 61 | 62 | try: 63 | # loading JIT archive 64 | model = torch.jit.load(model_path, map_location="cpu").eval() 65 | state_dict = None 66 | except RuntimeError: 67 | # loading saved state dict 68 | state_dict = torch.load(model_path, map_location="cpu") 69 | 70 | # Build a non-jit model from the OpenAI jitted model state dict 71 | cast_dtype = get_cast_dtype(precision) 72 | try: 73 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) 74 | except KeyError: 75 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 76 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) 77 | 78 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use 79 | model = model.to(device) 80 | # FIXME support pure fp16/bf16 precision modes 81 | if precision != 'fp16': 82 | model.float() 83 | if precision == 'bf16': 84 | # for bf16, convert back to low-precision 85 | convert_weights_to_lp(model, dtype=torch.bfloat16) 86 | 87 | # add mean / std attributes for consistency with OpenCLIP models 88 | model.visual.image_mean = OPENAI_DATASET_MEAN 89 | model.visual.image_std = OPENAI_DATASET_STD 90 | return model 91 | -------------------------------------------------------------------------------- /open_clip/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | -------------------------------------------------------------------------------- /open_clip/push_to_hf_hub.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from pathlib import Path 5 | from tempfile import TemporaryDirectory 6 | from typing import Optional, Tuple, Union 7 | 8 | import torch 9 | 10 | try: 11 | from huggingface_hub import ( 12 | create_repo, 13 | get_hf_file_metadata, 14 | hf_hub_download, 15 | hf_hub_url, 16 | repo_type_and_id_from_hf_id, 17 | upload_folder, 18 | list_repo_files, 19 | ) 20 | from huggingface_hub.utils import EntryNotFoundError 21 | _has_hf_hub = True 22 | except ImportError: 23 | _has_hf_hub = False 24 | 25 | try: 26 | import safetensors.torch 27 | _has_safetensors = True 28 | except ImportError: 29 | _has_safetensors = False 30 | 31 | from .factory import create_model_from_pretrained, get_model_config, get_tokenizer 32 | from .tokenizer import HFTokenizer 33 | 34 | # Default name for a weights file hosted on the Huggingface Hub. 35 | HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl 36 | HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version 37 | HF_CONFIG_NAME = 'open_clip_config.json' 38 | 39 | 40 | def save_config_for_hf( 41 | model, 42 | config_path: str, 43 | model_config: Optional[dict] 44 | ): 45 | preprocess_cfg = { 46 | 'mean': model.visual.image_mean, 47 | 'std': model.visual.image_std, 48 | } 49 | other_pp = getattr(model.visual, 'preprocess_cfg', {}) 50 | if 'interpolation' in other_pp: 51 | preprocess_cfg['interpolation'] = other_pp['interpolation'] 52 | if 'resize_mode' in other_pp: 53 | preprocess_cfg['resize_mode'] = other_pp['resize_mode'] 54 | hf_config = { 55 | 'model_cfg': model_config, 56 | 'preprocess_cfg': preprocess_cfg, 57 | } 58 | 59 | with config_path.open('w') as f: 60 | json.dump(hf_config, f, indent=2) 61 | 62 | 63 | def save_for_hf( 64 | model, 65 | tokenizer: HFTokenizer, 66 | model_config: dict, 67 | save_directory: str, 68 | safe_serialization: Union[bool, str] = 'both', 69 | skip_weights : bool = False, 70 | ): 71 | config_filename = HF_CONFIG_NAME 72 | 73 | save_directory = Path(save_directory) 74 | save_directory.mkdir(exist_ok=True, parents=True) 75 | 76 | if not skip_weights: 77 | tensors = model.state_dict() 78 | if safe_serialization is True or safe_serialization == "both": 79 | assert _has_safetensors, "`pip install safetensors` to use .safetensors" 80 | safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME) 81 | if safe_serialization is False or safe_serialization == "both": 82 | torch.save(tensors, save_directory / HF_WEIGHTS_NAME) 83 | 84 | tokenizer.save_pretrained(save_directory) 85 | 86 | config_path = save_directory / config_filename 87 | save_config_for_hf(model, config_path, model_config=model_config) 88 | 89 | 90 | def push_to_hf_hub( 91 | model, 92 | tokenizer, 93 | model_config: Optional[dict], 94 | repo_id: str, 95 | commit_message: str = 'Add model', 96 | token: Optional[str] = None, 97 | revision: Optional[str] = None, 98 | private: bool = False, 99 | create_pr: bool = False, 100 | model_card: Optional[dict] = None, 101 | safe_serialization: Union[bool, str] = False, 102 | ): 103 | if not isinstance(tokenizer, HFTokenizer): 104 | # FIXME this makes it awkward to push models with new tokenizers, come up with better soln. 105 | # default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14 106 | tokenizer = HFTokenizer('openai/clip-vit-large-patch14') 107 | 108 | # Create repo if it doesn't exist yet 109 | repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) 110 | 111 | # Infer complete repo_id from repo_url 112 | # Can be different from the input `repo_id` if repo_owner was implicit 113 | _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) 114 | repo_id = f"{repo_owner}/{repo_name}" 115 | 116 | # Check if repo already exists and determine what needs updating 117 | repo_exists = False 118 | repo_files = {} 119 | try: 120 | repo_files = set(list_repo_files(repo_id)) 121 | repo_exists = True 122 | except Exception as e: 123 | print('Repo does not exist', e) 124 | 125 | try: 126 | get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision)) 127 | has_readme = True 128 | except EntryNotFoundError: 129 | has_readme = False 130 | 131 | # Dump model and push to Hub 132 | with TemporaryDirectory() as tmpdir: 133 | # Save model weights and config. 134 | save_for_hf( 135 | model, 136 | tokenizer=tokenizer, 137 | model_config=model_config, 138 | save_directory=tmpdir, 139 | safe_serialization=safe_serialization, 140 | ) 141 | 142 | # Add readme if it does not exist 143 | if not has_readme: 144 | model_card = model_card or {} 145 | model_name = repo_id.split('/')[-1] 146 | readme_path = Path(tmpdir) / "README.md" 147 | readme_text = generate_readme(model_card, model_name) 148 | readme_path.write_text(readme_text) 149 | 150 | # Upload model and return 151 | return upload_folder( 152 | repo_id=repo_id, 153 | folder_path=tmpdir, 154 | revision=revision, 155 | create_pr=create_pr, 156 | commit_message=commit_message, 157 | ) 158 | 159 | 160 | def push_pretrained_to_hf_hub( 161 | model_name, 162 | pretrained: str, 163 | repo_id: str, 164 | precision: str = 'fp32', 165 | image_mean: Optional[Tuple[float, ...]] = None, 166 | image_std: Optional[Tuple[float, ...]] = None, 167 | image_interpolation: Optional[str] = None, 168 | image_resize_mode: Optional[str] = None, # only effective for inference 169 | commit_message: str = 'Add model', 170 | token: Optional[str] = None, 171 | revision: Optional[str] = None, 172 | private: bool = False, 173 | create_pr: bool = False, 174 | model_card: Optional[dict] = None, 175 | hf_tokenizer_self: bool = False, 176 | ): 177 | model, preprocess_eval = create_model_from_pretrained( 178 | model_name, 179 | pretrained=pretrained, 180 | precision=precision, 181 | image_mean=image_mean, 182 | image_std=image_std, 183 | image_interpolation=image_interpolation, 184 | image_resize_mode=image_resize_mode, 185 | ) 186 | model_config = get_model_config(model_name) 187 | assert model_config 188 | 189 | tokenizer = get_tokenizer(model_name) 190 | if hf_tokenizer_self: 191 | # make hf tokenizer config in the uploaded model point to self instead of original location 192 | model_config['text']['hf_tokenizer_name'] = repo_id 193 | 194 | push_to_hf_hub( 195 | model=model, 196 | tokenizer=tokenizer, 197 | model_config=model_config, 198 | repo_id=repo_id, 199 | commit_message=commit_message, 200 | token=token, 201 | revision=revision, 202 | private=private, 203 | create_pr=create_pr, 204 | model_card=model_card, 205 | safe_serialization='both', 206 | ) 207 | 208 | 209 | def generate_readme(model_card: dict, model_name: str): 210 | tags = model_card.pop('tags', ('clip',)) 211 | pipeline_tag = model_card.pop('pipeline_tag', 'zero-shot-image-classification') 212 | readme_text = "---\n" 213 | if tags: 214 | readme_text += "tags:\n" 215 | for t in tags: 216 | readme_text += f"- {t}\n" 217 | readme_text += "library_name: open_clip\n" 218 | readme_text += f"pipeline_tag: {pipeline_tag}\n" 219 | readme_text += f"license: {model_card.get('license', 'mit')}\n" 220 | if 'details' in model_card and 'Dataset' in model_card['details']: 221 | readme_text += 'datasets:\n' 222 | readme_text += f"- {model_card['details']['Dataset'].lower()}\n" 223 | readme_text += "---\n" 224 | readme_text += f"# Model card for {model_name}\n" 225 | if 'description' in model_card: 226 | readme_text += f"\n{model_card['description']}\n" 227 | if 'details' in model_card: 228 | readme_text += f"\n## Model Details\n" 229 | for k, v in model_card['details'].items(): 230 | if isinstance(v, (list, tuple)): 231 | readme_text += f"- **{k}:**\n" 232 | for vi in v: 233 | readme_text += f" - {vi}\n" 234 | elif isinstance(v, dict): 235 | readme_text += f"- **{k}:**\n" 236 | for ki, vi in v.items(): 237 | readme_text += f" - {ki}: {vi}\n" 238 | else: 239 | readme_text += f"- **{k}:** {v}\n" 240 | if 'usage' in model_card: 241 | readme_text += f"\n## Model Usage\n" 242 | readme_text += model_card['usage'] 243 | readme_text += '\n' 244 | 245 | if 'comparison' in model_card: 246 | readme_text += f"\n## Model Comparison\n" 247 | readme_text += model_card['comparison'] 248 | readme_text += '\n' 249 | 250 | if 'citation' in model_card: 251 | readme_text += f"\n## Citation\n" 252 | if not isinstance(model_card['citation'], (list, tuple)): 253 | citations = [model_card['citation']] 254 | else: 255 | citations = model_card['citation'] 256 | for c in citations: 257 | readme_text += f"```bibtex\n{c}\n```\n" 258 | 259 | return readme_text 260 | 261 | 262 | if __name__ == "__main__": 263 | parser = argparse.ArgumentParser(description="Push to Hugging Face Hub") 264 | parser.add_argument( 265 | "--model", type=str, help="Name of the model to use.", 266 | ) 267 | parser.add_argument( 268 | "--pretrained", type=str, 269 | help="Use a pretrained CLIP model weights with the specified tag or file path.", 270 | ) 271 | parser.add_argument( 272 | "--repo-id", type=str, 273 | help="Destination HF Hub repo-id ie 'organization/model_id'.", 274 | ) 275 | parser.add_argument( 276 | "--precision", type=str, default='fp32', 277 | ) 278 | parser.add_argument( 279 | '--image-mean', type=float, nargs='+', default=None, metavar='MEAN', 280 | help='Override default image mean value of dataset') 281 | parser.add_argument( 282 | '--image-std', type=float, nargs='+', default=None, metavar='STD', 283 | help='Override default image std deviation of of dataset') 284 | parser.add_argument( 285 | '--image-interpolation', 286 | default=None, type=str, choices=['bicubic', 'bilinear', 'random'], 287 | help="image resize interpolation" 288 | ) 289 | parser.add_argument( 290 | '--image-resize-mode', 291 | default=None, type=str, choices=['shortest', 'longest', 'squash'], 292 | help="image resize mode during inference" 293 | ) 294 | parser.add_argument( 295 | "--hf-tokenizer-self", 296 | default=False, 297 | action="store_true", 298 | help="make hf_tokenizer_name point in uploaded config point to itself" 299 | ) 300 | args = parser.parse_args() 301 | 302 | print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}') 303 | 304 | # FIXME add support to pass model_card json / template from file via cmd line 305 | 306 | push_pretrained_to_hf_hub( 307 | args.model, 308 | args.pretrained, 309 | args.repo_id, 310 | precision=args.precision, 311 | image_mean=args.image_mean, # override image mean/std if trained w/ non defaults 312 | image_std=args.image_std, 313 | image_interpolation=args.image_interpolation, 314 | image_resize_mode=args.image_resize_mode, 315 | ) 316 | 317 | print(f'{args.model} saved.') 318 | -------------------------------------------------------------------------------- /open_clip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | import logging 6 | from collections import OrderedDict 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | try: 12 | import timm 13 | from timm.models.layers import Mlp, to_2tuple 14 | try: 15 | # old timm imports < 0.8.1 16 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 17 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 18 | except ImportError: 19 | # new timm imports >= 0.8.1 20 | from timm.layers import RotAttentionPool2d 21 | from timm.layers import AttentionPool2d as AbsAttentionPool2d 22 | except ImportError: 23 | timm = None 24 | 25 | from .utils import freeze_batch_norm_2d 26 | 27 | 28 | class TimmModel(nn.Module): 29 | """ timm model adapter 30 | """ 31 | 32 | def __init__( 33 | self, 34 | model_name, 35 | embed_dim, 36 | image_size=224, 37 | pool='avg', 38 | proj='linear', 39 | proj_bias=False, 40 | drop=0., 41 | drop_path=None, 42 | patch_drop=None, 43 | pretrained=False, 44 | ): 45 | super().__init__() 46 | if timm is None: 47 | raise RuntimeError("Please `pip install timm` to use timm models.") 48 | self.image_size = to_2tuple(image_size) 49 | 50 | # setup kwargs that may not be common across all models 51 | timm_kwargs = {} 52 | if drop_path is not None: 53 | timm_kwargs['drop_path_rate'] = drop_path 54 | if patch_drop is not None: 55 | timm_kwargs['patch_drop_rate'] = patch_drop 56 | 57 | custom_pool = pool in ('abs_attn', 'rot_attn') 58 | if proj: 59 | assert proj in ("linear", "mlp", "none") 60 | extra_proj = proj in ("linear", "mlp") 61 | if not extra_proj and not custom_pool: 62 | # use network classifier head as projection if no proj specified and no custom pooling used 63 | # if projection is explicitly set to "none" will be pass through from network trunk 64 | proj_dim = 0 if proj == 'none' else embed_dim 65 | self.trunk = timm.create_model( 66 | model_name, 67 | num_classes=proj_dim, 68 | global_pool=pool, 69 | pretrained=pretrained, 70 | **timm_kwargs, 71 | ) 72 | prev_chs = embed_dim 73 | else: 74 | self.trunk = timm.create_model( 75 | model_name, 76 | pretrained=pretrained, 77 | **timm_kwargs, 78 | ) 79 | feat_size = self.trunk.default_cfg.get('pool_size', None) 80 | feature_ndim = 1 if not feat_size else 2 81 | if custom_pool: 82 | assert feature_ndim == 2 83 | # if attn pooling used, remove both classifier and default pool 84 | self.trunk.reset_classifier(0, global_pool='') 85 | else: 86 | # reset global pool if pool config set, otherwise leave as network default 87 | reset_kwargs = dict(global_pool=pool) if pool else {} 88 | self.trunk.reset_classifier(0, **reset_kwargs) 89 | prev_chs = self.trunk.num_features 90 | 91 | head_layers = OrderedDict() 92 | 93 | # Add custom pooling to head 94 | if pool == 'abs_attn': 95 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 96 | prev_chs = embed_dim 97 | elif pool == 'rot_attn': 98 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 99 | prev_chs = embed_dim 100 | 101 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 102 | if proj == 'linear': 103 | head_layers['drop'] = nn.Dropout(drop) 104 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) 105 | elif proj == 'mlp': 106 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) 107 | 108 | self.head = nn.Sequential(head_layers) 109 | 110 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 111 | """ lock modules 112 | Args: 113 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 114 | """ 115 | if not unlocked_groups: 116 | # lock full model 117 | for param in self.trunk.parameters(): 118 | param.requires_grad = False 119 | if freeze_bn_stats: 120 | freeze_batch_norm_2d(self.trunk) 121 | else: 122 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 123 | try: 124 | # FIXME import here until API stable and in an official release 125 | from timm.models.helpers import group_parameters, group_modules 126 | except ImportError: 127 | raise RuntimeError( 128 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') 129 | matcher = self.trunk.group_matcher() 130 | gparams = group_parameters(self.trunk, matcher) 131 | max_layer_id = max(gparams.keys()) 132 | max_layer_id = max_layer_id - unlocked_groups 133 | for group_idx in range(max_layer_id + 1): 134 | group = gparams[group_idx] 135 | for param in group: 136 | self.trunk.get_parameter(param).requires_grad = False 137 | if freeze_bn_stats: 138 | gmodules = group_modules(self.trunk, matcher, reverse=True) 139 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 140 | freeze_batch_norm_2d(self.trunk, gmodules) 141 | 142 | @torch.jit.ignore 143 | def set_grad_checkpointing(self, enable=True): 144 | try: 145 | self.trunk.set_grad_checkpointing(enable) 146 | except Exception as e: 147 | logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') 148 | 149 | def forward(self, x): 150 | x = self.trunk(x) 151 | x = self.head(x) 152 | return x 153 | -------------------------------------------------------------------------------- /open_clip/transform.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import random 3 | import warnings 4 | from dataclasses import dataclass, asdict 5 | from typing import Any, Dict, List, Optional, Sequence, Tuple, Union 6 | 7 | import torch 8 | import torchvision.transforms.functional as F 9 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ 10 | CenterCrop, ColorJitter, Grayscale, RandomApply, RandomGrayscale, RandomHorizontalFlip 11 | from PIL import ImageFilter 12 | from training.random_aug import RandomAugment 13 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 14 | from .utils import to_2tuple 15 | 16 | 17 | @dataclass 18 | class PreprocessCfg: 19 | size: Union[int, Tuple[int, int]] = 224 20 | mode: str = 'RGB' 21 | mean: Tuple[float, ...] = OPENAI_DATASET_MEAN 22 | std: Tuple[float, ...] = OPENAI_DATASET_STD 23 | interpolation: str = 'bicubic' 24 | resize_mode: str = 'shortest' 25 | fill_color: int = 0 26 | 27 | def __post_init__(self): 28 | assert self.mode in ('RGB',) 29 | 30 | @property 31 | def num_channels(self): 32 | return 3 33 | 34 | @property 35 | def input_size(self): 36 | return (self.num_channels,) + to_2tuple(self.size) 37 | 38 | _PREPROCESS_KEYS = set(asdict(PreprocessCfg()).keys()) 39 | 40 | class GaussianBlur(object): 41 | def __init__(self, sigma=[.1, 2.]): 42 | self.sigma = sigma 43 | 44 | def __call__(self, x): 45 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 46 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 47 | return x 48 | 49 | 50 | def merge_preprocess_dict( 51 | base: Union[PreprocessCfg, Dict], 52 | overlay: Dict, 53 | ): 54 | """ Merge overlay key-value pairs on top of base preprocess cfg or dict. 55 | Input dicts are filtered based on PreprocessCfg fields. 56 | """ 57 | if isinstance(base, PreprocessCfg): 58 | base_clean = asdict(base) 59 | else: 60 | base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS} 61 | if overlay: 62 | overlay_clean = {k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None} 63 | base_clean.update(overlay_clean) 64 | return base_clean 65 | 66 | 67 | def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs): 68 | return merge_preprocess_dict(base, kwargs) 69 | 70 | 71 | @dataclass 72 | class AugmentationCfg: 73 | scale: Tuple[float, float] = (0.9, 1.0) 74 | ratio: Optional[Tuple[float, float]] = None 75 | color_jitter: Optional[Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]] = None 76 | re_prob: Optional[float] = None 77 | re_count: Optional[int] = None 78 | use_timm: bool = False 79 | 80 | # params for simclr_jitter_gray 81 | color_jitter_prob: float = None 82 | gray_scale_prob: float = None 83 | 84 | 85 | def _setup_size(size, error_msg): 86 | if isinstance(size, numbers.Number): 87 | return int(size), int(size) 88 | 89 | if isinstance(size, Sequence) and len(size) == 1: 90 | return size[0], size[0] 91 | 92 | if len(size) != 2: 93 | raise ValueError(error_msg) 94 | 95 | return size 96 | 97 | 98 | class ResizeKeepRatio: 99 | """ Resize and Keep Ratio 100 | 101 | Copy & paste from `timm` 102 | """ 103 | 104 | def __init__( 105 | self, 106 | size, 107 | longest=0., 108 | interpolation=InterpolationMode.BICUBIC, 109 | random_scale_prob=0., 110 | random_scale_range=(0.85, 1.05), 111 | random_aspect_prob=0., 112 | random_aspect_range=(0.9, 1.11) 113 | ): 114 | if isinstance(size, (list, tuple)): 115 | self.size = tuple(size) 116 | else: 117 | self.size = (size, size) 118 | self.interpolation = interpolation 119 | self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest 120 | self.random_scale_prob = random_scale_prob 121 | self.random_scale_range = random_scale_range 122 | self.random_aspect_prob = random_aspect_prob 123 | self.random_aspect_range = random_aspect_range 124 | 125 | @staticmethod 126 | def get_params( 127 | img, 128 | target_size, 129 | longest, 130 | random_scale_prob=0., 131 | random_scale_range=(0.85, 1.05), 132 | random_aspect_prob=0., 133 | random_aspect_range=(0.9, 1.11) 134 | ): 135 | """Get parameters 136 | """ 137 | source_size = img.size[::-1] # h, w 138 | h, w = source_size 139 | target_h, target_w = target_size 140 | ratio_h = h / target_h 141 | ratio_w = w / target_w 142 | ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest) 143 | if random_scale_prob > 0 and random.random() < random_scale_prob: 144 | ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1]) 145 | ratio_factor = (ratio_factor, ratio_factor) 146 | else: 147 | ratio_factor = (1., 1.) 148 | if random_aspect_prob > 0 and random.random() < random_aspect_prob: 149 | aspect_factor = random.uniform(random_aspect_range[0], random_aspect_range[1]) 150 | ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor) 151 | size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)] 152 | return size 153 | 154 | def __call__(self, img): 155 | """ 156 | Args: 157 | img (PIL Image): Image to be cropped and resized. 158 | 159 | Returns: 160 | PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size 161 | """ 162 | size = self.get_params( 163 | img, self.size, self.longest, 164 | self.random_scale_prob, self.random_scale_range, 165 | self.random_aspect_prob, self.random_aspect_range 166 | ) 167 | img = F.resize(img, size, self.interpolation) 168 | return img 169 | 170 | def __repr__(self): 171 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size) 172 | format_string += f', interpolation={self.interpolation})' 173 | format_string += f', longest={self.longest:.3f})' 174 | return format_string 175 | 176 | 177 | def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor: 178 | """Center crops and/or pads the given image. 179 | If the image is torch Tensor, it is expected 180 | to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. 181 | If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. 182 | 183 | Args: 184 | img (PIL Image or Tensor): Image to be cropped. 185 | output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int, 186 | it is used for both directions. 187 | fill (int, Tuple[int]): Padding color 188 | 189 | Returns: 190 | PIL Image or Tensor: Cropped image. 191 | """ 192 | if isinstance(output_size, numbers.Number): 193 | output_size = (int(output_size), int(output_size)) 194 | elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: 195 | output_size = (output_size[0], output_size[0]) 196 | 197 | _, image_height, image_width = F.get_dimensions(img) 198 | crop_height, crop_width = output_size 199 | 200 | if crop_width > image_width or crop_height > image_height: 201 | padding_ltrb = [ 202 | (crop_width - image_width) // 2 if crop_width > image_width else 0, 203 | (crop_height - image_height) // 2 if crop_height > image_height else 0, 204 | (crop_width - image_width + 1) // 2 if crop_width > image_width else 0, 205 | (crop_height - image_height + 1) // 2 if crop_height > image_height else 0, 206 | ] 207 | img = F.pad(img, padding_ltrb, fill=fill) 208 | _, image_height, image_width = F.get_dimensions(img) 209 | if crop_width == image_width and crop_height == image_height: 210 | return img 211 | 212 | crop_top = int(round((image_height - crop_height) / 2.0)) 213 | crop_left = int(round((image_width - crop_width) / 2.0)) 214 | return F.crop(img, crop_top, crop_left, crop_height, crop_width) 215 | 216 | 217 | class CenterCropOrPad(torch.nn.Module): 218 | """Crops the given image at the center. 219 | If the image is torch Tensor, it is expected 220 | to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. 221 | If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. 222 | 223 | Args: 224 | size (sequence or int): Desired output size of the crop. If size is an 225 | int instead of sequence like (h, w), a square crop (size, size) is 226 | made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). 227 | """ 228 | 229 | def __init__(self, size, fill=0): 230 | super().__init__() 231 | self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") 232 | self.fill = fill 233 | 234 | def forward(self, img): 235 | """ 236 | Args: 237 | img (PIL Image or Tensor): Image to be cropped. 238 | 239 | Returns: 240 | PIL Image or Tensor: Cropped image. 241 | """ 242 | return center_crop_or_pad(img, self.size, fill=self.fill) 243 | 244 | def __repr__(self) -> str: 245 | return f"{self.__class__.__name__}(size={self.size})" 246 | 247 | 248 | def _convert_to_rgb(image): 249 | return image.convert('RGB') 250 | 251 | 252 | class color_jitter(object): 253 | """ 254 | Apply Color Jitter to the PIL image with a specified probability. 255 | """ 256 | def __init__(self, brightness=0., contrast=0., saturation=0., hue=0., p=0.8): 257 | assert 0. <= p <= 1. 258 | self.p = p 259 | self.transf = ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) 260 | 261 | def __call__(self, img): 262 | if random.random() < self.p: 263 | return self.transf(img) 264 | else: 265 | return img 266 | 267 | 268 | class gray_scale(object): 269 | """ 270 | Apply Gray Scale to the PIL image with a specified probability. 271 | """ 272 | def __init__(self, p=0.2): 273 | assert 0. <= p <= 1. 274 | self.p = p 275 | self.transf = Grayscale(num_output_channels=3) 276 | 277 | def __call__(self, img): 278 | if random.random() < self.p: 279 | return self.transf(img) 280 | else: 281 | return img 282 | 283 | 284 | def image_transform( 285 | image_size: Union[int, Tuple[int, int]], 286 | is_train: bool, 287 | mean: Optional[Tuple[float, ...]] = None, 288 | std: Optional[Tuple[float, ...]] = None, 289 | resize_mode: Optional[str] = None, 290 | interpolation: Optional[str] = None, 291 | fill_color: int = 0, 292 | aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, 293 | ): 294 | mean = mean or OPENAI_DATASET_MEAN 295 | if not isinstance(mean, (list, tuple)): 296 | mean = (mean,) * 3 297 | 298 | std = std or OPENAI_DATASET_STD 299 | if not isinstance(std, (list, tuple)): 300 | std = (std,) * 3 301 | 302 | interpolation = interpolation or 'bicubic' 303 | assert interpolation in ['bicubic', 'bilinear', 'random'] 304 | # NOTE random is ignored for interpolation_mode, so defaults to BICUBIC for inference if set 305 | interpolation_mode = InterpolationMode.BILINEAR if interpolation == 'bilinear' else InterpolationMode.BICUBIC 306 | 307 | resize_mode = resize_mode or 'shortest' 308 | assert resize_mode in ('shortest', 'longest', 'squash') 309 | 310 | if isinstance(aug_cfg, dict): 311 | aug_cfg = AugmentationCfg(**aug_cfg) 312 | else: 313 | aug_cfg = aug_cfg or AugmentationCfg() 314 | 315 | normalize = Normalize(mean=mean, std=std) 316 | if is_train: 317 | aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} 318 | use_timm = aug_cfg_dict.pop('use_timm', False) 319 | use_dataaug = aug_cfg_dict.pop('use_dataaug', False) 320 | if use_timm: 321 | from timm.data import create_transform # timm can still be optional 322 | if isinstance(image_size, (tuple, list)): 323 | assert len(image_size) >= 2 324 | input_size = (3,) + image_size[-2:] 325 | else: 326 | input_size = (3, image_size, image_size) 327 | 328 | aug_cfg_dict.setdefault('color_jitter', None) # disable by default 329 | # drop extra non-timm items 330 | aug_cfg_dict.pop('color_jitter_prob', None) 331 | aug_cfg_dict.pop('gray_scale_prob', None) 332 | 333 | train_transform = create_transform( 334 | input_size=input_size, 335 | is_training=True, 336 | hflip=0., 337 | mean=mean, 338 | std=std, 339 | re_mode='pixel', 340 | interpolation=interpolation, 341 | **aug_cfg_dict, 342 | ) 343 | elif use_dataaug: 344 | train_transform = Compose([ 345 | RandomResizedCrop(image_size, scale=(0.2, 1.0), interpolation=InterpolationMode.BICUBIC), 346 | RandomApply([ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), 347 | RandomGrayscale(p=0.2), 348 | RandomApply([GaussianBlur([.1, 2.])],p=0.5), 349 | RandomHorizontalFlip(), 350 | RandomAugment(2, 7, isPIL=True, augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness', 351 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), 352 | ToTensor(), 353 | normalize 354 | ]) 355 | else: 356 | train_transform = [ 357 | RandomResizedCrop( 358 | image_size, 359 | scale=aug_cfg_dict.pop('scale'), 360 | interpolation=InterpolationMode.BICUBIC, 361 | ), 362 | _convert_to_rgb, 363 | ] 364 | if aug_cfg.color_jitter_prob: 365 | assert aug_cfg.color_jitter is not None and len(aug_cfg.color_jitter) == 4 366 | train_transform.extend([ 367 | color_jitter(*aug_cfg.color_jitter, p=aug_cfg.color_jitter_prob) 368 | ]) 369 | if aug_cfg.gray_scale_prob: 370 | train_transform.extend([ 371 | gray_scale(aug_cfg.gray_scale_prob) 372 | ]) 373 | train_transform.extend([ 374 | ToTensor(), 375 | normalize, 376 | ]) 377 | train_transform = Compose(train_transform) 378 | if aug_cfg_dict: 379 | warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') 380 | print(train_transform) 381 | return train_transform 382 | else: 383 | if resize_mode == 'longest': 384 | transforms = [ 385 | ResizeKeepRatio(image_size, interpolation=interpolation_mode, longest=1), 386 | CenterCropOrPad(image_size, fill=fill_color) 387 | ] 388 | elif resize_mode == 'squash': 389 | if isinstance(image_size, int): 390 | image_size = (image_size, image_size) 391 | transforms = [ 392 | Resize(image_size, interpolation=interpolation_mode), 393 | ] 394 | else: 395 | assert resize_mode == 'shortest' 396 | if not isinstance(image_size, (tuple, list)): 397 | image_size = (image_size, image_size) 398 | if image_size[0] == image_size[1]: 399 | # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg) 400 | transforms = [ 401 | Resize(image_size[0], interpolation=interpolation_mode) 402 | ] 403 | else: 404 | # resize shortest edge to matching target dim for non-square target 405 | transforms = [ResizeKeepRatio(image_size)] 406 | transforms += [CenterCrop(image_size)] 407 | 408 | transforms.extend([ 409 | _convert_to_rgb, 410 | ToTensor(), 411 | normalize, 412 | ]) 413 | return Compose(transforms) 414 | 415 | 416 | def image_transform_v2( 417 | cfg: PreprocessCfg, 418 | is_train: bool, 419 | aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, 420 | ): 421 | return image_transform( 422 | image_size=cfg.size, 423 | is_train=is_train, 424 | mean=cfg.mean, 425 | std=cfg.std, 426 | interpolation=cfg.interpolation, 427 | resize_mode=cfg.resize_mode, 428 | fill_color=cfg.fill_color, 429 | aug_cfg=aug_cfg, 430 | ) 431 | -------------------------------------------------------------------------------- /open_clip/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | import collections.abc 3 | 4 | import torch 5 | from torch import nn as nn 6 | from torchvision.ops.misc import FrozenBatchNorm2d 7 | 8 | 9 | def freeze_batch_norm_2d(module, module_match={}, name=''): 10 | """ 11 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 12 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 13 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 14 | 15 | Args: 16 | module (torch.nn.Module): Any PyTorch module. 17 | module_match (dict): Dictionary of full module names to freeze (all if empty) 18 | name (str): Full module name (prefix) 19 | 20 | Returns: 21 | torch.nn.Module: Resulting module 22 | 23 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 24 | """ 25 | res = module 26 | is_match = True 27 | if module_match: 28 | is_match = name in module_match 29 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): 30 | res = FrozenBatchNorm2d(module.num_features) 31 | res.num_features = module.num_features 32 | res.affine = module.affine 33 | if module.affine: 34 | res.weight.data = module.weight.data.clone().detach() 35 | res.bias.data = module.bias.data.clone().detach() 36 | res.running_mean.data = module.running_mean.data 37 | res.running_var.data = module.running_var.data 38 | res.eps = module.eps 39 | else: 40 | for child_name, child in module.named_children(): 41 | full_child_name = '.'.join([name, child_name]) if name else child_name 42 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 43 | if new_child is not child: 44 | res.add_module(child_name, new_child) 45 | return res 46 | 47 | 48 | # From PyTorch internals 49 | def _ntuple(n): 50 | def parse(x): 51 | if isinstance(x, collections.abc.Iterable): 52 | return x 53 | return tuple(repeat(x, n)) 54 | return parse 55 | 56 | 57 | to_1tuple = _ntuple(1) 58 | to_2tuple = _ntuple(2) 59 | to_3tuple = _ntuple(3) 60 | to_4tuple = _ntuple(4) 61 | to_ntuple = lambda n, x: _ntuple(n)(x) 62 | 63 | # Replaces all linear layers with linear_replacement 64 | # TODO: add int8 support for other linear layers including attn and convnets 65 | def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True): 66 | for name, module in model.named_children(): 67 | if len(list(module.children())) > 0: 68 | replace_linear(module, linear_replacement, include_modules, copy_weights) 69 | 70 | if isinstance(module, torch.nn.Linear) and name in include_modules: 71 | old_module = model._modules[name] 72 | model._modules[name] = linear_replacement( 73 | module.in_features, 74 | module.out_features, 75 | module.bias is not None, 76 | ) 77 | if copy_weights: 78 | model._modules[name].weight.data.copy_(old_module.weight.data) 79 | if model._modules[name].bias is not None: 80 | model._modules[name].bias.data.copy_(old_module.bias) 81 | 82 | return model 83 | 84 | def convert_int8_model_to_inference_mode(model): 85 | for m in model.modules(): 86 | if hasattr(m, 'prepare_for_eval'): 87 | int8_original_dtype = m.weight.dtype 88 | m.prepare_for_eval() 89 | m.int8_original_dtype = int8_original_dtype -------------------------------------------------------------------------------- /open_clip/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '2.23.0' 2 | -------------------------------------------------------------------------------- /open_clip/zero_shot_classifier.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from itertools import islice 3 | from typing import Callable, List, Optional, Sequence, Union 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | def batched(iterable, n): 10 | """Batch data into lists of length *n*. The last batch may be shorter. 11 | NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl 12 | """ 13 | it = iter(iterable) 14 | while True: 15 | batch = list(islice(it, n)) 16 | if not batch: 17 | break 18 | yield batch 19 | 20 | 21 | def build_zero_shot_classifier( 22 | model, 23 | tokenizer, 24 | classnames: Sequence[str], 25 | templates: Sequence[Union[Callable, str]], 26 | num_classes_per_batch: Optional[int] = 10, 27 | device: Union[str, torch.device] = 'cpu', 28 | use_tqdm: bool = False, 29 | ): 30 | """ Build zero-shot classifier weights by iterating over class names in batches 31 | Args: 32 | model: CLIP model instance 33 | tokenizer: CLIP tokenizer instance 34 | classnames: A sequence of class (label) names 35 | templates: A sequence of callables or format() friendly strings to produce templates per class name 36 | num_classes_per_batch: The number of classes to batch together in each forward, all if None 37 | device: Device to use. 38 | use_tqdm: Enable TQDM progress bar. 39 | """ 40 | assert isinstance(templates, Sequence) and len(templates) > 0 41 | assert isinstance(classnames, Sequence) and len(classnames) > 0 42 | use_format = isinstance(templates[0], str) 43 | num_templates = len(templates) 44 | num_classes = len(classnames) 45 | if use_tqdm: 46 | import tqdm 47 | num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1) 48 | iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch) 49 | else: 50 | iter_wrap = iter 51 | 52 | def _process_batch(batch_classnames): 53 | num_batch_classes = len(batch_classnames) 54 | texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates] 55 | texts = tokenizer(texts).to(device) 56 | class_embeddings = model.encode_text(texts, normalize=True) 57 | class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1) 58 | class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True) 59 | class_embeddings = class_embeddings.T 60 | return class_embeddings 61 | 62 | with torch.no_grad(): 63 | if num_classes_per_batch: 64 | batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))] 65 | zeroshot_weights = torch.cat(batched_embeds, dim=1) 66 | else: 67 | zeroshot_weights = _process_batch(classnames) 68 | return zeroshot_weights 69 | 70 | 71 | def build_zero_shot_classifier_legacy( 72 | model, 73 | tokenizer, 74 | classnames: Sequence[str], 75 | templates: Sequence[Union[Callable, str]], 76 | device: Union[str, torch.device] = 'cpu', 77 | use_tqdm: bool = False, 78 | ): 79 | """ Build zero-shot classifier weights by iterating over class names 1 by 1 80 | Args: 81 | model: CLIP model instance 82 | tokenizer: CLIP tokenizer instance 83 | classnames: A sequence of class (label) names 84 | templates: A sequence of callables or format() friendly strings to produce templates per class name 85 | device: Device to use. 86 | use_tqdm: Enable TQDM progress bar. 87 | """ 88 | assert isinstance(templates, Sequence) and len(templates) > 0 89 | assert isinstance(classnames, Sequence) and len(classnames) > 0 90 | if use_tqdm: 91 | import tqdm 92 | iter_wrap = tqdm.tqdm 93 | else: 94 | iter_wrap = iter 95 | 96 | use_format = isinstance(templates[0], str) 97 | 98 | with torch.no_grad(): 99 | zeroshot_weights = [] 100 | for classname in iter_wrap(classnames): 101 | texts = [template.format(classname) if use_format else template(classname) for template in templates] 102 | texts = tokenizer(texts).to(device) # tokenize 103 | class_embeddings = model.encode_text(texts) 104 | class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) 105 | class_embedding /= class_embedding.norm() 106 | zeroshot_weights.append(class_embedding) 107 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) 108 | 109 | return zeroshot_weights 110 | 111 | -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | pytest-split==0.8.0 2 | pytest==7.2.0 3 | transformers 4 | timm>=0.9.8 5 | -------------------------------------------------------------------------------- /requirements-training.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.0 2 | torchvision 3 | webdataset>=0.2.5 4 | regex 5 | ftfy 6 | tqdm 7 | pandas 8 | braceexpand 9 | huggingface_hub 10 | transformers 11 | timm>=0.9.8 12 | fsspec 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.0 2 | torchvision 3 | regex 4 | ftfy 5 | tqdm 6 | huggingface_hub 7 | sentencepiece 8 | protobuf 9 | timm 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ Setup 2 | """ 3 | from setuptools import setup, find_packages 4 | from codecs import open 5 | from os import path 6 | 7 | here = path.abspath(path.dirname(__file__)) 8 | 9 | # Get the long description from the README file 10 | with open(path.join(here, 'README.md'), encoding='utf-8') as f: 11 | long_description = f.read() 12 | 13 | def _read_reqs(relpath): 14 | fullpath = path.join(path.dirname(__file__), relpath) 15 | with open(fullpath) as f: 16 | return [s.strip() for s in f.readlines() if (s.strip() and not s.startswith("#"))] 17 | 18 | REQUIREMENTS = _read_reqs("requirements.txt") 19 | TRAINING_REQUIREMENTS = _read_reqs("requirements-training.txt") 20 | 21 | exec(open('src/open_clip/version.py').read()) 22 | setup( 23 | name='open_clip_torch', 24 | version=__version__, 25 | description='OpenCLIP', 26 | long_description=long_description, 27 | long_description_content_type='text/markdown', 28 | url='https://github.com/mlfoundations/open_clip', 29 | author='', 30 | author_email='', 31 | classifiers=[ 32 | # How mature is this project? Common values are 33 | # 3 - Alpha 34 | # 4 - Beta 35 | # 5 - Production/Stable 36 | 'Development Status :: 3 - Alpha', 37 | 'Intended Audience :: Education', 38 | 'Intended Audience :: Science/Research', 39 | 'License :: OSI Approved :: Apache Software License', 40 | 'Programming Language :: Python :: 3.7', 41 | 'Programming Language :: Python :: 3.8', 42 | 'Programming Language :: Python :: 3.9', 43 | 'Programming Language :: Python :: 3.10', 44 | 'Topic :: Scientific/Engineering', 45 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 46 | 'Topic :: Software Development', 47 | 'Topic :: Software Development :: Libraries', 48 | 'Topic :: Software Development :: Libraries :: Python Modules', 49 | ], 50 | 51 | # Note that this is a string of words separated by whitespace, not a list. 52 | keywords='CLIP pretrained', 53 | package_dir={'': 'src'}, 54 | packages=find_packages(where='src'), 55 | include_package_data=True, 56 | install_requires=REQUIREMENTS, 57 | extras_require={ 58 | "training": TRAINING_REQUIREMENTS, 59 | }, 60 | python_requires='>=3.7', 61 | ) 62 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/DreamLIP/b4de0a43b6c002033c02873f91a695ab449e464c/training/__init__.py -------------------------------------------------------------------------------- /training/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.distributed as dist 5 | from datetime import timedelta 6 | try: 7 | import horovod.torch as hvd 8 | except ImportError: 9 | hvd = None 10 | 11 | 12 | def is_global_master(args): 13 | return args.rank == 0 14 | 15 | 16 | def is_local_master(args): 17 | return args.local_rank == 0 18 | 19 | 20 | def is_master(args, local=False): 21 | return is_local_master(args) if local else is_global_master(args) 22 | 23 | 24 | def is_using_horovod(): 25 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set 26 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... 27 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] 28 | pmi_vars = ["PMI_RANK", "PMI_SIZE"] 29 | if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]): 30 | return True 31 | else: 32 | return False 33 | 34 | 35 | def is_using_distributed(): 36 | if 'WORLD_SIZE' in os.environ: 37 | return int(os.environ['WORLD_SIZE']) > 1 38 | if 'SLURM_NTASKS' in os.environ: 39 | return int(os.environ['SLURM_NTASKS']) > 1 40 | return False 41 | 42 | 43 | def world_info_from_env(): 44 | local_rank = 0 45 | for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'): 46 | if v in os.environ: 47 | local_rank = int(os.environ[v]) 48 | break 49 | global_rank = 0 50 | for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'): 51 | if v in os.environ: 52 | global_rank = int(os.environ[v]) 53 | break 54 | world_size = 1 55 | for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'): 56 | if v in os.environ: 57 | world_size = int(os.environ[v]) 58 | break 59 | 60 | return local_rank, global_rank, world_size 61 | 62 | 63 | def init_distributed_device(args): 64 | # Distributed training = training on more than one GPU. 65 | # Works in both single and multi-node scenarios. 66 | args.distributed = False 67 | args.world_size = 1 68 | args.rank = 0 # global rank 69 | args.local_rank = 0 70 | if args.horovod: 71 | assert hvd is not None, "Horovod is not installed" 72 | hvd.init() 73 | args.local_rank = int(hvd.local_rank()) 74 | args.rank = hvd.rank() 75 | args.world_size = hvd.size() 76 | args.distributed = True 77 | os.environ['LOCAL_RANK'] = str(args.local_rank) 78 | os.environ['RANK'] = str(args.rank) 79 | os.environ['WORLD_SIZE'] = str(args.world_size) 80 | elif is_using_distributed(): 81 | if 'SLURM_PROCID' in os.environ: 82 | # DDP via SLURM 83 | args.local_rank, args.rank, args.world_size = world_info_from_env() 84 | # SLURM var -> torch.distributed vars in case needed 85 | os.environ['LOCAL_RANK'] = str(args.local_rank) 86 | os.environ['RANK'] = str(args.rank) 87 | os.environ['WORLD_SIZE'] = str(args.world_size) 88 | torch.distributed.init_process_group( 89 | backend=args.dist_backend, 90 | init_method=args.dist_url, 91 | world_size=args.world_size, 92 | rank=args.rank, 93 | timeout=timedelta(days=1), 94 | ) 95 | else: 96 | # DDP via torchrun, torch.distributed.launch 97 | args.local_rank, _, _ = world_info_from_env() 98 | torch.distributed.init_process_group( 99 | backend=args.dist_backend, 100 | init_method=args.dist_url, 101 | timeout=timedelta(days=1) 102 | ) 103 | args.world_size = torch.distributed.get_world_size() 104 | args.rank = torch.distributed.get_rank() 105 | args.distributed = True 106 | print(f'Using torch distributed DDP: {args.distributed}') 107 | else: 108 | print(f'No torch distributed DDP: {args.distributed}') 109 | if torch.cuda.is_available(): 110 | if args.distributed and not args.no_set_device_rank: 111 | device = 'cuda:%d' % args.local_rank 112 | else: 113 | device = 'cuda:0' 114 | torch.cuda.set_device(device) 115 | else: 116 | device = 'cpu' 117 | args.device = device 118 | device = torch.device(device) 119 | return device 120 | 121 | 122 | def broadcast_object(args, obj, src=0): 123 | # broadcast a pickle-able python object from rank-0 to all ranks 124 | if args.horovod: 125 | return hvd.broadcast_object(obj, root_rank=src) 126 | else: 127 | if args.rank == src: 128 | objects = [obj] 129 | else: 130 | objects = [None] 131 | dist.broadcast_object_list(objects, src=src) 132 | return objects[0] 133 | 134 | 135 | def all_gather_object(args, obj, dst=0): 136 | # gather a pickle-able python object across all ranks 137 | if args.horovod: 138 | return hvd.allgather_object(obj) 139 | else: 140 | objects = [None for _ in range(args.world_size)] 141 | dist.all_gather_object(objects, obj) 142 | return objects 143 | -------------------------------------------------------------------------------- /training/file_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import multiprocessing 4 | import subprocess 5 | import time 6 | import fsspec 7 | import torch 8 | from tqdm import tqdm 9 | 10 | def remote_sync_s3(local_dir, remote_dir): 11 | # skip epoch_latest which can change during sync. 12 | result = subprocess.run(["aws", "s3", "sync", local_dir, remote_dir, '--exclude', '*epoch_latest.pt'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) 13 | if result.returncode != 0: 14 | logging.error(f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}") 15 | return False 16 | 17 | logging.info(f"Successfully synced with S3 bucket") 18 | return True 19 | 20 | def remote_sync_fsspec(local_dir, remote_dir): 21 | # FIXME currently this is slow and not recommended. Look into speeding up. 22 | a = fsspec.get_mapper(local_dir) 23 | b = fsspec.get_mapper(remote_dir) 24 | 25 | for k in a: 26 | # skip epoch_latest which can change during sync. 27 | if 'epoch_latest.pt' in k: 28 | continue 29 | 30 | logging.info(f'Attempting to sync {k}') 31 | if k in b and len(a[k]) == len(b[k]): 32 | logging.debug(f'Skipping remote sync for {k}.') 33 | continue 34 | 35 | try: 36 | logging.info(f'Successful sync for {k}.') 37 | b[k] = a[k] 38 | except Exception as e: 39 | logging.info(f'Error during remote sync for {k}: {e}') 40 | return False 41 | 42 | return True 43 | 44 | def remote_sync(local_dir, remote_dir, protocol): 45 | logging.info('Starting remote sync.') 46 | if protocol == 's3': 47 | return remote_sync_s3(local_dir, remote_dir) 48 | elif protocol == 'fsspec': 49 | return remote_sync_fsspec(local_dir, remote_dir) 50 | else: 51 | logging.error('Remote protocol not known') 52 | return False 53 | 54 | def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol): 55 | while True: 56 | time.sleep(sync_every) 57 | remote_sync(local_dir, remote_dir, protocol) 58 | 59 | def start_sync_process(sync_every, local_dir, remote_dir, protocol): 60 | p = multiprocessing.Process(target=keep_running_remote_sync, args=(sync_every, local_dir, remote_dir, protocol)) 61 | return p 62 | 63 | # Note: we are not currently using this save function. 64 | def pt_save(pt_obj, file_path): 65 | of = fsspec.open(file_path, "wb") 66 | with of as f: 67 | torch.save(pt_obj, file_path) 68 | 69 | def pt_load(file_path, map_location=None): 70 | if file_path.startswith('s3'): 71 | logging.info('Loading remote checkpoint, which may take a bit.') 72 | of = fsspec.open(file_path, "rb") 73 | with of as f: 74 | out = torch.load(f, map_location=map_location) 75 | return out 76 | 77 | def check_exists(file_path): 78 | try: 79 | with fsspec.open(file_path): 80 | pass 81 | except FileNotFoundError: 82 | return False 83 | return True 84 | -------------------------------------------------------------------------------- /training/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def setup_logging(log_file, level, include_host=False): 5 | if include_host: 6 | import socket 7 | hostname = socket.gethostname() 8 | formatter = logging.Formatter( 9 | f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 10 | else: 11 | formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 12 | 13 | logging.root.setLevel(level) 14 | loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] 15 | for logger in loggers: 16 | logger.setLevel(level) 17 | 18 | stream_handler = logging.StreamHandler() 19 | stream_handler.setFormatter(formatter) 20 | logging.root.addHandler(stream_handler) 21 | 22 | if log_file: 23 | file_handler = logging.FileHandler(filename=log_file) 24 | file_handler.setFormatter(formatter) 25 | logging.root.addHandler(file_handler) 26 | 27 | -------------------------------------------------------------------------------- /training/precision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from contextlib import suppress 3 | 4 | 5 | def get_autocast(precision): 6 | if precision == 'amp': 7 | return torch.cuda.amp.autocast 8 | elif precision == 'amp_bfloat16' or precision == 'amp_bf16': 9 | # amp_bfloat16 is more stable than amp float16 for clip training 10 | return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) 11 | else: 12 | return suppress 13 | -------------------------------------------------------------------------------- /training/profiler.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import open_clip 5 | import pandas as pd 6 | from torch.utils.flop_counter import FlopCounterMode 7 | try: 8 | import fvcore 9 | except: 10 | fvcore = None 11 | 12 | parser = argparse.ArgumentParser(description='OpenCLIP Profiler') 13 | 14 | # benchmark specific args 15 | parser.add_argument('--model', metavar='NAME', default='', 16 | help='model(s) to profile') 17 | parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', 18 | help='Output csv file for results') 19 | parser.add_argument('--profiler', default='torch', type=str, choices=['torch', 'fvcore']) 20 | parser.add_argument('--batch-size', default=1, type=int, help='Batch size for profiling') 21 | 22 | 23 | def profile_fvcore( 24 | model, 25 | image_input_size=(3, 224, 224), 26 | text_input_size=(77,), 27 | batch_size=1, 28 | detailed=False, 29 | force_cpu=False 30 | ): 31 | if force_cpu: 32 | model = model.to('cpu') 33 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 34 | example_image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) 35 | example_text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) 36 | fca = fvcore.nn.FlopCountAnalysis(model, (example_image_input, example_text_input)) 37 | aca = fvcore.nn.ActivationCountAnalysis(model, (example_image_input, example_text_input)) 38 | if detailed: 39 | fcs = fvcore.nn.flop_count_str(fca) 40 | print(fcs) 41 | return fca.total() / batch_size, aca.total() / batch_size 42 | 43 | 44 | def profile_fvcore_text( 45 | model, 46 | text_input_size=(77,), 47 | batch_size=1, 48 | detailed=False, 49 | force_cpu=False 50 | ): 51 | if force_cpu: 52 | model = model.to('cpu') 53 | device = next(model.parameters()).device 54 | example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) 55 | fca = fvcore.nn.FlopCountAnalysis(model, example_input) 56 | aca = fvcore.nn.ActivationCountAnalysis(model, example_input) 57 | if detailed: 58 | fcs = fvcore.nn.flop_count_str(fca) 59 | print(fcs) 60 | return fca.total() / batch_size, aca.total() / batch_size 61 | 62 | 63 | def profile_fvcore_image( 64 | model, 65 | image_input_size=(3, 224, 224), 66 | batch_size=1, 67 | detailed=False, 68 | force_cpu=False 69 | ): 70 | if force_cpu: 71 | model = model.to('cpu') 72 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 73 | example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) 74 | fca = fvcore.nn.FlopCountAnalysis(model, example_input) 75 | aca = fvcore.nn.ActivationCountAnalysis(model, example_input) 76 | if detailed: 77 | fcs = fvcore.nn.flop_count_str(fca) 78 | print(fcs) 79 | return fca.total() / batch_size, aca.total() / batch_size 80 | 81 | 82 | def profile_torch_image(model, image_input_size, batch_size=1, force_cpu=False): 83 | """Profile the image encoder using torch.utils.flop_counter""" 84 | if force_cpu: 85 | model = model.to('cpu') 86 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 87 | example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) 88 | 89 | flop_counter = FlopCounterMode() 90 | with flop_counter: 91 | model(example_input) 92 | total_flops = sum(flop_counter.get_flop_counts()['Global'].values()) 93 | return total_flops / batch_size 94 | 95 | 96 | def profile_torch_text(model, text_input_size, batch_size=1, force_cpu=False): 97 | """Profile the text encoder using torch.utils.flop_counter""" 98 | if force_cpu: 99 | model = model.to('cpu') 100 | device = next(model.parameters()).device 101 | example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) 102 | 103 | flop_counter = FlopCounterMode() 104 | with flop_counter: 105 | model(example_input) 106 | total_flops = sum(flop_counter.get_flop_counts()['Global'].values()) 107 | return total_flops / batch_size 108 | 109 | 110 | def profile_torch(model, text_input_size, image_input_size, batch_size=1, force_cpu=False): 111 | """Profile the full model using torch.utils.flop_counter""" 112 | if force_cpu: 113 | model = model.to('cpu') 114 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 115 | image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) 116 | text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) 117 | 118 | flop_counter = FlopCounterMode() 119 | with flop_counter: 120 | model(image_input, text_input) 121 | total_flops = sum(flop_counter.get_flop_counts()['Global'].values()) 122 | return total_flops / batch_size 123 | 124 | 125 | def count_params(model): 126 | return sum([m.numel() for m in model.parameters()]) 127 | 128 | def profile_model(model_name, batch_size=1, profiler='torch'): 129 | assert profiler in ['torch', 'fvcore'], 'Only torch and fvcore profilers are supported' 130 | if profiler == 'fvcore': 131 | assert fvcore is not None, 'Please install fvcore.' 132 | model = open_clip.create_model(model_name, force_custom_text=True, pretrained_hf=False) 133 | model.eval() 134 | if torch.cuda.is_available(): 135 | model = model.cuda() 136 | 137 | if isinstance(model.visual.image_size, (tuple, list)): 138 | image_input_size = (3,) + tuple(model.visual.image_size[-2:]) 139 | else: 140 | image_input_size = (3, model.visual.image_size, model.visual.image_size) 141 | 142 | text_input_size = (77,) 143 | if hasattr(model, 'context_length') and model.context_length: 144 | text_input_size = (model.context_length,) 145 | 146 | results = {} 147 | results['model'] = model_name 148 | results['image_size'] = image_input_size[1] 149 | 150 | model_cfg = open_clip.get_model_config(model_name) 151 | if model_cfg: 152 | vision_cfg = open_clip.CLIPVisionCfg(**model_cfg['vision_cfg']) 153 | text_cfg = open_clip.CLIPTextCfg(**model_cfg['text_cfg']) 154 | results['image_width'] = int(vision_cfg.width) 155 | results['text_width'] = int(text_cfg.width) 156 | results['embed_dim'] = int(model_cfg['embed_dim']) 157 | else: 158 | results['image_width'] = 0 159 | results['text_width'] = 0 160 | results['embed_dim'] = 0 161 | 162 | retries = 2 163 | while retries: 164 | retries -= 1 165 | try: 166 | results['mparams'] = round(count_params(model) / 1e6, 2) 167 | results['image_mparams'] = round(count_params(model.visual) / 1e6, 2) 168 | results['text_mparams'] = round(count_params(model.text) / 1e6, 2) 169 | 170 | if profiler == 'fvcore': 171 | macs, acts = profile_fvcore( 172 | model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size) 173 | 174 | image_macs, image_acts = profile_fvcore_image( 175 | model.visual, image_input_size=image_input_size, force_cpu=not retries, batch_size=batch_size) 176 | 177 | text_macs, text_acts = profile_fvcore_text( 178 | model.text, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size) 179 | 180 | results['gmacs'] = round(macs / 1e9, 2) 181 | results['macts'] = round(acts / 1e6, 2) 182 | 183 | results['image_gmacs'] = round(image_macs / 1e9, 2) 184 | results['image_macts'] = round(image_acts / 1e6, 2) 185 | 186 | results['text_gmacs'] = round(text_macs / 1e9, 2) 187 | results['text_macts'] = round(text_acts / 1e6, 2) 188 | elif profiler == 'torch': 189 | image_flops = profile_torch_image( 190 | model.visual, image_input_size=image_input_size, force_cpu=not retries, batch_size=batch_size) 191 | text_flops = profile_torch_text( 192 | model.text, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size) 193 | total_flops = profile_torch( 194 | model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size) 195 | 196 | results['gflops'] = round(total_flops / 1e9, 2) 197 | results['image_gflops'] = round(image_flops / 1e9, 2) 198 | results['text_gflops'] = round(text_flops / 1e9, 2) 199 | 200 | except RuntimeError as e: 201 | pass 202 | return results 203 | 204 | 205 | def main(): 206 | args = parser.parse_args() 207 | 208 | # FIXME accept a text file name to allow lists of models in txt/csv 209 | if args.model == 'all': 210 | parsed_model = open_clip.list_models() 211 | else: 212 | parsed_model = args.model.split(',') 213 | 214 | results = [] 215 | models_with_errors = [] 216 | for m in parsed_model: 217 | print('='*100) 218 | print(f'Profiling {m}') 219 | try: 220 | row = profile_model(m, batch_size=args.batch_size, profiler=args.profiler) 221 | results.append(row) 222 | except Exception as e: 223 | print(f'Error profiling {m}: {e}') 224 | import traceback 225 | traceback.print_exc() 226 | models_with_errors.append(m) 227 | 228 | df = pd.DataFrame(results, columns=results[0].keys()) 229 | 230 | if 'gmacs' in df.columns: 231 | df = df.sort_values(by=['gmacs', 'mparams', 'model']) 232 | else: 233 | df = df.sort_values(by=['gflops', 'mparams', 'model']) 234 | 235 | print('='*100) 236 | print('Done.') 237 | print(df) 238 | if args.results_file: 239 | df.to_csv(args.results_file, index=False) 240 | 241 | if models_with_errors: 242 | print('Models with errors:', models_with_errors) 243 | 244 | 245 | if __name__ == '__main__': 246 | main() 247 | -------------------------------------------------------------------------------- /training/random_aug.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 5 | from timm.data.transforms import RandomResizedCropAndInterpolation 6 | from torchvision import transforms 7 | import urllib 8 | from tqdm import tqdm 9 | from torch.utils.data import default_collate 10 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 11 | from typing_extensions import TypedDict 12 | from numpy.typing import NDArray 13 | import importlib.machinery 14 | import importlib.util 15 | import types 16 | import random 17 | 18 | 19 | 20 | 21 | def pad(orig_items, key, max_length=None, padding_value=0, padding_side="left"): 22 | items = [] 23 | if isinstance(orig_items[0][key], list): 24 | assert isinstance(orig_items[0][key][0], torch.Tensor) 25 | for it in orig_items: 26 | for tr in it[key]: 27 | items.append({key: tr}) 28 | else: 29 | assert isinstance(orig_items[0][key], torch.Tensor) 30 | items = orig_items 31 | 32 | batch_size = len(items) 33 | shape = items[0][key].shape 34 | dim = len(shape) 35 | assert dim <= 3 36 | if max_length is None: 37 | max_length = 0 38 | max_length = max(max_length, max(item[key].shape[-1] for item in items)) 39 | min_length = min(item[key].shape[-1] for item in items) 40 | dtype = items[0][key].dtype 41 | 42 | if dim == 1: 43 | return torch.cat([item[key] for item in items], dim=0) 44 | elif dim == 2: 45 | if max_length == min_length: 46 | return torch.cat([item[key] for item in items], dim=0) 47 | tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value 48 | else: 49 | tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value 50 | 51 | for i, item in enumerate(items): 52 | if dim == 2: 53 | if padding_side == "left": 54 | tensor[i, -len(item[key][0]):] = item[key][0].clone() 55 | else: 56 | tensor[i, : len(item[key][0])] = item[key][0].clone() 57 | elif dim == 3: 58 | if padding_side == "left": 59 | tensor[i, -len(item[key][0]):, :] = item[key][0].clone() 60 | else: 61 | tensor[i, : len(item[key][0]), :] = item[key][0].clone() 62 | 63 | return tensor 64 | 65 | 66 | 67 | class _DictTree(TypedDict): 68 | value: str 69 | children: List["_DictTree"] 70 | depth: int 71 | segment_id: int 72 | need_predict: bool 73 | is_image: bool 74 | 75 | 76 | class _PrevExtTableStates(TypedDict): 77 | ext_table: Dict[int, str] 78 | token_id_table: Dict[str, Dict[int, int]] 79 | 80 | 81 | class _TransformFuncDict(TypedDict): 82 | loader: importlib.machinery.SourceFileLoader 83 | module: types.ModuleType 84 | last_m: float 85 | 86 | 87 | 88 | def rel_to_bucket(n_up: int, n_down: int, max_depth: int = 8): 89 | ret = n_up * max_depth + n_down 90 | if ret == 0: 91 | return ret 92 | else: 93 | # bucket 1 is reserved for incontext samples 94 | return ret + 1 95 | 96 | 97 | # aug functions 98 | def identity_func(img): 99 | return img 100 | 101 | 102 | def autocontrast_func(img, cutoff=0): 103 | ''' 104 | same output as PIL.ImageOps.autocontrast 105 | ''' 106 | n_bins = 256 107 | 108 | def tune_channel(ch): 109 | n = ch.size 110 | cut = cutoff * n // 100 111 | if cut == 0: 112 | high, low = ch.max(), ch.min() 113 | else: 114 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 115 | low = np.argwhere(np.cumsum(hist) > cut) 116 | low = 0 if low.shape[0] == 0 else low[0] 117 | high = np.argwhere(np.cumsum(hist[::-1]) > cut) 118 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] 119 | if high <= low: 120 | table = np.arange(n_bins) 121 | else: 122 | scale = (n_bins - 1) / (high - low) 123 | table = np.arange(n_bins) * scale - low * scale 124 | table[table < 0] = 0 125 | table[table > n_bins - 1] = n_bins - 1 126 | table = table.clip(0, 255).astype(np.uint8) 127 | return table[ch] 128 | 129 | channels = [tune_channel(ch) for ch in cv2.split(img)] 130 | out = cv2.merge(channels) 131 | return out 132 | 133 | 134 | def equalize_func(img): 135 | ''' 136 | same output as PIL.ImageOps.equalize 137 | PIL's implementation is different from cv2.equalize 138 | ''' 139 | n_bins = 256 140 | 141 | def tune_channel(ch): 142 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 143 | non_zero_hist = hist[hist != 0].reshape(-1) 144 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) 145 | if step == 0: 146 | return ch 147 | n = np.empty_like(hist) 148 | n[0] = step // 2 149 | n[1:] = hist[:-1] 150 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) 151 | return table[ch] 152 | 153 | channels = [tune_channel(ch) for ch in cv2.split(img)] 154 | out = cv2.merge(channels) 155 | return out 156 | 157 | 158 | def rotate_func(img, degree, fill=(0, 0, 0)): 159 | ''' 160 | like PIL, rotate by degree, not radians 161 | ''' 162 | H, W = img.shape[0], img.shape[1] 163 | center = W / 2, H / 2 164 | M = cv2.getRotationMatrix2D(center, degree, 1) 165 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill) 166 | return out 167 | 168 | 169 | def solarize_func(img, thresh=128): 170 | ''' 171 | same output as PIL.ImageOps.posterize 172 | ''' 173 | table = np.array([el if el < thresh else 255 - el for el in range(256)]) 174 | table = table.clip(0, 255).astype(np.uint8) 175 | out = table[img] 176 | return out 177 | 178 | 179 | def color_func(img, factor): 180 | ''' 181 | same output as PIL.ImageEnhance.Color 182 | ''' 183 | # implementation according to PIL definition, quite slow 184 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] 185 | # out = blend(degenerate, img, factor) 186 | # M = ( 187 | # np.eye(3) * factor 188 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) 189 | # )[np.newaxis, np.newaxis, :] 190 | M = ( 191 | np.float32([ 192 | [0.886, -0.114, -0.114], 193 | [-0.587, 0.413, -0.587], 194 | [-0.299, -0.299, 0.701]]) * factor 195 | + np.float32([[0.114], [0.587], [0.299]]) 196 | ) 197 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8) 198 | return out 199 | 200 | 201 | def contrast_func(img, factor): 202 | """ 203 | same output as PIL.ImageEnhance.Contrast 204 | """ 205 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) 206 | table = np.array([( 207 | el - mean) * factor + mean 208 | for el in range(256) 209 | ]).clip(0, 255).astype(np.uint8) 210 | out = table[img] 211 | return out 212 | 213 | 214 | def brightness_func(img, factor): 215 | ''' 216 | same output as PIL.ImageEnhance.Contrast 217 | ''' 218 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) 219 | out = table[img] 220 | return out 221 | 222 | 223 | def sharpness_func(img, factor): 224 | ''' 225 | The differences the this result and PIL are all on the 4 boundaries, the center 226 | areas are same 227 | ''' 228 | kernel = np.ones((3, 3), dtype=np.float32) 229 | kernel[1][1] = 5 230 | kernel /= 13 231 | degenerate = cv2.filter2D(img, -1, kernel) 232 | if factor == 0.0: 233 | out = degenerate 234 | elif factor == 1.0: 235 | out = img 236 | else: 237 | out = img.astype(np.float32) 238 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] 239 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) 240 | out = out.astype(np.uint8) 241 | return out 242 | 243 | 244 | def shear_x_func(img, factor, fill=(0, 0, 0)): 245 | H, W = img.shape[0], img.shape[1] 246 | M = np.float32([[1, factor, 0], [0, 1, 0]]) 247 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 248 | return out 249 | 250 | 251 | def translate_x_func(img, offset, fill=(0, 0, 0)): 252 | ''' 253 | same output as PIL.Image.transform 254 | ''' 255 | H, W = img.shape[0], img.shape[1] 256 | M = np.float32([[1, 0, -offset], [0, 1, 0]]) 257 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 258 | return out 259 | 260 | 261 | def translate_y_func(img, offset, fill=(0, 0, 0)): 262 | ''' 263 | same output as PIL.Image.transform 264 | ''' 265 | H, W = img.shape[0], img.shape[1] 266 | M = np.float32([[1, 0, 0], [0, 1, -offset]]) 267 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 268 | return out 269 | 270 | 271 | def posterize_func(img, bits): 272 | ''' 273 | same output as PIL.ImageOps.posterize 274 | ''' 275 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) 276 | return out 277 | 278 | 279 | def shear_y_func(img, factor, fill=(0, 0, 0)): 280 | H, W = img.shape[0], img.shape[1] 281 | M = np.float32([[1, 0, 0], [factor, 1, 0]]) 282 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 283 | return out 284 | 285 | 286 | def cutout_func(img, pad_size, replace=(0, 0, 0)): 287 | replace = np.array(replace, dtype=np.uint8) 288 | H, W = img.shape[0], img.shape[1] 289 | rh, rw = np.random.random(2) 290 | pad_size = pad_size // 2 291 | ch, cw = int(rh * H), int(rw * W) 292 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) 293 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) 294 | out = img.copy() 295 | out[x1:x2, y1:y2, :] = replace 296 | return out 297 | 298 | 299 | # level to args 300 | def enhance_level_to_args(MAX_LEVEL): 301 | def level_to_args(level): 302 | return ((level / MAX_LEVEL) * 1.8 + 0.1,) 303 | return level_to_args 304 | 305 | 306 | def shear_level_to_args(MAX_LEVEL, replace_value): 307 | def level_to_args(level): 308 | level = (level / MAX_LEVEL) * 0.3 309 | if np.random.random() > 0.5: 310 | level = -level 311 | return (level, replace_value) 312 | 313 | return level_to_args 314 | 315 | 316 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): 317 | def level_to_args(level): 318 | level = (level / MAX_LEVEL) * float(translate_const) 319 | if np.random.random() > 0.5: 320 | level = -level 321 | return (level, replace_value) 322 | 323 | return level_to_args 324 | 325 | 326 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): 327 | def level_to_args(level): 328 | level = int((level / MAX_LEVEL) * cutout_const) 329 | return (level, replace_value) 330 | 331 | return level_to_args 332 | 333 | 334 | def solarize_level_to_args(MAX_LEVEL): 335 | def level_to_args(level): 336 | level = int((level / MAX_LEVEL) * 256) 337 | return (level, ) 338 | return level_to_args 339 | 340 | 341 | def none_level_to_args(level): 342 | return () 343 | 344 | 345 | def posterize_level_to_args(MAX_LEVEL): 346 | def level_to_args(level): 347 | level = int((level / MAX_LEVEL) * 4) 348 | return (level, ) 349 | return level_to_args 350 | 351 | 352 | def rotate_level_to_args(MAX_LEVEL, replace_value): 353 | def level_to_args(level): 354 | level = (level / MAX_LEVEL) * 30 355 | if np.random.random() < 0.5: 356 | level = -level 357 | return (level, replace_value) 358 | 359 | return level_to_args 360 | 361 | 362 | func_dict = { 363 | 'Identity': identity_func, 364 | 'AutoContrast': autocontrast_func, 365 | 'Equalize': equalize_func, 366 | 'Rotate': rotate_func, 367 | 'Solarize': solarize_func, 368 | 'Color': color_func, 369 | 'Contrast': contrast_func, 370 | 'Brightness': brightness_func, 371 | 'Sharpness': sharpness_func, 372 | 'ShearX': shear_x_func, 373 | 'TranslateX': translate_x_func, 374 | 'TranslateY': translate_y_func, 375 | 'Posterize': posterize_func, 376 | 'ShearY': shear_y_func, 377 | } 378 | 379 | translate_const = 10 380 | MAX_LEVEL = 10 381 | replace_value = (128, 128, 128) 382 | arg_dict = { 383 | 'Identity': none_level_to_args, 384 | 'AutoContrast': none_level_to_args, 385 | 'Equalize': none_level_to_args, 386 | 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value), 387 | 'Solarize': solarize_level_to_args(MAX_LEVEL), 388 | 'Color': enhance_level_to_args(MAX_LEVEL), 389 | 'Contrast': enhance_level_to_args(MAX_LEVEL), 390 | 'Brightness': enhance_level_to_args(MAX_LEVEL), 391 | 'Sharpness': enhance_level_to_args(MAX_LEVEL), 392 | 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value), 393 | 'TranslateX': translate_level_to_args( 394 | translate_const, MAX_LEVEL, replace_value 395 | ), 396 | 'TranslateY': translate_level_to_args( 397 | translate_const, MAX_LEVEL, replace_value 398 | ), 399 | 'Posterize': posterize_level_to_args(MAX_LEVEL), 400 | 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value), 401 | } 402 | 403 | 404 | class RandomAugment(object): 405 | 406 | def __init__(self, N=2, M=10, isPIL=False, augs=[]): 407 | self.N = N 408 | self.M = M 409 | self.isPIL = isPIL 410 | if augs: 411 | self.augs = augs 412 | else: 413 | self.augs = list(arg_dict.keys()) 414 | 415 | def get_random_ops(self): 416 | sampled_ops = np.random.choice(self.augs, self.N) 417 | return [(op, 0.5, self.M) for op in sampled_ops] 418 | 419 | def __call__(self, img): 420 | if self.isPIL: 421 | img = np.array(img) 422 | ops = self.get_random_ops() 423 | for name, prob, level in ops: 424 | if np.random.random() > prob: 425 | continue 426 | args = arg_dict[name](level) 427 | img = func_dict[name](img, *args) 428 | return img 429 | 430 | -------------------------------------------------------------------------------- /training/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def assign_learning_rate(optimizer, new_lr): 5 | for param_group in optimizer.param_groups: 6 | param_group["lr"] = new_lr 7 | 8 | 9 | def _warmup_lr(base_lr, warmup_length, step): 10 | return base_lr * (step + 1) / warmup_length 11 | 12 | 13 | def const_lr(optimizer, base_lr, warmup_length, steps): 14 | def _lr_adjuster(step): 15 | if step < warmup_length: 16 | lr = _warmup_lr(base_lr, warmup_length, step) 17 | else: 18 | lr = base_lr 19 | assign_learning_rate(optimizer, lr) 20 | return lr 21 | return _lr_adjuster 22 | 23 | 24 | def const_lr_cooldown(optimizer, base_lr, warmup_length, steps, cooldown_steps, cooldown_power=1.0, cooldown_end_lr=0.): 25 | def _lr_adjuster(step): 26 | start_cooldown_step = steps - cooldown_steps 27 | if step < warmup_length: 28 | lr = _warmup_lr(base_lr, warmup_length, step) 29 | else: 30 | if step < start_cooldown_step: 31 | lr = base_lr 32 | else: 33 | e = step - start_cooldown_step 34 | es = steps - start_cooldown_step 35 | # linear decay if power == 1; polynomial decay otherwise; 36 | decay = (1 - (e/es)) ** cooldown_power 37 | lr = decay * (base_lr - cooldown_end_lr) + cooldown_end_lr 38 | assign_learning_rate(optimizer, lr) 39 | return lr 40 | return _lr_adjuster 41 | 42 | 43 | def cosine_lr(optimizer, base_lr, warmup_length, steps): 44 | def _lr_adjuster(step): 45 | if step < warmup_length: 46 | lr = _warmup_lr(base_lr, warmup_length, step) 47 | else: 48 | e = step - warmup_length 49 | es = steps - warmup_length 50 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 51 | assign_learning_rate(optimizer, lr) 52 | return lr 53 | return _lr_adjuster 54 | -------------------------------------------------------------------------------- /training/zero_shot.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from tqdm import tqdm 5 | 6 | from open_clip import get_input_dtype, get_tokenizer, build_zero_shot_classifier, \ 7 | IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES 8 | from .precision import get_autocast 9 | 10 | 11 | def accuracy(output, target, topk=(1,)): 12 | pred = output.topk(max(topk), 1, True, True)[1].t() 13 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 14 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] 15 | 16 | 17 | def run(model, classifier, dataloader, args): 18 | autocast = get_autocast(args.precision) 19 | input_dtype = get_input_dtype(args.precision) 20 | 21 | with torch.no_grad(): 22 | top1, top5, n = 0., 0., 0. 23 | for images, target in tqdm(dataloader, unit_scale=args.batch_size): 24 | images = images.to(device=args.device, dtype=input_dtype) 25 | target = target.to(args.device) 26 | 27 | with autocast(): 28 | # predict 29 | output = model(image=images) 30 | image_features = output['image_features'] if isinstance(output, dict) else output[0] 31 | logits = 100. * image_features @ classifier 32 | 33 | # measure accuracy 34 | acc1, acc5 = accuracy(logits, target, topk=(1, 5)) 35 | top1 += acc1 36 | top5 += acc5 37 | n += images.size(0) 38 | 39 | top1 = (top1 / n) 40 | top5 = (top5 / n) 41 | return top1, top5 42 | 43 | 44 | def zero_shot_eval(model, data, epoch, args, tokenizer=None): 45 | if 'imagenet-val' not in data and 'imagenet-v2' not in data: 46 | return {} 47 | if args.zeroshot_frequency == 0: 48 | return {} 49 | if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: 50 | return {} 51 | if args.distributed and not args.horovod: 52 | model = model.module 53 | 54 | logging.info('Starting zero-shot imagenet.') 55 | if tokenizer is None: 56 | tokenizer = get_tokenizer(args.model) 57 | 58 | logging.info('Building zero-shot classifier') 59 | autocast = get_autocast(args.precision) 60 | with autocast(): 61 | classifier = build_zero_shot_classifier( 62 | model, 63 | tokenizer=tokenizer, 64 | classnames=IMAGENET_CLASSNAMES, 65 | templates=OPENAI_IMAGENET_TEMPLATES, 66 | num_classes_per_batch=10, 67 | device=args.device, 68 | use_tqdm=True, 69 | ) 70 | 71 | logging.info('Using classifier') 72 | results = {} 73 | if 'imagenet-val' in data: 74 | top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args) 75 | results['imagenet-zeroshot-val-top1'] = top1 76 | results['imagenet-zeroshot-val-top5'] = top5 77 | if 'imagenet-v2' in data: 78 | top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args) 79 | results['imagenetv2-zeroshot-val-top1'] = top1 80 | results['imagenetv2-zeroshot-val-top5'] = top5 81 | 82 | logging.info('Finished zero-shot imagenet.') 83 | 84 | return results 85 | --------------------------------------------------------------------------------