├── .gitignore ├── LICENSE.md ├── README.md ├── assets ├── 4kpro_results.png ├── example_selection_maps │ ├── bottom_up_selection_prob.png │ ├── top_down_selection_prob_1.png │ └── top_down_selection_prob_2.png ├── teaser.png ├── test_images │ └── dock.jpg └── vila_hd_results.png ├── ps3 ├── __init__.py ├── configuration_ps3.py ├── image_processing_ps3.py ├── modeling_ps3.py ├── modeling_ps3_text.py └── tokenization_ps3.py ├── pyproject.toml └── train ├── .gitignore ├── CITATION.cff ├── HISTORY.md ├── Makefile ├── README.md ├── pyproject.toml ├── pytest.ini ├── requirements-test.txt ├── requirements-training.txt ├── requirements.txt ├── scripts ├── ps3_1.5k_siglip.sh └── ps3_4k_siglip.sh └── src ├── open_clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── coca_model.py ├── constants.py ├── convert.py ├── factory.py ├── hf_configs.py ├── hf_model.py ├── loss.py ├── model.py ├── model_configs │ ├── EVA01-g-14-plus.json │ ├── EVA01-g-14.json │ ├── EVA02-B-16.json │ ├── EVA02-E-14-plus.json │ ├── EVA02-E-14.json │ ├── EVA02-L-14-336.json │ ├── EVA02-L-14.json │ ├── MobileCLIP-B.json │ ├── MobileCLIP-S1.json │ ├── MobileCLIP-S2.json │ ├── PS3-1.5K-C-RADIO-v2-L.json │ ├── PS3-1.5K-SigLIP.json │ ├── PS3-1.5K-SigLIP2.json │ ├── PS3-4K-C-RADIO-v2-L.json │ ├── PS3-4K-SigLIP.json │ ├── RN101-quickgelu.json │ ├── RN101.json │ ├── RN50-quickgelu.json │ ├── RN50.json │ ├── RN50x16-quickgelu.json │ ├── RN50x16.json │ ├── RN50x4-quickgelu.json │ ├── RN50x4.json │ ├── RN50x64-quickgelu.json │ ├── RN50x64.json │ ├── ViT-B-16-SigLIP-256.json │ ├── ViT-B-16-SigLIP-384.json │ ├── ViT-B-16-SigLIP-512.json │ ├── ViT-B-16-SigLIP-i18n-256.json │ ├── ViT-B-16-SigLIP.json │ ├── ViT-B-16-SigLIP2-256.json │ ├── ViT-B-16-SigLIP2-384.json │ ├── ViT-B-16-SigLIP2-512.json │ ├── ViT-B-16-SigLIP2.json │ ├── ViT-B-16-plus-240.json │ ├── ViT-B-16-plus.json │ ├── ViT-B-16-quickgelu.json │ ├── ViT-B-16.json │ ├── ViT-B-32-256.json │ ├── ViT-B-32-SigLIP2-256.json │ ├── ViT-B-32-plus-256.json │ ├── ViT-B-32-quickgelu.json │ ├── ViT-B-32.json │ ├── ViT-H-14-378-quickgelu.json │ ├── ViT-H-14-378.json │ ├── ViT-H-14-CLIPA-336.json │ ├── ViT-H-14-CLIPA.json │ ├── ViT-H-14-quickgelu.json │ ├── ViT-H-14.json │ ├── ViT-H-16.json │ ├── ViT-L-14-280.json │ ├── ViT-L-14-336-quickgelu.json │ ├── ViT-L-14-336.json │ ├── ViT-L-14-CLIPA-336-S2-672.json │ ├── ViT-L-14-CLIPA-336-S3-672.json │ ├── ViT-L-14-CLIPA-336.json │ ├── ViT-L-14-CLIPA.json │ ├── ViT-L-14-quickgelu.json │ ├── ViT-L-14.json │ ├── ViT-L-16-320.json │ ├── ViT-L-16-SigLIP-256.json │ ├── ViT-L-16-SigLIP-384.json │ ├── ViT-L-16-SigLIP2-256.json │ ├── ViT-L-16-SigLIP2-384.json │ ├── ViT-L-16-SigLIP2-512.json │ ├── ViT-L-16.json │ ├── ViT-M-16-alt.json │ ├── ViT-M-16.json │ ├── ViT-M-32-alt.json │ ├── ViT-M-32.json │ ├── ViT-S-16-alt.json │ ├── ViT-S-16.json │ ├── ViT-S-32-alt.json │ ├── ViT-S-32.json │ ├── ViT-SO400M-14-SigLIP-378.json │ ├── ViT-SO400M-14-SigLIP-384.json │ ├── ViT-SO400M-14-SigLIP.json │ ├── ViT-SO400M-14-SigLIP2-378.json │ ├── ViT-SO400M-14-SigLIP2.json │ ├── ViT-SO400M-16-SigLIP-i18n-256.json │ ├── ViT-SO400M-16-SigLIP2-256.json │ ├── ViT-SO400M-16-SigLIP2-384.json │ ├── ViT-SO400M-16-SigLIP2-512.json │ ├── ViT-bigG-14-CLIPA-336.json │ ├── ViT-bigG-14-CLIPA.json │ ├── ViT-bigG-14-quickgelu.json │ ├── ViT-bigG-14.json │ ├── ViT-e-14.json │ ├── ViT-g-14.json │ ├── ViT-gopt-16-SigLIP2-256.json │ ├── ViT-gopt-16-SigLIP2-384.json │ ├── ViTamin-B-LTT.json │ ├── ViTamin-B.json │ ├── ViTamin-L-256.json │ ├── ViTamin-L-336.json │ ├── ViTamin-L-384.json │ ├── ViTamin-L.json │ ├── ViTamin-L2-256.json │ ├── ViTamin-L2-336.json │ ├── ViTamin-L2-384.json │ ├── ViTamin-L2.json │ ├── ViTamin-S-LTT.json │ ├── ViTamin-S.json │ ├── ViTamin-XL-256.json │ ├── ViTamin-XL-336.json │ ├── ViTamin-XL-384.json │ ├── coca_ViT-B-32.json │ ├── coca_ViT-L-14.json │ ├── coca_base.json │ ├── coca_roberta-ViT-B-32.json │ ├── convnext_base.json │ ├── convnext_base_w.json │ ├── convnext_base_w_320.json │ ├── convnext_large.json │ ├── convnext_large_d.json │ ├── convnext_large_d_320.json │ ├── convnext_small.json │ ├── convnext_tiny.json │ ├── convnext_xlarge.json │ ├── convnext_xxlarge.json │ ├── convnext_xxlarge_320.json │ ├── mt5-base-ViT-B-32.json │ ├── mt5-xl-ViT-H-14.json │ ├── nllb-clip-base-siglip.json │ ├── nllb-clip-base.json │ ├── nllb-clip-large-siglip.json │ ├── nllb-clip-large.json │ ├── roberta-ViT-B-32.json │ ├── swin_base_patch4_window7_224.json │ ├── vit_medium_patch16_gap_256.json │ ├── vit_relpos_medium_patch16_cls_224.json │ ├── xlm-roberta-base-ViT-B-32.json │ └── xlm-roberta-large-ViT-H-14.json ├── modified_resnet.py ├── openai.py ├── pos_embed.py ├── pretrained.py ├── push_to_hf_hub.py ├── save_ps3_hf_ckpt.py ├── timm_model.py ├── tokenizer.py ├── transform.py ├── transformer.py ├── utils.py ├── version.py ├── zero_shot_classifier.py └── zero_shot_metadata.py └── open_clip_train ├── __init__.py ├── data.py ├── distributed.py ├── file_utils.py ├── logger.py ├── main.py ├── params.py ├── precision.py ├── profiler.py ├── scheduler.py ├── train.py └── zero_shot.py /.gitignore: -------------------------------------------------------------------------------- 1 | dist/ 2 | 3 | tests 4 | output 5 | **__pycache__** 6 | **egg-info** 7 | 8 | LICENSE_for_PS3.md 9 | MODEL_CARD* 10 | 11 | **radio_adapter_mlp** 12 | -------------------------------------------------------------------------------- /assets/4kpro_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/PS3/529d66aa30b26438c30ce839197b2c0b2cf765e5/assets/4kpro_results.png -------------------------------------------------------------------------------- /assets/example_selection_maps/bottom_up_selection_prob.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/PS3/529d66aa30b26438c30ce839197b2c0b2cf765e5/assets/example_selection_maps/bottom_up_selection_prob.png -------------------------------------------------------------------------------- /assets/example_selection_maps/top_down_selection_prob_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/PS3/529d66aa30b26438c30ce839197b2c0b2cf765e5/assets/example_selection_maps/top_down_selection_prob_1.png -------------------------------------------------------------------------------- /assets/example_selection_maps/top_down_selection_prob_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/PS3/529d66aa30b26438c30ce839197b2c0b2cf765e5/assets/example_selection_maps/top_down_selection_prob_2.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/PS3/529d66aa30b26438c30ce839197b2c0b2cf765e5/assets/teaser.png -------------------------------------------------------------------------------- /assets/test_images/dock.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/PS3/529d66aa30b26438c30ce839197b2c0b2cf765e5/assets/test_images/dock.jpg -------------------------------------------------------------------------------- /assets/vila_hd_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/PS3/529d66aa30b26438c30ce839197b2c0b2cf765e5/assets/vila_hd_results.png -------------------------------------------------------------------------------- /ps3/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_ps3 import PS3VisionEncoder, PS3VisionModel, PS3TextModel, PS3Model 2 | from .image_processing_ps3 import PS3ImageProcessor 3 | from .configuration_ps3 import PS3VisionConfig, PS3TextConfig, PS3Config 4 | from .tokenization_ps3 import PS3Tokenizer -------------------------------------------------------------------------------- /ps3/configuration_ps3.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | from typing import Union 18 | from packaging import version 19 | 20 | import transformers 21 | from transformers.configuration_utils import PretrainedConfig 22 | from transformers.utils import logging 23 | 24 | 25 | logger = logging.get_logger(__name__) 26 | 27 | 28 | class PS3VisionConfig(PretrainedConfig): 29 | 30 | model_type = "ps3_vision_model" 31 | base_config_key = "vision_config" 32 | 33 | def __init__( 34 | self, 35 | # timm model args 36 | model_name: str = None, 37 | hidden_size: int = 1152, 38 | pool: str = 'avg', 39 | drop_path: float = None, 40 | patch_drop: float = None, 41 | pretrained: bool = False, 42 | dynamic_img_size: bool = True, 43 | # ps3 args 44 | ps3_scales: list[int] = [378, 756, 1512], 45 | select_based_on_layer: list[int] = [0, 9, 18, 26], 46 | max_select_num: int = 1280, 47 | max_select_num_each_scale: list[int] = None, 48 | separate_pos_emb: bool = True, 49 | highres_selection_feature: bool = True, 50 | highres_selection_module_hidden_dim: int = 512, 51 | highres_selection_module_out_dim: int = 512, 52 | highres_selection_module_depth: int = 3, 53 | highres_selection_module_kernel_size: int = 28, 54 | # radio args 55 | radio: bool = False, 56 | radio_adapter_mlp_version: str = None, 57 | radio_adapter_mlp_input_dim: int = None, 58 | radio_adapter_mlp_hidden_dim: int = None, 59 | radio_adapter_mlp_output_dim: int = None, 60 | radio_adapter_mlp_num_inner: int = None, 61 | img_size: int = None, 62 | drop: float = 0.0, 63 | class_token: bool = None, 64 | final_norm: bool = False, 65 | **kwargs, 66 | ): 67 | super().__init__(**kwargs) 68 | 69 | self.model_name = model_name 70 | self.hidden_size = hidden_size 71 | self.pool = pool 72 | self.drop_path = drop_path 73 | self.patch_drop = patch_drop 74 | self.pretrained = pretrained 75 | self.dynamic_img_size = dynamic_img_size 76 | self.ps3_scales = ps3_scales 77 | self.select_based_on_layer = select_based_on_layer 78 | self.max_select_num = max_select_num 79 | self.max_select_num_each_scale = max_select_num_each_scale 80 | self.separate_pos_emb = separate_pos_emb 81 | self.highres_selection_feature = highres_selection_feature 82 | self.highres_selection_module_hidden_dim = highres_selection_module_hidden_dim 83 | self.highres_selection_module_out_dim = highres_selection_module_out_dim 84 | self.highres_selection_module_depth = highres_selection_module_depth 85 | self.highres_selection_module_kernel_size = highres_selection_module_kernel_size 86 | self.radio = radio 87 | self.radio_adapter_mlp_version = radio_adapter_mlp_version 88 | self.radio_adapter_mlp_input_dim = radio_adapter_mlp_input_dim 89 | self.radio_adapter_mlp_hidden_dim = radio_adapter_mlp_hidden_dim 90 | self.radio_adapter_mlp_output_dim = radio_adapter_mlp_output_dim 91 | self.radio_adapter_mlp_num_inner = radio_adapter_mlp_num_inner 92 | self.img_size = img_size 93 | self.drop = drop 94 | self.class_token = class_token 95 | self.final_norm = final_norm 96 | 97 | # Dummy config to make vila training code happy 98 | self.vision_tower_name = model_name 99 | self.image_size = ps3_scales[-1] 100 | self.patch_size = 14 101 | 102 | @classmethod 103 | def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": 104 | if version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.47.0"): 105 | return super().from_pretrained(pretrained_model_name_or_path, **kwargs) 106 | 107 | cls._set_token_in_kwargs(kwargs) 108 | 109 | config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) 110 | 111 | # get the vision config dict if we are loading from SiglipConfig 112 | if config_dict.get("model_type") == "ps3": 113 | config_dict = config_dict["vision_config"] 114 | 115 | if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: 116 | logger.warning( 117 | f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " 118 | f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." 119 | ) 120 | 121 | return cls.from_dict(config_dict, **kwargs) 122 | 123 | 124 | class PS3TextConfig(PretrainedConfig): 125 | 126 | model_type = "ps3_text_model" 127 | base_config_key = "text_config" 128 | 129 | def __init__( 130 | self, 131 | output_dim: int = 1152, 132 | prompt_proj_dim: int = 1152, 133 | context_length: int = 77, 134 | vocab_size: int = 49408, 135 | hf_tokenizer_name: str = None, 136 | tokenizer_kwargs: dict = None, 137 | width: int = 512, 138 | heads: int = 8, 139 | layers: int = 12, 140 | mlp_ratio: float = 4.0, 141 | ls_init_value: float = None, # layer scale initial value 142 | embed_cls: bool = False, 143 | pad_id: int = 0, 144 | no_causal_mask: bool = False, # disable causal masking 145 | final_ln_after_pool: bool = False, # apply final LayerNorm after pooling 146 | pool_type: str = 'argmax', 147 | proj_bias: bool = False, 148 | output_tokens: bool = False, 149 | act_kwargs: dict = {}, 150 | norm_kwargs: dict = {}, 151 | **kwargs 152 | ): 153 | super().__init__(**kwargs) 154 | 155 | self.output_dim = output_dim 156 | self.prompt_proj_dim = prompt_proj_dim 157 | self.context_length = context_length 158 | self.vocab_size = vocab_size 159 | self.hf_tokenizer_name = hf_tokenizer_name 160 | self.tokenizer_kwargs = tokenizer_kwargs 161 | self.width = width 162 | self.heads = heads 163 | self.layers = layers 164 | self.mlp_ratio = mlp_ratio 165 | self.ls_init_value = ls_init_value 166 | self.embed_cls = embed_cls 167 | self.pad_id = pad_id 168 | self.no_causal_mask = no_causal_mask 169 | self.final_ln_after_pool = final_ln_after_pool 170 | self.pool_type = pool_type 171 | self.proj_bias = proj_bias 172 | self.output_tokens = output_tokens 173 | self.act_kwargs = act_kwargs 174 | self.norm_kwargs = norm_kwargs 175 | 176 | @classmethod 177 | def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": 178 | if version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.47.0"): 179 | return super().from_pretrained(pretrained_model_name_or_path, **kwargs) 180 | 181 | cls._set_token_in_kwargs(kwargs) 182 | 183 | config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) 184 | 185 | # get the text config dict if we are loading from SiglipConfig 186 | if config_dict.get("model_type") == "ps3": 187 | config_dict = config_dict["text_config"] 188 | 189 | if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: 190 | logger.warning( 191 | f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " 192 | f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." 193 | ) 194 | 195 | return cls.from_dict(config_dict, **kwargs) 196 | 197 | 198 | class PS3Config(PretrainedConfig): 199 | 200 | model_type = "ps3" 201 | sub_configs = {"text_config": PS3TextConfig, "vision_config": PS3VisionConfig} 202 | 203 | def __init__(self, text_config=None, vision_config=None, **kwargs): 204 | super().__init__(**kwargs) 205 | 206 | if text_config is None: 207 | text_config = {} 208 | logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.") 209 | 210 | if vision_config is None: 211 | vision_config = {} 212 | logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.") 213 | 214 | self.text_config = PS3TextConfig(**text_config) 215 | self.vision_config = PS3VisionConfig(**vision_config) -------------------------------------------------------------------------------- /ps3/image_processing_ps3.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Optional, Tuple, Union, Any 17 | from PIL import Image 18 | 19 | from torchvision.transforms.v2 import Normalize, InterpolationMode, ToTensor, Resize 20 | from torchvision.transforms.v2.functional import normalize, to_tensor, resize 21 | 22 | from transformers.image_processing_utils import BaseImageProcessor 23 | from transformers.utils import TensorType 24 | 25 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 26 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 27 | 28 | 29 | def _convert_to_rgb(image): 30 | if isinstance(image, tuple): 31 | return (image[0].convert('RGB'),) + image[1:] 32 | else: 33 | return image.convert('RGB') 34 | 35 | 36 | 37 | class PS3ImageProcessor(BaseImageProcessor): 38 | def __init__( 39 | self, 40 | image_size: Union[int, Tuple[int, int]] = None, 41 | mean: Optional[Tuple[float, ...]] = None, 42 | std: Optional[Tuple[float, ...]] = None, 43 | resize_mode: Optional[str] = None, 44 | interpolation: Optional[str] = None, 45 | **kwargs, 46 | ): 47 | self.image_size = image_size 48 | if isinstance(self.image_size, int): 49 | self.image_size = (self.image_size, self.image_size) 50 | 51 | self.mean = mean or OPENAI_DATASET_MEAN 52 | if not isinstance(self.mean, (list, tuple)): 53 | self.mean = (self.mean,) * 3 54 | 55 | self.std = std or OPENAI_DATASET_STD 56 | if not isinstance(self.std, (list, tuple)): 57 | self.std = (self.std,) * 3 58 | 59 | self.resize_mode = resize_mode or 'squash' 60 | assert self.resize_mode in ('squash') 61 | 62 | self.interpolation = interpolation or 'bicubic' 63 | assert self.interpolation in ['bicubic', 'bilinear', 'random'] 64 | 65 | # Define some attributes to align with vila code 66 | self.size = {'shortest_edge': self.image_size[0]} 67 | self.crop_size = {'height': self.image_size[0], 'width': self.image_size[0]} 68 | self.image_mean = self.mean 69 | self.image_std = self.std 70 | 71 | def preprocess( 72 | self, 73 | image: Any, 74 | return_tensors: Optional[Union[str, TensorType]] = None, 75 | ): 76 | image = Resize(self.image_size, interpolation=InterpolationMode.BILINEAR if self.interpolation == 'bilinear' else InterpolationMode.BICUBIC)(image) 77 | image = _convert_to_rgb(image) 78 | image = ToTensor()(image) 79 | image = Normalize(mean=self.mean, std=self.std)(image) 80 | 81 | data = {"pixel_values": [image]} 82 | return data 83 | 84 | 85 | -------------------------------------------------------------------------------- /ps3/tokenization_ps3.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | Adapted from https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/tokenizer.py. 17 | Originally license: https://github.com/mlfoundations/open_clip/blob/main/LICENSE. 18 | Copyright (c) 2012-2021 Gabriel Ilharco, Mitchell Wortsman, 19 | Nicholas Carlini, Rohan Taori, Achal Dave, Vaishaal Shankar, 20 | John Miller, Hongseok Namkoong, Hannaneh Hajishirzi, Ali Farhadi, 21 | Ludwig Schmidt 22 | """ 23 | 24 | import html 25 | import os 26 | import string 27 | import json 28 | from typing import List, Optional, Union 29 | import warnings 30 | 31 | import ftfy 32 | import torch 33 | from transformers import PreTrainedTokenizer, AutoTokenizer 34 | 35 | # https://stackoverflow.com/q/62691279 36 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 37 | 38 | DEFAULT_CONTEXT_LENGTH = 77 # default context length for OpenAI CLIP 39 | 40 | 41 | def basic_clean(text): 42 | text = ftfy.fix_text(text) 43 | text = html.unescape(html.unescape(text)) 44 | return text.strip() 45 | 46 | 47 | def whitespace_clean(text): 48 | text = " ".join(text.split()) 49 | text = text.strip() 50 | return text 51 | 52 | 53 | def _clean_canonicalize(x): 54 | # basic, remove whitespace, remove punctuation, lower case 55 | return canonicalize_text(basic_clean(x)) 56 | 57 | 58 | def _clean_lower(x): 59 | # basic, remove whitespace, lower case 60 | return whitespace_clean(basic_clean(x)).lower() 61 | 62 | 63 | def _clean_whitespace(x): 64 | # basic, remove whitespace 65 | return whitespace_clean(basic_clean(x)) 66 | 67 | 68 | def get_clean_fn(type: str): 69 | if type == 'canonicalize': 70 | return _clean_canonicalize 71 | elif type == 'lower': 72 | return _clean_lower 73 | elif type == 'whitespace': 74 | return _clean_whitespace 75 | else: 76 | assert False, f"Invalid clean function ({type})." 77 | 78 | 79 | def canonicalize_text( 80 | text, 81 | *, 82 | keep_punctuation_exact_string=None, 83 | trans_punctuation: dict = str.maketrans("", "", string.punctuation), 84 | ): 85 | """Returns canonicalized `text` (lowercase and punctuation removed). 86 | 87 | From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94 88 | 89 | Args: 90 | text: string to be canonicalized. 91 | keep_punctuation_exact_string: If provided, then this exact string kept. 92 | For example providing '{}' will keep any occurrences of '{}' (but will 93 | still remove '{' and '}' that appear separately). 94 | """ 95 | text = text.replace("_", " ") 96 | if keep_punctuation_exact_string: 97 | text = keep_punctuation_exact_string.join( 98 | part.translate(trans_punctuation) 99 | for part in text.split(keep_punctuation_exact_string) 100 | ) 101 | else: 102 | text = text.translate(trans_punctuation) 103 | text = text.lower() 104 | text = " ".join(text.split()) 105 | return text.strip() 106 | 107 | 108 | class PS3Tokenizer(PreTrainedTokenizer): 109 | """HuggingFace tokenizer wrapper""" 110 | 111 | def __init__( 112 | self, 113 | tokenizer_name: str, 114 | context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH, 115 | clean: str = 'whitespace', 116 | strip_sep_token: bool = False, 117 | language: Optional[str] = None, 118 | **kwargs 119 | ): 120 | self.init_kwargs = { 121 | "tokenizer_name": tokenizer_name, 122 | "context_length": context_length, 123 | "clean": clean, 124 | "strip_sep_token": strip_sep_token, 125 | "language": language, 126 | **kwargs 127 | } 128 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 129 | set_lang_fn = getattr(self.tokenizer, 'set_src_lang_special_tokens', None) 130 | if callable(set_lang_fn): 131 | self.set_lang_fn = set_lang_fn 132 | if language is not None: 133 | self.set_language(language) 134 | self.context_length = context_length 135 | self.clean_fn = get_clean_fn(clean) 136 | self.strip_sep_token = strip_sep_token 137 | 138 | def save_pretrained(self, save_directory): 139 | # dump init_kwargs into tokenizer_config.json 140 | os.makedirs(save_directory, exist_ok=True) 141 | with open(os.path.join(save_directory, "tokenizer_config.json"), "w") as f: 142 | json.dump(self.init_kwargs, f) 143 | 144 | def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor: 145 | # same cleaning as for default tokenizer, except lowercasing 146 | # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance 147 | if isinstance(texts, str): 148 | texts = [texts] 149 | 150 | context_length = context_length or self.context_length 151 | assert context_length, 'Please set a valid context length in class init or call.' 152 | 153 | texts = [self.clean_fn(text) for text in texts] 154 | input_ids = self.tokenizer.batch_encode_plus( 155 | texts, 156 | return_tensors='pt', 157 | max_length=context_length, 158 | padding='max_length', 159 | truncation=True, 160 | ).input_ids 161 | 162 | if self.strip_sep_token: 163 | input_ids = torch.where( 164 | input_ids == self.tokenizer.sep_token_id, 165 | torch.zeros_like(input_ids), 166 | input_ids, 167 | ) 168 | 169 | return input_ids 170 | 171 | def set_language(self, src_lang): 172 | if hasattr(self, 'set_lang_fn'): 173 | self.set_lang_fn(src_lang) 174 | else: 175 | warnings.warn('Cannot set language for the tokenizer.') 176 | 177 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "ps3-torch" 7 | version = "0.1.2" 8 | description = "Scaling Vision Pre-Training to 4K Resolution" 9 | readme = "README.md" 10 | authors = [ 11 | { name = "Baifeng Shi", email = "baifeng_shi@berkeley.edu" } 12 | ] 13 | license = { text = "Apache 2.0" } 14 | requires-python = ">=3.8" 15 | classifiers = [ 16 | "Programming Language :: Python :: 3", 17 | "License :: OSI Approved :: Apache Software License", 18 | ] 19 | dependencies = [ 20 | "torch", 21 | "torchvision", 22 | "transformers[sentencepiece]<=4.49.0", 23 | "timm>=1.0.15", 24 | "einops", 25 | "accelerate", 26 | "pillow", 27 | "tiktoken", 28 | "ftfy", 29 | ] 30 | 31 | [tool.setuptools] 32 | packages = ["ps3"] 33 | 34 | [project.urls] 35 | Homepage = "https://nvlabs.github.io/PS3" 36 | Code = "https://github.com/NVlabs/PS3" 37 | -------------------------------------------------------------------------------- /train/.gitignore: -------------------------------------------------------------------------------- 1 | **/logs/ 2 | **/wandb/ 3 | models/ 4 | features/ 5 | results/ 6 | output/ 7 | root/ 8 | 9 | tests/data/ 10 | tests/images 11 | *.pt 12 | 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | pip-wheel-metadata/ 37 | share/python-wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | MANIFEST 42 | 43 | # PyInstaller 44 | # Usually these files are written by a python script from a template 45 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 46 | *.manifest 47 | *.spec 48 | 49 | # Installer logs 50 | pip-log.txt 51 | pip-delete-this-directory.txt 52 | 53 | # Unit test / coverage reports 54 | htmlcov/ 55 | .tox/ 56 | .nox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *.cover 63 | *.py,cover 64 | .hypothesis/ 65 | .pytest_cache/ 66 | 67 | # Translations 68 | *.mo 69 | *.pot 70 | 71 | # Django stuff: 72 | *.log 73 | local_settings.py 74 | db.sqlite3 75 | db.sqlite3-journal 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # IPython 94 | profile_default/ 95 | ipython_config.py 96 | 97 | # pyenv 98 | .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 108 | __pypackages__/ 109 | 110 | # Celery stuff 111 | celerybeat-schedule 112 | celerybeat.pid 113 | 114 | # SageMath parsed files 115 | *.sage.py 116 | 117 | # Environments 118 | .env 119 | .venv 120 | env/ 121 | venv/ 122 | ENV/ 123 | env.bak/ 124 | venv.bak/ 125 | 126 | # Spyder project settings 127 | .spyderproject 128 | .spyproject 129 | 130 | # Rope project settings 131 | .ropeproject 132 | 133 | # mkdocs documentation 134 | /site 135 | 136 | # mypy 137 | .mypy_cache/ 138 | .dmypy.json 139 | dmypy.json 140 | 141 | # Pyre type checker 142 | .pyre/ 143 | sync.sh 144 | gpu1sync.sh 145 | .idea 146 | *.pdf 147 | **/._* 148 | **/*DS_* 149 | **.jsonl 150 | src/sbatch 151 | src/misc 152 | .vscode 153 | src/debug 154 | core.* 155 | 156 | # Allow 157 | !src/evaluation/misc/results_dbs/* 158 | -------------------------------------------------------------------------------- /train/CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.1.0 2 | message: If you use this software, please cite it as below. 3 | authors: 4 | - family-names: Ilharco 5 | given-names: Gabriel 6 | - family-names: Wortsman 7 | given-names: Mitchell 8 | - family-names: Wightman 9 | given-names: Ross 10 | - family-names: Gordon 11 | given-names: Cade 12 | - family-names: Carlini 13 | given-names: Nicholas 14 | - family-names: Taori 15 | given-names: Rohan 16 | - family-names: Dave 17 | given-names: Achal 18 | - family-names: Shankar 19 | given-names: Vaishaal 20 | - family-names: Namkoong 21 | given-names: Hongseok 22 | - family-names: Miller 23 | given-names: John 24 | - family-names: Hajishirzi 25 | given-names: Hannaneh 26 | - family-names: Farhadi 27 | given-names: Ali 28 | - family-names: Schmidt 29 | given-names: Ludwig 30 | title: OpenCLIP 31 | version: v0.1 32 | doi: 10.5281/zenodo.5143773 33 | date-released: 2021-07-28 34 | -------------------------------------------------------------------------------- /train/HISTORY.md: -------------------------------------------------------------------------------- 1 | ## 2.24.0 2 | 3 | * Fix missing space in error message 4 | * use model flag for normalizing embeddings 5 | * init logit_bias for non siglip pretrained models 6 | * Fix logit_bias load_checkpoint addition 7 | * Make CoCa model match CLIP models for logit scale/bias init 8 | * Fix missing return of "logit_bias" in CoCa.forward 9 | * Add NLLB-CLIP with SigLIP models 10 | * Add get_logits method and NLLB tokenizer 11 | * Remove the empty file src/open_clip/generation_utils.py 12 | * Update params.py: "BatchNorm" -> "LayerNorm" in the description string for "--lock-text-freeze-layer-norm" 13 | 14 | ## 2.23.0 15 | 16 | * Add CLIPA-v2 models 17 | * Add SigLIP models 18 | * Add MetaCLIP models 19 | * Add NLLB-CLIP models 20 | * CLIPA train code 21 | * Minor changes/fixes 22 | * Remove protobuf version limit 23 | * Stop checking model name when loading CoCa models 24 | * Log native wandb step 25 | * Use bool instead of long masks 26 | 27 | ## 2.21.0 28 | 29 | * Add SigLIP loss + training support 30 | * Add more DataComp models (B/16, B/32 and B/32@256) 31 | * Update default num workers 32 | * Update CoCa generation for `transformers>=4.31` 33 | * PyTorch 2.0 `state_dict()` compatibility fix for compiled models 34 | * Fix padding in `ResizeMaxSize` 35 | * Convert JIT model on state dict load for `pretrained='filename…'` 36 | * Other minor changes and fixes (typos, README, dependencies, CI) 37 | 38 | ## 2.20.0 39 | 40 | * Add EVA models 41 | * Support serial worker training 42 | * Fix Python 3.7 compatibility 43 | 44 | ## 2.19.0 45 | 46 | * Add DataComp models 47 | 48 | ## 2.18.0 49 | 50 | * Enable int8 inference without `.weight` attribute 51 | 52 | ## 2.17.2 53 | 54 | * Update push_to_hf_hub 55 | 56 | ## 2.17.0 57 | 58 | * Add int8 support 59 | * Update notebook demo 60 | * Refactor zero-shot classification code 61 | 62 | ## 2.16.2 63 | 64 | * Fixes for context_length and vocab_size attributes 65 | 66 | ## 2.16.1 67 | 68 | * Fixes for context_length and vocab_size attributes 69 | * Fix --train-num-samples logic 70 | * Add HF BERT configs for PubMed CLIP model 71 | 72 | ## 2.16.0 73 | 74 | * Add improved g-14 weights 75 | * Update protobuf version 76 | 77 | ## 2.15.0 78 | 79 | * Add convnext_xxlarge weights 80 | * Fixed import in readme 81 | * Add samples per second per gpu logging 82 | * Fix slurm example 83 | 84 | ## 2.14.0 85 | 86 | * Move dataset mixtures logic to shard level 87 | * Fix CoCa accum-grad training 88 | * Safer transformers import guard 89 | * get_labels refactoring 90 | 91 | ## 2.13.0 92 | 93 | * Add support for dataset mixtures with different sampling weights 94 | * Make transformers optional again 95 | 96 | ## 2.12.0 97 | 98 | * Updated convnext configs for consistency 99 | * Added input_patchnorm option 100 | * Clean and improve CoCa generation 101 | * Support model distillation 102 | * Add ConvNeXt-Large 320x320 fine-tune weights 103 | 104 | ## 2.11.1 105 | 106 | * Make transformers optional 107 | * Add MSCOCO CoCa finetunes to pretrained models 108 | 109 | ## 2.11.0 110 | 111 | * coca support and weights 112 | * ConvNeXt-Large weights 113 | 114 | ## 2.10.1 115 | 116 | * `hf-hub:org/model_id` support for loading models w/ config and weights in Hugging Face Hub 117 | 118 | ## 2.10.0 119 | 120 | * Added a ViT-bigG-14 model. 121 | * Added an up-to-date example slurm script for large training jobs. 122 | * Added a option to sync logs and checkpoints to S3 during training. 123 | * New options for LR schedulers, constant and constant with cooldown 124 | * Fix wandb autoresuming when resume is not set 125 | * ConvNeXt `base` & `base_w` pretrained models added 126 | * `timm-` model prefix removed from configs 127 | * `timm` augmentation + regularization (dropout / drop-path) supported 128 | 129 | ## 2.9.3 130 | 131 | * Fix wandb collapsing multiple parallel runs into a single one 132 | 133 | ## 2.9.2 134 | 135 | * Fix braceexpand memory explosion for complex webdataset urls 136 | 137 | ## 2.9.1 138 | 139 | * Fix release 140 | 141 | ## 2.9.0 142 | 143 | * Add training feature to auto-resume from the latest checkpoint on restart via `--resume latest` 144 | * Allow webp in webdataset 145 | * Fix logging for number of samples when using gradient accumulation 146 | * Add model configs for convnext xxlarge 147 | 148 | ## 2.8.2 149 | 150 | * wrapped patchdropout in a torch.nn.Module 151 | 152 | ## 2.8.1 153 | 154 | * relax protobuf dependency 155 | * override the default patch dropout value in 'vision_cfg' 156 | 157 | ## 2.8.0 158 | 159 | * better support for HF models 160 | * add support for gradient accumulation 161 | * CI fixes 162 | * add support for patch dropout 163 | * add convnext configs 164 | 165 | 166 | ## 2.7.0 167 | 168 | * add multilingual H/14 xlm roberta large 169 | 170 | ## 2.6.1 171 | 172 | * fix setup.py _read_reqs 173 | 174 | ## 2.6.0 175 | 176 | * Make openclip training usable from pypi. 177 | * Add xlm roberta large vit h 14 config. 178 | 179 | ## 2.5.0 180 | 181 | * pretrained B/32 xlm roberta base: first multilingual clip trained on laion5B 182 | * pretrained B/32 roberta base: first clip trained using an HF text encoder 183 | 184 | ## 2.4.1 185 | 186 | * Add missing hf_tokenizer_name in CLIPTextCfg. 187 | 188 | ## 2.4.0 189 | 190 | * Fix #211, missing RN50x64 config. Fix type of dropout param for ResNet models 191 | * Bring back LayerNorm impl that casts to input for non bf16/fp16 192 | * zero_shot.py: set correct tokenizer based on args 193 | * training/params.py: remove hf params and get them from model config 194 | 195 | ## 2.3.1 196 | 197 | * Implement grad checkpointing for hf model. 198 | * custom_text: True if hf_model_name is set 199 | * Disable hf tokenizer parallelism 200 | 201 | ## 2.3.0 202 | 203 | * Generalizable Text Transformer with HuggingFace Models (@iejMac) 204 | 205 | ## 2.2.0 206 | 207 | * Support for custom text tower 208 | * Add checksum verification for pretrained model weights 209 | 210 | ## 2.1.0 211 | 212 | * lot including sota models, bfloat16 option, better loading, better metrics 213 | 214 | ## 1.2.0 215 | 216 | * ViT-B/32 trained on Laion2B-en 217 | * add missing openai RN50x64 model 218 | 219 | ## 1.1.1 220 | 221 | * ViT-B/16+ 222 | * Add grad checkpointing support 223 | * more robust data loader 224 | -------------------------------------------------------------------------------- /train/Makefile: -------------------------------------------------------------------------------- 1 | install: ## [Local development] Upgrade pip, install requirements, install package. 2 | python -m pip install -U pip 3 | python -m pip install -e . 4 | 5 | install-training: 6 | python -m pip install -r requirements-training.txt 7 | 8 | install-test: ## [Local development] Install test requirements 9 | python -m pip install -r requirements-test.txt 10 | 11 | test: ## [Local development] Run unit tests 12 | python -m pytest -x -s -v tests 13 | -------------------------------------------------------------------------------- /train/README.md: -------------------------------------------------------------------------------- 1 | # Pre-Training PS3 2 | 3 | This codebase is built largely on top of [OpenCLIP](https://github.com/mlfoundations/open_clip). The main changes include **1)** adding the snippet to build PS3 model in `src/open_clip/model.py`, **2)** adding the image-text-box dataset in `src/open_clip_train/data.py`, **3)** adding the PS3 pre-training loss in `src/open_clip/loss.py`, and **4)** slightly modifying `src/open_clip_train/train.py` to support PS3 training. Common practices of training with OpenCLIP should be inherited in general. 4 | 5 | ## Installation 6 | 7 | First make sure to install PS3 as instructed [here](https://github.com/NVLabs/PS3). 8 | 9 | Then install this codebase as following (there's also instructions in [OpenCLIP](https://github.com/mlfoundations/open_clip/tree/7260a46e7b4bcf518f5200fea06da5bc85aae025?tab=readme-ov-file#development)): 10 | 11 | ```bash 12 | make install 13 | make install-training 14 | ``` 15 | 16 | ## Data Preparation 17 | 18 | Trainin data should be in webdataset format, following the original OpenCLIP. Specifically, we use two separate webdatasets for images and text-box pairs. The image webdataset should be in the following structure: 19 | 20 | ``` 21 | images_path/ 22 | |-- image_00000.tar 23 | | |-- 00000001.jpg 24 | | |-- 00000002.jpg 25 | | |-- ... 26 | |-- image_00001.tar 27 | | |-- 00010001.jpg 28 | | |-- 00010002.jpg 29 | | |-- ... 30 | |-- ... 31 | ``` 32 | 33 | The text-box webdataset should have the exact same structure as the image webdataset, but with the text-box annotations in the format of `json` instead of images: 34 | 35 | ``` 36 | text_boxes_path/ 37 | |-- text_box_00000.tar 38 | | |-- 00000001.json 39 | | |-- 00000002.json 40 | | |-- ... 41 | |-- text_box_00001.tar 42 | | |-- 00010001.json 43 | | |-- 00010002.json 44 | | |-- ... 45 | |-- ... 46 | ``` 47 | 48 | Each json file contains pairs of local captions and local bounding boxes, as well as a global caption of the image. For example: 49 | 50 | ```json 51 | { 52 | "text": [ 53 | "The second image is a cropped view of the Ferris wheel and the clock tower from the first image. The Ferris wheel is prominently visible on the left side, with a green and white color scheme. The clock tower, featuring a white clock face and a green roof, is situated in the middle of the image. The sky appears overcast, and there are no other significant objects or text in the frame.", 54 | "The second image is a cropped view of the first image, focusing on the water and the distant cityscape. The image shows a body of water with a few boats and a dock in the foreground. In the background, there are several buildings and structures, including a large Ferris wheel and a tall skyscraper. The sky is clear with some clouds.", 55 | "The second image shows a cropped view of the first image, focusing on a section of the cityscape. The buildings are tall and modern, with various signs and advertisements visible. The text on the signs includes \"UBS,\" \"AXA,\" and \"PRUDENTIAL.\" The sky is clear, and the overall color palette is dominated by blues and grays, typical of an urban environment.", 56 | "The second image is a close-up crop of the first image, focusing on a section of the cityscape. It features tall skyscrapers with reflective glass facades, a few smaller buildings, and a construction site with scaffolding and a crane. There are also some trees and bushes in the foreground, and a few streetlights are visible. The sky is clear with a few clouds." 57 | ], 58 | "box": [ 59 | [ 60 | 1020, 61 | 780, 62 | 1380, 63 | 1020 64 | ], 65 | [ 66 | 2010, 67 | 675, 68 | 2190, 69 | 1125 70 | ], 71 | [ 72 | 3075, 73 | 810, 74 | 3525, 75 | 990 76 | ], 77 | [ 78 | 510, 79 | 675, 80 | 690, 81 | 1125 82 | ] 83 | ], 84 | "global_text": "The image is a panoramic view of a cityscape with a prominent waterfront. On the left side, there is a tall skyscraper with a distinctive Ferris wheel in front of it. The Ferris wheel is surrounded by several other high-rise buildings. The middle ground features a large body of water, possibly a bay or harbor, with a few boats visible on the water. On the right side, there are more high-rise buildings, including a particularly tall skyscraper that stands out due to its height and design. The sky is clear with a few scattered clouds, suggesting a sunny day. The overall scene is vibrant and bustling, indicative of a major urban area." 85 | } 86 | ``` 87 | 88 | Release of pre-training data is still under review. You can also build you own data following the pipeline described in the [paper](https://arxiv.org/abs/2503.19903). 89 | 90 | 91 | ## Training 92 | 93 | Example trianing scripts for PS3-1.5K-SigLIP and PS3-4K-SigLIP are provided in `train/scripts`. The scripts are using single GPU node but we suggest using multiple nodes. The models in the paper are trained with 16 nodes of 8xA100. 94 | 95 | ## Convert the checkpoint to use with PS3 package 96 | 97 | Since the checkpoint format is different from what is used by the [PS3 package](https://github.com/NVLabs/PS3) (which is huggingface format), we need to convert the checkpoint format with: 98 | 99 | ```bash 100 | python -m src.open_clip.save_ps3_hf_ckpt --model --pretrained --save-dir 101 | ``` 102 | 103 | After this you will be able to load the chkeckpoint with PS3 package, for example: 104 | 105 | ```bash 106 | vision_model = PS3VisionModel.from_pretrained("path_after_conversion") 107 | ``` 108 | 109 | ## More Information 110 | 111 | We suggest checking out the original [OpenCLIP](https://github.com/mlfoundations/open_clip/tree/7260a46e7b4bcf518f5200fea06da5bc85aae025?tab=readme-ov-file#training-clip) repository for other important tips on training, such as training with multiple data sources and more efficient training. 112 | 113 | 114 | ## Citing 115 | 116 | If you found this repository useful, please consider citing PS3 and the original OpenCLIP repository: 117 | 118 | ```bibtex 119 | @article{shi2025scaling, 120 | title={Scaling Vision Pre-Training to 4K Resolution}, 121 | author={Shi, Baifeng and Li, Boyi and Cai, Han and Lu, Yao and Liu, Sifei and Pavone, Marco and Kautz, Jan and Han, Song and Darrell, Trevor and Molchanov, Pavlo and others}, 122 | journal={arXiv preprint arXiv:2503.19903}, 123 | year={2025} 124 | } 125 | ``` 126 | 127 | ```bibtex 128 | @software{ilharco_gabriel_2021_5143773, 129 | author = {Ilharco, Gabriel and 130 | Wortsman, Mitchell and 131 | Wightman, Ross and 132 | Gordon, Cade and 133 | Carlini, Nicholas and 134 | Taori, Rohan and 135 | Dave, Achal and 136 | Shankar, Vaishaal and 137 | Namkoong, Hongseok and 138 | Miller, John and 139 | Hajishirzi, Hannaneh and 140 | Farhadi, Ali and 141 | Schmidt, Ludwig}, 142 | title = {OpenCLIP}, 143 | month = jul, 144 | year = 2021, 145 | note = {If you use this software, please cite it as below.}, 146 | publisher = {Zenodo}, 147 | version = {0.1}, 148 | doi = {10.5281/zenodo.5143773}, 149 | url = {https://doi.org/10.5281/zenodo.5143773} 150 | } 151 | ``` 152 | 153 | ```bibtex 154 | @inproceedings{cherti2023reproducible, 155 | title={Reproducible scaling laws for contrastive language-image learning}, 156 | author={Cherti, Mehdi and Beaumont, Romain and Wightman, Ross and Wortsman, Mitchell and Ilharco, Gabriel and Gordon, Cade and Schuhmann, Christoph and Schmidt, Ludwig and Jitsev, Jenia}, 157 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 158 | pages={2818--2829}, 159 | year={2023} 160 | } 161 | ``` 162 | -------------------------------------------------------------------------------- /train/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["pdm-backend"] 3 | build-backend = "pdm.backend" 4 | 5 | [project] 6 | name = "open_clip_torch" 7 | # NOTE for full list of authors see https://github.com/mlfoundations/open_clip?tab=readme-ov-file#citing 8 | # below covers most active / recent maintainers 9 | authors = [ 10 | {name = "Ross Wightman", email = "ross@huggingface.co"}, 11 | {name = "Gabriel Ilharco"}, 12 | {name = "Mitchell Wortsman"}, 13 | {name = "Romain Beaumont"}, 14 | ] 15 | description = "Open reproduction of consastive language-image pretraining (CLIP) and related." 16 | readme = "README.md" 17 | requires-python = ">=3.8" 18 | keywords = ["pytorch", "clip", "image-text", "language-image", "multimodal"] 19 | license = {text = "MIT"} 20 | classifiers = [ 21 | 'Development Status :: 4 - Beta', 22 | 'Intended Audience :: Education', 23 | 'Intended Audience :: Science/Research', 24 | 'License :: OSI Approved :: MIT License', 25 | 'Programming Language :: Python :: 3.8', 26 | 'Programming Language :: Python :: 3.9', 27 | 'Programming Language :: Python :: 3.10', 28 | 'Programming Language :: Python :: 3.11', 29 | 'Programming Language :: Python :: 3.12', 30 | 'Topic :: Scientific/Engineering', 31 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 32 | 'Topic :: Software Development', 33 | 'Topic :: Software Development :: Libraries', 34 | 'Topic :: Software Development :: Libraries :: Python Modules', 35 | ] 36 | dependencies = [ 37 | 'torch>=1.9.0', 38 | 'torchvision', 39 | 'regex', 40 | 'ftfy', 41 | 'tqdm', 42 | 'huggingface-hub', 43 | 'safetensors', 44 | 'timm', 45 | ] 46 | dynamic = ["version"] 47 | 48 | [project.optional-dependencies] 49 | training = [ 50 | 'torch>=2.0', 51 | 'webdataset>=0.2.5,<=0.2.86', 52 | 'pandas', 53 | 'transformers[sentencepiece]', 54 | 'timm>=1.0.10', 55 | 'fsspec', 56 | ] 57 | test = [ 58 | 'pytest-split', 59 | 'pytest', 60 | 'open_clip_torch[training]' 61 | ] 62 | 63 | [project.urls] 64 | homepage = "https://github.com/mlfoundations/open_clip" 65 | repository = "https://github.com/mlfoundations/open_clip" 66 | 67 | [tool.pdm.version] 68 | source = "file" 69 | path = "src/open_clip/version.py" 70 | 71 | [tool.pdm.build] 72 | excludes = ["./**/.git", "./**/logs/*"] 73 | package-dir = "src" 74 | includes = ["src/open_clip", "src/open_clip_train"] 75 | 76 | [tool.pytest.ini_options] 77 | testpaths = ['tests'] 78 | markers = [ 79 | 'regression_test' 80 | ] -------------------------------------------------------------------------------- /train/pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | regression_test 4 | -------------------------------------------------------------------------------- /train/requirements-test.txt: -------------------------------------------------------------------------------- 1 | pytest-split==0.8.0 2 | pytest==7.2.0 3 | transformers[sentencepiece] 4 | timm>=1.0.10 5 | -------------------------------------------------------------------------------- /train/requirements-training.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.0 2 | torchvision 3 | webdataset>=0.2.5,<=0.2.86 4 | regex 5 | ftfy 6 | tqdm 7 | pandas 8 | braceexpand 9 | huggingface_hub 10 | safetensors 11 | transformers[sentencepiece] 12 | timm>=1.0.15 13 | fsspec 14 | -------------------------------------------------------------------------------- /train/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.0 2 | torchvision 3 | regex 4 | ftfy 5 | tqdm 6 | huggingface_hub 7 | safetensors 8 | timm 9 | -------------------------------------------------------------------------------- /train/scripts/ps3_1.5k_siglip.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | torchrun --nnodes=1 --nproc_per_node=8 --master_port=25001 \ 4 | -m open_clip_train.main \ 5 | --model PS3-1.5K-SigLIP \ 6 | --pretrained webli \ 7 | --siglip \ 8 | --train-data-image \ 9 | --train-data-text-box \ 10 | --sentence_masking \ 11 | --select_w_gt \ 12 | --use_global_caption \ 13 | --pool_gt_token_only \ 14 | --global_caption_prob 0.25 \ 15 | --train-num-samples 1000000 \ 16 | --dataset-resampled \ 17 | --dataset-type webdataset_separate_image_text_box \ 18 | --batch-size 32 \ 19 | --epochs 75 \ 20 | --warmup 1500 \ 21 | --lr 5e-6 \ 22 | --beta1 0.9 \ 23 | --beta2 0.95 \ 24 | --wd 0.0003 \ 25 | --selection_prob_loss_weight 0 \ 26 | --prior_box_supervision_loss_weight 1 \ 27 | --posterior_box_supervision_loss_weight 1 \ 28 | --precision amp \ 29 | --grad-checkpointing \ 30 | --gather-with-grad \ 31 | --local-loss \ 32 | --workers 4 \ 33 | --name ps3_1.5k_siglip \ 34 | --resume latest \ 35 | --save-frequency 1 \ 36 | --save-most-recent \ 37 | --ddp-static-graph \ 38 | --report-to wandb 39 | -------------------------------------------------------------------------------- /train/scripts/ps3_4k_siglip.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | torchrun --nnodes=1 --nproc_per_node=8 --master_port=25001 \ 4 | -m open_clip_train.main \ 5 | --model PS3-4K-SigLIP \ 6 | --pretrained webli \ 7 | --siglip \ 8 | --train-data-image \ 9 | --train-data-text-box \ 10 | --sentence_masking \ 11 | --select_w_gt \ 12 | --use_global_caption \ 13 | --pool_gt_token_only \ 14 | --global_caption_prob 0.25 \ 15 | --train-num-samples 1000000 \ 16 | --dataset-resampled \ 17 | --dataset-type webdataset_separate_image_text_box \ 18 | --batch-size 32 \ 19 | --epochs 75 \ 20 | --warmup 1500 \ 21 | --lr 5e-6 \ 22 | --beta1 0.9 \ 23 | --beta2 0.95 \ 24 | --wd 0.0003 \ 25 | --selection_prob_loss_weight 0 \ 26 | --prior_box_supervision_loss_weight 1 \ 27 | --posterior_box_supervision_loss_weight 1 \ 28 | --precision amp \ 29 | --grad-checkpointing \ 30 | --gather-with-grad \ 31 | --local-loss \ 32 | --workers 4 \ 33 | --name ps3_4k_siglip \ 34 | --resume latest \ 35 | --save-frequency 1 \ 36 | --save-most-recent \ 37 | --ddp-static-graph \ 38 | --report-to wandb 39 | -------------------------------------------------------------------------------- /train/src/open_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ 2 | 3 | from .coca_model import CoCa 4 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 5 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss 6 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint 7 | from .loss import ClipLoss, DistillClipLoss, CoCaLoss 8 | from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ 9 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \ 10 | get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg 11 | from .openai import load_openai_model, list_openai_models 12 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ 13 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 14 | from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub 15 | from .tokenizer import SimpleTokenizer, tokenize, decode 16 | from .transform import image_transform, AugmentationCfg 17 | from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy 18 | from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES 19 | from .radio_adapter_mlp import create_mlp_from_state, create_mlp_from_config, get_mlp_info_from_state 20 | -------------------------------------------------------------------------------- /train/src/open_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/PS3/529d66aa30b26438c30ce839197b2c0b2cf765e5/train/src/open_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /train/src/open_clip/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 3 | IMAGENET_MEAN = (0.485, 0.456, 0.406) 4 | IMAGENET_STD = (0.229, 0.224, 0.225) 5 | INCEPTION_MEAN = (0.5, 0.5, 0.5) 6 | INCEPTION_STD = (0.5, 0.5, 0.5) 7 | 8 | # Default name for a weights file hosted on the Huggingface Hub. 9 | HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl 10 | HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version 11 | HF_CONFIG_NAME = 'open_clip_config.json' -------------------------------------------------------------------------------- /train/src/open_clip/hf_configs.py: -------------------------------------------------------------------------------- 1 | # HF architecture dict: 2 | arch_dict = { 3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 4 | "roberta": { 5 | "config_names": { 6 | "context_length": "max_position_embeddings", 7 | "vocab_size": "vocab_size", 8 | "width": "hidden_size", 9 | "heads": "num_attention_heads", 10 | "layers": "num_hidden_layers", 11 | "layer_attr": "layer", 12 | "token_embeddings_attr": "embeddings" 13 | }, 14 | "pooler": "mean_pooler", 15 | }, 16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 17 | "xlm-roberta": { 18 | "config_names": { 19 | "context_length": "max_position_embeddings", 20 | "vocab_size": "vocab_size", 21 | "width": "hidden_size", 22 | "heads": "num_attention_heads", 23 | "layers": "num_hidden_layers", 24 | "layer_attr": "layer", 25 | "token_embeddings_attr": "embeddings" 26 | }, 27 | "pooler": "mean_pooler", 28 | }, 29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 30 | "mt5": { 31 | "config_names": { 32 | # unlimited seqlen 33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 35 | "context_length": "", 36 | "vocab_size": "vocab_size", 37 | "width": "d_model", 38 | "heads": "num_heads", 39 | "layers": "num_layers", 40 | "layer_attr": "block", 41 | "token_embeddings_attr": "embed_tokens" 42 | }, 43 | "pooler": "mean_pooler", 44 | }, 45 | # https://huggingface.co/docs/transformers/model_doc/bert 46 | "bert": { 47 | "config_names": { 48 | "context_length": "max_position_embeddings", 49 | "vocab_size": "vocab_size", 50 | "width": "hidden_size", 51 | "heads": "num_attention_heads", 52 | "layers": "num_hidden_layers", 53 | }, 54 | "pooler": "cls_pooler", 55 | }, 56 | # https://huggingface.co/docs/transformers/model_doc/m2m_100 57 | "m2m_100": { 58 | "config_names": { 59 | "context_length": "max_position_embeddings", 60 | "vocab_size": "vocab_size", 61 | "width": "d_model", 62 | "heads": "encoder_attention_heads", 63 | "layers": "encoder_layers", 64 | }, 65 | "pooler": "cls_pooler", 66 | }, 67 | } 68 | -------------------------------------------------------------------------------- /train/src/open_clip/hf_model.py: -------------------------------------------------------------------------------- 1 | """ huggingface model adapter 2 | 3 | Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. 4 | """ 5 | import re 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch import TensorType 10 | 11 | try: 12 | import transformers 13 | from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig 14 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ 15 | BaseModelOutputWithPoolingAndCrossAttentions 16 | except ImportError as e: 17 | transformers = None 18 | 19 | 20 | class BaseModelOutput: 21 | pass 22 | 23 | 24 | class PretrainedConfig: 25 | pass 26 | 27 | from .hf_configs import arch_dict 28 | 29 | 30 | # utils 31 | def _camel2snake(s): 32 | return re.sub(r'(? List[str]: 20 | """Returns the names of available CLIP models""" 21 | return list_pretrained_models_by_tag('openai') 22 | 23 | 24 | def load_openai_model( 25 | name: str, 26 | precision: Optional[str] = None, 27 | device: Optional[Union[str, torch.device]] = None, 28 | cache_dir: Optional[str] = None, 29 | ): 30 | """Load a CLIP model 31 | 32 | Parameters 33 | ---------- 34 | name : str 35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 36 | precision: str 37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | cache_dir : Optional[str] 41 | The directory to cache the downloaded model weights 42 | 43 | Returns 44 | ------- 45 | model : torch.nn.Module 46 | The CLIP model 47 | preprocess : Callable[[PIL.Image], torch.Tensor] 48 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 49 | """ 50 | if device is None: 51 | device = "cuda" if torch.cuda.is_available() else "cpu" 52 | if precision is None: 53 | precision = 'fp32' if device == 'cpu' else 'fp16' 54 | 55 | if get_pretrained_url(name, 'openai'): 56 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) 57 | elif os.path.isfile(name): 58 | model_path = name 59 | else: 60 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 61 | 62 | try: 63 | # loading JIT archive 64 | model = torch.jit.load(model_path, map_location="cpu").eval() 65 | state_dict = None 66 | except RuntimeError: 67 | # loading saved state dict 68 | state_dict = torch.load(model_path, map_location="cpu") 69 | 70 | # Build a non-jit model from the OpenAI jitted model state dict 71 | cast_dtype = get_cast_dtype(precision) 72 | try: 73 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) 74 | except KeyError: 75 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 76 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) 77 | 78 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use 79 | model = model.to(device) 80 | # FIXME support pure fp16/bf16 precision modes 81 | if precision != 'fp16': 82 | model.float() 83 | if precision == 'bf16': 84 | # for bf16, convert back to low-precision 85 | convert_weights_to_lp(model, dtype=torch.bfloat16) 86 | 87 | # add mean / std attributes for consistency with OpenCLIP models 88 | model.visual.image_mean = OPENAI_DATASET_MEAN 89 | model.visual.image_std = OPENAI_DATASET_STD 90 | return model -------------------------------------------------------------------------------- /train/src/open_clip/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed -------------------------------------------------------------------------------- /train/src/open_clip/save_ps3_hf_ckpt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from pathlib import Path 5 | from tempfile import TemporaryDirectory 6 | from typing import Optional, Tuple, Union 7 | 8 | import torch 9 | 10 | try: 11 | from huggingface_hub import ( 12 | create_repo, 13 | get_hf_file_metadata, 14 | hf_hub_download, 15 | hf_hub_url, 16 | repo_type_and_id_from_hf_id, 17 | upload_folder, 18 | list_repo_files, 19 | ) 20 | from huggingface_hub.utils import EntryNotFoundError 21 | _has_hf_hub = True 22 | except ImportError: 23 | _has_hf_hub = False 24 | 25 | try: 26 | import safetensors.torch 27 | _has_safetensors = True 28 | except ImportError: 29 | _has_safetensors = False 30 | 31 | from .factory import create_model_from_pretrained, get_model_config, get_tokenizer 32 | from .tokenizer import HFTokenizer, DEFAULT_CONTEXT_LENGTH 33 | 34 | # Default name for a weights file hosted on the Huggingface Hub. 35 | HF_WEIGHTS_NAME = "model.bin" # default pytorch pkl 36 | HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version 37 | HF_CONFIG_NAME = 'config.json' 38 | HF_PROCESSOR_CONFIG_NAME = 'preprocessor_config.json' 39 | HF_TOKENIZER_CONFIG_NAME = 'tokenizer_config.json' 40 | 41 | 42 | def save_tokenizer_config( 43 | model_config, 44 | save_dir: str, 45 | ): 46 | save_dir = Path(save_dir) 47 | save_dir.mkdir(exist_ok=True, parents=True) 48 | 49 | hf_config = { 50 | "tokenizer_name": model_config["text_cfg"]["hf_tokenizer_name"], 51 | "context_length": model_config["text_cfg"].get('context_length', DEFAULT_CONTEXT_LENGTH), 52 | **model_config["text_cfg"].get("tokenizer_kwargs", {}), 53 | } 54 | 55 | with (save_dir / HF_TOKENIZER_CONFIG_NAME).open('w') as f: 56 | json.dump(hf_config, f, indent=2) 57 | 58 | 59 | def save_preprocessor_config( 60 | model, 61 | save_dir: str, 62 | ): 63 | save_dir = Path(save_dir) 64 | save_dir.mkdir(exist_ok=True, parents=True) 65 | 66 | cfg = model.visual.preprocess_cfg 67 | 68 | hf_config = { 69 | "image_size": cfg["size"], 70 | "mean": cfg["mean"], 71 | "std": cfg["std"], 72 | "interpolation": cfg["interpolation"], 73 | "resize_mode": cfg["resize_mode"], 74 | # "fill_color": cfg["fill_color"], 75 | } 76 | 77 | with (save_dir / HF_PROCESSOR_CONFIG_NAME).open('w') as f: 78 | json.dump(hf_config, f, indent=2) 79 | 80 | 81 | def save_config( 82 | model, 83 | save_dir: str, 84 | model_config: Optional[dict], 85 | save_vision_model_only: bool = False, 86 | ): 87 | save_dir = Path(save_dir) 88 | save_dir.mkdir(exist_ok=True, parents=True) 89 | 90 | # Vision model config 91 | vision_hf_config = { 92 | "architectures": ["PS3VisionModel"], 93 | "model_type": "ps3_vision_model", 94 | **model_config["vision_cfg"], 95 | } 96 | 97 | # Text model config 98 | text_hf_config = model_config["text_cfg"] 99 | text_hf_config["architectures"] = ["PS3TextModel"] 100 | text_hf_config["model_type"] = "ps3_text_model" 101 | text_hf_config["output_dim"] = model_config["embed_dim"] 102 | text_hf_config["prompt_proj_dim"] = model.visual.width 103 | 104 | # Merge vision and text configs 105 | if save_vision_model_only: 106 | hf_config = vision_hf_config 107 | else: 108 | hf_config = { 109 | "architectures": ["PS3Model"], 110 | "model_type": "ps3", 111 | "vision_config": vision_hf_config, 112 | "text_config": text_hf_config, 113 | } 114 | 115 | with (save_dir / HF_CONFIG_NAME).open('w') as f: 116 | json.dump(hf_config, f, indent=2) 117 | 118 | 119 | def save_model( 120 | model, 121 | save_dir: str, 122 | safe_serialization: Union[bool, str] = True, 123 | save_vision_model_only: bool = False, 124 | ): 125 | save_dir = Path(save_dir) 126 | save_dir.mkdir(exist_ok=True, parents=True) 127 | 128 | tensors = model.state_dict() 129 | 130 | # process vision model weights 131 | tensors = {k.replace("visual.", "vision_model."): v for k, v in tensors.items()} 132 | 133 | # process text model weights 134 | tensors = {k.replace("text.", "text_model."): v for k, v in tensors.items()} 135 | tensors = {"text_model." + k if k.startswith("prompt_proj.") else k: v for k, v in tensors.items()} 136 | 137 | if save_vision_model_only: 138 | tensors = {k: v for k, v in tensors.items() if "vision_model." in k} 139 | 140 | if safe_serialization is True or safe_serialization == "both": 141 | assert _has_safetensors, "`pip install safetensors` to use .safetensors" 142 | safetensors.torch.save_file(tensors, save_dir / HF_SAFE_WEIGHTS_NAME, metadata={'format': 'pt'}) 143 | if safe_serialization is False or safe_serialization == "both": 144 | torch.save(tensors, save_dir / HF_WEIGHTS_NAME) 145 | 146 | 147 | def save_hf_ckpt( 148 | model_name, 149 | pretrained: str, 150 | save_dir: str, 151 | save_vision_model_only: bool = False, 152 | **kwargs, 153 | ): 154 | model, processor = create_model_from_pretrained( 155 | model_name, 156 | pretrained=pretrained, 157 | load_weights_only=False, 158 | ) 159 | 160 | model_config = get_model_config(model_name) 161 | 162 | save_model( 163 | model, 164 | save_dir=save_dir, 165 | save_vision_model_only=save_vision_model_only, 166 | ) 167 | 168 | save_config( 169 | model, 170 | save_dir=save_dir, 171 | model_config=model_config, 172 | save_vision_model_only=save_vision_model_only, 173 | ) 174 | 175 | save_preprocessor_config( 176 | model, 177 | save_dir=save_dir, 178 | ) 179 | 180 | save_tokenizer_config( 181 | model_config, 182 | save_dir=save_dir, 183 | ) 184 | 185 | 186 | def push_hf_ckpt(repo_id, save_dir, token=None, private=False): 187 | # Create repo if it doesn't exist yet 188 | repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) 189 | 190 | # Infer complete repo_id from repo_url 191 | # Can be different from the input `repo_id` if repo_owner was implicit 192 | _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) 193 | repo_id = f"{repo_owner}/{repo_name}" 194 | 195 | upload_folder( 196 | repo_id=repo_id, 197 | folder_path=save_dir, 198 | ) 199 | 200 | 201 | if __name__ == "__main__": 202 | parser = argparse.ArgumentParser(description="Push to Hugging Face Hub") 203 | parser.add_argument( 204 | "--model", type=str, help="Name of the model to use.", 205 | ) 206 | parser.add_argument( 207 | "--pretrained", type=str, 208 | help="Use a pretrained CLIP model weights with the specified tag or file path.", 209 | ) 210 | parser.add_argument( 211 | "--save-dir", type=str, 212 | help="Which directory to save the model to.", 213 | ) 214 | parser.add_argument( 215 | "--save-vision-model-only", 216 | help="Whether to save the vision model weights only.", 217 | action="store_true", 218 | ) 219 | parser.add_argument( 220 | "--push-to-repo", type=str, 221 | help="Destination HF Hub repo-id ie 'organization/model_id'.", 222 | default=None 223 | ) 224 | args = parser.parse_args() 225 | 226 | save_hf_ckpt( 227 | args.model, 228 | args.pretrained, 229 | args.save_dir, 230 | save_vision_model_only=args.save_vision_model_only, 231 | ) 232 | print(f'{args.pretrained} saved to {args.save_dir}.') 233 | 234 | if args.push_to_repo is not None: 235 | push_hf_ckpt( 236 | repo_id=args.push_to_repo, 237 | folder_path=args.save_dir, 238 | ) 239 | print(f'{args.pretrained} pushed to {args.push_to_repo}.') 240 | -------------------------------------------------------------------------------- /train/src/open_clip/timm_model.py: -------------------------------------------------------------------------------- 1 | 2 | """ timm model adapter 3 | 4 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 5 | """ 6 | import logging 7 | from collections import OrderedDict 8 | from typing import Dict, List, Optional, Tuple, Union 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | try: 14 | import timm 15 | from timm.layers import RotAttentionPool2d 16 | from timm.layers import AttentionPool2d as AbsAttentionPool2d 17 | from timm.layers import Mlp, to_2tuple 18 | except ImportError: 19 | timm = None 20 | 21 | from .utils import freeze_batch_norm_2d 22 | 23 | 24 | class TimmModel(nn.Module): 25 | """ timm model adapter 26 | """ 27 | 28 | def __init__( 29 | self, 30 | model_name: str, 31 | embed_dim: int, 32 | image_size: Union[int, Tuple[int, int]] = 224, 33 | pool: str = 'avg', 34 | proj: str = 'linear', 35 | proj_bias: bool = False, 36 | drop: float = 0., 37 | drop_path: Optional[float] = None, 38 | patch_drop: Optional[float] = None, 39 | pretrained: bool = False, 40 | ): 41 | super().__init__() 42 | if timm is None: 43 | raise RuntimeError("Please install the latest timm (`pip install timm`) to use timm based models.") 44 | self.image_size = to_2tuple(image_size) 45 | 46 | # setup kwargs that may not be common across all models 47 | timm_kwargs = {} 48 | if drop_path is not None: 49 | timm_kwargs['drop_path_rate'] = drop_path 50 | if patch_drop is not None: 51 | timm_kwargs['patch_drop_rate'] = patch_drop 52 | 53 | custom_pool = pool in ('abs_attn', 'rot_attn') 54 | if proj: 55 | assert proj in ("linear", "mlp", "none") 56 | extra_proj = proj in ("linear", "mlp") 57 | if not extra_proj and not custom_pool: 58 | # use network classifier head as projection if no proj specified and no custom pooling used 59 | # if projection is explicitly set to "none" will be pass through from network trunk 60 | proj_dim = 0 if proj == 'none' else embed_dim 61 | self.trunk = timm.create_model( 62 | model_name, 63 | num_classes=proj_dim, 64 | global_pool=pool, 65 | pretrained=pretrained, 66 | **timm_kwargs, 67 | ) 68 | prev_chs = embed_dim 69 | else: 70 | self.trunk = timm.create_model( 71 | model_name, 72 | pretrained=pretrained, 73 | **timm_kwargs, 74 | ) 75 | feat_size = self.trunk.default_cfg.get('pool_size', None) 76 | feature_ndim = 1 if not feat_size else 2 77 | if custom_pool: 78 | assert feature_ndim == 2 79 | # if attn pooling used, remove both classifier and default pool 80 | self.trunk.reset_classifier(0, global_pool='') 81 | else: 82 | # reset global pool if pool config set, otherwise leave as network default 83 | reset_kwargs = dict(global_pool=pool) if pool else {} 84 | self.trunk.reset_classifier(0, **reset_kwargs) 85 | prev_chs = self.trunk.num_features 86 | 87 | head_layers = OrderedDict() 88 | 89 | # Add custom pooling to head 90 | if pool == 'abs_attn': 91 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 92 | prev_chs = embed_dim 93 | elif pool == 'rot_attn': 94 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 95 | prev_chs = embed_dim 96 | 97 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 98 | if proj == 'linear': 99 | head_layers['drop'] = nn.Dropout(drop) 100 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) 101 | elif proj == 'mlp': 102 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) 103 | 104 | self.head = nn.Sequential(head_layers) 105 | 106 | def lock(self, unlocked_groups: int = 0, freeze_bn_stats: bool = False): 107 | """ lock modules 108 | Args: 109 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 110 | """ 111 | if not unlocked_groups: 112 | # lock full model 113 | for param in self.trunk.parameters(): 114 | param.requires_grad = False 115 | if freeze_bn_stats: 116 | freeze_batch_norm_2d(self.trunk) 117 | else: 118 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 119 | try: 120 | # FIXME import here until API stable and in an official release 121 | from timm.models.helpers import group_parameters, group_modules 122 | except ImportError: 123 | raise RuntimeError( 124 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') 125 | matcher = self.trunk.group_matcher() 126 | gparams = group_parameters(self.trunk, matcher) 127 | max_layer_id = max(gparams.keys()) 128 | max_layer_id = max_layer_id - unlocked_groups 129 | for group_idx in range(max_layer_id + 1): 130 | group = gparams[group_idx] 131 | for param in group: 132 | self.trunk.get_parameter(param).requires_grad = False 133 | if freeze_bn_stats: 134 | gmodules = group_modules(self.trunk, matcher, reverse=True) 135 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 136 | freeze_batch_norm_2d(self.trunk, gmodules) 137 | 138 | @torch.jit.ignore 139 | def set_grad_checkpointing(self, enable: bool = True): 140 | try: 141 | self.trunk.set_grad_checkpointing(enable) 142 | except Exception as e: 143 | logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') 144 | 145 | def forward_intermediates( 146 | self, 147 | x: torch.Tensor, 148 | indices: Optional[Union[int, List[int]]] = None, 149 | stop_early: bool = False, 150 | normalize_intermediates: bool = False, 151 | intermediates_only: bool = False, 152 | output_fmt: str = 'NCHW', 153 | output_extra_tokens: bool = False, 154 | ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: 155 | """ Forward features that returns intermediates. 156 | 157 | Args: 158 | x: Input image tensor 159 | indices: Take last n blocks if int, all if None, select matching indices if sequence 160 | stop_early: Stop iterating over blocks when last desired intermediate hit 161 | normalize_intermediates: Apply norm layer to all intermediates 162 | intermediates_only: Only return intermediate features 163 | output_fmt: Shape of intermediate feature outputs 164 | output_extra_tokens: Return both prefix and spatial intermediate tokens 165 | Returns: 166 | """ 167 | extra_args = {} 168 | if output_extra_tokens: 169 | extra_args['return_prefix_tokens'] = True 170 | trunk_output = self.trunk.forward_intermediates( 171 | x, 172 | indices=indices, 173 | intermediates_only=intermediates_only, 174 | norm=normalize_intermediates, 175 | stop_early=stop_early, 176 | output_fmt=output_fmt, 177 | **extra_args, 178 | ) 179 | 180 | return_dict = {} 181 | intermediates = trunk_output if intermediates_only else trunk_output[1] 182 | if output_extra_tokens and intermediates and isinstance(intermediates[0], tuple): 183 | intermediates_prefix = [xi[1] for xi in intermediates] 184 | intermediates = [xi[0] for xi in intermediates] 185 | return_dict['image_intermediates_prefix'] = intermediates_prefix 186 | 187 | return_dict['image_intermediates'] = intermediates 188 | if intermediates_only: 189 | return return_dict 190 | 191 | image_features = self.trunk.forward_head(trunk_output[0]) # run through timm pooling / projection 192 | image_features = self.head(image_features) # run through adapter pooling / projection 193 | return_dict['image_features'] = image_features 194 | return return_dict 195 | 196 | def forward(self, x): 197 | x = self.trunk(x) 198 | x = self.head(x) 199 | return x 200 | -------------------------------------------------------------------------------- /train/src/open_clip/utils.py: -------------------------------------------------------------------------------- 1 | import collections.abc 2 | from itertools import repeat 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import torch 6 | from torch import nn as nn 7 | from torch import _assert 8 | from torchvision.ops.misc import FrozenBatchNorm2d 9 | 10 | 11 | def freeze_batch_norm_2d(module, module_match={}, name=''): 12 | """ 13 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 14 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 15 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 16 | 17 | Args: 18 | module (torch.nn.Module): Any PyTorch module. 19 | module_match (dict): Dictionary of full module names to freeze (all if empty) 20 | name (str): Full module name (prefix) 21 | 22 | Returns: 23 | torch.nn.Module: Resulting module 24 | 25 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 26 | """ 27 | res = module 28 | is_match = True 29 | if module_match: 30 | is_match = name in module_match 31 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): 32 | res = FrozenBatchNorm2d(module.num_features) 33 | res.num_features = module.num_features 34 | res.affine = module.affine 35 | if module.affine: 36 | res.weight.data = module.weight.data.clone().detach() 37 | res.bias.data = module.bias.data.clone().detach() 38 | res.running_mean.data = module.running_mean.data 39 | res.running_var.data = module.running_var.data 40 | res.eps = module.eps 41 | else: 42 | for child_name, child in module.named_children(): 43 | full_child_name = '.'.join([name, child_name]) if name else child_name 44 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 45 | if new_child is not child: 46 | res.add_module(child_name, new_child) 47 | return res 48 | 49 | 50 | # From PyTorch internals 51 | def _ntuple(n): 52 | def parse(x): 53 | if isinstance(x, collections.abc.Iterable): 54 | return x 55 | return tuple(repeat(x, n)) 56 | return parse 57 | 58 | 59 | to_1tuple = _ntuple(1) 60 | to_2tuple = _ntuple(2) 61 | to_3tuple = _ntuple(3) 62 | to_4tuple = _ntuple(4) 63 | to_ntuple = lambda n, x: _ntuple(n)(x) 64 | 65 | # Replaces all linear layers with linear_replacement 66 | # TODO: add int8 support for other linear layers including attn and convnets 67 | def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True): 68 | for name, module in model.named_children(): 69 | if len(list(module.children())) > 0: 70 | replace_linear(module, linear_replacement, include_modules, copy_weights) 71 | 72 | if isinstance(module, torch.nn.Linear) and name in include_modules: 73 | old_module = model._modules[name] 74 | model._modules[name] = linear_replacement( 75 | module.in_features, 76 | module.out_features, 77 | module.bias is not None, 78 | ) 79 | if copy_weights: 80 | model._modules[name].weight.data.copy_(old_module.weight.data) 81 | if model._modules[name].bias is not None: 82 | model._modules[name].bias.data.copy_(old_module.bias) 83 | 84 | return model 85 | 86 | def convert_int8_model_to_inference_mode(model): 87 | for m in model.modules(): 88 | if hasattr(m, 'prepare_for_eval'): 89 | int8_original_dtype = m.weight.dtype 90 | m.prepare_for_eval() 91 | m.int8_original_dtype = int8_original_dtype 92 | 93 | 94 | def feature_take_indices( 95 | num_features: int, 96 | indices: Optional[Union[int, List[int]]] = None, 97 | as_set: bool = False, 98 | ) -> Tuple[List[int], int]: 99 | """ Determine the absolute feature indices to 'take' from. 100 | 101 | Note: This function can be called in forward() so must be torchscript compatible, 102 | which requires some incomplete typing and workaround hacks. 103 | 104 | Args: 105 | num_features: total number of features to select from 106 | indices: indices to select, 107 | None -> select all 108 | int -> select last n 109 | list/tuple of int -> return specified (-ve indices specify from end) 110 | as_set: return as a set 111 | 112 | Returns: 113 | List (or set) of absolute (from beginning) indices, Maximum index 114 | """ 115 | if indices is None: 116 | indices = num_features # all features if None 117 | 118 | if isinstance(indices, int): 119 | # convert int -> last n indices 120 | _assert(0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})') 121 | take_indices = [num_features - indices + i for i in range(indices)] 122 | else: 123 | take_indices: List[int] = [] 124 | for i in indices: 125 | idx = num_features + i if i < 0 else i 126 | _assert(0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})') 127 | take_indices.append(idx) 128 | 129 | if not torch.jit.is_scripting() and as_set: 130 | return set(take_indices), max(take_indices) 131 | 132 | return take_indices, max(take_indices) 133 | 134 | 135 | def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]: 136 | if isinstance(x, int): 137 | # if indices is an int, take last N features 138 | return tuple(range(-x, 0)) 139 | return tuple(x) -------------------------------------------------------------------------------- /train/src/open_clip/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '2.31.0' 2 | -------------------------------------------------------------------------------- /train/src/open_clip/zero_shot_classifier.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from itertools import islice 3 | from typing import Callable, List, Optional, Sequence, Union 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | def batched(iterable, n): 10 | """Batch data into lists of length *n*. The last batch may be shorter. 11 | NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl 12 | """ 13 | it = iter(iterable) 14 | while True: 15 | batch = list(islice(it, n)) 16 | if not batch: 17 | break 18 | yield batch 19 | 20 | 21 | def build_zero_shot_classifier( 22 | model, 23 | tokenizer, 24 | classnames: Sequence[str], 25 | templates: Sequence[Union[Callable, str]], 26 | num_classes_per_batch: Optional[int] = 10, 27 | device: Union[str, torch.device] = 'cpu', 28 | use_tqdm: bool = False, 29 | ): 30 | """ Build zero-shot classifier weights by iterating over class names in batches 31 | Args: 32 | model: CLIP model instance 33 | tokenizer: CLIP tokenizer instance 34 | classnames: A sequence of class (label) names 35 | templates: A sequence of callables or format() friendly strings to produce templates per class name 36 | num_classes_per_batch: The number of classes to batch together in each forward, all if None 37 | device: Device to use. 38 | use_tqdm: Enable TQDM progress bar. 39 | """ 40 | assert isinstance(templates, Sequence) and len(templates) > 0 41 | assert isinstance(classnames, Sequence) and len(classnames) > 0 42 | use_format = isinstance(templates[0], str) 43 | num_templates = len(templates) 44 | num_classes = len(classnames) 45 | if use_tqdm: 46 | import tqdm 47 | num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1) 48 | iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch) 49 | else: 50 | iter_wrap = iter 51 | 52 | def _process_batch(batch_classnames): 53 | num_batch_classes = len(batch_classnames) 54 | texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates] 55 | texts = tokenizer(texts).to(device) 56 | class_embeddings = model.encode_text(texts, normalize=True) 57 | class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1) 58 | class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True) 59 | class_embeddings = class_embeddings.T 60 | return class_embeddings 61 | 62 | with torch.no_grad(): 63 | if num_classes_per_batch: 64 | batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))] 65 | zeroshot_weights = torch.cat(batched_embeds, dim=1) 66 | else: 67 | zeroshot_weights = _process_batch(classnames) 68 | return zeroshot_weights 69 | 70 | 71 | def build_zero_shot_classifier_legacy( 72 | model, 73 | tokenizer, 74 | classnames: Sequence[str], 75 | templates: Sequence[Union[Callable, str]], 76 | device: Union[str, torch.device] = 'cpu', 77 | use_tqdm: bool = False, 78 | ): 79 | """ Build zero-shot classifier weights by iterating over class names 1 by 1 80 | Args: 81 | model: CLIP model instance 82 | tokenizer: CLIP tokenizer instance 83 | classnames: A sequence of class (label) names 84 | templates: A sequence of callables or format() friendly strings to produce templates per class name 85 | device: Device to use. 86 | use_tqdm: Enable TQDM progress bar. 87 | """ 88 | assert isinstance(templates, Sequence) and len(templates) > 0 89 | assert isinstance(classnames, Sequence) and len(classnames) > 0 90 | if use_tqdm: 91 | import tqdm 92 | iter_wrap = tqdm.tqdm 93 | else: 94 | iter_wrap = iter 95 | 96 | use_format = isinstance(templates[0], str) 97 | 98 | with torch.no_grad(): 99 | zeroshot_weights = [] 100 | for classname in iter_wrap(classnames): 101 | texts = [template.format(classname) if use_format else template(classname) for template in templates] 102 | texts = tokenizer(texts).to(device) # tokenize 103 | class_embeddings = model.encode_text(texts) 104 | class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) 105 | class_embedding /= class_embedding.norm() 106 | zeroshot_weights.append(class_embedding) 107 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) 108 | 109 | return zeroshot_weights -------------------------------------------------------------------------------- /train/src/open_clip_train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/PS3/529d66aa30b26438c30ce839197b2c0b2cf765e5/train/src/open_clip_train/__init__.py -------------------------------------------------------------------------------- /train/src/open_clip_train/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | from typing import Optional 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | try: 9 | import horovod.torch as hvd 10 | except ImportError: 11 | hvd = None 12 | 13 | 14 | def is_global_master(args): 15 | return args.rank == 0 16 | 17 | 18 | def is_local_master(args): 19 | return args.local_rank == 0 20 | 21 | 22 | def is_master(args, local=False): 23 | return is_local_master(args) if local else is_global_master(args) 24 | 25 | 26 | def is_device_available(device): 27 | device_type = torch.device(device).type 28 | is_avail = False 29 | is_known = False 30 | if device_type == 'cuda': 31 | is_avail = torch.cuda.is_available() 32 | is_known = True 33 | elif device_type == 'npu': 34 | # NOTE autoload device extension needed for this not to error out on this check 35 | is_avail = torch.npu.is_available() 36 | is_known = True 37 | elif device_type == 'mps': 38 | is_avail = torch.backends.mps.is_available() 39 | is_known = True 40 | elif device_type == 'cpu': 41 | is_avail = True 42 | is_known = True 43 | 44 | return is_avail, is_known 45 | 46 | 47 | def set_device(device): 48 | if device.startswith('cuda:'): 49 | torch.cuda.set_device(device) 50 | elif device.startswith('npu:'): 51 | torch.npu.set_device(device) 52 | 53 | 54 | def is_using_horovod(): 55 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set 56 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... 57 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] 58 | pmi_vars = ["PMI_RANK", "PMI_SIZE"] 59 | if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]): 60 | return True 61 | else: 62 | return False 63 | 64 | 65 | def is_using_distributed(): 66 | if 'WORLD_SIZE' in os.environ: 67 | return int(os.environ['WORLD_SIZE']) > 1 68 | if 'SLURM_NTASKS' in os.environ: 69 | return int(os.environ['SLURM_NTASKS']) > 1 70 | return False 71 | 72 | 73 | def world_info_from_env(): 74 | local_rank = 0 75 | for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'): 76 | if v in os.environ: 77 | local_rank = int(os.environ[v]) 78 | break 79 | global_rank = 0 80 | for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'): 81 | if v in os.environ: 82 | global_rank = int(os.environ[v]) 83 | break 84 | world_size = 1 85 | for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'): 86 | if v in os.environ: 87 | world_size = int(os.environ[v]) 88 | break 89 | 90 | return local_rank, global_rank, world_size 91 | 92 | 93 | def init_distributed_device(args): 94 | # Distributed training = training on more than one GPU. 95 | # Works in both single and multi-node scenarios. 96 | args.distributed = False 97 | args.world_size = 1 98 | args.rank = 0 # global rank 99 | args.local_rank = 0 100 | result = init_distributed_device_so( 101 | device=getattr(args, 'device', 'cuda'), 102 | dist_backend=getattr(args, 'dist_backend', None), 103 | dist_url=getattr(args, 'dist_url', None), 104 | horovod=getattr(args, 'horovod', False), 105 | no_set_device_rank=getattr(args, 'no_set_device_rank', False), 106 | ) 107 | args.device = result['device'] 108 | args.world_size = result['world_size'] 109 | args.rank = result['global_rank'] 110 | args.local_rank = result['local_rank'] 111 | args.distributed = result['distributed'] 112 | device = torch.device(args.device) 113 | return device 114 | 115 | 116 | def init_distributed_device_so( 117 | device: str = 'cuda', 118 | dist_backend: Optional[str] = None, 119 | dist_url: Optional[str] = None, 120 | horovod: bool = False, 121 | no_set_device_rank: bool = False, 122 | ): 123 | # Distributed training = training on more than one GPU. 124 | # Works in both single and multi-node scenarios. 125 | distributed = False 126 | world_size = 1 127 | global_rank = 0 128 | local_rank = 0 129 | device_type, *device_idx = device.split(':', maxsplit=1) 130 | is_avail, is_known = is_device_available(device_type) 131 | if not is_known: 132 | warnings.warn(f"Device {device} was not known and checked for availability, trying anyways.") 133 | elif not is_avail: 134 | warnings.warn(f"Device {device} was not available, falling back to CPU.") 135 | device_type = device = 'cpu' 136 | 137 | if horovod: 138 | import horovod.torch as hvd 139 | assert hvd is not None, "Horovod is not installed" 140 | hvd.init() 141 | local_rank = int(hvd.local_rank()) 142 | global_rank = hvd.rank() 143 | world_size = hvd.size() 144 | distributed = True 145 | elif is_using_distributed(): 146 | if dist_backend is None: 147 | dist_backends = { 148 | "cuda": "nccl", 149 | "hpu": "hccl", 150 | "npu": "hccl", 151 | "xpu": "ccl", 152 | } 153 | dist_backend = dist_backends.get(device_type, 'gloo') 154 | 155 | dist_url = dist_url or 'env://' 156 | 157 | if 'SLURM_PROCID' in os.environ: 158 | # DDP via SLURM 159 | local_rank, global_rank, world_size = world_info_from_env() 160 | # SLURM var -> torch.distributed vars in case needed 161 | os.environ['LOCAL_RANK'] = str(local_rank) 162 | os.environ['RANK'] = str(global_rank) 163 | os.environ['WORLD_SIZE'] = str(world_size) 164 | torch.distributed.init_process_group( 165 | backend=dist_backend, 166 | init_method=dist_url, 167 | world_size=world_size, 168 | rank=global_rank, 169 | ) 170 | else: 171 | # DDP via torchrun, torch.distributed.launch 172 | local_rank, _, _ = world_info_from_env() 173 | torch.distributed.init_process_group( 174 | backend=dist_backend, 175 | init_method=dist_url, 176 | ) 177 | world_size = torch.distributed.get_world_size() 178 | global_rank = torch.distributed.get_rank() 179 | distributed = True 180 | 181 | if distributed and not no_set_device_rank and device_type not in ('cpu', 'mps'): 182 | # Ignore manually specified device index in distributed mode and 183 | # override with resolved local rank, fewer headaches in most setups. 184 | if device_idx: 185 | warnings.warn(f'device index {device_idx[0]} removed from specified ({device}).') 186 | device = f'{device_type}:{local_rank}' 187 | set_device(device) 188 | 189 | return dict( 190 | device=device, 191 | global_rank=global_rank, 192 | local_rank=local_rank, 193 | world_size=world_size, 194 | distributed=distributed, 195 | ) 196 | 197 | 198 | def broadcast_object(args, obj, src=0): 199 | # broadcast a pickle-able python object from rank-0 to all ranks 200 | if args.horovod: 201 | return hvd.broadcast_object(obj, root_rank=src) 202 | else: 203 | if args.rank == src: 204 | objects = [obj] 205 | else: 206 | objects = [None] 207 | dist.broadcast_object_list(objects, src=src) 208 | return objects[0] 209 | 210 | 211 | def all_gather_object(args, obj, dst=0): 212 | # gather a pickle-able python object across all ranks 213 | if args.horovod: 214 | return hvd.allgather_object(obj) 215 | else: 216 | objects = [None for _ in range(args.world_size)] 217 | dist.all_gather_object(objects, obj) 218 | return objects -------------------------------------------------------------------------------- /train/src/open_clip_train/file_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import multiprocessing 4 | import subprocess 5 | import time 6 | import fsspec 7 | import torch 8 | from tqdm import tqdm 9 | 10 | def remote_sync_s3(local_dir, remote_dir): 11 | # skip epoch_latest which can change during sync. 12 | result = subprocess.run(["aws", "s3", "sync", local_dir, remote_dir, '--exclude', '*epoch_latest.pt'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) 13 | if result.returncode != 0: 14 | logging.error(f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}") 15 | return False 16 | 17 | logging.info(f"Successfully synced with S3 bucket") 18 | return True 19 | 20 | def remote_sync_fsspec(local_dir, remote_dir): 21 | # FIXME currently this is slow and not recommended. Look into speeding up. 22 | a = fsspec.get_mapper(local_dir) 23 | b = fsspec.get_mapper(remote_dir) 24 | 25 | for k in a: 26 | # skip epoch_latest which can change during sync. 27 | if 'epoch_latest.pt' in k: 28 | continue 29 | 30 | logging.info(f'Attempting to sync {k}') 31 | if k in b and len(a[k]) == len(b[k]): 32 | logging.debug(f'Skipping remote sync for {k}.') 33 | continue 34 | 35 | try: 36 | logging.info(f'Successful sync for {k}.') 37 | b[k] = a[k] 38 | except Exception as e: 39 | logging.info(f'Error during remote sync for {k}: {e}') 40 | return False 41 | 42 | return True 43 | 44 | def remote_sync(local_dir, remote_dir, protocol): 45 | logging.info('Starting remote sync.') 46 | if protocol == 's3': 47 | return remote_sync_s3(local_dir, remote_dir) 48 | elif protocol == 'fsspec': 49 | return remote_sync_fsspec(local_dir, remote_dir) 50 | else: 51 | logging.error('Remote protocol not known') 52 | return False 53 | 54 | def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol): 55 | while True: 56 | time.sleep(sync_every) 57 | remote_sync(local_dir, remote_dir, protocol) 58 | 59 | def start_sync_process(sync_every, local_dir, remote_dir, protocol): 60 | p = multiprocessing.Process(target=keep_running_remote_sync, args=(sync_every, local_dir, remote_dir, protocol)) 61 | return p 62 | 63 | # Note: we are not currently using this save function. 64 | def pt_save(pt_obj, file_path): 65 | of = fsspec.open(file_path, "wb") 66 | with of as f: 67 | torch.save(pt_obj, file_path) 68 | 69 | def pt_load(file_path, map_location=None): 70 | if file_path.startswith('s3'): 71 | logging.info('Loading remote checkpoint, which may take a bit.') 72 | of = fsspec.open(file_path, "rb") 73 | with of as f: 74 | out = torch.load(f, map_location=map_location) 75 | return out 76 | 77 | def check_exists(file_path): 78 | try: 79 | with fsspec.open(file_path): 80 | pass 81 | except FileNotFoundError: 82 | return False 83 | return True -------------------------------------------------------------------------------- /train/src/open_clip_train/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def setup_logging(log_file, level, include_host=False): 5 | if include_host: 6 | import socket 7 | hostname = socket.gethostname() 8 | formatter = logging.Formatter( 9 | f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 10 | else: 11 | formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 12 | 13 | logging.root.setLevel(level) 14 | loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] 15 | for logger in loggers: 16 | logger.setLevel(level) 17 | 18 | stream_handler = logging.StreamHandler() 19 | stream_handler.setFormatter(formatter) 20 | logging.root.addHandler(stream_handler) 21 | 22 | if log_file: 23 | file_handler = logging.FileHandler(filename=log_file) 24 | file_handler.setFormatter(formatter) 25 | logging.root.addHandler(file_handler) -------------------------------------------------------------------------------- /train/src/open_clip_train/precision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from contextlib import suppress 3 | from functools import partial 4 | 5 | 6 | def get_autocast(precision, device_type='cuda'): 7 | if precision =='amp': 8 | amp_dtype = torch.float16 9 | elif precision == 'amp_bfloat16' or precision == 'amp_bf16': 10 | amp_dtype = torch.bfloat16 11 | else: 12 | return suppress 13 | 14 | return partial(torch.amp.autocast, device_type=device_type, dtype=amp_dtype) -------------------------------------------------------------------------------- /train/src/open_clip_train/scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def assign_learning_rate(optimizer, new_lr): 5 | for param_group in optimizer.param_groups: 6 | param_group["lr"] = new_lr 7 | 8 | 9 | def _warmup_lr(base_lr, warmup_length, step): 10 | return base_lr * (step + 1) / warmup_length 11 | 12 | 13 | def const_lr(optimizer, base_lr, warmup_length, steps): 14 | def _lr_adjuster(step): 15 | if step < warmup_length: 16 | lr = _warmup_lr(base_lr, warmup_length, step) 17 | else: 18 | lr = base_lr 19 | assign_learning_rate(optimizer, lr) 20 | return lr 21 | 22 | return _lr_adjuster 23 | 24 | 25 | def const_lr_cooldown(optimizer, base_lr, warmup_length, steps, cooldown_steps, cooldown_power=1.0, cooldown_end_lr=0.): 26 | def _lr_adjuster(step): 27 | start_cooldown_step = steps - cooldown_steps 28 | if step < warmup_length: 29 | lr = _warmup_lr(base_lr, warmup_length, step) 30 | else: 31 | if step < start_cooldown_step: 32 | lr = base_lr 33 | else: 34 | e = step - start_cooldown_step 35 | es = steps - start_cooldown_step 36 | # linear decay if power == 1; polynomial decay otherwise; 37 | decay = (1 - (e / es)) ** cooldown_power 38 | lr = decay * (base_lr - cooldown_end_lr) + cooldown_end_lr 39 | assign_learning_rate(optimizer, lr) 40 | return lr 41 | 42 | return _lr_adjuster 43 | 44 | 45 | def cosine_lr(optimizer, base_lr, warmup_length, steps): 46 | def _lr_adjuster(step): 47 | if step < warmup_length: 48 | lr = _warmup_lr(base_lr, warmup_length, step) 49 | else: 50 | e = step - warmup_length 51 | es = steps - warmup_length 52 | lr = 0.5 * (1 + math.cos(math.pi * e / es)) * base_lr 53 | assign_learning_rate(optimizer, lr) 54 | return lr 55 | 56 | return _lr_adjuster 57 | 58 | -------------------------------------------------------------------------------- /train/src/open_clip_train/zero_shot.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from tqdm import tqdm 5 | 6 | from open_clip import get_input_dtype, get_tokenizer, build_zero_shot_classifier, \ 7 | IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES 8 | from open_clip_train.precision import get_autocast 9 | 10 | 11 | def accuracy(output, target, topk=(1,)): 12 | pred = output.topk(max(topk), 1, True, True)[1].t() 13 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 14 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] 15 | 16 | 17 | def run(model, classifier, dataloader, args): 18 | device = torch.device(args.device) 19 | autocast = get_autocast(args.precision, device_type=device.type) 20 | input_dtype = get_input_dtype(args.precision) 21 | 22 | with torch.inference_mode(): 23 | top1, top5, n = 0., 0., 0. 24 | for images, target in tqdm(dataloader, unit_scale=args.batch_size): 25 | images = images.to(device=device, dtype=input_dtype) 26 | target = target.to(device) 27 | 28 | with autocast(): 29 | # predict 30 | output = model(image=images) 31 | image_features = output['image_features'] if isinstance(output, dict) else output[0] 32 | logits = 100. * image_features @ classifier 33 | 34 | # measure accuracy 35 | acc1, acc5 = accuracy(logits, target, topk=(1, 5)) 36 | top1 += acc1 37 | top5 += acc5 38 | n += images.size(0) 39 | 40 | top1 = (top1 / n) 41 | top5 = (top5 / n) 42 | return top1, top5 43 | 44 | 45 | def zero_shot_eval(model, data, epoch, args, tokenizer=None): 46 | if 'imagenet-val' not in data and 'imagenet-v2' not in data: 47 | return {} 48 | if args.zeroshot_frequency == 0: 49 | return {} 50 | if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: 51 | return {} 52 | if args.distributed and not args.horovod: 53 | model = model.module 54 | 55 | logging.info('Starting zero-shot imagenet.') 56 | if tokenizer is None: 57 | tokenizer = get_tokenizer(args.model) 58 | 59 | logging.info('Building zero-shot classifier') 60 | device = torch.device(args.device) 61 | autocast = get_autocast(args.precision, device_type=device.type) 62 | with autocast(): 63 | classifier = build_zero_shot_classifier( 64 | model, 65 | tokenizer=tokenizer, 66 | classnames=IMAGENET_CLASSNAMES, 67 | templates=OPENAI_IMAGENET_TEMPLATES, 68 | num_classes_per_batch=10, 69 | device=device, 70 | use_tqdm=True, 71 | ) 72 | 73 | logging.info('Using classifier') 74 | results = {} 75 | if 'imagenet-val' in data: 76 | top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args) 77 | results['imagenet-zeroshot-val-top1'] = top1 78 | results['imagenet-zeroshot-val-top5'] = top5 79 | if 'imagenet-v2' in data: 80 | top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args) 81 | results['imagenetv2-zeroshot-val-top1'] = top1 82 | results['imagenetv2-zeroshot-val-top5'] = top5 83 | 84 | logging.info('Finished zero-shot imagenet.') 85 | 86 | return results 87 | --------------------------------------------------------------------------------