├── DiffCLIP__Differential_Attention_Meets_CLIP.pdf
├── README.md
├── __init__.py
├── __pycache__
├── diff_attention.cpython-310.pyc
├── diff_clip.cpython-310.pyc
└── tokenizer.cpython-310.pyc
├── assets
├── coco_sample.jpg
└── images.png
├── bpe_simple_vocab_16e6.txt.gz
├── diff_attention.py
├── diff_clip.py
├── requirements.txt
├── test_models.py
└── tokenizer.py
/DiffCLIP__Differential_Attention_Meets_CLIP.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hammoudhasan/DiffCLIP/f6bd3ff3b2b97f12fecc903e2c1496357242d0c9/DiffCLIP__Differential_Attention_Meets_CLIP.pdf
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
DiffCLIP: Differential Attention Meets CLIP
3 |
Hasan Abed Al Kader Hammoud and Bernard Ghanem
4 |
King Abdullah University of Science and Technology
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |

17 |
18 |
19 | ## Abstract
20 |
21 | We propose DiffCLIP, a novel vision-language model that extends the differential attention mechanism to CLIP architectures. Differential attention was originally developed for large language models to amplify relevant context while canceling out noisy information. In this work, we integrate this mechanism into CLIP's dual encoder (image and text) framework. With minimal additional parameters, DiffCLIP achieves superior performance on image-text understanding tasks. Across zero-shot classification, retrieval, and robustness benchmarks, DiffCLIP consistently outperforms baseline CLIP models. Notably, these gains come with negligible computational overhead, demonstrating that differential attention can significantly enhance multi-modal representations without sacrificing efficiency.
22 |
23 | ## What is Differential Attention?
24 |
25 | Differential attention, proposed in [Differential Transformer](https://arxiv.org/abs/2410.05258), computes the difference between two attention maps:
26 |
27 | ```
28 | DiffAttn(X) = (softmax(Q₁K₁ᵀ/√d) − λ · softmax(Q₂K₂ᵀ/√d)) · V
29 | ```
30 |
31 | where the query and key projections are split as `[Q₁; Q₂] = X·Wᵠ` and `[K₁; K₂] = X·Wᵏ`, and λ is a learnable parameter. This mechanism allows the model to capture complementary information by explicitly modeling the differences between attention patterns, leading to richer multimodal representations.
32 |
33 | ## Structure
34 |
35 | The repository contains two main components:
36 |
37 | 1. **DifferentialVisionTransformer** (in `diff_attention.py`): A Vision Transformer modified to use differential attention.
38 |
39 | 2. **DiffCLIP** (in `diff_clip.py`): A CLIP model that uses differential attention in both its vision and text encoders.
40 |
41 | ## How to Use
42 |
43 | ### Installation
44 |
45 | ```bash
46 | # Clone the repository
47 | git clone https://github.com/yourusername/DiffCLIP.git
48 | cd DiffCLIP
49 |
50 | # Install dependencies
51 | pip install -r requirements.txt
52 | ```
53 |
54 | ### Basic Usage
55 |
56 | ```python
57 | import torch
58 | from diff_clip import DiffCLIP_VITB16
59 |
60 | # Create model
61 | model = DiffCLIP_VITB16()
62 |
63 | # Process image and text
64 | image = torch.randn(1, 3, 224, 224)
65 | text = torch.randint(0, 49408, (1, 77)) # Tokenized text
66 |
67 | # Get embeddings
68 | with torch.no_grad():
69 | outputs = model(image, text)
70 |
71 | print(outputs["image_embed"].shape) # Should be [1, 512]
72 | print(outputs["text_embed"].shape) # Should be [1, 512]
73 | ```
74 |
75 | ### Zero-Shot Classification
76 |
77 | You can use the provided `test_models.py` script to perform zero-shot classification:
78 |
79 | ```bash
80 | # Download the model from Hugging Face and test on a COCO image
81 | python test_models.py
82 | ```
83 |
84 | This will:
85 | 1. Download the DiffCLIP_ViTB16_CC12M model from Hugging Face
86 | 2. Load a sample image from COCO
87 | 3. Perform zero-shot classification
88 | 4. Print the top-5 predicted classes
89 |
90 | ## References
91 |
92 | ```
93 | @misc{hammoud2025diffclipdifferentialattentionmeets,
94 | title={DiffCLIP: Differential Attention Meets CLIP},
95 | author={Hasan Abed Al Kader Hammoud and Bernard Ghanem},
96 | year={2025},
97 | eprint={2503.06626},
98 | archivePrefix={arXiv},
99 | primaryClass={cs.CV},
100 | url={https://arxiv.org/abs/2503.06626},
101 | }
102 | ```
103 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | DiffCLIP Module
4 |
5 | This package implements a differential version of CLIP, featuring:
6 | - DifferentialVisionTransformer: A vision transformer with differential attention
7 | - DiffCLIP: A CLIP model using differential attention in both vision and text encoders
8 |
9 | Main components:
10 | - diff_attention.py: Implements DifferentialVisionTransformer
11 | - diff_clip.py: Implements DiffCLIP
12 |
13 | For more details, see the individual module docstrings.
14 | """
15 |
16 | from .diff_attention import (
17 | RMSNorm,
18 | DiffAttention,
19 | LayerScale,
20 | DiffBlock,
21 | DifferentialVisionTransformer,
22 | diff_vit_base_patch16_224
23 | )
24 |
25 | from .diff_clip import (
26 | DifferentialMultiheadAttention,
27 | DifferentialResidualAttentionBlock,
28 | DifferentialTextTransformer,
29 | DiffCLIP,
30 | DiffCLIP_VITB16,
31 | )
32 |
33 | __all__ = [
34 | 'RMSNorm',
35 | 'DiffAttention',
36 | 'LayerScale',
37 | 'DiffBlock',
38 | 'DifferentialVisionTransformer',
39 | 'diff_vit_base_patch16_224',
40 | 'DifferentialMultiheadAttention',
41 | 'DifferentialResidualAttentionBlock',
42 | 'DifferentialTextTransformer',
43 | 'DiffCLIP',
44 | 'DiffCLIP_VITB16',
45 | ]
46 |
--------------------------------------------------------------------------------
/__pycache__/diff_attention.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hammoudhasan/DiffCLIP/f6bd3ff3b2b97f12fecc903e2c1496357242d0c9/__pycache__/diff_attention.cpython-310.pyc
--------------------------------------------------------------------------------
/__pycache__/diff_clip.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hammoudhasan/DiffCLIP/f6bd3ff3b2b97f12fecc903e2c1496357242d0c9/__pycache__/diff_clip.cpython-310.pyc
--------------------------------------------------------------------------------
/__pycache__/tokenizer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hammoudhasan/DiffCLIP/f6bd3ff3b2b97f12fecc903e2c1496357242d0c9/__pycache__/tokenizer.cpython-310.pyc
--------------------------------------------------------------------------------
/assets/coco_sample.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hammoudhasan/DiffCLIP/f6bd3ff3b2b97f12fecc903e2c1496357242d0c9/assets/coco_sample.jpg
--------------------------------------------------------------------------------
/assets/images.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hammoudhasan/DiffCLIP/f6bd3ff3b2b97f12fecc903e2c1496357242d0c9/assets/images.png
--------------------------------------------------------------------------------
/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hammoudhasan/DiffCLIP/f6bd3ff3b2b97f12fecc903e2c1496357242d0c9/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/diff_attention.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | diff_attention.py
4 |
5 | This module implements a Differential Attention Vision Transformer.
6 | The key idea is to replace the standard softmax attention with a
7 | differential attention mechanism as described in the paper:
8 |
9 | DiffAttn(X) = (softmax(Q₁K₁ᵀ/√d) − λ · softmax(Q₂K₂ᵀ/√d)) · V
10 |
11 | where the query and key projections are split as:
12 | [Q₁; Q₂] = X Wᵠ, [K₁; K₂] = X Wᵏ,
13 | and V = X Wᵛ.
14 |
15 | The learnable scalar λ is re-parameterized as:
16 | λ = exp(λ_{q1} ⋅ λ_{k1}) − exp(λ_{q2} ⋅ λ_{k2}) + λ_init
17 |
18 | The multi-head formulation uses "effective heads" computed as:
19 | h_effective = (num_heads // 2)
20 | with the per-head dimension d_head = d_model / num_heads. Note that the value
21 | projection is not split (it remains of dimension d_model), so that its per-head shape
22 | is (2·d_head), aligning with the fact that Q and K are split into two parts.
23 |
24 | The overall transformer block is:
25 | Y = X + DropPath(LayerScale(DiffAttention(LN(X))))
26 | X' = Y + DropPath(LayerScale(MLP(LN(Y))))
27 |
28 | The DifferentialVisionTransformer class below inherits from timm's VisionTransformer
29 | and replaces its transformer blocks with blocks using Differential Attention.
30 | A registration function diff_vit_base_patch16_224 is provided with the same default
31 | parameters as ViT-Base (patch16, 224).
32 |
33 | References:
34 | - Vision Transformer: https://arxiv.org/abs/2010.11929
35 | - Differential Transformers: (see paper)
36 | """
37 |
38 | import math
39 | import torch
40 | import torch.nn as nn
41 | import torch.nn.functional as F
42 |
43 | # ---------------------------------------------------------
44 | # Register model decorator: Try to import timm's version; if unavailable, use a dummy.
45 | try:
46 | from timm.models.registry import register_model
47 | except ImportError:
48 | def register_model(fn):
49 | return fn
50 |
51 | # Import timm's VisionTransformer and common layers.
52 | from timm.models.vision_transformer import VisionTransformer
53 | from timm.models.layers import DropPath, Mlp
54 |
55 | # ---------------------------------------------------------
56 | # RMSNorm (as used in the differential attention paper)
57 | # ---------------------------------------------------------
58 | class RMSNorm(nn.Module):
59 | r"""
60 | RMSNorm normalizes the input tensor by its root-mean-square (RMS) value.
61 |
62 | Given an input x ∈ ℝ^(...×d), it computes:
63 |
64 | RMS(x) = sqrt(mean(x², dim=-1, keepdim=True) + ε)
65 | output = x / RMS(x)
66 |
67 | Optionally, a learnable weight is applied if elementwise_affine is True.
68 |
69 | Args:
70 | dim (int): Dimension to normalize.
71 | eps (float): A value added for numerical stability.
72 | elementwise_affine (bool): If True, multiply by a learnable weight.
73 | """
74 | def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True):
75 | super().__init__()
76 | self.dim = dim
77 | self.eps = eps
78 | if elementwise_affine:
79 | self.weight = nn.Parameter(torch.ones(dim))
80 | else:
81 | self.register_parameter('weight', None)
82 |
83 | def _norm(self, x):
84 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
85 |
86 | def forward(self, x: torch.Tensor) -> torch.Tensor:
87 | output = self._norm(x.float()).type_as(x)
88 | if self.weight is not None:
89 | output = output * self.weight
90 | return output
91 |
92 | def extra_repr(self) -> str:
93 | return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.weight is not None}'
94 |
95 | # ---------------------------------------------------------
96 | # Differential Attention Module
97 | # ---------------------------------------------------------
98 | class DiffAttention(nn.Module):
99 | r"""
100 | Differential Attention Module.
101 |
102 | Given an input tensor X ∈ ℝ^(B×N×d_model), we first compute the linear projections:
103 |
104 | Q = X Wᵠ, K = X Wᵏ, V = X Wᵛ
105 |
106 | The queries and keys are then reshaped and split into two parts:
107 | Q → [Q₁; Q₂] ∈ ℝ^(B, N, 2·h_effective, d_head)
108 | K → [K₁; K₂] ∈ ℝ^(B, N, 2·h_effective, d_head)
109 | with h_effective = num_heads // 2 and d_head = d_model / num_heads.
110 |
111 | The value projection is reshaped to:
112 | V ∈ ℝ^(B, N, h_effective, 2·d_head)
113 |
114 | We then compute two attention maps:
115 | A₁ = softmax((Q₁ K₁ᵀ) / √d_head)
116 | A₂ = softmax((Q₂ K₂ᵀ) / √d_head)
117 |
118 | A learnable scalar λ is computed via:
119 | λ = exp(λ_{q1} ⋅ λ_{k1}) − exp(λ_{q2} ⋅ λ_{k2}) + λ_init
120 |
121 | Finally, the differential attention output is:
122 | DiffAttn(X) = (A₁ − λ · A₂) · V
123 |
124 | The per-head outputs are then normalized headwise with RMSNorm and projected back to d_model.
125 |
126 | Args:
127 | dim (int): Embedding dimension (d_model).
128 | num_heads (int): Number of heads in the original transformer (must be even).
129 | qkv_bias (bool): If True, add a bias term to the Q, K, V projections.
130 | attn_drop (float): Dropout probability after softmax.
131 | proj_drop (float): Dropout probability after the output projection.
132 | lambda_init (float): Initial constant for lambda re-parameterization.
133 | """
134 | def __init__(self, dim, num_heads=8, qkv_bias=True, attn_drop=0., proj_drop=0., lambda_init=0.8):
135 | super().__init__()
136 | if num_heads % 2 != 0:
137 | raise ValueError("num_heads must be even for Differential Attention.")
138 | self.dim = dim
139 | self.num_heads = num_heads # original number of heads
140 | self.effective_heads = num_heads // 2 # differential attention operates on half as many heads
141 | self.head_dim = dim // num_heads # per-head dimension
142 | self.scaling = self.head_dim ** -0.5
143 |
144 | # Linear projections for Q, K, V: mapping from dim → dim.
145 | self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
146 | self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
147 | self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
148 | self.out_proj = nn.Linear(dim, dim, bias=True) # final output projection
149 |
150 | self.attn_drop = nn.Dropout(attn_drop)
151 | self.proj_drop = nn.Dropout(proj_drop)
152 |
153 | # RMSNorm for headwise normalization on outputs (each head's output has dimension 2·head_dim)
154 | self.diff_norm = RMSNorm(2 * self.head_dim, eps=1e-5, elementwise_affine=True)
155 |
156 | # Learnable lambda parameters (shared across all heads)
157 | self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))
158 | self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))
159 | self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))
160 | self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))
161 | self.lambda_init = lambda_init
162 |
163 | def forward(self, x: torch.Tensor) -> torch.Tensor:
164 | """
165 | Args:
166 | x (Tensor): Input tensor of shape (B, N, d_model).
167 |
168 | Returns:
169 | Tensor of shape (B, N, d_model) after applying differential attention.
170 | """
171 | B, N, _ = x.shape
172 |
173 | # Compute Q, K, V projections.
174 | q = self.q_proj(x) # shape: (B, N, d_model)
175 | k = self.k_proj(x) # shape: (B, N, d_model)
176 | v = self.v_proj(x) # shape: (B, N, d_model)
177 |
178 | # Reshape Q and K into (B, N, 2 * h_effective, head_dim)
179 | q = q.view(B, N, 2 * self.effective_heads, self.head_dim)
180 | k = k.view(B, N, 2 * self.effective_heads, self.head_dim)
181 | # Reshape V into (B, N, h_effective, 2 * head_dim)
182 | v = v.view(B, N, self.effective_heads, 2 * self.head_dim)
183 |
184 | # Transpose to bring head dimension forward.
185 | # q, k: (B, 2 * h_effective, N, head_dim)
186 | q = q.transpose(1, 2)
187 | k = k.transpose(1, 2)
188 | # v: (B, h_effective, N, 2 * head_dim)
189 | v = v.transpose(1, 2)
190 |
191 | # Scale Q.
192 | q = q * self.scaling
193 |
194 | # Compute raw attention scores: (B, 2 * h_effective, N, N)
195 | attn_scores = torch.matmul(q, k.transpose(-1, -2))
196 |
197 | # Compute attention probabilities.
198 | attn_probs = F.softmax(attn_scores, dim=-1)
199 | attn_probs = self.attn_drop(attn_probs)
200 |
201 | # Reshape to separate the two halves: (B, h_effective, 2, N, N)
202 | attn_probs = attn_probs.view(B, self.effective_heads, 2, N, N)
203 |
204 | # Compute lambda via re-parameterization.
205 | lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1))
206 | lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2))
207 | lambda_full = lambda_1 - lambda_2 + self.lambda_init
208 |
209 | # Differential attention: subtract the second attention map scaled by lambda_full.
210 | diff_attn = attn_probs[:, :, 0, :, :] - lambda_full * attn_probs[:, :, 1, :, :] # shape: (B, h_effective, N, N)
211 |
212 | # Multiply the differential attention weights with V.
213 | attn_output = torch.matmul(diff_attn, v) # shape: (B, h_effective, N, 2 * head_dim)
214 |
215 | # Apply RMSNorm (headwise normalization) and scale by (1 - lambda_init)
216 | attn_output = self.diff_norm(attn_output) * (1 - self.lambda_init)
217 |
218 | # Concatenate heads: reshape from (B, h_effective, N, 2 * head_dim) → (B, N, 2 * h_effective * head_dim)
219 | attn_output = attn_output.transpose(1, 2).reshape(B, N, 2 * self.effective_heads * self.head_dim)
220 |
221 | # Final linear projection.
222 | x_out = self.out_proj(attn_output)
223 | x_out = self.proj_drop(x_out)
224 | return x_out
225 |
226 | # ---------------------------------------------------------
227 | # LayerScale module (optional scaling of sublayer outputs)
228 | # ---------------------------------------------------------
229 | class LayerScale(nn.Module):
230 | r"""
231 | LayerScale scales the output of a sublayer by a learnable parameter.
232 |
233 | Equation:
234 | Output = x * γ
235 |
236 | Args:
237 | dim (int): Dimension of the sublayer output.
238 | init_values (float): Initial value for scaling parameter γ.
239 | inplace (bool): Whether to perform the multiplication in-place.
240 | """
241 | def __init__(self, dim, init_values=1e-5, inplace=False):
242 | super().__init__()
243 | self.inplace = inplace
244 | self.gamma = nn.Parameter(init_values * torch.ones(dim))
245 |
246 | def forward(self, x: torch.Tensor) -> torch.Tensor:
247 | return x.mul_(self.gamma) if self.inplace else x * self.gamma
248 |
249 | # ---------------------------------------------------------
250 | # Transformer Block with Differential Attention
251 | # ---------------------------------------------------------
252 | class DiffBlock(nn.Module):
253 | r"""
254 | Transformer Block with Differential Attention.
255 |
256 | The block consists of two main sublayers:
257 |
258 | 1. Differential Attention sublayer:
259 | Y = X + DropPath(LayerScale(DiffAttention(LayerNorm(X))))
260 |
261 | 2. MLP sublayer:
262 | X' = Y + DropPath(LayerScale(MLP(LayerNorm(Y))))
263 |
264 | Equations:
265 | Y = X + DropPath(LS₁(DiffAttention(LN₁(X))))
266 | X' = Y + DropPath(LS₂(MLP(LN₂(Y))))
267 |
268 | Args:
269 | dim (int): Embedding dimension.
270 | num_heads (int): Number of heads in the original transformer (must be even).
271 | mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
272 | qkv_bias (bool): If True, add bias in Q, K, V projections.
273 | drop (float): Dropout probability.
274 | attn_drop (float): Attention dropout probability.
275 | drop_path (float): Stochastic depth rate.
276 | init_values (float or None): Initial value for LayerScale. If None, LayerScale is disabled.
277 | act_layer (nn.Module): Activation layer.
278 | norm_layer (nn.Module): Normalization layer.
279 | lambda_init (float): Initial lambda value for differential attention.
280 | """
281 | def __init__(self, dim, num_heads, mlp_ratio=4.0, qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
282 | init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, lambda_init=0.8):
283 | super().__init__()
284 | self.norm1 = norm_layer(dim)
285 | self.attn = DiffAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
286 | lambda_init=lambda_init)
287 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity()
288 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
289 |
290 | self.norm2 = norm_layer(dim)
291 | self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
292 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity()
293 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
294 |
295 | def forward(self, x: torch.Tensor) -> torch.Tensor:
296 | x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
297 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
298 | return x
299 |
300 | # ---------------------------------------------------------
301 | # Differential Vision Transformer
302 | # ---------------------------------------------------------
303 | class DifferentialVisionTransformer(VisionTransformer):
304 | r"""
305 | Vision Transformer with Differential Attention.
306 |
307 | This model is a modification of the standard VisionTransformer (timm)
308 | where the self-attention mechanism is replaced with Differential Attention.
309 |
310 | In each transformer block, the attention sublayer is computed as:
311 |
312 | DiffAttn(X) = (softmax(Q₁K₁ᵀ/√d_head) − λ · softmax(Q₂K₂ᵀ/√d_head)) · V
313 |
314 | with the λ re-parameterization:
315 | λ = exp(λ_{q1}⋅λ_{k1}) − exp(λ_{q2}⋅λ_{k2}) + λ_init
316 |
317 | The overall block structure is:
318 | Y = X + DropPath(LayerScale(DiffAttention(LayerNorm(X))))
319 | X' = Y + DropPath(LayerScale(MLP(LayerNorm(Y))))
320 |
321 | Args:
322 | All arguments are as in timm's VisionTransformer.
323 | lambda_init (float): Initial lambda value for differential attention.
324 | """
325 | def __init__(self, *args, lambda_init=0.8, **kwargs):
326 | super().__init__(*args, **kwargs)
327 | depth = kwargs.get('depth', 12)
328 | embed_dim = self.embed_dim # d_model from VisionTransformer
329 | num_heads = kwargs.get('num_heads', 12)
330 | mlp_ratio = kwargs.get('mlp_ratio', 4.0)
331 | qkv_bias = kwargs.get('qkv_bias', True)
332 | drop_rate = kwargs.get('drop_rate', 0.)
333 | attn_drop_rate = kwargs.get('attn_drop_rate', 0.)
334 | drop_path_rate = kwargs.get('drop_path_rate', 0.)
335 | init_values = kwargs.get('init_values', None)
336 | norm_layer = kwargs.get('norm_layer', None) or nn.LayerNorm
337 | act_layer = kwargs.get('act_layer', None) or nn.GELU
338 |
339 | # Create stochastic depth schedule.
340 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
341 | blocks = []
342 | for i in range(depth):
343 | blocks.append(
344 | DiffBlock(
345 | dim=embed_dim,
346 | num_heads=num_heads,
347 | mlp_ratio=mlp_ratio,
348 | qkv_bias=qkv_bias,
349 | drop=drop_rate,
350 | attn_drop=attn_drop_rate,
351 | drop_path=dpr[i],
352 | init_values=init_values,
353 | act_layer=act_layer,
354 | norm_layer=norm_layer,
355 | lambda_init=lambda_init # same for all blocks (or can vary with depth)
356 | )
357 | )
358 | self.blocks = nn.Sequential(*blocks)
359 | if hasattr(self, 'norm'):
360 | self.norm = norm_layer(embed_dim)
361 |
362 | # ---------------------------------------------------------
363 | # Model Registration: Differential ViT-Base (Patch16, 224)
364 | # ---------------------------------------------------------
365 | @register_model
366 | def diff_vit_base_patch16_224(pretrained: bool = False, **kwargs) -> DifferentialVisionTransformer:
367 | """
368 | Differential ViT-Base (ViT-B/16) with Differential Attention.
369 |
370 | The defaults are set to match the original ViT-Base (patch16, 224):
371 | - patch_size = 16
372 | - embed_dim = 768
373 | - depth = 12
374 | - num_heads = 12
375 | - lambda_init = 0.8
376 |
377 | Args:
378 | pretrained (bool): If True, load pretrained weights (not implemented here).
379 | **kwargs: Additional keyword arguments.
380 |
381 | Returns:
382 | DifferentialVisionTransformer model.
383 | """
384 | model_args = dict(
385 | patch_size=16,
386 | embed_dim=768,
387 | depth=12,
388 | num_heads=12,
389 | lambda_init=0.8,
390 | )
391 | # Merge additional kwargs with defaults.
392 | model = DifferentialVisionTransformer(**dict(model_args, **kwargs))
393 | if pretrained:
394 | # Code to load pretrained weights can be added here.
395 | pass
396 | return model
397 |
398 | # ---------------------------------------------------------
399 | # Main test function
400 | # ---------------------------------------------------------
401 | if __name__ == "__main__":
402 | # Instantiate the Differential ViT-Base (patch16, 224) with default parameters.
403 | model = diff_vit_base_patch16_224()
404 | model.eval()
405 | dummy_input = torch.randn(1, 3, 224, 224)
406 | with torch.no_grad():
407 | output = model(dummy_input)
408 | print("Differential ViT-Base output shape:", output.shape)
409 |
410 | trainable_params = sum(p.numel() for p in model.parameters())
411 | print(f"Number of trainable parameters: {trainable_params}")
412 |
413 | from timm.models.vision_transformer import vit_base_patch16_224
414 | model = vit_base_patch16_224()
415 | trainable_params = sum(p.numel() for p in model.parameters())
416 | print(f"Number of trainable parameters: {trainable_params}")
417 |
418 |
--------------------------------------------------------------------------------
/diff_clip.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | diff_clip.py
4 |
5 | This module implements a differential version of CLIP's text encoder.
6 | The idea is to replace the standard softmax multi-head attention with a
7 | differential attention mechanism. The differential attention is defined as:
8 |
9 | DiffAttn(X) = (softmax(Q₁ K₁ᵀ / √d) − λ · softmax(Q₂ K₂ᵀ / √d)) · V
10 |
11 | where the input X ∈ ℝ^(L×N×d) (with L = sequence length, N = batch size, and d = embed_dim)
12 | is projected to query, key, and value as:
13 |
14 | Q = X W^Q, K = X W^K, V = X W^V
15 |
16 | and Q and K are split along the head dimension:
17 |
18 | [Q₁; Q₂] ∈ ℝ^(L, N, 2·h_eff, d_head) where h_eff = num_heads // 2 and d_head = d / num_heads.
19 |
20 | A learnable scalar λ is computed via:
21 |
22 | λ = exp(λ_{q1} ⋅ λ_{k1}) − exp(λ_{q2} ⋅ λ_{k2}) + λ_init
23 |
24 | The final multi-head differential attention is then computed as:
25 |
26 | MultiHeadDiffAttn(X) = Concat( LN( DiffAttn₁(X) ), …, LN( DiffAttn_h_eff(X) ) ) · W^O
27 |
28 | The overall block in the text transformer is structured as:
29 |
30 | X' = X + DifferentialAttention(LayerNorm(X))
31 | X'' = X' + MLP(LayerNorm(X'))
32 |
33 | This file defines:
34 | 1. DifferentialMultiheadAttention – a drop-in replacement for nn.MultiheadAttention.
35 | 2. DifferentialResidualAttentionBlock – a residual block using differential attention.
36 | 3. DifferentialTextTransformer – a stack of differential residual attention blocks.
37 | 4. DiffCLIP – a version of CLIP that uses DifferentialTextTransformer for text encoding.
38 |
39 | References:
40 | - CLIP from OpenAI (modified from github.com/openai/CLIP)
41 | - Differential Transformers (see paper)
42 | """
43 |
44 | import math
45 | from collections import OrderedDict
46 |
47 | import numpy as np
48 | import torch
49 | import torch.nn as nn
50 | import torch.nn.functional as F
51 | from diff_attention import diff_vit_base_patch16_224
52 |
53 | class RMSNorm(nn.Module):
54 | r"""
55 | RMSNorm normalizes the input tensor by its root-mean-square (RMS) value.
56 |
57 | Given an input x ∈ ℝ^(...×d), it computes:
58 |
59 | RMS(x) = sqrt(mean(x², dim=-1, keepdim=True) + ε)
60 | output = x / RMS(x)
61 |
62 | Optionally, a learnable weight is applied if elementwise_affine is True.
63 |
64 | Args:
65 | dim (int): Dimension to normalize.
66 | eps (float): A value added for numerical stability.
67 | elementwise_affine (bool): If True, multiply by a learnable weight.
68 | """
69 | def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True):
70 | super().__init__()
71 | self.dim = dim
72 | self.eps = eps
73 | if elementwise_affine:
74 | self.weight = nn.Parameter(torch.ones(dim))
75 | else:
76 | self.register_parameter('weight', None)
77 |
78 | def _norm(self, x):
79 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
80 |
81 | def forward(self, x: torch.Tensor) -> torch.Tensor:
82 | output = self._norm(x.float()).type_as(x)
83 | if self.weight is not None:
84 | output = output * self.weight
85 | return output
86 |
87 | def extra_repr(self) -> str:
88 | return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.weight is not None}'
89 |
90 |
91 | # ---------------------------------------------------------
92 | # Utility Layers (LayerNorm and QuickGELU)
93 | # ---------------------------------------------------------
94 | class LayerNorm(nn.LayerNorm):
95 | """Subclass torch's LayerNorm to handle fp16."""
96 | def forward(self, x: torch.Tensor):
97 | orig_type = x.dtype
98 | ret = super().forward(x.type(torch.float32))
99 | return ret.type(orig_type)
100 |
101 | class QuickGELU(nn.Module):
102 | def forward(self, x: torch.Tensor):
103 | return x * torch.sigmoid(1.702 * x)
104 |
105 | # ---------------------------------------------------------
106 | # Differential Multihead Attention for Text
107 | # ---------------------------------------------------------
108 | class DifferentialMultiheadAttention(nn.Module):
109 | r"""
110 | Differential Multihead Attention for text inputs.
111 |
112 | This module implements the differential attention mechanism with the following steps:
113 |
114 | 1. Given input X ∈ ℝ^(L×N×d) (L: sequence length, N: batch size, d: embed_dim),
115 | compute linear projections:
116 | Q = X W^Q, K = X W^K, V = X W^V.
117 |
118 | 2. Permute the input to shape (N, L, d) so that we treat the batch dimension as B.
119 |
120 | 3. Reshape Q and K to shape (B, L, 2·h_eff, d_head) and then transpose to
121 | (B, 2·h_eff, L, d_head), where h_eff = num_heads // 2 and d_head = d / num_heads.
122 |
123 | 4. Reshape V to (B, L, h_eff, 2·d_head) and then transpose to (B, h_eff, L, 2·d_head).
124 |
125 | 5. Compute the scaled dot-product attention scores for both splits:
126 | A₁ = softmax((Q₁ K₁ᵀ) / √d_head)
127 | A₂ = softmax((Q₂ K₂ᵀ) / √d_head)
128 |
129 | 6. Compute a learnable scalar:
130 | λ = exp(λ_{q1} ⋅ λ_{k1}) − exp(λ_{q2} ⋅ λ_{k2}) + λ_init
131 |
132 | 7. The differential attention output is:
133 | DiffAttn(X) = (A₁ − λ · A₂) · V
134 |
135 | 8. After applying headwise RMSNorm and a final linear projection, the output is
136 | permuted back to (L, N, d).
137 |
138 | Args:
139 | embed_dim (int): Total dimension of the model.
140 | num_heads (int): Number of attention heads (must be even).
141 | qkv_bias (bool): If True, add a bias to the Q, K, V projections.
142 | attn_drop (float): Dropout rate after softmax.
143 | proj_drop (float): Dropout rate after the output projection.
144 | lambda_init (float): Initial scalar for λ.
145 | """
146 | def __init__(self, embed_dim, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., lambda_init=0.8):
147 | super().__init__()
148 | if num_heads % 2 != 0:
149 | raise ValueError("num_heads must be even for Differential Attention.")
150 | self.embed_dim = embed_dim
151 | self.num_heads = num_heads
152 | self.effective_heads = num_heads // 2 # differential attention uses half the heads
153 | self.head_dim = embed_dim // num_heads
154 | self.scaling = self.head_dim ** -0.5
155 |
156 | # Linear layers for Q, K, V projections.
157 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
158 | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
159 | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
160 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
161 |
162 | self.attn_drop = nn.Dropout(attn_drop)
163 | self.proj_drop = nn.Dropout(proj_drop)
164 |
165 | # RMSNorm for headwise normalization; each head's output has dimension 2 * head_dim.
166 | self.diff_norm = RMSNorm(2 * self.head_dim, eps=1e-5, elementwise_affine=True)
167 |
168 | # Learnable lambda parameters (shared across heads).
169 | self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))
170 | self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))
171 | self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))
172 | self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))
173 | self.lambda_init = lambda_init
174 |
175 | def forward(self, query, key, value, attn_mask=None):
176 | """
177 | Args:
178 | query, key, value (Tensor): Input tensors of shape (L, N, embed_dim),
179 | where L is sequence length and N is batch size.
180 | attn_mask (Tensor, optional): Additive attention mask of shape (L, L).
181 |
182 | Returns:
183 | Tensor of shape (L, N, embed_dim) after applying differential multi-head attention.
184 | """
185 | # Permute input from (L, N, embed_dim) to (N, L, embed_dim)
186 | x = query.transpose(0, 1) # now (N, L, embed_dim)
187 | B, L, _ = x.shape
188 |
189 | # Compute projections.
190 | q = self.q_proj(x) # (B, L, embed_dim)
191 | k = self.k_proj(x)
192 | v = self.v_proj(x)
193 |
194 | # Reshape Q and K to (B, L, 2*h_eff, head_dim)
195 | q = q.view(B, L, 2 * self.effective_heads, self.head_dim)
196 | k = k.view(B, L, 2 * self.effective_heads, self.head_dim)
197 | # Reshape V to (B, L, h_eff, 2*head_dim)
198 | v = v.view(B, L, self.effective_heads, 2 * self.head_dim)
199 |
200 | # Transpose Q and K to (B, 2*h_eff, L, head_dim)
201 | q = q.transpose(1, 2)
202 | k = k.transpose(1, 2)
203 | # Transpose V to (B, h_eff, L, 2*head_dim)
204 | v = v.transpose(1, 2)
205 |
206 | # Scale Q.
207 | q = q * self.scaling
208 |
209 | # Compute raw attention scores: (B, 2*h_eff, L, L)
210 | attn_scores = torch.matmul(q, k.transpose(-1, -2))
211 | # If an attention mask is provided, add it.
212 | if attn_mask is not None:
213 | # attn_mask is expected to be of shape (L, L)
214 | attn_scores = attn_scores + attn_mask.unsqueeze(0).unsqueeze(0)
215 |
216 | # Compute attention probabilities.
217 | attn_probs = F.softmax(attn_scores, dim=-1)
218 | attn_probs = self.attn_drop(attn_probs)
219 |
220 | # Reshape to separate the two halves: (B, h_eff, 2, L, L)
221 | attn_probs = attn_probs.view(B, self.effective_heads, 2, L, L)
222 |
223 | # Compute λ via re-parameterization.
224 | lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1))
225 | lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2))
226 | lambda_full = lambda_1 - lambda_2 + self.lambda_init
227 |
228 | # Differential attention: subtract the second attention map scaled by λ.
229 | diff_attn = attn_probs[:, :, 0, :, :] - lambda_full * attn_probs[:, :, 1, :, :] # (B, h_eff, L, L)
230 |
231 | # Compute weighted sum with V.
232 | out = torch.matmul(diff_attn, v) # (B, h_eff, L, 2*head_dim)
233 | # Apply RMSNorm (headwise normalization) and scale by (1 - lambda_init).
234 | out = self.diff_norm(out) * (1 - self.lambda_init)
235 |
236 | # Concatenate heads: transpose to (B, L, h_eff, 2*head_dim) and then reshape to (B, L, embed_dim)
237 | out = out.transpose(1, 2).reshape(B, L, 2 * self.effective_heads * self.head_dim)
238 | # Final linear projection.
239 | out = self.out_proj(out)
240 | out = self.proj_drop(out)
241 | # Permute back to (L, N, embed_dim)
242 | out = out.transpose(0, 1)
243 | return out
244 |
245 | # ---------------------------------------------------------
246 | # Differential Residual Attention Block for Text
247 | # ---------------------------------------------------------
248 | class DifferentialResidualAttentionBlock(nn.Module):
249 | r"""
250 | Residual Attention Block using Differential Multihead Attention.
251 |
252 | This block first applies layer normalization to the input, then
253 | differential multihead attention, and adds the result to the input.
254 | Then it applies another layer normalization, a feed-forward MLP, and adds
255 | the result again.
256 |
257 | Equations:
258 | X' = X + DiffMultiheadAttention(LN(X))
259 | Y = X' + MLP(LN(X'))
260 |
261 | Args:
262 | d_model (int): Embedding dimension.
263 | n_head (int): Number of attention heads (must be even).
264 | attn_mask (Tensor, optional): Additive attention mask of shape (L, L).
265 | """
266 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
267 | super().__init__()
268 | self.attn = DifferentialMultiheadAttention(d_model, n_head)
269 | self.ln_1 = LayerNorm(d_model)
270 | self.mlp = nn.Sequential(OrderedDict([
271 | ("c_fc", nn.Linear(d_model, d_model * 4)),
272 | ("gelu", QuickGELU()),
273 | ("c_proj", nn.Linear(d_model * 4, d_model)),
274 | ]))
275 | self.ln_2 = LayerNorm(d_model)
276 | self.attn_mask = attn_mask
277 |
278 | def attention(self, x: torch.Tensor):
279 | if self.attn_mask is not None:
280 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device)
281 | # DifferentialMultiheadAttention returns output of shape (L, N, d_model)
282 | return self.attn(x, x, x, attn_mask=self.attn_mask)
283 |
284 | def forward(self, x: torch.Tensor):
285 | x = x + self.attention(self.ln_1(x))
286 | x = x + self.mlp(self.ln_2(x))
287 | return x
288 |
289 | # ---------------------------------------------------------
290 | # Differential Text Transformer
291 | # ---------------------------------------------------------
292 | class DifferentialTextTransformer(nn.Module):
293 | r"""
294 | Transformer for text built from Differential Residual Attention Blocks.
295 |
296 | Args:
297 | width (int): Embedding dimension.
298 | layers (int): Number of transformer layers.
299 | heads (int): Number of attention heads (must be even).
300 | attn_mask (Tensor, optional): Additive attention mask of shape (L, L).
301 | """
302 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
303 | super().__init__()
304 | self.width = width
305 | self.layers = layers
306 | self.resblocks = nn.Sequential(*[
307 | DifferentialResidualAttentionBlock(width, heads, attn_mask)
308 | for _ in range(layers)
309 | ])
310 |
311 | def forward(self, x: torch.Tensor):
312 | return self.resblocks(x)
313 |
314 | # ---------------------------------------------------------
315 | # DiffCLIP: Differential CLIP Model
316 | # ---------------------------------------------------------
317 | class DiffCLIP(nn.Module):
318 | r"""
319 | DiffCLIP implements a differential version of CLIP, where the text encoder is modified
320 | to use DifferentialTextTransformer.
321 |
322 | The overall architecture is similar to CLIP:
323 | - A vision encoder (can be any vision model, e.g., a differential ViT).
324 | - A text encoder that tokenizes text, adds positional embeddings, and passes through
325 | a stack of differential transformer blocks.
326 | - Final projections for image and text features.
327 |
328 | Args:
329 | embed_dim (int): Dimension of the joint embedding space.
330 | vision_width (int): Width (output dimension) of the vision encoder.
331 | vision_model (nn.Module): Vision encoder model.
332 | context_length (int): Maximum text sequence length.
333 | vocab_size (int): Vocabulary size for the text encoder.
334 | transformer_width (int): Embedding dimension of the text transformer.
335 | transformer_heads (int): Number of heads in the text transformer (must be even).
336 | transformer_layers (int): Number of layers in the text transformer.
337 | """
338 | def __init__(
339 | self,
340 | embed_dim: int,
341 | vision_width: int,
342 | vision_model: nn.Module,
343 | context_length: int,
344 | vocab_size: int,
345 | transformer_width: int,
346 | transformer_heads: int,
347 | transformer_layers: int,
348 | **kwargs,
349 | ):
350 | super().__init__()
351 | self.context_length = context_length
352 | self.vision_width = vision_width
353 |
354 | self.visual = vision_model
355 |
356 | # Use DifferentialTextTransformer instead of the standard Transformer.
357 | self.transformer = DifferentialTextTransformer(
358 | width=transformer_width,
359 | layers=transformer_layers,
360 | heads=transformer_heads,
361 | attn_mask=self.build_attention_mask(),
362 | )
363 |
364 | self.vocab_size = vocab_size
365 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
366 | self.positional_embedding = nn.Parameter(
367 | torch.empty(self.context_length, transformer_width)
368 | )
369 | self.ln_final = LayerNorm(transformer_width)
370 |
371 | self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim))
372 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
373 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
374 |
375 | self.initialize_parameters()
376 |
377 | def initialize_parameters(self):
378 | nn.init.normal_(self.token_embedding.weight, std=0.02)
379 | nn.init.normal_(self.positional_embedding, std=0.01)
380 |
381 | # Initialize transformer parameters.
382 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
383 | attn_std = self.transformer.width ** -0.5
384 | fc_std = (2 * self.transformer.width) ** -0.5
385 | for block in self.transformer.resblocks:
386 | # For our differential attention blocks, initialize the linear layers.
387 | nn.init.normal_(block.attn.q_proj.weight, std=attn_std)
388 | nn.init.normal_(block.attn.k_proj.weight, std=attn_std)
389 | nn.init.normal_(block.attn.v_proj.weight, std=attn_std)
390 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
391 | nn.init.normal_(block.mlp[0].weight, std=fc_std) # c_fc layer
392 | nn.init.normal_(block.mlp[2].weight, std=proj_std) # c_proj layer
393 |
394 | nn.init.normal_(self.image_projection, std=self.vision_width ** -0.5)
395 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
396 |
397 | def build_attention_mask(self):
398 | # Create a causal attention mask for text.
399 | # The mask is of shape (context_length, context_length) with -inf above the diagonal.
400 | mask = torch.empty(self.context_length, self.context_length)
401 | mask.fill_(float("-inf"))
402 | mask.triu_(1)
403 | return mask
404 |
405 | def encode_image(self, image):
406 | x = self.visual(image)
407 | x = x @ self.image_projection
408 | return x
409 |
410 | def encode_text(self, text):
411 | # text: (batch_size, context_length)
412 | x = self.token_embedding(text) # (batch_size, context_length, transformer_width)
413 | x = x + self.positional_embedding
414 | # Permute to (context_length, batch_size, transformer_width)
415 | x = x.permute(1, 0, 2)
416 | x = self.transformer(x)
417 | # Permute back to (batch_size, context_length, transformer_width)
418 | x = x.permute(1, 0, 2)
419 | x = self.ln_final(x)
420 | # Extract the features at the position of the end-of-text token (assumed to be the max token index in each sequence).
421 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
422 | return x
423 |
424 | def forward(self, image, text):
425 | image_embed = self.encode_image(image)
426 | text_embed = self.encode_text(text)
427 | return {
428 | "image_embed": image_embed,
429 | "text_embed": text_embed,
430 | "logit_scale": self.logit_scale.exp(),
431 | }
432 |
433 |
434 | def DiffCLIP_VITB16(**kwargs):
435 | """
436 | Factory function to build DiffCLIP with a ViT-B/16 vision encoder.
437 | This function creates a vision model using the differential vision transformer
438 | "diff_vit_base_patch16_224" and then builds DiffCLIP.
439 |
440 | Args:
441 | **kwargs: Additional keyword arguments.
442 |
443 | Returns:
444 | DiffCLIP model.
445 | """
446 | # Create a vision model using the differential vision transformer
447 | vision_model = diff_vit_base_patch16_224(num_classes=0)
448 | model = DiffCLIP(
449 | embed_dim=512,
450 | vision_width=768,
451 | vision_model=vision_model,
452 | context_length=77,
453 | vocab_size=49408,
454 | transformer_width=512,
455 | transformer_heads=8, # must be even
456 | transformer_layers=12,
457 | **kwargs,
458 | )
459 | return model
460 |
461 |
462 | # ---------------------------------------------------------
463 | # Main test function
464 | # ---------------------------------------------------------
465 | if __name__ == "__main__":
466 | # Create dummy inputs.
467 | dummy_image = torch.randn(2, 3, 224, 224) # batch of 2 images
468 | # Create dummy text tokens (e.g., 77 tokens per sequence). For simplicity, we simulate tokens as random integers.
469 | dummy_text = torch.randint(low=0, high=49408, size=(2, 77))
470 |
471 | # Instantiate DiffCLIP using the DiffCLIP_VITB16 factory function.
472 | model = DiffCLIP_VITB16()
473 | model.eval()
474 |
475 | with torch.no_grad():
476 | outputs = model(dummy_image, dummy_text)
477 |
478 | print("DiffCLIP output keys:", outputs.keys())
479 | print("Image embed shape:", outputs["image_embed"].shape)
480 | print("Text embed shape:", outputs["text_embed"].shape)
481 | print("Logit scale:", outputs["logit_scale"])
482 |
483 | trainable_params = sum(p.numel() for p in model.parameters())
484 | print(f"Number of trainable parameters: {trainable_params}")
485 |
486 | # Note: Original CLIP comparison removed as it's not relevant for the public release
487 |
488 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.9.0
2 | timm>=0.4.12
3 | numpy>=1.19.2
4 | huggingface_hub>=0.11.0
5 | requests>=2.26.0
6 | Pillow>=8.0.0
7 | ftfy>=6.0.0
8 | regex>=2022.3.15
--------------------------------------------------------------------------------
/test_models.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | test_models.py
4 |
5 | This script demonstrates zero-shot prediction with DiffCLIP. It:
6 | 1. Downloads the DiffCLIP_ViTB16_CC12M model from Hugging Face
7 | 2. Loads an image from COCO dataset
8 | 3. Performs zero-shot classification
9 |
10 | Usage:
11 | python test_models.py
12 | """
13 |
14 | import os
15 | import torch
16 | import numpy as np
17 | import requests
18 | from PIL import Image
19 | from io import BytesIO
20 | from huggingface_hub import hf_hub_download
21 | from tokenizer import SimpleTokenizer
22 |
23 | # Import the DiffCLIP models
24 | from diff_clip import DiffCLIP_VITB16
25 |
26 |
27 | def download_model():
28 | """
29 | Download the DiffCLIP_ViTB16_CC12M model from Hugging Face to a local folder.
30 | Returns the path to the checkpoint file.
31 | """
32 | print("Downloading model from Hugging Face...")
33 | model_id = "hammh0a/DiffCLIP_ViTB16_CC12M"
34 | local_dir = "./DiffCLIP_ViTB16_CC12M"
35 |
36 | os.makedirs(local_dir, exist_ok=True)
37 |
38 | # Download the checkpoint file
39 | checkpoint_path = hf_hub_download(
40 | repo_id=model_id,
41 | filename="checkpoint_best.pt",
42 | local_dir=local_dir,
43 | local_dir_use_symlinks=False
44 | )
45 |
46 | print(f"Model downloaded to {checkpoint_path}")
47 | return checkpoint_path
48 |
49 |
50 | def load_model(checkpoint_path):
51 | """
52 | Load the DiffCLIP model and checkpoint.
53 | """
54 | print("Loading model...")
55 | # Create a model instance
56 | model = DiffCLIP_VITB16()
57 |
58 | # Load the checkpoint
59 | checkpoint = torch.load(checkpoint_path, map_location="cpu")
60 |
61 | # If the model was saved with DataParallel, we need to handle that
62 | if list(checkpoint["state_dict"].keys())[0].startswith("module."):
63 | # Create a new state_dict without the 'module.' prefix
64 | new_state_dict = {k[7:]: v for k, v in checkpoint["state_dict"].items()}
65 | load_status = model.load_state_dict(new_state_dict)
66 | else:
67 | load_status = model.load_state_dict(checkpoint["state_dict"])
68 |
69 | print(f"Model loaded with status: {load_status}")
70 | return model
71 |
72 |
73 | def load_image_from_coco():
74 | """
75 | Load a sample image from COCO dataset.
76 | """
77 | print("Loading sample image from COCO...")
78 | # A sample image URL from COCO
79 | coco_image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
80 |
81 | # Download the image
82 | response = requests.get(coco_image_url)
83 | img = Image.open(BytesIO(response.content))
84 |
85 | # Resize and preprocess for the model
86 | img = img.convert("RGB").resize((224, 224))
87 | img_tensor = torch.tensor(np.array(img)).permute(2, 0, 1).float() # (3, 224, 224)
88 | img_tensor = img_tensor / 255.0 # Normalize to [0, 1]
89 |
90 | # Apply ImageNet normalization
91 | mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
92 | std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
93 | img_tensor = (img_tensor - mean) / std
94 |
95 | return img_tensor.unsqueeze(0), img # Return tensor and PIL image
96 |
97 |
98 | def zero_shot_prediction(model, image_tensor, tokenizer, classes):
99 | """
100 | Perform zero-shot prediction on an image with the provided model.
101 | """
102 | print("Performing zero-shot prediction...")
103 | model.eval()
104 |
105 | # Create text prompts
106 | prompts = [f"a photo of a {label}" for label in classes]
107 |
108 | # Tokenize text
109 | text_tokens = tokenizer(prompts)
110 |
111 | # Put everything on the same device
112 | device = "cuda" if torch.cuda.is_available() else "cpu"
113 | model = model.to(device)
114 | image_tensor = image_tensor.to(device)
115 | text_tokens = text_tokens.to(device)
116 |
117 | # Get image and text features
118 | with torch.no_grad():
119 | outputs = model(image_tensor, text_tokens)
120 | image_features = outputs["image_embed"]
121 | text_features = outputs["text_embed"]
122 | logit_scale = outputs["logit_scale"]
123 |
124 | # Normalize features
125 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
126 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
127 |
128 | # Calculate similarity scores
129 | similarity = (logit_scale * image_features @ text_features.T).softmax(dim=-1)
130 |
131 | # Get top predictions
132 | values, indices = similarity[0].topk(min(5, len(classes)))
133 |
134 | # Return prediction results
135 | predictions = [(classes[idx], values[i].item()) for i, idx in enumerate(indices)]
136 | return predictions
137 |
138 |
139 | def main():
140 | # Define some classes for zero-shot prediction
141 | coco_classes = [
142 | "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
143 | "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter",
144 | "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear",
145 | "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase",
146 | "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat",
147 | "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
148 | "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
149 | "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut",
150 | "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
151 | "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave",
152 | "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors",
153 | "teddy bear", "hair drier", "toothbrush"
154 | ]
155 |
156 | # Initialize tokenizer
157 | tokenizer = SimpleTokenizer()
158 |
159 | # Download and load the model
160 | checkpoint_path = download_model()
161 | model = load_model(checkpoint_path)
162 |
163 | # Load a sample image
164 | image_tensor, pil_image = load_image_from_coco()
165 |
166 | # Save the image for reference
167 | # pil_image.save("coco_sample.jpg")
168 | # print("Sample image saved as 'coco_sample.jpg'")
169 |
170 | # Perform zero-shot prediction
171 | predictions = zero_shot_prediction(model, image_tensor, tokenizer, coco_classes)
172 |
173 | # Print results
174 | print("\nZero-shot prediction results:")
175 | for i, (label, score) in enumerate(predictions):
176 | print(f"{i+1}. {label}: {score:.4f}")
177 |
178 |
179 | if __name__ == "__main__":
180 | main()
--------------------------------------------------------------------------------
/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # Modified from github.com/openai/CLIP
8 | import gzip
9 | import html
10 | import os
11 | from functools import lru_cache
12 |
13 | import ftfy
14 | import regex as re
15 | import torch
16 |
17 |
18 | @lru_cache()
19 | def default_bpe():
20 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
21 |
22 |
23 | @lru_cache()
24 | def bytes_to_unicode():
25 | """
26 | Returns list of utf-8 byte and a corresponding list of unicode strings.
27 | The reversible bpe codes work on unicode strings.
28 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
29 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
30 | This is a signficant percentage of your normal, say, 32K bpe vocab.
31 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
32 | And avoids mapping to whitespace/control characters the bpe code barfs on.
33 | """
34 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
35 | cs = bs[:]
36 | n = 0
37 | for b in range(2**8):
38 | if b not in bs:
39 | bs.append(b)
40 | cs.append(2**8+n)
41 | n += 1
42 | cs = [chr(n) for n in cs]
43 | return dict(zip(bs, cs))
44 |
45 |
46 | def get_pairs(word):
47 | """Return set of symbol pairs in a word.
48 | Word is represented as tuple of symbols (symbols being variable-length strings).
49 | """
50 | pairs = set()
51 | prev_char = word[0]
52 | for char in word[1:]:
53 | pairs.add((prev_char, char))
54 | prev_char = char
55 | return pairs
56 |
57 |
58 | def basic_clean(text):
59 | text = ftfy.fix_text(text)
60 | text = html.unescape(html.unescape(text))
61 | return text.strip()
62 |
63 |
64 | def whitespace_clean(text):
65 | text = re.sub(r'\s+', ' ', text)
66 | text = text.strip()
67 | return text
68 |
69 |
70 | class SimpleTokenizer(object):
71 | def __init__(self, bpe_path: str = default_bpe()):
72 | self.byte_encoder = bytes_to_unicode()
73 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
74 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
75 | merges = merges[1:49152-256-2+1]
76 | merges = [tuple(merge.split()) for merge in merges]
77 | vocab = list(bytes_to_unicode().values())
78 | vocab = vocab + [v+'' for v in vocab]
79 | for merge in merges:
80 | vocab.append(''.join(merge))
81 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
82 | self.encoder = dict(zip(vocab, range(len(vocab))))
83 | self.decoder = {v: k for k, v in self.encoder.items()}
84 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
85 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
86 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
87 |
88 | def bpe(self, token):
89 | if token in self.cache:
90 | return self.cache[token]
91 | word = tuple(token[:-1]) + ( token[-1] + '',)
92 | pairs = get_pairs(word)
93 |
94 | if not pairs:
95 | return token+''
96 |
97 | while True:
98 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
99 | if bigram not in self.bpe_ranks:
100 | break
101 | first, second = bigram
102 | new_word = []
103 | i = 0
104 | while i < len(word):
105 | try:
106 | j = word.index(first, i)
107 | new_word.extend(word[i:j])
108 | i = j
109 | except:
110 | new_word.extend(word[i:])
111 | break
112 |
113 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
114 | new_word.append(first+second)
115 | i += 2
116 | else:
117 | new_word.append(word[i])
118 | i += 1
119 | new_word = tuple(new_word)
120 | word = new_word
121 | if len(word) == 1:
122 | break
123 | else:
124 | pairs = get_pairs(word)
125 | word = ' '.join(word)
126 | self.cache[token] = word
127 | return word
128 |
129 | def encode(self, text):
130 | bpe_tokens = []
131 | text = whitespace_clean(basic_clean(text)).lower()
132 | for token in re.findall(self.pat, text):
133 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
134 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
135 | return bpe_tokens
136 |
137 | def decode(self, tokens):
138 | text = ''.join([self.decoder[token] for token in tokens])
139 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
140 | return text
141 |
142 | def __call__(self, texts, context_length=77):
143 | if isinstance(texts, str):
144 | texts = [texts]
145 |
146 | sot_token = self.encoder["<|startoftext|>"]
147 | eot_token = self.encoder["<|endoftext|>"]
148 | all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
149 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
150 |
151 | for i, tokens in enumerate(all_tokens):
152 | tokens = tokens[:context_length]
153 | result[i, :len(tokens)] = torch.tensor(tokens)
154 |
155 | if len(result) == 1:
156 | return result[0]
157 | return result
--------------------------------------------------------------------------------