├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── assets
├── icon.jpg
├── model_pipe.png
└── rmis_curve.png
└── models
├── base.py
├── fisher.py
├── images.py
├── mae.py
└── modules.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.py text eol=lf
2 | *.md text eol=lf
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ckpt/
2 | hf*
3 | demo*
4 | __pycache__
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Anbai Jiang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | FISHER
4 |
5 |
6 |
20 |
21 |
22 |
23 |
24 |

25 |
26 |
27 | ## 🔥🔥🔥 Updates
28 |
29 | - [2025.7.25] FISHER is now integrated on HuggingFace🤗.
30 |
31 | - [2025.7.23] We release the inference code and checkpoints for tiny, mini and small.
32 |
33 | ## Introduction
34 |
35 |
36 |

37 |
38 |
39 | FISHER is a **F**oundation model for **I**ndustrial **S**ignal compre**HE**nsive **R**epresentation, which models heterogeneous industrial signals (sound, vibration, voltage, etc.) in a unified manner. FISHER accepts arbitrary sampling rates and models the increment of sampling rate as the concatenation of sub-band information, which first splits a STFT spectrogram into sub-bands before processsing it by the ViT encoder. FISHER is trained by teacher student EMA self-distillation.
40 |
41 | To evaluate the model, we develop the RMIS benchmark, which will also be open-sourced in the near future. FISHER achieves the SOTA performances on the RMIS benchmark with much more efficient scaling properties.
42 |
43 | ## Checkpoints
44 |
45 | We release the checkpoints of FISHER-tiny, FISHER-mini and FISHER-small.
46 |
47 | | Version| ☁️ Tsinghua Cloud | 🤗 HuggingFace | wisemodel
48 | |------------| :------------: | :--------: | :--------: |
49 | | FISHER-tiny | [Link](https://cloud.tsinghua.edu.cn/f/630a4b1b2962481a9150/?dl=1) | [Link](https://huggingface.co/jiangab/FISHER-tiny-0723) | [Link](https://wisemodel.cn/models/jiangab/FISHER-tiny-0723)
50 | | FISHER-mini | [Link](https://cloud.tsinghua.edu.cn/f/60b3bfc0977f45f48dff/?dl=1) | [Link](https://huggingface.co/jiangab/FISHER-mini-0723) | [Link](https://wisemodel.cn/models/jiangab/FISHER-mini-0723)
51 | | FISHER-small | [Link](https://cloud.tsinghua.edu.cn/f/f997a6932b614046915e/?dl=1) | [Link](https://huggingface.co/jiangab/FISHER-small-0723) | [Link](https://wisemodel.cn/models/jiangab/FISHER-small-0723)
52 |
53 | ## Inference
54 |
55 | Please use the following code to infer the signal representation by FISHER.
56 |
57 | ```python
58 | import torch
59 | import torchaudio
60 | import torch.nn.functional as F
61 | from models.fisher import FISHER
62 |
63 | wav, sr = torchaudio.load('/path/to/local/signal.wav')
64 | # You can replace it with your custom loading function for other signals
65 |
66 | wav = wav - wav.mean()
67 | STFT = torchaudio.transforms.Spectrogram(
68 | n_fft=25 * sr // 1000,
69 | win_length=None,
70 | hop_length=10 * sr // 1000,
71 | power=1,
72 | center=False
73 | )
74 | spec = torch.log(torch.abs(STFT(wav)) + 1e-10)
75 | spec = spec.transpose(-2, -1) # [1, time, freq]
76 | spec = (spec + 3.017344307886898) / (2.1531635155379805 * 2)
77 |
78 | model_path = '/path/to/local/fisher/model.pt' # Please download the checkpoint in advance.
79 | model = FISHER.from_pretrained(model_path)
80 | model = model.cuda()
81 | model.eval()
82 |
83 | # time-wise cutoff
84 | if spec.shape[-2] > 1024:
85 | spec = spec[:, :1024]
86 | # freq-wise padding
87 | if spec.shape[-1] < model.cfg.band_width:
88 | spec = F.pad(spec, (0, model.cfg.band_width - spec.shape[-1]))
89 | spec = spec.unsqueeze(1).cuda()
90 |
91 | with torch.no_grad():
92 | # Use autocast for mixed precision inference. You can disable it for full precision.
93 | with torch.autocast('cuda'):
94 | repre = model.extract_features(spec)
95 | print(repre.shape)
96 | ```
97 |
98 | ## Acknowledgements
99 |
100 | FISHER is developed based on [EAT](https://github.com/cwx-worst-one/EAT) and [fairseq](https://github.com/facebookresearch/fairseq). We thank these authors for open-sourcing their works.
101 |
102 | ## Citation
103 |
104 | If you find FISHER useful, please cite the following paper.
105 |
106 | ```bibtex
107 | @article{fan2025fisher,
108 | title={FISHER: A Foundation Model for Multi-Modal Industrial Signal Comprehensive Representation},
109 | author={Fan, Pingyi and Jiang, Anbai and Zhang, Shuwei and Lv, Zhiqiang and Han, Bing and Zheng, Xinhu and Liang, Wenrui and Li, Junjie and Zhang, Wei-Qiang and Qian, Yanmin and Chen, Xie and Lu, Cheng and Liu, Jia},
110 | journal={arXiv preprint arXiv:2507.16696},
111 | year={2025}
112 | }
113 | ```
114 |
--------------------------------------------------------------------------------
/assets/icon.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianganbai/FISHER/31ac067ce5414565902d70bcb0769848fdf0305d/assets/icon.jpg
--------------------------------------------------------------------------------
/assets/model_pipe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianganbai/FISHER/31ac067ce5414565902d70bcb0769848fdf0305d/assets/model_pipe.png
--------------------------------------------------------------------------------
/assets/rmis_curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianganbai/FISHER/31ac067ce5414565902d70bcb0769848fdf0305d/assets/rmis_curve.png
--------------------------------------------------------------------------------
/models/base.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from collections import namedtuple
8 | from dataclasses import dataclass, field
9 | from functools import partial
10 | from omegaconf import MISSING, II
11 | from typing import Optional, Callable
12 | from enum import Enum, auto
13 |
14 | from .modules import D2vDecoderConfig
15 |
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | class Modality(Enum):
21 | AUDIO = auto()
22 | IMAGE = auto()
23 | TEXT = auto()
24 |
25 |
26 | @dataclass
27 | class D2vModalityConfig:
28 | type: Modality = MISSING
29 | prenet_depth: int = 0
30 | prenet_layerdrop: float = 0.0
31 | prenet_dropout: float = 0.0
32 | start_drop_path_rate: float = 0.0
33 | end_drop_path_rate: float = 0.0
34 |
35 | num_extra_tokens: int = 1
36 | init_extra_token_zero: bool = False
37 |
38 | mask_noise_std: float = 0.01
39 | mask_prob_min: Optional[float] = None
40 | mask_prob: float = 0.8
41 | inverse_mask: bool = True
42 | mask_prob_adjust: float = 0.07
43 | keep_masked_pct: float = 0.0
44 | flexible_mask: bool = False
45 |
46 | mask_length: int = 5
47 | add_masks: bool = False
48 | remove_masks: bool = False
49 | mask_dropout: float = 0.0
50 | encoder_zero_mask: bool = True
51 |
52 | mask_channel_prob: float = 0.0
53 | mask_channel_length: int = 64
54 |
55 | ema_local_encoder: bool = True # used in data2vec_multi
56 | ema_local_decoder: bool = False
57 | local_grad_mult: float = 1.0
58 | flatten: str = 'freq'
59 | max_length: int = 128
60 | max_freq: int = 50
61 |
62 | use_alibi_encoder: bool = False
63 | alibi_scale: float = 1.0
64 | learned_alibi: bool = False
65 | alibi_max_pos: Optional[int] = None
66 | learned_alibi_scale: bool = False
67 | learned_alibi_scale_per_head: bool = False
68 | learned_alibi_scale_per_layer: bool = False
69 |
70 | num_alibi_heads: int = II("model.num_heads")
71 | model_depth: int = II("model.depth")
72 |
73 | decoder: Optional[D2vDecoderConfig] = field(default_factory=lambda *x: D2vDecoderConfig())
74 |
75 |
76 | MaskSeed = namedtuple("MaskSeed", ["seed", "update", "ids"])
77 | MaskInfo = namedtuple("MaskInfo", ["x_unmasked", "mask", "ids_restore", "ids_keep"])
78 |
79 |
80 | class ModalitySpecificEncoder(nn.Module):
81 | def __init__(
82 | self,
83 | modality_cfg: D2vModalityConfig,
84 | embed_dim: int,
85 | local_encoder: nn.Module,
86 | project_features: nn.Module,
87 | fixed_positional_encoder: Optional[nn.Module],
88 | relative_positional_encoder: Optional[nn.Module], # None
89 | context_encoder: nn.Module,
90 | decoder: Optional[nn.Module],
91 | get_alibi_bias: Optional[Callable[[int, int, str, str], torch.Tensor]],
92 | ):
93 | super().__init__()
94 |
95 | self.modality_cfg = modality_cfg
96 | self.local_encoder = local_encoder
97 | self.project_features = project_features
98 | self.fixed_positional_encoder = fixed_positional_encoder
99 | self.relative_positional_encoder = relative_positional_encoder
100 | self.context_encoder = context_encoder
101 |
102 | self.decoder = decoder
103 | self.get_alibi_bias = get_alibi_bias if modality_cfg.use_alibi_encoder else None
104 |
105 | self.local_grad_mult = self.modality_cfg.local_grad_mult
106 |
107 | self.extra_tokens = None
108 | if modality_cfg.num_extra_tokens > 0:
109 | self.extra_tokens = nn.Parameter(
110 | torch.zeros(1, modality_cfg.num_extra_tokens, embed_dim)
111 | )
112 | if not modality_cfg.init_extra_token_zero:
113 | nn.init.normal_(self.extra_tokens)
114 | elif self.extra_tokens.size(1) > 1:
115 | nn.init.normal_(self.extra_tokens[:, 1:])
116 |
117 | self.alibi_scale = None
118 | if self.get_alibi_bias is not None:
119 | self.alibi_scale = nn.Parameter(
120 | torch.full(
121 | (
122 | (modality_cfg.prenet_depth + modality_cfg.model_depth)
123 | if modality_cfg.learned_alibi_scale_per_layer
124 | else 1,
125 | 1,
126 | self.modality_cfg.num_alibi_heads
127 | if modality_cfg.learned_alibi_scale_per_head
128 | else 1,
129 | 1,
130 | 1,
131 | ),
132 | modality_cfg.alibi_scale,
133 | dtype=torch.float,
134 | ),
135 | requires_grad=modality_cfg.learned_alibi_scale,
136 | )
137 |
138 | if modality_cfg.learned_alibi and self.get_alibi_bias is not None:
139 | assert modality_cfg.alibi_max_pos is not None
140 | alibi_bias = self.get_alibi_bias(
141 | batch_size=1,
142 | time_steps=modality_cfg.alibi_max_pos,
143 | heads=modality_cfg.num_alibi_heads,
144 | scale=1.0,
145 | dtype=torch.float,
146 | device="cpu",
147 | )
148 | self.alibi_bias = nn.Parameter(alibi_bias)
149 | self.get_alibi_bias = partial(
150 | _learned_alibi_bias, alibi_bias=self.alibi_bias
151 | )
152 |
153 | def upgrade_state_dict_named(self, state_dict, name):
154 | k = f"{name}.alibi_scale"
155 | if k in state_dict and state_dict[k].dim() == 4:
156 | state_dict[k] = state_dict[k].unsqueeze(0)
157 |
158 | return state_dict
159 |
160 | def convert_padding_mask(self, x, padding_mask):
161 | return padding_mask
162 |
163 | def local_features(self, features):
164 | x = self.local_encoder(features)
165 | x = self.project_features(x) # nn.Identity()
166 | return x
167 |
168 | def contextualized_features(
169 | self,
170 | x,
171 | padding_mask,
172 | mask, # True
173 | remove_masked, # train: True; infer: False
174 | clone_batch: int = 1,
175 | mask_seeds: Optional[torch.Tensor] = None,
176 | precomputed_mask=None,
177 | ):
178 |
179 | if padding_mask is not None:
180 | padding_mask = self.convert_padding_mask(x, padding_mask) # [b,t,f] => [b,seq]
181 |
182 | local_features = x
183 | if mask and clone_batch == 1:
184 | local_features = local_features.clone()
185 |
186 | orig_B, orig_T, _ = x.shape
187 | pre_mask_B = orig_B
188 | mask_info = None
189 |
190 | x_pos = None
191 | # x: [B, seq_len, embed_dim]
192 | if self.fixed_positional_encoder is not None: # models.modules.FixPositionalEncoder
193 | x = x + self.fixed_positional_encoder(x, padding_mask)[:, :x.size(1), :]
194 |
195 | if self.relative_positional_encoder is not None:
196 | x_pos = self.relative_positional_encoder(x)
197 |
198 | masked_padding_mask = padding_mask
199 | if mask and remove_masked: # only pass masked to student
200 | x = mask_info.x_unmasked
201 | if x_pos is not None:
202 | x = x + gather_unmasked(x_pos, mask_info)
203 |
204 | # padding_mask: [bs, seq_len]
205 | # valid: False; padded: True
206 | if padding_mask is not None and padding_mask.any():
207 | # retrieve padding_mask for unmasked patch
208 | masked_padding_mask = gather_unmasked_mask(padding_mask, mask_info)
209 | if not masked_padding_mask.any():
210 | masked_padding_mask = None
211 | else:
212 | masked_padding_mask = None
213 |
214 | elif x_pos is not None:
215 | x = x + x_pos
216 |
217 | alibi_bias = None
218 | alibi_scale = self.alibi_scale
219 |
220 | if self.get_alibi_bias is not None:
221 | alibi_bias = self.get_alibi_bias(
222 | batch_size=pre_mask_B,
223 | time_steps=orig_T,
224 | heads=self.modality_cfg.num_alibi_heads,
225 | dtype=torch.float32,
226 | device=x.device,
227 | )
228 |
229 | if alibi_scale is not None:
230 | alibi_scale = alibi_scale.clamp_min(0)
231 | if alibi_scale.size(0) == 1:
232 | alibi_bias = alibi_bias * alibi_scale.squeeze(0).type_as(alibi_bias)
233 | alibi_scale = None
234 |
235 | if clone_batch > 1:
236 | alibi_bias = alibi_bias.repeat_interleave(clone_batch, 0)
237 |
238 | if mask_info is not None and remove_masked:
239 | alibi_bias = masked_alibi(alibi_bias, mask_info)
240 |
241 | if self.extra_tokens is not None:
242 | num = self.extra_tokens.size(1)
243 | x = torch.cat([self.extra_tokens.expand(x.size(0), -1, -1), x], dim=1)
244 | if masked_padding_mask is not None:
245 | # B x T
246 | masked_padding_mask = F.pad(masked_padding_mask, (num, 0))
247 | if alibi_bias is not None:
248 | # B x H x T x T
249 | alibi_bias = F.pad(alibi_bias, (num, 0, num, 0))
250 |
251 | x = self.context_encoder(
252 | x,
253 | masked_padding_mask,
254 | alibi_bias,
255 | alibi_scale[: self.modality_cfg.prenet_depth]
256 | if alibi_scale is not None
257 | else None,
258 | )
259 |
260 | return {
261 | "x": x,
262 | "local_features": local_features,
263 | "padding_mask": masked_padding_mask,
264 | "alibi_bias": alibi_bias,
265 | "alibi_scale": alibi_scale[self.modality_cfg.prenet_depth :]
266 | if alibi_scale is not None and alibi_scale.size(0) > 1
267 | else alibi_scale,
268 | "encoder_mask": mask_info,
269 | }
270 |
271 | def forward(
272 | self,
273 | features,
274 | padding_mask,
275 | mask: bool,
276 | remove_masked: bool,
277 | clone_batch: int = 1,
278 | mask_seeds: Optional[torch.Tensor] = None,
279 | precomputed_mask=None,
280 | ):
281 | x = self.local_features(features) # patch embed
282 | # x: [bs, time*freq, embed_dim], e.g. [12, 512, 768]
283 | out = self.contextualized_features(
284 | x,
285 | padding_mask,
286 | mask,
287 | remove_masked,
288 | clone_batch,
289 | mask_seeds,
290 | precomputed_mask,
291 | ) # add mask, discarded masked, context encoder (only layer norm)
292 | return out
293 |
294 | def reset_parameters(self):
295 | pass
296 |
297 | def remove_pretraining_modules(self, keep_decoder=False):
298 | if not keep_decoder:
299 | self.decoder = None
300 |
301 |
302 | def get_annealed_rate(start, end, curr_step, total_steps):
303 | if curr_step >= total_steps:
304 | return end
305 | r = end - start
306 | pct_remaining = 1 - curr_step / total_steps
307 | return end - r * pct_remaining
308 |
309 |
310 | def gather_unmasked(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
311 | return torch.gather(
312 | x,
313 | dim=1,
314 | index=mask_info.ids_keep,
315 | )
316 |
317 |
318 | def gather_unmasked_mask(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
319 | return torch.gather(
320 | x,
321 | dim=1,
322 | index=mask_info.ids_keep[..., 0], # ignore the feature dimension
323 | )
324 |
325 |
326 | def get_alibi(
327 | max_positions: int,
328 | attention_heads: int,
329 | dims: int = 1,
330 | distance: str = "manhattan",
331 | ):
332 | def get_slopes(n):
333 | def get_slopes_power_of_2(n):
334 | start = 2 ** (-(2 ** -(math.log2(n) - 3)))
335 | ratio = start
336 | return [start * ratio**i for i in range(n)]
337 |
338 | # In the paper, we only train models that have 2^a heads for some
339 | # a. This function has some good properties that only occur when
340 | # the input is a power of 2. To maintain that even when the number
341 | # of heads is not a power of 2, we use this workaround.
342 | if math.log2(n).is_integer():
343 | return get_slopes_power_of_2(n)
344 | else:
345 | closest_power_of_2 = 2 ** math.floor(math.log2(n))
346 | return (
347 | get_slopes_power_of_2(closest_power_of_2)
348 | + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
349 | )
350 |
351 | maxpos = max_positions
352 | attn_heads = attention_heads
353 | slopes = torch.Tensor(get_slopes(attn_heads))
354 |
355 | if dims == 1:
356 | # prepare alibi position linear bias. Note that wav2vec2 is non
357 | # autoregressive model so we want a symmetric mask with 0 on the
358 | # diagonal and other wise linear decreasing valuees
359 | pos_bias = (
360 | torch.abs(
361 | torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)
362 | )
363 | * -1
364 | )
365 | elif dims == 2:
366 | if distance == "manhattan":
367 | df = lambda x1, y1, x2, y2: abs(x1 - x2) + abs(y1 - y2)
368 | elif distance == "euclidean":
369 | df = lambda x1, y1, x2, y2: math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
370 |
371 | n = math.sqrt(max_positions)
372 | assert n.is_integer(), n
373 | n = int(n)
374 |
375 | pos_bias = torch.zeros((max_positions, max_positions))
376 |
377 | for i in range(n):
378 | for j in range(n):
379 | for k in range(n):
380 | for l in range(n):
381 | new_x = i * n + j
382 | new_y = k * n + l
383 | pos_bias[new_x, new_y] = -df(i, j, k, l)
384 |
385 | else:
386 | raise Exception(f"unsupported number of alibi dims: {dims}")
387 |
388 | alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(
389 | attn_heads, -1, -1
390 | )
391 |
392 | return alibi_bias
393 |
394 |
395 | def get_alibi_bias(
396 | alibi_biases,
397 | batch_size,
398 | time_steps,
399 | heads,
400 | dtype,
401 | device,
402 | dims=1,
403 | distance="manhattan",
404 | ):
405 | cache_key = f"{dims}_{heads}_{distance}"
406 |
407 | buffered = alibi_biases.get(cache_key, None)
408 |
409 | target_size = heads * batch_size
410 | if (
411 | buffered is None
412 | or buffered.size(0) < target_size
413 | or buffered.size(1) < time_steps
414 | or buffered.dtype != dtype
415 | or buffered.device != device
416 | ):
417 | bt = max(time_steps, buffered.size(1) if buffered is not None else 0)
418 | bn = max(target_size, buffered.size(0) if buffered is not None else 0) // heads
419 |
420 | buffered = (
421 | get_alibi(bt, heads, dims=dims, distance=distance)
422 | .to(dtype=dtype, device=device)
423 | .repeat(bn, 1, 1)
424 | )
425 |
426 | alibi_biases[cache_key] = buffered
427 |
428 | b = buffered[:target_size, :time_steps, :time_steps]
429 | b = b.view(batch_size, heads, time_steps, time_steps)
430 | return b
431 |
432 |
433 | def _learned_alibi_bias(
434 | alibi_bias,
435 | batch_size,
436 | time_steps,
437 | heads,
438 | scale,
439 | dtype,
440 | device,
441 | ):
442 | assert alibi_bias.size(1) == heads, alibi_bias.shape
443 | assert alibi_bias.dtype == dtype, alibi_bias.dtype
444 | assert alibi_bias.device == device, alibi_bias.device
445 |
446 | if alibi_bias.size(-1) < time_steps:
447 | psz = math.ceil((time_steps - alibi_bias.size(-1)) / 2)
448 | alibi_bias = F.pad(alibi_bias, (psz, psz, psz, psz), mode="replicate")
449 |
450 | alibi_bias = alibi_bias.expand(batch_size, -1, -1, -1) * scale
451 | return alibi_bias[..., :time_steps, :time_steps]
452 |
453 |
454 | def masked_alibi(alibi_bias, mask_info):
455 | H = alibi_bias.size(1)
456 |
457 | orig_bias = alibi_bias
458 |
459 | index = mask_info.ids_keep.unsqueeze(1)[..., 0].unsqueeze(-1)
460 | alibi_bias = torch.gather(
461 | orig_bias,
462 | dim=-2,
463 | index=index.expand(-1, H, -1, mask_info.ids_restore.size(1)),
464 | )
465 | alibi_bias = torch.gather(
466 | alibi_bias,
467 | dim=-1,
468 | index=index.transpose(-1, -2).expand(-1, H, alibi_bias.size(-2), -1),
469 | )
470 |
471 | return alibi_bias
472 |
--------------------------------------------------------------------------------
/models/fisher.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging
3 | import numpy as np
4 | import torch.nn as nn
5 |
6 | from functools import partial
7 | from einops import rearrange
8 | from typing import Callable, Optional
9 | from dataclasses import dataclass, field, is_dataclass
10 |
11 |
12 | from .base import (
13 | D2vModalityConfig,
14 | ModalitySpecificEncoder,
15 | )
16 |
17 | from .modules import AltBlock
18 |
19 | from .images import (
20 | D2vImageConfig,
21 | ImageEncoder,
22 | )
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 |
27 | @dataclass
28 | class D2vModalitiesConfig:
29 | image: D2vImageConfig = field(default_factory=lambda *args: D2vImageConfig())
30 |
31 |
32 | @dataclass
33 | class Data2VecMultiConfig:
34 | depth: int = 12
35 |
36 | # band split
37 | band_width: int = 100
38 |
39 | # standard vision Transformer
40 | start_drop_path_rate: float = 0.0
41 | end_drop_path_rate: float = 0.0
42 | num_heads: int = 12
43 | norm_eps: float = 1e-6
44 | norm_affine: bool = True
45 | encoder_dropout: float = 0.0
46 | post_mlp_drop: float = 0.0
47 | attention_dropout: float = 0.0
48 | activation_dropout: float = 0.0
49 | dropout_input: float = 0.0
50 | layerdrop: float = 0.0
51 | embed_dim: int = 768
52 | mlp_ratio: float = 4.0
53 | layer_norm_first: bool = False
54 |
55 | end_of_block_targets: bool = False
56 |
57 | # clone batch for multi-mask strategy
58 | clone_batch: int = 8
59 | max_band_per_sample: int = 64
60 |
61 | # normalization for teacher Transformer layer output
62 | layer_norm_target_layer: bool = False
63 | batch_norm_target_layer: bool = False
64 | instance_norm_target_layer: bool = True
65 | instance_norm_targets: bool = False
66 | layer_norm_targets: bool = True
67 |
68 | modalities: D2vModalitiesConfig = field(default_factory=lambda *args: D2vModalitiesConfig())
69 |
70 |
71 | class FISHER(nn.Module):
72 | def __init__(self, cfg: Data2VecMultiConfig):
73 | super().__init__()
74 | self.cfg = cfg
75 |
76 | make_layer_norm = partial(
77 | nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine
78 | )
79 |
80 | def make_block(drop_path, dim=None, heads=None):
81 | return AltBlock(
82 | cfg.embed_dim if dim is None else dim,
83 | cfg.num_heads if heads is None else heads,
84 | cfg.mlp_ratio,
85 | qkv_bias=True,
86 | drop=cfg.encoder_dropout,
87 | attn_drop=cfg.attention_dropout,
88 | mlp_drop=cfg.activation_dropout,
89 | post_mlp_drop=cfg.post_mlp_drop,
90 | drop_path=drop_path,
91 | norm_layer=make_layer_norm,
92 | layer_norm_first=cfg.layer_norm_first,
93 | ffn_targets=not cfg.end_of_block_targets,
94 | )
95 |
96 | self.alibi_biases = {}
97 | self.modality_encoders = nn.ModuleDict()
98 |
99 | mod_cfg = getattr(cfg.modalities, 'image')
100 | enc = self.make_modality_encoder(
101 | mod_cfg,
102 | cfg.embed_dim,
103 | make_block,
104 | make_layer_norm,
105 | cfg.layer_norm_first,
106 | self.alibi_biases,
107 | )
108 | self.modality_encoders['IMAGE'] = enc
109 |
110 | self.dropout_input = nn.Dropout(cfg.dropout_input)
111 |
112 | dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth)
113 |
114 | self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
115 |
116 | self.norm = None
117 | if cfg.layer_norm_first:
118 | self.norm = make_layer_norm(cfg.embed_dim)
119 |
120 | # band split
121 | self.band_width = cfg.band_width
122 | self.patch_size = cfg.modalities.image.patch_size
123 | self.num_time_patch = cfg.modalities.image.target_length // self.patch_size
124 | self.num_band_patch = self.band_width // self.patch_size
125 |
126 | def make_modality_encoder(
127 | self,
128 | cfg: D2vModalityConfig,
129 | embed_dim: int,
130 | make_block: Callable[[float], nn.ModuleList],
131 | norm_layer: Callable[[int], nn.LayerNorm],
132 | layer_norm_first: bool,
133 | alibi_biases,
134 | task=None,
135 | ) -> ModalitySpecificEncoder:
136 | return ImageEncoder(
137 | cfg,
138 | embed_dim,
139 | make_block,
140 | norm_layer,
141 | layer_norm_first,
142 | alibi_biases,
143 | task,
144 | )
145 |
146 | @classmethod
147 | def from_pretrained(
148 | cls,
149 | model_path: str
150 | ):
151 | """
152 | Load a pretrained FISHER model from a checkpoint file.
153 | """
154 | def update_dataclass(instance, data_dict):
155 | if not data_dict:
156 | return instance
157 |
158 | for field_name, field_value in data_dict.items():
159 | if hasattr(instance, field_name):
160 | current_value = getattr(instance, field_name)
161 | if is_dataclass(current_value) and isinstance(field_value, dict):
162 | update_dataclass(current_value, field_value)
163 | else:
164 | setattr(instance, field_name, field_value)
165 | return instance
166 |
167 | state_dict = torch.load(model_path, map_location='cpu', weights_only=False)
168 | cfg = Data2VecMultiConfig()
169 | update_dataclass(cfg, state_dict['cfg']['model'])
170 | model = cls(cfg)
171 | load_info = model.load_state_dict(state_dict['model'], strict=True)
172 | print(load_info)
173 | return model
174 |
175 | def state_dict(self, **kwargs):
176 | state = {
177 | 'cfg': self.cfg,
178 | 'model': super().state_dict(**kwargs)
179 | }
180 | return state
181 |
182 | def forward(
183 | self,
184 | source: torch.Tensor,
185 | target=None,
186 | id=None,
187 | mode='IMAGE',
188 | padding_mask: Optional[torch.Tensor] = None,
189 | mask: bool = True,
190 | features_only: bool = False,
191 | force_remove_masked=False,
192 | remove_extra_tokens: bool = True,
193 | precomputed_mask: Optional[torch.Tensor] = None,
194 | ):
195 | # band split
196 | num_band = source.shape[-1] // self.band_width
197 | source = torch.stack(source.split(self.band_width, dim=-1)[:num_band]) # drop residual
198 | source = rearrange(source, 'nb B c t f -> (B nb) c t f')
199 | clone_batch = self.cfg.max_band_per_sample // num_band
200 |
201 | feature_extractor = self.modality_encoders[mode] # models.images.ImageEncoder
202 |
203 | # extract (unmasked) features using CNN encoder
204 | extractor_out = feature_extractor(
205 | source,
206 | padding_mask,
207 | mask,
208 | remove_masked=not features_only or force_remove_masked, # train: True; infer: False
209 | clone_batch=clone_batch if not features_only else 1,
210 | mask_seeds=None,
211 | precomputed_mask=precomputed_mask,
212 | )
213 |
214 | # x in shape (batch_size * clone batch, patch_frame(64) * patch_freqency(8) * unmask_ratio(0.2) + 1(cls_token), 768(feature dimension))
215 | x = extractor_out["x"]
216 | # encoder_mask is applied on sub-band level
217 | encoder_mask = extractor_out["encoder_mask"] # models.base.MaskInfo, ["x_unmasked", "mask", "ids_restore", "ids_keep"]
218 | masked_padding_mask = extractor_out["padding_mask"]
219 | masked_alibi_bias = extractor_out.get("alibi_bias", None)
220 | alibi_scale = extractor_out.get("alibi_scale", None)
221 |
222 | if self.dropout_input is not None:
223 | x = self.dropout_input(x)
224 |
225 | # standard Transformer (for student encoder)
226 | layer_results = []
227 | for i, blk in enumerate(self.blocks):
228 | if (
229 | not self.training
230 | or self.cfg.layerdrop == 0
231 | or (np.random.random() > self.cfg.layerdrop)
232 | ):
233 | ab = masked_alibi_bias
234 | if ab is not None and alibi_scale is not None:
235 | scale = (
236 | alibi_scale[i]
237 | if alibi_scale.size(0) > 1
238 | else alibi_scale.squeeze(0)
239 | )
240 | ab = ab * scale.type_as(ab)
241 |
242 | x, lr = blk(
243 | x,
244 | padding_mask=masked_padding_mask,
245 | alibi_bias=ab,
246 | )
247 | if features_only:
248 | layer_results.append(lr)
249 |
250 | if self.norm is not None:
251 | x = self.norm(x)
252 |
253 | # extract features for fine-tuning
254 | if features_only:
255 | if remove_extra_tokens:
256 | x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
257 | if masked_padding_mask is not None:
258 | masked_padding_mask = masked_padding_mask[
259 | :, feature_extractor.modality_cfg.num_extra_tokens :
260 | ]
261 |
262 | return {
263 | "x": x,
264 | "padding_mask": masked_padding_mask,
265 | "layer_results": layer_results,
266 | "mask": encoder_mask,
267 | }
268 |
269 | def extract_features(
270 | self, source, mode='IMAGE', padding_mask=None, mask=False, remove_extra_tokens=False
271 | ):
272 | num_band = source.shape[-1] // self.band_width
273 | res = self.forward(
274 | source,
275 | mode=mode,
276 | padding_mask=padding_mask,
277 | mask=mask,
278 | features_only=True,
279 | remove_extra_tokens=remove_extra_tokens,
280 | )
281 | x = res['x'][:, 0]
282 | x = rearrange(x, '(B nb) D -> B (nb D)', nb=num_band)
283 | return x
284 |
--------------------------------------------------------------------------------
/models/images.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | from functools import partial
6 | from dataclasses import dataclass
7 | from typing import Callable, Dict, Optional
8 | from enum import Enum, auto
9 | from einops import rearrange
10 | from omegaconf import II
11 |
12 | from .mae import get_2d_sincos_pos_embed_flexible, PatchEmbed_new
13 |
14 |
15 | from .base import (
16 | D2vModalityConfig,
17 | ModalitySpecificEncoder,
18 | get_alibi_bias,
19 | )
20 | from .modules import (
21 | BlockEncoder,
22 | FixedPositionalEncoder,
23 | )
24 |
25 |
26 | class Modality(Enum):
27 | AUDIO = auto()
28 | IMAGE = auto()
29 | TEXT = auto()
30 |
31 |
32 | @dataclass
33 | class D2vImageConfig(D2vModalityConfig):
34 | type: Modality = Modality.IMAGE
35 |
36 | input_size: int = 224
37 | in_chans: int = 1
38 | patch_size: int = 16
39 | embed_dim: int = II('model.embed_dim')
40 |
41 | alibi_dims: int = 2
42 | alibi_distance: str = "manhattan"
43 |
44 | fixed_positions: bool = True
45 |
46 | transformer_decoder: bool = False
47 | enc_dec_transformer: bool = False
48 | target_length: int = 1024
49 | max_length: int = 128
50 | max_freq: int = 50
51 |
52 | band_width: int = II('model.band_width')
53 | flatten: str = 'freq' # 'time', 'freq'
54 |
55 |
56 | class ImageEncoder(ModalitySpecificEncoder):
57 | # forward() implemented in models.base.ModalitySpecificEncoder
58 |
59 | modality_cfg: D2vImageConfig
60 |
61 | def __init__(
62 | self,
63 | modality_cfg: D2vImageConfig,
64 | embed_dim: int,
65 | make_block: Callable[[float, Optional[int], Optional[int]], nn.ModuleList],
66 | norm_layer: Callable[[int], nn.LayerNorm],
67 | layer_norm_first: bool,
68 | alibi_biases: Dict,
69 | task=None,
70 | ):
71 | self.patch_size = modality_cfg.patch_size
72 | self.band_width = modality_cfg.band_width
73 | self.W = self.band_width // self.patch_size
74 | self.H = modality_cfg.target_length // self.patch_size # 64
75 |
76 | # convert spec to patch embed, using conv1d
77 | local_encoder = PatchEmbed_new(
78 | patch_size=modality_cfg.patch_size, # 16
79 | in_chans=modality_cfg.in_chans, # 1
80 | embed_dim=modality_cfg.embed_dim, # 768
81 | stride=modality_cfg.patch_size, # 16
82 | flatten=modality_cfg.flatten
83 | )
84 |
85 | # CNN initialize
86 | w = local_encoder.proj.weight.data
87 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
88 |
89 | if modality_cfg.embed_dim != embed_dim:
90 | local_encoder = nn.Sequential(
91 | local_encoder,
92 | nn.Linear(modality_cfg.embed_dim, embed_dim),
93 | )
94 |
95 | project_features = nn.Identity()
96 |
97 | # note: max_length control the maximum time length of audio -> "64" for 10s, here we define it as 2min, you can change it yourself
98 | max_length = modality_cfg.max_length
99 | max_freq = modality_cfg.max_freq
100 | # max_length=768, self.W=8, embed_dim=768
101 | pos_embed = nn.Parameter(
102 | torch.zeros(1, max_length*max_freq, embed_dim), requires_grad=False
103 | )
104 |
105 | # side_n = int(num_patches ** 0.5)
106 | # note: we fix the variable length sequence problem here -> support up to 2min audio
107 | emb = get_2d_sincos_pos_embed_flexible(
108 | pos_embed.shape[-1],
109 | (max_length, max_freq),
110 | cls_token=False,
111 | )
112 |
113 | pos_embed.data.copy_(torch.from_numpy(emb[:max_length * max_freq, :]).float().unsqueeze(0))
114 | fixed_positional_encoder = (
115 | FixedPositionalEncoder(pos_embed) if modality_cfg.fixed_positions else None # True
116 | )
117 |
118 | dpr = np.linspace( # drop_path_rate
119 | modality_cfg.start_drop_path_rate,
120 | modality_cfg.end_drop_path_rate,
121 | modality_cfg.prenet_depth, # actual: 0
122 | )
123 |
124 | # actual: only layer norm
125 | context_encoder = BlockEncoder(
126 | nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
127 | norm_layer(embed_dim) if not layer_norm_first else None,
128 | layer_norm_first,
129 | modality_cfg.prenet_layerdrop,
130 | modality_cfg.prenet_dropout,
131 | )
132 |
133 | alibi_bias_fn = partial(
134 | get_alibi_bias,
135 | alibi_biases=alibi_biases,
136 | heads=modality_cfg.num_alibi_heads,
137 | dims=modality_cfg.alibi_dims,
138 | distance=modality_cfg.alibi_distance,
139 | )
140 |
141 | super().__init__(
142 | modality_cfg=modality_cfg,
143 | embed_dim=embed_dim,
144 | local_encoder=local_encoder, # patch embed
145 | project_features=project_features, # nn.Identity()
146 | fixed_positional_encoder=fixed_positional_encoder,
147 | relative_positional_encoder=None,
148 | context_encoder=context_encoder, # apply mask
149 | decoder=None,
150 | get_alibi_bias=alibi_bias_fn,
151 | )
152 |
153 | def reset_parameters(self):
154 | super().reset_parameters()
155 |
156 | @torch.no_grad()
157 | def patchify(self, imgs):
158 | """
159 | imgs: (N, 3, H, W) audio: (N,1,H,W) 1024/16 = 64 128/16 = 8
160 | x: (N, L, patch_size**2 *3)
161 | """
162 | if self.modality_cfg.in_chans == 1: # actual: this one
163 | p = self.modality_cfg.patch_size
164 | h = imgs.shape[2] // p
165 | w = imgs.shape[3] // p
166 | # h,w = self.patch_embed.patch_hw
167 | x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))
168 | x = torch.einsum('nchpwq->nhwpqc', x)
169 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
170 |
171 | else:
172 | p = self.modality_cfg.patch_size
173 | h = w = imgs.shape[2] // p
174 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
175 | x = torch.einsum("nchpwq->nhwpqc", x)
176 | x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
177 |
178 | return x
179 |
180 | @torch.no_grad()
181 | def unpatchify(self, x):
182 | """
183 | x: (N, L, patch_size**2 *C)
184 | imgs: (N, C, H, W)
185 | """
186 | p = self.modality_cfg.patch_size
187 | h = w = int(x.shape[1] ** 0.5) # num patch along two axis
188 | assert h * w == x.shape[1]
189 |
190 | x = x.reshape(shape=(x.shape[0], h, w, p, p, -1))
191 | x = torch.einsum("nhwpqc->nchpwq", x)
192 | imgs = x.reshape(shape=(x.shape[0], -1, h * p, h * p))
193 | return imgs
194 |
195 | def convert_padding_mask(
196 | self,
197 | x: torch.Tensor,
198 | padding_mask: torch.Tensor
199 | ) -> torch.Tensor:
200 | '''patchify and serialize padding_mask: [b,t,f] => [b,t_patch,f_patch] => [b,patch_seq]
201 |
202 | Args:
203 | x (torch.Tensor): input_features
204 | padding_mask (torch.Tensor): [b,t_patch,f_patch], 1 for padded patch
205 |
206 | Returns:
207 | torch.Tensor: serialized padding mask. [b,patch_seq]
208 | '''
209 | B, T, F = x.shape
210 | t_extra, f_extra = T % self.patch_size, F % self.patch_size
211 | padding_mask = padding_mask[:, :-t_extra, :-f_extra]
212 | padding_mask = rearrange(
213 | padding_mask,
214 | 'b (tp p) (fp q) -> b tp fp (p q)',
215 | p=self.patch_size, q=self.patch_size
216 | )
217 | padding_mask = padding_mask.all(-1)
218 |
219 | if self.modality_cfg.flatten == 'time':
220 | padding_mask = padding_mask.transpose(-2, -1).flatten(1)
221 | else:
222 | padding_mask = padding_mask.flatten(1)
223 | return padding_mask
224 |
--------------------------------------------------------------------------------
/models/mae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | from timm.models.layers import to_2tuple
6 |
7 |
8 | class PatchEmbed_new(nn.Module):
9 | """ Flexible Image to Patch Embedding
10 | """
11 | def __init__(
12 | self,
13 | patch_size=16,
14 | in_chans=3,
15 | embed_dim=768,
16 | stride=16,
17 | flatten='freq'
18 | ):
19 | super().__init__()
20 | self.flatten = flatten
21 | patch_size = to_2tuple(patch_size)
22 | stride = to_2tuple(stride)
23 | assert flatten in ['time', 'freq']
24 |
25 | self.patch_size = patch_size
26 |
27 | # no padding for conv
28 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) # with overlapped patches
29 |
30 | def forward(self, x):
31 | x = self.proj(x) # (B,768,64,8)
32 | if self.flatten == 'freq':
33 | x = x.flatten(2).transpose(1, 2) # flatten from dim
34 | else:
35 | x = x.transpose(-2, -1).flatten(2).transpose(1, 2)
36 | return x
37 |
38 |
39 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
40 | """
41 | grid_size: int of the grid height and width
42 | return:
43 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
44 | """
45 | grid_h = np.arange(grid_size, dtype=np.float32)
46 | grid_w = np.arange(grid_size, dtype=np.float32)
47 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
48 | grid = np.stack(grid, axis=0)
49 |
50 | grid = grid.reshape([2, 1, grid_size, grid_size])
51 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
52 | if cls_token:
53 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
54 | return pos_embed
55 |
56 |
57 | def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False):
58 | """
59 | grid_size: int of the grid height and width
60 | return:
61 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
62 | """
63 | grid_h = np.arange(grid_size[0], dtype=np.float32)
64 | grid_w = np.arange(grid_size[1], dtype=np.float32)
65 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
66 | grid = np.stack(grid, axis=0)
67 |
68 | grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
69 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
70 | if cls_token:
71 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
72 | return pos_embed
73 |
74 |
75 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
76 | assert embed_dim % 2 == 0
77 |
78 | # use half of dimensions to encode grid_h
79 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
80 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
81 |
82 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
83 | return emb
84 |
85 |
86 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
87 | """
88 | embed_dim: output dimension for each position
89 | pos: a list of positions to be encoded: size (M,)
90 | out: (M, D)
91 | """
92 | assert embed_dim % 2 == 0
93 | omega = np.arange(embed_dim // 2, dtype=np.float32)
94 | omega /= embed_dim / 2.0
95 | omega = 1.0 / 10000 ** omega # (D/2,)
96 |
97 | pos = pos.reshape(-1) # (M,)
98 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
99 |
100 | emb_sin = np.sin(out) # (M, D/2)
101 | emb_cos = np.cos(out) # (M, D/2)
102 |
103 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
104 | return emb
105 |
106 |
107 | def interpolate_pos_embed(model, checkpoint_model):
108 | if "pos_embed" in checkpoint_model:
109 | pos_embed_checkpoint = checkpoint_model["pos_embed"]
110 | embedding_size = pos_embed_checkpoint.shape[-1]
111 | num_patches = model.patch_embed.num_patches
112 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
113 | # height (== width) for the checkpoint position embedding
114 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
115 | # height (== width) for the new position embedding
116 | new_size = int(num_patches ** 0.5)
117 | # class_token and dist_token are kept unchanged
118 | if orig_size != new_size:
119 | print(
120 | "Position interpolate from %dx%d to %dx%d"
121 | % (orig_size, orig_size, new_size, new_size)
122 | )
123 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
124 | # only the position tokens are interpolated
125 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
126 | pos_tokens = pos_tokens.reshape(
127 | -1, orig_size, orig_size, embedding_size
128 | ).permute(0, 3, 1, 2)
129 | pos_tokens = torch.nn.functional.interpolate(
130 | pos_tokens,
131 | size=(new_size, new_size),
132 | mode="bicubic",
133 | align_corners=False,
134 | )
135 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
136 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
137 | checkpoint_model["pos_embed"] = new_pos_embed
138 |
--------------------------------------------------------------------------------
/models/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | from dataclasses import dataclass
7 |
8 |
9 | @dataclass
10 | class D2vDecoderConfig:
11 | decoder_dim: int = 384
12 | decoder_groups: int = 16
13 | decoder_kernel: int = 5
14 | decoder_layers: int = 5
15 | input_dropout: float = 0.1
16 |
17 | add_positions_masked: bool = False
18 | add_positions_all: bool = False
19 |
20 | decoder_residual: bool = True
21 | projection_layers: int = 1
22 | projection_ratio: float = 2.0
23 |
24 |
25 | class FixedPositionalEncoder(nn.Module):
26 | def __init__(self, pos_embed):
27 | super().__init__()
28 | self.positions = pos_embed # [1, max_t * max_freq, embed_dim]
29 |
30 | def forward(self, x, padding_mask):
31 | return self.positions
32 |
33 |
34 | class BlockEncoder(nn.Module):
35 | def __init__(self, blocks, norm_layer, layer_norm_first, layerdrop, dropout):
36 | super().__init__()
37 | self.blocks = blocks
38 | self.norm = norm_layer
39 | self.layer_norm_first = layer_norm_first
40 | self.layerdrop = layerdrop
41 | self.dropout = nn.Dropout(dropout, inplace=True)
42 |
43 | def forward(self, x, padding_mask, alibi_bias, alibi_scale):
44 | if self.norm is not None and not self.layer_norm_first:
45 | x = self.norm(x)
46 |
47 | x = self.dropout(x)
48 |
49 | for i, blk in enumerate(self.blocks):
50 | if (
51 | not self.training
52 | or self.layerdrop == 0
53 | or (np.random.random() > self.layerdrop)
54 | ):
55 | ab = alibi_bias
56 | if ab is not None and alibi_scale is not None:
57 | scale = (
58 | alibi_scale[i]
59 | if alibi_scale.size(0) > 1
60 | else alibi_scale.squeeze(0)
61 | )
62 | ab = ab * scale.type_as(ab)
63 | x, _ = blk(x, padding_mask, ab)
64 |
65 | if self.norm is not None and self.layer_norm_first:
66 | x = self.norm(x)
67 |
68 | return x
69 |
70 |
71 | class AltBlock(nn.Module):
72 | def __init__(
73 | self,
74 | dim,
75 | num_heads,
76 | mlp_ratio=4.0,
77 | qkv_bias=False,
78 | qk_scale=None,
79 | drop=0.0,
80 | attn_drop=0.0,
81 | mlp_drop=0.0,
82 | post_mlp_drop=0.0,
83 | drop_path=0.0,
84 | act_layer=nn.GELU,
85 | norm_layer=nn.LayerNorm,
86 | layer_norm_first=True,
87 | ffn_targets=False,
88 | cosine_attention=False,
89 | ):
90 | super().__init__()
91 |
92 | self.layer_norm_first = layer_norm_first
93 | self.ffn_targets = ffn_targets
94 |
95 | from timm.models.vision_transformer import DropPath, Mlp
96 |
97 | self.norm1 = norm_layer(dim)
98 | self.attn = AltAttention(
99 | dim,
100 | num_heads=num_heads,
101 | qkv_bias=qkv_bias,
102 | qk_scale=qk_scale,
103 | attn_drop=attn_drop,
104 | proj_drop=drop,
105 | cosine_attention=cosine_attention,
106 | )
107 |
108 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
109 | self.norm2 = norm_layer(dim)
110 | mlp_hidden_dim = int(dim * mlp_ratio)
111 | self.mlp = Mlp(
112 | in_features=dim,
113 | hidden_features=mlp_hidden_dim,
114 | act_layer=act_layer,
115 | drop=mlp_drop,
116 | )
117 | self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False)
118 |
119 | def forward(self, x, padding_mask=None, alibi_bias=None):
120 | if self.layer_norm_first:
121 | x = x + self.drop_path(self.attn(self.norm1(x), padding_mask, alibi_bias))
122 | r = x = self.mlp(self.norm2(x))
123 | t = x
124 | x = r + self.drop_path(self.post_mlp_dropout(x))
125 | if not self.ffn_targets:
126 | t = x
127 | else:
128 | x = x + self.drop_path(self.attn(x, padding_mask, alibi_bias))
129 | r = x = self.norm1(x)
130 | x = self.mlp(x)
131 | t = x
132 | x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x)))
133 | if not self.ffn_targets:
134 | t = x
135 |
136 | return x, t
137 |
138 |
139 | class AltAttention(nn.Module):
140 | def __init__(
141 | self,
142 | dim,
143 | num_heads=8,
144 | qkv_bias=False,
145 | qk_scale=None,
146 | attn_drop=0.0,
147 | proj_drop=0.0,
148 | cosine_attention=False,
149 | ):
150 | super().__init__()
151 | self.num_heads = num_heads
152 | head_dim = dim // num_heads
153 | self.scale = qk_scale or head_dim ** -0.5
154 |
155 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
156 | self.attn_drop = nn.Dropout(attn_drop)
157 | self.proj = nn.Linear(dim, dim)
158 | self.proj_drop = nn.Dropout(proj_drop)
159 |
160 | self.cosine_attention = cosine_attention
161 |
162 | if cosine_attention:
163 | self.logit_scale = nn.Parameter(
164 | torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
165 | )
166 |
167 | def forward(self, x, padding_mask=None, alibi_bias=None):
168 | B, N, C = x.shape
169 | qkv = (
170 | self.qkv(x)
171 | .reshape(B, N, 3, self.num_heads, C // self.num_heads)
172 | .permute(2, 0, 3, 1, 4) # qkv x B x H x L x D
173 | )
174 | q, k, v = (
175 | qkv[0],
176 | qkv[1],
177 | qkv[2],
178 | ) # make torchscript happy (cannot use tensor as tuple)
179 |
180 | dtype = q.dtype
181 |
182 | if self.cosine_attention:
183 | # cosine attention
184 | attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
185 | logit_scale = torch.clamp(
186 | self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))
187 | ).exp()
188 | attn = attn * logit_scale
189 | else:
190 | q = q * self.scale
191 | attn = q @ k.transpose(-2, -1)
192 |
193 | if alibi_bias is not None:
194 | attn = attn.type_as(alibi_bias)
195 | attn[:, : alibi_bias.size(1)] += alibi_bias
196 |
197 | if padding_mask is not None and padding_mask.any():
198 | attn = attn.masked_fill(
199 | padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
200 | float("-inf"),
201 | )
202 |
203 | attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype)
204 | attn = self.attn_drop(attn)
205 | x = (attn @ v).transpose(1, 2) #
206 | x = x.reshape(B, N, C)
207 | x = self.proj(x)
208 | x = self.proj_drop(x)
209 | return x
210 |
--------------------------------------------------------------------------------