├── DiffCLIP__Differential_Attention_Meets_CLIP.pdf ├── README.md ├── __init__.py ├── __pycache__ ├── diff_attention.cpython-310.pyc ├── diff_clip.cpython-310.pyc └── tokenizer.cpython-310.pyc ├── assets ├── coco_sample.jpg └── images.png ├── bpe_simple_vocab_16e6.txt.gz ├── diff_attention.py ├── diff_clip.py ├── requirements.txt ├── test_models.py └── tokenizer.py /DiffCLIP__Differential_Attention_Meets_CLIP.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hammoudhasan/DiffCLIP/f6bd3ff3b2b97f12fecc903e2c1496357242d0c9/DiffCLIP__Differential_Attention_Meets_CLIP.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

DiffCLIP: Differential Attention Meets CLIP

3 |

Hasan Abed Al Kader Hammoud and Bernard Ghanem

4 |

King Abdullah University of Science and Technology

5 | 6 | 7 |

8 | 9 | arXiv 10 | 11 | Hugging Face Collection 12 |

13 |
14 | 15 |
16 | 17 |
18 | 19 | ## Abstract 20 | 21 | We propose DiffCLIP, a novel vision-language model that extends the differential attention mechanism to CLIP architectures. Differential attention was originally developed for large language models to amplify relevant context while canceling out noisy information. In this work, we integrate this mechanism into CLIP's dual encoder (image and text) framework. With minimal additional parameters, DiffCLIP achieves superior performance on image-text understanding tasks. Across zero-shot classification, retrieval, and robustness benchmarks, DiffCLIP consistently outperforms baseline CLIP models. Notably, these gains come with negligible computational overhead, demonstrating that differential attention can significantly enhance multi-modal representations without sacrificing efficiency. 22 | 23 | ## What is Differential Attention? 24 | 25 | Differential attention, proposed in [Differential Transformer](https://arxiv.org/abs/2410.05258), computes the difference between two attention maps: 26 | 27 | ``` 28 | DiffAttn(X) = (softmax(Q₁K₁ᵀ/√d) − λ · softmax(Q₂K₂ᵀ/√d)) · V 29 | ``` 30 | 31 | where the query and key projections are split as `[Q₁; Q₂] = X·Wᵠ` and `[K₁; K₂] = X·Wᵏ`, and λ is a learnable parameter. This mechanism allows the model to capture complementary information by explicitly modeling the differences between attention patterns, leading to richer multimodal representations. 32 | 33 | ## Structure 34 | 35 | The repository contains two main components: 36 | 37 | 1. **DifferentialVisionTransformer** (in `diff_attention.py`): A Vision Transformer modified to use differential attention. 38 | 39 | 2. **DiffCLIP** (in `diff_clip.py`): A CLIP model that uses differential attention in both its vision and text encoders. 40 | 41 | ## How to Use 42 | 43 | ### Installation 44 | 45 | ```bash 46 | # Clone the repository 47 | git clone https://github.com/yourusername/DiffCLIP.git 48 | cd DiffCLIP 49 | 50 | # Install dependencies 51 | pip install -r requirements.txt 52 | ``` 53 | 54 | ### Basic Usage 55 | 56 | ```python 57 | import torch 58 | from diff_clip import DiffCLIP_VITB16 59 | 60 | # Create model 61 | model = DiffCLIP_VITB16() 62 | 63 | # Process image and text 64 | image = torch.randn(1, 3, 224, 224) 65 | text = torch.randint(0, 49408, (1, 77)) # Tokenized text 66 | 67 | # Get embeddings 68 | with torch.no_grad(): 69 | outputs = model(image, text) 70 | 71 | print(outputs["image_embed"].shape) # Should be [1, 512] 72 | print(outputs["text_embed"].shape) # Should be [1, 512] 73 | ``` 74 | 75 | ### Zero-Shot Classification 76 | 77 | You can use the provided `test_models.py` script to perform zero-shot classification: 78 | 79 | ```bash 80 | # Download the model from Hugging Face and test on a COCO image 81 | python test_models.py 82 | ``` 83 | 84 | This will: 85 | 1. Download the DiffCLIP_ViTB16_CC12M model from Hugging Face 86 | 2. Load a sample image from COCO 87 | 3. Perform zero-shot classification 88 | 4. Print the top-5 predicted classes 89 | 90 | ## References 91 | 92 | ``` 93 | @misc{hammoud2025diffclipdifferentialattentionmeets, 94 | title={DiffCLIP: Differential Attention Meets CLIP}, 95 | author={Hasan Abed Al Kader Hammoud and Bernard Ghanem}, 96 | year={2025}, 97 | eprint={2503.06626}, 98 | archivePrefix={arXiv}, 99 | primaryClass={cs.CV}, 100 | url={https://arxiv.org/abs/2503.06626}, 101 | } 102 | ``` 103 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | DiffCLIP Module 4 | 5 | This package implements a differential version of CLIP, featuring: 6 | - DifferentialVisionTransformer: A vision transformer with differential attention 7 | - DiffCLIP: A CLIP model using differential attention in both vision and text encoders 8 | 9 | Main components: 10 | - diff_attention.py: Implements DifferentialVisionTransformer 11 | - diff_clip.py: Implements DiffCLIP 12 | 13 | For more details, see the individual module docstrings. 14 | """ 15 | 16 | from .diff_attention import ( 17 | RMSNorm, 18 | DiffAttention, 19 | LayerScale, 20 | DiffBlock, 21 | DifferentialVisionTransformer, 22 | diff_vit_base_patch16_224 23 | ) 24 | 25 | from .diff_clip import ( 26 | DifferentialMultiheadAttention, 27 | DifferentialResidualAttentionBlock, 28 | DifferentialTextTransformer, 29 | DiffCLIP, 30 | DiffCLIP_VITB16, 31 | ) 32 | 33 | __all__ = [ 34 | 'RMSNorm', 35 | 'DiffAttention', 36 | 'LayerScale', 37 | 'DiffBlock', 38 | 'DifferentialVisionTransformer', 39 | 'diff_vit_base_patch16_224', 40 | 'DifferentialMultiheadAttention', 41 | 'DifferentialResidualAttentionBlock', 42 | 'DifferentialTextTransformer', 43 | 'DiffCLIP', 44 | 'DiffCLIP_VITB16', 45 | ] 46 | -------------------------------------------------------------------------------- /__pycache__/diff_attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hammoudhasan/DiffCLIP/f6bd3ff3b2b97f12fecc903e2c1496357242d0c9/__pycache__/diff_attention.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/diff_clip.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hammoudhasan/DiffCLIP/f6bd3ff3b2b97f12fecc903e2c1496357242d0c9/__pycache__/diff_clip.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/tokenizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hammoudhasan/DiffCLIP/f6bd3ff3b2b97f12fecc903e2c1496357242d0c9/__pycache__/tokenizer.cpython-310.pyc -------------------------------------------------------------------------------- /assets/coco_sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hammoudhasan/DiffCLIP/f6bd3ff3b2b97f12fecc903e2c1496357242d0c9/assets/coco_sample.jpg -------------------------------------------------------------------------------- /assets/images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hammoudhasan/DiffCLIP/f6bd3ff3b2b97f12fecc903e2c1496357242d0c9/assets/images.png -------------------------------------------------------------------------------- /bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hammoudhasan/DiffCLIP/f6bd3ff3b2b97f12fecc903e2c1496357242d0c9/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /diff_attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | diff_attention.py 4 | 5 | This module implements a Differential Attention Vision Transformer. 6 | The key idea is to replace the standard softmax attention with a 7 | differential attention mechanism as described in the paper: 8 | 9 | DiffAttn(X) = (softmax(Q₁K₁ᵀ/√d) − λ · softmax(Q₂K₂ᵀ/√d)) · V 10 | 11 | where the query and key projections are split as: 12 | [Q₁; Q₂] = X Wᵠ, [K₁; K₂] = X Wᵏ, 13 | and V = X Wᵛ. 14 | 15 | The learnable scalar λ is re-parameterized as: 16 | λ = exp(λ_{q1} ⋅ λ_{k1}) − exp(λ_{q2} ⋅ λ_{k2}) + λ_init 17 | 18 | The multi-head formulation uses "effective heads" computed as: 19 | h_effective = (num_heads // 2) 20 | with the per-head dimension d_head = d_model / num_heads. Note that the value 21 | projection is not split (it remains of dimension d_model), so that its per-head shape 22 | is (2·d_head), aligning with the fact that Q and K are split into two parts. 23 | 24 | The overall transformer block is: 25 | Y = X + DropPath(LayerScale(DiffAttention(LN(X)))) 26 | X' = Y + DropPath(LayerScale(MLP(LN(Y)))) 27 | 28 | The DifferentialVisionTransformer class below inherits from timm's VisionTransformer 29 | and replaces its transformer blocks with blocks using Differential Attention. 30 | A registration function diff_vit_base_patch16_224 is provided with the same default 31 | parameters as ViT-Base (patch16, 224). 32 | 33 | References: 34 | - Vision Transformer: https://arxiv.org/abs/2010.11929 35 | - Differential Transformers: (see paper) 36 | """ 37 | 38 | import math 39 | import torch 40 | import torch.nn as nn 41 | import torch.nn.functional as F 42 | 43 | # --------------------------------------------------------- 44 | # Register model decorator: Try to import timm's version; if unavailable, use a dummy. 45 | try: 46 | from timm.models.registry import register_model 47 | except ImportError: 48 | def register_model(fn): 49 | return fn 50 | 51 | # Import timm's VisionTransformer and common layers. 52 | from timm.models.vision_transformer import VisionTransformer 53 | from timm.models.layers import DropPath, Mlp 54 | 55 | # --------------------------------------------------------- 56 | # RMSNorm (as used in the differential attention paper) 57 | # --------------------------------------------------------- 58 | class RMSNorm(nn.Module): 59 | r""" 60 | RMSNorm normalizes the input tensor by its root-mean-square (RMS) value. 61 | 62 | Given an input x ∈ ℝ^(...×d), it computes: 63 | 64 | RMS(x) = sqrt(mean(x², dim=-1, keepdim=True) + ε) 65 | output = x / RMS(x) 66 | 67 | Optionally, a learnable weight is applied if elementwise_affine is True. 68 | 69 | Args: 70 | dim (int): Dimension to normalize. 71 | eps (float): A value added for numerical stability. 72 | elementwise_affine (bool): If True, multiply by a learnable weight. 73 | """ 74 | def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True): 75 | super().__init__() 76 | self.dim = dim 77 | self.eps = eps 78 | if elementwise_affine: 79 | self.weight = nn.Parameter(torch.ones(dim)) 80 | else: 81 | self.register_parameter('weight', None) 82 | 83 | def _norm(self, x): 84 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 85 | 86 | def forward(self, x: torch.Tensor) -> torch.Tensor: 87 | output = self._norm(x.float()).type_as(x) 88 | if self.weight is not None: 89 | output = output * self.weight 90 | return output 91 | 92 | def extra_repr(self) -> str: 93 | return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.weight is not None}' 94 | 95 | # --------------------------------------------------------- 96 | # Differential Attention Module 97 | # --------------------------------------------------------- 98 | class DiffAttention(nn.Module): 99 | r""" 100 | Differential Attention Module. 101 | 102 | Given an input tensor X ∈ ℝ^(B×N×d_model), we first compute the linear projections: 103 | 104 | Q = X Wᵠ, K = X Wᵏ, V = X Wᵛ 105 | 106 | The queries and keys are then reshaped and split into two parts: 107 | Q → [Q₁; Q₂] ∈ ℝ^(B, N, 2·h_effective, d_head) 108 | K → [K₁; K₂] ∈ ℝ^(B, N, 2·h_effective, d_head) 109 | with h_effective = num_heads // 2 and d_head = d_model / num_heads. 110 | 111 | The value projection is reshaped to: 112 | V ∈ ℝ^(B, N, h_effective, 2·d_head) 113 | 114 | We then compute two attention maps: 115 | A₁ = softmax((Q₁ K₁ᵀ) / √d_head) 116 | A₂ = softmax((Q₂ K₂ᵀ) / √d_head) 117 | 118 | A learnable scalar λ is computed via: 119 | λ = exp(λ_{q1} ⋅ λ_{k1}) − exp(λ_{q2} ⋅ λ_{k2}) + λ_init 120 | 121 | Finally, the differential attention output is: 122 | DiffAttn(X) = (A₁ − λ · A₂) · V 123 | 124 | The per-head outputs are then normalized headwise with RMSNorm and projected back to d_model. 125 | 126 | Args: 127 | dim (int): Embedding dimension (d_model). 128 | num_heads (int): Number of heads in the original transformer (must be even). 129 | qkv_bias (bool): If True, add a bias term to the Q, K, V projections. 130 | attn_drop (float): Dropout probability after softmax. 131 | proj_drop (float): Dropout probability after the output projection. 132 | lambda_init (float): Initial constant for lambda re-parameterization. 133 | """ 134 | def __init__(self, dim, num_heads=8, qkv_bias=True, attn_drop=0., proj_drop=0., lambda_init=0.8): 135 | super().__init__() 136 | if num_heads % 2 != 0: 137 | raise ValueError("num_heads must be even for Differential Attention.") 138 | self.dim = dim 139 | self.num_heads = num_heads # original number of heads 140 | self.effective_heads = num_heads // 2 # differential attention operates on half as many heads 141 | self.head_dim = dim // num_heads # per-head dimension 142 | self.scaling = self.head_dim ** -0.5 143 | 144 | # Linear projections for Q, K, V: mapping from dim → dim. 145 | self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) 146 | self.k_proj = nn.Linear(dim, dim, bias=qkv_bias) 147 | self.v_proj = nn.Linear(dim, dim, bias=qkv_bias) 148 | self.out_proj = nn.Linear(dim, dim, bias=True) # final output projection 149 | 150 | self.attn_drop = nn.Dropout(attn_drop) 151 | self.proj_drop = nn.Dropout(proj_drop) 152 | 153 | # RMSNorm for headwise normalization on outputs (each head's output has dimension 2·head_dim) 154 | self.diff_norm = RMSNorm(2 * self.head_dim, eps=1e-5, elementwise_affine=True) 155 | 156 | # Learnable lambda parameters (shared across all heads) 157 | self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) 158 | self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) 159 | self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) 160 | self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) 161 | self.lambda_init = lambda_init 162 | 163 | def forward(self, x: torch.Tensor) -> torch.Tensor: 164 | """ 165 | Args: 166 | x (Tensor): Input tensor of shape (B, N, d_model). 167 | 168 | Returns: 169 | Tensor of shape (B, N, d_model) after applying differential attention. 170 | """ 171 | B, N, _ = x.shape 172 | 173 | # Compute Q, K, V projections. 174 | q = self.q_proj(x) # shape: (B, N, d_model) 175 | k = self.k_proj(x) # shape: (B, N, d_model) 176 | v = self.v_proj(x) # shape: (B, N, d_model) 177 | 178 | # Reshape Q and K into (B, N, 2 * h_effective, head_dim) 179 | q = q.view(B, N, 2 * self.effective_heads, self.head_dim) 180 | k = k.view(B, N, 2 * self.effective_heads, self.head_dim) 181 | # Reshape V into (B, N, h_effective, 2 * head_dim) 182 | v = v.view(B, N, self.effective_heads, 2 * self.head_dim) 183 | 184 | # Transpose to bring head dimension forward. 185 | # q, k: (B, 2 * h_effective, N, head_dim) 186 | q = q.transpose(1, 2) 187 | k = k.transpose(1, 2) 188 | # v: (B, h_effective, N, 2 * head_dim) 189 | v = v.transpose(1, 2) 190 | 191 | # Scale Q. 192 | q = q * self.scaling 193 | 194 | # Compute raw attention scores: (B, 2 * h_effective, N, N) 195 | attn_scores = torch.matmul(q, k.transpose(-1, -2)) 196 | 197 | # Compute attention probabilities. 198 | attn_probs = F.softmax(attn_scores, dim=-1) 199 | attn_probs = self.attn_drop(attn_probs) 200 | 201 | # Reshape to separate the two halves: (B, h_effective, 2, N, N) 202 | attn_probs = attn_probs.view(B, self.effective_heads, 2, N, N) 203 | 204 | # Compute lambda via re-parameterization. 205 | lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1)) 206 | lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2)) 207 | lambda_full = lambda_1 - lambda_2 + self.lambda_init 208 | 209 | # Differential attention: subtract the second attention map scaled by lambda_full. 210 | diff_attn = attn_probs[:, :, 0, :, :] - lambda_full * attn_probs[:, :, 1, :, :] # shape: (B, h_effective, N, N) 211 | 212 | # Multiply the differential attention weights with V. 213 | attn_output = torch.matmul(diff_attn, v) # shape: (B, h_effective, N, 2 * head_dim) 214 | 215 | # Apply RMSNorm (headwise normalization) and scale by (1 - lambda_init) 216 | attn_output = self.diff_norm(attn_output) * (1 - self.lambda_init) 217 | 218 | # Concatenate heads: reshape from (B, h_effective, N, 2 * head_dim) → (B, N, 2 * h_effective * head_dim) 219 | attn_output = attn_output.transpose(1, 2).reshape(B, N, 2 * self.effective_heads * self.head_dim) 220 | 221 | # Final linear projection. 222 | x_out = self.out_proj(attn_output) 223 | x_out = self.proj_drop(x_out) 224 | return x_out 225 | 226 | # --------------------------------------------------------- 227 | # LayerScale module (optional scaling of sublayer outputs) 228 | # --------------------------------------------------------- 229 | class LayerScale(nn.Module): 230 | r""" 231 | LayerScale scales the output of a sublayer by a learnable parameter. 232 | 233 | Equation: 234 | Output = x * γ 235 | 236 | Args: 237 | dim (int): Dimension of the sublayer output. 238 | init_values (float): Initial value for scaling parameter γ. 239 | inplace (bool): Whether to perform the multiplication in-place. 240 | """ 241 | def __init__(self, dim, init_values=1e-5, inplace=False): 242 | super().__init__() 243 | self.inplace = inplace 244 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 245 | 246 | def forward(self, x: torch.Tensor) -> torch.Tensor: 247 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 248 | 249 | # --------------------------------------------------------- 250 | # Transformer Block with Differential Attention 251 | # --------------------------------------------------------- 252 | class DiffBlock(nn.Module): 253 | r""" 254 | Transformer Block with Differential Attention. 255 | 256 | The block consists of two main sublayers: 257 | 258 | 1. Differential Attention sublayer: 259 | Y = X + DropPath(LayerScale(DiffAttention(LayerNorm(X)))) 260 | 261 | 2. MLP sublayer: 262 | X' = Y + DropPath(LayerScale(MLP(LayerNorm(Y)))) 263 | 264 | Equations: 265 | Y = X + DropPath(LS₁(DiffAttention(LN₁(X)))) 266 | X' = Y + DropPath(LS₂(MLP(LN₂(Y)))) 267 | 268 | Args: 269 | dim (int): Embedding dimension. 270 | num_heads (int): Number of heads in the original transformer (must be even). 271 | mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension. 272 | qkv_bias (bool): If True, add bias in Q, K, V projections. 273 | drop (float): Dropout probability. 274 | attn_drop (float): Attention dropout probability. 275 | drop_path (float): Stochastic depth rate. 276 | init_values (float or None): Initial value for LayerScale. If None, LayerScale is disabled. 277 | act_layer (nn.Module): Activation layer. 278 | norm_layer (nn.Module): Normalization layer. 279 | lambda_init (float): Initial lambda value for differential attention. 280 | """ 281 | def __init__(self, dim, num_heads, mlp_ratio=4.0, qkv_bias=True, drop=0., attn_drop=0., drop_path=0., 282 | init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, lambda_init=0.8): 283 | super().__init__() 284 | self.norm1 = norm_layer(dim) 285 | self.attn = DiffAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, 286 | lambda_init=lambda_init) 287 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity() 288 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 289 | 290 | self.norm2 = norm_layer(dim) 291 | self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) 292 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity() 293 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 294 | 295 | def forward(self, x: torch.Tensor) -> torch.Tensor: 296 | x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) 297 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) 298 | return x 299 | 300 | # --------------------------------------------------------- 301 | # Differential Vision Transformer 302 | # --------------------------------------------------------- 303 | class DifferentialVisionTransformer(VisionTransformer): 304 | r""" 305 | Vision Transformer with Differential Attention. 306 | 307 | This model is a modification of the standard VisionTransformer (timm) 308 | where the self-attention mechanism is replaced with Differential Attention. 309 | 310 | In each transformer block, the attention sublayer is computed as: 311 | 312 | DiffAttn(X) = (softmax(Q₁K₁ᵀ/√d_head) − λ · softmax(Q₂K₂ᵀ/√d_head)) · V 313 | 314 | with the λ re-parameterization: 315 | λ = exp(λ_{q1}⋅λ_{k1}) − exp(λ_{q2}⋅λ_{k2}) + λ_init 316 | 317 | The overall block structure is: 318 | Y = X + DropPath(LayerScale(DiffAttention(LayerNorm(X)))) 319 | X' = Y + DropPath(LayerScale(MLP(LayerNorm(Y)))) 320 | 321 | Args: 322 | All arguments are as in timm's VisionTransformer. 323 | lambda_init (float): Initial lambda value for differential attention. 324 | """ 325 | def __init__(self, *args, lambda_init=0.8, **kwargs): 326 | super().__init__(*args, **kwargs) 327 | depth = kwargs.get('depth', 12) 328 | embed_dim = self.embed_dim # d_model from VisionTransformer 329 | num_heads = kwargs.get('num_heads', 12) 330 | mlp_ratio = kwargs.get('mlp_ratio', 4.0) 331 | qkv_bias = kwargs.get('qkv_bias', True) 332 | drop_rate = kwargs.get('drop_rate', 0.) 333 | attn_drop_rate = kwargs.get('attn_drop_rate', 0.) 334 | drop_path_rate = kwargs.get('drop_path_rate', 0.) 335 | init_values = kwargs.get('init_values', None) 336 | norm_layer = kwargs.get('norm_layer', None) or nn.LayerNorm 337 | act_layer = kwargs.get('act_layer', None) or nn.GELU 338 | 339 | # Create stochastic depth schedule. 340 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] 341 | blocks = [] 342 | for i in range(depth): 343 | blocks.append( 344 | DiffBlock( 345 | dim=embed_dim, 346 | num_heads=num_heads, 347 | mlp_ratio=mlp_ratio, 348 | qkv_bias=qkv_bias, 349 | drop=drop_rate, 350 | attn_drop=attn_drop_rate, 351 | drop_path=dpr[i], 352 | init_values=init_values, 353 | act_layer=act_layer, 354 | norm_layer=norm_layer, 355 | lambda_init=lambda_init # same for all blocks (or can vary with depth) 356 | ) 357 | ) 358 | self.blocks = nn.Sequential(*blocks) 359 | if hasattr(self, 'norm'): 360 | self.norm = norm_layer(embed_dim) 361 | 362 | # --------------------------------------------------------- 363 | # Model Registration: Differential ViT-Base (Patch16, 224) 364 | # --------------------------------------------------------- 365 | @register_model 366 | def diff_vit_base_patch16_224(pretrained: bool = False, **kwargs) -> DifferentialVisionTransformer: 367 | """ 368 | Differential ViT-Base (ViT-B/16) with Differential Attention. 369 | 370 | The defaults are set to match the original ViT-Base (patch16, 224): 371 | - patch_size = 16 372 | - embed_dim = 768 373 | - depth = 12 374 | - num_heads = 12 375 | - lambda_init = 0.8 376 | 377 | Args: 378 | pretrained (bool): If True, load pretrained weights (not implemented here). 379 | **kwargs: Additional keyword arguments. 380 | 381 | Returns: 382 | DifferentialVisionTransformer model. 383 | """ 384 | model_args = dict( 385 | patch_size=16, 386 | embed_dim=768, 387 | depth=12, 388 | num_heads=12, 389 | lambda_init=0.8, 390 | ) 391 | # Merge additional kwargs with defaults. 392 | model = DifferentialVisionTransformer(**dict(model_args, **kwargs)) 393 | if pretrained: 394 | # Code to load pretrained weights can be added here. 395 | pass 396 | return model 397 | 398 | # --------------------------------------------------------- 399 | # Main test function 400 | # --------------------------------------------------------- 401 | if __name__ == "__main__": 402 | # Instantiate the Differential ViT-Base (patch16, 224) with default parameters. 403 | model = diff_vit_base_patch16_224() 404 | model.eval() 405 | dummy_input = torch.randn(1, 3, 224, 224) 406 | with torch.no_grad(): 407 | output = model(dummy_input) 408 | print("Differential ViT-Base output shape:", output.shape) 409 | 410 | trainable_params = sum(p.numel() for p in model.parameters()) 411 | print(f"Number of trainable parameters: {trainable_params}") 412 | 413 | from timm.models.vision_transformer import vit_base_patch16_224 414 | model = vit_base_patch16_224() 415 | trainable_params = sum(p.numel() for p in model.parameters()) 416 | print(f"Number of trainable parameters: {trainable_params}") 417 | 418 | -------------------------------------------------------------------------------- /diff_clip.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | diff_clip.py 4 | 5 | This module implements a differential version of CLIP's text encoder. 6 | The idea is to replace the standard softmax multi-head attention with a 7 | differential attention mechanism. The differential attention is defined as: 8 | 9 | DiffAttn(X) = (softmax(Q₁ K₁ᵀ / √d) − λ · softmax(Q₂ K₂ᵀ / √d)) · V 10 | 11 | where the input X ∈ ℝ^(L×N×d) (with L = sequence length, N = batch size, and d = embed_dim) 12 | is projected to query, key, and value as: 13 | 14 | Q = X W^Q, K = X W^K, V = X W^V 15 | 16 | and Q and K are split along the head dimension: 17 | 18 | [Q₁; Q₂] ∈ ℝ^(L, N, 2·h_eff, d_head) where h_eff = num_heads // 2 and d_head = d / num_heads. 19 | 20 | A learnable scalar λ is computed via: 21 | 22 | λ = exp(λ_{q1} ⋅ λ_{k1}) − exp(λ_{q2} ⋅ λ_{k2}) + λ_init 23 | 24 | The final multi-head differential attention is then computed as: 25 | 26 | MultiHeadDiffAttn(X) = Concat( LN( DiffAttn₁(X) ), …, LN( DiffAttn_h_eff(X) ) ) · W^O 27 | 28 | The overall block in the text transformer is structured as: 29 | 30 | X' = X + DifferentialAttention(LayerNorm(X)) 31 | X'' = X' + MLP(LayerNorm(X')) 32 | 33 | This file defines: 34 | 1. DifferentialMultiheadAttention – a drop-in replacement for nn.MultiheadAttention. 35 | 2. DifferentialResidualAttentionBlock – a residual block using differential attention. 36 | 3. DifferentialTextTransformer – a stack of differential residual attention blocks. 37 | 4. DiffCLIP – a version of CLIP that uses DifferentialTextTransformer for text encoding. 38 | 39 | References: 40 | - CLIP from OpenAI (modified from github.com/openai/CLIP) 41 | - Differential Transformers (see paper) 42 | """ 43 | 44 | import math 45 | from collections import OrderedDict 46 | 47 | import numpy as np 48 | import torch 49 | import torch.nn as nn 50 | import torch.nn.functional as F 51 | from diff_attention import diff_vit_base_patch16_224 52 | 53 | class RMSNorm(nn.Module): 54 | r""" 55 | RMSNorm normalizes the input tensor by its root-mean-square (RMS) value. 56 | 57 | Given an input x ∈ ℝ^(...×d), it computes: 58 | 59 | RMS(x) = sqrt(mean(x², dim=-1, keepdim=True) + ε) 60 | output = x / RMS(x) 61 | 62 | Optionally, a learnable weight is applied if elementwise_affine is True. 63 | 64 | Args: 65 | dim (int): Dimension to normalize. 66 | eps (float): A value added for numerical stability. 67 | elementwise_affine (bool): If True, multiply by a learnable weight. 68 | """ 69 | def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True): 70 | super().__init__() 71 | self.dim = dim 72 | self.eps = eps 73 | if elementwise_affine: 74 | self.weight = nn.Parameter(torch.ones(dim)) 75 | else: 76 | self.register_parameter('weight', None) 77 | 78 | def _norm(self, x): 79 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 80 | 81 | def forward(self, x: torch.Tensor) -> torch.Tensor: 82 | output = self._norm(x.float()).type_as(x) 83 | if self.weight is not None: 84 | output = output * self.weight 85 | return output 86 | 87 | def extra_repr(self) -> str: 88 | return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.weight is not None}' 89 | 90 | 91 | # --------------------------------------------------------- 92 | # Utility Layers (LayerNorm and QuickGELU) 93 | # --------------------------------------------------------- 94 | class LayerNorm(nn.LayerNorm): 95 | """Subclass torch's LayerNorm to handle fp16.""" 96 | def forward(self, x: torch.Tensor): 97 | orig_type = x.dtype 98 | ret = super().forward(x.type(torch.float32)) 99 | return ret.type(orig_type) 100 | 101 | class QuickGELU(nn.Module): 102 | def forward(self, x: torch.Tensor): 103 | return x * torch.sigmoid(1.702 * x) 104 | 105 | # --------------------------------------------------------- 106 | # Differential Multihead Attention for Text 107 | # --------------------------------------------------------- 108 | class DifferentialMultiheadAttention(nn.Module): 109 | r""" 110 | Differential Multihead Attention for text inputs. 111 | 112 | This module implements the differential attention mechanism with the following steps: 113 | 114 | 1. Given input X ∈ ℝ^(L×N×d) (L: sequence length, N: batch size, d: embed_dim), 115 | compute linear projections: 116 | Q = X W^Q, K = X W^K, V = X W^V. 117 | 118 | 2. Permute the input to shape (N, L, d) so that we treat the batch dimension as B. 119 | 120 | 3. Reshape Q and K to shape (B, L, 2·h_eff, d_head) and then transpose to 121 | (B, 2·h_eff, L, d_head), where h_eff = num_heads // 2 and d_head = d / num_heads. 122 | 123 | 4. Reshape V to (B, L, h_eff, 2·d_head) and then transpose to (B, h_eff, L, 2·d_head). 124 | 125 | 5. Compute the scaled dot-product attention scores for both splits: 126 | A₁ = softmax((Q₁ K₁ᵀ) / √d_head) 127 | A₂ = softmax((Q₂ K₂ᵀ) / √d_head) 128 | 129 | 6. Compute a learnable scalar: 130 | λ = exp(λ_{q1} ⋅ λ_{k1}) − exp(λ_{q2} ⋅ λ_{k2}) + λ_init 131 | 132 | 7. The differential attention output is: 133 | DiffAttn(X) = (A₁ − λ · A₂) · V 134 | 135 | 8. After applying headwise RMSNorm and a final linear projection, the output is 136 | permuted back to (L, N, d). 137 | 138 | Args: 139 | embed_dim (int): Total dimension of the model. 140 | num_heads (int): Number of attention heads (must be even). 141 | qkv_bias (bool): If True, add a bias to the Q, K, V projections. 142 | attn_drop (float): Dropout rate after softmax. 143 | proj_drop (float): Dropout rate after the output projection. 144 | lambda_init (float): Initial scalar for λ. 145 | """ 146 | def __init__(self, embed_dim, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., lambda_init=0.8): 147 | super().__init__() 148 | if num_heads % 2 != 0: 149 | raise ValueError("num_heads must be even for Differential Attention.") 150 | self.embed_dim = embed_dim 151 | self.num_heads = num_heads 152 | self.effective_heads = num_heads // 2 # differential attention uses half the heads 153 | self.head_dim = embed_dim // num_heads 154 | self.scaling = self.head_dim ** -0.5 155 | 156 | # Linear layers for Q, K, V projections. 157 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) 158 | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) 159 | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) 160 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) 161 | 162 | self.attn_drop = nn.Dropout(attn_drop) 163 | self.proj_drop = nn.Dropout(proj_drop) 164 | 165 | # RMSNorm for headwise normalization; each head's output has dimension 2 * head_dim. 166 | self.diff_norm = RMSNorm(2 * self.head_dim, eps=1e-5, elementwise_affine=True) 167 | 168 | # Learnable lambda parameters (shared across heads). 169 | self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) 170 | self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) 171 | self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) 172 | self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) 173 | self.lambda_init = lambda_init 174 | 175 | def forward(self, query, key, value, attn_mask=None): 176 | """ 177 | Args: 178 | query, key, value (Tensor): Input tensors of shape (L, N, embed_dim), 179 | where L is sequence length and N is batch size. 180 | attn_mask (Tensor, optional): Additive attention mask of shape (L, L). 181 | 182 | Returns: 183 | Tensor of shape (L, N, embed_dim) after applying differential multi-head attention. 184 | """ 185 | # Permute input from (L, N, embed_dim) to (N, L, embed_dim) 186 | x = query.transpose(0, 1) # now (N, L, embed_dim) 187 | B, L, _ = x.shape 188 | 189 | # Compute projections. 190 | q = self.q_proj(x) # (B, L, embed_dim) 191 | k = self.k_proj(x) 192 | v = self.v_proj(x) 193 | 194 | # Reshape Q and K to (B, L, 2*h_eff, head_dim) 195 | q = q.view(B, L, 2 * self.effective_heads, self.head_dim) 196 | k = k.view(B, L, 2 * self.effective_heads, self.head_dim) 197 | # Reshape V to (B, L, h_eff, 2*head_dim) 198 | v = v.view(B, L, self.effective_heads, 2 * self.head_dim) 199 | 200 | # Transpose Q and K to (B, 2*h_eff, L, head_dim) 201 | q = q.transpose(1, 2) 202 | k = k.transpose(1, 2) 203 | # Transpose V to (B, h_eff, L, 2*head_dim) 204 | v = v.transpose(1, 2) 205 | 206 | # Scale Q. 207 | q = q * self.scaling 208 | 209 | # Compute raw attention scores: (B, 2*h_eff, L, L) 210 | attn_scores = torch.matmul(q, k.transpose(-1, -2)) 211 | # If an attention mask is provided, add it. 212 | if attn_mask is not None: 213 | # attn_mask is expected to be of shape (L, L) 214 | attn_scores = attn_scores + attn_mask.unsqueeze(0).unsqueeze(0) 215 | 216 | # Compute attention probabilities. 217 | attn_probs = F.softmax(attn_scores, dim=-1) 218 | attn_probs = self.attn_drop(attn_probs) 219 | 220 | # Reshape to separate the two halves: (B, h_eff, 2, L, L) 221 | attn_probs = attn_probs.view(B, self.effective_heads, 2, L, L) 222 | 223 | # Compute λ via re-parameterization. 224 | lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1)) 225 | lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2)) 226 | lambda_full = lambda_1 - lambda_2 + self.lambda_init 227 | 228 | # Differential attention: subtract the second attention map scaled by λ. 229 | diff_attn = attn_probs[:, :, 0, :, :] - lambda_full * attn_probs[:, :, 1, :, :] # (B, h_eff, L, L) 230 | 231 | # Compute weighted sum with V. 232 | out = torch.matmul(diff_attn, v) # (B, h_eff, L, 2*head_dim) 233 | # Apply RMSNorm (headwise normalization) and scale by (1 - lambda_init). 234 | out = self.diff_norm(out) * (1 - self.lambda_init) 235 | 236 | # Concatenate heads: transpose to (B, L, h_eff, 2*head_dim) and then reshape to (B, L, embed_dim) 237 | out = out.transpose(1, 2).reshape(B, L, 2 * self.effective_heads * self.head_dim) 238 | # Final linear projection. 239 | out = self.out_proj(out) 240 | out = self.proj_drop(out) 241 | # Permute back to (L, N, embed_dim) 242 | out = out.transpose(0, 1) 243 | return out 244 | 245 | # --------------------------------------------------------- 246 | # Differential Residual Attention Block for Text 247 | # --------------------------------------------------------- 248 | class DifferentialResidualAttentionBlock(nn.Module): 249 | r""" 250 | Residual Attention Block using Differential Multihead Attention. 251 | 252 | This block first applies layer normalization to the input, then 253 | differential multihead attention, and adds the result to the input. 254 | Then it applies another layer normalization, a feed-forward MLP, and adds 255 | the result again. 256 | 257 | Equations: 258 | X' = X + DiffMultiheadAttention(LN(X)) 259 | Y = X' + MLP(LN(X')) 260 | 261 | Args: 262 | d_model (int): Embedding dimension. 263 | n_head (int): Number of attention heads (must be even). 264 | attn_mask (Tensor, optional): Additive attention mask of shape (L, L). 265 | """ 266 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 267 | super().__init__() 268 | self.attn = DifferentialMultiheadAttention(d_model, n_head) 269 | self.ln_1 = LayerNorm(d_model) 270 | self.mlp = nn.Sequential(OrderedDict([ 271 | ("c_fc", nn.Linear(d_model, d_model * 4)), 272 | ("gelu", QuickGELU()), 273 | ("c_proj", nn.Linear(d_model * 4, d_model)), 274 | ])) 275 | self.ln_2 = LayerNorm(d_model) 276 | self.attn_mask = attn_mask 277 | 278 | def attention(self, x: torch.Tensor): 279 | if self.attn_mask is not None: 280 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) 281 | # DifferentialMultiheadAttention returns output of shape (L, N, d_model) 282 | return self.attn(x, x, x, attn_mask=self.attn_mask) 283 | 284 | def forward(self, x: torch.Tensor): 285 | x = x + self.attention(self.ln_1(x)) 286 | x = x + self.mlp(self.ln_2(x)) 287 | return x 288 | 289 | # --------------------------------------------------------- 290 | # Differential Text Transformer 291 | # --------------------------------------------------------- 292 | class DifferentialTextTransformer(nn.Module): 293 | r""" 294 | Transformer for text built from Differential Residual Attention Blocks. 295 | 296 | Args: 297 | width (int): Embedding dimension. 298 | layers (int): Number of transformer layers. 299 | heads (int): Number of attention heads (must be even). 300 | attn_mask (Tensor, optional): Additive attention mask of shape (L, L). 301 | """ 302 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 303 | super().__init__() 304 | self.width = width 305 | self.layers = layers 306 | self.resblocks = nn.Sequential(*[ 307 | DifferentialResidualAttentionBlock(width, heads, attn_mask) 308 | for _ in range(layers) 309 | ]) 310 | 311 | def forward(self, x: torch.Tensor): 312 | return self.resblocks(x) 313 | 314 | # --------------------------------------------------------- 315 | # DiffCLIP: Differential CLIP Model 316 | # --------------------------------------------------------- 317 | class DiffCLIP(nn.Module): 318 | r""" 319 | DiffCLIP implements a differential version of CLIP, where the text encoder is modified 320 | to use DifferentialTextTransformer. 321 | 322 | The overall architecture is similar to CLIP: 323 | - A vision encoder (can be any vision model, e.g., a differential ViT). 324 | - A text encoder that tokenizes text, adds positional embeddings, and passes through 325 | a stack of differential transformer blocks. 326 | - Final projections for image and text features. 327 | 328 | Args: 329 | embed_dim (int): Dimension of the joint embedding space. 330 | vision_width (int): Width (output dimension) of the vision encoder. 331 | vision_model (nn.Module): Vision encoder model. 332 | context_length (int): Maximum text sequence length. 333 | vocab_size (int): Vocabulary size for the text encoder. 334 | transformer_width (int): Embedding dimension of the text transformer. 335 | transformer_heads (int): Number of heads in the text transformer (must be even). 336 | transformer_layers (int): Number of layers in the text transformer. 337 | """ 338 | def __init__( 339 | self, 340 | embed_dim: int, 341 | vision_width: int, 342 | vision_model: nn.Module, 343 | context_length: int, 344 | vocab_size: int, 345 | transformer_width: int, 346 | transformer_heads: int, 347 | transformer_layers: int, 348 | **kwargs, 349 | ): 350 | super().__init__() 351 | self.context_length = context_length 352 | self.vision_width = vision_width 353 | 354 | self.visual = vision_model 355 | 356 | # Use DifferentialTextTransformer instead of the standard Transformer. 357 | self.transformer = DifferentialTextTransformer( 358 | width=transformer_width, 359 | layers=transformer_layers, 360 | heads=transformer_heads, 361 | attn_mask=self.build_attention_mask(), 362 | ) 363 | 364 | self.vocab_size = vocab_size 365 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 366 | self.positional_embedding = nn.Parameter( 367 | torch.empty(self.context_length, transformer_width) 368 | ) 369 | self.ln_final = LayerNorm(transformer_width) 370 | 371 | self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim)) 372 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 373 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 374 | 375 | self.initialize_parameters() 376 | 377 | def initialize_parameters(self): 378 | nn.init.normal_(self.token_embedding.weight, std=0.02) 379 | nn.init.normal_(self.positional_embedding, std=0.01) 380 | 381 | # Initialize transformer parameters. 382 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 383 | attn_std = self.transformer.width ** -0.5 384 | fc_std = (2 * self.transformer.width) ** -0.5 385 | for block in self.transformer.resblocks: 386 | # For our differential attention blocks, initialize the linear layers. 387 | nn.init.normal_(block.attn.q_proj.weight, std=attn_std) 388 | nn.init.normal_(block.attn.k_proj.weight, std=attn_std) 389 | nn.init.normal_(block.attn.v_proj.weight, std=attn_std) 390 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 391 | nn.init.normal_(block.mlp[0].weight, std=fc_std) # c_fc layer 392 | nn.init.normal_(block.mlp[2].weight, std=proj_std) # c_proj layer 393 | 394 | nn.init.normal_(self.image_projection, std=self.vision_width ** -0.5) 395 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 396 | 397 | def build_attention_mask(self): 398 | # Create a causal attention mask for text. 399 | # The mask is of shape (context_length, context_length) with -inf above the diagonal. 400 | mask = torch.empty(self.context_length, self.context_length) 401 | mask.fill_(float("-inf")) 402 | mask.triu_(1) 403 | return mask 404 | 405 | def encode_image(self, image): 406 | x = self.visual(image) 407 | x = x @ self.image_projection 408 | return x 409 | 410 | def encode_text(self, text): 411 | # text: (batch_size, context_length) 412 | x = self.token_embedding(text) # (batch_size, context_length, transformer_width) 413 | x = x + self.positional_embedding 414 | # Permute to (context_length, batch_size, transformer_width) 415 | x = x.permute(1, 0, 2) 416 | x = self.transformer(x) 417 | # Permute back to (batch_size, context_length, transformer_width) 418 | x = x.permute(1, 0, 2) 419 | x = self.ln_final(x) 420 | # Extract the features at the position of the end-of-text token (assumed to be the max token index in each sequence). 421 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 422 | return x 423 | 424 | def forward(self, image, text): 425 | image_embed = self.encode_image(image) 426 | text_embed = self.encode_text(text) 427 | return { 428 | "image_embed": image_embed, 429 | "text_embed": text_embed, 430 | "logit_scale": self.logit_scale.exp(), 431 | } 432 | 433 | 434 | def DiffCLIP_VITB16(**kwargs): 435 | """ 436 | Factory function to build DiffCLIP with a ViT-B/16 vision encoder. 437 | This function creates a vision model using the differential vision transformer 438 | "diff_vit_base_patch16_224" and then builds DiffCLIP. 439 | 440 | Args: 441 | **kwargs: Additional keyword arguments. 442 | 443 | Returns: 444 | DiffCLIP model. 445 | """ 446 | # Create a vision model using the differential vision transformer 447 | vision_model = diff_vit_base_patch16_224(num_classes=0) 448 | model = DiffCLIP( 449 | embed_dim=512, 450 | vision_width=768, 451 | vision_model=vision_model, 452 | context_length=77, 453 | vocab_size=49408, 454 | transformer_width=512, 455 | transformer_heads=8, # must be even 456 | transformer_layers=12, 457 | **kwargs, 458 | ) 459 | return model 460 | 461 | 462 | # --------------------------------------------------------- 463 | # Main test function 464 | # --------------------------------------------------------- 465 | if __name__ == "__main__": 466 | # Create dummy inputs. 467 | dummy_image = torch.randn(2, 3, 224, 224) # batch of 2 images 468 | # Create dummy text tokens (e.g., 77 tokens per sequence). For simplicity, we simulate tokens as random integers. 469 | dummy_text = torch.randint(low=0, high=49408, size=(2, 77)) 470 | 471 | # Instantiate DiffCLIP using the DiffCLIP_VITB16 factory function. 472 | model = DiffCLIP_VITB16() 473 | model.eval() 474 | 475 | with torch.no_grad(): 476 | outputs = model(dummy_image, dummy_text) 477 | 478 | print("DiffCLIP output keys:", outputs.keys()) 479 | print("Image embed shape:", outputs["image_embed"].shape) 480 | print("Text embed shape:", outputs["text_embed"].shape) 481 | print("Logit scale:", outputs["logit_scale"]) 482 | 483 | trainable_params = sum(p.numel() for p in model.parameters()) 484 | print(f"Number of trainable parameters: {trainable_params}") 485 | 486 | # Note: Original CLIP comparison removed as it's not relevant for the public release 487 | 488 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.0 2 | timm>=0.4.12 3 | numpy>=1.19.2 4 | huggingface_hub>=0.11.0 5 | requests>=2.26.0 6 | Pillow>=8.0.0 7 | ftfy>=6.0.0 8 | regex>=2022.3.15 -------------------------------------------------------------------------------- /test_models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | test_models.py 4 | 5 | This script demonstrates zero-shot prediction with DiffCLIP. It: 6 | 1. Downloads the DiffCLIP_ViTB16_CC12M model from Hugging Face 7 | 2. Loads an image from COCO dataset 8 | 3. Performs zero-shot classification 9 | 10 | Usage: 11 | python test_models.py 12 | """ 13 | 14 | import os 15 | import torch 16 | import numpy as np 17 | import requests 18 | from PIL import Image 19 | from io import BytesIO 20 | from huggingface_hub import hf_hub_download 21 | from tokenizer import SimpleTokenizer 22 | 23 | # Import the DiffCLIP models 24 | from diff_clip import DiffCLIP_VITB16 25 | 26 | 27 | def download_model(): 28 | """ 29 | Download the DiffCLIP_ViTB16_CC12M model from Hugging Face to a local folder. 30 | Returns the path to the checkpoint file. 31 | """ 32 | print("Downloading model from Hugging Face...") 33 | model_id = "hammh0a/DiffCLIP_ViTB16_CC12M" 34 | local_dir = "./DiffCLIP_ViTB16_CC12M" 35 | 36 | os.makedirs(local_dir, exist_ok=True) 37 | 38 | # Download the checkpoint file 39 | checkpoint_path = hf_hub_download( 40 | repo_id=model_id, 41 | filename="checkpoint_best.pt", 42 | local_dir=local_dir, 43 | local_dir_use_symlinks=False 44 | ) 45 | 46 | print(f"Model downloaded to {checkpoint_path}") 47 | return checkpoint_path 48 | 49 | 50 | def load_model(checkpoint_path): 51 | """ 52 | Load the DiffCLIP model and checkpoint. 53 | """ 54 | print("Loading model...") 55 | # Create a model instance 56 | model = DiffCLIP_VITB16() 57 | 58 | # Load the checkpoint 59 | checkpoint = torch.load(checkpoint_path, map_location="cpu") 60 | 61 | # If the model was saved with DataParallel, we need to handle that 62 | if list(checkpoint["state_dict"].keys())[0].startswith("module."): 63 | # Create a new state_dict without the 'module.' prefix 64 | new_state_dict = {k[7:]: v for k, v in checkpoint["state_dict"].items()} 65 | load_status = model.load_state_dict(new_state_dict) 66 | else: 67 | load_status = model.load_state_dict(checkpoint["state_dict"]) 68 | 69 | print(f"Model loaded with status: {load_status}") 70 | return model 71 | 72 | 73 | def load_image_from_coco(): 74 | """ 75 | Load a sample image from COCO dataset. 76 | """ 77 | print("Loading sample image from COCO...") 78 | # A sample image URL from COCO 79 | coco_image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" 80 | 81 | # Download the image 82 | response = requests.get(coco_image_url) 83 | img = Image.open(BytesIO(response.content)) 84 | 85 | # Resize and preprocess for the model 86 | img = img.convert("RGB").resize((224, 224)) 87 | img_tensor = torch.tensor(np.array(img)).permute(2, 0, 1).float() # (3, 224, 224) 88 | img_tensor = img_tensor / 255.0 # Normalize to [0, 1] 89 | 90 | # Apply ImageNet normalization 91 | mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) 92 | std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) 93 | img_tensor = (img_tensor - mean) / std 94 | 95 | return img_tensor.unsqueeze(0), img # Return tensor and PIL image 96 | 97 | 98 | def zero_shot_prediction(model, image_tensor, tokenizer, classes): 99 | """ 100 | Perform zero-shot prediction on an image with the provided model. 101 | """ 102 | print("Performing zero-shot prediction...") 103 | model.eval() 104 | 105 | # Create text prompts 106 | prompts = [f"a photo of a {label}" for label in classes] 107 | 108 | # Tokenize text 109 | text_tokens = tokenizer(prompts) 110 | 111 | # Put everything on the same device 112 | device = "cuda" if torch.cuda.is_available() else "cpu" 113 | model = model.to(device) 114 | image_tensor = image_tensor.to(device) 115 | text_tokens = text_tokens.to(device) 116 | 117 | # Get image and text features 118 | with torch.no_grad(): 119 | outputs = model(image_tensor, text_tokens) 120 | image_features = outputs["image_embed"] 121 | text_features = outputs["text_embed"] 122 | logit_scale = outputs["logit_scale"] 123 | 124 | # Normalize features 125 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 126 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 127 | 128 | # Calculate similarity scores 129 | similarity = (logit_scale * image_features @ text_features.T).softmax(dim=-1) 130 | 131 | # Get top predictions 132 | values, indices = similarity[0].topk(min(5, len(classes))) 133 | 134 | # Return prediction results 135 | predictions = [(classes[idx], values[i].item()) for i, idx in enumerate(indices)] 136 | return predictions 137 | 138 | 139 | def main(): 140 | # Define some classes for zero-shot prediction 141 | coco_classes = [ 142 | "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", 143 | "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", 144 | "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", 145 | "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", 146 | "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", 147 | "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", 148 | "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", 149 | "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", 150 | "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", 151 | "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", 152 | "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", 153 | "teddy bear", "hair drier", "toothbrush" 154 | ] 155 | 156 | # Initialize tokenizer 157 | tokenizer = SimpleTokenizer() 158 | 159 | # Download and load the model 160 | checkpoint_path = download_model() 161 | model = load_model(checkpoint_path) 162 | 163 | # Load a sample image 164 | image_tensor, pil_image = load_image_from_coco() 165 | 166 | # Save the image for reference 167 | # pil_image.save("coco_sample.jpg") 168 | # print("Sample image saved as 'coco_sample.jpg'") 169 | 170 | # Perform zero-shot prediction 171 | predictions = zero_shot_prediction(model, image_tensor, tokenizer, coco_classes) 172 | 173 | # Print results 174 | print("\nZero-shot prediction results:") 175 | for i, (label, score) in enumerate(predictions): 176 | print(f"{i+1}. {label}: {score:.4f}") 177 | 178 | 179 | if __name__ == "__main__": 180 | main() -------------------------------------------------------------------------------- /tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Modified from github.com/openai/CLIP 8 | import gzip 9 | import html 10 | import os 11 | from functools import lru_cache 12 | 13 | import ftfy 14 | import regex as re 15 | import torch 16 | 17 | 18 | @lru_cache() 19 | def default_bpe(): 20 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 21 | 22 | 23 | @lru_cache() 24 | def bytes_to_unicode(): 25 | """ 26 | Returns list of utf-8 byte and a corresponding list of unicode strings. 27 | The reversible bpe codes work on unicode strings. 28 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 29 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 30 | This is a signficant percentage of your normal, say, 32K bpe vocab. 31 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 32 | And avoids mapping to whitespace/control characters the bpe code barfs on. 33 | """ 34 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 35 | cs = bs[:] 36 | n = 0 37 | for b in range(2**8): 38 | if b not in bs: 39 | bs.append(b) 40 | cs.append(2**8+n) 41 | n += 1 42 | cs = [chr(n) for n in cs] 43 | return dict(zip(bs, cs)) 44 | 45 | 46 | def get_pairs(word): 47 | """Return set of symbol pairs in a word. 48 | Word is represented as tuple of symbols (symbols being variable-length strings). 49 | """ 50 | pairs = set() 51 | prev_char = word[0] 52 | for char in word[1:]: 53 | pairs.add((prev_char, char)) 54 | prev_char = char 55 | return pairs 56 | 57 | 58 | def basic_clean(text): 59 | text = ftfy.fix_text(text) 60 | text = html.unescape(html.unescape(text)) 61 | return text.strip() 62 | 63 | 64 | def whitespace_clean(text): 65 | text = re.sub(r'\s+', ' ', text) 66 | text = text.strip() 67 | return text 68 | 69 | 70 | class SimpleTokenizer(object): 71 | def __init__(self, bpe_path: str = default_bpe()): 72 | self.byte_encoder = bytes_to_unicode() 73 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 74 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 75 | merges = merges[1:49152-256-2+1] 76 | merges = [tuple(merge.split()) for merge in merges] 77 | vocab = list(bytes_to_unicode().values()) 78 | vocab = vocab + [v+'' for v in vocab] 79 | for merge in merges: 80 | vocab.append(''.join(merge)) 81 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 82 | self.encoder = dict(zip(vocab, range(len(vocab)))) 83 | self.decoder = {v: k for k, v in self.encoder.items()} 84 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 85 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 86 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 87 | 88 | def bpe(self, token): 89 | if token in self.cache: 90 | return self.cache[token] 91 | word = tuple(token[:-1]) + ( token[-1] + '',) 92 | pairs = get_pairs(word) 93 | 94 | if not pairs: 95 | return token+'' 96 | 97 | while True: 98 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 99 | if bigram not in self.bpe_ranks: 100 | break 101 | first, second = bigram 102 | new_word = [] 103 | i = 0 104 | while i < len(word): 105 | try: 106 | j = word.index(first, i) 107 | new_word.extend(word[i:j]) 108 | i = j 109 | except: 110 | new_word.extend(word[i:]) 111 | break 112 | 113 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 114 | new_word.append(first+second) 115 | i += 2 116 | else: 117 | new_word.append(word[i]) 118 | i += 1 119 | new_word = tuple(new_word) 120 | word = new_word 121 | if len(word) == 1: 122 | break 123 | else: 124 | pairs = get_pairs(word) 125 | word = ' '.join(word) 126 | self.cache[token] = word 127 | return word 128 | 129 | def encode(self, text): 130 | bpe_tokens = [] 131 | text = whitespace_clean(basic_clean(text)).lower() 132 | for token in re.findall(self.pat, text): 133 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 134 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 135 | return bpe_tokens 136 | 137 | def decode(self, tokens): 138 | text = ''.join([self.decoder[token] for token in tokens]) 139 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 140 | return text 141 | 142 | def __call__(self, texts, context_length=77): 143 | if isinstance(texts, str): 144 | texts = [texts] 145 | 146 | sot_token = self.encoder["<|startoftext|>"] 147 | eot_token = self.encoder["<|endoftext|>"] 148 | all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts] 149 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 150 | 151 | for i, tokens in enumerate(all_tokens): 152 | tokens = tokens[:context_length] 153 | result[i, :len(tokens)] = torch.tensor(tokens) 154 | 155 | if len(result) == 1: 156 | return result[0] 157 | return result --------------------------------------------------------------------------------