├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── encoders_flux.py ├── eva_clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── constants.py ├── eva_vit_model.py ├── factory.py ├── hf_configs.py ├── hf_model.py ├── loss.py ├── model.py ├── model_configs │ ├── EVA01-CLIP-B-16.json │ ├── EVA01-CLIP-g-14-plus.json │ ├── EVA01-CLIP-g-14.json │ ├── EVA02-CLIP-B-16.json │ ├── EVA02-CLIP-L-14-336.json │ ├── EVA02-CLIP-L-14.json │ ├── EVA02-CLIP-bigE-14-plus.json │ └── EVA02-CLIP-bigE-14.json ├── modified_resnet.py ├── openai.py ├── pretrained.py ├── rope.py ├── timm_model.py ├── tokenizer.py ├── transform.py ├── transformer.py └── utils.py ├── examples ├── einstein.jpg ├── pulid_flux_16bit_simple.json ├── pulid_flux_8bitgguf_simple.json └── pulid_flux_einstein.png ├── pulidflux.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # ignore specific directories 2 | __pycache__/ 3 | \!misc/ 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/balazik/ComfyUI-PuLID-Flux/a80912fc3435c358607bf4b43a58dbcbebdb09ff/README.md -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .pulidflux import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | 3 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] 4 | -------------------------------------------------------------------------------- /encoders_flux.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | # FFN 8 | def FeedForward(dim, mult=4): 9 | inner_dim = int(dim * mult) 10 | return nn.Sequential( 11 | nn.LayerNorm(dim), 12 | nn.Linear(dim, inner_dim, bias=False), 13 | nn.GELU(), 14 | nn.Linear(inner_dim, dim, bias=False), 15 | ) 16 | 17 | 18 | def reshape_tensor(x, heads): 19 | bs, length, width = x.shape 20 | # (bs, length, width) --> (bs, length, n_heads, dim_per_head) 21 | x = x.view(bs, length, heads, -1) 22 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 23 | x = x.transpose(1, 2) 24 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 25 | x = x.reshape(bs, heads, length, -1) 26 | return x 27 | 28 | 29 | class PerceiverAttentionCA(nn.Module): 30 | def __init__(self, *, dim=3072, dim_head=128, heads=16, kv_dim=2048): 31 | super().__init__() 32 | self.scale = dim_head ** -0.5 33 | self.dim_head = dim_head 34 | self.heads = heads 35 | inner_dim = dim_head * heads 36 | 37 | self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim) 38 | self.norm2 = nn.LayerNorm(dim) 39 | 40 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 41 | self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False) 42 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 43 | 44 | def forward(self, x, latents): 45 | """ 46 | Args: 47 | x (torch.Tensor): image features 48 | shape (b, n1, D) 49 | latent (torch.Tensor): latent features 50 | shape (b, n2, D) 51 | """ 52 | x = self.norm1(x) 53 | latents = self.norm2(latents) 54 | 55 | b, seq_len, _ = latents.shape 56 | 57 | q = self.to_q(latents) 58 | k, v = self.to_kv(x).chunk(2, dim=-1) 59 | 60 | q = reshape_tensor(q, self.heads) 61 | k = reshape_tensor(k, self.heads) 62 | v = reshape_tensor(v, self.heads) 63 | 64 | # attention 65 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 66 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 67 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 68 | out = weight @ v 69 | 70 | out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1) 71 | 72 | return self.to_out(out) 73 | 74 | 75 | class PerceiverAttention(nn.Module): 76 | def __init__(self, *, dim, dim_head=64, heads=8, kv_dim=None): 77 | super().__init__() 78 | self.scale = dim_head ** -0.5 79 | self.dim_head = dim_head 80 | self.heads = heads 81 | inner_dim = dim_head * heads 82 | 83 | self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim) 84 | self.norm2 = nn.LayerNorm(dim) 85 | 86 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 87 | self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False) 88 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 89 | 90 | def forward(self, x, latents): 91 | """ 92 | Args: 93 | x (torch.Tensor): image features 94 | shape (b, n1, D) 95 | latent (torch.Tensor): latent features 96 | shape (b, n2, D) 97 | """ 98 | x = self.norm1(x) 99 | latents = self.norm2(latents) 100 | 101 | b, seq_len, _ = latents.shape 102 | 103 | q = self.to_q(latents) 104 | kv_input = torch.cat((x, latents), dim=-2) 105 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 106 | 107 | q = reshape_tensor(q, self.heads) 108 | k = reshape_tensor(k, self.heads) 109 | v = reshape_tensor(v, self.heads) 110 | 111 | # attention 112 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 113 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 114 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 115 | out = weight @ v 116 | 117 | out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1) 118 | 119 | return self.to_out(out) 120 | 121 | 122 | class IDFormer(nn.Module): 123 | """ 124 | - perceiver resampler like arch (compared with previous MLP-like arch) 125 | - we concat id embedding (generated by arcface) and query tokens as latents 126 | - latents will attend each other and interact with vit features through cross-attention 127 | - vit features are multi-scaled and inserted into IDFormer in order, currently, each scale corresponds to two 128 | IDFormer layers 129 | """ 130 | def __init__( 131 | self, 132 | dim=1024, 133 | depth=10, 134 | dim_head=64, 135 | heads=16, 136 | num_id_token=5, 137 | num_queries=32, 138 | output_dim=2048, 139 | ff_mult=4, 140 | ): 141 | super().__init__() 142 | 143 | self.num_id_token = num_id_token 144 | self.dim = dim 145 | self.num_queries = num_queries 146 | assert depth % 5 == 0 147 | self.depth = depth // 5 148 | scale = dim ** -0.5 149 | 150 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) * scale) 151 | self.proj_out = nn.Parameter(scale * torch.randn(dim, output_dim)) 152 | 153 | self.layers = nn.ModuleList([]) 154 | for _ in range(depth): 155 | self.layers.append( 156 | nn.ModuleList( 157 | [ 158 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 159 | FeedForward(dim=dim, mult=ff_mult), 160 | ] 161 | ) 162 | ) 163 | 164 | for i in range(5): 165 | setattr( 166 | self, 167 | f'mapping_{i}', 168 | nn.Sequential( 169 | nn.Linear(1024, 1024), 170 | nn.LayerNorm(1024), 171 | nn.LeakyReLU(), 172 | nn.Linear(1024, 1024), 173 | nn.LayerNorm(1024), 174 | nn.LeakyReLU(), 175 | nn.Linear(1024, dim), 176 | ), 177 | ) 178 | 179 | self.id_embedding_mapping = nn.Sequential( 180 | nn.Linear(1280, 1024), 181 | nn.LayerNorm(1024), 182 | nn.LeakyReLU(), 183 | nn.Linear(1024, 1024), 184 | nn.LayerNorm(1024), 185 | nn.LeakyReLU(), 186 | nn.Linear(1024, dim * num_id_token), 187 | ) 188 | 189 | def forward(self, x, y): 190 | 191 | latents = self.latents.repeat(x.size(0), 1, 1) 192 | 193 | x = self.id_embedding_mapping(x) 194 | x = x.reshape(-1, self.num_id_token, self.dim) 195 | 196 | latents = torch.cat((latents, x), dim=1) 197 | 198 | for i in range(5): 199 | vit_feature = getattr(self, f'mapping_{i}')(y[i]) 200 | ctx_feature = torch.cat((x, vit_feature), dim=1) 201 | for attn, ff in self.layers[i * self.depth: (i + 1) * self.depth]: 202 | latents = attn(ctx_feature, latents) + latents 203 | latents = ff(latents) + latents 204 | 205 | latents = latents[:, :self.num_queries] 206 | latents = latents @ self.proj_out 207 | return latents 208 | -------------------------------------------------------------------------------- /eva_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 2 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_transforms 3 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint 4 | from .loss import ClipLoss 5 | from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg,\ 6 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype 7 | from .openai import load_openai_model, list_openai_models 8 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\ 9 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 10 | from .tokenizer import SimpleTokenizer, tokenize 11 | from .transform import image_transform -------------------------------------------------------------------------------- /eva_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/balazik/ComfyUI-PuLID-Flux/a80912fc3435c358607bf4b43a58dbcbebdb09ff/eva_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /eva_clip/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 3 | -------------------------------------------------------------------------------- /eva_clip/factory.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import pathlib 5 | import re 6 | from copy import deepcopy 7 | from pathlib import Path 8 | from typing import Optional, Tuple, Union, Dict, Any 9 | import torch 10 | 11 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 12 | from .model import CLIP, CustomCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ 13 | get_cast_dtype 14 | from .openai import load_openai_model 15 | from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model 16 | from .transform import image_transform 17 | from .tokenizer import HFTokenizer, tokenize 18 | from .utils import resize_clip_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed, resize_eva_pos_embed 19 | 20 | 21 | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] 22 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs 23 | 24 | 25 | def _natural_key(string_): 26 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 27 | 28 | 29 | def _rescan_model_configs(): 30 | global _MODEL_CONFIGS 31 | 32 | config_ext = ('.json',) 33 | config_files = [] 34 | for config_path in _MODEL_CONFIG_PATHS: 35 | if config_path.is_file() and config_path.suffix in config_ext: 36 | config_files.append(config_path) 37 | elif config_path.is_dir(): 38 | for ext in config_ext: 39 | config_files.extend(config_path.glob(f'*{ext}')) 40 | 41 | for cf in config_files: 42 | with open(cf, "r", encoding="utf8") as f: 43 | model_cfg = json.load(f) 44 | if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): 45 | _MODEL_CONFIGS[cf.stem] = model_cfg 46 | 47 | _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))) 48 | 49 | 50 | _rescan_model_configs() # initial populate of model config registry 51 | 52 | 53 | def list_models(): 54 | """ enumerate available model architectures based on config files """ 55 | return list(_MODEL_CONFIGS.keys()) 56 | 57 | 58 | def add_model_config(path): 59 | """ add model config path or file and update registry """ 60 | if not isinstance(path, Path): 61 | path = Path(path) 62 | _MODEL_CONFIG_PATHS.append(path) 63 | _rescan_model_configs() 64 | 65 | 66 | def get_model_config(model_name): 67 | if model_name in _MODEL_CONFIGS: 68 | return deepcopy(_MODEL_CONFIGS[model_name]) 69 | else: 70 | return None 71 | 72 | 73 | def get_tokenizer(model_name): 74 | config = get_model_config(model_name) 75 | tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize 76 | return tokenizer 77 | 78 | 79 | # loading openai CLIP weights when is_openai=True for training 80 | def load_state_dict(checkpoint_path: str, map_location: str='cpu', model_key: str='model|module|state_dict', is_openai: bool=False, skip_list: list=[]): 81 | if is_openai: 82 | model = torch.jit.load(checkpoint_path, map_location="cpu").eval() 83 | state_dict = model.state_dict() 84 | for key in ["input_resolution", "context_length", "vocab_size"]: 85 | state_dict.pop(key, None) 86 | else: 87 | checkpoint = torch.load(checkpoint_path, map_location=map_location) 88 | for mk in model_key.split('|'): 89 | if isinstance(checkpoint, dict) and mk in checkpoint: 90 | state_dict = checkpoint[mk] 91 | break 92 | else: 93 | state_dict = checkpoint 94 | if next(iter(state_dict.items()))[0].startswith('module'): 95 | state_dict = {k[7:]: v for k, v in state_dict.items()} 96 | 97 | for k in skip_list: 98 | if k in list(state_dict.keys()): 99 | logging.info(f"Removing key {k} from pretrained checkpoint") 100 | del state_dict[k] 101 | 102 | if os.getenv('RoPE') == '1': 103 | for k in list(state_dict.keys()): 104 | if 'freqs_cos' in k or 'freqs_sin' in k: 105 | del state_dict[k] 106 | return state_dict 107 | 108 | 109 | 110 | def load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=True): 111 | state_dict = load_state_dict(checkpoint_path, model_key=model_key, is_openai=False) 112 | # detect old format and make compatible with new format 113 | if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): 114 | state_dict = convert_to_custom_text_state_dict(state_dict) 115 | if 'text.logit_scale' in state_dict and hasattr(model, 'logit_scale'): 116 | state_dict['logit_scale'] = state_dict['text.logit_scale'] 117 | del state_dict['text.logit_scale'] 118 | 119 | # resize_clip_pos_embed for CLIP and open CLIP 120 | if 'visual.positional_embedding' in state_dict: 121 | resize_clip_pos_embed(state_dict, model) 122 | # specified to eva_vit_model 123 | elif 'visual.pos_embed' in state_dict: 124 | resize_evaclip_pos_embed(state_dict, model) 125 | 126 | # resize_clip_pos_embed(state_dict, model) 127 | incompatible_keys = model.load_state_dict(state_dict, strict=strict) 128 | logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}") 129 | return incompatible_keys 130 | 131 | def load_clip_visual_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]): 132 | state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list) 133 | 134 | for k in list(state_dict.keys()): 135 | if not k.startswith('visual.'): 136 | del state_dict[k] 137 | for k in list(state_dict.keys()): 138 | if k.startswith('visual.'): 139 | new_k = k[7:] 140 | state_dict[new_k] = state_dict[k] 141 | del state_dict[k] 142 | return state_dict 143 | 144 | def load_clip_text_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]): 145 | state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list) 146 | 147 | for k in list(state_dict.keys()): 148 | if k.startswith('visual.'): 149 | del state_dict[k] 150 | return state_dict 151 | 152 | def get_pretrained_tag(pretrained_model): 153 | pretrained_model = pretrained_model.lower() 154 | if "laion" in pretrained_model or "open_clip" in pretrained_model: 155 | return "open_clip" 156 | elif "openai" in pretrained_model: 157 | return "clip" 158 | elif "eva" in pretrained_model and "clip" in pretrained_model: 159 | return "eva_clip" 160 | else: 161 | return "other" 162 | 163 | def load_pretrained_checkpoint( 164 | model, 165 | visual_checkpoint_path, 166 | text_checkpoint_path, 167 | strict=True, 168 | visual_model=None, 169 | text_model=None, 170 | model_key="model|module|state_dict", 171 | skip_list=[]): 172 | visual_tag = get_pretrained_tag(visual_model) 173 | text_tag = get_pretrained_tag(text_model) 174 | 175 | logging.info(f"num of model state_dict keys: {len(model.state_dict().keys())}") 176 | visual_incompatible_keys, text_incompatible_keys = None, None 177 | if visual_checkpoint_path: 178 | if visual_tag == "eva_clip" or visual_tag == "open_clip": 179 | visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=False, skip_list=skip_list) 180 | elif visual_tag == "clip": 181 | visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=True, skip_list=skip_list) 182 | else: 183 | visual_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list) 184 | 185 | # resize_clip_pos_embed for CLIP and open CLIP 186 | if 'positional_embedding' in visual_state_dict: 187 | resize_visual_pos_embed(visual_state_dict, model) 188 | # specified to EVA model 189 | elif 'pos_embed' in visual_state_dict: 190 | resize_eva_pos_embed(visual_state_dict, model) 191 | 192 | visual_incompatible_keys = model.visual.load_state_dict(visual_state_dict, strict=strict) 193 | logging.info(f"num of loaded visual_state_dict keys: {len(visual_state_dict.keys())}") 194 | logging.info(f"visual_incompatible_keys.missing_keys: {visual_incompatible_keys.missing_keys}") 195 | 196 | if text_checkpoint_path: 197 | if text_tag == "eva_clip" or text_tag == "open_clip": 198 | text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=False, skip_list=skip_list) 199 | elif text_tag == "clip": 200 | text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=True, skip_list=skip_list) 201 | else: 202 | text_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list) 203 | 204 | text_incompatible_keys = model.text.load_state_dict(text_state_dict, strict=strict) 205 | 206 | logging.info(f"num of loaded text_state_dict keys: {len(text_state_dict.keys())}") 207 | logging.info(f"text_incompatible_keys.missing_keys: {text_incompatible_keys.missing_keys}") 208 | 209 | return visual_incompatible_keys, text_incompatible_keys 210 | 211 | def create_model( 212 | model_name: str, 213 | pretrained: Optional[str] = None, 214 | precision: str = 'fp32', 215 | device: Union[str, torch.device] = 'cpu', 216 | jit: bool = False, 217 | force_quick_gelu: bool = False, 218 | force_custom_clip: bool = False, 219 | force_patch_dropout: Optional[float] = None, 220 | pretrained_image: str = '', 221 | pretrained_text: str = '', 222 | pretrained_hf: bool = True, 223 | pretrained_visual_model: str = None, 224 | pretrained_text_model: str = None, 225 | cache_dir: Optional[str] = None, 226 | skip_list: list = [], 227 | ): 228 | model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names 229 | if isinstance(device, str): 230 | device = torch.device(device) 231 | 232 | if pretrained and pretrained.lower() == 'openai': 233 | logging.info(f'Loading pretrained {model_name} from OpenAI.') 234 | model = load_openai_model( 235 | model_name, 236 | precision=precision, 237 | device=device, 238 | jit=jit, 239 | cache_dir=cache_dir, 240 | ) 241 | else: 242 | model_cfg = get_model_config(model_name) 243 | if model_cfg is not None: 244 | logging.info(f'Loaded {model_name} model config.') 245 | else: 246 | logging.error(f'Model config for {model_name} not found; available models {list_models()}.') 247 | raise RuntimeError(f'Model config for {model_name} not found.') 248 | 249 | if 'rope' in model_cfg.get('vision_cfg', {}): 250 | if model_cfg['vision_cfg']['rope']: 251 | os.environ['RoPE'] = "1" 252 | else: 253 | os.environ['RoPE'] = "0" 254 | 255 | if force_quick_gelu: 256 | # override for use of QuickGELU on non-OpenAI transformer models 257 | model_cfg["quick_gelu"] = True 258 | 259 | if force_patch_dropout is not None: 260 | # override the default patch dropout value 261 | model_cfg['vision_cfg']["patch_dropout"] = force_patch_dropout 262 | 263 | cast_dtype = get_cast_dtype(precision) 264 | custom_clip = model_cfg.pop('custom_text', False) or force_custom_clip or ('hf_model_name' in model_cfg['text_cfg']) 265 | 266 | 267 | if custom_clip: 268 | if 'hf_model_name' in model_cfg.get('text_cfg', {}): 269 | model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf 270 | model = CustomCLIP(**model_cfg, cast_dtype=cast_dtype) 271 | else: 272 | model = CLIP(**model_cfg, cast_dtype=cast_dtype) 273 | 274 | pretrained_cfg = {} 275 | if pretrained: 276 | checkpoint_path = '' 277 | pretrained_cfg = get_pretrained_cfg(model_name, pretrained) 278 | if pretrained_cfg: 279 | checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) 280 | elif os.path.exists(pretrained): 281 | checkpoint_path = pretrained 282 | 283 | if checkpoint_path: 284 | logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') 285 | load_checkpoint(model, 286 | checkpoint_path, 287 | model_key="model|module|state_dict", 288 | strict=False 289 | ) 290 | else: 291 | error_str = ( 292 | f'Pretrained weights ({pretrained}) not found for model {model_name}.' 293 | f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') 294 | logging.warning(error_str) 295 | raise RuntimeError(error_str) 296 | else: 297 | visual_checkpoint_path = '' 298 | text_checkpoint_path = '' 299 | 300 | if pretrained_image: 301 | pretrained_visual_model = pretrained_visual_model.replace('/', '-') # for callers using old naming with / in ViT names 302 | pretrained_image_cfg = get_pretrained_cfg(pretrained_visual_model, pretrained_image) 303 | if 'timm_model_name' in model_cfg.get('vision_cfg', {}): 304 | # pretrained weight loading for timm models set via vision_cfg 305 | model_cfg['vision_cfg']['timm_model_pretrained'] = True 306 | elif pretrained_image_cfg: 307 | visual_checkpoint_path = download_pretrained(pretrained_image_cfg, cache_dir=cache_dir) 308 | elif os.path.exists(pretrained_image): 309 | visual_checkpoint_path = pretrained_image 310 | else: 311 | logging.warning(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.') 312 | raise RuntimeError(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.') 313 | 314 | if pretrained_text: 315 | pretrained_text_model = pretrained_text_model.replace('/', '-') # for callers using old naming with / in ViT names 316 | pretrained_text_cfg = get_pretrained_cfg(pretrained_text_model, pretrained_text) 317 | if pretrained_image_cfg: 318 | text_checkpoint_path = download_pretrained(pretrained_text_cfg, cache_dir=cache_dir) 319 | elif os.path.exists(pretrained_text): 320 | text_checkpoint_path = pretrained_text 321 | else: 322 | logging.warning(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.') 323 | raise RuntimeError(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.') 324 | 325 | if visual_checkpoint_path: 326 | logging.info(f'Loading pretrained {model_name}.visual weights ({visual_checkpoint_path}).') 327 | if text_checkpoint_path: 328 | logging.info(f'Loading pretrained {model_name}.text weights ({text_checkpoint_path}).') 329 | 330 | if visual_checkpoint_path or text_checkpoint_path: 331 | load_pretrained_checkpoint( 332 | model, 333 | visual_checkpoint_path, 334 | text_checkpoint_path, 335 | strict=False, 336 | visual_model=pretrained_visual_model, 337 | text_model=pretrained_text_model, 338 | model_key="model|module|state_dict", 339 | skip_list=skip_list 340 | ) 341 | 342 | if "fp16" in precision or "bf16" in precision: 343 | logging.info(f'convert precision to {precision}') 344 | model = model.to(torch.bfloat16) if 'bf16' in precision else model.to(torch.float16) 345 | 346 | model.to(device=device) 347 | 348 | # set image / mean metadata from pretrained_cfg if available, or use default 349 | model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN 350 | model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD 351 | 352 | if jit: 353 | model = torch.jit.script(model) 354 | 355 | return model 356 | 357 | 358 | def create_model_and_transforms( 359 | model_name: str, 360 | pretrained: Optional[str] = None, 361 | precision: str = 'fp32', 362 | device: Union[str, torch.device] = 'cpu', 363 | jit: bool = False, 364 | force_quick_gelu: bool = False, 365 | force_custom_clip: bool = False, 366 | force_patch_dropout: Optional[float] = None, 367 | pretrained_image: str = '', 368 | pretrained_text: str = '', 369 | pretrained_hf: bool = True, 370 | pretrained_visual_model: str = None, 371 | pretrained_text_model: str = None, 372 | image_mean: Optional[Tuple[float, ...]] = None, 373 | image_std: Optional[Tuple[float, ...]] = None, 374 | cache_dir: Optional[str] = None, 375 | skip_list: list = [], 376 | ): 377 | model = create_model( 378 | model_name, 379 | pretrained, 380 | precision=precision, 381 | device=device, 382 | jit=jit, 383 | force_quick_gelu=force_quick_gelu, 384 | force_custom_clip=force_custom_clip, 385 | force_patch_dropout=force_patch_dropout, 386 | pretrained_image=pretrained_image, 387 | pretrained_text=pretrained_text, 388 | pretrained_hf=pretrained_hf, 389 | pretrained_visual_model=pretrained_visual_model, 390 | pretrained_text_model=pretrained_text_model, 391 | cache_dir=cache_dir, 392 | skip_list=skip_list, 393 | ) 394 | 395 | image_mean = image_mean or getattr(model.visual, 'image_mean', None) 396 | image_std = image_std or getattr(model.visual, 'image_std', None) 397 | preprocess_train = image_transform( 398 | model.visual.image_size, 399 | is_train=True, 400 | mean=image_mean, 401 | std=image_std 402 | ) 403 | preprocess_val = image_transform( 404 | model.visual.image_size, 405 | is_train=False, 406 | mean=image_mean, 407 | std=image_std 408 | ) 409 | 410 | return model, preprocess_train, preprocess_val 411 | 412 | 413 | def create_transforms( 414 | model_name: str, 415 | pretrained: Optional[str] = None, 416 | precision: str = 'fp32', 417 | device: Union[str, torch.device] = 'cpu', 418 | jit: bool = False, 419 | force_quick_gelu: bool = False, 420 | force_custom_clip: bool = False, 421 | force_patch_dropout: Optional[float] = None, 422 | pretrained_image: str = '', 423 | pretrained_text: str = '', 424 | pretrained_hf: bool = True, 425 | pretrained_visual_model: str = None, 426 | pretrained_text_model: str = None, 427 | image_mean: Optional[Tuple[float, ...]] = None, 428 | image_std: Optional[Tuple[float, ...]] = None, 429 | cache_dir: Optional[str] = None, 430 | skip_list: list = [], 431 | ): 432 | model = create_model( 433 | model_name, 434 | pretrained, 435 | precision=precision, 436 | device=device, 437 | jit=jit, 438 | force_quick_gelu=force_quick_gelu, 439 | force_custom_clip=force_custom_clip, 440 | force_patch_dropout=force_patch_dropout, 441 | pretrained_image=pretrained_image, 442 | pretrained_text=pretrained_text, 443 | pretrained_hf=pretrained_hf, 444 | pretrained_visual_model=pretrained_visual_model, 445 | pretrained_text_model=pretrained_text_model, 446 | cache_dir=cache_dir, 447 | skip_list=skip_list, 448 | ) 449 | 450 | 451 | image_mean = image_mean or getattr(model.visual, 'image_mean', None) 452 | image_std = image_std or getattr(model.visual, 'image_std', None) 453 | preprocess_train = image_transform( 454 | model.visual.image_size, 455 | is_train=True, 456 | mean=image_mean, 457 | std=image_std 458 | ) 459 | preprocess_val = image_transform( 460 | model.visual.image_size, 461 | is_train=False, 462 | mean=image_mean, 463 | std=image_std 464 | ) 465 | del model 466 | 467 | return preprocess_train, preprocess_val 468 | 469 | def create_model_from_pretrained( 470 | model_name: str, 471 | pretrained: str, 472 | precision: str = 'fp32', 473 | device: Union[str, torch.device] = 'cpu', 474 | jit: bool = False, 475 | force_quick_gelu: bool = False, 476 | force_custom_clip: bool = False, 477 | force_patch_dropout: Optional[float] = None, 478 | return_transform: bool = True, 479 | image_mean: Optional[Tuple[float, ...]] = None, 480 | image_std: Optional[Tuple[float, ...]] = None, 481 | cache_dir: Optional[str] = None, 482 | is_frozen: bool = False, 483 | ): 484 | if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained): 485 | raise RuntimeError( 486 | f'{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}.' 487 | f' Use open_clip.list_pretrained() to find one.') 488 | 489 | model = create_model( 490 | model_name, 491 | pretrained, 492 | precision=precision, 493 | device=device, 494 | jit=jit, 495 | force_quick_gelu=force_quick_gelu, 496 | force_custom_clip=force_custom_clip, 497 | force_patch_dropout=force_patch_dropout, 498 | cache_dir=cache_dir, 499 | ) 500 | 501 | if is_frozen: 502 | for param in model.parameters(): 503 | param.requires_grad = False 504 | 505 | if not return_transform: 506 | return model 507 | 508 | image_mean = image_mean or getattr(model.visual, 'image_mean', None) 509 | image_std = image_std or getattr(model.visual, 'image_std', None) 510 | preprocess = image_transform( 511 | model.visual.image_size, 512 | is_train=False, 513 | mean=image_mean, 514 | std=image_std 515 | ) 516 | 517 | return model, preprocess 518 | -------------------------------------------------------------------------------- /eva_clip/hf_configs.py: -------------------------------------------------------------------------------- 1 | # HF architecture dict: 2 | arch_dict = { 3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 4 | "roberta": { 5 | "config_names": { 6 | "context_length": "max_position_embeddings", 7 | "vocab_size": "vocab_size", 8 | "width": "hidden_size", 9 | "heads": "num_attention_heads", 10 | "layers": "num_hidden_layers", 11 | "layer_attr": "layer", 12 | "token_embeddings_attr": "embeddings" 13 | }, 14 | "pooler": "mean_pooler", 15 | }, 16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 17 | "xlm-roberta": { 18 | "config_names": { 19 | "context_length": "max_position_embeddings", 20 | "vocab_size": "vocab_size", 21 | "width": "hidden_size", 22 | "heads": "num_attention_heads", 23 | "layers": "num_hidden_layers", 24 | "layer_attr": "layer", 25 | "token_embeddings_attr": "embeddings" 26 | }, 27 | "pooler": "mean_pooler", 28 | }, 29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 30 | "mt5": { 31 | "config_names": { 32 | # unlimited seqlen 33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 35 | "context_length": "", 36 | "vocab_size": "vocab_size", 37 | "width": "d_model", 38 | "heads": "num_heads", 39 | "layers": "num_layers", 40 | "layer_attr": "block", 41 | "token_embeddings_attr": "embed_tokens" 42 | }, 43 | "pooler": "mean_pooler", 44 | }, 45 | "bert": { 46 | "config_names": { 47 | "context_length": "max_position_embeddings", 48 | "vocab_size": "vocab_size", 49 | "width": "hidden_size", 50 | "heads": "num_attention_heads", 51 | "layers": "num_hidden_layers", 52 | "layer_attr": "layer", 53 | "token_embeddings_attr": "embeddings" 54 | }, 55 | "pooler": "mean_pooler", 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /eva_clip/hf_model.py: -------------------------------------------------------------------------------- 1 | """ huggingface model adapter 2 | 3 | Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. 4 | """ 5 | 6 | import re 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn import functional as F 11 | from torch import TensorType 12 | try: 13 | import transformers 14 | from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig 15 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ 16 | BaseModelOutputWithPoolingAndCrossAttentions 17 | except ImportError as e: 18 | transformers = None 19 | 20 | 21 | class BaseModelOutput: 22 | pass 23 | 24 | 25 | class PretrainedConfig: 26 | pass 27 | 28 | from .hf_configs import arch_dict 29 | 30 | # utils 31 | def _camel2snake(s): 32 | return re.sub(r'(? TensorType: 140 | # image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device) 141 | # attn_mask = (x != self.config.pad_token_id).long() 142 | # out = self.transformer( 143 | # input_ids=x, 144 | # attention_mask=attn_mask, 145 | # encoder_hidden_states = image_embeds, 146 | # encoder_attention_mask = image_atts, 147 | # ) 148 | # pooled_out = self.pooler(out, attn_mask) 149 | 150 | # return self.itm_proj(pooled_out) 151 | 152 | def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None): 153 | if masked_indices is None: 154 | masked_indices = torch.bernoulli(probability_matrix).bool() 155 | 156 | masked_indices[input_ids == self.tokenizer.pad_token_id] = False 157 | masked_indices[input_ids == self.tokenizer.cls_token_id] = False 158 | 159 | if targets is not None: 160 | targets[~masked_indices] = -100 # We only compute loss on masked tokens 161 | 162 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 163 | indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices 164 | input_ids[indices_replaced] = self.tokenizer.mask_token_id 165 | 166 | # 10% of the time, we replace masked input tokens with random word 167 | indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced 168 | random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device) 169 | input_ids[indices_random] = random_words[indices_random] 170 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 171 | 172 | if targets is not None: 173 | return input_ids, targets 174 | else: 175 | return input_ids 176 | 177 | def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25): 178 | labels = input_ids.clone() 179 | attn_mask = (input_ids != self.config.pad_token_id).long() 180 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(input_ids.device) 181 | vocab_size = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["vocab_size"]) 182 | probability_matrix = torch.full(labels.shape, mlm_probability) 183 | input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels, 184 | probability_matrix = probability_matrix) 185 | mlm_output = self.transformer(input_ids, 186 | attention_mask = attn_mask, 187 | encoder_hidden_states = image_embeds, 188 | encoder_attention_mask = image_atts, 189 | return_dict = True, 190 | labels = labels, 191 | ) 192 | return mlm_output.loss 193 | # mlm_output = self.transformer(input_ids, 194 | # attention_mask = attn_mask, 195 | # encoder_hidden_states = image_embeds, 196 | # encoder_attention_mask = image_atts, 197 | # return_dict = True, 198 | # ).last_hidden_state 199 | # logits = self.mlm_proj(mlm_output) 200 | 201 | # # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size) 202 | # logits = logits[:, 1:, :].contiguous().view(-1, vocab_size) 203 | # labels = labels[:, 1:].contiguous().view(-1) 204 | 205 | # mlm_loss = F.cross_entropy( 206 | # logits, 207 | # labels, 208 | # # label_smoothing=0.1, 209 | # ) 210 | # return mlm_loss 211 | 212 | 213 | def forward(self, x:TensorType) -> TensorType: 214 | attn_mask = (x != self.config.pad_token_id).long() 215 | out = self.transformer(input_ids=x, attention_mask=attn_mask) 216 | pooled_out = self.pooler(out, attn_mask) 217 | 218 | return self.proj(pooled_out) 219 | 220 | def lock(self, unlocked_layers:int=0, freeze_layer_norm:bool=True): 221 | if not unlocked_layers: # full freezing 222 | for n, p in self.transformer.named_parameters(): 223 | p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False 224 | return 225 | 226 | encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer 227 | layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"]) 228 | print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model") 229 | embeddings = getattr( 230 | self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"]) 231 | modules = [embeddings, *layer_list][:-unlocked_layers] 232 | # freeze layers 233 | for module in modules: 234 | for n, p in module.named_parameters(): 235 | p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False 236 | 237 | 238 | @torch.jit.ignore 239 | def set_grad_checkpointing(self, enable=True): 240 | self.transformer.gradient_checkpointing_enable() 241 | 242 | def get_num_layers(self): 243 | encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer 244 | layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"]) 245 | return len(layer_list) 246 | 247 | def init_parameters(self): 248 | pass 249 | -------------------------------------------------------------------------------- /eva_clip/loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | 6 | try: 7 | import torch.distributed.nn 8 | from torch import distributed as dist 9 | has_distributed = True 10 | except ImportError: 11 | has_distributed = False 12 | 13 | try: 14 | import horovod.torch as hvd 15 | except ImportError: 16 | hvd = None 17 | 18 | from timm.loss import LabelSmoothingCrossEntropy 19 | 20 | 21 | def gather_features( 22 | image_features, 23 | text_features, 24 | local_loss=False, 25 | gather_with_grad=False, 26 | rank=0, 27 | world_size=1, 28 | use_horovod=False 29 | ): 30 | assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.' 31 | if use_horovod: 32 | assert hvd is not None, 'Please install horovod' 33 | if gather_with_grad: 34 | all_image_features = hvd.allgather(image_features) 35 | all_text_features = hvd.allgather(text_features) 36 | else: 37 | with torch.no_grad(): 38 | all_image_features = hvd.allgather(image_features) 39 | all_text_features = hvd.allgather(text_features) 40 | if not local_loss: 41 | # ensure grads for local rank when all_* features don't have a gradient 42 | gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) 43 | gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) 44 | gathered_image_features[rank] = image_features 45 | gathered_text_features[rank] = text_features 46 | all_image_features = torch.cat(gathered_image_features, dim=0) 47 | all_text_features = torch.cat(gathered_text_features, dim=0) 48 | else: 49 | # We gather tensors from all gpus 50 | if gather_with_grad: 51 | all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) 52 | all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) 53 | # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0) 54 | # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0) 55 | else: 56 | gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] 57 | gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] 58 | dist.all_gather(gathered_image_features, image_features) 59 | dist.all_gather(gathered_text_features, text_features) 60 | if not local_loss: 61 | # ensure grads for local rank when all_* features don't have a gradient 62 | gathered_image_features[rank] = image_features 63 | gathered_text_features[rank] = text_features 64 | all_image_features = torch.cat(gathered_image_features, dim=0) 65 | all_text_features = torch.cat(gathered_text_features, dim=0) 66 | 67 | return all_image_features, all_text_features 68 | 69 | 70 | class ClipLoss(nn.Module): 71 | 72 | def __init__( 73 | self, 74 | local_loss=False, 75 | gather_with_grad=False, 76 | cache_labels=False, 77 | rank=0, 78 | world_size=1, 79 | use_horovod=False, 80 | smoothing=0., 81 | ): 82 | super().__init__() 83 | self.local_loss = local_loss 84 | self.gather_with_grad = gather_with_grad 85 | self.cache_labels = cache_labels 86 | self.rank = rank 87 | self.world_size = world_size 88 | self.use_horovod = use_horovod 89 | self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None 90 | 91 | # cache state 92 | self.prev_num_logits = 0 93 | self.labels = {} 94 | 95 | def forward(self, image_features, text_features, logit_scale=1.): 96 | device = image_features.device 97 | if self.world_size > 1: 98 | all_image_features, all_text_features = gather_features( 99 | image_features, text_features, 100 | self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) 101 | 102 | if self.local_loss: 103 | logits_per_image = logit_scale * image_features @ all_text_features.T 104 | logits_per_text = logit_scale * text_features @ all_image_features.T 105 | else: 106 | logits_per_image = logit_scale * all_image_features @ all_text_features.T 107 | logits_per_text = logits_per_image.T 108 | else: 109 | logits_per_image = logit_scale * image_features @ text_features.T 110 | logits_per_text = logit_scale * text_features @ image_features.T 111 | # calculated ground-truth and cache if enabled 112 | num_logits = logits_per_image.shape[0] 113 | if self.prev_num_logits != num_logits or device not in self.labels: 114 | labels = torch.arange(num_logits, device=device, dtype=torch.long) 115 | if self.world_size > 1 and self.local_loss: 116 | labels = labels + num_logits * self.rank 117 | if self.cache_labels: 118 | self.labels[device] = labels 119 | self.prev_num_logits = num_logits 120 | else: 121 | labels = self.labels[device] 122 | 123 | if self.label_smoothing_cross_entropy: 124 | total_loss = ( 125 | self.label_smoothing_cross_entropy(logits_per_image, labels) + 126 | self.label_smoothing_cross_entropy(logits_per_text, labels) 127 | ) / 2 128 | else: 129 | total_loss = ( 130 | F.cross_entropy(logits_per_image, labels) + 131 | F.cross_entropy(logits_per_text, labels) 132 | ) / 2 133 | 134 | acc = None 135 | i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image) 136 | t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text) 137 | acc = {"i2t": i2t_acc, "t2i": t2i_acc} 138 | return total_loss, acc -------------------------------------------------------------------------------- /eva_clip/model.py: -------------------------------------------------------------------------------- 1 | """ CLIP Model 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | import os 6 | from dataclasses import dataclass 7 | from typing import Optional, Tuple, Union 8 | from functools import partial 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import nn 14 | 15 | try: 16 | from .hf_model import HFTextEncoder 17 | except: 18 | HFTextEncoder = None 19 | from .modified_resnet import ModifiedResNet 20 | from .timm_model import TimmModel 21 | from .eva_vit_model import EVAVisionTransformer 22 | from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer 23 | 24 | try: 25 | from apex.normalization import FusedLayerNorm 26 | except: 27 | FusedLayerNorm = LayerNorm 28 | print("Nvidia APEX normalization not installed, using PyTorch LayerNorm") 29 | 30 | try: 31 | import xformers.ops as xops 32 | except ImportError: 33 | xops = None 34 | #print("Please 'pip install xformers'") 35 | 36 | @dataclass 37 | class CLIPVisionCfg: 38 | layers: Union[Tuple[int, int, int, int], int] = 12 39 | width: int = 768 40 | head_width: int = 64 41 | mlp_ratio: float = 4.0 42 | patch_size: int = 16 43 | image_size: Union[Tuple[int, int], int] = 224 44 | ls_init_value: Optional[float] = None # layer scale initial value 45 | patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results 46 | global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) 47 | drop_path_rate: Optional[float] = None # drop path rate 48 | timm_model_name: str = None # a valid model name overrides layers, width, patch_size 49 | timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model 50 | timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') 51 | timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') 52 | timm_proj_bias: bool = False # enable bias final projection 53 | eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size 54 | qkv_bias: bool = True 55 | fusedLN: bool = False 56 | xattn: bool = False 57 | postnorm: bool = False 58 | rope: bool = False 59 | pt_hw_seq_len: int = 16 # 224/14 60 | intp_freq: bool = False 61 | naiveswiglu: bool = False 62 | subln: bool = False 63 | 64 | 65 | @dataclass 66 | class CLIPTextCfg: 67 | context_length: int = 77 68 | vocab_size: int = 49408 69 | width: int = 512 70 | heads: int = 8 71 | layers: int = 12 72 | ls_init_value: Optional[float] = None # layer scale initial value 73 | hf_model_name: str = None 74 | hf_tokenizer_name: str = None 75 | hf_model_pretrained: bool = True 76 | proj: str = 'mlp' 77 | pooler_type: str = 'mean_pooler' 78 | masked_language_modeling: bool = False 79 | fusedLN: bool = False 80 | xattn: bool = False 81 | attn_mask: bool = True 82 | 83 | def get_cast_dtype(precision: str): 84 | cast_dtype = None 85 | if precision == 'bf16': 86 | cast_dtype = torch.bfloat16 87 | elif precision == 'fp16': 88 | cast_dtype = torch.float16 89 | return cast_dtype 90 | 91 | 92 | def _build_vision_tower( 93 | embed_dim: int, 94 | vision_cfg: CLIPVisionCfg, 95 | quick_gelu: bool = False, 96 | cast_dtype: Optional[torch.dtype] = None 97 | ): 98 | if isinstance(vision_cfg, dict): 99 | vision_cfg = CLIPVisionCfg(**vision_cfg) 100 | 101 | # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more 102 | # memory efficient in recent PyTorch releases (>= 1.10). 103 | # NOTE: timm models always use native GELU regardless of quick_gelu flag. 104 | act_layer = QuickGELU if quick_gelu else nn.GELU 105 | 106 | if vision_cfg.eva_model_name: 107 | vision_heads = vision_cfg.width // vision_cfg.head_width 108 | norm_layer = LayerNorm 109 | 110 | visual = EVAVisionTransformer( 111 | img_size=vision_cfg.image_size, 112 | patch_size=vision_cfg.patch_size, 113 | num_classes=embed_dim, 114 | use_mean_pooling=vision_cfg.global_average_pool, #False 115 | init_values=vision_cfg.ls_init_value, 116 | patch_dropout=vision_cfg.patch_dropout, 117 | embed_dim=vision_cfg.width, 118 | depth=vision_cfg.layers, 119 | num_heads=vision_heads, 120 | mlp_ratio=vision_cfg.mlp_ratio, 121 | qkv_bias=vision_cfg.qkv_bias, 122 | drop_path_rate=vision_cfg.drop_path_rate, 123 | norm_layer= partial(FusedLayerNorm, eps=1e-6) if vision_cfg.fusedLN else partial(norm_layer, eps=1e-6), 124 | xattn=vision_cfg.xattn, 125 | rope=vision_cfg.rope, 126 | postnorm=vision_cfg.postnorm, 127 | pt_hw_seq_len= vision_cfg.pt_hw_seq_len, # 224/14 128 | intp_freq= vision_cfg.intp_freq, 129 | naiveswiglu= vision_cfg.naiveswiglu, 130 | subln= vision_cfg.subln 131 | ) 132 | elif vision_cfg.timm_model_name: 133 | visual = TimmModel( 134 | vision_cfg.timm_model_name, 135 | pretrained=vision_cfg.timm_model_pretrained, 136 | pool=vision_cfg.timm_pool, 137 | proj=vision_cfg.timm_proj, 138 | proj_bias=vision_cfg.timm_proj_bias, 139 | embed_dim=embed_dim, 140 | image_size=vision_cfg.image_size 141 | ) 142 | act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models 143 | elif isinstance(vision_cfg.layers, (tuple, list)): 144 | vision_heads = vision_cfg.width * 32 // vision_cfg.head_width 145 | visual = ModifiedResNet( 146 | layers=vision_cfg.layers, 147 | output_dim=embed_dim, 148 | heads=vision_heads, 149 | image_size=vision_cfg.image_size, 150 | width=vision_cfg.width 151 | ) 152 | else: 153 | vision_heads = vision_cfg.width // vision_cfg.head_width 154 | norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm 155 | visual = VisionTransformer( 156 | image_size=vision_cfg.image_size, 157 | patch_size=vision_cfg.patch_size, 158 | width=vision_cfg.width, 159 | layers=vision_cfg.layers, 160 | heads=vision_heads, 161 | mlp_ratio=vision_cfg.mlp_ratio, 162 | ls_init_value=vision_cfg.ls_init_value, 163 | patch_dropout=vision_cfg.patch_dropout, 164 | global_average_pool=vision_cfg.global_average_pool, 165 | output_dim=embed_dim, 166 | act_layer=act_layer, 167 | norm_layer=norm_layer, 168 | ) 169 | 170 | return visual 171 | 172 | 173 | def _build_text_tower( 174 | embed_dim: int, 175 | text_cfg: CLIPTextCfg, 176 | quick_gelu: bool = False, 177 | cast_dtype: Optional[torch.dtype] = None, 178 | ): 179 | if isinstance(text_cfg, dict): 180 | text_cfg = CLIPTextCfg(**text_cfg) 181 | 182 | if text_cfg.hf_model_name: 183 | text = HFTextEncoder( 184 | text_cfg.hf_model_name, 185 | output_dim=embed_dim, 186 | tokenizer_name=text_cfg.hf_tokenizer_name, 187 | proj=text_cfg.proj, 188 | pooler_type=text_cfg.pooler_type, 189 | masked_language_modeling=text_cfg.masked_language_modeling 190 | ) 191 | else: 192 | act_layer = QuickGELU if quick_gelu else nn.GELU 193 | norm_layer = LayerNorm 194 | 195 | text = TextTransformer( 196 | context_length=text_cfg.context_length, 197 | vocab_size=text_cfg.vocab_size, 198 | width=text_cfg.width, 199 | heads=text_cfg.heads, 200 | layers=text_cfg.layers, 201 | ls_init_value=text_cfg.ls_init_value, 202 | output_dim=embed_dim, 203 | act_layer=act_layer, 204 | norm_layer= FusedLayerNorm if text_cfg.fusedLN else norm_layer, 205 | xattn=text_cfg.xattn, 206 | attn_mask=text_cfg.attn_mask, 207 | ) 208 | return text 209 | 210 | class CLIP(nn.Module): 211 | def __init__( 212 | self, 213 | embed_dim: int, 214 | vision_cfg: CLIPVisionCfg, 215 | text_cfg: CLIPTextCfg, 216 | quick_gelu: bool = False, 217 | cast_dtype: Optional[torch.dtype] = None, 218 | ): 219 | super().__init__() 220 | self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) 221 | 222 | text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) 223 | self.transformer = text.transformer 224 | self.vocab_size = text.vocab_size 225 | self.token_embedding = text.token_embedding 226 | self.positional_embedding = text.positional_embedding 227 | self.ln_final = text.ln_final 228 | self.text_projection = text.text_projection 229 | self.register_buffer('attn_mask', text.attn_mask, persistent=False) 230 | 231 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 232 | 233 | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): 234 | # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 235 | self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) 236 | 237 | @torch.jit.ignore 238 | def set_grad_checkpointing(self, enable=True): 239 | self.visual.set_grad_checkpointing(enable) 240 | self.transformer.grad_checkpointing = enable 241 | 242 | @torch.jit.ignore 243 | def no_weight_decay(self): 244 | return {'logit_scale'} 245 | 246 | def encode_image(self, image, normalize: bool = False): 247 | features = self.visual(image) 248 | return F.normalize(features, dim=-1) if normalize else features 249 | 250 | def encode_text(self, text, normalize: bool = False): 251 | cast_dtype = self.transformer.get_cast_dtype() 252 | 253 | x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] 254 | 255 | x = x + self.positional_embedding.to(cast_dtype) 256 | x = x.permute(1, 0, 2) # NLD -> LND 257 | x = self.transformer(x, attn_mask=self.attn_mask) 258 | x = x.permute(1, 0, 2) # LND -> NLD 259 | x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] 260 | # take features from the eot embedding (eot_token is the highest number in each sequence) 261 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 262 | return F.normalize(x, dim=-1) if normalize else x 263 | 264 | def forward(self, image, text): 265 | image_features = self.encode_image(image, normalize=True) 266 | text_features = self.encode_text(text, normalize=True) 267 | return image_features, text_features, self.logit_scale.exp() 268 | 269 | 270 | class CustomCLIP(nn.Module): 271 | def __init__( 272 | self, 273 | embed_dim: int, 274 | vision_cfg: CLIPVisionCfg, 275 | text_cfg: CLIPTextCfg, 276 | quick_gelu: bool = False, 277 | cast_dtype: Optional[torch.dtype] = None, 278 | itm_task: bool = False, 279 | ): 280 | super().__init__() 281 | self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) 282 | self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) 283 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 284 | 285 | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): 286 | # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 287 | self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) 288 | 289 | def lock_text_tower(self, unlocked_layers:int=0, freeze_layer_norm:bool=True): 290 | self.text.lock(unlocked_layers, freeze_layer_norm) 291 | 292 | @torch.jit.ignore 293 | def set_grad_checkpointing(self, enable=True): 294 | self.visual.set_grad_checkpointing(enable) 295 | self.text.set_grad_checkpointing(enable) 296 | 297 | @torch.jit.ignore 298 | def no_weight_decay(self): 299 | return {'logit_scale'} 300 | 301 | def encode_image(self, image, normalize: bool = False): 302 | features = self.visual(image) 303 | return F.normalize(features, dim=-1) if normalize else features 304 | 305 | def encode_text(self, text, normalize: bool = False): 306 | features = self.text(text) 307 | return F.normalize(features, dim=-1) if normalize else features 308 | 309 | def forward(self, image, text): 310 | image_features = self.encode_image(image, normalize=True) 311 | text_features = self.encode_text(text, normalize=True) 312 | return image_features, text_features, self.logit_scale.exp() 313 | 314 | 315 | def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): 316 | """Convert applicable model parameters to low-precision (bf16 or fp16)""" 317 | 318 | def _convert_weights(l): 319 | 320 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 321 | l.weight.data = l.weight.data.to(dtype) 322 | if l.bias is not None: 323 | l.bias.data = l.bias.data.to(dtype) 324 | 325 | if isinstance(l, (nn.MultiheadAttention, Attention)): 326 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 327 | tensor = getattr(l, attr, None) 328 | if tensor is not None: 329 | tensor.data = tensor.data.to(dtype) 330 | 331 | if isinstance(l, nn.Parameter): 332 | l.data = l.data.to(dtype) 333 | 334 | for name in ["text_projection", "proj"]: 335 | if hasattr(l, name) and isinstance(l, nn.Parameter): 336 | attr = getattr(l, name, None) 337 | if attr is not None: 338 | attr.data = attr.data.to(dtype) 339 | 340 | model.apply(_convert_weights) 341 | 342 | 343 | convert_weights_to_fp16 = convert_weights_to_lp # backwards compat 344 | 345 | 346 | # used to maintain checkpoint compatibility 347 | def convert_to_custom_text_state_dict(state_dict: dict): 348 | if 'text_projection' in state_dict: 349 | # old format state_dict, move text tower -> .text 350 | new_state_dict = {} 351 | for k, v in state_dict.items(): 352 | if any(k.startswith(p) for p in ( 353 | 'text_projection', 354 | 'positional_embedding', 355 | 'token_embedding', 356 | 'transformer', 357 | 'ln_final', 358 | 'logit_scale' 359 | )): 360 | k = 'text.' + k 361 | new_state_dict[k] = v 362 | return new_state_dict 363 | return state_dict 364 | 365 | 366 | def build_model_from_openai_state_dict( 367 | state_dict: dict, 368 | quick_gelu=True, 369 | cast_dtype=torch.float16, 370 | ): 371 | vit = "visual.proj" in state_dict 372 | 373 | if vit: 374 | vision_width = state_dict["visual.conv1.weight"].shape[0] 375 | vision_layers = len( 376 | [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 377 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 378 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 379 | image_size = vision_patch_size * grid_size 380 | else: 381 | counts: list = [ 382 | len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 383 | vision_layers = tuple(counts) 384 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 385 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 386 | vision_patch_size = None 387 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 388 | image_size = output_width * 32 389 | 390 | embed_dim = state_dict["text_projection"].shape[1] 391 | context_length = state_dict["positional_embedding"].shape[0] 392 | vocab_size = state_dict["token_embedding.weight"].shape[0] 393 | transformer_width = state_dict["ln_final.weight"].shape[0] 394 | transformer_heads = transformer_width // 64 395 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 396 | 397 | vision_cfg = CLIPVisionCfg( 398 | layers=vision_layers, 399 | width=vision_width, 400 | patch_size=vision_patch_size, 401 | image_size=image_size, 402 | ) 403 | text_cfg = CLIPTextCfg( 404 | context_length=context_length, 405 | vocab_size=vocab_size, 406 | width=transformer_width, 407 | heads=transformer_heads, 408 | layers=transformer_layers 409 | ) 410 | model = CLIP( 411 | embed_dim, 412 | vision_cfg=vision_cfg, 413 | text_cfg=text_cfg, 414 | quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU 415 | cast_dtype=cast_dtype, 416 | ) 417 | 418 | for key in ["input_resolution", "context_length", "vocab_size"]: 419 | state_dict.pop(key, None) 420 | 421 | convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 422 | model.load_state_dict(state_dict) 423 | return model.eval() 424 | 425 | 426 | def trace_model(model, batch_size=256, device=torch.device('cpu')): 427 | model.eval() 428 | image_size = model.visual.image_size 429 | example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) 430 | example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) 431 | model = torch.jit.trace_module( 432 | model, 433 | inputs=dict( 434 | forward=(example_images, example_text), 435 | encode_text=(example_text,), 436 | encode_image=(example_images,) 437 | )) 438 | model.visual.image_size = image_size 439 | return model 440 | -------------------------------------------------------------------------------- /eva_clip/model_configs/EVA01-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16, 8 | "eva_model_name": "eva-clip-b-16", 9 | "ls_init_value": 0.1, 10 | "drop_path_rate": 0.0 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /eva_clip/model_configs/EVA01-CLIP-g-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 1024, 19 | "heads": 16, 20 | "layers": 24, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /eva_clip/model_configs/EVA01-CLIP-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0.4, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 768, 19 | "heads": 12, 20 | "layers": 12, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /eva_clip/model_configs/EVA02-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "head_width": 64, 8 | "patch_size": 16, 9 | "mlp_ratio": 2.6667, 10 | "eva_model_name": "eva-clip-b-16-X", 11 | "drop_path_rate": 0.0, 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 512, 24 | "heads": 8, 25 | "layers": 12, 26 | "xattn": true, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /eva_clip/model_configs/EVA02-CLIP-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14-336", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /eva_clip/model_configs/EVA02-CLIP-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /eva_clip/model_configs/EVA02-CLIP-bigE-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1024, 20 | "heads": 16, 21 | "layers": 24, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } -------------------------------------------------------------------------------- /eva_clip/modified_resnet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from .utils import freeze_batch_norm_2d 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.act1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.act2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.act3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.act1(self.bn1(self.conv1(x))) 46 | out = self.act2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.act3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x, key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0., 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | 92 | return x[0] 93 | 94 | 95 | class ModifiedResNet(nn.Module): 96 | """ 97 | A ResNet class that is similar to torchvision's but contains the following changes: 98 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 99 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 100 | - The final pooling layer is a QKV attention instead of an average pool 101 | """ 102 | 103 | def __init__(self, layers, output_dim, heads, image_size=224, width=64): 104 | super().__init__() 105 | self.output_dim = output_dim 106 | self.image_size = image_size 107 | 108 | # the 3-layer stem 109 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 110 | self.bn1 = nn.BatchNorm2d(width // 2) 111 | self.act1 = nn.ReLU(inplace=True) 112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 113 | self.bn2 = nn.BatchNorm2d(width // 2) 114 | self.act2 = nn.ReLU(inplace=True) 115 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 116 | self.bn3 = nn.BatchNorm2d(width) 117 | self.act3 = nn.ReLU(inplace=True) 118 | self.avgpool = nn.AvgPool2d(2) 119 | 120 | # residual layers 121 | self._inplanes = width # this is a *mutable* variable used during construction 122 | self.layer1 = self._make_layer(width, layers[0]) 123 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 124 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 125 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 126 | 127 | embed_dim = width * 32 # the ResNet feature dimension 128 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 129 | 130 | self.init_parameters() 131 | 132 | def _make_layer(self, planes, blocks, stride=1): 133 | layers = [Bottleneck(self._inplanes, planes, stride)] 134 | 135 | self._inplanes = planes * Bottleneck.expansion 136 | for _ in range(1, blocks): 137 | layers.append(Bottleneck(self._inplanes, planes)) 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def init_parameters(self): 142 | if self.attnpool is not None: 143 | std = self.attnpool.c_proj.in_features ** -0.5 144 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 145 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 146 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 147 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 148 | 149 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 150 | for name, param in resnet_block.named_parameters(): 151 | if name.endswith("bn3.weight"): 152 | nn.init.zeros_(param) 153 | 154 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 155 | assert unlocked_groups == 0, 'partial locking not currently supported for this model' 156 | for param in self.parameters(): 157 | param.requires_grad = False 158 | if freeze_bn_stats: 159 | freeze_batch_norm_2d(self) 160 | 161 | @torch.jit.ignore 162 | def set_grad_checkpointing(self, enable=True): 163 | # FIXME support for non-transformer 164 | pass 165 | 166 | def stem(self, x): 167 | x = self.act1(self.bn1(self.conv1(x))) 168 | x = self.act2(self.bn2(self.conv2(x))) 169 | x = self.act3(self.bn3(self.conv3(x))) 170 | x = self.avgpool(x) 171 | return x 172 | 173 | def forward(self, x): 174 | x = self.stem(x) 175 | x = self.layer1(x) 176 | x = self.layer2(x) 177 | x = self.layer3(x) 178 | x = self.layer4(x) 179 | x = self.attnpool(x) 180 | 181 | return x 182 | -------------------------------------------------------------------------------- /eva_clip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import List, Optional, Union 9 | 10 | import torch 11 | 12 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype 13 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url 14 | 15 | __all__ = ["list_openai_models", "load_openai_model"] 16 | 17 | 18 | def list_openai_models() -> List[str]: 19 | """Returns the names of available CLIP models""" 20 | return list_pretrained_models_by_tag('openai') 21 | 22 | 23 | def load_openai_model( 24 | name: str, 25 | precision: Optional[str] = None, 26 | device: Optional[Union[str, torch.device]] = None, 27 | jit: bool = True, 28 | cache_dir: Optional[str] = None, 29 | ): 30 | """Load a CLIP model 31 | 32 | Parameters 33 | ---------- 34 | name : str 35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 36 | precision: str 37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | jit : bool 41 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 42 | cache_dir : Optional[str] 43 | The directory to cache the downloaded model weights 44 | 45 | Returns 46 | ------- 47 | model : torch.nn.Module 48 | The CLIP model 49 | preprocess : Callable[[PIL.Image], torch.Tensor] 50 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 51 | """ 52 | if device is None: 53 | device = "cuda" if torch.cuda.is_available() else "cpu" 54 | if precision is None: 55 | precision = 'fp32' if device == 'cpu' else 'fp16' 56 | 57 | if get_pretrained_url(name, 'openai'): 58 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) 59 | elif os.path.isfile(name): 60 | model_path = name 61 | else: 62 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 63 | 64 | try: 65 | # loading JIT archive 66 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 67 | state_dict = None 68 | except RuntimeError: 69 | # loading saved state dict 70 | if jit: 71 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 72 | jit = False 73 | state_dict = torch.load(model_path, map_location="cpu") 74 | 75 | if not jit: 76 | # Build a non-jit model from the OpenAI jitted model state dict 77 | cast_dtype = get_cast_dtype(precision) 78 | try: 79 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) 80 | except KeyError: 81 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 82 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) 83 | 84 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use 85 | model = model.to(device) 86 | if precision.startswith('amp') or precision == 'fp32': 87 | model.float() 88 | elif precision == 'bf16': 89 | convert_weights_to_lp(model, dtype=torch.bfloat16) 90 | 91 | return model 92 | 93 | # patch the device names 94 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 95 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 96 | 97 | def patch_device(module): 98 | try: 99 | graphs = [module.graph] if hasattr(module, "graph") else [] 100 | except RuntimeError: 101 | graphs = [] 102 | 103 | if hasattr(module, "forward1"): 104 | graphs.append(module.forward1.graph) 105 | 106 | for graph in graphs: 107 | for node in graph.findAllNodes("prim::Constant"): 108 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 109 | node.copyAttributes(device_node) 110 | 111 | model.apply(patch_device) 112 | patch_device(model.encode_image) 113 | patch_device(model.encode_text) 114 | 115 | # patch dtype to float32 (typically for CPU) 116 | if precision == 'fp32': 117 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 118 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 119 | float_node = float_input.node() 120 | 121 | def patch_float(module): 122 | try: 123 | graphs = [module.graph] if hasattr(module, "graph") else [] 124 | except RuntimeError: 125 | graphs = [] 126 | 127 | if hasattr(module, "forward1"): 128 | graphs.append(module.forward1.graph) 129 | 130 | for graph in graphs: 131 | for node in graph.findAllNodes("aten::to"): 132 | inputs = list(node.inputs()) 133 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 134 | if inputs[i].node()["value"] == 5: 135 | inputs[i].node().copyAttributes(float_node) 136 | 137 | model.apply(patch_float) 138 | patch_float(model.encode_image) 139 | patch_float(model.encode_text) 140 | model.float() 141 | 142 | # ensure image_size attr available at consistent location for both jit and non-jit 143 | model.visual.image_size = model.input_resolution.item() 144 | return model 145 | -------------------------------------------------------------------------------- /eva_clip/pretrained.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from functools import partial 6 | from typing import Dict, Union 7 | 8 | from tqdm import tqdm 9 | 10 | try: 11 | from huggingface_hub import hf_hub_download 12 | _has_hf_hub = True 13 | except ImportError: 14 | hf_hub_download = None 15 | _has_hf_hub = False 16 | 17 | 18 | def _pcfg(url='', hf_hub='', filename='', mean=None, std=None): 19 | return dict( 20 | url=url, 21 | hf_hub=hf_hub, 22 | mean=mean, 23 | std=std, 24 | ) 25 | 26 | _VITB32 = dict( 27 | openai=_pcfg( 28 | "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), 29 | laion400m_e31=_pcfg( 30 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), 31 | laion400m_e32=_pcfg( 32 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), 33 | laion2b_e16=_pcfg( 34 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), 35 | laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/') 36 | ) 37 | 38 | _VITB32_quickgelu = dict( 39 | openai=_pcfg( 40 | "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), 41 | laion400m_e31=_pcfg( 42 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), 43 | laion400m_e32=_pcfg( 44 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), 45 | ) 46 | 47 | _VITB16 = dict( 48 | openai=_pcfg( 49 | "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), 50 | laion400m_e31=_pcfg( 51 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), 52 | laion400m_e32=_pcfg( 53 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), 54 | laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), 55 | ) 56 | 57 | _EVAB16 = dict( 58 | eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'), 59 | eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'), 60 | eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'), 61 | eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'), 62 | ) 63 | 64 | _VITB16_PLUS_240 = dict( 65 | laion400m_e31=_pcfg( 66 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"), 67 | laion400m_e32=_pcfg( 68 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"), 69 | ) 70 | 71 | _VITL14 = dict( 72 | openai=_pcfg( 73 | "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), 74 | laion400m_e31=_pcfg( 75 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), 76 | laion400m_e32=_pcfg( 77 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), 78 | laion2b_s32b_b82k=_pcfg( 79 | hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', 80 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 81 | ) 82 | 83 | _EVAL14 = dict( 84 | eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'), 85 | eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'), 86 | eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'), 87 | eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'), 88 | ) 89 | 90 | _VITL14_336 = dict( 91 | openai=_pcfg( 92 | "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), 93 | ) 94 | 95 | _EVAL14_336 = dict( 96 | eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'), 97 | eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'), 98 | eva_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'), 99 | eva02_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'), 100 | ) 101 | 102 | _VITH14 = dict( 103 | laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), 104 | ) 105 | 106 | _VITg14 = dict( 107 | laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'), 108 | laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'), 109 | ) 110 | 111 | _EVAg14 = dict( 112 | eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'), 113 | eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'), 114 | eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'), 115 | eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'), 116 | ) 117 | 118 | _EVAg14_PLUS = dict( 119 | eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'), 120 | eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'), 121 | eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'), 122 | eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'), 123 | ) 124 | 125 | _VITbigG14 = dict( 126 | laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), 127 | ) 128 | 129 | _EVAbigE14 = dict( 130 | eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'), 131 | eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'), 132 | eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'), 133 | eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'), 134 | ) 135 | 136 | _EVAbigE14_PLUS = dict( 137 | eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'), 138 | eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'), 139 | eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'), 140 | eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'), 141 | ) 142 | 143 | 144 | _PRETRAINED = { 145 | # "ViT-B-32": _VITB32, 146 | "OpenaiCLIP-B-32": _VITB32, 147 | "OpenCLIP-B-32": _VITB32, 148 | 149 | # "ViT-B-32-quickgelu": _VITB32_quickgelu, 150 | "OpenaiCLIP-B-32-quickgelu": _VITB32_quickgelu, 151 | "OpenCLIP-B-32-quickgelu": _VITB32_quickgelu, 152 | 153 | # "ViT-B-16": _VITB16, 154 | "OpenaiCLIP-B-16": _VITB16, 155 | "OpenCLIP-B-16": _VITB16, 156 | 157 | "EVA02-B-16": _EVAB16, 158 | "EVA02-CLIP-B-16": _EVAB16, 159 | 160 | # "ViT-B-16-plus-240": _VITB16_PLUS_240, 161 | "OpenCLIP-B-16-plus-240": _VITB16_PLUS_240, 162 | 163 | # "ViT-L-14": _VITL14, 164 | "OpenaiCLIP-L-14": _VITL14, 165 | "OpenCLIP-L-14": _VITL14, 166 | 167 | "EVA02-L-14": _EVAL14, 168 | "EVA02-CLIP-L-14": _EVAL14, 169 | 170 | # "ViT-L-14-336": _VITL14_336, 171 | "OpenaiCLIP-L-14-336": _VITL14_336, 172 | 173 | "EVA02-CLIP-L-14-336": _EVAL14_336, 174 | 175 | # "ViT-H-14": _VITH14, 176 | # "ViT-g-14": _VITg14, 177 | "OpenCLIP-H-14": _VITH14, 178 | "OpenCLIP-g-14": _VITg14, 179 | 180 | "EVA01-CLIP-g-14": _EVAg14, 181 | "EVA01-CLIP-g-14-plus": _EVAg14_PLUS, 182 | 183 | # "ViT-bigG-14": _VITbigG14, 184 | "OpenCLIP-bigG-14": _VITbigG14, 185 | 186 | "EVA02-CLIP-bigE-14": _EVAbigE14, 187 | "EVA02-CLIP-bigE-14-plus": _EVAbigE14_PLUS, 188 | } 189 | 190 | 191 | def _clean_tag(tag: str): 192 | # normalize pretrained tags 193 | return tag.lower().replace('-', '_') 194 | 195 | 196 | def list_pretrained(as_str: bool = False): 197 | """ returns list of pretrained models 198 | Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True 199 | """ 200 | return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] 201 | 202 | 203 | def list_pretrained_models_by_tag(tag: str): 204 | """ return all models having the specified pretrain tag """ 205 | models = [] 206 | tag = _clean_tag(tag) 207 | for k in _PRETRAINED.keys(): 208 | if tag in _PRETRAINED[k]: 209 | models.append(k) 210 | return models 211 | 212 | 213 | def list_pretrained_tags_by_model(model: str): 214 | """ return all pretrain tags for the specified model architecture """ 215 | tags = [] 216 | if model in _PRETRAINED: 217 | tags.extend(_PRETRAINED[model].keys()) 218 | return tags 219 | 220 | 221 | def is_pretrained_cfg(model: str, tag: str): 222 | if model not in _PRETRAINED: 223 | return False 224 | return _clean_tag(tag) in _PRETRAINED[model] 225 | 226 | 227 | def get_pretrained_cfg(model: str, tag: str): 228 | if model not in _PRETRAINED: 229 | return {} 230 | model_pretrained = _PRETRAINED[model] 231 | return model_pretrained.get(_clean_tag(tag), {}) 232 | 233 | 234 | def get_pretrained_url(model: str, tag: str): 235 | cfg = get_pretrained_cfg(model, _clean_tag(tag)) 236 | return cfg.get('url', '') 237 | 238 | 239 | def download_pretrained_from_url( 240 | url: str, 241 | cache_dir: Union[str, None] = None, 242 | ): 243 | if not cache_dir: 244 | cache_dir = os.path.expanduser("~/.cache/clip") 245 | os.makedirs(cache_dir, exist_ok=True) 246 | filename = os.path.basename(url) 247 | 248 | if 'openaipublic' in url: 249 | expected_sha256 = url.split("/")[-2] 250 | elif 'mlfoundations' in url: 251 | expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] 252 | else: 253 | expected_sha256 = '' 254 | 255 | download_target = os.path.join(cache_dir, filename) 256 | 257 | if os.path.exists(download_target) and not os.path.isfile(download_target): 258 | raise RuntimeError(f"{download_target} exists and is not a regular file") 259 | 260 | if os.path.isfile(download_target): 261 | if expected_sha256: 262 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): 263 | return download_target 264 | else: 265 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 266 | else: 267 | return download_target 268 | 269 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 270 | with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 271 | while True: 272 | buffer = source.read(8192) 273 | if not buffer: 274 | break 275 | 276 | output.write(buffer) 277 | loop.update(len(buffer)) 278 | 279 | if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): 280 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 281 | 282 | return download_target 283 | 284 | 285 | def has_hf_hub(necessary=False): 286 | if not _has_hf_hub and necessary: 287 | # if no HF Hub module installed, and it is necessary to continue, raise error 288 | raise RuntimeError( 289 | 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') 290 | return _has_hf_hub 291 | 292 | 293 | def download_pretrained_from_hf( 294 | model_id: str, 295 | filename: str = 'open_clip_pytorch_model.bin', 296 | revision=None, 297 | cache_dir: Union[str, None] = None, 298 | ): 299 | has_hf_hub(True) 300 | cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) 301 | return cached_file 302 | 303 | 304 | def download_pretrained( 305 | cfg: Dict, 306 | force_hf_hub: bool = False, 307 | cache_dir: Union[str, None] = None, 308 | ): 309 | target = '' 310 | if not cfg: 311 | return target 312 | 313 | download_url = cfg.get('url', '') 314 | download_hf_hub = cfg.get('hf_hub', '') 315 | if download_hf_hub and force_hf_hub: 316 | # use HF hub even if url exists 317 | download_url = '' 318 | 319 | if download_url: 320 | target = download_pretrained_from_url(download_url, cache_dir=cache_dir) 321 | elif download_hf_hub: 322 | has_hf_hub(True) 323 | # we assume the hf_hub entries in pretrained config combine model_id + filename in 324 | # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and 325 | # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'. 326 | model_id, filename = os.path.split(download_hf_hub) 327 | if filename: 328 | target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) 329 | else: 330 | target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) 331 | 332 | return target 333 | -------------------------------------------------------------------------------- /eva_clip/rope.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | import torch 3 | from torch import nn 4 | from einops import rearrange, repeat 5 | import logging 6 | 7 | def broadcat(tensors, dim = -1): 8 | num_tensors = len(tensors) 9 | shape_lens = set(list(map(lambda t: len(t.shape), tensors))) 10 | assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' 11 | shape_len = list(shape_lens)[0] 12 | dim = (dim + shape_len) if dim < 0 else dim 13 | dims = list(zip(*map(lambda t: list(t.shape), tensors))) 14 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] 15 | assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' 16 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) 17 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) 18 | expanded_dims.insert(dim, (dim, dims[dim])) 19 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) 20 | tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) 21 | return torch.cat(tensors, dim = dim) 22 | 23 | def rotate_half(x): 24 | x = rearrange(x, '... (d r) -> ... d r', r = 2) 25 | x1, x2 = x.unbind(dim = -1) 26 | x = torch.stack((-x2, x1), dim = -1) 27 | return rearrange(x, '... d r -> ... (d r)') 28 | 29 | 30 | class VisionRotaryEmbedding(nn.Module): 31 | def __init__( 32 | self, 33 | dim, 34 | pt_seq_len, 35 | ft_seq_len=None, 36 | custom_freqs = None, 37 | freqs_for = 'lang', 38 | theta = 10000, 39 | max_freq = 10, 40 | num_freqs = 1, 41 | ): 42 | super().__init__() 43 | if custom_freqs: 44 | freqs = custom_freqs 45 | elif freqs_for == 'lang': 46 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) 47 | elif freqs_for == 'pixel': 48 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi 49 | elif freqs_for == 'constant': 50 | freqs = torch.ones(num_freqs).float() 51 | else: 52 | raise ValueError(f'unknown modality {freqs_for}') 53 | 54 | if ft_seq_len is None: ft_seq_len = pt_seq_len 55 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 56 | 57 | freqs_h = torch.einsum('..., f -> ... f', t, freqs) 58 | freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2) 59 | 60 | freqs_w = torch.einsum('..., f -> ... f', t, freqs) 61 | freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2) 62 | 63 | freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1) 64 | 65 | self.register_buffer("freqs_cos", freqs.cos()) 66 | self.register_buffer("freqs_sin", freqs.sin()) 67 | 68 | logging.info(f'Shape of rope freq: {self.freqs_cos.shape}') 69 | 70 | def forward(self, t, start_index = 0): 71 | rot_dim = self.freqs_cos.shape[-1] 72 | end_index = start_index + rot_dim 73 | assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' 74 | t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] 75 | t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) 76 | 77 | return torch.cat((t_left, t, t_right), dim = -1) 78 | 79 | class VisionRotaryEmbeddingFast(nn.Module): 80 | def __init__( 81 | self, 82 | dim, 83 | pt_seq_len, 84 | ft_seq_len=None, 85 | custom_freqs = None, 86 | freqs_for = 'lang', 87 | theta = 10000, 88 | max_freq = 10, 89 | num_freqs = 1, 90 | patch_dropout = 0. 91 | ): 92 | super().__init__() 93 | if custom_freqs: 94 | freqs = custom_freqs 95 | elif freqs_for == 'lang': 96 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) 97 | elif freqs_for == 'pixel': 98 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi 99 | elif freqs_for == 'constant': 100 | freqs = torch.ones(num_freqs).float() 101 | else: 102 | raise ValueError(f'unknown modality {freqs_for}') 103 | 104 | if ft_seq_len is None: ft_seq_len = pt_seq_len 105 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 106 | 107 | freqs = torch.einsum('..., f -> ... f', t, freqs) 108 | freqs = repeat(freqs, '... n -> ... (n r)', r = 2) 109 | freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1) 110 | 111 | freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) 112 | freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) 113 | 114 | self.patch_dropout = patch_dropout 115 | 116 | self.register_buffer("freqs_cos", freqs_cos) 117 | self.register_buffer("freqs_sin", freqs_sin) 118 | 119 | logging.info(f'Shape of rope freq: {self.freqs_cos.shape}') 120 | 121 | def forward(self, t, patch_indices_keep=None): 122 | if patch_indices_keep is not None: 123 | batch = t.size()[0] 124 | batch_indices = torch.arange(batch) 125 | batch_indices = batch_indices[..., None] 126 | 127 | freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1]) 128 | freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1]) 129 | 130 | freqs_cos = freqs_cos[batch_indices, patch_indices_keep] 131 | freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j') 132 | freqs_sin = freqs_sin[batch_indices, patch_indices_keep] 133 | freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j') 134 | 135 | return t * freqs_cos + rotate_half(t) * freqs_sin 136 | 137 | return t * self.freqs_cos + rotate_half(t) * self.freqs_sin -------------------------------------------------------------------------------- /eva_clip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | import logging 6 | from collections import OrderedDict 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | try: 12 | import timm 13 | from timm.models.layers import Mlp, to_2tuple 14 | try: 15 | # old timm imports < 0.8.1 16 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 17 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 18 | except ImportError: 19 | # new timm imports >= 0.8.1 20 | from timm.layers import RotAttentionPool2d 21 | from timm.layers import AttentionPool2d as AbsAttentionPool2d 22 | except ImportError: 23 | timm = None 24 | 25 | from .utils import freeze_batch_norm_2d 26 | 27 | 28 | class TimmModel(nn.Module): 29 | """ timm model adapter 30 | # FIXME this adapter is a work in progress, may change in ways that break weight compat 31 | """ 32 | 33 | def __init__( 34 | self, 35 | model_name, 36 | embed_dim, 37 | image_size=224, 38 | pool='avg', 39 | proj='linear', 40 | proj_bias=False, 41 | drop=0., 42 | pretrained=False): 43 | super().__init__() 44 | if timm is None: 45 | raise RuntimeError("Please `pip install timm` to use timm models.") 46 | 47 | self.image_size = to_2tuple(image_size) 48 | self.trunk = timm.create_model(model_name, pretrained=pretrained) 49 | feat_size = self.trunk.default_cfg.get('pool_size', None) 50 | feature_ndim = 1 if not feat_size else 2 51 | if pool in ('abs_attn', 'rot_attn'): 52 | assert feature_ndim == 2 53 | # if attn pooling used, remove both classifier and default pool 54 | self.trunk.reset_classifier(0, global_pool='') 55 | else: 56 | # reset global pool if pool config set, otherwise leave as network default 57 | reset_kwargs = dict(global_pool=pool) if pool else {} 58 | self.trunk.reset_classifier(0, **reset_kwargs) 59 | prev_chs = self.trunk.num_features 60 | 61 | head_layers = OrderedDict() 62 | if pool == 'abs_attn': 63 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 64 | prev_chs = embed_dim 65 | elif pool == 'rot_attn': 66 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 67 | prev_chs = embed_dim 68 | else: 69 | assert proj, 'projection layer needed if non-attention pooling is used.' 70 | 71 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 72 | if proj == 'linear': 73 | head_layers['drop'] = nn.Dropout(drop) 74 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) 75 | elif proj == 'mlp': 76 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias)) 77 | 78 | self.head = nn.Sequential(head_layers) 79 | 80 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 81 | """ lock modules 82 | Args: 83 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 84 | """ 85 | if not unlocked_groups: 86 | # lock full model 87 | for param in self.trunk.parameters(): 88 | param.requires_grad = False 89 | if freeze_bn_stats: 90 | freeze_batch_norm_2d(self.trunk) 91 | else: 92 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 93 | try: 94 | # FIXME import here until API stable and in an official release 95 | from timm.models.helpers import group_parameters, group_modules 96 | except ImportError: 97 | raise RuntimeError( 98 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') 99 | matcher = self.trunk.group_matcher() 100 | gparams = group_parameters(self.trunk, matcher) 101 | max_layer_id = max(gparams.keys()) 102 | max_layer_id = max_layer_id - unlocked_groups 103 | for group_idx in range(max_layer_id + 1): 104 | group = gparams[group_idx] 105 | for param in group: 106 | self.trunk.get_parameter(param).requires_grad = False 107 | if freeze_bn_stats: 108 | gmodules = group_modules(self.trunk, matcher, reverse=True) 109 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 110 | freeze_batch_norm_2d(self.trunk, gmodules) 111 | 112 | @torch.jit.ignore 113 | def set_grad_checkpointing(self, enable=True): 114 | try: 115 | self.trunk.set_grad_checkpointing(enable) 116 | except Exception as e: 117 | logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') 118 | 119 | def forward(self, x): 120 | x = self.trunk(x) 121 | x = self.head(x) 122 | return x 123 | -------------------------------------------------------------------------------- /eva_clip/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ CLIP tokenizer 2 | 3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | import gzip 6 | import html 7 | import os 8 | from functools import lru_cache 9 | from typing import Union, List 10 | 11 | import ftfy 12 | import regex as re 13 | import torch 14 | 15 | # https://stackoverflow.com/q/62691279 16 | import os 17 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 18 | 19 | 20 | @lru_cache() 21 | def default_bpe(): 22 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 23 | 24 | 25 | @lru_cache() 26 | def bytes_to_unicode(): 27 | """ 28 | Returns list of utf-8 byte and a corresponding list of unicode strings. 29 | The reversible bpe codes work on unicode strings. 30 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 31 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 32 | This is a signficant percentage of your normal, say, 32K bpe vocab. 33 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 34 | And avoids mapping to whitespace/control characters the bpe code barfs on. 35 | """ 36 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 37 | cs = bs[:] 38 | n = 0 39 | for b in range(2**8): 40 | if b not in bs: 41 | bs.append(b) 42 | cs.append(2**8+n) 43 | n += 1 44 | cs = [chr(n) for n in cs] 45 | return dict(zip(bs, cs)) 46 | 47 | 48 | def get_pairs(word): 49 | """Return set of symbol pairs in a word. 50 | Word is represented as tuple of symbols (symbols being variable-length strings). 51 | """ 52 | pairs = set() 53 | prev_char = word[0] 54 | for char in word[1:]: 55 | pairs.add((prev_char, char)) 56 | prev_char = char 57 | return pairs 58 | 59 | 60 | def basic_clean(text): 61 | text = ftfy.fix_text(text) 62 | text = html.unescape(html.unescape(text)) 63 | return text.strip() 64 | 65 | 66 | def whitespace_clean(text): 67 | text = re.sub(r'\s+', ' ', text) 68 | text = text.strip() 69 | return text 70 | 71 | 72 | class SimpleTokenizer(object): 73 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 74 | self.byte_encoder = bytes_to_unicode() 75 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 76 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 77 | merges = merges[1:49152-256-2+1] 78 | merges = [tuple(merge.split()) for merge in merges] 79 | vocab = list(bytes_to_unicode().values()) 80 | vocab = vocab + [v+'' for v in vocab] 81 | for merge in merges: 82 | vocab.append(''.join(merge)) 83 | if not special_tokens: 84 | special_tokens = ['', ''] 85 | else: 86 | special_tokens = ['', ''] + special_tokens 87 | vocab.extend(special_tokens) 88 | self.encoder = dict(zip(vocab, range(len(vocab)))) 89 | self.decoder = {v: k for k, v in self.encoder.items()} 90 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 91 | self.cache = {t:t for t in special_tokens} 92 | special = "|".join(special_tokens) 93 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 94 | 95 | self.vocab_size = len(self.encoder) 96 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 97 | 98 | def bpe(self, token): 99 | if token in self.cache: 100 | return self.cache[token] 101 | word = tuple(token[:-1]) + ( token[-1] + '',) 102 | pairs = get_pairs(word) 103 | 104 | if not pairs: 105 | return token+'' 106 | 107 | while True: 108 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 109 | if bigram not in self.bpe_ranks: 110 | break 111 | first, second = bigram 112 | new_word = [] 113 | i = 0 114 | while i < len(word): 115 | try: 116 | j = word.index(first, i) 117 | new_word.extend(word[i:j]) 118 | i = j 119 | except: 120 | new_word.extend(word[i:]) 121 | break 122 | 123 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 124 | new_word.append(first+second) 125 | i += 2 126 | else: 127 | new_word.append(word[i]) 128 | i += 1 129 | new_word = tuple(new_word) 130 | word = new_word 131 | if len(word) == 1: 132 | break 133 | else: 134 | pairs = get_pairs(word) 135 | word = ' '.join(word) 136 | self.cache[token] = word 137 | return word 138 | 139 | def encode(self, text): 140 | bpe_tokens = [] 141 | text = whitespace_clean(basic_clean(text)).lower() 142 | for token in re.findall(self.pat, text): 143 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 144 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 145 | return bpe_tokens 146 | 147 | def decode(self, tokens): 148 | text = ''.join([self.decoder[token] for token in tokens]) 149 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 150 | return text 151 | 152 | 153 | _tokenizer = SimpleTokenizer() 154 | 155 | 156 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 157 | """ 158 | Returns the tokenized representation of given input string(s) 159 | 160 | Parameters 161 | ---------- 162 | texts : Union[str, List[str]] 163 | An input string or a list of input strings to tokenize 164 | context_length : int 165 | The context length to use; all CLIP models use 77 as the context length 166 | 167 | Returns 168 | ------- 169 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 170 | """ 171 | if isinstance(texts, str): 172 | texts = [texts] 173 | 174 | sot_token = _tokenizer.encoder[""] 175 | eot_token = _tokenizer.encoder[""] 176 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 177 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 178 | 179 | for i, tokens in enumerate(all_tokens): 180 | if len(tokens) > context_length: 181 | tokens = tokens[:context_length] # Truncate 182 | tokens[-1] = eot_token 183 | result[i, :len(tokens)] = torch.tensor(tokens) 184 | 185 | return result 186 | 187 | 188 | class HFTokenizer: 189 | "HuggingFace tokenizer wrapper" 190 | def __init__(self, tokenizer_name:str): 191 | from transformers import AutoTokenizer 192 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 193 | 194 | def __call__(self, texts:Union[str, List[str]], context_length:int=77) -> torch.Tensor: 195 | # same cleaning as for default tokenizer, except lowercasing 196 | # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance 197 | if isinstance(texts, str): 198 | texts = [texts] 199 | texts = [whitespace_clean(basic_clean(text)) for text in texts] 200 | input_ids = self.tokenizer(texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True).input_ids 201 | return input_ids 202 | -------------------------------------------------------------------------------- /eva_clip/transform.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.transforms.functional as F 6 | 7 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ 8 | CenterCrop 9 | 10 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 11 | 12 | 13 | class ResizeMaxSize(nn.Module): 14 | 15 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): 16 | super().__init__() 17 | if not isinstance(max_size, int): 18 | raise TypeError(f"Size should be int. Got {type(max_size)}") 19 | self.max_size = max_size 20 | self.interpolation = interpolation 21 | self.fn = min if fn == 'min' else min 22 | self.fill = fill 23 | 24 | def forward(self, img): 25 | if isinstance(img, torch.Tensor): 26 | height, width = img.shape[:2] 27 | else: 28 | width, height = img.size 29 | scale = self.max_size / float(max(height, width)) 30 | if scale != 1.0: 31 | new_size = tuple(round(dim * scale) for dim in (height, width)) 32 | img = F.resize(img, new_size, self.interpolation) 33 | pad_h = self.max_size - new_size[0] 34 | pad_w = self.max_size - new_size[1] 35 | img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) 36 | return img 37 | 38 | 39 | def _convert_to_rgb(image): 40 | return image.convert('RGB') 41 | 42 | 43 | # class CatGen(nn.Module): 44 | # def __init__(self, num=4): 45 | # self.num = num 46 | # def mixgen_batch(image, text): 47 | # batch_size = image.shape[0] 48 | # index = np.random.permutation(batch_size) 49 | 50 | # cat_images = [] 51 | # for i in range(batch_size): 52 | # # image mixup 53 | # image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:] 54 | # # text concat 55 | # text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0] 56 | # text = torch.stack(text) 57 | # return image, text 58 | 59 | 60 | def image_transform( 61 | image_size: int, 62 | is_train: bool, 63 | mean: Optional[Tuple[float, ...]] = None, 64 | std: Optional[Tuple[float, ...]] = None, 65 | resize_longest_max: bool = False, 66 | fill_color: int = 0, 67 | ): 68 | mean = mean or OPENAI_DATASET_MEAN 69 | if not isinstance(mean, (list, tuple)): 70 | mean = (mean,) * 3 71 | 72 | std = std or OPENAI_DATASET_STD 73 | if not isinstance(std, (list, tuple)): 74 | std = (std,) * 3 75 | 76 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: 77 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge 78 | image_size = image_size[0] 79 | 80 | normalize = Normalize(mean=mean, std=std) 81 | if is_train: 82 | return Compose([ 83 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), 84 | _convert_to_rgb, 85 | ToTensor(), 86 | normalize, 87 | ]) 88 | else: 89 | if resize_longest_max: 90 | transforms = [ 91 | ResizeMaxSize(image_size, fill=fill_color) 92 | ] 93 | else: 94 | transforms = [ 95 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 96 | CenterCrop(image_size), 97 | ] 98 | transforms.extend([ 99 | _convert_to_rgb, 100 | ToTensor(), 101 | normalize, 102 | ]) 103 | return Compose(transforms) 104 | -------------------------------------------------------------------------------- /eva_clip/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | import collections.abc 3 | import logging 4 | import math 5 | import numpy as np 6 | 7 | import torch 8 | from torch import nn as nn 9 | from torchvision.ops.misc import FrozenBatchNorm2d 10 | import torch.nn.functional as F 11 | 12 | # open CLIP 13 | def resize_clip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1): 14 | # Rescale the grid of position embeddings when loading from state_dict 15 | old_pos_embed = state_dict.get('visual.positional_embedding', None) 16 | if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): 17 | return 18 | grid_size = to_2tuple(model.visual.grid_size) 19 | extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) 20 | new_seq_len = grid_size[0] * grid_size[1] + extra_tokens 21 | if new_seq_len == old_pos_embed.shape[0]: 22 | return 23 | 24 | if extra_tokens: 25 | pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] 26 | else: 27 | pos_emb_tok, pos_emb_img = None, old_pos_embed 28 | old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) 29 | 30 | logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) 31 | pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) 32 | pos_emb_img = F.interpolate( 33 | pos_emb_img, 34 | size=grid_size, 35 | mode=interpolation, 36 | align_corners=True, 37 | ) 38 | pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] 39 | if pos_emb_tok is not None: 40 | new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) 41 | else: 42 | new_pos_embed = pos_emb_img 43 | state_dict['visual.positional_embedding'] = new_pos_embed 44 | 45 | 46 | def resize_visual_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1): 47 | # Rescale the grid of position embeddings when loading from state_dict 48 | old_pos_embed = state_dict.get('positional_embedding', None) 49 | if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): 50 | return 51 | grid_size = to_2tuple(model.visual.grid_size) 52 | extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) 53 | new_seq_len = grid_size[0] * grid_size[1] + extra_tokens 54 | if new_seq_len == old_pos_embed.shape[0]: 55 | return 56 | 57 | if extra_tokens: 58 | pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] 59 | else: 60 | pos_emb_tok, pos_emb_img = None, old_pos_embed 61 | old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) 62 | 63 | logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) 64 | pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) 65 | pos_emb_img = F.interpolate( 66 | pos_emb_img, 67 | size=grid_size, 68 | mode=interpolation, 69 | align_corners=True, 70 | ) 71 | pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] 72 | if pos_emb_tok is not None: 73 | new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) 74 | else: 75 | new_pos_embed = pos_emb_img 76 | state_dict['positional_embedding'] = new_pos_embed 77 | 78 | def resize_evaclip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1): 79 | all_keys = list(state_dict.keys()) 80 | # interpolate position embedding 81 | if 'visual.pos_embed' in state_dict: 82 | pos_embed_checkpoint = state_dict['visual.pos_embed'] 83 | embedding_size = pos_embed_checkpoint.shape[-1] 84 | num_patches = model.visual.patch_embed.num_patches 85 | num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches 86 | # height (== width) for the checkpoint position embedding 87 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 88 | # height (== width) for the new position embedding 89 | new_size = int(num_patches ** 0.5) 90 | # class_token and dist_token are kept unchanged 91 | if orig_size != new_size: 92 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 93 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 94 | # only the position tokens are interpolated 95 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 96 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 97 | pos_tokens = torch.nn.functional.interpolate( 98 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 99 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 100 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 101 | state_dict['visual.pos_embed'] = new_pos_embed 102 | 103 | patch_embed_proj = state_dict['visual.patch_embed.proj.weight'] 104 | patch_size = model.visual.patch_embed.patch_size 105 | state_dict['visual.patch_embed.proj.weight'] = torch.nn.functional.interpolate( 106 | patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False) 107 | 108 | 109 | def resize_eva_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1): 110 | all_keys = list(state_dict.keys()) 111 | # interpolate position embedding 112 | if 'pos_embed' in state_dict: 113 | pos_embed_checkpoint = state_dict['pos_embed'] 114 | embedding_size = pos_embed_checkpoint.shape[-1] 115 | num_patches = model.visual.patch_embed.num_patches 116 | num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches 117 | # height (== width) for the checkpoint position embedding 118 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 119 | # height (== width) for the new position embedding 120 | new_size = int(num_patches ** 0.5) 121 | # class_token and dist_token are kept unchanged 122 | if orig_size != new_size: 123 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 124 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 125 | # only the position tokens are interpolated 126 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 127 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 128 | pos_tokens = torch.nn.functional.interpolate( 129 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 130 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 131 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 132 | state_dict['pos_embed'] = new_pos_embed 133 | 134 | patch_embed_proj = state_dict['patch_embed.proj.weight'] 135 | patch_size = model.visual.patch_embed.patch_size 136 | state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate( 137 | patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False) 138 | 139 | 140 | def resize_rel_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1): 141 | all_keys = list(state_dict.keys()) 142 | for key in all_keys: 143 | if "relative_position_index" in key: 144 | state_dict.pop(key) 145 | 146 | if "relative_position_bias_table" in key: 147 | rel_pos_bias = state_dict[key] 148 | src_num_pos, num_attn_heads = rel_pos_bias.size() 149 | dst_num_pos, _ = model.visual.state_dict()[key].size() 150 | dst_patch_shape = model.visual.patch_embed.patch_shape 151 | if dst_patch_shape[0] != dst_patch_shape[1]: 152 | raise NotImplementedError() 153 | num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1) 154 | src_size = int((src_num_pos - num_extra_tokens) ** 0.5) 155 | dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) 156 | if src_size != dst_size: 157 | print("Position interpolate for %s from %dx%d to %dx%d" % ( 158 | key, src_size, src_size, dst_size, dst_size)) 159 | extra_tokens = rel_pos_bias[-num_extra_tokens:, :] 160 | rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] 161 | 162 | def geometric_progression(a, r, n): 163 | return a * (1.0 - r ** n) / (1.0 - r) 164 | 165 | left, right = 1.01, 1.5 166 | while right - left > 1e-6: 167 | q = (left + right) / 2.0 168 | gp = geometric_progression(1, q, src_size // 2) 169 | if gp > dst_size // 2: 170 | right = q 171 | else: 172 | left = q 173 | 174 | # if q > 1.090307: 175 | # q = 1.090307 176 | 177 | dis = [] 178 | cur = 1 179 | for i in range(src_size // 2): 180 | dis.append(cur) 181 | cur += q ** (i + 1) 182 | 183 | r_ids = [-_ for _ in reversed(dis)] 184 | 185 | x = r_ids + [0] + dis 186 | y = r_ids + [0] + dis 187 | 188 | t = dst_size // 2.0 189 | dx = np.arange(-t, t + 0.1, 1.0) 190 | dy = np.arange(-t, t + 0.1, 1.0) 191 | 192 | print("Original positions = %s" % str(x)) 193 | print("Target positions = %s" % str(dx)) 194 | 195 | all_rel_pos_bias = [] 196 | 197 | for i in range(num_attn_heads): 198 | z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() 199 | f = F.interpolate.interp2d(x, y, z, kind='cubic') 200 | all_rel_pos_bias.append( 201 | torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device)) 202 | 203 | rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) 204 | 205 | new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) 206 | state_dict[key] = new_rel_pos_bias 207 | 208 | # interpolate position embedding 209 | if 'pos_embed' in state_dict: 210 | pos_embed_checkpoint = state_dict['pos_embed'] 211 | embedding_size = pos_embed_checkpoint.shape[-1] 212 | num_patches = model.visual.patch_embed.num_patches 213 | num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches 214 | # height (== width) for the checkpoint position embedding 215 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 216 | # height (== width) for the new position embedding 217 | new_size = int(num_patches ** 0.5) 218 | # class_token and dist_token are kept unchanged 219 | if orig_size != new_size: 220 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 221 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 222 | # only the position tokens are interpolated 223 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 224 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 225 | pos_tokens = torch.nn.functional.interpolate( 226 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 227 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 228 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 229 | state_dict['pos_embed'] = new_pos_embed 230 | 231 | patch_embed_proj = state_dict['patch_embed.proj.weight'] 232 | patch_size = model.visual.patch_embed.patch_size 233 | state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate( 234 | patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False) 235 | 236 | 237 | def freeze_batch_norm_2d(module, module_match={}, name=''): 238 | """ 239 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 240 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 241 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 242 | 243 | Args: 244 | module (torch.nn.Module): Any PyTorch module. 245 | module_match (dict): Dictionary of full module names to freeze (all if empty) 246 | name (str): Full module name (prefix) 247 | 248 | Returns: 249 | torch.nn.Module: Resulting module 250 | 251 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 252 | """ 253 | res = module 254 | is_match = True 255 | if module_match: 256 | is_match = name in module_match 257 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): 258 | res = FrozenBatchNorm2d(module.num_features) 259 | res.num_features = module.num_features 260 | res.affine = module.affine 261 | if module.affine: 262 | res.weight.data = module.weight.data.clone().detach() 263 | res.bias.data = module.bias.data.clone().detach() 264 | res.running_mean.data = module.running_mean.data 265 | res.running_var.data = module.running_var.data 266 | res.eps = module.eps 267 | else: 268 | for child_name, child in module.named_children(): 269 | full_child_name = '.'.join([name, child_name]) if name else child_name 270 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 271 | if new_child is not child: 272 | res.add_module(child_name, new_child) 273 | return res 274 | 275 | 276 | # From PyTorch internals 277 | def _ntuple(n): 278 | def parse(x): 279 | if isinstance(x, collections.abc.Iterable): 280 | return x 281 | return tuple(repeat(x, n)) 282 | return parse 283 | 284 | 285 | to_1tuple = _ntuple(1) 286 | to_2tuple = _ntuple(2) 287 | to_3tuple = _ntuple(3) 288 | to_4tuple = _ntuple(4) 289 | to_ntuple = lambda n, x: _ntuple(n)(x) 290 | 291 | 292 | def is_logging(args): 293 | def is_global_master(args): 294 | return args.rank == 0 295 | 296 | def is_local_master(args): 297 | return args.local_rank == 0 298 | 299 | def is_master(args, local=False): 300 | return is_local_master(args) if local else is_global_master(args) 301 | return is_master 302 | 303 | 304 | class AllGather(torch.autograd.Function): 305 | """An autograd function that performs allgather on a tensor. 306 | Performs all_gather operation on the provided tensors. 307 | *** Warning ***: torch.distributed.all_gather has no gradient. 308 | """ 309 | 310 | @staticmethod 311 | def forward(ctx, tensor, rank, world_size): 312 | tensors_gather = [torch.empty_like(tensor) for _ in range(world_size)] 313 | torch.distributed.all_gather(tensors_gather, tensor) 314 | ctx.rank = rank 315 | ctx.batch_size = tensor.shape[0] 316 | return torch.cat(tensors_gather, 0) 317 | 318 | @staticmethod 319 | def backward(ctx, grad_output): 320 | return ( 321 | grad_output[ctx.batch_size * ctx.rank: ctx.batch_size * (ctx.rank + 1)], 322 | None, 323 | None 324 | ) 325 | 326 | allgather = AllGather.apply -------------------------------------------------------------------------------- /examples/einstein.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/balazik/ComfyUI-PuLID-Flux/a80912fc3435c358607bf4b43a58dbcbebdb09ff/examples/einstein.jpg -------------------------------------------------------------------------------- /examples/pulid_flux_16bit_simple.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 64, 3 | "last_link_id": 132, 4 | "nodes": [ 5 | { 6 | "id": 25, 7 | "type": "RandomNoise", 8 | "pos": { 9 | "0": 6, 10 | "1": -135 11 | }, 12 | "size": { 13 | "0": 315, 14 | "1": 82 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "inputs": [], 20 | "outputs": [ 21 | { 22 | "name": "NOISE", 23 | "type": "NOISE", 24 | "links": [ 25 | 84 26 | ], 27 | "slot_index": 0, 28 | "shape": 3 29 | } 30 | ], 31 | "properties": { 32 | "Node name for S&R": "RandomNoise" 33 | }, 34 | "widgets_values": [ 35 | 186462208016243, 36 | "fixed" 37 | ], 38 | "color": "#2a363b", 39 | "bgcolor": "#3f5159" 40 | }, 41 | { 42 | "id": 26, 43 | "type": "FluxGuidance", 44 | "pos": { 45 | "0": 372, 46 | "1": -171 47 | }, 48 | "size": { 49 | "0": 317.4000244140625, 50 | "1": 58 51 | }, 52 | "flags": { 53 | "collapsed": false 54 | }, 55 | "order": 13, 56 | "mode": 0, 57 | "inputs": [ 58 | { 59 | "name": "conditioning", 60 | "type": "CONDITIONING", 61 | "link": 41 62 | } 63 | ], 64 | "outputs": [ 65 | { 66 | "name": "CONDITIONING", 67 | "type": "CONDITIONING", 68 | "links": [ 69 | 107 70 | ], 71 | "slot_index": 0, 72 | "shape": 3 73 | } 74 | ], 75 | "properties": { 76 | "Node name for S&R": "FluxGuidance" 77 | }, 78 | "widgets_values": [ 79 | 3.5 80 | ], 81 | "color": "#233", 82 | "bgcolor": "#355" 83 | }, 84 | { 85 | "id": 6, 86 | "type": "CLIPTextEncode", 87 | "pos": { 88 | "0": 372, 89 | "1": -55 90 | }, 91 | "size": { 92 | "0": 422.84503173828125, 93 | "1": 164.31304931640625 94 | }, 95 | "flags": {}, 96 | "order": 12, 97 | "mode": 0, 98 | "inputs": [ 99 | { 100 | "name": "clip", 101 | "type": "CLIP", 102 | "link": 132 103 | } 104 | ], 105 | "outputs": [ 106 | { 107 | "name": "CONDITIONING", 108 | "type": "CONDITIONING", 109 | "links": [ 110 | 41 111 | ], 112 | "slot_index": 0 113 | } 114 | ], 115 | "title": "CLIP Text Encode (Positive Prompt)", 116 | "properties": { 117 | "Node name for S&R": "CLIPTextEncode" 118 | }, 119 | "widgets_values": [ 120 | "Half body portrait of 60 years old guy, with an surprised expression, he is lost in vectors of AI models, sourounded by PC monitors and many cables, on his tshirt is a text with words printed in Arial font:\"PuLID Flux\", detailed, glowy background, photorealistic style with skin inperfections, looks like shot with an smartphone, skin details without plastic look, ASUS Keyboard." 121 | ], 122 | "color": "#232", 123 | "bgcolor": "#353" 124 | }, 125 | { 126 | "id": 27, 127 | "type": "EmptySD3LatentImage", 128 | "pos": { 129 | "0": 383, 130 | "1": 155 131 | }, 132 | "size": { 133 | "0": 315, 134 | "1": 106 135 | }, 136 | "flags": {}, 137 | "order": 1, 138 | "mode": 0, 139 | "inputs": [], 140 | "outputs": [ 141 | { 142 | "name": "LATENT", 143 | "type": "LATENT", 144 | "links": [ 145 | 86 146 | ], 147 | "slot_index": 0, 148 | "shape": 3 149 | } 150 | ], 151 | "properties": { 152 | "Node name for S&R": "EmptySD3LatentImage" 153 | }, 154 | "widgets_values": [ 155 | 768, 156 | 1024, 157 | 1 158 | ], 159 | "color": "#323", 160 | "bgcolor": "#535" 161 | }, 162 | { 163 | "id": 16, 164 | "type": "KSamplerSelect", 165 | "pos": { 166 | "0": 384, 167 | "1": 313 168 | }, 169 | "size": { 170 | "0": 315, 171 | "1": 58 172 | }, 173 | "flags": {}, 174 | "order": 2, 175 | "mode": 0, 176 | "inputs": [], 177 | "outputs": [ 178 | { 179 | "name": "SAMPLER", 180 | "type": "SAMPLER", 181 | "links": [ 182 | 85 183 | ], 184 | "slot_index": 0, 185 | "shape": 3 186 | } 187 | ], 188 | "properties": { 189 | "Node name for S&R": "KSamplerSelect" 190 | }, 191 | "widgets_values": [ 192 | "euler" 193 | ] 194 | }, 195 | { 196 | "id": 17, 197 | "type": "BasicScheduler", 198 | "pos": { 199 | "0": 392, 200 | "1": 424 201 | }, 202 | "size": { 203 | "0": 315, 204 | "1": 106 205 | }, 206 | "flags": { 207 | "collapsed": false 208 | }, 209 | "order": 11, 210 | "mode": 0, 211 | "inputs": [ 212 | { 213 | "name": "model", 214 | "type": "MODEL", 215 | "link": 131, 216 | "slot_index": 0 217 | } 218 | ], 219 | "outputs": [ 220 | { 221 | "name": "SIGMAS", 222 | "type": "SIGMAS", 223 | "links": [ 224 | 93 225 | ], 226 | "slot_index": 0, 227 | "shape": 3 228 | } 229 | ], 230 | "properties": { 231 | "Node name for S&R": "BasicScheduler" 232 | }, 233 | "widgets_values": [ 234 | "simple", 235 | 10, 236 | 1 237 | ] 238 | }, 239 | { 240 | "id": 54, 241 | "type": "LoadImage", 242 | "pos": { 243 | "0": 729, 244 | "1": -490 245 | }, 246 | "size": { 247 | "0": 315, 248 | "1": 314 249 | }, 250 | "flags": {}, 251 | "order": 3, 252 | "mode": 0, 253 | "inputs": [], 254 | "outputs": [ 255 | { 256 | "name": "IMAGE", 257 | "type": "IMAGE", 258 | "links": [ 259 | 126 260 | ], 261 | "slot_index": 0, 262 | "shape": 3 263 | }, 264 | { 265 | "name": "MASK", 266 | "type": "MASK", 267 | "links": null, 268 | "shape": 3 269 | } 270 | ], 271 | "properties": { 272 | "Node name for S&R": "LoadImage" 273 | }, 274 | "widgets_values": [ 275 | "einstein.jpg", 276 | "image" 277 | ] 278 | }, 279 | { 280 | "id": 53, 281 | "type": "PulidFluxInsightFaceLoader", 282 | "pos": { 283 | "0": 822, 284 | "1": -80 285 | }, 286 | "size": { 287 | "0": 365.4000244140625, 288 | "1": 58 289 | }, 290 | "flags": {}, 291 | "order": 4, 292 | "mode": 0, 293 | "inputs": [], 294 | "outputs": [ 295 | { 296 | "name": "FACEANALYSIS", 297 | "type": "FACEANALYSIS", 298 | "links": [ 299 | 124 300 | ], 301 | "slot_index": 0, 302 | "shape": 3 303 | } 304 | ], 305 | "properties": { 306 | "Node name for S&R": "PulidFluxInsightFaceLoader" 307 | }, 308 | "widgets_values": [ 309 | "CPU" 310 | ] 311 | }, 312 | { 313 | "id": 51, 314 | "type": "PulidFluxEvaClipLoader", 315 | "pos": { 316 | "0": 845, 317 | "1": 52 318 | }, 319 | "size": { 320 | "0": 327.5999755859375, 321 | "1": 26 322 | }, 323 | "flags": {}, 324 | "order": 5, 325 | "mode": 0, 326 | "inputs": [], 327 | "outputs": [ 328 | { 329 | "name": "EVA_CLIP", 330 | "type": "EVA_CLIP", 331 | "links": [ 332 | 123 333 | ], 334 | "slot_index": 0, 335 | "shape": 3 336 | } 337 | ], 338 | "properties": { 339 | "Node name for S&R": "PulidFluxEvaClipLoader" 340 | } 341 | }, 342 | { 343 | "id": 45, 344 | "type": "PulidFluxModelLoader", 345 | "pos": { 346 | "0": 846, 347 | "1": 137 348 | }, 349 | "size": { 350 | "0": 315, 351 | "1": 58 352 | }, 353 | "flags": {}, 354 | "order": 6, 355 | "mode": 0, 356 | "inputs": [], 357 | "outputs": [ 358 | { 359 | "name": "PULIDFLUX", 360 | "type": "PULIDFLUX", 361 | "links": [ 362 | 125 363 | ], 364 | "slot_index": 0, 365 | "shape": 3 366 | } 367 | ], 368 | "properties": { 369 | "Node name for S&R": "PulidFluxModelLoader" 370 | }, 371 | "widgets_values": [ 372 | "pulid_flux_v0.9.0.safetensors" 373 | ] 374 | }, 375 | { 376 | "id": 62, 377 | "type": "ApplyPulidFlux", 378 | "pos": { 379 | "0": 842, 380 | "1": 258 381 | }, 382 | "size": { 383 | "0": 315, 384 | "1": 206 385 | }, 386 | "flags": {}, 387 | "order": 10, 388 | "mode": 0, 389 | "inputs": [ 390 | { 391 | "name": "model", 392 | "type": "MODEL", 393 | "link": 130 394 | }, 395 | { 396 | "name": "pulid_flux", 397 | "type": "PULIDFLUX", 398 | "link": 125 399 | }, 400 | { 401 | "name": "eva_clip", 402 | "type": "EVA_CLIP", 403 | "link": 123 404 | }, 405 | { 406 | "name": "face_analysis", 407 | "type": "FACEANALYSIS", 408 | "link": 124 409 | }, 410 | { 411 | "name": "image", 412 | "type": "IMAGE", 413 | "link": 126 414 | }, 415 | { 416 | "name": "attn_mask", 417 | "type": "MASK", 418 | "link": null 419 | } 420 | ], 421 | "outputs": [ 422 | { 423 | "name": "MODEL", 424 | "type": "MODEL", 425 | "links": [ 426 | 122 427 | ], 428 | "slot_index": 0, 429 | "shape": 3 430 | } 431 | ], 432 | "properties": { 433 | "Node name for S&R": "ApplyPulidFlux" 434 | }, 435 | "widgets_values": [ 436 | 1, 437 | 0, 438 | 1 439 | ] 440 | }, 441 | { 442 | "id": 47, 443 | "type": "BasicGuider", 444 | "pos": { 445 | "0": 1217, 446 | "1": 401 447 | }, 448 | "size": { 449 | "0": 241.79998779296875, 450 | "1": 46 451 | }, 452 | "flags": {}, 453 | "order": 14, 454 | "mode": 0, 455 | "inputs": [ 456 | { 457 | "name": "model", 458 | "type": "MODEL", 459 | "link": 122 460 | }, 461 | { 462 | "name": "conditioning", 463 | "type": "CONDITIONING", 464 | "link": 107 465 | } 466 | ], 467 | "outputs": [ 468 | { 469 | "name": "GUIDER", 470 | "type": "GUIDER", 471 | "links": [ 472 | 83 473 | ], 474 | "slot_index": 0, 475 | "shape": 3 476 | } 477 | ], 478 | "properties": { 479 | "Node name for S&R": "BasicGuider" 480 | } 481 | }, 482 | { 483 | "id": 48, 484 | "type": "SamplerCustomAdvanced", 485 | "pos": { 486 | "0": 1205, 487 | "1": -39 488 | }, 489 | "size": { 490 | "0": 355.20001220703125, 491 | "1": 326 492 | }, 493 | "flags": {}, 494 | "order": 15, 495 | "mode": 0, 496 | "inputs": [ 497 | { 498 | "name": "noise", 499 | "type": "NOISE", 500 | "link": 84 501 | }, 502 | { 503 | "name": "guider", 504 | "type": "GUIDER", 505 | "link": 83 506 | }, 507 | { 508 | "name": "sampler", 509 | "type": "SAMPLER", 510 | "link": 85 511 | }, 512 | { 513 | "name": "sigmas", 514 | "type": "SIGMAS", 515 | "link": 93 516 | }, 517 | { 518 | "name": "latent_image", 519 | "type": "LATENT", 520 | "link": 86 521 | } 522 | ], 523 | "outputs": [ 524 | { 525 | "name": "output", 526 | "type": "LATENT", 527 | "links": [ 528 | 87 529 | ], 530 | "slot_index": 0, 531 | "shape": 3 532 | }, 533 | { 534 | "name": "denoised_output", 535 | "type": "LATENT", 536 | "links": null, 537 | "shape": 3 538 | } 539 | ], 540 | "properties": { 541 | "Node name for S&R": "SamplerCustomAdvanced" 542 | } 543 | }, 544 | { 545 | "id": 49, 546 | "type": "VAEDecode", 547 | "pos": { 548 | "0": 1263, 549 | "1": -137 550 | }, 551 | "size": { 552 | "0": 210, 553 | "1": 46 554 | }, 555 | "flags": {}, 556 | "order": 16, 557 | "mode": 0, 558 | "inputs": [ 559 | { 560 | "name": "samples", 561 | "type": "LATENT", 562 | "link": 87 563 | }, 564 | { 565 | "name": "vae", 566 | "type": "VAE", 567 | "link": 88 568 | } 569 | ], 570 | "outputs": [ 571 | { 572 | "name": "IMAGE", 573 | "type": "IMAGE", 574 | "links": [ 575 | 89 576 | ], 577 | "slot_index": 0, 578 | "shape": 3 579 | } 580 | ], 581 | "properties": { 582 | "Node name for S&R": "VAEDecode" 583 | } 584 | }, 585 | { 586 | "id": 50, 587 | "type": "PreviewImage", 588 | "pos": { 589 | "0": 1587, 590 | "1": -169 591 | }, 592 | "size": { 593 | "0": 841.524169921875, 594 | "1": 698.3060302734375 595 | }, 596 | "flags": {}, 597 | "order": 17, 598 | "mode": 0, 599 | "inputs": [ 600 | { 601 | "name": "images", 602 | "type": "IMAGE", 603 | "link": 89 604 | } 605 | ], 606 | "outputs": [], 607 | "properties": { 608 | "Node name for S&R": "PreviewImage" 609 | } 610 | }, 611 | { 612 | "id": 63, 613 | "type": "UNETLoader", 614 | "pos": { 615 | "0": 6, 616 | "1": -7 617 | }, 618 | "size": { 619 | "0": 315, 620 | "1": 82 621 | }, 622 | "flags": {}, 623 | "order": 7, 624 | "mode": 0, 625 | "inputs": [], 626 | "outputs": [ 627 | { 628 | "name": "MODEL", 629 | "type": "MODEL", 630 | "links": [ 631 | 130, 632 | 131 633 | ], 634 | "shape": 3, 635 | "slot_index": 0 636 | } 637 | ], 638 | "properties": { 639 | "Node name for S&R": "UNETLoader" 640 | }, 641 | "widgets_values": [ 642 | "flux1-dev.safetensors", 643 | "default" 644 | ] 645 | }, 646 | { 647 | "id": 10, 648 | "type": "VAELoader", 649 | "pos": { 650 | "0": 12, 651 | "1": 285 652 | }, 653 | "size": { 654 | "0": 311.81634521484375, 655 | "1": 60.429901123046875 656 | }, 657 | "flags": {}, 658 | "order": 8, 659 | "mode": 0, 660 | "inputs": [], 661 | "outputs": [ 662 | { 663 | "name": "VAE", 664 | "type": "VAE", 665 | "links": [ 666 | 88 667 | ], 668 | "slot_index": 0, 669 | "shape": 3 670 | } 671 | ], 672 | "properties": { 673 | "Node name for S&R": "VAELoader" 674 | }, 675 | "widgets_values": [ 676 | "flux1_vae.safetensors" 677 | ] 678 | }, 679 | { 680 | "id": 64, 681 | "type": "DualCLIPLoader", 682 | "pos": { 683 | "0": 8, 684 | "1": 124 685 | }, 686 | "size": { 687 | "0": 315, 688 | "1": 106 689 | }, 690 | "flags": {}, 691 | "order": 9, 692 | "mode": 0, 693 | "inputs": [], 694 | "outputs": [ 695 | { 696 | "name": "CLIP", 697 | "type": "CLIP", 698 | "links": [ 699 | 132 700 | ], 701 | "shape": 3, 702 | "slot_index": 0 703 | } 704 | ], 705 | "properties": { 706 | "Node name for S&R": "DualCLIPLoader" 707 | }, 708 | "widgets_values": [ 709 | "t5xxl_fp16.safetensors", 710 | "clip_l.safetensors", 711 | "flux" 712 | ] 713 | } 714 | ], 715 | "links": [ 716 | [ 717 | 41, 718 | 6, 719 | 0, 720 | 26, 721 | 0, 722 | "CONDITIONING" 723 | ], 724 | [ 725 | 83, 726 | 47, 727 | 0, 728 | 48, 729 | 1, 730 | "GUIDER" 731 | ], 732 | [ 733 | 84, 734 | 25, 735 | 0, 736 | 48, 737 | 0, 738 | "NOISE" 739 | ], 740 | [ 741 | 85, 742 | 16, 743 | 0, 744 | 48, 745 | 2, 746 | "SAMPLER" 747 | ], 748 | [ 749 | 86, 750 | 27, 751 | 0, 752 | 48, 753 | 4, 754 | "LATENT" 755 | ], 756 | [ 757 | 87, 758 | 48, 759 | 0, 760 | 49, 761 | 0, 762 | "LATENT" 763 | ], 764 | [ 765 | 88, 766 | 10, 767 | 0, 768 | 49, 769 | 1, 770 | "VAE" 771 | ], 772 | [ 773 | 89, 774 | 49, 775 | 0, 776 | 50, 777 | 0, 778 | "IMAGE" 779 | ], 780 | [ 781 | 93, 782 | 17, 783 | 0, 784 | 48, 785 | 3, 786 | "SIGMAS" 787 | ], 788 | [ 789 | 107, 790 | 26, 791 | 0, 792 | 47, 793 | 1, 794 | "CONDITIONING" 795 | ], 796 | [ 797 | 122, 798 | 62, 799 | 0, 800 | 47, 801 | 0, 802 | "MODEL" 803 | ], 804 | [ 805 | 123, 806 | 51, 807 | 0, 808 | 62, 809 | 2, 810 | "EVA_CLIP" 811 | ], 812 | [ 813 | 124, 814 | 53, 815 | 0, 816 | 62, 817 | 3, 818 | "FACEANALYSIS" 819 | ], 820 | [ 821 | 125, 822 | 45, 823 | 0, 824 | 62, 825 | 1, 826 | "PULIDFLUX" 827 | ], 828 | [ 829 | 126, 830 | 54, 831 | 0, 832 | 62, 833 | 4, 834 | "IMAGE" 835 | ], 836 | [ 837 | 130, 838 | 63, 839 | 0, 840 | 62, 841 | 0, 842 | "MODEL" 843 | ], 844 | [ 845 | 131, 846 | 63, 847 | 0, 848 | 17, 849 | 0, 850 | "MODEL" 851 | ], 852 | [ 853 | 132, 854 | 64, 855 | 0, 856 | 6, 857 | 0, 858 | "CLIP" 859 | ] 860 | ], 861 | "groups": [], 862 | "config": {}, 863 | "extra": { 864 | "ds": { 865 | "scale": 0.9090909090909091, 866 | "offset": [ 867 | 113.84966682267732, 868 | 547.8597243753773 869 | ] 870 | } 871 | }, 872 | "version": 0.4 873 | } -------------------------------------------------------------------------------- /examples/pulid_flux_8bitgguf_simple.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 62, 3 | "last_link_id": 129, 4 | "nodes": [ 5 | { 6 | "id": 25, 7 | "type": "RandomNoise", 8 | "pos": { 9 | "0": 6, 10 | "1": -135 11 | }, 12 | "size": [ 13 | 315, 14 | 82 15 | ], 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "inputs": [], 20 | "outputs": [ 21 | { 22 | "name": "NOISE", 23 | "type": "NOISE", 24 | "links": [ 25 | 84 26 | ], 27 | "slot_index": 0, 28 | "shape": 3 29 | } 30 | ], 31 | "properties": { 32 | "Node name for S&R": "RandomNoise" 33 | }, 34 | "widgets_values": [ 35 | 186462208016243, 36 | "fixed" 37 | ], 38 | "color": "#2a363b", 39 | "bgcolor": "#3f5159" 40 | }, 41 | { 42 | "id": 31, 43 | "type": "UnetLoaderGGUF", 44 | "pos": { 45 | "0": 14, 46 | "1": 5 47 | }, 48 | "size": { 49 | "0": 315, 50 | "1": 58 51 | }, 52 | "flags": {}, 53 | "order": 1, 54 | "mode": 0, 55 | "inputs": [], 56 | "outputs": [ 57 | { 58 | "name": "MODEL", 59 | "type": "MODEL", 60 | "links": [ 61 | 127, 62 | 129 63 | ], 64 | "slot_index": 0, 65 | "shape": 3 66 | } 67 | ], 68 | "properties": { 69 | "Node name for S&R": "UnetLoaderGGUF" 70 | }, 71 | "widgets_values": [ 72 | "flux1-dev-Q8_0.gguf" 73 | ] 74 | }, 75 | { 76 | "id": 41, 77 | "type": "DualCLIPLoaderGGUF", 78 | "pos": { 79 | "0": 18, 80 | "1": 114 81 | }, 82 | "size": { 83 | "0": 315, 84 | "1": 106 85 | }, 86 | "flags": {}, 87 | "order": 2, 88 | "mode": 0, 89 | "inputs": [], 90 | "outputs": [ 91 | { 92 | "name": "CLIP", 93 | "type": "CLIP", 94 | "links": [ 95 | 128 96 | ], 97 | "slot_index": 0, 98 | "shape": 3 99 | } 100 | ], 101 | "properties": { 102 | "Node name for S&R": "DualCLIPLoaderGGUF" 103 | }, 104 | "widgets_values": [ 105 | "t5-v1_1-xxl-encoder-Q8_0.gguf", 106 | "clip_l.safetensors", 107 | "flux" 108 | ] 109 | }, 110 | { 111 | "id": 10, 112 | "type": "VAELoader", 113 | "pos": { 114 | "0": 23, 115 | "1": 275 116 | }, 117 | "size": { 118 | "0": 311.81634521484375, 119 | "1": 60.429901123046875 120 | }, 121 | "flags": {}, 122 | "order": 3, 123 | "mode": 0, 124 | "inputs": [], 125 | "outputs": [ 126 | { 127 | "name": "VAE", 128 | "type": "VAE", 129 | "links": [ 130 | 88 131 | ], 132 | "slot_index": 0, 133 | "shape": 3 134 | } 135 | ], 136 | "properties": { 137 | "Node name for S&R": "VAELoader" 138 | }, 139 | "widgets_values": [ 140 | "flux1_vae.safetensors" 141 | ] 142 | }, 143 | { 144 | "id": 26, 145 | "type": "FluxGuidance", 146 | "pos": { 147 | "0": 372, 148 | "1": -171 149 | }, 150 | "size": { 151 | "0": 317.4000244140625, 152 | "1": 58 153 | }, 154 | "flags": { 155 | "collapsed": false 156 | }, 157 | "order": 13, 158 | "mode": 0, 159 | "inputs": [ 160 | { 161 | "name": "conditioning", 162 | "type": "CONDITIONING", 163 | "link": 41 164 | } 165 | ], 166 | "outputs": [ 167 | { 168 | "name": "CONDITIONING", 169 | "type": "CONDITIONING", 170 | "links": [ 171 | 107 172 | ], 173 | "slot_index": 0, 174 | "shape": 3 175 | } 176 | ], 177 | "properties": { 178 | "Node name for S&R": "FluxGuidance" 179 | }, 180 | "widgets_values": [ 181 | 3.5 182 | ], 183 | "color": "#233", 184 | "bgcolor": "#355" 185 | }, 186 | { 187 | "id": 6, 188 | "type": "CLIPTextEncode", 189 | "pos": { 190 | "0": 372, 191 | "1": -55 192 | }, 193 | "size": { 194 | "0": 422.84503173828125, 195 | "1": 164.31304931640625 196 | }, 197 | "flags": {}, 198 | "order": 11, 199 | "mode": 0, 200 | "inputs": [ 201 | { 202 | "name": "clip", 203 | "type": "CLIP", 204 | "link": 128 205 | } 206 | ], 207 | "outputs": [ 208 | { 209 | "name": "CONDITIONING", 210 | "type": "CONDITIONING", 211 | "links": [ 212 | 41 213 | ], 214 | "slot_index": 0 215 | } 216 | ], 217 | "title": "CLIP Text Encode (Positive Prompt)", 218 | "properties": { 219 | "Node name for S&R": "CLIPTextEncode" 220 | }, 221 | "widgets_values": [ 222 | "Half body portrait of 60 years old guy, with an surprised expression, he is lost in vectors of AI models, sourounded by PC monitors and many cables, on his tshirt is a text with words printed in Arial font:\"PuLID Flux\", detailed, glowy background, photorealistic style with skin inperfections, looks like shot with an smartphone, skin details without plastic look, ASUS Keyboard." 223 | ], 224 | "color": "#232", 225 | "bgcolor": "#353" 226 | }, 227 | { 228 | "id": 27, 229 | "type": "EmptySD3LatentImage", 230 | "pos": { 231 | "0": 383, 232 | "1": 155 233 | }, 234 | "size": { 235 | "0": 315, 236 | "1": 106 237 | }, 238 | "flags": {}, 239 | "order": 4, 240 | "mode": 0, 241 | "inputs": [], 242 | "outputs": [ 243 | { 244 | "name": "LATENT", 245 | "type": "LATENT", 246 | "links": [ 247 | 86 248 | ], 249 | "slot_index": 0, 250 | "shape": 3 251 | } 252 | ], 253 | "properties": { 254 | "Node name for S&R": "EmptySD3LatentImage" 255 | }, 256 | "widgets_values": [ 257 | 768, 258 | 1024, 259 | 1 260 | ], 261 | "color": "#323", 262 | "bgcolor": "#535" 263 | }, 264 | { 265 | "id": 16, 266 | "type": "KSamplerSelect", 267 | "pos": { 268 | "0": 384, 269 | "1": 313 270 | }, 271 | "size": { 272 | "0": 315, 273 | "1": 58 274 | }, 275 | "flags": {}, 276 | "order": 5, 277 | "mode": 0, 278 | "inputs": [], 279 | "outputs": [ 280 | { 281 | "name": "SAMPLER", 282 | "type": "SAMPLER", 283 | "links": [ 284 | 85 285 | ], 286 | "slot_index": 0, 287 | "shape": 3 288 | } 289 | ], 290 | "properties": { 291 | "Node name for S&R": "KSamplerSelect" 292 | }, 293 | "widgets_values": [ 294 | "euler" 295 | ] 296 | }, 297 | { 298 | "id": 17, 299 | "type": "BasicScheduler", 300 | "pos": { 301 | "0": 392, 302 | "1": 424 303 | }, 304 | "size": { 305 | "0": 315, 306 | "1": 106 307 | }, 308 | "flags": { 309 | "collapsed": false 310 | }, 311 | "order": 10, 312 | "mode": 0, 313 | "inputs": [ 314 | { 315 | "name": "model", 316 | "type": "MODEL", 317 | "link": 129, 318 | "slot_index": 0 319 | } 320 | ], 321 | "outputs": [ 322 | { 323 | "name": "SIGMAS", 324 | "type": "SIGMAS", 325 | "links": [ 326 | 93 327 | ], 328 | "slot_index": 0, 329 | "shape": 3 330 | } 331 | ], 332 | "properties": { 333 | "Node name for S&R": "BasicScheduler" 334 | }, 335 | "widgets_values": [ 336 | "simple", 337 | 10, 338 | 1 339 | ] 340 | }, 341 | { 342 | "id": 54, 343 | "type": "LoadImage", 344 | "pos": { 345 | "0": 729, 346 | "1": -490 347 | }, 348 | "size": { 349 | "0": 315, 350 | "1": 314 351 | }, 352 | "flags": {}, 353 | "order": 6, 354 | "mode": 0, 355 | "inputs": [], 356 | "outputs": [ 357 | { 358 | "name": "IMAGE", 359 | "type": "IMAGE", 360 | "links": [ 361 | 126 362 | ], 363 | "slot_index": 0, 364 | "shape": 3 365 | }, 366 | { 367 | "name": "MASK", 368 | "type": "MASK", 369 | "links": null, 370 | "shape": 3 371 | } 372 | ], 373 | "properties": { 374 | "Node name for S&R": "LoadImage" 375 | }, 376 | "widgets_values": [ 377 | "einstein.jpg", 378 | "image" 379 | ] 380 | }, 381 | { 382 | "id": 53, 383 | "type": "PulidFluxInsightFaceLoader", 384 | "pos": { 385 | "0": 822, 386 | "1": -80 387 | }, 388 | "size": { 389 | "0": 365.4000244140625, 390 | "1": 58 391 | }, 392 | "flags": {}, 393 | "order": 7, 394 | "mode": 0, 395 | "inputs": [], 396 | "outputs": [ 397 | { 398 | "name": "FACEANALYSIS", 399 | "type": "FACEANALYSIS", 400 | "links": [ 401 | 124 402 | ], 403 | "slot_index": 0, 404 | "shape": 3 405 | } 406 | ], 407 | "properties": { 408 | "Node name for S&R": "PulidFluxInsightFaceLoader" 409 | }, 410 | "widgets_values": [ 411 | "CPU" 412 | ] 413 | }, 414 | { 415 | "id": 51, 416 | "type": "PulidFluxEvaClipLoader", 417 | "pos": { 418 | "0": 845, 419 | "1": 52 420 | }, 421 | "size": { 422 | "0": 327.5999755859375, 423 | "1": 26 424 | }, 425 | "flags": {}, 426 | "order": 8, 427 | "mode": 0, 428 | "inputs": [], 429 | "outputs": [ 430 | { 431 | "name": "EVA_CLIP", 432 | "type": "EVA_CLIP", 433 | "links": [ 434 | 123 435 | ], 436 | "slot_index": 0, 437 | "shape": 3 438 | } 439 | ], 440 | "properties": { 441 | "Node name for S&R": "PulidFluxEvaClipLoader" 442 | } 443 | }, 444 | { 445 | "id": 45, 446 | "type": "PulidFluxModelLoader", 447 | "pos": { 448 | "0": 846, 449 | "1": 137 450 | }, 451 | "size": { 452 | "0": 315, 453 | "1": 58 454 | }, 455 | "flags": {}, 456 | "order": 9, 457 | "mode": 0, 458 | "inputs": [], 459 | "outputs": [ 460 | { 461 | "name": "PULIDFLUX", 462 | "type": "PULIDFLUX", 463 | "links": [ 464 | 125 465 | ], 466 | "slot_index": 0, 467 | "shape": 3 468 | } 469 | ], 470 | "properties": { 471 | "Node name for S&R": "PulidFluxModelLoader" 472 | }, 473 | "widgets_values": [ 474 | "pulid_flux_v0.9.0.safetensors" 475 | ] 476 | }, 477 | { 478 | "id": 62, 479 | "type": "ApplyPulidFlux", 480 | "pos": { 481 | "0": 842, 482 | "1": 258 483 | }, 484 | "size": { 485 | "0": 315, 486 | "1": 206 487 | }, 488 | "flags": {}, 489 | "order": 12, 490 | "mode": 0, 491 | "inputs": [ 492 | { 493 | "name": "model", 494 | "type": "MODEL", 495 | "link": 127 496 | }, 497 | { 498 | "name": "pulid_flux", 499 | "type": "PULIDFLUX", 500 | "link": 125 501 | }, 502 | { 503 | "name": "eva_clip", 504 | "type": "EVA_CLIP", 505 | "link": 123 506 | }, 507 | { 508 | "name": "face_analysis", 509 | "type": "FACEANALYSIS", 510 | "link": 124 511 | }, 512 | { 513 | "name": "image", 514 | "type": "IMAGE", 515 | "link": 126 516 | }, 517 | { 518 | "name": "attn_mask", 519 | "type": "MASK", 520 | "link": null 521 | } 522 | ], 523 | "outputs": [ 524 | { 525 | "name": "MODEL", 526 | "type": "MODEL", 527 | "links": [ 528 | 122 529 | ], 530 | "shape": 3, 531 | "slot_index": 0 532 | } 533 | ], 534 | "properties": { 535 | "Node name for S&R": "ApplyPulidFlux" 536 | }, 537 | "widgets_values": [ 538 | 1, 539 | 0, 540 | 1 541 | ] 542 | }, 543 | { 544 | "id": 47, 545 | "type": "BasicGuider", 546 | "pos": { 547 | "0": 1217, 548 | "1": 401 549 | }, 550 | "size": { 551 | "0": 241.79998779296875, 552 | "1": 46 553 | }, 554 | "flags": {}, 555 | "order": 14, 556 | "mode": 0, 557 | "inputs": [ 558 | { 559 | "name": "model", 560 | "type": "MODEL", 561 | "link": 122 562 | }, 563 | { 564 | "name": "conditioning", 565 | "type": "CONDITIONING", 566 | "link": 107 567 | } 568 | ], 569 | "outputs": [ 570 | { 571 | "name": "GUIDER", 572 | "type": "GUIDER", 573 | "links": [ 574 | 83 575 | ], 576 | "slot_index": 0, 577 | "shape": 3 578 | } 579 | ], 580 | "properties": { 581 | "Node name for S&R": "BasicGuider" 582 | } 583 | }, 584 | { 585 | "id": 48, 586 | "type": "SamplerCustomAdvanced", 587 | "pos": { 588 | "0": 1205, 589 | "1": -39 590 | }, 591 | "size": { 592 | "0": 355.20001220703125, 593 | "1": 326 594 | }, 595 | "flags": {}, 596 | "order": 15, 597 | "mode": 0, 598 | "inputs": [ 599 | { 600 | "name": "noise", 601 | "type": "NOISE", 602 | "link": 84 603 | }, 604 | { 605 | "name": "guider", 606 | "type": "GUIDER", 607 | "link": 83 608 | }, 609 | { 610 | "name": "sampler", 611 | "type": "SAMPLER", 612 | "link": 85 613 | }, 614 | { 615 | "name": "sigmas", 616 | "type": "SIGMAS", 617 | "link": 93 618 | }, 619 | { 620 | "name": "latent_image", 621 | "type": "LATENT", 622 | "link": 86 623 | } 624 | ], 625 | "outputs": [ 626 | { 627 | "name": "output", 628 | "type": "LATENT", 629 | "links": [ 630 | 87 631 | ], 632 | "slot_index": 0, 633 | "shape": 3 634 | }, 635 | { 636 | "name": "denoised_output", 637 | "type": "LATENT", 638 | "links": null, 639 | "shape": 3 640 | } 641 | ], 642 | "properties": { 643 | "Node name for S&R": "SamplerCustomAdvanced" 644 | } 645 | }, 646 | { 647 | "id": 49, 648 | "type": "VAEDecode", 649 | "pos": { 650 | "0": 1263, 651 | "1": -137 652 | }, 653 | "size": { 654 | "0": 210, 655 | "1": 46 656 | }, 657 | "flags": {}, 658 | "order": 16, 659 | "mode": 0, 660 | "inputs": [ 661 | { 662 | "name": "samples", 663 | "type": "LATENT", 664 | "link": 87 665 | }, 666 | { 667 | "name": "vae", 668 | "type": "VAE", 669 | "link": 88 670 | } 671 | ], 672 | "outputs": [ 673 | { 674 | "name": "IMAGE", 675 | "type": "IMAGE", 676 | "links": [ 677 | 89 678 | ], 679 | "slot_index": 0, 680 | "shape": 3 681 | } 682 | ], 683 | "properties": { 684 | "Node name for S&R": "VAEDecode" 685 | } 686 | }, 687 | { 688 | "id": 50, 689 | "type": "PreviewImage", 690 | "pos": { 691 | "0": 1587, 692 | "1": -169 693 | }, 694 | "size": { 695 | "0": 841.524169921875, 696 | "1": 698.3060302734375 697 | }, 698 | "flags": {}, 699 | "order": 17, 700 | "mode": 0, 701 | "inputs": [ 702 | { 703 | "name": "images", 704 | "type": "IMAGE", 705 | "link": 89 706 | } 707 | ], 708 | "outputs": [], 709 | "properties": { 710 | "Node name for S&R": "PreviewImage" 711 | } 712 | } 713 | ], 714 | "links": [ 715 | [ 716 | 41, 717 | 6, 718 | 0, 719 | 26, 720 | 0, 721 | "CONDITIONING" 722 | ], 723 | [ 724 | 83, 725 | 47, 726 | 0, 727 | 48, 728 | 1, 729 | "GUIDER" 730 | ], 731 | [ 732 | 84, 733 | 25, 734 | 0, 735 | 48, 736 | 0, 737 | "NOISE" 738 | ], 739 | [ 740 | 85, 741 | 16, 742 | 0, 743 | 48, 744 | 2, 745 | "SAMPLER" 746 | ], 747 | [ 748 | 86, 749 | 27, 750 | 0, 751 | 48, 752 | 4, 753 | "LATENT" 754 | ], 755 | [ 756 | 87, 757 | 48, 758 | 0, 759 | 49, 760 | 0, 761 | "LATENT" 762 | ], 763 | [ 764 | 88, 765 | 10, 766 | 0, 767 | 49, 768 | 1, 769 | "VAE" 770 | ], 771 | [ 772 | 89, 773 | 49, 774 | 0, 775 | 50, 776 | 0, 777 | "IMAGE" 778 | ], 779 | [ 780 | 93, 781 | 17, 782 | 0, 783 | 48, 784 | 3, 785 | "SIGMAS" 786 | ], 787 | [ 788 | 107, 789 | 26, 790 | 0, 791 | 47, 792 | 1, 793 | "CONDITIONING" 794 | ], 795 | [ 796 | 122, 797 | 62, 798 | 0, 799 | 47, 800 | 0, 801 | "MODEL" 802 | ], 803 | [ 804 | 123, 805 | 51, 806 | 0, 807 | 62, 808 | 2, 809 | "EVA_CLIP" 810 | ], 811 | [ 812 | 124, 813 | 53, 814 | 0, 815 | 62, 816 | 3, 817 | "FACEANALYSIS" 818 | ], 819 | [ 820 | 125, 821 | 45, 822 | 0, 823 | 62, 824 | 1, 825 | "PULIDFLUX" 826 | ], 827 | [ 828 | 126, 829 | 54, 830 | 0, 831 | 62, 832 | 4, 833 | "IMAGE" 834 | ], 835 | [ 836 | 127, 837 | 31, 838 | 0, 839 | 62, 840 | 0, 841 | "MODEL" 842 | ], 843 | [ 844 | 128, 845 | 41, 846 | 0, 847 | 6, 848 | 0, 849 | "CLIP" 850 | ], 851 | [ 852 | 129, 853 | 31, 854 | 0, 855 | 17, 856 | 0, 857 | "MODEL" 858 | ] 859 | ], 860 | "groups": [], 861 | "config": {}, 862 | "extra": { 863 | "ds": { 864 | "scale": 0.7513148009015777, 865 | "offset": [ 866 | 124.42912136813258, 867 | 743.5079061935592 868 | ] 869 | } 870 | }, 871 | "version": 0.4 872 | } -------------------------------------------------------------------------------- /examples/pulid_flux_einstein.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/balazik/ComfyUI-PuLID-Flux/a80912fc3435c358607bf4b43a58dbcbebdb09ff/examples/pulid_flux_einstein.png -------------------------------------------------------------------------------- /pulidflux.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn, Tensor 4 | from torchvision import transforms 5 | from torchvision.transforms import functional 6 | import os 7 | import logging 8 | import folder_paths 9 | import comfy.utils 10 | from comfy.ldm.flux.layers import timestep_embedding 11 | from insightface.app import FaceAnalysis 12 | from facexlib.parsing import init_parsing_model 13 | from facexlib.utils.face_restoration_helper import FaceRestoreHelper 14 | 15 | from .eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 16 | from .encoders_flux import IDFormer, PerceiverAttentionCA 17 | 18 | INSIGHTFACE_DIR = os.path.join(folder_paths.models_dir, "insightface") 19 | 20 | MODELS_DIR = os.path.join(folder_paths.models_dir, "pulid") 21 | if "pulid" not in folder_paths.folder_names_and_paths: 22 | current_paths = [MODELS_DIR] 23 | else: 24 | current_paths, _ = folder_paths.folder_names_and_paths["pulid"] 25 | folder_paths.folder_names_and_paths["pulid"] = (current_paths, folder_paths.supported_pt_extensions) 26 | 27 | class PulidFluxModel(nn.Module): 28 | def __init__(self): 29 | super().__init__() 30 | 31 | self.double_interval = 2 32 | self.single_interval = 4 33 | 34 | # Init encoder 35 | self.pulid_encoder = IDFormer() 36 | 37 | # Init attention 38 | num_ca = 19 // self.double_interval + 38 // self.single_interval 39 | if 19 % self.double_interval != 0: 40 | num_ca += 1 41 | if 38 % self.single_interval != 0: 42 | num_ca += 1 43 | self.pulid_ca = nn.ModuleList([ 44 | PerceiverAttentionCA() for _ in range(num_ca) 45 | ]) 46 | 47 | def from_pretrained(self, path: str): 48 | state_dict = comfy.utils.load_torch_file(path, safe_load=True) 49 | state_dict_dict = {} 50 | for k, v in state_dict.items(): 51 | module = k.split('.')[0] 52 | state_dict_dict.setdefault(module, {}) 53 | new_k = k[len(module) + 1:] 54 | state_dict_dict[module][new_k] = v 55 | 56 | for module in state_dict_dict: 57 | getattr(self, module).load_state_dict(state_dict_dict[module], strict=True) 58 | 59 | del state_dict 60 | del state_dict_dict 61 | 62 | def get_embeds(self, face_embed, clip_embeds): 63 | return self.pulid_encoder(face_embed, clip_embeds) 64 | 65 | def forward_orig( 66 | self, 67 | img: Tensor, 68 | img_ids: Tensor, 69 | txt: Tensor, 70 | txt_ids: Tensor, 71 | timesteps: Tensor, 72 | y: Tensor, 73 | guidance: Tensor = None, 74 | control=None, 75 | ) -> Tensor: 76 | if img.ndim != 3 or txt.ndim != 3: 77 | raise ValueError("Input img and txt tensors must have 3 dimensions.") 78 | 79 | # running on sequences img 80 | img = self.img_in(img) 81 | vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype)) 82 | if self.params.guidance_embed: 83 | if guidance is None: 84 | raise ValueError("Didn't get guidance strength for guidance distilled model.") 85 | vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) 86 | 87 | vec = vec + self.vector_in(y) 88 | txt = self.txt_in(txt) 89 | 90 | ids = torch.cat((txt_ids, img_ids), dim=1) 91 | pe = self.pe_embedder(ids) 92 | 93 | ca_idx = 0 94 | for i, block in enumerate(self.double_blocks): 95 | img, txt = block(img=img, txt=txt, vec=vec, pe=pe) 96 | 97 | if control is not None: # Controlnet 98 | control_i = control.get("input") 99 | if i < len(control_i): 100 | add = control_i[i] 101 | if add is not None: 102 | img += add 103 | 104 | # PuLID attention 105 | if self.pulid_data: 106 | if i % self.pulid_double_interval == 0: 107 | # Will calculate influence of all pulid nodes at once 108 | for _, node_data in self.pulid_data.items(): 109 | if torch.any((node_data['sigma_start'] >= timesteps) & (timesteps >= node_data['sigma_end'])): 110 | img = img + node_data['weight'] * self.pulid_ca[ca_idx](node_data['embedding'], img) 111 | ca_idx += 1 112 | 113 | img = torch.cat((txt, img), 1) 114 | 115 | for i, block in enumerate(self.single_blocks): 116 | img = block(img, vec=vec, pe=pe) 117 | 118 | if control is not None: # Controlnet 119 | control_o = control.get("output") 120 | if i < len(control_o): 121 | add = control_o[i] 122 | if add is not None: 123 | img[:, txt.shape[1] :, ...] += add 124 | 125 | # PuLID attention 126 | if self.pulid_data: 127 | real_img, txt = img[:, txt.shape[1]:, ...], img[:, :txt.shape[1], ...] 128 | if i % self.pulid_single_interval == 0: 129 | # Will calculate influence of all nodes at once 130 | for _, node_data in self.pulid_data.items(): 131 | if torch.any((node_data['sigma_start'] >= timesteps) & (timesteps >= node_data['sigma_end'])): 132 | real_img = real_img + node_data['weight'] * self.pulid_ca[ca_idx](node_data['embedding'], real_img) 133 | ca_idx += 1 134 | img = torch.cat((txt, real_img), 1) 135 | 136 | img = img[:, txt.shape[1] :, ...] 137 | 138 | img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) 139 | return img 140 | 141 | def tensor_to_image(tensor): 142 | image = tensor.mul(255).clamp(0, 255).byte().cpu() 143 | image = image[..., [2, 1, 0]].numpy() 144 | return image 145 | 146 | def image_to_tensor(image): 147 | tensor = torch.clamp(torch.from_numpy(image).float() / 255., 0, 1) 148 | tensor = tensor[..., [2, 1, 0]] 149 | return tensor 150 | 151 | def to_gray(img): 152 | x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3] 153 | x = x.repeat(1, 3, 1, 1) 154 | return x 155 | 156 | """ 157 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 158 | Nodes 159 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 160 | """ 161 | 162 | class PulidFluxModelLoader: 163 | @classmethod 164 | def INPUT_TYPES(s): 165 | return {"required": {"pulid_file": (folder_paths.get_filename_list("pulid"), )}} 166 | 167 | RETURN_TYPES = ("PULIDFLUX",) 168 | FUNCTION = "load_model" 169 | CATEGORY = "pulid" 170 | 171 | def load_model(self, pulid_file): 172 | model_path = folder_paths.get_full_path("pulid", pulid_file) 173 | 174 | # Also initialize the model, takes longer to load but then it doesn't have to be done every time you change parameters in the apply node 175 | model = PulidFluxModel() 176 | 177 | logging.info("Loading PuLID-Flux model.") 178 | model.from_pretrained(path=model_path) 179 | 180 | return (model,) 181 | 182 | class PulidFluxInsightFaceLoader: 183 | @classmethod 184 | def INPUT_TYPES(s): 185 | return { 186 | "required": { 187 | "provider": (["CPU", "CUDA", "ROCM"], ), 188 | }, 189 | } 190 | 191 | RETURN_TYPES = ("FACEANALYSIS",) 192 | FUNCTION = "load_insightface" 193 | CATEGORY = "pulid" 194 | 195 | def load_insightface(self, provider): 196 | model = FaceAnalysis(name="antelopev2", root=INSIGHTFACE_DIR, providers=[provider + 'ExecutionProvider',]) # alternative to buffalo_l 197 | model.prepare(ctx_id=0, det_size=(640, 640)) 198 | 199 | return (model,) 200 | 201 | class PulidFluxEvaClipLoader: 202 | @classmethod 203 | def INPUT_TYPES(s): 204 | return { 205 | "required": {}, 206 | } 207 | 208 | RETURN_TYPES = ("EVA_CLIP",) 209 | FUNCTION = "load_eva_clip" 210 | CATEGORY = "pulid" 211 | 212 | def load_eva_clip(self): 213 | from .eva_clip.factory import create_model_and_transforms 214 | 215 | model, _, _ = create_model_and_transforms('EVA02-CLIP-L-14-336', 'eva_clip', force_custom_clip=True) 216 | 217 | model = model.visual 218 | 219 | eva_transform_mean = getattr(model, 'image_mean', OPENAI_DATASET_MEAN) 220 | eva_transform_std = getattr(model, 'image_std', OPENAI_DATASET_STD) 221 | if not isinstance(eva_transform_mean, (list, tuple)): 222 | model["image_mean"] = (eva_transform_mean,) * 3 223 | if not isinstance(eva_transform_std, (list, tuple)): 224 | model["image_std"] = (eva_transform_std,) * 3 225 | 226 | return (model,) 227 | 228 | class ApplyPulidFlux: 229 | @classmethod 230 | def INPUT_TYPES(s): 231 | return { 232 | "required": { 233 | "model": ("MODEL", ), 234 | "pulid_flux": ("PULIDFLUX", ), 235 | "eva_clip": ("EVA_CLIP", ), 236 | "face_analysis": ("FACEANALYSIS", ), 237 | "image": ("IMAGE", ), 238 | "weight": ("FLOAT", {"default": 1.0, "min": -1.0, "max": 5.0, "step": 0.05 }), 239 | "start_at": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001 }), 240 | "end_at": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001 }), 241 | }, 242 | "optional": { 243 | "attn_mask": ("MASK", ), 244 | }, 245 | "hidden": { 246 | "unique_id": "UNIQUE_ID" 247 | }, 248 | } 249 | 250 | RETURN_TYPES = ("MODEL",) 251 | FUNCTION = "apply_pulid_flux" 252 | CATEGORY = "pulid" 253 | 254 | def __init__(self): 255 | self.pulid_data_dict = None 256 | 257 | def apply_pulid_flux(self, model, pulid_flux, eva_clip, face_analysis, image, weight, start_at, end_at, attn_mask=None, unique_id=None): 258 | device = comfy.model_management.get_torch_device() 259 | # Why should I care what args say, when the unet model has a different dtype?! 260 | # Am I missing something?! 261 | #dtype = comfy.model_management.unet_dtype() 262 | dtype = model.model.diffusion_model.dtype 263 | # Because of 8bit models we must check what cast type does the unet uses 264 | # ZLUDA (Intel, AMD) & GPUs with compute capability < 8.0 don't support bfloat16 etc. 265 | # Issue: https://github.com/balazik/ComfyUI-PuLID-Flux/issues/6 266 | if model.model.manual_cast_dtype is not None: 267 | dtype = model.model.manual_cast_dtype 268 | 269 | eva_clip.to(device, dtype=dtype) 270 | pulid_flux.to(device, dtype=dtype) 271 | 272 | # TODO: Add masking support! 273 | if attn_mask is not None: 274 | if attn_mask.dim() > 3: 275 | attn_mask = attn_mask.squeeze(-1) 276 | elif attn_mask.dim() < 3: 277 | attn_mask = attn_mask.unsqueeze(0) 278 | attn_mask = attn_mask.to(device, dtype=dtype) 279 | 280 | image = tensor_to_image(image) 281 | 282 | face_helper = FaceRestoreHelper( 283 | upscale_factor=1, 284 | face_size=512, 285 | crop_ratio=(1, 1), 286 | det_model='retinaface_resnet50', 287 | save_ext='png', 288 | device=device, 289 | ) 290 | 291 | face_helper.face_parse = None 292 | face_helper.face_parse = init_parsing_model(model_name='bisenet', device=device) 293 | 294 | bg_label = [0, 16, 18, 7, 8, 9, 14, 15] 295 | cond = [] 296 | 297 | # Analyse multiple images at multiple sizes and combine largest area embeddings 298 | for i in range(image.shape[0]): 299 | # get insightface embeddings 300 | iface_embeds = None 301 | for size in [(size, size) for size in range(640, 256, -64)]: 302 | face_analysis.det_model.input_size = size 303 | face_info = face_analysis.get(image[i]) 304 | if face_info: 305 | # Only use the maximum face 306 | # Removed the reverse=True from original code because we need the largest area not the smallest one! 307 | # Sorts the list in ascending order (smallest to largest), 308 | # then selects the last element, which is the largest face 309 | face_info = sorted(face_info, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[-1] 310 | iface_embeds = torch.from_numpy(face_info.embedding).unsqueeze(0).to(device, dtype=dtype) 311 | break 312 | else: 313 | # No face detected, skip this image 314 | logging.warning(f'Warning: No face detected in image {str(i)}') 315 | continue 316 | 317 | # get eva_clip embeddings 318 | face_helper.clean_all() 319 | face_helper.read_image(image[i]) 320 | face_helper.get_face_landmarks_5(only_center_face=True) 321 | face_helper.align_warp_face() 322 | 323 | if len(face_helper.cropped_faces) == 0: 324 | # No face detected, skip this image 325 | continue 326 | 327 | # Get aligned face image 328 | align_face = face_helper.cropped_faces[0] 329 | # Convert bgr face image to tensor 330 | align_face = image_to_tensor(align_face).unsqueeze(0).permute(0, 3, 1, 2).to(device) 331 | parsing_out = face_helper.face_parse(functional.normalize(align_face, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0] 332 | parsing_out = parsing_out.argmax(dim=1, keepdim=True) 333 | bg = sum(parsing_out == i for i in bg_label).bool() 334 | white_image = torch.ones_like(align_face) 335 | # Only keep the face features 336 | face_features_image = torch.where(bg, white_image, to_gray(align_face)) 337 | 338 | # Transform img before sending to eva_clip 339 | # Apparently MPS only supports NEAREST interpolation? 340 | face_features_image = functional.resize(face_features_image, eva_clip.image_size, transforms.InterpolationMode.BICUBIC if 'cuda' in device.type else transforms.InterpolationMode.NEAREST).to(device, dtype=dtype) 341 | face_features_image = functional.normalize(face_features_image, eva_clip.image_mean, eva_clip.image_std) 342 | 343 | # eva_clip 344 | id_cond_vit, id_vit_hidden = eva_clip(face_features_image, return_all_features=False, return_hidden=True, shuffle=False) 345 | id_cond_vit = id_cond_vit.to(device, dtype=dtype) 346 | for idx in range(len(id_vit_hidden)): 347 | id_vit_hidden[idx] = id_vit_hidden[idx].to(device, dtype=dtype) 348 | 349 | id_cond_vit = torch.div(id_cond_vit, torch.norm(id_cond_vit, 2, 1, True)) 350 | 351 | # Combine embeddings 352 | id_cond = torch.cat([iface_embeds, id_cond_vit], dim=-1) 353 | 354 | # Pulid_encoder 355 | cond.append(pulid_flux.get_embeds(id_cond, id_vit_hidden)) 356 | 357 | if not cond: 358 | # No faces detected, return the original model 359 | logging.warning("PuLID warning: No faces detected in any of the given images, returning unmodified model.") 360 | return (model,) 361 | 362 | # average embeddings 363 | cond = torch.cat(cond).to(device, dtype=dtype) 364 | if cond.shape[0] > 1: 365 | cond = torch.mean(cond, dim=0, keepdim=True) 366 | 367 | sigma_start = model.get_model_object("model_sampling").percent_to_sigma(start_at) 368 | sigma_end = model.get_model_object("model_sampling").percent_to_sigma(end_at) 369 | 370 | # Patch the Flux model (original diffusion_model) 371 | # Nah, I don't care for the official ModelPatcher because it's undocumented! 372 | # I want the end result now, and I don’t mind if I break other custom nodes in the process. 😄 373 | flux_model = model.model.diffusion_model 374 | # Let's see if we already patched the underlying flux model, if not apply patch 375 | if not hasattr(flux_model, "pulid_ca"): 376 | # Add perceiver attention, variables and current node data (weight, embedding, sigma_start, sigma_end) 377 | # The pulid_data is stored in Dict by unique node index, 378 | # so we can chain multiple ApplyPulidFlux nodes! 379 | flux_model.pulid_ca = pulid_flux.pulid_ca 380 | flux_model.pulid_double_interval = pulid_flux.double_interval 381 | flux_model.pulid_single_interval = pulid_flux.single_interval 382 | flux_model.pulid_data = {} 383 | # Replace model forward_orig with our own 384 | new_method = forward_orig.__get__(flux_model, flux_model.__class__) 385 | setattr(flux_model, 'forward_orig', new_method) 386 | 387 | # Patch is already in place, add data (weight, embedding, sigma_start, sigma_end) under unique node index 388 | flux_model.pulid_data[unique_id] = { 389 | 'weight': weight, 390 | 'embedding': cond, 391 | 'sigma_start': sigma_start, 392 | 'sigma_end': sigma_end, 393 | } 394 | 395 | # Keep a reference for destructor (if node is deleted the data will be deleted as well) 396 | self.pulid_data_dict = {'data': flux_model.pulid_data, 'unique_id': unique_id} 397 | 398 | return (model,) 399 | 400 | def __del__(self): 401 | # Destroy the data for this node 402 | if self.pulid_data_dict: 403 | del self.pulid_data_dict['data'][self.pulid_data_dict['unique_id']] 404 | del self.pulid_data_dict 405 | 406 | 407 | NODE_CLASS_MAPPINGS = { 408 | "PulidFluxModelLoader": PulidFluxModelLoader, 409 | "PulidFluxInsightFaceLoader": PulidFluxInsightFaceLoader, 410 | "PulidFluxEvaClipLoader": PulidFluxEvaClipLoader, 411 | "ApplyPulidFlux": ApplyPulidFlux, 412 | } 413 | 414 | NODE_DISPLAY_NAME_MAPPINGS = { 415 | "PulidFluxModelLoader": "Load PuLID Flux Model", 416 | "PulidFluxInsightFaceLoader": "Load InsightFace (PuLID Flux)", 417 | "PulidFluxEvaClipLoader": "Load Eva Clip (PuLID Flux)", 418 | "ApplyPulidFlux": "Apply PuLID Flux", 419 | } 420 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | facexlib 2 | insightface 3 | onnxruntime 4 | onnxruntime-gpu 5 | ftfy 6 | timm 7 | --------------------------------------------------------------------------------