├── LICENSE
├── Makefile
├── README.md
├── dataloaders
├── __init__.py
├── caltech101.py
├── cifar10.py
├── cifar100.py
├── coco.py
├── dtd.py
├── fgvc_aircraft.py
├── flickr30k.py
├── flowers102.py
├── food101.py
├── oxford_pets.py
├── stanford_car.py
├── sun397.py
└── utils.py
├── dist_train.sh
├── eval_zs.sh
├── figures
├── moti.png
├── radar.jpg
└── radar.png
├── main.py
├── open_clip
├── __init__.py
├── big_vision.py
├── bpe_simple_vocab_16e6.txt.gz
├── coca_model.py
├── constants.py
├── factory.py
├── generation_utils.py
├── hf_configs.py
├── hf_model.py
├── loss.py
├── model.py
├── model_configs
│ ├── EVA01-g-14-plus.json
│ ├── EVA01-g-14.json
│ ├── EVA02-B-16.json
│ ├── EVA02-E-14-plus.json
│ ├── EVA02-E-14.json
│ ├── EVA02-L-14-336.json
│ ├── EVA02-L-14.json
│ ├── RN101-quickgelu.json
│ ├── RN101.json
│ ├── RN50-quickgelu.json
│ ├── RN50.json
│ ├── RN50x16.json
│ ├── RN50x4.json
│ ├── RN50x64.json
│ ├── ViT-B-16-SigLIP-256.json
│ ├── ViT-B-16-SigLIP-384.json
│ ├── ViT-B-16-SigLIP-512.json
│ ├── ViT-B-16-SigLIP-i18n-256.json
│ ├── ViT-B-16-SigLIP.json
│ ├── ViT-B-16-plus-240.json
│ ├── ViT-B-16-plus.json
│ ├── ViT-B-16-quickgelu.json
│ ├── ViT-B-16.json
│ ├── ViT-B-32-256.json
│ ├── ViT-B-32-plus-256.json
│ ├── ViT-B-32-quickgelu.json
│ ├── ViT-B-32.json
│ ├── ViT-H-14-378-quickgelu.json
│ ├── ViT-H-14-CLIPA-336.json
│ ├── ViT-H-14-CLIPA.json
│ ├── ViT-H-14-quickgelu.json
│ ├── ViT-H-14.json
│ ├── ViT-H-16.json
│ ├── ViT-L-14-280.json
│ ├── ViT-L-14-336.json
│ ├── ViT-L-14-CLIPA-336.json
│ ├── ViT-L-14-CLIPA.json
│ ├── ViT-L-14-quickgelu.json
│ ├── ViT-L-14.json
│ ├── ViT-L-16-320.json
│ ├── ViT-L-16-SigLIP-256.json
│ ├── ViT-L-16-SigLIP-384.json
│ ├── ViT-L-16.json
│ ├── ViT-M-16-alt.json
│ ├── ViT-M-16.json
│ ├── ViT-M-32-alt.json
│ ├── ViT-M-32.json
│ ├── ViT-S-16-alt.json
│ ├── ViT-S-16.json
│ ├── ViT-S-32-alt.json
│ ├── ViT-S-32.json
│ ├── ViT-SO400M-14-SigLIP-384.json
│ ├── ViT-SO400M-14-SigLIP.json
│ ├── ViT-bigG-14-CLIPA-336.json
│ ├── ViT-bigG-14-CLIPA.json
│ ├── ViT-bigG-14.json
│ ├── ViT-e-14.json
│ ├── ViT-g-14.json
│ ├── coca_ViT-B-32.json
│ ├── coca_ViT-L-14.json
│ ├── coca_base.json
│ ├── coca_roberta-ViT-B-32.json
│ ├── convnext_base.json
│ ├── convnext_base_w.json
│ ├── convnext_base_w_320.json
│ ├── convnext_large.json
│ ├── convnext_large_d.json
│ ├── convnext_large_d_320.json
│ ├── convnext_small.json
│ ├── convnext_tiny.json
│ ├── convnext_xlarge.json
│ ├── convnext_xxlarge.json
│ ├── convnext_xxlarge_320.json
│ ├── mt5-base-ViT-B-32.json
│ ├── mt5-xl-ViT-H-14.json
│ ├── nllb-clip-base.json
│ ├── nllb-clip-large.json
│ ├── roberta-ViT-B-16.json
│ ├── roberta-ViT-B-32.json
│ ├── swin_base_patch4_window7_224.json
│ ├── vit_medium_patch16_gap_256.json
│ ├── vit_relpos_medium_patch16_cls_224.json
│ ├── xlm-roberta-base-ViT-B-32.json
│ └── xlm-roberta-large-ViT-H-14.json
├── modified_resnet.py
├── openai.py
├── pos_embed.py
├── pretrained.py
├── push_to_hf_hub.py
├── timm_model.py
├── tokenizer.py
├── transform.py
├── transformer.py
├── utils.py
├── version.py
├── zero_shot_classifier.py
└── zero_shot_metadata.py
├── requirements-test.txt
├── requirements-training.txt
├── requirements.txt
├── setup.py
└── training
├── __init__.py
├── data.py
├── distributed.py
├── file_utils.py
├── logger.py
├── params.py
├── precision.py
├── profiler.py
├── random_aug.py
├── scheduler.py
├── train.py
└── zero_shot.py
/Makefile:
--------------------------------------------------------------------------------
1 | install: ## [Local development] Upgrade pip, install requirements, install package.
2 | python -m pip install -U pip
3 | python -m pip install -e .
4 |
5 | install-training:
6 | python -m pip install -r requirements-training.txt
7 |
8 | install-test: ## [Local development] Install test requirements
9 | python -m pip install -r requirements-test.txt
10 |
11 | test: ## [Local development] Run unit tests
12 | python -m pytest -x -s -v tests
13 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DreamLIP: Language-Image Pre-training with Long Captions
2 |
3 |
4 | > **DreamLIP: Language-Image Pre-training with Long Captions**
5 | Kecheng Zheng,
6 | Yifei Zhang,
7 | Wei Wu,
8 | Fan Lu,
9 | Shuailei Ma,
10 | Xin Jin,
11 | Wei Chen,
12 | Yujun Shen
13 | [Project Page](https://zyf0619sjtu.github.io/dream-lip/) | [Paper](https://arxiv.org/pdf/2403.17007.pdf)
14 |
15 |
16 | ## 📰 News
17 |
18 | [//]: # (- [2024/11/15] Long captions (LLAVA1.5, InstructBLIP and shareGPT4V) of Laion50M and Coyo20M that are used in [NeurIPS'24 LotLIP](https://github.com/wuw2019/LoTLIP) are released in google drive~)
19 | - [2024/11/26] Long captions (LLAVA1.5, InstructBLIP and shareGPT4V) of COYO24M/LAION49M are released in huggingface~
20 | - [2024/08/26] Long captions (LLAVA1.5, InstructBLIP and shareGPT4V) of CC3M/CC12M/YFCC15M are released in huggingface~
21 | - [2024/07/16] Upload the pretrained weight of VIT-B/16 pretrained in CC3M, CC12M, YFCC15M, and merged-30M (long captions of ShareGPT4V)!
22 | - [2024/07/08] DreamLIP is accepted by ECCV 2024!
23 |
24 | ## 💡 Highlights
25 | - 🔥 Exploring how language-image pre-training could benefit from long captions.
26 | - 🔥 Strong improvement on semantic segmentation, image-text retrieval, semantic segmentation, and image understanding in MLLM.
27 |
28 |
29 |
30 | - 🔥 DreamLIP trained with 30M image-text pairs achieves on par or even better performance than CLIP trained with 400M pairs.
31 | 
32 |
33 | ## 🎨 In-Progress
34 |
35 | - [x] Release long captions of CC3M, CC12M, YFCC15M, COYO24M and LAION49M.
36 | - [ ] Release training code.
37 |
38 | ## 🏝️ Overview of supported long captions:
39 |
40 |
41 | Long Captions of Supported Datasets (5)
42 |
43 | > - [x] [](https://ai.google.com/research/ConceptualCaptions/)
44 | > - [x] [](https://github.com/google-research-datasets/conceptual-12m)
45 | > - [x] [](https://github.com/Sense-GVT/DeCLIP/blob/main/docs/dataset_prepare.md)
46 | > - [x] [](https://laion.ai/laion-5b-a-new-era-of-open-large-scale-multi-modal-datasets/)
47 | > - [x] [](https://github.com/kakaobrain/coyo-dataset)
48 |
49 |
50 | Long Captions of MLLMs (3)
51 |
52 | > - [x] 
53 | > - [x] 
54 | > - [x] 
55 |
56 |
57 |
58 | [//]: # (## Acknowledgement)
59 |
60 |
61 | #### Generated Long Captions
62 |
63 |
91 |
92 | ## Pretrained checkpoints
93 |
94 |
95 |
96 | Dataset |
97 | Model |
98 | ShareGPT4V |
99 | InstructBLIP + LLAVA1.5 + ShareGPT4V |
100 |
101 |
102 |
103 | CC3M |
104 | ViT-B/16 |
105 | Link |
106 | TODO |
107 |
108 |
109 | CC12M |
110 | ViT-B/16 |
111 | Link |
112 | TODO |
113 |
114 |
115 | YFCC15M |
116 | ViT-B/16 |
117 | Link |
118 | TODO |
119 |
120 |
121 | CC30M |
122 | ViT-B/16 |
123 | Link |
124 | TODO |
125 |
126 |
127 |
128 | ## 📣 Instructions
129 | Environment installation
130 | ```
131 | pip install -r requirments.txt
132 | ```
133 |
134 | Evaluate zero shot classification
135 | ```
136 | bash eval_zs.sh
137 | ```
138 |
139 | [//]: # (You can download checkpoints pre-trained )
140 |
141 | ## License
142 |
143 | The project is under a standard Creative Common [CC-BY-4.0 License](./LICENSE).
144 |
145 | ## 📖 Citation
146 |
147 | We open source this library to the community to facilitate the research. If you do like our work and use the codebase for your projects, please cite our work as follows.
148 |
149 | ```bibtex
150 | @inproceedings{DreamLIP,
151 | title={DreamLIP: Language-Image Pre-training with Long Captions},
152 | author={Zheng, Kecheng and Zhang, Yifei and Wu, Wei and Lu, Fan and Ma, Shuailei and Jin, Xin and Chen, Wei and Shen, Yujun},
153 | booktitle={ECCV},
154 | year={2024}
155 | }
156 | ```
157 |
158 | ### Acknowledgements
159 | This project is based on [open_clip](https://github.com/mlfoundations/open_clip/tree/main), and thanks for the nice work!
160 | We also thank [InstructBLIP](https://github.com/salesforce/LAVIS), [ShareGPT4V](https://github.com/InternLM/InternLM-XComposer) and [LLAVA](https://github.com/haotian-liu/LLaVA) for the pretrained models and codes.
161 |
162 |
163 |
164 |
--------------------------------------------------------------------------------
/dataloaders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/DreamLIP/b4de0a43b6c002033c02873f91a695ab449e464c/dataloaders/__init__.py
--------------------------------------------------------------------------------
/dataloaders/caltech101.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | import torch
3 | from torch.utils.data import DataLoader, Dataset, random_split
4 | from torchvision.datasets import Caltech101
5 | from .utils import dataset_root
6 |
7 |
8 | root = dataset_root
9 | num_example_train = 3000
10 | num_example_test = 5677
11 | num_classes = 101
12 | mean_per_class = True
13 |
14 | def get_loader_train(
15 | transform, batch_size, num_workers, seed
16 | ) -> Tuple[Dataset, DataLoader]:
17 | dataset = Caltech101(root, download=False, transform=transform)
18 | dataset_train, dataset_test = random_split(
19 | dataset,
20 | lengths=[num_example_train,
21 | num_example_test],
22 | generator=torch.Generator().manual_seed(seed))
23 | return (dataset_train, None)
24 |
25 |
26 | def get_loader_test(
27 | transform, batch_size, num_workers, seed
28 | ) -> Tuple[Dataset, DataLoader]:
29 | dataset = Caltech101(root, download=False, transform=transform)
30 | dataset_train, dataset_test = random_split(
31 | dataset,
32 | lengths=[num_example_train,
33 | num_example_test],
34 | generator=torch.Generator().manual_seed(seed))
35 | return (dataset_test, None)
36 |
--------------------------------------------------------------------------------
/dataloaders/cifar10.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | from torch.utils.data import DataLoader, Dataset
3 | from torchvision.datasets import CIFAR10
4 | from .utils import dataset_root
5 |
6 |
7 | root = dataset_root + '/cifar/'
8 | num_example_train = 50000
9 | num_example_test = 10000
10 | num_classes = 10
11 |
12 |
13 | def get_loader_train(
14 | transform, batch_size, num_workers, seed
15 | ) -> Tuple[Dataset, DataLoader]:
16 | dataset_train = CIFAR10(root, download=False, train=True, transform=transform)
17 | return (dataset_train, None)
18 |
19 |
20 | def get_loader_test(
21 | transform, batch_size, num_workers, seed
22 | ) -> Tuple[Dataset, DataLoader]:
23 | dataset = CIFAR10(root, download=False, train=False, transform=transform)
24 | return (dataset, None)
25 |
--------------------------------------------------------------------------------
/dataloaders/cifar100.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | from torch.utils.data import DataLoader, Dataset
3 | from torchvision.datasets import CIFAR100
4 | from .utils import dataset_root
5 |
6 |
7 | root = dataset_root + '/cifar/'
8 | num_example_train_val = 50000
9 | num_example_test = 10000
10 | num_classes = 100
11 |
12 |
13 | def get_loader_train(
14 | transform, batch_size, num_workers, seed
15 | ) -> Tuple[Dataset, DataLoader]:
16 | dataset_train_val = CIFAR100(root, download=False, train=True, transform=transform)
17 | return (dataset_train_val, None)
18 |
19 |
20 | def get_loader_test(
21 | transform, batch_size, num_workers, seed
22 | ) -> Tuple[Dataset, DataLoader]:
23 | dataset_test = CIFAR100(root, download=False, train=False, transform=transform)
24 | return (dataset_test, None)
25 |
--------------------------------------------------------------------------------
/dataloaders/coco.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | from typing import Any, Callable, Optional, Tuple, List
3 |
4 | from PIL import Image
5 |
6 | from torchvision.datasets.vision import VisionDataset
7 |
8 | from pcache_fileio import fileio
9 |
10 |
11 | class CocoDetection(VisionDataset):
12 | """`MS Coco Detection `_ Dataset.
13 |
14 | It requires the `COCO API to be installed `_.
15 |
16 | Args:
17 | root (string): Root directory where images are downloaded to.
18 | annFile (string): Path to json annotation file.
19 | transform (callable, optional): A function/transform that takes in an PIL image
20 | and returns a transformed version. E.g, ``transforms.ToTensor``
21 | target_transform (callable, optional): A function/transform that takes in the
22 | target and transforms it.
23 | transforms (callable, optional): A function/transform that takes input sample and its target as entry
24 | and returns a transformed version.
25 | """
26 |
27 | def __init__(
28 | self,
29 | root: str,
30 | annFile: str,
31 | transform: Optional[Callable] = None,
32 | target_transform: Optional[Callable] = None,
33 | transforms: Optional[Callable] = None,
34 | ) -> None:
35 | super().__init__(root, transforms, transform, target_transform)
36 | from pycocotools.coco import COCO
37 |
38 | self.coco = COCO(annFile)
39 | self.ids = list(sorted(self.coco.imgs.keys()))
40 |
41 | def _load_image(self, id: int) -> Image.Image:
42 | path = self.coco.loadImgs(id)[0]["file_name"]
43 | return Image.open(os.path.join(self.root, path)).convert("RGB")
44 |
45 | def _load_target(self, id: int) -> List[Any]:
46 | return self.coco.loadAnns(self.coco.getAnnIds(id))
47 |
48 | def __getitem__(self, index: int) -> Tuple[Any, Any]:
49 | id = self.ids[index]
50 | image = self._load_image(id)
51 | target = self._load_target(id)
52 |
53 | if self.transforms is not None:
54 | image, target = self.transforms(image, target)
55 |
56 | return image, target
57 |
58 | def __len__(self) -> int:
59 | return len(self.ids)
60 |
61 |
62 | class CocoCaptions(CocoDetection):
63 | """`MS Coco Captions `_ Dataset.
64 |
65 | It requires the `COCO API to be installed `_.
66 |
67 | Args:
68 | root (string): Root directory where images are downloaded to.
69 | annFile (string): Path to json annotation file.
70 | transform (callable, optional): A function/transform that takes in an PIL image
71 | and returns a transformed version. E.g, ``transforms.ToTensor``
72 | target_transform (callable, optional): A function/transform that takes in the
73 | target and transforms it.
74 | transforms (callable, optional): A function/transform that takes input sample and its target as entry
75 | and returns a transformed version.
76 |
77 | Example:
78 |
79 | .. code:: python
80 |
81 | import torchvision.datasets as dset
82 | import torchvision.transforms as transforms
83 | cap = dset.CocoCaptions(root = 'dir where images are',
84 | annFile = 'json annotation file',
85 | transform=transforms.ToTensor())
86 |
87 | print('Number of samples: ', len(cap))
88 | img, target = cap[3] # load 4th sample
89 |
90 | print("Image Size: ", img.size())
91 | print(target)
92 |
93 | Output: ::
94 |
95 | Number of samples: 82783
96 | Image Size: (3L, 427L, 640L)
97 | [u'A plane emitting smoke stream flying over a mountain.',
98 | u'A plane darts across a bright blue sky behind a mountain covered in snow',
99 | u'A plane leaves a contrail above the snowy mountain top.',
100 | u'A mountain that has a plane flying overheard in the distance.',
101 | u'A mountain view with a plume of smoke in the background']
102 |
103 | """
104 |
105 | def _load_target(self, id: int) -> List[str]:
106 | return [ann["caption"] for ann in super()._load_target(id)]
107 |
--------------------------------------------------------------------------------
/dataloaders/dtd.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | from torch.utils.data import DataLoader, Dataset, ConcatDataset
3 | from torchvision.datasets import DTD
4 | import torchvision
5 | from .utils import dataset_root
6 |
7 | root = dataset_root
8 | num_example_train = 3760
9 | num_example_test = 1880
10 | num_classes = 47
11 |
12 |
13 | class Warper(Dataset):
14 | def __init__(self, dataset) -> None:
15 | super().__init__()
16 | self.dataset = dataset
17 |
18 | def __len__(self) -> int:
19 | return len(self.dataset)
20 |
21 | def __getitem__(self, index):
22 | img, label = self.dataset[index]
23 | if torchvision.__version__ >= "0.13.0":
24 | return img, label
25 | else:
26 | return img, label - 1
27 |
28 |
29 | def get_loader_train(
30 | transform, batch_size, num_workers, seed
31 | ) -> Tuple[Dataset, DataLoader]:
32 | dataset_train = DTD(root, download=False, split='train', transform=transform)
33 | dataset_val = DTD(root, download=False, split='val', transform=transform)
34 | dataset_train = ConcatDataset([dataset_train, dataset_val])
35 | dataset_train = Warper(dataset_train)
36 | return (dataset_train, None)
37 |
38 |
39 | def get_loader_test(
40 | transform, batch_size, num_workers, seed
41 | ) -> Tuple[Dataset, DataLoader]:
42 | dataset_test = DTD(root, download=False, split='test', transform=transform)
43 | return (dataset_test, None)
44 |
--------------------------------------------------------------------------------
/dataloaders/fgvc_aircraft.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | from torch.utils.data import DataLoader, Dataset
3 | from torchvision.datasets import FGVCAircraft
4 | from .utils import dataset_root
5 |
6 |
7 | root = dataset_root
8 | num_example_train_val = 6667
9 | num_example_test = 3333
10 | num_classes = 100
11 | mean_per_class = True
12 |
13 |
14 | def get_loader_train(
15 | transform, batch_size, num_workers, seed
16 | ) -> Tuple[Dataset, DataLoader]:
17 | dataset_train_val = FGVCAircraft(root, download=True, annotation_level='variant', split='trainval',transform=transform)
18 | return (dataset_train_val, None)
19 |
20 |
21 | def get_loader_test(
22 | transform, batch_size, num_workers, seed
23 | ) -> Tuple[Dataset, DataLoader]:
24 | dataset_test = FGVCAircraft(root, download=True, annotation_level='variant', split='test', transform=transform)
25 | return (dataset_test, None)
26 |
27 |
--------------------------------------------------------------------------------
/dataloaders/flickr30k.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | from collections import defaultdict
4 | from html.parser import HTMLParser
5 | from typing import Any, Callable, Dict, List, Optional, Tuple
6 | import json
7 |
8 | from PIL import Image
9 |
10 | from torchvision.datasets.vision import VisionDataset
11 |
12 | from pcache_fileio import fileio
13 |
14 |
15 | class Flickr8kParser(HTMLParser):
16 | """Parser for extracting captions from the Flickr8k dataset web page."""
17 |
18 | def __init__(self, root: str) -> None:
19 | super().__init__()
20 |
21 | self.root = root
22 |
23 | # Data structure to store captions
24 | self.annotations: Dict[str, List[str]] = {}
25 |
26 | # State variables
27 | self.in_table = False
28 | self.current_tag: Optional[str] = None
29 | self.current_img: Optional[str] = None
30 |
31 | def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None:
32 | self.current_tag = tag
33 |
34 | if tag == "table":
35 | self.in_table = True
36 |
37 | def handle_endtag(self, tag: str) -> None:
38 | self.current_tag = None
39 |
40 | if tag == "table":
41 | self.in_table = False
42 |
43 | def handle_data(self, data: str) -> None:
44 | if self.in_table:
45 | if data == "Image Not Found":
46 | self.current_img = None
47 | elif self.current_tag == "a":
48 | img_id = data.split("/")[-2]
49 | img_id = os.path.join(self.root, img_id + "_*.jpg")
50 | img_id = glob.glob(img_id)[0]
51 | self.current_img = img_id
52 | self.annotations[img_id] = []
53 | elif self.current_tag == "li" and self.current_img:
54 | img_id = self.current_img
55 | self.annotations[img_id].append(data.strip())
56 |
57 |
58 | class Flickr8k(VisionDataset):
59 | """`Flickr8k Entities `_ Dataset.
60 |
61 | Args:
62 | root (string): Root directory where images are downloaded to.
63 | ann_file (string): Path to annotation file.
64 | transform (callable, optional): A function/transform that takes in a PIL image
65 | and returns a transformed version. E.g, ``transforms.ToTensor``
66 | target_transform (callable, optional): A function/transform that takes in the
67 | target and transforms it.
68 | """
69 |
70 | def __init__(
71 | self,
72 | root: str,
73 | ann_file: str,
74 | transform: Optional[Callable] = None,
75 | target_transform: Optional[Callable] = None,
76 | ) -> None:
77 | super().__init__(root, transform=transform, target_transform=target_transform)
78 | self.ann_file = os.path.expanduser(ann_file)
79 |
80 | # Read annotations and store in a dict
81 | parser = Flickr8kParser(self.root)
82 | with open(self.ann_file) as fh:
83 | parser.feed(fh.read())
84 | self.annotations = parser.annotations
85 |
86 | self.ids = list(sorted(self.annotations.keys()))
87 |
88 | def __getitem__(self, index: int) -> Tuple[Any, Any]:
89 | """
90 | Args:
91 | index (int): Index
92 |
93 | Returns:
94 | tuple: Tuple (image, target). target is a list of captions for the image.
95 | """
96 | img_id = self.ids[index]
97 |
98 | # Image
99 | img = Image.open(img_id).convert("RGB")
100 | if self.transform is not None:
101 | img = self.transform(img)
102 |
103 | # Captions
104 | target = self.annotations[img_id]
105 | if self.target_transform is not None:
106 | target = self.target_transform(target)
107 |
108 | return img, target
109 |
110 | def __len__(self) -> int:
111 | return len(self.ids)
112 |
113 |
114 | class Flickr30k(VisionDataset):
115 | """`Flickr30k Entities `_ Dataset.
116 |
117 | Args:
118 | root (string): Root directory where images are downloaded to.
119 | ann_file (string): Path to annotation file.
120 | transform (callable, optional): A function/transform that takes in a PIL image
121 | and returns a transformed version. E.g, ``transforms.ToTensor``
122 | target_transform (callable, optional): A function/transform that takes in the
123 | target and transforms it.
124 | """
125 |
126 | def __init__(
127 | self,
128 | root: str,
129 | ann_file: str,
130 | transform: Optional[Callable] = None,
131 | target_transform: Optional[Callable] = None,
132 | ) -> None:
133 | super().__init__(root, transform=transform, target_transform=target_transform)
134 | self.ann_file = os.path.expanduser(ann_file)
135 | self.ann_file = json.load(open(ann_file, 'r'))
136 | # Read annotations and store in a dict
137 | self.annotations = defaultdict(list)
138 |
139 | for line in self.ann_file:
140 | img_id, caption = line['image'], line['caption']
141 | self.annotations[img_id].append(caption)
142 |
143 | self.ids = list(sorted(self.annotations.keys()))
144 |
145 | def __getitem__(self, index: int) -> Tuple[Any, Any]:
146 | """
147 | Args:
148 | index (int): Index
149 |
150 | Returns:
151 | tuple: Tuple (image, target). target is a list of captions for the image.
152 | """
153 | img_id = self.ids[index]
154 |
155 | # Image
156 | filename = os.path.join(self.root, img_id)
157 | img = Image.open(filename).convert("RGB")
158 | if self.transform is not None:
159 | img = self.transform(img)
160 |
161 | # Captions
162 | targets = self.annotations[img_id]
163 | if self.target_transform is not None:
164 | targets = [self.target_transform(target) for target in targets]
165 |
166 | return img, targets
167 |
168 | def __len__(self) -> int:
169 | return len(self.ids)
170 |
--------------------------------------------------------------------------------
/dataloaders/flowers102.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | from torch.utils.data import DataLoader, Dataset, ConcatDataset
3 | from torchvision.datasets import Flowers102
4 | import torchvision
5 | from .utils import dataset_root
6 |
7 | root = dataset_root
8 | num_example_train_val = 2040
9 | num_example_test = 6149
10 | num_classes = 102
11 | mean_per_class = True
12 |
13 |
14 | class Warper(Dataset):
15 | def __init__(self, dataset) -> None:
16 | super().__init__()
17 | self.dataset = dataset
18 |
19 | def __len__(self) -> int:
20 | return len(self.dataset)
21 |
22 | def __getitem__(self, index):
23 | img, label = self.dataset[index]
24 | if torchvision.__version__ >= "0.13.0":
25 | return img, label
26 | else:
27 | return img, label - 1
28 |
29 |
30 | def get_loader_train(
31 | transform, batch_size, num_workers, seed
32 | ) -> Tuple[Dataset, DataLoader]:
33 | dataset_train = Flowers102(root, download=False, split='train', transform=transform)
34 | dataset_val = Flowers102(root, download=False, split='val', transform=transform)
35 | dataset_train = ConcatDataset([dataset_train, dataset_val])
36 | dataset_train = Warper(dataset_train)
37 | return (dataset_train, None)
38 |
39 |
40 | def get_loader_test(
41 | transform, batch_size, num_workers, seed
42 | ) -> Tuple[Dataset, DataLoader]:
43 | dataset_test = Flowers102(root, download=False, split='test', transform=transform)
44 | dataset_test = Warper(dataset_test)
45 | return (dataset_test, None)
46 |
47 |
--------------------------------------------------------------------------------
/dataloaders/food101.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | from torch.utils.data import DataLoader, Dataset
3 | from torchvision.datasets import Food101
4 | from .utils import dataset_root
5 |
6 |
7 | root = dataset_root
8 | num_example_train_val = 75750
9 | num_example_test = 25250
10 | num_classes = 101
11 |
12 |
13 | def get_loader_train(
14 | transform, batch_size, num_workers, seed
15 | ) -> Tuple[Dataset, DataLoader]:
16 | dataset_train_val = Food101(root, download=False, split='train', transform=transform)
17 | return (dataset_train_val, None)
18 |
19 |
20 | def get_loader_test(
21 | transform, batch_size, num_workers, seed
22 | ) -> Tuple[Dataset, DataLoader]:
23 | dataset_test = Food101(root, download=False, split='test', transform=transform)
24 | return (dataset_test, None)
25 |
--------------------------------------------------------------------------------
/dataloaders/oxford_pets.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | from torch.utils.data import DataLoader, Dataset
3 | from torchvision.datasets import OxfordIIITPet
4 | from .utils import dataset_root
5 |
6 |
7 | root = dataset_root
8 | num_example_train_val = 3680
9 | num_example_test = 3699
10 | num_classes = 37
11 | mean_per_class = True
12 |
13 |
14 | def get_loader_train(
15 | transform, batch_size, num_workers, seed
16 | ) -> Tuple[Dataset, DataLoader]:
17 | dataset_train_val = OxfordIIITPet(root, download=True, split='trainval',transform=transform)
18 | return (dataset_train_val, None)
19 |
20 |
21 | def get_loader_test(
22 | transform, batch_size, num_workers, seed
23 | ) -> Tuple[Dataset, DataLoader]:
24 | dataset_test = OxfordIIITPet(root, download=True, split='test', transform=transform)
25 | return (dataset_test, None)
26 |
27 |
--------------------------------------------------------------------------------
/dataloaders/stanford_car.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | from torch.utils.data import DataLoader, Dataset
3 | from torchvision.datasets import StanfordCars
4 | from .utils import dataset_root
5 |
6 |
7 | root = dataset_root
8 | num_example_train_val = 8144
9 | num_example_test = 8041
10 | num_classes = 196
11 |
12 |
13 | def get_loader_train(
14 | transform, batch_size, num_workers, seed
15 | ) -> Tuple[Dataset, DataLoader]:
16 | dataset_train = StanfordCars(
17 | root, download=False, split='train', transform=transform)
18 | return (dataset_train, None)
19 |
20 |
21 | def get_loader_test(
22 | transform, batch_size, num_workers, seed
23 | ) -> Tuple[Dataset, DataLoader]:
24 | dataset_test = StanfordCars(
25 | root, download=False, split='test', transform=transform)
26 | return (dataset_test, None)
27 |
--------------------------------------------------------------------------------
/dataloaders/sun397.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | import torch
3 | from torch.utils.data import DataLoader, Dataset, random_split
4 | from torchvision.datasets import SUN397
5 | from .utils import dataset_root
6 |
7 |
8 | num_example_train = 19850
9 | num_example_test = 19850
10 | num_example_others = 69054
11 | num_classes = 397
12 | root = dataset_root + '/sun397/'
13 |
14 |
15 | def get_loader_train(
16 | transform, batch_size, num_workers, seed
17 | ) -> Tuple[Dataset, DataLoader]:
18 | dataset = SUN397(root, transform=transform)
19 | dataset_train, dataset_test, others = random_split(
20 | dataset,
21 | lengths=[num_example_train,
22 | num_example_test,
23 | num_example_others,],
24 | generator=torch.Generator().manual_seed(seed + hash("sun397") % 2048))
25 | return (dataset_train, None)
26 |
27 |
28 | def get_loader_test(
29 | transform, batch_size, num_workers, seed
30 | ) -> Tuple[Dataset, DataLoader]:
31 | dataset = SUN397(root, transform=transform)
32 | dataset_train, dataset_test, others = random_split(
33 | dataset,
34 | lengths=[num_example_train,
35 | num_example_test,
36 | num_example_others,],
37 | generator=torch.Generator().manual_seed(seed + hash("sun397") % 2048))
38 | return (dataset_test, None)
39 |
--------------------------------------------------------------------------------
/dataloaders/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | import numpy as np
4 | import torch
5 | from torch import distributed as dist
6 | from torch.utils.data import DistributedSampler as _DistributedSampler
7 | import os
8 | try:
9 | rank = int(os.environ["RANK"])
10 | world_size = int(os.environ["WORLD_SIZE"])
11 | except KeyError:
12 | rank = 0
13 | world_size = 1
14 |
15 | dataset_root = "/input_ssd/datasets/"
16 |
17 | def worker_init_fn(worker_id, num_workers, rank, seed):
18 | # The seed of each worker equals to
19 | # num_worker * rank + worker_id + user_seed
20 | worker_seed = num_workers * rank + worker_id + seed
21 | np.random.seed(worker_seed)
22 | random.seed(worker_seed)
23 | torch.manual_seed(worker_seed)
24 |
25 |
26 | def get_dist_info():
27 | if dist.is_available() and dist.is_initialized():
28 | rank = dist.get_rank()
29 | world_size = dist.get_world_size()
30 | else:
31 | rank = 0
32 | world_size = 1
33 |
34 | return rank, world_size
35 |
36 |
37 | def sync_random_seed(seed=None, device="cuda"):
38 | """Make sure different ranks share the same seed.
39 | All workers must call this function, otherwise it will deadlock.
40 | This method is generally used in `DistributedSampler`,
41 | because the seed should be identical across all processes
42 | in the distributed group.
43 | In distributed sampling, different ranks should sample non-overlapped
44 | data in the dataset. Therefore, this function is used to make sure that
45 | each rank shuffles the data indices in the same order based
46 | on the same seed. Then different ranks could use different indices
47 | to select non-overlapped data from the same data list.
48 | Args:
49 | seed (int, Optional): The seed. Default to None.
50 | device (str): The device where the seed will be put on.
51 | Default to 'cuda'.
52 | Returns:
53 | int: Seed to be used.
54 | """
55 | if seed is None:
56 | seed = np.random.randint(2**31)
57 | assert isinstance(seed, int)
58 |
59 | rank, world_size = get_dist_info()
60 |
61 | if world_size == 1:
62 | return seed
63 |
64 | if rank == 0:
65 | random_num = torch.tensor(seed, dtype=torch.int32, device=device)
66 | else:
67 | random_num = torch.tensor(0, dtype=torch.int32, device=device)
68 |
69 | dist.broadcast(random_num, src=0)
70 |
71 | return random_num.item()
72 |
73 |
74 | class DistributedSampler(_DistributedSampler):
75 | def __init__(
76 | self,
77 | dataset,
78 | num_replicas=None, # world_size
79 | rank=None, # local_rank
80 | shuffle=True,
81 | seed=0,
82 | ):
83 |
84 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
85 |
86 | # In distributed sampling, different ranks should sample
87 | # non-overlapped data in the dataset. Therefore, this function
88 | # is used to make sure that each rank shuffles the data indices
89 | # in the same order based on the same seed. Then different ranks
90 | # could use different indices to select non-overlapped data from the
91 | # same data list.
92 | self.seed = sync_random_seed(seed)
93 |
94 | def __iter__(self):
95 | # deterministically shuffle based on epoch
96 | if self.shuffle:
97 | g = torch.Generator()
98 | # When :attr:`shuffle=True`, this ensures all replicas
99 | # use a different random ordering for each epoch.
100 | # Otherwise, the next iteration of this sampler will
101 | # yield the same ordering.
102 | g.manual_seed(self.epoch + self.seed)
103 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
104 | else:
105 | indices = torch.arange(len(self.dataset)).tolist()
106 |
107 | # add extra samples to make it evenly divisible
108 | # in case that indices is shorter than half of total_size
109 | indices = (indices * math.ceil(self.total_size / len(indices)))[
110 | : self.total_size
111 | ]
112 | assert len(indices) == self.total_size
113 |
114 | # subsample
115 | indices = indices[self.rank : self.total_size : self.num_replicas]
116 | assert len(indices) == self.num_samples
117 |
118 | return iter(indices)
119 |
--------------------------------------------------------------------------------
/dist_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Detect `python3` command.
4 | # This workaround addresses a common issue:
5 | # `python` points to `python2`, which is deprecated.
6 | export PYTHONS
7 | export RVAL
8 |
9 | PYTHONS=$(compgen -c | grep "^python3$")
10 |
11 | # `$?` is a built-in variable in bash, which is the exit status of the most
12 | # recently-executed command; by convention, 0 means success and anything else
13 | # indicates failure.
14 | RVAL=$?
15 |
16 | if [[ $RVAL -eq 0 ]]; then # if `python3` exist
17 | PYTHON="python3"
18 | else
19 | PYTHON="python"
20 | fi
21 |
22 | # Help message.
23 | if [[ $# -lt 2 ]]; then
24 | echo "This script helps launch distributed training job on local machine."
25 | echo
26 | echo "Usage: $0 GPUS COMMAND [ARGS]"
27 | echo
28 | echo "Example: $0 8 ddpm [--help]"
29 | echo
30 | echo "To enable multi-node training, one can reuse this script" \
31 | "by simply setting the following environment variables on each" \
32 | "machine:"
33 | echo " MASTER_IP: The IP address of the master node."
34 | echo " MASTER_PORT: The port of the master node."
35 | echo " NODE_SIZE: Number of nodes (machines) used for training."
36 | echo " NODE_RANK: Node rank of the current machine."
37 | echo
38 | echo "NOTE: In multi-node training, \`GPUS\` refers to the number" \
39 | "of GPUs on each local machine, or say, GPUs per node."
40 | echo
41 | echo "Example of using 16 GPUs on 2 machines (i.e., 8 GPUs each):"
42 | echo
43 | echo " On machine 0: MASTER_IP=node_0_ip MASTER_PORT=node_0_port" \
44 | "NODE_SIZE=2 NODE_RANK=0 $0 8 ddpm [--help]"
45 | echo " On machine 1: MASTER_IP=node_0_ip MASTER_PORT=node_0_port" \
46 | "NODE_SIZE=2 NODE_RANK=1 $0 8 ddpm [--help]"
47 | echo
48 | echo "Detailed instruction on available commands:"
49 | echo "--------------------------------------------------"
50 | ${PYTHON} ./main.py --help
51 | echo
52 | exit 0
53 | fi
54 |
55 | GPUS=$1
56 | COMMAND=$2
57 |
58 | # Help message for a particular command.
59 | if [[ $# -lt 3 || ${*: -1} == "--help" ]]; then
60 | echo "Detailed instruction on the arguments for command \`"${COMMAND}"\`:"
61 | echo "--------------------------------------------------"
62 | ${PYTHON} ./main.py ${COMMAND} --help
63 | echo
64 | exit 0
65 | fi
66 |
67 | # Switch memory allocator if available.
68 | # Search order: jemalloc.so -> tcmalloc.so.
69 | # According to https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html,
70 | # it can get better performance by reusing memory as much as possible than
71 | # default malloc function.
72 | JEMALLOC=$(ldconfig -p | grep -i "libjemalloc.so$" | tr " " "\n" | grep "/" \
73 | | head -n 1)
74 | TCMALLOC=$(ldconfig -p | grep -i "libtcmalloc.so.4$" | tr " " "\n" | grep "/" \
75 | | head -n 1)
76 | if [ -n "$JEMALLOC" ]; then # if found the path to libjemalloc.so
77 | echo "Switch memory allocator to jemalloc."
78 | export LD_PRELOAD=$JEMALLOC:$LD_PRELOAD
79 | elif [ -n "$TCMALLOC" ]; then # if found the path to libtcmalloc.so.4
80 | echo "Switch memory allocator to tcmalloc."
81 | export LD_PRELOAD=$TCMALLOC:$LD_PRELOAD
82 | fi
83 |
84 | # Get an available port for launching distributed training.
85 | # Credit to https://superuser.com/a/1293762.
86 | export DEFAULT_FREE_PORT
87 | DEFAULT_FREE_PORT=$(comm -23 <(seq 49152 65535 | sort) \
88 | <(ss -Htan | awk '{print $4}' | cut -d':' -f2 | sort -u) \
89 | | shuf | head -n 1)
90 |
91 | MASTER_IP=${MASTER_IP:-127.0.0.1}
92 | MASTER_PORT=${MASTER_PORT:-$DEFAULT_FREE_PORT}
93 | NODE_SIZE=${NODE_SIZE:-1}
94 | NODE_RANK=${NODE_RANK:-0}
95 |
96 | ${PYTHON} -m torch.distributed.launch \
97 | --master_addr=${MASTER_IP} \
98 | --master_port=${MASTER_PORT} \
99 | --nnodes=${NODE_SIZE} \
100 | --node_rank=${NODE_RANK} \
101 | --nproc_per_node=${GPUS} \
102 | ./main.py \
103 | ${COMMAND} \
104 | ${@:3} \
105 | || exit 1 # Stop the script when it finds exception threw by Python.
106 |
--------------------------------------------------------------------------------
/eval_zs.sh:
--------------------------------------------------------------------------------
1 | torchrun --nproc_per_node 3 \
2 | --rdzv_endpoint=$HOSTE_NODE_ADDR \
3 | -m main \
4 | --imagenet-val '/nas1/datasets/ImageNet1k/val' \
5 | --logs ../eval/ \
6 | --pretrained '/home/kecheng/ECCV2024/epoch_32.pt' \
7 | --batch-size=16 \
8 | --model "ViT-B-16" \
9 |
--------------------------------------------------------------------------------
/figures/moti.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/DreamLIP/b4de0a43b6c002033c02873f91a695ab449e464c/figures/moti.png
--------------------------------------------------------------------------------
/figures/radar.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/DreamLIP/b4de0a43b6c002033c02873f91a695ab449e464c/figures/radar.jpg
--------------------------------------------------------------------------------
/figures/radar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/DreamLIP/b4de0a43b6c002033c02873f91a695ab449e464c/figures/radar.png
--------------------------------------------------------------------------------
/open_clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .coca_model import CoCa
2 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
3 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
4 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint
5 | from .loss import ClipLoss, DistillClipLoss, CoCaLoss
6 | from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
7 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \
8 | get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg
9 | from .openai import load_openai_model, list_openai_models
10 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
11 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
12 | from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
13 | from .tokenizer import SimpleTokenizer, tokenize, decode
14 | from .transform import image_transform, AugmentationCfg
15 | from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy
16 | from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES
17 |
--------------------------------------------------------------------------------
/open_clip/big_vision.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | from .model import CustomTextCLIP
5 | from .transformer import TextTransformer, Transformer
6 |
7 |
8 | @torch.no_grad()
9 | def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str):
10 | """ Load weights from .npz checkpoints for official Google big_vision image-text models
11 |
12 | Currently the SigLIP source models are supported and a CustomTextCLIP destination model
13 | w/ timm image encoder.
14 | """
15 | from timm.layers import resample_patch_embed, resample_abs_pos_embed
16 |
17 | def _n2p(w, t=True):
18 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
19 | w = w.flatten()
20 | if t:
21 | if w.ndim == 4:
22 | w = w.transpose([3, 2, 0, 1])
23 | elif w.ndim == 3:
24 | w = w.transpose([2, 0, 1])
25 | elif w.ndim == 2:
26 | w = w.transpose([1, 0])
27 | return torch.from_numpy(w)
28 |
29 | w = np.load(checkpoint_path)
30 | interpolation = 'bilinear'
31 | antialias = False
32 |
33 | def _convert_timm_img(module, prefix):
34 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
35 | if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]:
36 | embed_conv_w = resample_patch_embed(
37 | embed_conv_w,
38 | module.patch_embed.proj.weight.shape[-2:],
39 | interpolation=interpolation,
40 | antialias=antialias,
41 | verbose=True,
42 | )
43 | module.patch_embed.proj.weight.copy_(embed_conv_w)
44 | module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
45 |
46 | if module.cls_token is not None:
47 | module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
48 |
49 | pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False)
50 | if pos_embed_w.shape != module.pos_embed.shape:
51 | assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}'
52 | num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1)
53 | pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
54 | pos_embed_w,
55 | new_size=module.patch_embed.grid_size,
56 | num_prefix_tokens=num_prefix_tokens,
57 | interpolation=interpolation,
58 | antialias=antialias,
59 | verbose=True,
60 | )
61 | module.pos_embed.copy_(pos_embed_w)
62 |
63 | mha_sub, b_sub, ln1_sub = (0, 0, 1)
64 | for i, block in enumerate(module.blocks.children()):
65 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
66 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
67 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
68 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
69 | block.attn.qkv.weight.copy_(torch.cat([
70 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
71 | block.attn.qkv.bias.copy_(torch.cat([
72 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
73 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
74 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
75 | for r in range(2):
76 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel']))
77 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias']))
78 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale']))
79 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias']))
80 |
81 | module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
82 | module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
83 |
84 | if module.attn_pool is not None:
85 | block_prefix = f'{prefix}MAPHead_0/'
86 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
87 | module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False))
88 | module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T)
89 | module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1))
90 | module.attn_pool.kv.weight.copy_(torch.cat([
91 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')]))
92 | module.attn_pool.kv.bias.copy_(torch.cat([
93 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')]))
94 | module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
95 | module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
96 | module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
97 | module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
98 | for r in range(2):
99 | getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel']))
100 | getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias']))
101 |
102 | def _convert_openclip_transformer(module: Transformer, prefix):
103 | for i, block in enumerate(module.resblocks.children()):
104 | block_prefix = f'{prefix}encoderblock_{i}/'
105 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
106 | block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
107 | block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
108 | block.attn.in_proj_weight.copy_(torch.cat([
109 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
110 | block.attn.in_proj_bias.copy_(torch.cat([
111 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
112 | block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
113 | block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
114 | block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale']))
115 | block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias']))
116 | block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel']))
117 | block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias']))
118 | block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel']))
119 | block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias']))
120 |
121 | def _convert_openclip_txt(module: TextTransformer, prefix):
122 | module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False))
123 | pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0)
124 | module.positional_embedding.copy_(pos_embed_w)
125 | _convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/')
126 | module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale']))
127 | module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias']))
128 | module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
129 | module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias']))
130 |
131 | _convert_timm_img(model.visual.trunk, 'params/img/')
132 | _convert_openclip_txt(model.text, 'params/txt/')
133 | model.logit_bias.copy_(_n2p(w['params/b'])[0])
134 | model.logit_scale.copy_(_n2p(w['params/t'])[0])
135 |
136 |
137 |
--------------------------------------------------------------------------------
/open_clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/DreamLIP/b4de0a43b6c002033c02873f91a695ab449e464c/open_clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/open_clip/constants.py:
--------------------------------------------------------------------------------
1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
3 | IMAGENET_MEAN = (0.485, 0.456, 0.406)
4 | IMAGENET_STD = (0.229, 0.224, 0.225)
5 | INCEPTION_MEAN = (0.5, 0.5, 0.5)
6 | INCEPTION_STD = (0.5, 0.5, 0.5)
7 |
--------------------------------------------------------------------------------
/open_clip/generation_utils.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/DreamLIP/b4de0a43b6c002033c02873f91a695ab449e464c/open_clip/generation_utils.py
--------------------------------------------------------------------------------
/open_clip/hf_configs.py:
--------------------------------------------------------------------------------
1 | # HF architecture dict:
2 | arch_dict = {
3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
4 | "roberta": {
5 | "config_names": {
6 | "context_length": "max_position_embeddings",
7 | "vocab_size": "vocab_size",
8 | "width": "hidden_size",
9 | "heads": "num_attention_heads",
10 | "layers": "num_hidden_layers",
11 | "layer_attr": "layer",
12 | "token_embeddings_attr": "embeddings"
13 | },
14 | "pooler": "mean_pooler",
15 | },
16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17 | "xlm-roberta": {
18 | "config_names": {
19 | "context_length": "max_position_embeddings",
20 | "vocab_size": "vocab_size",
21 | "width": "hidden_size",
22 | "heads": "num_attention_heads",
23 | "layers": "num_hidden_layers",
24 | "layer_attr": "layer",
25 | "token_embeddings_attr": "embeddings"
26 | },
27 | "pooler": "mean_pooler",
28 | },
29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30 | "mt5": {
31 | "config_names": {
32 | # unlimited seqlen
33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35 | "context_length": "",
36 | "vocab_size": "vocab_size",
37 | "width": "d_model",
38 | "heads": "num_heads",
39 | "layers": "num_layers",
40 | "layer_attr": "block",
41 | "token_embeddings_attr": "embed_tokens"
42 | },
43 | "pooler": "mean_pooler",
44 | },
45 | # https://huggingface.co/docs/transformers/model_doc/bert
46 | "bert": {
47 | "config_names": {
48 | "context_length": "max_position_embeddings",
49 | "vocab_size": "vocab_size",
50 | "width": "hidden_size",
51 | "heads": "num_attention_heads",
52 | "layers": "num_hidden_layers",
53 | },
54 | "pooler": "cls_pooler",
55 | },
56 | # https://huggingface.co/docs/transformers/model_doc/m2m_100
57 | "m2m_100": {
58 | "config_names": {
59 | "context_length": "max_position_embeddings",
60 | "vocab_size": "vocab_size",
61 | "width": "d_model",
62 | "heads": "encoder_attention_heads",
63 | "layers": "encoder_layers",
64 | },
65 | "pooler": "cls_pooler",
66 | },
67 | }
68 |
--------------------------------------------------------------------------------
/open_clip/hf_model.py:
--------------------------------------------------------------------------------
1 | """ huggingface model adapter
2 |
3 | Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
4 | """
5 | import re
6 | import os
7 | import torch
8 | import torch.nn as nn
9 | from torch import TensorType
10 |
11 | try:
12 | import transformers
13 | from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
14 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
15 | BaseModelOutputWithPoolingAndCrossAttentions
16 | except ImportError as e:
17 | transformers = None
18 |
19 |
20 | class BaseModelOutput:
21 | pass
22 |
23 |
24 | class PretrainedConfig:
25 | pass
26 |
27 | from .hf_configs import arch_dict
28 |
29 |
30 | # utils
31 | def _camel2snake(s):
32 | return re.sub(r'(? 1
17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 | self.act1 = nn.ReLU(inplace=True)
20 |
21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 | self.act2 = nn.ReLU(inplace=True)
24 |
25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26 |
27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29 | self.act3 = nn.ReLU(inplace=True)
30 |
31 | self.downsample = None
32 | self.stride = stride
33 |
34 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36 | self.downsample = nn.Sequential(OrderedDict([
37 | ("-1", nn.AvgPool2d(stride)),
38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39 | ("1", nn.BatchNorm2d(planes * self.expansion))
40 | ]))
41 |
42 | def forward(self, x: torch.Tensor):
43 | identity = x
44 |
45 | out = self.act1(self.bn1(self.conv1(x)))
46 | out = self.act2(self.bn2(self.conv2(out)))
47 | out = self.avgpool(out)
48 | out = self.bn3(self.conv3(out))
49 |
50 | if self.downsample is not None:
51 | identity = self.downsample(x)
52 |
53 | out += identity
54 | out = self.act3(out)
55 | return out
56 |
57 |
58 | class AttentionPool2d(nn.Module):
59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60 | super().__init__()
61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62 | self.k_proj = nn.Linear(embed_dim, embed_dim)
63 | self.q_proj = nn.Linear(embed_dim, embed_dim)
64 | self.v_proj = nn.Linear(embed_dim, embed_dim)
65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66 | self.num_heads = num_heads
67 |
68 | def forward(self, x):
69 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72 | x, _ = F.multi_head_attention_forward(
73 | query=x, key=x, value=x,
74 | embed_dim_to_check=x.shape[-1],
75 | num_heads=self.num_heads,
76 | q_proj_weight=self.q_proj.weight,
77 | k_proj_weight=self.k_proj.weight,
78 | v_proj_weight=self.v_proj.weight,
79 | in_proj_weight=None,
80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81 | bias_k=None,
82 | bias_v=None,
83 | add_zero_attn=False,
84 | dropout_p=0.,
85 | out_proj_weight=self.c_proj.weight,
86 | out_proj_bias=self.c_proj.bias,
87 | use_separate_proj_weight=True,
88 | training=self.training,
89 | need_weights=False
90 | )
91 |
92 | return x[0]
93 |
94 |
95 | class ModifiedResNet(nn.Module):
96 | """
97 | A ResNet class that is similar to torchvision's but contains the following changes:
98 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
99 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
100 | - The final pooling layer is a QKV attention instead of an average pool
101 | """
102 |
103 | def __init__(self, layers, output_dim, heads, image_size=224, width=64):
104 | super().__init__()
105 | self.output_dim = output_dim
106 | self.image_size = image_size
107 |
108 | # the 3-layer stem
109 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
110 | self.bn1 = nn.BatchNorm2d(width // 2)
111 | self.act1 = nn.ReLU(inplace=True)
112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
113 | self.bn2 = nn.BatchNorm2d(width // 2)
114 | self.act2 = nn.ReLU(inplace=True)
115 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
116 | self.bn3 = nn.BatchNorm2d(width)
117 | self.act3 = nn.ReLU(inplace=True)
118 | self.avgpool = nn.AvgPool2d(2)
119 |
120 | # residual layers
121 | self._inplanes = width # this is a *mutable* variable used during construction
122 | self.layer1 = self._make_layer(width, layers[0])
123 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
124 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
125 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
126 |
127 | embed_dim = width * 32 # the ResNet feature dimension
128 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
129 |
130 | self.init_parameters()
131 |
132 | def _make_layer(self, planes, blocks, stride=1):
133 | layers = [Bottleneck(self._inplanes, planes, stride)]
134 |
135 | self._inplanes = planes * Bottleneck.expansion
136 | for _ in range(1, blocks):
137 | layers.append(Bottleneck(self._inplanes, planes))
138 |
139 | return nn.Sequential(*layers)
140 |
141 | def init_parameters(self):
142 | if self.attnpool is not None:
143 | std = self.attnpool.c_proj.in_features ** -0.5
144 | nn.init.normal_(self.attnpool.q_proj.weight, std=std)
145 | nn.init.normal_(self.attnpool.k_proj.weight, std=std)
146 | nn.init.normal_(self.attnpool.v_proj.weight, std=std)
147 | nn.init.normal_(self.attnpool.c_proj.weight, std=std)
148 |
149 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
150 | for name, param in resnet_block.named_parameters():
151 | if name.endswith("bn3.weight"):
152 | nn.init.zeros_(param)
153 |
154 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
155 | assert unlocked_groups == 0, 'partial locking not currently supported for this model'
156 | for param in self.parameters():
157 | param.requires_grad = False
158 | if freeze_bn_stats:
159 | freeze_batch_norm_2d(self)
160 |
161 | @torch.jit.ignore
162 | def set_grad_checkpointing(self, enable=True):
163 | # FIXME support for non-transformer
164 | pass
165 |
166 | def stem(self, x):
167 | x = self.act1(self.bn1(self.conv1(x)))
168 | x = self.act2(self.bn2(self.conv2(x)))
169 | x = self.act3(self.bn3(self.conv3(x)))
170 | x = self.avgpool(x)
171 | return x
172 |
173 | def forward(self, x):
174 | x = self.stem(x)
175 | x = self.layer1(x)
176 | x = self.layer2(x)
177 | x = self.layer3(x)
178 | x = self.layer4(x)
179 | x = self.attnpool(x)
180 |
181 | return x
182 |
--------------------------------------------------------------------------------
/open_clip/openai.py:
--------------------------------------------------------------------------------
1 | """ OpenAI pretrained model functions
2 |
3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 |
6 | import os
7 | import warnings
8 | from typing import List, Optional, Union
9 |
10 | import torch
11 |
12 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
13 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
14 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
15 |
16 | __all__ = ["list_openai_models", "load_openai_model"]
17 |
18 |
19 | def list_openai_models() -> List[str]:
20 | """Returns the names of available CLIP models"""
21 | return list_pretrained_models_by_tag('openai')
22 |
23 |
24 | def load_openai_model(
25 | name: str,
26 | precision: Optional[str] = None,
27 | device: Optional[Union[str, torch.device]] = None,
28 | cache_dir: Optional[str] = None,
29 | ):
30 | """Load a CLIP model
31 |
32 | Parameters
33 | ----------
34 | name : str
35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
36 | precision: str
37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
38 | device : Union[str, torch.device]
39 | The device to put the loaded model
40 | cache_dir : Optional[str]
41 | The directory to cache the downloaded model weights
42 |
43 | Returns
44 | -------
45 | model : torch.nn.Module
46 | The CLIP model
47 | preprocess : Callable[[PIL.Image], torch.Tensor]
48 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
49 | """
50 | if device is None:
51 | device = "cuda" if torch.cuda.is_available() else "cpu"
52 | if precision is None:
53 | precision = 'fp32' if device == 'cpu' else 'fp16'
54 |
55 | if get_pretrained_url(name, 'openai'):
56 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
57 | elif os.path.isfile(name):
58 | model_path = name
59 | else:
60 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
61 |
62 | try:
63 | # loading JIT archive
64 | model = torch.jit.load(model_path, map_location="cpu").eval()
65 | state_dict = None
66 | except RuntimeError:
67 | # loading saved state dict
68 | state_dict = torch.load(model_path, map_location="cpu")
69 |
70 | # Build a non-jit model from the OpenAI jitted model state dict
71 | cast_dtype = get_cast_dtype(precision)
72 | try:
73 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
74 | except KeyError:
75 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
76 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
77 |
78 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
79 | model = model.to(device)
80 | # FIXME support pure fp16/bf16 precision modes
81 | if precision != 'fp16':
82 | model.float()
83 | if precision == 'bf16':
84 | # for bf16, convert back to low-precision
85 | convert_weights_to_lp(model, dtype=torch.bfloat16)
86 |
87 | # add mean / std attributes for consistency with OpenCLIP models
88 | model.visual.image_mean = OPENAI_DATASET_MEAN
89 | model.visual.image_std = OPENAI_DATASET_STD
90 | return model
91 |
--------------------------------------------------------------------------------
/open_clip/pos_embed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # Position embedding utils
8 | # --------------------------------------------------------
9 |
10 | import numpy as np
11 |
12 | import torch
13 |
14 | # --------------------------------------------------------
15 | # 2D sine-cosine position embedding
16 | # References:
17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
18 | # MoCo v3: https://github.com/facebookresearch/moco-v3
19 | # --------------------------------------------------------
20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
21 | """
22 | grid_size: int of the grid height and width
23 | return:
24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
25 | """
26 | grid_h = np.arange(grid_size, dtype=np.float32)
27 | grid_w = np.arange(grid_size, dtype=np.float32)
28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
29 | grid = np.stack(grid, axis=0)
30 |
31 | grid = grid.reshape([2, 1, grid_size, grid_size])
32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
33 | if cls_token:
34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
35 | return pos_embed
36 |
37 |
38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
39 | assert embed_dim % 2 == 0
40 |
41 | # use half of dimensions to encode grid_h
42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
44 |
45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
46 | return emb
47 |
48 |
49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
50 | """
51 | embed_dim: output dimension for each position
52 | pos: a list of positions to be encoded: size (M,)
53 | out: (M, D)
54 | """
55 | assert embed_dim % 2 == 0
56 | omega = np.arange(embed_dim // 2, dtype=float)
57 | omega /= embed_dim / 2.
58 | omega = 1. / 10000**omega # (D/2,)
59 |
60 | pos = pos.reshape(-1) # (M,)
61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
62 |
63 | emb_sin = np.sin(out) # (M, D/2)
64 | emb_cos = np.cos(out) # (M, D/2)
65 |
66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
67 | return emb
68 |
69 |
70 | # --------------------------------------------------------
71 | # Interpolate position embeddings for high-resolution
72 | # References:
73 | # DeiT: https://github.com/facebookresearch/deit
74 | # --------------------------------------------------------
75 | def interpolate_pos_embed(model, checkpoint_model):
76 | if 'pos_embed' in checkpoint_model:
77 | pos_embed_checkpoint = checkpoint_model['pos_embed']
78 | embedding_size = pos_embed_checkpoint.shape[-1]
79 | num_patches = model.patch_embed.num_patches
80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
81 | # height (== width) for the checkpoint position embedding
82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
83 | # height (== width) for the new position embedding
84 | new_size = int(num_patches ** 0.5)
85 | # class_token and dist_token are kept unchanged
86 | if orig_size != new_size:
87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
89 | # only the position tokens are interpolated
90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
92 | pos_tokens = torch.nn.functional.interpolate(
93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
96 | checkpoint_model['pos_embed'] = new_pos_embed
97 |
--------------------------------------------------------------------------------
/open_clip/push_to_hf_hub.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | from pathlib import Path
5 | from tempfile import TemporaryDirectory
6 | from typing import Optional, Tuple, Union
7 |
8 | import torch
9 |
10 | try:
11 | from huggingface_hub import (
12 | create_repo,
13 | get_hf_file_metadata,
14 | hf_hub_download,
15 | hf_hub_url,
16 | repo_type_and_id_from_hf_id,
17 | upload_folder,
18 | list_repo_files,
19 | )
20 | from huggingface_hub.utils import EntryNotFoundError
21 | _has_hf_hub = True
22 | except ImportError:
23 | _has_hf_hub = False
24 |
25 | try:
26 | import safetensors.torch
27 | _has_safetensors = True
28 | except ImportError:
29 | _has_safetensors = False
30 |
31 | from .factory import create_model_from_pretrained, get_model_config, get_tokenizer
32 | from .tokenizer import HFTokenizer
33 |
34 | # Default name for a weights file hosted on the Huggingface Hub.
35 | HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl
36 | HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version
37 | HF_CONFIG_NAME = 'open_clip_config.json'
38 |
39 |
40 | def save_config_for_hf(
41 | model,
42 | config_path: str,
43 | model_config: Optional[dict]
44 | ):
45 | preprocess_cfg = {
46 | 'mean': model.visual.image_mean,
47 | 'std': model.visual.image_std,
48 | }
49 | other_pp = getattr(model.visual, 'preprocess_cfg', {})
50 | if 'interpolation' in other_pp:
51 | preprocess_cfg['interpolation'] = other_pp['interpolation']
52 | if 'resize_mode' in other_pp:
53 | preprocess_cfg['resize_mode'] = other_pp['resize_mode']
54 | hf_config = {
55 | 'model_cfg': model_config,
56 | 'preprocess_cfg': preprocess_cfg,
57 | }
58 |
59 | with config_path.open('w') as f:
60 | json.dump(hf_config, f, indent=2)
61 |
62 |
63 | def save_for_hf(
64 | model,
65 | tokenizer: HFTokenizer,
66 | model_config: dict,
67 | save_directory: str,
68 | safe_serialization: Union[bool, str] = 'both',
69 | skip_weights : bool = False,
70 | ):
71 | config_filename = HF_CONFIG_NAME
72 |
73 | save_directory = Path(save_directory)
74 | save_directory.mkdir(exist_ok=True, parents=True)
75 |
76 | if not skip_weights:
77 | tensors = model.state_dict()
78 | if safe_serialization is True or safe_serialization == "both":
79 | assert _has_safetensors, "`pip install safetensors` to use .safetensors"
80 | safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
81 | if safe_serialization is False or safe_serialization == "both":
82 | torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
83 |
84 | tokenizer.save_pretrained(save_directory)
85 |
86 | config_path = save_directory / config_filename
87 | save_config_for_hf(model, config_path, model_config=model_config)
88 |
89 |
90 | def push_to_hf_hub(
91 | model,
92 | tokenizer,
93 | model_config: Optional[dict],
94 | repo_id: str,
95 | commit_message: str = 'Add model',
96 | token: Optional[str] = None,
97 | revision: Optional[str] = None,
98 | private: bool = False,
99 | create_pr: bool = False,
100 | model_card: Optional[dict] = None,
101 | safe_serialization: Union[bool, str] = False,
102 | ):
103 | if not isinstance(tokenizer, HFTokenizer):
104 | # FIXME this makes it awkward to push models with new tokenizers, come up with better soln.
105 | # default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14
106 | tokenizer = HFTokenizer('openai/clip-vit-large-patch14')
107 |
108 | # Create repo if it doesn't exist yet
109 | repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
110 |
111 | # Infer complete repo_id from repo_url
112 | # Can be different from the input `repo_id` if repo_owner was implicit
113 | _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
114 | repo_id = f"{repo_owner}/{repo_name}"
115 |
116 | # Check if repo already exists and determine what needs updating
117 | repo_exists = False
118 | repo_files = {}
119 | try:
120 | repo_files = set(list_repo_files(repo_id))
121 | repo_exists = True
122 | except Exception as e:
123 | print('Repo does not exist', e)
124 |
125 | try:
126 | get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
127 | has_readme = True
128 | except EntryNotFoundError:
129 | has_readme = False
130 |
131 | # Dump model and push to Hub
132 | with TemporaryDirectory() as tmpdir:
133 | # Save model weights and config.
134 | save_for_hf(
135 | model,
136 | tokenizer=tokenizer,
137 | model_config=model_config,
138 | save_directory=tmpdir,
139 | safe_serialization=safe_serialization,
140 | )
141 |
142 | # Add readme if it does not exist
143 | if not has_readme:
144 | model_card = model_card or {}
145 | model_name = repo_id.split('/')[-1]
146 | readme_path = Path(tmpdir) / "README.md"
147 | readme_text = generate_readme(model_card, model_name)
148 | readme_path.write_text(readme_text)
149 |
150 | # Upload model and return
151 | return upload_folder(
152 | repo_id=repo_id,
153 | folder_path=tmpdir,
154 | revision=revision,
155 | create_pr=create_pr,
156 | commit_message=commit_message,
157 | )
158 |
159 |
160 | def push_pretrained_to_hf_hub(
161 | model_name,
162 | pretrained: str,
163 | repo_id: str,
164 | precision: str = 'fp32',
165 | image_mean: Optional[Tuple[float, ...]] = None,
166 | image_std: Optional[Tuple[float, ...]] = None,
167 | image_interpolation: Optional[str] = None,
168 | image_resize_mode: Optional[str] = None, # only effective for inference
169 | commit_message: str = 'Add model',
170 | token: Optional[str] = None,
171 | revision: Optional[str] = None,
172 | private: bool = False,
173 | create_pr: bool = False,
174 | model_card: Optional[dict] = None,
175 | hf_tokenizer_self: bool = False,
176 | ):
177 | model, preprocess_eval = create_model_from_pretrained(
178 | model_name,
179 | pretrained=pretrained,
180 | precision=precision,
181 | image_mean=image_mean,
182 | image_std=image_std,
183 | image_interpolation=image_interpolation,
184 | image_resize_mode=image_resize_mode,
185 | )
186 | model_config = get_model_config(model_name)
187 | assert model_config
188 |
189 | tokenizer = get_tokenizer(model_name)
190 | if hf_tokenizer_self:
191 | # make hf tokenizer config in the uploaded model point to self instead of original location
192 | model_config['text']['hf_tokenizer_name'] = repo_id
193 |
194 | push_to_hf_hub(
195 | model=model,
196 | tokenizer=tokenizer,
197 | model_config=model_config,
198 | repo_id=repo_id,
199 | commit_message=commit_message,
200 | token=token,
201 | revision=revision,
202 | private=private,
203 | create_pr=create_pr,
204 | model_card=model_card,
205 | safe_serialization='both',
206 | )
207 |
208 |
209 | def generate_readme(model_card: dict, model_name: str):
210 | tags = model_card.pop('tags', ('clip',))
211 | pipeline_tag = model_card.pop('pipeline_tag', 'zero-shot-image-classification')
212 | readme_text = "---\n"
213 | if tags:
214 | readme_text += "tags:\n"
215 | for t in tags:
216 | readme_text += f"- {t}\n"
217 | readme_text += "library_name: open_clip\n"
218 | readme_text += f"pipeline_tag: {pipeline_tag}\n"
219 | readme_text += f"license: {model_card.get('license', 'mit')}\n"
220 | if 'details' in model_card and 'Dataset' in model_card['details']:
221 | readme_text += 'datasets:\n'
222 | readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
223 | readme_text += "---\n"
224 | readme_text += f"# Model card for {model_name}\n"
225 | if 'description' in model_card:
226 | readme_text += f"\n{model_card['description']}\n"
227 | if 'details' in model_card:
228 | readme_text += f"\n## Model Details\n"
229 | for k, v in model_card['details'].items():
230 | if isinstance(v, (list, tuple)):
231 | readme_text += f"- **{k}:**\n"
232 | for vi in v:
233 | readme_text += f" - {vi}\n"
234 | elif isinstance(v, dict):
235 | readme_text += f"- **{k}:**\n"
236 | for ki, vi in v.items():
237 | readme_text += f" - {ki}: {vi}\n"
238 | else:
239 | readme_text += f"- **{k}:** {v}\n"
240 | if 'usage' in model_card:
241 | readme_text += f"\n## Model Usage\n"
242 | readme_text += model_card['usage']
243 | readme_text += '\n'
244 |
245 | if 'comparison' in model_card:
246 | readme_text += f"\n## Model Comparison\n"
247 | readme_text += model_card['comparison']
248 | readme_text += '\n'
249 |
250 | if 'citation' in model_card:
251 | readme_text += f"\n## Citation\n"
252 | if not isinstance(model_card['citation'], (list, tuple)):
253 | citations = [model_card['citation']]
254 | else:
255 | citations = model_card['citation']
256 | for c in citations:
257 | readme_text += f"```bibtex\n{c}\n```\n"
258 |
259 | return readme_text
260 |
261 |
262 | if __name__ == "__main__":
263 | parser = argparse.ArgumentParser(description="Push to Hugging Face Hub")
264 | parser.add_argument(
265 | "--model", type=str, help="Name of the model to use.",
266 | )
267 | parser.add_argument(
268 | "--pretrained", type=str,
269 | help="Use a pretrained CLIP model weights with the specified tag or file path.",
270 | )
271 | parser.add_argument(
272 | "--repo-id", type=str,
273 | help="Destination HF Hub repo-id ie 'organization/model_id'.",
274 | )
275 | parser.add_argument(
276 | "--precision", type=str, default='fp32',
277 | )
278 | parser.add_argument(
279 | '--image-mean', type=float, nargs='+', default=None, metavar='MEAN',
280 | help='Override default image mean value of dataset')
281 | parser.add_argument(
282 | '--image-std', type=float, nargs='+', default=None, metavar='STD',
283 | help='Override default image std deviation of of dataset')
284 | parser.add_argument(
285 | '--image-interpolation',
286 | default=None, type=str, choices=['bicubic', 'bilinear', 'random'],
287 | help="image resize interpolation"
288 | )
289 | parser.add_argument(
290 | '--image-resize-mode',
291 | default=None, type=str, choices=['shortest', 'longest', 'squash'],
292 | help="image resize mode during inference"
293 | )
294 | parser.add_argument(
295 | "--hf-tokenizer-self",
296 | default=False,
297 | action="store_true",
298 | help="make hf_tokenizer_name point in uploaded config point to itself"
299 | )
300 | args = parser.parse_args()
301 |
302 | print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}')
303 |
304 | # FIXME add support to pass model_card json / template from file via cmd line
305 |
306 | push_pretrained_to_hf_hub(
307 | args.model,
308 | args.pretrained,
309 | args.repo_id,
310 | precision=args.precision,
311 | image_mean=args.image_mean, # override image mean/std if trained w/ non defaults
312 | image_std=args.image_std,
313 | image_interpolation=args.image_interpolation,
314 | image_resize_mode=args.image_resize_mode,
315 | )
316 |
317 | print(f'{args.model} saved.')
318 |
--------------------------------------------------------------------------------
/open_clip/timm_model.py:
--------------------------------------------------------------------------------
1 | """ timm model adapter
2 |
3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4 | """
5 | import logging
6 | from collections import OrderedDict
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 | try:
12 | import timm
13 | from timm.models.layers import Mlp, to_2tuple
14 | try:
15 | # old timm imports < 0.8.1
16 | from timm.models.layers.attention_pool2d import RotAttentionPool2d
17 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
18 | except ImportError:
19 | # new timm imports >= 0.8.1
20 | from timm.layers import RotAttentionPool2d
21 | from timm.layers import AttentionPool2d as AbsAttentionPool2d
22 | except ImportError:
23 | timm = None
24 |
25 | from .utils import freeze_batch_norm_2d
26 |
27 |
28 | class TimmModel(nn.Module):
29 | """ timm model adapter
30 | """
31 |
32 | def __init__(
33 | self,
34 | model_name,
35 | embed_dim,
36 | image_size=224,
37 | pool='avg',
38 | proj='linear',
39 | proj_bias=False,
40 | drop=0.,
41 | drop_path=None,
42 | patch_drop=None,
43 | pretrained=False,
44 | ):
45 | super().__init__()
46 | if timm is None:
47 | raise RuntimeError("Please `pip install timm` to use timm models.")
48 | self.image_size = to_2tuple(image_size)
49 |
50 | # setup kwargs that may not be common across all models
51 | timm_kwargs = {}
52 | if drop_path is not None:
53 | timm_kwargs['drop_path_rate'] = drop_path
54 | if patch_drop is not None:
55 | timm_kwargs['patch_drop_rate'] = patch_drop
56 |
57 | custom_pool = pool in ('abs_attn', 'rot_attn')
58 | if proj:
59 | assert proj in ("linear", "mlp", "none")
60 | extra_proj = proj in ("linear", "mlp")
61 | if not extra_proj and not custom_pool:
62 | # use network classifier head as projection if no proj specified and no custom pooling used
63 | # if projection is explicitly set to "none" will be pass through from network trunk
64 | proj_dim = 0 if proj == 'none' else embed_dim
65 | self.trunk = timm.create_model(
66 | model_name,
67 | num_classes=proj_dim,
68 | global_pool=pool,
69 | pretrained=pretrained,
70 | **timm_kwargs,
71 | )
72 | prev_chs = embed_dim
73 | else:
74 | self.trunk = timm.create_model(
75 | model_name,
76 | pretrained=pretrained,
77 | **timm_kwargs,
78 | )
79 | feat_size = self.trunk.default_cfg.get('pool_size', None)
80 | feature_ndim = 1 if not feat_size else 2
81 | if custom_pool:
82 | assert feature_ndim == 2
83 | # if attn pooling used, remove both classifier and default pool
84 | self.trunk.reset_classifier(0, global_pool='')
85 | else:
86 | # reset global pool if pool config set, otherwise leave as network default
87 | reset_kwargs = dict(global_pool=pool) if pool else {}
88 | self.trunk.reset_classifier(0, **reset_kwargs)
89 | prev_chs = self.trunk.num_features
90 |
91 | head_layers = OrderedDict()
92 |
93 | # Add custom pooling to head
94 | if pool == 'abs_attn':
95 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
96 | prev_chs = embed_dim
97 | elif pool == 'rot_attn':
98 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
99 | prev_chs = embed_dim
100 |
101 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
102 | if proj == 'linear':
103 | head_layers['drop'] = nn.Dropout(drop)
104 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
105 | elif proj == 'mlp':
106 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias))
107 |
108 | self.head = nn.Sequential(head_layers)
109 |
110 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
111 | """ lock modules
112 | Args:
113 | unlocked_groups (int): leave last n layer groups unlocked (default: 0)
114 | """
115 | if not unlocked_groups:
116 | # lock full model
117 | for param in self.trunk.parameters():
118 | param.requires_grad = False
119 | if freeze_bn_stats:
120 | freeze_batch_norm_2d(self.trunk)
121 | else:
122 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change
123 | try:
124 | # FIXME import here until API stable and in an official release
125 | from timm.models.helpers import group_parameters, group_modules
126 | except ImportError:
127 | raise RuntimeError(
128 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
129 | matcher = self.trunk.group_matcher()
130 | gparams = group_parameters(self.trunk, matcher)
131 | max_layer_id = max(gparams.keys())
132 | max_layer_id = max_layer_id - unlocked_groups
133 | for group_idx in range(max_layer_id + 1):
134 | group = gparams[group_idx]
135 | for param in group:
136 | self.trunk.get_parameter(param).requires_grad = False
137 | if freeze_bn_stats:
138 | gmodules = group_modules(self.trunk, matcher, reverse=True)
139 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
140 | freeze_batch_norm_2d(self.trunk, gmodules)
141 |
142 | @torch.jit.ignore
143 | def set_grad_checkpointing(self, enable=True):
144 | try:
145 | self.trunk.set_grad_checkpointing(enable)
146 | except Exception as e:
147 | logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
148 |
149 | def forward(self, x):
150 | x = self.trunk(x)
151 | x = self.head(x)
152 | return x
153 |
--------------------------------------------------------------------------------
/open_clip/transform.py:
--------------------------------------------------------------------------------
1 | import numbers
2 | import random
3 | import warnings
4 | from dataclasses import dataclass, asdict
5 | from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
6 |
7 | import torch
8 | import torchvision.transforms.functional as F
9 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
10 | CenterCrop, ColorJitter, Grayscale, RandomApply, RandomGrayscale, RandomHorizontalFlip
11 | from PIL import ImageFilter
12 | from training.random_aug import RandomAugment
13 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
14 | from .utils import to_2tuple
15 |
16 |
17 | @dataclass
18 | class PreprocessCfg:
19 | size: Union[int, Tuple[int, int]] = 224
20 | mode: str = 'RGB'
21 | mean: Tuple[float, ...] = OPENAI_DATASET_MEAN
22 | std: Tuple[float, ...] = OPENAI_DATASET_STD
23 | interpolation: str = 'bicubic'
24 | resize_mode: str = 'shortest'
25 | fill_color: int = 0
26 |
27 | def __post_init__(self):
28 | assert self.mode in ('RGB',)
29 |
30 | @property
31 | def num_channels(self):
32 | return 3
33 |
34 | @property
35 | def input_size(self):
36 | return (self.num_channels,) + to_2tuple(self.size)
37 |
38 | _PREPROCESS_KEYS = set(asdict(PreprocessCfg()).keys())
39 |
40 | class GaussianBlur(object):
41 | def __init__(self, sigma=[.1, 2.]):
42 | self.sigma = sigma
43 |
44 | def __call__(self, x):
45 | sigma = random.uniform(self.sigma[0], self.sigma[1])
46 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
47 | return x
48 |
49 |
50 | def merge_preprocess_dict(
51 | base: Union[PreprocessCfg, Dict],
52 | overlay: Dict,
53 | ):
54 | """ Merge overlay key-value pairs on top of base preprocess cfg or dict.
55 | Input dicts are filtered based on PreprocessCfg fields.
56 | """
57 | if isinstance(base, PreprocessCfg):
58 | base_clean = asdict(base)
59 | else:
60 | base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS}
61 | if overlay:
62 | overlay_clean = {k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None}
63 | base_clean.update(overlay_clean)
64 | return base_clean
65 |
66 |
67 | def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs):
68 | return merge_preprocess_dict(base, kwargs)
69 |
70 |
71 | @dataclass
72 | class AugmentationCfg:
73 | scale: Tuple[float, float] = (0.9, 1.0)
74 | ratio: Optional[Tuple[float, float]] = None
75 | color_jitter: Optional[Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]] = None
76 | re_prob: Optional[float] = None
77 | re_count: Optional[int] = None
78 | use_timm: bool = False
79 |
80 | # params for simclr_jitter_gray
81 | color_jitter_prob: float = None
82 | gray_scale_prob: float = None
83 |
84 |
85 | def _setup_size(size, error_msg):
86 | if isinstance(size, numbers.Number):
87 | return int(size), int(size)
88 |
89 | if isinstance(size, Sequence) and len(size) == 1:
90 | return size[0], size[0]
91 |
92 | if len(size) != 2:
93 | raise ValueError(error_msg)
94 |
95 | return size
96 |
97 |
98 | class ResizeKeepRatio:
99 | """ Resize and Keep Ratio
100 |
101 | Copy & paste from `timm`
102 | """
103 |
104 | def __init__(
105 | self,
106 | size,
107 | longest=0.,
108 | interpolation=InterpolationMode.BICUBIC,
109 | random_scale_prob=0.,
110 | random_scale_range=(0.85, 1.05),
111 | random_aspect_prob=0.,
112 | random_aspect_range=(0.9, 1.11)
113 | ):
114 | if isinstance(size, (list, tuple)):
115 | self.size = tuple(size)
116 | else:
117 | self.size = (size, size)
118 | self.interpolation = interpolation
119 | self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest
120 | self.random_scale_prob = random_scale_prob
121 | self.random_scale_range = random_scale_range
122 | self.random_aspect_prob = random_aspect_prob
123 | self.random_aspect_range = random_aspect_range
124 |
125 | @staticmethod
126 | def get_params(
127 | img,
128 | target_size,
129 | longest,
130 | random_scale_prob=0.,
131 | random_scale_range=(0.85, 1.05),
132 | random_aspect_prob=0.,
133 | random_aspect_range=(0.9, 1.11)
134 | ):
135 | """Get parameters
136 | """
137 | source_size = img.size[::-1] # h, w
138 | h, w = source_size
139 | target_h, target_w = target_size
140 | ratio_h = h / target_h
141 | ratio_w = w / target_w
142 | ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)
143 | if random_scale_prob > 0 and random.random() < random_scale_prob:
144 | ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
145 | ratio_factor = (ratio_factor, ratio_factor)
146 | else:
147 | ratio_factor = (1., 1.)
148 | if random_aspect_prob > 0 and random.random() < random_aspect_prob:
149 | aspect_factor = random.uniform(random_aspect_range[0], random_aspect_range[1])
150 | ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor)
151 | size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)]
152 | return size
153 |
154 | def __call__(self, img):
155 | """
156 | Args:
157 | img (PIL Image): Image to be cropped and resized.
158 |
159 | Returns:
160 | PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size
161 | """
162 | size = self.get_params(
163 | img, self.size, self.longest,
164 | self.random_scale_prob, self.random_scale_range,
165 | self.random_aspect_prob, self.random_aspect_range
166 | )
167 | img = F.resize(img, size, self.interpolation)
168 | return img
169 |
170 | def __repr__(self):
171 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
172 | format_string += f', interpolation={self.interpolation})'
173 | format_string += f', longest={self.longest:.3f})'
174 | return format_string
175 |
176 |
177 | def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor:
178 | """Center crops and/or pads the given image.
179 | If the image is torch Tensor, it is expected
180 | to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
181 | If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
182 |
183 | Args:
184 | img (PIL Image or Tensor): Image to be cropped.
185 | output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
186 | it is used for both directions.
187 | fill (int, Tuple[int]): Padding color
188 |
189 | Returns:
190 | PIL Image or Tensor: Cropped image.
191 | """
192 | if isinstance(output_size, numbers.Number):
193 | output_size = (int(output_size), int(output_size))
194 | elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
195 | output_size = (output_size[0], output_size[0])
196 |
197 | _, image_height, image_width = F.get_dimensions(img)
198 | crop_height, crop_width = output_size
199 |
200 | if crop_width > image_width or crop_height > image_height:
201 | padding_ltrb = [
202 | (crop_width - image_width) // 2 if crop_width > image_width else 0,
203 | (crop_height - image_height) // 2 if crop_height > image_height else 0,
204 | (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
205 | (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
206 | ]
207 | img = F.pad(img, padding_ltrb, fill=fill)
208 | _, image_height, image_width = F.get_dimensions(img)
209 | if crop_width == image_width and crop_height == image_height:
210 | return img
211 |
212 | crop_top = int(round((image_height - crop_height) / 2.0))
213 | crop_left = int(round((image_width - crop_width) / 2.0))
214 | return F.crop(img, crop_top, crop_left, crop_height, crop_width)
215 |
216 |
217 | class CenterCropOrPad(torch.nn.Module):
218 | """Crops the given image at the center.
219 | If the image is torch Tensor, it is expected
220 | to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
221 | If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
222 |
223 | Args:
224 | size (sequence or int): Desired output size of the crop. If size is an
225 | int instead of sequence like (h, w), a square crop (size, size) is
226 | made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
227 | """
228 |
229 | def __init__(self, size, fill=0):
230 | super().__init__()
231 | self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
232 | self.fill = fill
233 |
234 | def forward(self, img):
235 | """
236 | Args:
237 | img (PIL Image or Tensor): Image to be cropped.
238 |
239 | Returns:
240 | PIL Image or Tensor: Cropped image.
241 | """
242 | return center_crop_or_pad(img, self.size, fill=self.fill)
243 |
244 | def __repr__(self) -> str:
245 | return f"{self.__class__.__name__}(size={self.size})"
246 |
247 |
248 | def _convert_to_rgb(image):
249 | return image.convert('RGB')
250 |
251 |
252 | class color_jitter(object):
253 | """
254 | Apply Color Jitter to the PIL image with a specified probability.
255 | """
256 | def __init__(self, brightness=0., contrast=0., saturation=0., hue=0., p=0.8):
257 | assert 0. <= p <= 1.
258 | self.p = p
259 | self.transf = ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
260 |
261 | def __call__(self, img):
262 | if random.random() < self.p:
263 | return self.transf(img)
264 | else:
265 | return img
266 |
267 |
268 | class gray_scale(object):
269 | """
270 | Apply Gray Scale to the PIL image with a specified probability.
271 | """
272 | def __init__(self, p=0.2):
273 | assert 0. <= p <= 1.
274 | self.p = p
275 | self.transf = Grayscale(num_output_channels=3)
276 |
277 | def __call__(self, img):
278 | if random.random() < self.p:
279 | return self.transf(img)
280 | else:
281 | return img
282 |
283 |
284 | def image_transform(
285 | image_size: Union[int, Tuple[int, int]],
286 | is_train: bool,
287 | mean: Optional[Tuple[float, ...]] = None,
288 | std: Optional[Tuple[float, ...]] = None,
289 | resize_mode: Optional[str] = None,
290 | interpolation: Optional[str] = None,
291 | fill_color: int = 0,
292 | aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
293 | ):
294 | mean = mean or OPENAI_DATASET_MEAN
295 | if not isinstance(mean, (list, tuple)):
296 | mean = (mean,) * 3
297 |
298 | std = std or OPENAI_DATASET_STD
299 | if not isinstance(std, (list, tuple)):
300 | std = (std,) * 3
301 |
302 | interpolation = interpolation or 'bicubic'
303 | assert interpolation in ['bicubic', 'bilinear', 'random']
304 | # NOTE random is ignored for interpolation_mode, so defaults to BICUBIC for inference if set
305 | interpolation_mode = InterpolationMode.BILINEAR if interpolation == 'bilinear' else InterpolationMode.BICUBIC
306 |
307 | resize_mode = resize_mode or 'shortest'
308 | assert resize_mode in ('shortest', 'longest', 'squash')
309 |
310 | if isinstance(aug_cfg, dict):
311 | aug_cfg = AugmentationCfg(**aug_cfg)
312 | else:
313 | aug_cfg = aug_cfg or AugmentationCfg()
314 |
315 | normalize = Normalize(mean=mean, std=std)
316 | if is_train:
317 | aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
318 | use_timm = aug_cfg_dict.pop('use_timm', False)
319 | use_dataaug = aug_cfg_dict.pop('use_dataaug', False)
320 | if use_timm:
321 | from timm.data import create_transform # timm can still be optional
322 | if isinstance(image_size, (tuple, list)):
323 | assert len(image_size) >= 2
324 | input_size = (3,) + image_size[-2:]
325 | else:
326 | input_size = (3, image_size, image_size)
327 |
328 | aug_cfg_dict.setdefault('color_jitter', None) # disable by default
329 | # drop extra non-timm items
330 | aug_cfg_dict.pop('color_jitter_prob', None)
331 | aug_cfg_dict.pop('gray_scale_prob', None)
332 |
333 | train_transform = create_transform(
334 | input_size=input_size,
335 | is_training=True,
336 | hflip=0.,
337 | mean=mean,
338 | std=std,
339 | re_mode='pixel',
340 | interpolation=interpolation,
341 | **aug_cfg_dict,
342 | )
343 | elif use_dataaug:
344 | train_transform = Compose([
345 | RandomResizedCrop(image_size, scale=(0.2, 1.0), interpolation=InterpolationMode.BICUBIC),
346 | RandomApply([ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
347 | RandomGrayscale(p=0.2),
348 | RandomApply([GaussianBlur([.1, 2.])],p=0.5),
349 | RandomHorizontalFlip(),
350 | RandomAugment(2, 7, isPIL=True, augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness',
351 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
352 | ToTensor(),
353 | normalize
354 | ])
355 | else:
356 | train_transform = [
357 | RandomResizedCrop(
358 | image_size,
359 | scale=aug_cfg_dict.pop('scale'),
360 | interpolation=InterpolationMode.BICUBIC,
361 | ),
362 | _convert_to_rgb,
363 | ]
364 | if aug_cfg.color_jitter_prob:
365 | assert aug_cfg.color_jitter is not None and len(aug_cfg.color_jitter) == 4
366 | train_transform.extend([
367 | color_jitter(*aug_cfg.color_jitter, p=aug_cfg.color_jitter_prob)
368 | ])
369 | if aug_cfg.gray_scale_prob:
370 | train_transform.extend([
371 | gray_scale(aug_cfg.gray_scale_prob)
372 | ])
373 | train_transform.extend([
374 | ToTensor(),
375 | normalize,
376 | ])
377 | train_transform = Compose(train_transform)
378 | if aug_cfg_dict:
379 | warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')
380 | print(train_transform)
381 | return train_transform
382 | else:
383 | if resize_mode == 'longest':
384 | transforms = [
385 | ResizeKeepRatio(image_size, interpolation=interpolation_mode, longest=1),
386 | CenterCropOrPad(image_size, fill=fill_color)
387 | ]
388 | elif resize_mode == 'squash':
389 | if isinstance(image_size, int):
390 | image_size = (image_size, image_size)
391 | transforms = [
392 | Resize(image_size, interpolation=interpolation_mode),
393 | ]
394 | else:
395 | assert resize_mode == 'shortest'
396 | if not isinstance(image_size, (tuple, list)):
397 | image_size = (image_size, image_size)
398 | if image_size[0] == image_size[1]:
399 | # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
400 | transforms = [
401 | Resize(image_size[0], interpolation=interpolation_mode)
402 | ]
403 | else:
404 | # resize shortest edge to matching target dim for non-square target
405 | transforms = [ResizeKeepRatio(image_size)]
406 | transforms += [CenterCrop(image_size)]
407 |
408 | transforms.extend([
409 | _convert_to_rgb,
410 | ToTensor(),
411 | normalize,
412 | ])
413 | return Compose(transforms)
414 |
415 |
416 | def image_transform_v2(
417 | cfg: PreprocessCfg,
418 | is_train: bool,
419 | aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
420 | ):
421 | return image_transform(
422 | image_size=cfg.size,
423 | is_train=is_train,
424 | mean=cfg.mean,
425 | std=cfg.std,
426 | interpolation=cfg.interpolation,
427 | resize_mode=cfg.resize_mode,
428 | fill_color=cfg.fill_color,
429 | aug_cfg=aug_cfg,
430 | )
431 |
--------------------------------------------------------------------------------
/open_clip/utils.py:
--------------------------------------------------------------------------------
1 | from itertools import repeat
2 | import collections.abc
3 |
4 | import torch
5 | from torch import nn as nn
6 | from torchvision.ops.misc import FrozenBatchNorm2d
7 |
8 |
9 | def freeze_batch_norm_2d(module, module_match={}, name=''):
10 | """
11 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
12 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
13 | returned. Otherwise, the module is walked recursively and submodules are converted in place.
14 |
15 | Args:
16 | module (torch.nn.Module): Any PyTorch module.
17 | module_match (dict): Dictionary of full module names to freeze (all if empty)
18 | name (str): Full module name (prefix)
19 |
20 | Returns:
21 | torch.nn.Module: Resulting module
22 |
23 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
24 | """
25 | res = module
26 | is_match = True
27 | if module_match:
28 | is_match = name in module_match
29 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
30 | res = FrozenBatchNorm2d(module.num_features)
31 | res.num_features = module.num_features
32 | res.affine = module.affine
33 | if module.affine:
34 | res.weight.data = module.weight.data.clone().detach()
35 | res.bias.data = module.bias.data.clone().detach()
36 | res.running_mean.data = module.running_mean.data
37 | res.running_var.data = module.running_var.data
38 | res.eps = module.eps
39 | else:
40 | for child_name, child in module.named_children():
41 | full_child_name = '.'.join([name, child_name]) if name else child_name
42 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
43 | if new_child is not child:
44 | res.add_module(child_name, new_child)
45 | return res
46 |
47 |
48 | # From PyTorch internals
49 | def _ntuple(n):
50 | def parse(x):
51 | if isinstance(x, collections.abc.Iterable):
52 | return x
53 | return tuple(repeat(x, n))
54 | return parse
55 |
56 |
57 | to_1tuple = _ntuple(1)
58 | to_2tuple = _ntuple(2)
59 | to_3tuple = _ntuple(3)
60 | to_4tuple = _ntuple(4)
61 | to_ntuple = lambda n, x: _ntuple(n)(x)
62 |
63 | # Replaces all linear layers with linear_replacement
64 | # TODO: add int8 support for other linear layers including attn and convnets
65 | def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True):
66 | for name, module in model.named_children():
67 | if len(list(module.children())) > 0:
68 | replace_linear(module, linear_replacement, include_modules, copy_weights)
69 |
70 | if isinstance(module, torch.nn.Linear) and name in include_modules:
71 | old_module = model._modules[name]
72 | model._modules[name] = linear_replacement(
73 | module.in_features,
74 | module.out_features,
75 | module.bias is not None,
76 | )
77 | if copy_weights:
78 | model._modules[name].weight.data.copy_(old_module.weight.data)
79 | if model._modules[name].bias is not None:
80 | model._modules[name].bias.data.copy_(old_module.bias)
81 |
82 | return model
83 |
84 | def convert_int8_model_to_inference_mode(model):
85 | for m in model.modules():
86 | if hasattr(m, 'prepare_for_eval'):
87 | int8_original_dtype = m.weight.dtype
88 | m.prepare_for_eval()
89 | m.int8_original_dtype = int8_original_dtype
--------------------------------------------------------------------------------
/open_clip/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '2.23.0'
2 |
--------------------------------------------------------------------------------
/open_clip/zero_shot_classifier.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from itertools import islice
3 | from typing import Callable, List, Optional, Sequence, Union
4 |
5 | import torch
6 | import torch.nn.functional as F
7 |
8 |
9 | def batched(iterable, n):
10 | """Batch data into lists of length *n*. The last batch may be shorter.
11 | NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl
12 | """
13 | it = iter(iterable)
14 | while True:
15 | batch = list(islice(it, n))
16 | if not batch:
17 | break
18 | yield batch
19 |
20 |
21 | def build_zero_shot_classifier(
22 | model,
23 | tokenizer,
24 | classnames: Sequence[str],
25 | templates: Sequence[Union[Callable, str]],
26 | num_classes_per_batch: Optional[int] = 10,
27 | device: Union[str, torch.device] = 'cpu',
28 | use_tqdm: bool = False,
29 | ):
30 | """ Build zero-shot classifier weights by iterating over class names in batches
31 | Args:
32 | model: CLIP model instance
33 | tokenizer: CLIP tokenizer instance
34 | classnames: A sequence of class (label) names
35 | templates: A sequence of callables or format() friendly strings to produce templates per class name
36 | num_classes_per_batch: The number of classes to batch together in each forward, all if None
37 | device: Device to use.
38 | use_tqdm: Enable TQDM progress bar.
39 | """
40 | assert isinstance(templates, Sequence) and len(templates) > 0
41 | assert isinstance(classnames, Sequence) and len(classnames) > 0
42 | use_format = isinstance(templates[0], str)
43 | num_templates = len(templates)
44 | num_classes = len(classnames)
45 | if use_tqdm:
46 | import tqdm
47 | num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1)
48 | iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch)
49 | else:
50 | iter_wrap = iter
51 |
52 | def _process_batch(batch_classnames):
53 | num_batch_classes = len(batch_classnames)
54 | texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates]
55 | texts = tokenizer(texts).to(device)
56 | class_embeddings = model.encode_text(texts, normalize=True)
57 | class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1)
58 | class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True)
59 | class_embeddings = class_embeddings.T
60 | return class_embeddings
61 |
62 | with torch.no_grad():
63 | if num_classes_per_batch:
64 | batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))]
65 | zeroshot_weights = torch.cat(batched_embeds, dim=1)
66 | else:
67 | zeroshot_weights = _process_batch(classnames)
68 | return zeroshot_weights
69 |
70 |
71 | def build_zero_shot_classifier_legacy(
72 | model,
73 | tokenizer,
74 | classnames: Sequence[str],
75 | templates: Sequence[Union[Callable, str]],
76 | device: Union[str, torch.device] = 'cpu',
77 | use_tqdm: bool = False,
78 | ):
79 | """ Build zero-shot classifier weights by iterating over class names 1 by 1
80 | Args:
81 | model: CLIP model instance
82 | tokenizer: CLIP tokenizer instance
83 | classnames: A sequence of class (label) names
84 | templates: A sequence of callables or format() friendly strings to produce templates per class name
85 | device: Device to use.
86 | use_tqdm: Enable TQDM progress bar.
87 | """
88 | assert isinstance(templates, Sequence) and len(templates) > 0
89 | assert isinstance(classnames, Sequence) and len(classnames) > 0
90 | if use_tqdm:
91 | import tqdm
92 | iter_wrap = tqdm.tqdm
93 | else:
94 | iter_wrap = iter
95 |
96 | use_format = isinstance(templates[0], str)
97 |
98 | with torch.no_grad():
99 | zeroshot_weights = []
100 | for classname in iter_wrap(classnames):
101 | texts = [template.format(classname) if use_format else template(classname) for template in templates]
102 | texts = tokenizer(texts).to(device) # tokenize
103 | class_embeddings = model.encode_text(texts)
104 | class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
105 | class_embedding /= class_embedding.norm()
106 | zeroshot_weights.append(class_embedding)
107 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
108 |
109 | return zeroshot_weights
110 |
111 |
--------------------------------------------------------------------------------
/requirements-test.txt:
--------------------------------------------------------------------------------
1 | pytest-split==0.8.0
2 | pytest==7.2.0
3 | transformers
4 | timm>=0.9.8
5 |
--------------------------------------------------------------------------------
/requirements-training.txt:
--------------------------------------------------------------------------------
1 | torch>=1.9.0
2 | torchvision
3 | webdataset>=0.2.5
4 | regex
5 | ftfy
6 | tqdm
7 | pandas
8 | braceexpand
9 | huggingface_hub
10 | transformers
11 | timm>=0.9.8
12 | fsspec
13 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.9.0
2 | torchvision
3 | regex
4 | ftfy
5 | tqdm
6 | huggingface_hub
7 | sentencepiece
8 | protobuf
9 | timm
10 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """ Setup
2 | """
3 | from setuptools import setup, find_packages
4 | from codecs import open
5 | from os import path
6 |
7 | here = path.abspath(path.dirname(__file__))
8 |
9 | # Get the long description from the README file
10 | with open(path.join(here, 'README.md'), encoding='utf-8') as f:
11 | long_description = f.read()
12 |
13 | def _read_reqs(relpath):
14 | fullpath = path.join(path.dirname(__file__), relpath)
15 | with open(fullpath) as f:
16 | return [s.strip() for s in f.readlines() if (s.strip() and not s.startswith("#"))]
17 |
18 | REQUIREMENTS = _read_reqs("requirements.txt")
19 | TRAINING_REQUIREMENTS = _read_reqs("requirements-training.txt")
20 |
21 | exec(open('src/open_clip/version.py').read())
22 | setup(
23 | name='open_clip_torch',
24 | version=__version__,
25 | description='OpenCLIP',
26 | long_description=long_description,
27 | long_description_content_type='text/markdown',
28 | url='https://github.com/mlfoundations/open_clip',
29 | author='',
30 | author_email='',
31 | classifiers=[
32 | # How mature is this project? Common values are
33 | # 3 - Alpha
34 | # 4 - Beta
35 | # 5 - Production/Stable
36 | 'Development Status :: 3 - Alpha',
37 | 'Intended Audience :: Education',
38 | 'Intended Audience :: Science/Research',
39 | 'License :: OSI Approved :: Apache Software License',
40 | 'Programming Language :: Python :: 3.7',
41 | 'Programming Language :: Python :: 3.8',
42 | 'Programming Language :: Python :: 3.9',
43 | 'Programming Language :: Python :: 3.10',
44 | 'Topic :: Scientific/Engineering',
45 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
46 | 'Topic :: Software Development',
47 | 'Topic :: Software Development :: Libraries',
48 | 'Topic :: Software Development :: Libraries :: Python Modules',
49 | ],
50 |
51 | # Note that this is a string of words separated by whitespace, not a list.
52 | keywords='CLIP pretrained',
53 | package_dir={'': 'src'},
54 | packages=find_packages(where='src'),
55 | include_package_data=True,
56 | install_requires=REQUIREMENTS,
57 | extras_require={
58 | "training": TRAINING_REQUIREMENTS,
59 | },
60 | python_requires='>=3.7',
61 | )
62 |
--------------------------------------------------------------------------------
/training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/DreamLIP/b4de0a43b6c002033c02873f91a695ab449e464c/training/__init__.py
--------------------------------------------------------------------------------
/training/distributed.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.distributed as dist
5 | from datetime import timedelta
6 | try:
7 | import horovod.torch as hvd
8 | except ImportError:
9 | hvd = None
10 |
11 |
12 | def is_global_master(args):
13 | return args.rank == 0
14 |
15 |
16 | def is_local_master(args):
17 | return args.local_rank == 0
18 |
19 |
20 | def is_master(args, local=False):
21 | return is_local_master(args) if local else is_global_master(args)
22 |
23 |
24 | def is_using_horovod():
25 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
26 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
27 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
28 | pmi_vars = ["PMI_RANK", "PMI_SIZE"]
29 | if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]):
30 | return True
31 | else:
32 | return False
33 |
34 |
35 | def is_using_distributed():
36 | if 'WORLD_SIZE' in os.environ:
37 | return int(os.environ['WORLD_SIZE']) > 1
38 | if 'SLURM_NTASKS' in os.environ:
39 | return int(os.environ['SLURM_NTASKS']) > 1
40 | return False
41 |
42 |
43 | def world_info_from_env():
44 | local_rank = 0
45 | for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'):
46 | if v in os.environ:
47 | local_rank = int(os.environ[v])
48 | break
49 | global_rank = 0
50 | for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'):
51 | if v in os.environ:
52 | global_rank = int(os.environ[v])
53 | break
54 | world_size = 1
55 | for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'):
56 | if v in os.environ:
57 | world_size = int(os.environ[v])
58 | break
59 |
60 | return local_rank, global_rank, world_size
61 |
62 |
63 | def init_distributed_device(args):
64 | # Distributed training = training on more than one GPU.
65 | # Works in both single and multi-node scenarios.
66 | args.distributed = False
67 | args.world_size = 1
68 | args.rank = 0 # global rank
69 | args.local_rank = 0
70 | if args.horovod:
71 | assert hvd is not None, "Horovod is not installed"
72 | hvd.init()
73 | args.local_rank = int(hvd.local_rank())
74 | args.rank = hvd.rank()
75 | args.world_size = hvd.size()
76 | args.distributed = True
77 | os.environ['LOCAL_RANK'] = str(args.local_rank)
78 | os.environ['RANK'] = str(args.rank)
79 | os.environ['WORLD_SIZE'] = str(args.world_size)
80 | elif is_using_distributed():
81 | if 'SLURM_PROCID' in os.environ:
82 | # DDP via SLURM
83 | args.local_rank, args.rank, args.world_size = world_info_from_env()
84 | # SLURM var -> torch.distributed vars in case needed
85 | os.environ['LOCAL_RANK'] = str(args.local_rank)
86 | os.environ['RANK'] = str(args.rank)
87 | os.environ['WORLD_SIZE'] = str(args.world_size)
88 | torch.distributed.init_process_group(
89 | backend=args.dist_backend,
90 | init_method=args.dist_url,
91 | world_size=args.world_size,
92 | rank=args.rank,
93 | timeout=timedelta(days=1),
94 | )
95 | else:
96 | # DDP via torchrun, torch.distributed.launch
97 | args.local_rank, _, _ = world_info_from_env()
98 | torch.distributed.init_process_group(
99 | backend=args.dist_backend,
100 | init_method=args.dist_url,
101 | timeout=timedelta(days=1)
102 | )
103 | args.world_size = torch.distributed.get_world_size()
104 | args.rank = torch.distributed.get_rank()
105 | args.distributed = True
106 | print(f'Using torch distributed DDP: {args.distributed}')
107 | else:
108 | print(f'No torch distributed DDP: {args.distributed}')
109 | if torch.cuda.is_available():
110 | if args.distributed and not args.no_set_device_rank:
111 | device = 'cuda:%d' % args.local_rank
112 | else:
113 | device = 'cuda:0'
114 | torch.cuda.set_device(device)
115 | else:
116 | device = 'cpu'
117 | args.device = device
118 | device = torch.device(device)
119 | return device
120 |
121 |
122 | def broadcast_object(args, obj, src=0):
123 | # broadcast a pickle-able python object from rank-0 to all ranks
124 | if args.horovod:
125 | return hvd.broadcast_object(obj, root_rank=src)
126 | else:
127 | if args.rank == src:
128 | objects = [obj]
129 | else:
130 | objects = [None]
131 | dist.broadcast_object_list(objects, src=src)
132 | return objects[0]
133 |
134 |
135 | def all_gather_object(args, obj, dst=0):
136 | # gather a pickle-able python object across all ranks
137 | if args.horovod:
138 | return hvd.allgather_object(obj)
139 | else:
140 | objects = [None for _ in range(args.world_size)]
141 | dist.all_gather_object(objects, obj)
142 | return objects
143 |
--------------------------------------------------------------------------------
/training/file_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import multiprocessing
4 | import subprocess
5 | import time
6 | import fsspec
7 | import torch
8 | from tqdm import tqdm
9 |
10 | def remote_sync_s3(local_dir, remote_dir):
11 | # skip epoch_latest which can change during sync.
12 | result = subprocess.run(["aws", "s3", "sync", local_dir, remote_dir, '--exclude', '*epoch_latest.pt'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
13 | if result.returncode != 0:
14 | logging.error(f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}")
15 | return False
16 |
17 | logging.info(f"Successfully synced with S3 bucket")
18 | return True
19 |
20 | def remote_sync_fsspec(local_dir, remote_dir):
21 | # FIXME currently this is slow and not recommended. Look into speeding up.
22 | a = fsspec.get_mapper(local_dir)
23 | b = fsspec.get_mapper(remote_dir)
24 |
25 | for k in a:
26 | # skip epoch_latest which can change during sync.
27 | if 'epoch_latest.pt' in k:
28 | continue
29 |
30 | logging.info(f'Attempting to sync {k}')
31 | if k in b and len(a[k]) == len(b[k]):
32 | logging.debug(f'Skipping remote sync for {k}.')
33 | continue
34 |
35 | try:
36 | logging.info(f'Successful sync for {k}.')
37 | b[k] = a[k]
38 | except Exception as e:
39 | logging.info(f'Error during remote sync for {k}: {e}')
40 | return False
41 |
42 | return True
43 |
44 | def remote_sync(local_dir, remote_dir, protocol):
45 | logging.info('Starting remote sync.')
46 | if protocol == 's3':
47 | return remote_sync_s3(local_dir, remote_dir)
48 | elif protocol == 'fsspec':
49 | return remote_sync_fsspec(local_dir, remote_dir)
50 | else:
51 | logging.error('Remote protocol not known')
52 | return False
53 |
54 | def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol):
55 | while True:
56 | time.sleep(sync_every)
57 | remote_sync(local_dir, remote_dir, protocol)
58 |
59 | def start_sync_process(sync_every, local_dir, remote_dir, protocol):
60 | p = multiprocessing.Process(target=keep_running_remote_sync, args=(sync_every, local_dir, remote_dir, protocol))
61 | return p
62 |
63 | # Note: we are not currently using this save function.
64 | def pt_save(pt_obj, file_path):
65 | of = fsspec.open(file_path, "wb")
66 | with of as f:
67 | torch.save(pt_obj, file_path)
68 |
69 | def pt_load(file_path, map_location=None):
70 | if file_path.startswith('s3'):
71 | logging.info('Loading remote checkpoint, which may take a bit.')
72 | of = fsspec.open(file_path, "rb")
73 | with of as f:
74 | out = torch.load(f, map_location=map_location)
75 | return out
76 |
77 | def check_exists(file_path):
78 | try:
79 | with fsspec.open(file_path):
80 | pass
81 | except FileNotFoundError:
82 | return False
83 | return True
84 |
--------------------------------------------------------------------------------
/training/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 |
4 | def setup_logging(log_file, level, include_host=False):
5 | if include_host:
6 | import socket
7 | hostname = socket.gethostname()
8 | formatter = logging.Formatter(
9 | f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S')
10 | else:
11 | formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S')
12 |
13 | logging.root.setLevel(level)
14 | loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
15 | for logger in loggers:
16 | logger.setLevel(level)
17 |
18 | stream_handler = logging.StreamHandler()
19 | stream_handler.setFormatter(formatter)
20 | logging.root.addHandler(stream_handler)
21 |
22 | if log_file:
23 | file_handler = logging.FileHandler(filename=log_file)
24 | file_handler.setFormatter(formatter)
25 | logging.root.addHandler(file_handler)
26 |
27 |
--------------------------------------------------------------------------------
/training/precision.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from contextlib import suppress
3 |
4 |
5 | def get_autocast(precision):
6 | if precision == 'amp':
7 | return torch.cuda.amp.autocast
8 | elif precision == 'amp_bfloat16' or precision == 'amp_bf16':
9 | # amp_bfloat16 is more stable than amp float16 for clip training
10 | return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
11 | else:
12 | return suppress
13 |
--------------------------------------------------------------------------------
/training/profiler.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import open_clip
5 | import pandas as pd
6 | from torch.utils.flop_counter import FlopCounterMode
7 | try:
8 | import fvcore
9 | except:
10 | fvcore = None
11 |
12 | parser = argparse.ArgumentParser(description='OpenCLIP Profiler')
13 |
14 | # benchmark specific args
15 | parser.add_argument('--model', metavar='NAME', default='',
16 | help='model(s) to profile')
17 | parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
18 | help='Output csv file for results')
19 | parser.add_argument('--profiler', default='torch', type=str, choices=['torch', 'fvcore'])
20 | parser.add_argument('--batch-size', default=1, type=int, help='Batch size for profiling')
21 |
22 |
23 | def profile_fvcore(
24 | model,
25 | image_input_size=(3, 224, 224),
26 | text_input_size=(77,),
27 | batch_size=1,
28 | detailed=False,
29 | force_cpu=False
30 | ):
31 | if force_cpu:
32 | model = model.to('cpu')
33 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
34 | example_image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
35 | example_text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
36 | fca = fvcore.nn.FlopCountAnalysis(model, (example_image_input, example_text_input))
37 | aca = fvcore.nn.ActivationCountAnalysis(model, (example_image_input, example_text_input))
38 | if detailed:
39 | fcs = fvcore.nn.flop_count_str(fca)
40 | print(fcs)
41 | return fca.total() / batch_size, aca.total() / batch_size
42 |
43 |
44 | def profile_fvcore_text(
45 | model,
46 | text_input_size=(77,),
47 | batch_size=1,
48 | detailed=False,
49 | force_cpu=False
50 | ):
51 | if force_cpu:
52 | model = model.to('cpu')
53 | device = next(model.parameters()).device
54 | example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
55 | fca = fvcore.nn.FlopCountAnalysis(model, example_input)
56 | aca = fvcore.nn.ActivationCountAnalysis(model, example_input)
57 | if detailed:
58 | fcs = fvcore.nn.flop_count_str(fca)
59 | print(fcs)
60 | return fca.total() / batch_size, aca.total() / batch_size
61 |
62 |
63 | def profile_fvcore_image(
64 | model,
65 | image_input_size=(3, 224, 224),
66 | batch_size=1,
67 | detailed=False,
68 | force_cpu=False
69 | ):
70 | if force_cpu:
71 | model = model.to('cpu')
72 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
73 | example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
74 | fca = fvcore.nn.FlopCountAnalysis(model, example_input)
75 | aca = fvcore.nn.ActivationCountAnalysis(model, example_input)
76 | if detailed:
77 | fcs = fvcore.nn.flop_count_str(fca)
78 | print(fcs)
79 | return fca.total() / batch_size, aca.total() / batch_size
80 |
81 |
82 | def profile_torch_image(model, image_input_size, batch_size=1, force_cpu=False):
83 | """Profile the image encoder using torch.utils.flop_counter"""
84 | if force_cpu:
85 | model = model.to('cpu')
86 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
87 | example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
88 |
89 | flop_counter = FlopCounterMode()
90 | with flop_counter:
91 | model(example_input)
92 | total_flops = sum(flop_counter.get_flop_counts()['Global'].values())
93 | return total_flops / batch_size
94 |
95 |
96 | def profile_torch_text(model, text_input_size, batch_size=1, force_cpu=False):
97 | """Profile the text encoder using torch.utils.flop_counter"""
98 | if force_cpu:
99 | model = model.to('cpu')
100 | device = next(model.parameters()).device
101 | example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
102 |
103 | flop_counter = FlopCounterMode()
104 | with flop_counter:
105 | model(example_input)
106 | total_flops = sum(flop_counter.get_flop_counts()['Global'].values())
107 | return total_flops / batch_size
108 |
109 |
110 | def profile_torch(model, text_input_size, image_input_size, batch_size=1, force_cpu=False):
111 | """Profile the full model using torch.utils.flop_counter"""
112 | if force_cpu:
113 | model = model.to('cpu')
114 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
115 | image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
116 | text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
117 |
118 | flop_counter = FlopCounterMode()
119 | with flop_counter:
120 | model(image_input, text_input)
121 | total_flops = sum(flop_counter.get_flop_counts()['Global'].values())
122 | return total_flops / batch_size
123 |
124 |
125 | def count_params(model):
126 | return sum([m.numel() for m in model.parameters()])
127 |
128 | def profile_model(model_name, batch_size=1, profiler='torch'):
129 | assert profiler in ['torch', 'fvcore'], 'Only torch and fvcore profilers are supported'
130 | if profiler == 'fvcore':
131 | assert fvcore is not None, 'Please install fvcore.'
132 | model = open_clip.create_model(model_name, force_custom_text=True, pretrained_hf=False)
133 | model.eval()
134 | if torch.cuda.is_available():
135 | model = model.cuda()
136 |
137 | if isinstance(model.visual.image_size, (tuple, list)):
138 | image_input_size = (3,) + tuple(model.visual.image_size[-2:])
139 | else:
140 | image_input_size = (3, model.visual.image_size, model.visual.image_size)
141 |
142 | text_input_size = (77,)
143 | if hasattr(model, 'context_length') and model.context_length:
144 | text_input_size = (model.context_length,)
145 |
146 | results = {}
147 | results['model'] = model_name
148 | results['image_size'] = image_input_size[1]
149 |
150 | model_cfg = open_clip.get_model_config(model_name)
151 | if model_cfg:
152 | vision_cfg = open_clip.CLIPVisionCfg(**model_cfg['vision_cfg'])
153 | text_cfg = open_clip.CLIPTextCfg(**model_cfg['text_cfg'])
154 | results['image_width'] = int(vision_cfg.width)
155 | results['text_width'] = int(text_cfg.width)
156 | results['embed_dim'] = int(model_cfg['embed_dim'])
157 | else:
158 | results['image_width'] = 0
159 | results['text_width'] = 0
160 | results['embed_dim'] = 0
161 |
162 | retries = 2
163 | while retries:
164 | retries -= 1
165 | try:
166 | results['mparams'] = round(count_params(model) / 1e6, 2)
167 | results['image_mparams'] = round(count_params(model.visual) / 1e6, 2)
168 | results['text_mparams'] = round(count_params(model.text) / 1e6, 2)
169 |
170 | if profiler == 'fvcore':
171 | macs, acts = profile_fvcore(
172 | model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size)
173 |
174 | image_macs, image_acts = profile_fvcore_image(
175 | model.visual, image_input_size=image_input_size, force_cpu=not retries, batch_size=batch_size)
176 |
177 | text_macs, text_acts = profile_fvcore_text(
178 | model.text, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size)
179 |
180 | results['gmacs'] = round(macs / 1e9, 2)
181 | results['macts'] = round(acts / 1e6, 2)
182 |
183 | results['image_gmacs'] = round(image_macs / 1e9, 2)
184 | results['image_macts'] = round(image_acts / 1e6, 2)
185 |
186 | results['text_gmacs'] = round(text_macs / 1e9, 2)
187 | results['text_macts'] = round(text_acts / 1e6, 2)
188 | elif profiler == 'torch':
189 | image_flops = profile_torch_image(
190 | model.visual, image_input_size=image_input_size, force_cpu=not retries, batch_size=batch_size)
191 | text_flops = profile_torch_text(
192 | model.text, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size)
193 | total_flops = profile_torch(
194 | model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size)
195 |
196 | results['gflops'] = round(total_flops / 1e9, 2)
197 | results['image_gflops'] = round(image_flops / 1e9, 2)
198 | results['text_gflops'] = round(text_flops / 1e9, 2)
199 |
200 | except RuntimeError as e:
201 | pass
202 | return results
203 |
204 |
205 | def main():
206 | args = parser.parse_args()
207 |
208 | # FIXME accept a text file name to allow lists of models in txt/csv
209 | if args.model == 'all':
210 | parsed_model = open_clip.list_models()
211 | else:
212 | parsed_model = args.model.split(',')
213 |
214 | results = []
215 | models_with_errors = []
216 | for m in parsed_model:
217 | print('='*100)
218 | print(f'Profiling {m}')
219 | try:
220 | row = profile_model(m, batch_size=args.batch_size, profiler=args.profiler)
221 | results.append(row)
222 | except Exception as e:
223 | print(f'Error profiling {m}: {e}')
224 | import traceback
225 | traceback.print_exc()
226 | models_with_errors.append(m)
227 |
228 | df = pd.DataFrame(results, columns=results[0].keys())
229 |
230 | if 'gmacs' in df.columns:
231 | df = df.sort_values(by=['gmacs', 'mparams', 'model'])
232 | else:
233 | df = df.sort_values(by=['gflops', 'mparams', 'model'])
234 |
235 | print('='*100)
236 | print('Done.')
237 | print(df)
238 | if args.results_file:
239 | df.to_csv(args.results_file, index=False)
240 |
241 | if models_with_errors:
242 | print('Models with errors:', models_with_errors)
243 |
244 |
245 | if __name__ == '__main__':
246 | main()
247 |
--------------------------------------------------------------------------------
/training/random_aug.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
5 | from timm.data.transforms import RandomResizedCropAndInterpolation
6 | from torchvision import transforms
7 | import urllib
8 | from tqdm import tqdm
9 | from torch.utils.data import default_collate
10 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
11 | from typing_extensions import TypedDict
12 | from numpy.typing import NDArray
13 | import importlib.machinery
14 | import importlib.util
15 | import types
16 | import random
17 |
18 |
19 |
20 |
21 | def pad(orig_items, key, max_length=None, padding_value=0, padding_side="left"):
22 | items = []
23 | if isinstance(orig_items[0][key], list):
24 | assert isinstance(orig_items[0][key][0], torch.Tensor)
25 | for it in orig_items:
26 | for tr in it[key]:
27 | items.append({key: tr})
28 | else:
29 | assert isinstance(orig_items[0][key], torch.Tensor)
30 | items = orig_items
31 |
32 | batch_size = len(items)
33 | shape = items[0][key].shape
34 | dim = len(shape)
35 | assert dim <= 3
36 | if max_length is None:
37 | max_length = 0
38 | max_length = max(max_length, max(item[key].shape[-1] for item in items))
39 | min_length = min(item[key].shape[-1] for item in items)
40 | dtype = items[0][key].dtype
41 |
42 | if dim == 1:
43 | return torch.cat([item[key] for item in items], dim=0)
44 | elif dim == 2:
45 | if max_length == min_length:
46 | return torch.cat([item[key] for item in items], dim=0)
47 | tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
48 | else:
49 | tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
50 |
51 | for i, item in enumerate(items):
52 | if dim == 2:
53 | if padding_side == "left":
54 | tensor[i, -len(item[key][0]):] = item[key][0].clone()
55 | else:
56 | tensor[i, : len(item[key][0])] = item[key][0].clone()
57 | elif dim == 3:
58 | if padding_side == "left":
59 | tensor[i, -len(item[key][0]):, :] = item[key][0].clone()
60 | else:
61 | tensor[i, : len(item[key][0]), :] = item[key][0].clone()
62 |
63 | return tensor
64 |
65 |
66 |
67 | class _DictTree(TypedDict):
68 | value: str
69 | children: List["_DictTree"]
70 | depth: int
71 | segment_id: int
72 | need_predict: bool
73 | is_image: bool
74 |
75 |
76 | class _PrevExtTableStates(TypedDict):
77 | ext_table: Dict[int, str]
78 | token_id_table: Dict[str, Dict[int, int]]
79 |
80 |
81 | class _TransformFuncDict(TypedDict):
82 | loader: importlib.machinery.SourceFileLoader
83 | module: types.ModuleType
84 | last_m: float
85 |
86 |
87 |
88 | def rel_to_bucket(n_up: int, n_down: int, max_depth: int = 8):
89 | ret = n_up * max_depth + n_down
90 | if ret == 0:
91 | return ret
92 | else:
93 | # bucket 1 is reserved for incontext samples
94 | return ret + 1
95 |
96 |
97 | # aug functions
98 | def identity_func(img):
99 | return img
100 |
101 |
102 | def autocontrast_func(img, cutoff=0):
103 | '''
104 | same output as PIL.ImageOps.autocontrast
105 | '''
106 | n_bins = 256
107 |
108 | def tune_channel(ch):
109 | n = ch.size
110 | cut = cutoff * n // 100
111 | if cut == 0:
112 | high, low = ch.max(), ch.min()
113 | else:
114 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
115 | low = np.argwhere(np.cumsum(hist) > cut)
116 | low = 0 if low.shape[0] == 0 else low[0]
117 | high = np.argwhere(np.cumsum(hist[::-1]) > cut)
118 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
119 | if high <= low:
120 | table = np.arange(n_bins)
121 | else:
122 | scale = (n_bins - 1) / (high - low)
123 | table = np.arange(n_bins) * scale - low * scale
124 | table[table < 0] = 0
125 | table[table > n_bins - 1] = n_bins - 1
126 | table = table.clip(0, 255).astype(np.uint8)
127 | return table[ch]
128 |
129 | channels = [tune_channel(ch) for ch in cv2.split(img)]
130 | out = cv2.merge(channels)
131 | return out
132 |
133 |
134 | def equalize_func(img):
135 | '''
136 | same output as PIL.ImageOps.equalize
137 | PIL's implementation is different from cv2.equalize
138 | '''
139 | n_bins = 256
140 |
141 | def tune_channel(ch):
142 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
143 | non_zero_hist = hist[hist != 0].reshape(-1)
144 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
145 | if step == 0:
146 | return ch
147 | n = np.empty_like(hist)
148 | n[0] = step // 2
149 | n[1:] = hist[:-1]
150 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
151 | return table[ch]
152 |
153 | channels = [tune_channel(ch) for ch in cv2.split(img)]
154 | out = cv2.merge(channels)
155 | return out
156 |
157 |
158 | def rotate_func(img, degree, fill=(0, 0, 0)):
159 | '''
160 | like PIL, rotate by degree, not radians
161 | '''
162 | H, W = img.shape[0], img.shape[1]
163 | center = W / 2, H / 2
164 | M = cv2.getRotationMatrix2D(center, degree, 1)
165 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
166 | return out
167 |
168 |
169 | def solarize_func(img, thresh=128):
170 | '''
171 | same output as PIL.ImageOps.posterize
172 | '''
173 | table = np.array([el if el < thresh else 255 - el for el in range(256)])
174 | table = table.clip(0, 255).astype(np.uint8)
175 | out = table[img]
176 | return out
177 |
178 |
179 | def color_func(img, factor):
180 | '''
181 | same output as PIL.ImageEnhance.Color
182 | '''
183 | # implementation according to PIL definition, quite slow
184 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
185 | # out = blend(degenerate, img, factor)
186 | # M = (
187 | # np.eye(3) * factor
188 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
189 | # )[np.newaxis, np.newaxis, :]
190 | M = (
191 | np.float32([
192 | [0.886, -0.114, -0.114],
193 | [-0.587, 0.413, -0.587],
194 | [-0.299, -0.299, 0.701]]) * factor
195 | + np.float32([[0.114], [0.587], [0.299]])
196 | )
197 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
198 | return out
199 |
200 |
201 | def contrast_func(img, factor):
202 | """
203 | same output as PIL.ImageEnhance.Contrast
204 | """
205 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
206 | table = np.array([(
207 | el - mean) * factor + mean
208 | for el in range(256)
209 | ]).clip(0, 255).astype(np.uint8)
210 | out = table[img]
211 | return out
212 |
213 |
214 | def brightness_func(img, factor):
215 | '''
216 | same output as PIL.ImageEnhance.Contrast
217 | '''
218 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
219 | out = table[img]
220 | return out
221 |
222 |
223 | def sharpness_func(img, factor):
224 | '''
225 | The differences the this result and PIL are all on the 4 boundaries, the center
226 | areas are same
227 | '''
228 | kernel = np.ones((3, 3), dtype=np.float32)
229 | kernel[1][1] = 5
230 | kernel /= 13
231 | degenerate = cv2.filter2D(img, -1, kernel)
232 | if factor == 0.0:
233 | out = degenerate
234 | elif factor == 1.0:
235 | out = img
236 | else:
237 | out = img.astype(np.float32)
238 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
239 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
240 | out = out.astype(np.uint8)
241 | return out
242 |
243 |
244 | def shear_x_func(img, factor, fill=(0, 0, 0)):
245 | H, W = img.shape[0], img.shape[1]
246 | M = np.float32([[1, factor, 0], [0, 1, 0]])
247 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
248 | return out
249 |
250 |
251 | def translate_x_func(img, offset, fill=(0, 0, 0)):
252 | '''
253 | same output as PIL.Image.transform
254 | '''
255 | H, W = img.shape[0], img.shape[1]
256 | M = np.float32([[1, 0, -offset], [0, 1, 0]])
257 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
258 | return out
259 |
260 |
261 | def translate_y_func(img, offset, fill=(0, 0, 0)):
262 | '''
263 | same output as PIL.Image.transform
264 | '''
265 | H, W = img.shape[0], img.shape[1]
266 | M = np.float32([[1, 0, 0], [0, 1, -offset]])
267 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
268 | return out
269 |
270 |
271 | def posterize_func(img, bits):
272 | '''
273 | same output as PIL.ImageOps.posterize
274 | '''
275 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
276 | return out
277 |
278 |
279 | def shear_y_func(img, factor, fill=(0, 0, 0)):
280 | H, W = img.shape[0], img.shape[1]
281 | M = np.float32([[1, 0, 0], [factor, 1, 0]])
282 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
283 | return out
284 |
285 |
286 | def cutout_func(img, pad_size, replace=(0, 0, 0)):
287 | replace = np.array(replace, dtype=np.uint8)
288 | H, W = img.shape[0], img.shape[1]
289 | rh, rw = np.random.random(2)
290 | pad_size = pad_size // 2
291 | ch, cw = int(rh * H), int(rw * W)
292 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
293 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
294 | out = img.copy()
295 | out[x1:x2, y1:y2, :] = replace
296 | return out
297 |
298 |
299 | # level to args
300 | def enhance_level_to_args(MAX_LEVEL):
301 | def level_to_args(level):
302 | return ((level / MAX_LEVEL) * 1.8 + 0.1,)
303 | return level_to_args
304 |
305 |
306 | def shear_level_to_args(MAX_LEVEL, replace_value):
307 | def level_to_args(level):
308 | level = (level / MAX_LEVEL) * 0.3
309 | if np.random.random() > 0.5:
310 | level = -level
311 | return (level, replace_value)
312 |
313 | return level_to_args
314 |
315 |
316 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
317 | def level_to_args(level):
318 | level = (level / MAX_LEVEL) * float(translate_const)
319 | if np.random.random() > 0.5:
320 | level = -level
321 | return (level, replace_value)
322 |
323 | return level_to_args
324 |
325 |
326 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
327 | def level_to_args(level):
328 | level = int((level / MAX_LEVEL) * cutout_const)
329 | return (level, replace_value)
330 |
331 | return level_to_args
332 |
333 |
334 | def solarize_level_to_args(MAX_LEVEL):
335 | def level_to_args(level):
336 | level = int((level / MAX_LEVEL) * 256)
337 | return (level, )
338 | return level_to_args
339 |
340 |
341 | def none_level_to_args(level):
342 | return ()
343 |
344 |
345 | def posterize_level_to_args(MAX_LEVEL):
346 | def level_to_args(level):
347 | level = int((level / MAX_LEVEL) * 4)
348 | return (level, )
349 | return level_to_args
350 |
351 |
352 | def rotate_level_to_args(MAX_LEVEL, replace_value):
353 | def level_to_args(level):
354 | level = (level / MAX_LEVEL) * 30
355 | if np.random.random() < 0.5:
356 | level = -level
357 | return (level, replace_value)
358 |
359 | return level_to_args
360 |
361 |
362 | func_dict = {
363 | 'Identity': identity_func,
364 | 'AutoContrast': autocontrast_func,
365 | 'Equalize': equalize_func,
366 | 'Rotate': rotate_func,
367 | 'Solarize': solarize_func,
368 | 'Color': color_func,
369 | 'Contrast': contrast_func,
370 | 'Brightness': brightness_func,
371 | 'Sharpness': sharpness_func,
372 | 'ShearX': shear_x_func,
373 | 'TranslateX': translate_x_func,
374 | 'TranslateY': translate_y_func,
375 | 'Posterize': posterize_func,
376 | 'ShearY': shear_y_func,
377 | }
378 |
379 | translate_const = 10
380 | MAX_LEVEL = 10
381 | replace_value = (128, 128, 128)
382 | arg_dict = {
383 | 'Identity': none_level_to_args,
384 | 'AutoContrast': none_level_to_args,
385 | 'Equalize': none_level_to_args,
386 | 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
387 | 'Solarize': solarize_level_to_args(MAX_LEVEL),
388 | 'Color': enhance_level_to_args(MAX_LEVEL),
389 | 'Contrast': enhance_level_to_args(MAX_LEVEL),
390 | 'Brightness': enhance_level_to_args(MAX_LEVEL),
391 | 'Sharpness': enhance_level_to_args(MAX_LEVEL),
392 | 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
393 | 'TranslateX': translate_level_to_args(
394 | translate_const, MAX_LEVEL, replace_value
395 | ),
396 | 'TranslateY': translate_level_to_args(
397 | translate_const, MAX_LEVEL, replace_value
398 | ),
399 | 'Posterize': posterize_level_to_args(MAX_LEVEL),
400 | 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
401 | }
402 |
403 |
404 | class RandomAugment(object):
405 |
406 | def __init__(self, N=2, M=10, isPIL=False, augs=[]):
407 | self.N = N
408 | self.M = M
409 | self.isPIL = isPIL
410 | if augs:
411 | self.augs = augs
412 | else:
413 | self.augs = list(arg_dict.keys())
414 |
415 | def get_random_ops(self):
416 | sampled_ops = np.random.choice(self.augs, self.N)
417 | return [(op, 0.5, self.M) for op in sampled_ops]
418 |
419 | def __call__(self, img):
420 | if self.isPIL:
421 | img = np.array(img)
422 | ops = self.get_random_ops()
423 | for name, prob, level in ops:
424 | if np.random.random() > prob:
425 | continue
426 | args = arg_dict[name](level)
427 | img = func_dict[name](img, *args)
428 | return img
429 |
430 |
--------------------------------------------------------------------------------
/training/scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def assign_learning_rate(optimizer, new_lr):
5 | for param_group in optimizer.param_groups:
6 | param_group["lr"] = new_lr
7 |
8 |
9 | def _warmup_lr(base_lr, warmup_length, step):
10 | return base_lr * (step + 1) / warmup_length
11 |
12 |
13 | def const_lr(optimizer, base_lr, warmup_length, steps):
14 | def _lr_adjuster(step):
15 | if step < warmup_length:
16 | lr = _warmup_lr(base_lr, warmup_length, step)
17 | else:
18 | lr = base_lr
19 | assign_learning_rate(optimizer, lr)
20 | return lr
21 | return _lr_adjuster
22 |
23 |
24 | def const_lr_cooldown(optimizer, base_lr, warmup_length, steps, cooldown_steps, cooldown_power=1.0, cooldown_end_lr=0.):
25 | def _lr_adjuster(step):
26 | start_cooldown_step = steps - cooldown_steps
27 | if step < warmup_length:
28 | lr = _warmup_lr(base_lr, warmup_length, step)
29 | else:
30 | if step < start_cooldown_step:
31 | lr = base_lr
32 | else:
33 | e = step - start_cooldown_step
34 | es = steps - start_cooldown_step
35 | # linear decay if power == 1; polynomial decay otherwise;
36 | decay = (1 - (e/es)) ** cooldown_power
37 | lr = decay * (base_lr - cooldown_end_lr) + cooldown_end_lr
38 | assign_learning_rate(optimizer, lr)
39 | return lr
40 | return _lr_adjuster
41 |
42 |
43 | def cosine_lr(optimizer, base_lr, warmup_length, steps):
44 | def _lr_adjuster(step):
45 | if step < warmup_length:
46 | lr = _warmup_lr(base_lr, warmup_length, step)
47 | else:
48 | e = step - warmup_length
49 | es = steps - warmup_length
50 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
51 | assign_learning_rate(optimizer, lr)
52 | return lr
53 | return _lr_adjuster
54 |
--------------------------------------------------------------------------------
/training/zero_shot.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 | from tqdm import tqdm
5 |
6 | from open_clip import get_input_dtype, get_tokenizer, build_zero_shot_classifier, \
7 | IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES
8 | from .precision import get_autocast
9 |
10 |
11 | def accuracy(output, target, topk=(1,)):
12 | pred = output.topk(max(topk), 1, True, True)[1].t()
13 | correct = pred.eq(target.view(1, -1).expand_as(pred))
14 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
15 |
16 |
17 | def run(model, classifier, dataloader, args):
18 | autocast = get_autocast(args.precision)
19 | input_dtype = get_input_dtype(args.precision)
20 |
21 | with torch.no_grad():
22 | top1, top5, n = 0., 0., 0.
23 | for images, target in tqdm(dataloader, unit_scale=args.batch_size):
24 | images = images.to(device=args.device, dtype=input_dtype)
25 | target = target.to(args.device)
26 |
27 | with autocast():
28 | # predict
29 | output = model(image=images)
30 | image_features = output['image_features'] if isinstance(output, dict) else output[0]
31 | logits = 100. * image_features @ classifier
32 |
33 | # measure accuracy
34 | acc1, acc5 = accuracy(logits, target, topk=(1, 5))
35 | top1 += acc1
36 | top5 += acc5
37 | n += images.size(0)
38 |
39 | top1 = (top1 / n)
40 | top5 = (top5 / n)
41 | return top1, top5
42 |
43 |
44 | def zero_shot_eval(model, data, epoch, args, tokenizer=None):
45 | if 'imagenet-val' not in data and 'imagenet-v2' not in data:
46 | return {}
47 | if args.zeroshot_frequency == 0:
48 | return {}
49 | if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs:
50 | return {}
51 | if args.distributed and not args.horovod:
52 | model = model.module
53 |
54 | logging.info('Starting zero-shot imagenet.')
55 | if tokenizer is None:
56 | tokenizer = get_tokenizer(args.model)
57 |
58 | logging.info('Building zero-shot classifier')
59 | autocast = get_autocast(args.precision)
60 | with autocast():
61 | classifier = build_zero_shot_classifier(
62 | model,
63 | tokenizer=tokenizer,
64 | classnames=IMAGENET_CLASSNAMES,
65 | templates=OPENAI_IMAGENET_TEMPLATES,
66 | num_classes_per_batch=10,
67 | device=args.device,
68 | use_tqdm=True,
69 | )
70 |
71 | logging.info('Using classifier')
72 | results = {}
73 | if 'imagenet-val' in data:
74 | top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args)
75 | results['imagenet-zeroshot-val-top1'] = top1
76 | results['imagenet-zeroshot-val-top5'] = top5
77 | if 'imagenet-v2' in data:
78 | top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args)
79 | results['imagenetv2-zeroshot-val-top1'] = top1
80 | results['imagenetv2-zeroshot-val-top5'] = top5
81 |
82 | logging.info('Finished zero-shot imagenet.')
83 |
84 | return results
85 |
--------------------------------------------------------------------------------