├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── assets ├── icon.jpg ├── model_pipe.png └── rmis_curve.png └── models ├── base.py ├── fisher.py ├── images.py ├── mae.py └── modules.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.py text eol=lf 2 | *.md text eol=lf -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ckpt/ 2 | hf* 3 | demo* 4 | __pycache__ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Anbai Jiang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | icon 3 | FISHER 4 |

5 | 6 |
7 | 8 | Python 9 | 10 | 11 | PyTorch 12 | 13 | 14 | arXiv 15 | 16 | 17 | huggingface 18 | 19 |
20 | 21 |
22 | 23 |
24 | Model Performances on the RMIS Benchmark 25 |
26 | 27 | ## 🔥🔥🔥 Updates 28 | 29 | - [2025.7.25] FISHER is now integrated on HuggingFace🤗. 30 | 31 | - [2025.7.23] We release the inference code and checkpoints for tiny, mini and small. 32 | 33 | ## Introduction 34 | 35 |
36 | Model Performances on the RMIS Benchmark 37 |
38 | 39 | FISHER is a **F**oundation model for **I**ndustrial **S**ignal compre**HE**nsive **R**epresentation, which models heterogeneous industrial signals (sound, vibration, voltage, etc.) in a unified manner. FISHER accepts arbitrary sampling rates and models the increment of sampling rate as the concatenation of sub-band information, which first splits a STFT spectrogram into sub-bands before processsing it by the ViT encoder. FISHER is trained by teacher student EMA self-distillation. 40 | 41 | To evaluate the model, we develop the RMIS benchmark, which will also be open-sourced in the near future. FISHER achieves the SOTA performances on the RMIS benchmark with much more efficient scaling properties. 42 | 43 | ## Checkpoints 44 | 45 | We release the checkpoints of FISHER-tiny, FISHER-mini and FISHER-small. 46 | 47 | | Version| ☁️ Tsinghua Cloud | 🤗 HuggingFace | wisemodel 48 | |------------| :------------: | :--------: | :--------: | 49 | | FISHER-tiny | [Link](https://cloud.tsinghua.edu.cn/f/630a4b1b2962481a9150/?dl=1) | [Link](https://huggingface.co/jiangab/FISHER-tiny-0723) | [Link](https://wisemodel.cn/models/jiangab/FISHER-tiny-0723) 50 | | FISHER-mini | [Link](https://cloud.tsinghua.edu.cn/f/60b3bfc0977f45f48dff/?dl=1) | [Link](https://huggingface.co/jiangab/FISHER-mini-0723) | [Link](https://wisemodel.cn/models/jiangab/FISHER-mini-0723) 51 | | FISHER-small | [Link](https://cloud.tsinghua.edu.cn/f/f997a6932b614046915e/?dl=1) | [Link](https://huggingface.co/jiangab/FISHER-small-0723) | [Link](https://wisemodel.cn/models/jiangab/FISHER-small-0723) 52 | 53 | ## Inference 54 | 55 | Please use the following code to infer the signal representation by FISHER. 56 | 57 | ```python 58 | import torch 59 | import torchaudio 60 | import torch.nn.functional as F 61 | from models.fisher import FISHER 62 | 63 | wav, sr = torchaudio.load('/path/to/local/signal.wav') 64 | # You can replace it with your custom loading function for other signals 65 | 66 | wav = wav - wav.mean() 67 | STFT = torchaudio.transforms.Spectrogram( 68 | n_fft=25 * sr // 1000, 69 | win_length=None, 70 | hop_length=10 * sr // 1000, 71 | power=1, 72 | center=False 73 | ) 74 | spec = torch.log(torch.abs(STFT(wav)) + 1e-10) 75 | spec = spec.transpose(-2, -1) # [1, time, freq] 76 | spec = (spec + 3.017344307886898) / (2.1531635155379805 * 2) 77 | 78 | model_path = '/path/to/local/fisher/model.pt' # Please download the checkpoint in advance. 79 | model = FISHER.from_pretrained(model_path) 80 | model = model.cuda() 81 | model.eval() 82 | 83 | # time-wise cutoff 84 | if spec.shape[-2] > 1024: 85 | spec = spec[:, :1024] 86 | # freq-wise padding 87 | if spec.shape[-1] < model.cfg.band_width: 88 | spec = F.pad(spec, (0, model.cfg.band_width - spec.shape[-1])) 89 | spec = spec.unsqueeze(1).cuda() 90 | 91 | with torch.no_grad(): 92 | # Use autocast for mixed precision inference. You can disable it for full precision. 93 | with torch.autocast('cuda'): 94 | repre = model.extract_features(spec) 95 | print(repre.shape) 96 | ``` 97 | 98 | ## Acknowledgements 99 | 100 | FISHER is developed based on [EAT](https://github.com/cwx-worst-one/EAT) and [fairseq](https://github.com/facebookresearch/fairseq). We thank these authors for open-sourcing their works. 101 | 102 | ## Citation 103 | 104 | If you find FISHER useful, please cite the following paper. 105 | 106 | ```bibtex 107 | @article{fan2025fisher, 108 | title={FISHER: A Foundation Model for Multi-Modal Industrial Signal Comprehensive Representation}, 109 | author={Fan, Pingyi and Jiang, Anbai and Zhang, Shuwei and Lv, Zhiqiang and Han, Bing and Zheng, Xinhu and Liang, Wenrui and Li, Junjie and Zhang, Wei-Qiang and Qian, Yanmin and Chen, Xie and Lu, Cheng and Liu, Jia}, 110 | journal={arXiv preprint arXiv:2507.16696}, 111 | year={2025} 112 | } 113 | ``` 114 | -------------------------------------------------------------------------------- /assets/icon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianganbai/FISHER/31ac067ce5414565902d70bcb0769848fdf0305d/assets/icon.jpg -------------------------------------------------------------------------------- /assets/model_pipe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianganbai/FISHER/31ac067ce5414565902d70bcb0769848fdf0305d/assets/model_pipe.png -------------------------------------------------------------------------------- /assets/rmis_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianganbai/FISHER/31ac067ce5414565902d70bcb0769848fdf0305d/assets/rmis_curve.png -------------------------------------------------------------------------------- /models/base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from collections import namedtuple 8 | from dataclasses import dataclass, field 9 | from functools import partial 10 | from omegaconf import MISSING, II 11 | from typing import Optional, Callable 12 | from enum import Enum, auto 13 | 14 | from .modules import D2vDecoderConfig 15 | 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class Modality(Enum): 21 | AUDIO = auto() 22 | IMAGE = auto() 23 | TEXT = auto() 24 | 25 | 26 | @dataclass 27 | class D2vModalityConfig: 28 | type: Modality = MISSING 29 | prenet_depth: int = 0 30 | prenet_layerdrop: float = 0.0 31 | prenet_dropout: float = 0.0 32 | start_drop_path_rate: float = 0.0 33 | end_drop_path_rate: float = 0.0 34 | 35 | num_extra_tokens: int = 1 36 | init_extra_token_zero: bool = False 37 | 38 | mask_noise_std: float = 0.01 39 | mask_prob_min: Optional[float] = None 40 | mask_prob: float = 0.8 41 | inverse_mask: bool = True 42 | mask_prob_adjust: float = 0.07 43 | keep_masked_pct: float = 0.0 44 | flexible_mask: bool = False 45 | 46 | mask_length: int = 5 47 | add_masks: bool = False 48 | remove_masks: bool = False 49 | mask_dropout: float = 0.0 50 | encoder_zero_mask: bool = True 51 | 52 | mask_channel_prob: float = 0.0 53 | mask_channel_length: int = 64 54 | 55 | ema_local_encoder: bool = True # used in data2vec_multi 56 | ema_local_decoder: bool = False 57 | local_grad_mult: float = 1.0 58 | flatten: str = 'freq' 59 | max_length: int = 128 60 | max_freq: int = 50 61 | 62 | use_alibi_encoder: bool = False 63 | alibi_scale: float = 1.0 64 | learned_alibi: bool = False 65 | alibi_max_pos: Optional[int] = None 66 | learned_alibi_scale: bool = False 67 | learned_alibi_scale_per_head: bool = False 68 | learned_alibi_scale_per_layer: bool = False 69 | 70 | num_alibi_heads: int = II("model.num_heads") 71 | model_depth: int = II("model.depth") 72 | 73 | decoder: Optional[D2vDecoderConfig] = field(default_factory=lambda *x: D2vDecoderConfig()) 74 | 75 | 76 | MaskSeed = namedtuple("MaskSeed", ["seed", "update", "ids"]) 77 | MaskInfo = namedtuple("MaskInfo", ["x_unmasked", "mask", "ids_restore", "ids_keep"]) 78 | 79 | 80 | class ModalitySpecificEncoder(nn.Module): 81 | def __init__( 82 | self, 83 | modality_cfg: D2vModalityConfig, 84 | embed_dim: int, 85 | local_encoder: nn.Module, 86 | project_features: nn.Module, 87 | fixed_positional_encoder: Optional[nn.Module], 88 | relative_positional_encoder: Optional[nn.Module], # None 89 | context_encoder: nn.Module, 90 | decoder: Optional[nn.Module], 91 | get_alibi_bias: Optional[Callable[[int, int, str, str], torch.Tensor]], 92 | ): 93 | super().__init__() 94 | 95 | self.modality_cfg = modality_cfg 96 | self.local_encoder = local_encoder 97 | self.project_features = project_features 98 | self.fixed_positional_encoder = fixed_positional_encoder 99 | self.relative_positional_encoder = relative_positional_encoder 100 | self.context_encoder = context_encoder 101 | 102 | self.decoder = decoder 103 | self.get_alibi_bias = get_alibi_bias if modality_cfg.use_alibi_encoder else None 104 | 105 | self.local_grad_mult = self.modality_cfg.local_grad_mult 106 | 107 | self.extra_tokens = None 108 | if modality_cfg.num_extra_tokens > 0: 109 | self.extra_tokens = nn.Parameter( 110 | torch.zeros(1, modality_cfg.num_extra_tokens, embed_dim) 111 | ) 112 | if not modality_cfg.init_extra_token_zero: 113 | nn.init.normal_(self.extra_tokens) 114 | elif self.extra_tokens.size(1) > 1: 115 | nn.init.normal_(self.extra_tokens[:, 1:]) 116 | 117 | self.alibi_scale = None 118 | if self.get_alibi_bias is not None: 119 | self.alibi_scale = nn.Parameter( 120 | torch.full( 121 | ( 122 | (modality_cfg.prenet_depth + modality_cfg.model_depth) 123 | if modality_cfg.learned_alibi_scale_per_layer 124 | else 1, 125 | 1, 126 | self.modality_cfg.num_alibi_heads 127 | if modality_cfg.learned_alibi_scale_per_head 128 | else 1, 129 | 1, 130 | 1, 131 | ), 132 | modality_cfg.alibi_scale, 133 | dtype=torch.float, 134 | ), 135 | requires_grad=modality_cfg.learned_alibi_scale, 136 | ) 137 | 138 | if modality_cfg.learned_alibi and self.get_alibi_bias is not None: 139 | assert modality_cfg.alibi_max_pos is not None 140 | alibi_bias = self.get_alibi_bias( 141 | batch_size=1, 142 | time_steps=modality_cfg.alibi_max_pos, 143 | heads=modality_cfg.num_alibi_heads, 144 | scale=1.0, 145 | dtype=torch.float, 146 | device="cpu", 147 | ) 148 | self.alibi_bias = nn.Parameter(alibi_bias) 149 | self.get_alibi_bias = partial( 150 | _learned_alibi_bias, alibi_bias=self.alibi_bias 151 | ) 152 | 153 | def upgrade_state_dict_named(self, state_dict, name): 154 | k = f"{name}.alibi_scale" 155 | if k in state_dict and state_dict[k].dim() == 4: 156 | state_dict[k] = state_dict[k].unsqueeze(0) 157 | 158 | return state_dict 159 | 160 | def convert_padding_mask(self, x, padding_mask): 161 | return padding_mask 162 | 163 | def local_features(self, features): 164 | x = self.local_encoder(features) 165 | x = self.project_features(x) # nn.Identity() 166 | return x 167 | 168 | def contextualized_features( 169 | self, 170 | x, 171 | padding_mask, 172 | mask, # True 173 | remove_masked, # train: True; infer: False 174 | clone_batch: int = 1, 175 | mask_seeds: Optional[torch.Tensor] = None, 176 | precomputed_mask=None, 177 | ): 178 | 179 | if padding_mask is not None: 180 | padding_mask = self.convert_padding_mask(x, padding_mask) # [b,t,f] => [b,seq] 181 | 182 | local_features = x 183 | if mask and clone_batch == 1: 184 | local_features = local_features.clone() 185 | 186 | orig_B, orig_T, _ = x.shape 187 | pre_mask_B = orig_B 188 | mask_info = None 189 | 190 | x_pos = None 191 | # x: [B, seq_len, embed_dim] 192 | if self.fixed_positional_encoder is not None: # models.modules.FixPositionalEncoder 193 | x = x + self.fixed_positional_encoder(x, padding_mask)[:, :x.size(1), :] 194 | 195 | if self.relative_positional_encoder is not None: 196 | x_pos = self.relative_positional_encoder(x) 197 | 198 | masked_padding_mask = padding_mask 199 | if mask and remove_masked: # only pass masked to student 200 | x = mask_info.x_unmasked 201 | if x_pos is not None: 202 | x = x + gather_unmasked(x_pos, mask_info) 203 | 204 | # padding_mask: [bs, seq_len] 205 | # valid: False; padded: True 206 | if padding_mask is not None and padding_mask.any(): 207 | # retrieve padding_mask for unmasked patch 208 | masked_padding_mask = gather_unmasked_mask(padding_mask, mask_info) 209 | if not masked_padding_mask.any(): 210 | masked_padding_mask = None 211 | else: 212 | masked_padding_mask = None 213 | 214 | elif x_pos is not None: 215 | x = x + x_pos 216 | 217 | alibi_bias = None 218 | alibi_scale = self.alibi_scale 219 | 220 | if self.get_alibi_bias is not None: 221 | alibi_bias = self.get_alibi_bias( 222 | batch_size=pre_mask_B, 223 | time_steps=orig_T, 224 | heads=self.modality_cfg.num_alibi_heads, 225 | dtype=torch.float32, 226 | device=x.device, 227 | ) 228 | 229 | if alibi_scale is not None: 230 | alibi_scale = alibi_scale.clamp_min(0) 231 | if alibi_scale.size(0) == 1: 232 | alibi_bias = alibi_bias * alibi_scale.squeeze(0).type_as(alibi_bias) 233 | alibi_scale = None 234 | 235 | if clone_batch > 1: 236 | alibi_bias = alibi_bias.repeat_interleave(clone_batch, 0) 237 | 238 | if mask_info is not None and remove_masked: 239 | alibi_bias = masked_alibi(alibi_bias, mask_info) 240 | 241 | if self.extra_tokens is not None: 242 | num = self.extra_tokens.size(1) 243 | x = torch.cat([self.extra_tokens.expand(x.size(0), -1, -1), x], dim=1) 244 | if masked_padding_mask is not None: 245 | # B x T 246 | masked_padding_mask = F.pad(masked_padding_mask, (num, 0)) 247 | if alibi_bias is not None: 248 | # B x H x T x T 249 | alibi_bias = F.pad(alibi_bias, (num, 0, num, 0)) 250 | 251 | x = self.context_encoder( 252 | x, 253 | masked_padding_mask, 254 | alibi_bias, 255 | alibi_scale[: self.modality_cfg.prenet_depth] 256 | if alibi_scale is not None 257 | else None, 258 | ) 259 | 260 | return { 261 | "x": x, 262 | "local_features": local_features, 263 | "padding_mask": masked_padding_mask, 264 | "alibi_bias": alibi_bias, 265 | "alibi_scale": alibi_scale[self.modality_cfg.prenet_depth :] 266 | if alibi_scale is not None and alibi_scale.size(0) > 1 267 | else alibi_scale, 268 | "encoder_mask": mask_info, 269 | } 270 | 271 | def forward( 272 | self, 273 | features, 274 | padding_mask, 275 | mask: bool, 276 | remove_masked: bool, 277 | clone_batch: int = 1, 278 | mask_seeds: Optional[torch.Tensor] = None, 279 | precomputed_mask=None, 280 | ): 281 | x = self.local_features(features) # patch embed 282 | # x: [bs, time*freq, embed_dim], e.g. [12, 512, 768] 283 | out = self.contextualized_features( 284 | x, 285 | padding_mask, 286 | mask, 287 | remove_masked, 288 | clone_batch, 289 | mask_seeds, 290 | precomputed_mask, 291 | ) # add mask, discarded masked, context encoder (only layer norm) 292 | return out 293 | 294 | def reset_parameters(self): 295 | pass 296 | 297 | def remove_pretraining_modules(self, keep_decoder=False): 298 | if not keep_decoder: 299 | self.decoder = None 300 | 301 | 302 | def get_annealed_rate(start, end, curr_step, total_steps): 303 | if curr_step >= total_steps: 304 | return end 305 | r = end - start 306 | pct_remaining = 1 - curr_step / total_steps 307 | return end - r * pct_remaining 308 | 309 | 310 | def gather_unmasked(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor: 311 | return torch.gather( 312 | x, 313 | dim=1, 314 | index=mask_info.ids_keep, 315 | ) 316 | 317 | 318 | def gather_unmasked_mask(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor: 319 | return torch.gather( 320 | x, 321 | dim=1, 322 | index=mask_info.ids_keep[..., 0], # ignore the feature dimension 323 | ) 324 | 325 | 326 | def get_alibi( 327 | max_positions: int, 328 | attention_heads: int, 329 | dims: int = 1, 330 | distance: str = "manhattan", 331 | ): 332 | def get_slopes(n): 333 | def get_slopes_power_of_2(n): 334 | start = 2 ** (-(2 ** -(math.log2(n) - 3))) 335 | ratio = start 336 | return [start * ratio**i for i in range(n)] 337 | 338 | # In the paper, we only train models that have 2^a heads for some 339 | # a. This function has some good properties that only occur when 340 | # the input is a power of 2. To maintain that even when the number 341 | # of heads is not a power of 2, we use this workaround. 342 | if math.log2(n).is_integer(): 343 | return get_slopes_power_of_2(n) 344 | else: 345 | closest_power_of_2 = 2 ** math.floor(math.log2(n)) 346 | return ( 347 | get_slopes_power_of_2(closest_power_of_2) 348 | + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] 349 | ) 350 | 351 | maxpos = max_positions 352 | attn_heads = attention_heads 353 | slopes = torch.Tensor(get_slopes(attn_heads)) 354 | 355 | if dims == 1: 356 | # prepare alibi position linear bias. Note that wav2vec2 is non 357 | # autoregressive model so we want a symmetric mask with 0 on the 358 | # diagonal and other wise linear decreasing valuees 359 | pos_bias = ( 360 | torch.abs( 361 | torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1) 362 | ) 363 | * -1 364 | ) 365 | elif dims == 2: 366 | if distance == "manhattan": 367 | df = lambda x1, y1, x2, y2: abs(x1 - x2) + abs(y1 - y2) 368 | elif distance == "euclidean": 369 | df = lambda x1, y1, x2, y2: math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2) 370 | 371 | n = math.sqrt(max_positions) 372 | assert n.is_integer(), n 373 | n = int(n) 374 | 375 | pos_bias = torch.zeros((max_positions, max_positions)) 376 | 377 | for i in range(n): 378 | for j in range(n): 379 | for k in range(n): 380 | for l in range(n): 381 | new_x = i * n + j 382 | new_y = k * n + l 383 | pos_bias[new_x, new_y] = -df(i, j, k, l) 384 | 385 | else: 386 | raise Exception(f"unsupported number of alibi dims: {dims}") 387 | 388 | alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand( 389 | attn_heads, -1, -1 390 | ) 391 | 392 | return alibi_bias 393 | 394 | 395 | def get_alibi_bias( 396 | alibi_biases, 397 | batch_size, 398 | time_steps, 399 | heads, 400 | dtype, 401 | device, 402 | dims=1, 403 | distance="manhattan", 404 | ): 405 | cache_key = f"{dims}_{heads}_{distance}" 406 | 407 | buffered = alibi_biases.get(cache_key, None) 408 | 409 | target_size = heads * batch_size 410 | if ( 411 | buffered is None 412 | or buffered.size(0) < target_size 413 | or buffered.size(1) < time_steps 414 | or buffered.dtype != dtype 415 | or buffered.device != device 416 | ): 417 | bt = max(time_steps, buffered.size(1) if buffered is not None else 0) 418 | bn = max(target_size, buffered.size(0) if buffered is not None else 0) // heads 419 | 420 | buffered = ( 421 | get_alibi(bt, heads, dims=dims, distance=distance) 422 | .to(dtype=dtype, device=device) 423 | .repeat(bn, 1, 1) 424 | ) 425 | 426 | alibi_biases[cache_key] = buffered 427 | 428 | b = buffered[:target_size, :time_steps, :time_steps] 429 | b = b.view(batch_size, heads, time_steps, time_steps) 430 | return b 431 | 432 | 433 | def _learned_alibi_bias( 434 | alibi_bias, 435 | batch_size, 436 | time_steps, 437 | heads, 438 | scale, 439 | dtype, 440 | device, 441 | ): 442 | assert alibi_bias.size(1) == heads, alibi_bias.shape 443 | assert alibi_bias.dtype == dtype, alibi_bias.dtype 444 | assert alibi_bias.device == device, alibi_bias.device 445 | 446 | if alibi_bias.size(-1) < time_steps: 447 | psz = math.ceil((time_steps - alibi_bias.size(-1)) / 2) 448 | alibi_bias = F.pad(alibi_bias, (psz, psz, psz, psz), mode="replicate") 449 | 450 | alibi_bias = alibi_bias.expand(batch_size, -1, -1, -1) * scale 451 | return alibi_bias[..., :time_steps, :time_steps] 452 | 453 | 454 | def masked_alibi(alibi_bias, mask_info): 455 | H = alibi_bias.size(1) 456 | 457 | orig_bias = alibi_bias 458 | 459 | index = mask_info.ids_keep.unsqueeze(1)[..., 0].unsqueeze(-1) 460 | alibi_bias = torch.gather( 461 | orig_bias, 462 | dim=-2, 463 | index=index.expand(-1, H, -1, mask_info.ids_restore.size(1)), 464 | ) 465 | alibi_bias = torch.gather( 466 | alibi_bias, 467 | dim=-1, 468 | index=index.transpose(-1, -2).expand(-1, H, alibi_bias.size(-2), -1), 469 | ) 470 | 471 | return alibi_bias 472 | -------------------------------------------------------------------------------- /models/fisher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import numpy as np 4 | import torch.nn as nn 5 | 6 | from functools import partial 7 | from einops import rearrange 8 | from typing import Callable, Optional 9 | from dataclasses import dataclass, field, is_dataclass 10 | 11 | 12 | from .base import ( 13 | D2vModalityConfig, 14 | ModalitySpecificEncoder, 15 | ) 16 | 17 | from .modules import AltBlock 18 | 19 | from .images import ( 20 | D2vImageConfig, 21 | ImageEncoder, 22 | ) 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | @dataclass 28 | class D2vModalitiesConfig: 29 | image: D2vImageConfig = field(default_factory=lambda *args: D2vImageConfig()) 30 | 31 | 32 | @dataclass 33 | class Data2VecMultiConfig: 34 | depth: int = 12 35 | 36 | # band split 37 | band_width: int = 100 38 | 39 | # standard vision Transformer 40 | start_drop_path_rate: float = 0.0 41 | end_drop_path_rate: float = 0.0 42 | num_heads: int = 12 43 | norm_eps: float = 1e-6 44 | norm_affine: bool = True 45 | encoder_dropout: float = 0.0 46 | post_mlp_drop: float = 0.0 47 | attention_dropout: float = 0.0 48 | activation_dropout: float = 0.0 49 | dropout_input: float = 0.0 50 | layerdrop: float = 0.0 51 | embed_dim: int = 768 52 | mlp_ratio: float = 4.0 53 | layer_norm_first: bool = False 54 | 55 | end_of_block_targets: bool = False 56 | 57 | # clone batch for multi-mask strategy 58 | clone_batch: int = 8 59 | max_band_per_sample: int = 64 60 | 61 | # normalization for teacher Transformer layer output 62 | layer_norm_target_layer: bool = False 63 | batch_norm_target_layer: bool = False 64 | instance_norm_target_layer: bool = True 65 | instance_norm_targets: bool = False 66 | layer_norm_targets: bool = True 67 | 68 | modalities: D2vModalitiesConfig = field(default_factory=lambda *args: D2vModalitiesConfig()) 69 | 70 | 71 | class FISHER(nn.Module): 72 | def __init__(self, cfg: Data2VecMultiConfig): 73 | super().__init__() 74 | self.cfg = cfg 75 | 76 | make_layer_norm = partial( 77 | nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine 78 | ) 79 | 80 | def make_block(drop_path, dim=None, heads=None): 81 | return AltBlock( 82 | cfg.embed_dim if dim is None else dim, 83 | cfg.num_heads if heads is None else heads, 84 | cfg.mlp_ratio, 85 | qkv_bias=True, 86 | drop=cfg.encoder_dropout, 87 | attn_drop=cfg.attention_dropout, 88 | mlp_drop=cfg.activation_dropout, 89 | post_mlp_drop=cfg.post_mlp_drop, 90 | drop_path=drop_path, 91 | norm_layer=make_layer_norm, 92 | layer_norm_first=cfg.layer_norm_first, 93 | ffn_targets=not cfg.end_of_block_targets, 94 | ) 95 | 96 | self.alibi_biases = {} 97 | self.modality_encoders = nn.ModuleDict() 98 | 99 | mod_cfg = getattr(cfg.modalities, 'image') 100 | enc = self.make_modality_encoder( 101 | mod_cfg, 102 | cfg.embed_dim, 103 | make_block, 104 | make_layer_norm, 105 | cfg.layer_norm_first, 106 | self.alibi_biases, 107 | ) 108 | self.modality_encoders['IMAGE'] = enc 109 | 110 | self.dropout_input = nn.Dropout(cfg.dropout_input) 111 | 112 | dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth) 113 | 114 | self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)]) 115 | 116 | self.norm = None 117 | if cfg.layer_norm_first: 118 | self.norm = make_layer_norm(cfg.embed_dim) 119 | 120 | # band split 121 | self.band_width = cfg.band_width 122 | self.patch_size = cfg.modalities.image.patch_size 123 | self.num_time_patch = cfg.modalities.image.target_length // self.patch_size 124 | self.num_band_patch = self.band_width // self.patch_size 125 | 126 | def make_modality_encoder( 127 | self, 128 | cfg: D2vModalityConfig, 129 | embed_dim: int, 130 | make_block: Callable[[float], nn.ModuleList], 131 | norm_layer: Callable[[int], nn.LayerNorm], 132 | layer_norm_first: bool, 133 | alibi_biases, 134 | task=None, 135 | ) -> ModalitySpecificEncoder: 136 | return ImageEncoder( 137 | cfg, 138 | embed_dim, 139 | make_block, 140 | norm_layer, 141 | layer_norm_first, 142 | alibi_biases, 143 | task, 144 | ) 145 | 146 | @classmethod 147 | def from_pretrained( 148 | cls, 149 | model_path: str 150 | ): 151 | """ 152 | Load a pretrained FISHER model from a checkpoint file. 153 | """ 154 | def update_dataclass(instance, data_dict): 155 | if not data_dict: 156 | return instance 157 | 158 | for field_name, field_value in data_dict.items(): 159 | if hasattr(instance, field_name): 160 | current_value = getattr(instance, field_name) 161 | if is_dataclass(current_value) and isinstance(field_value, dict): 162 | update_dataclass(current_value, field_value) 163 | else: 164 | setattr(instance, field_name, field_value) 165 | return instance 166 | 167 | state_dict = torch.load(model_path, map_location='cpu', weights_only=False) 168 | cfg = Data2VecMultiConfig() 169 | update_dataclass(cfg, state_dict['cfg']['model']) 170 | model = cls(cfg) 171 | load_info = model.load_state_dict(state_dict['model'], strict=True) 172 | print(load_info) 173 | return model 174 | 175 | def state_dict(self, **kwargs): 176 | state = { 177 | 'cfg': self.cfg, 178 | 'model': super().state_dict(**kwargs) 179 | } 180 | return state 181 | 182 | def forward( 183 | self, 184 | source: torch.Tensor, 185 | target=None, 186 | id=None, 187 | mode='IMAGE', 188 | padding_mask: Optional[torch.Tensor] = None, 189 | mask: bool = True, 190 | features_only: bool = False, 191 | force_remove_masked=False, 192 | remove_extra_tokens: bool = True, 193 | precomputed_mask: Optional[torch.Tensor] = None, 194 | ): 195 | # band split 196 | num_band = source.shape[-1] // self.band_width 197 | source = torch.stack(source.split(self.band_width, dim=-1)[:num_band]) # drop residual 198 | source = rearrange(source, 'nb B c t f -> (B nb) c t f') 199 | clone_batch = self.cfg.max_band_per_sample // num_band 200 | 201 | feature_extractor = self.modality_encoders[mode] # models.images.ImageEncoder 202 | 203 | # extract (unmasked) features using CNN encoder 204 | extractor_out = feature_extractor( 205 | source, 206 | padding_mask, 207 | mask, 208 | remove_masked=not features_only or force_remove_masked, # train: True; infer: False 209 | clone_batch=clone_batch if not features_only else 1, 210 | mask_seeds=None, 211 | precomputed_mask=precomputed_mask, 212 | ) 213 | 214 | # x in shape (batch_size * clone batch, patch_frame(64) * patch_freqency(8) * unmask_ratio(0.2) + 1(cls_token), 768(feature dimension)) 215 | x = extractor_out["x"] 216 | # encoder_mask is applied on sub-band level 217 | encoder_mask = extractor_out["encoder_mask"] # models.base.MaskInfo, ["x_unmasked", "mask", "ids_restore", "ids_keep"] 218 | masked_padding_mask = extractor_out["padding_mask"] 219 | masked_alibi_bias = extractor_out.get("alibi_bias", None) 220 | alibi_scale = extractor_out.get("alibi_scale", None) 221 | 222 | if self.dropout_input is not None: 223 | x = self.dropout_input(x) 224 | 225 | # standard Transformer (for student encoder) 226 | layer_results = [] 227 | for i, blk in enumerate(self.blocks): 228 | if ( 229 | not self.training 230 | or self.cfg.layerdrop == 0 231 | or (np.random.random() > self.cfg.layerdrop) 232 | ): 233 | ab = masked_alibi_bias 234 | if ab is not None and alibi_scale is not None: 235 | scale = ( 236 | alibi_scale[i] 237 | if alibi_scale.size(0) > 1 238 | else alibi_scale.squeeze(0) 239 | ) 240 | ab = ab * scale.type_as(ab) 241 | 242 | x, lr = blk( 243 | x, 244 | padding_mask=masked_padding_mask, 245 | alibi_bias=ab, 246 | ) 247 | if features_only: 248 | layer_results.append(lr) 249 | 250 | if self.norm is not None: 251 | x = self.norm(x) 252 | 253 | # extract features for fine-tuning 254 | if features_only: 255 | if remove_extra_tokens: 256 | x = x[:, feature_extractor.modality_cfg.num_extra_tokens :] 257 | if masked_padding_mask is not None: 258 | masked_padding_mask = masked_padding_mask[ 259 | :, feature_extractor.modality_cfg.num_extra_tokens : 260 | ] 261 | 262 | return { 263 | "x": x, 264 | "padding_mask": masked_padding_mask, 265 | "layer_results": layer_results, 266 | "mask": encoder_mask, 267 | } 268 | 269 | def extract_features( 270 | self, source, mode='IMAGE', padding_mask=None, mask=False, remove_extra_tokens=False 271 | ): 272 | num_band = source.shape[-1] // self.band_width 273 | res = self.forward( 274 | source, 275 | mode=mode, 276 | padding_mask=padding_mask, 277 | mask=mask, 278 | features_only=True, 279 | remove_extra_tokens=remove_extra_tokens, 280 | ) 281 | x = res['x'][:, 0] 282 | x = rearrange(x, '(B nb) D -> B (nb D)', nb=num_band) 283 | return x 284 | -------------------------------------------------------------------------------- /models/images.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | from functools import partial 6 | from dataclasses import dataclass 7 | from typing import Callable, Dict, Optional 8 | from enum import Enum, auto 9 | from einops import rearrange 10 | from omegaconf import II 11 | 12 | from .mae import get_2d_sincos_pos_embed_flexible, PatchEmbed_new 13 | 14 | 15 | from .base import ( 16 | D2vModalityConfig, 17 | ModalitySpecificEncoder, 18 | get_alibi_bias, 19 | ) 20 | from .modules import ( 21 | BlockEncoder, 22 | FixedPositionalEncoder, 23 | ) 24 | 25 | 26 | class Modality(Enum): 27 | AUDIO = auto() 28 | IMAGE = auto() 29 | TEXT = auto() 30 | 31 | 32 | @dataclass 33 | class D2vImageConfig(D2vModalityConfig): 34 | type: Modality = Modality.IMAGE 35 | 36 | input_size: int = 224 37 | in_chans: int = 1 38 | patch_size: int = 16 39 | embed_dim: int = II('model.embed_dim') 40 | 41 | alibi_dims: int = 2 42 | alibi_distance: str = "manhattan" 43 | 44 | fixed_positions: bool = True 45 | 46 | transformer_decoder: bool = False 47 | enc_dec_transformer: bool = False 48 | target_length: int = 1024 49 | max_length: int = 128 50 | max_freq: int = 50 51 | 52 | band_width: int = II('model.band_width') 53 | flatten: str = 'freq' # 'time', 'freq' 54 | 55 | 56 | class ImageEncoder(ModalitySpecificEncoder): 57 | # forward() implemented in models.base.ModalitySpecificEncoder 58 | 59 | modality_cfg: D2vImageConfig 60 | 61 | def __init__( 62 | self, 63 | modality_cfg: D2vImageConfig, 64 | embed_dim: int, 65 | make_block: Callable[[float, Optional[int], Optional[int]], nn.ModuleList], 66 | norm_layer: Callable[[int], nn.LayerNorm], 67 | layer_norm_first: bool, 68 | alibi_biases: Dict, 69 | task=None, 70 | ): 71 | self.patch_size = modality_cfg.patch_size 72 | self.band_width = modality_cfg.band_width 73 | self.W = self.band_width // self.patch_size 74 | self.H = modality_cfg.target_length // self.patch_size # 64 75 | 76 | # convert spec to patch embed, using conv1d 77 | local_encoder = PatchEmbed_new( 78 | patch_size=modality_cfg.patch_size, # 16 79 | in_chans=modality_cfg.in_chans, # 1 80 | embed_dim=modality_cfg.embed_dim, # 768 81 | stride=modality_cfg.patch_size, # 16 82 | flatten=modality_cfg.flatten 83 | ) 84 | 85 | # CNN initialize 86 | w = local_encoder.proj.weight.data 87 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 88 | 89 | if modality_cfg.embed_dim != embed_dim: 90 | local_encoder = nn.Sequential( 91 | local_encoder, 92 | nn.Linear(modality_cfg.embed_dim, embed_dim), 93 | ) 94 | 95 | project_features = nn.Identity() 96 | 97 | # note: max_length control the maximum time length of audio -> "64" for 10s, here we define it as 2min, you can change it yourself 98 | max_length = modality_cfg.max_length 99 | max_freq = modality_cfg.max_freq 100 | # max_length=768, self.W=8, embed_dim=768 101 | pos_embed = nn.Parameter( 102 | torch.zeros(1, max_length*max_freq, embed_dim), requires_grad=False 103 | ) 104 | 105 | # side_n = int(num_patches ** 0.5) 106 | # note: we fix the variable length sequence problem here -> support up to 2min audio 107 | emb = get_2d_sincos_pos_embed_flexible( 108 | pos_embed.shape[-1], 109 | (max_length, max_freq), 110 | cls_token=False, 111 | ) 112 | 113 | pos_embed.data.copy_(torch.from_numpy(emb[:max_length * max_freq, :]).float().unsqueeze(0)) 114 | fixed_positional_encoder = ( 115 | FixedPositionalEncoder(pos_embed) if modality_cfg.fixed_positions else None # True 116 | ) 117 | 118 | dpr = np.linspace( # drop_path_rate 119 | modality_cfg.start_drop_path_rate, 120 | modality_cfg.end_drop_path_rate, 121 | modality_cfg.prenet_depth, # actual: 0 122 | ) 123 | 124 | # actual: only layer norm 125 | context_encoder = BlockEncoder( 126 | nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)), 127 | norm_layer(embed_dim) if not layer_norm_first else None, 128 | layer_norm_first, 129 | modality_cfg.prenet_layerdrop, 130 | modality_cfg.prenet_dropout, 131 | ) 132 | 133 | alibi_bias_fn = partial( 134 | get_alibi_bias, 135 | alibi_biases=alibi_biases, 136 | heads=modality_cfg.num_alibi_heads, 137 | dims=modality_cfg.alibi_dims, 138 | distance=modality_cfg.alibi_distance, 139 | ) 140 | 141 | super().__init__( 142 | modality_cfg=modality_cfg, 143 | embed_dim=embed_dim, 144 | local_encoder=local_encoder, # patch embed 145 | project_features=project_features, # nn.Identity() 146 | fixed_positional_encoder=fixed_positional_encoder, 147 | relative_positional_encoder=None, 148 | context_encoder=context_encoder, # apply mask 149 | decoder=None, 150 | get_alibi_bias=alibi_bias_fn, 151 | ) 152 | 153 | def reset_parameters(self): 154 | super().reset_parameters() 155 | 156 | @torch.no_grad() 157 | def patchify(self, imgs): 158 | """ 159 | imgs: (N, 3, H, W) audio: (N,1,H,W) 1024/16 = 64 128/16 = 8 160 | x: (N, L, patch_size**2 *3) 161 | """ 162 | if self.modality_cfg.in_chans == 1: # actual: this one 163 | p = self.modality_cfg.patch_size 164 | h = imgs.shape[2] // p 165 | w = imgs.shape[3] // p 166 | # h,w = self.patch_embed.patch_hw 167 | x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p)) 168 | x = torch.einsum('nchpwq->nhwpqc', x) 169 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1)) 170 | 171 | else: 172 | p = self.modality_cfg.patch_size 173 | h = w = imgs.shape[2] // p 174 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) 175 | x = torch.einsum("nchpwq->nhwpqc", x) 176 | x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3)) 177 | 178 | return x 179 | 180 | @torch.no_grad() 181 | def unpatchify(self, x): 182 | """ 183 | x: (N, L, patch_size**2 *C) 184 | imgs: (N, C, H, W) 185 | """ 186 | p = self.modality_cfg.patch_size 187 | h = w = int(x.shape[1] ** 0.5) # num patch along two axis 188 | assert h * w == x.shape[1] 189 | 190 | x = x.reshape(shape=(x.shape[0], h, w, p, p, -1)) 191 | x = torch.einsum("nhwpqc->nchpwq", x) 192 | imgs = x.reshape(shape=(x.shape[0], -1, h * p, h * p)) 193 | return imgs 194 | 195 | def convert_padding_mask( 196 | self, 197 | x: torch.Tensor, 198 | padding_mask: torch.Tensor 199 | ) -> torch.Tensor: 200 | '''patchify and serialize padding_mask: [b,t,f] => [b,t_patch,f_patch] => [b,patch_seq] 201 | 202 | Args: 203 | x (torch.Tensor): input_features 204 | padding_mask (torch.Tensor): [b,t_patch,f_patch], 1 for padded patch 205 | 206 | Returns: 207 | torch.Tensor: serialized padding mask. [b,patch_seq] 208 | ''' 209 | B, T, F = x.shape 210 | t_extra, f_extra = T % self.patch_size, F % self.patch_size 211 | padding_mask = padding_mask[:, :-t_extra, :-f_extra] 212 | padding_mask = rearrange( 213 | padding_mask, 214 | 'b (tp p) (fp q) -> b tp fp (p q)', 215 | p=self.patch_size, q=self.patch_size 216 | ) 217 | padding_mask = padding_mask.all(-1) 218 | 219 | if self.modality_cfg.flatten == 'time': 220 | padding_mask = padding_mask.transpose(-2, -1).flatten(1) 221 | else: 222 | padding_mask = padding_mask.flatten(1) 223 | return padding_mask 224 | -------------------------------------------------------------------------------- /models/mae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | from timm.models.layers import to_2tuple 6 | 7 | 8 | class PatchEmbed_new(nn.Module): 9 | """ Flexible Image to Patch Embedding 10 | """ 11 | def __init__( 12 | self, 13 | patch_size=16, 14 | in_chans=3, 15 | embed_dim=768, 16 | stride=16, 17 | flatten='freq' 18 | ): 19 | super().__init__() 20 | self.flatten = flatten 21 | patch_size = to_2tuple(patch_size) 22 | stride = to_2tuple(stride) 23 | assert flatten in ['time', 'freq'] 24 | 25 | self.patch_size = patch_size 26 | 27 | # no padding for conv 28 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) # with overlapped patches 29 | 30 | def forward(self, x): 31 | x = self.proj(x) # (B,768,64,8) 32 | if self.flatten == 'freq': 33 | x = x.flatten(2).transpose(1, 2) # flatten from dim 34 | else: 35 | x = x.transpose(-2, -1).flatten(2).transpose(1, 2) 36 | return x 37 | 38 | 39 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 40 | """ 41 | grid_size: int of the grid height and width 42 | return: 43 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 44 | """ 45 | grid_h = np.arange(grid_size, dtype=np.float32) 46 | grid_w = np.arange(grid_size, dtype=np.float32) 47 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 48 | grid = np.stack(grid, axis=0) 49 | 50 | grid = grid.reshape([2, 1, grid_size, grid_size]) 51 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 52 | if cls_token: 53 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 54 | return pos_embed 55 | 56 | 57 | def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False): 58 | """ 59 | grid_size: int of the grid height and width 60 | return: 61 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 62 | """ 63 | grid_h = np.arange(grid_size[0], dtype=np.float32) 64 | grid_w = np.arange(grid_size[1], dtype=np.float32) 65 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 66 | grid = np.stack(grid, axis=0) 67 | 68 | grid = grid.reshape([2, 1, grid_size[0], grid_size[1]]) 69 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 70 | if cls_token: 71 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 72 | return pos_embed 73 | 74 | 75 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 76 | assert embed_dim % 2 == 0 77 | 78 | # use half of dimensions to encode grid_h 79 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 80 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 81 | 82 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 83 | return emb 84 | 85 | 86 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 87 | """ 88 | embed_dim: output dimension for each position 89 | pos: a list of positions to be encoded: size (M,) 90 | out: (M, D) 91 | """ 92 | assert embed_dim % 2 == 0 93 | omega = np.arange(embed_dim // 2, dtype=np.float32) 94 | omega /= embed_dim / 2.0 95 | omega = 1.0 / 10000 ** omega # (D/2,) 96 | 97 | pos = pos.reshape(-1) # (M,) 98 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product 99 | 100 | emb_sin = np.sin(out) # (M, D/2) 101 | emb_cos = np.cos(out) # (M, D/2) 102 | 103 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 104 | return emb 105 | 106 | 107 | def interpolate_pos_embed(model, checkpoint_model): 108 | if "pos_embed" in checkpoint_model: 109 | pos_embed_checkpoint = checkpoint_model["pos_embed"] 110 | embedding_size = pos_embed_checkpoint.shape[-1] 111 | num_patches = model.patch_embed.num_patches 112 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 113 | # height (== width) for the checkpoint position embedding 114 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 115 | # height (== width) for the new position embedding 116 | new_size = int(num_patches ** 0.5) 117 | # class_token and dist_token are kept unchanged 118 | if orig_size != new_size: 119 | print( 120 | "Position interpolate from %dx%d to %dx%d" 121 | % (orig_size, orig_size, new_size, new_size) 122 | ) 123 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 124 | # only the position tokens are interpolated 125 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 126 | pos_tokens = pos_tokens.reshape( 127 | -1, orig_size, orig_size, embedding_size 128 | ).permute(0, 3, 1, 2) 129 | pos_tokens = torch.nn.functional.interpolate( 130 | pos_tokens, 131 | size=(new_size, new_size), 132 | mode="bicubic", 133 | align_corners=False, 134 | ) 135 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 136 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 137 | checkpoint_model["pos_embed"] = new_pos_embed 138 | -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from dataclasses import dataclass 7 | 8 | 9 | @dataclass 10 | class D2vDecoderConfig: 11 | decoder_dim: int = 384 12 | decoder_groups: int = 16 13 | decoder_kernel: int = 5 14 | decoder_layers: int = 5 15 | input_dropout: float = 0.1 16 | 17 | add_positions_masked: bool = False 18 | add_positions_all: bool = False 19 | 20 | decoder_residual: bool = True 21 | projection_layers: int = 1 22 | projection_ratio: float = 2.0 23 | 24 | 25 | class FixedPositionalEncoder(nn.Module): 26 | def __init__(self, pos_embed): 27 | super().__init__() 28 | self.positions = pos_embed # [1, max_t * max_freq, embed_dim] 29 | 30 | def forward(self, x, padding_mask): 31 | return self.positions 32 | 33 | 34 | class BlockEncoder(nn.Module): 35 | def __init__(self, blocks, norm_layer, layer_norm_first, layerdrop, dropout): 36 | super().__init__() 37 | self.blocks = blocks 38 | self.norm = norm_layer 39 | self.layer_norm_first = layer_norm_first 40 | self.layerdrop = layerdrop 41 | self.dropout = nn.Dropout(dropout, inplace=True) 42 | 43 | def forward(self, x, padding_mask, alibi_bias, alibi_scale): 44 | if self.norm is not None and not self.layer_norm_first: 45 | x = self.norm(x) 46 | 47 | x = self.dropout(x) 48 | 49 | for i, blk in enumerate(self.blocks): 50 | if ( 51 | not self.training 52 | or self.layerdrop == 0 53 | or (np.random.random() > self.layerdrop) 54 | ): 55 | ab = alibi_bias 56 | if ab is not None and alibi_scale is not None: 57 | scale = ( 58 | alibi_scale[i] 59 | if alibi_scale.size(0) > 1 60 | else alibi_scale.squeeze(0) 61 | ) 62 | ab = ab * scale.type_as(ab) 63 | x, _ = blk(x, padding_mask, ab) 64 | 65 | if self.norm is not None and self.layer_norm_first: 66 | x = self.norm(x) 67 | 68 | return x 69 | 70 | 71 | class AltBlock(nn.Module): 72 | def __init__( 73 | self, 74 | dim, 75 | num_heads, 76 | mlp_ratio=4.0, 77 | qkv_bias=False, 78 | qk_scale=None, 79 | drop=0.0, 80 | attn_drop=0.0, 81 | mlp_drop=0.0, 82 | post_mlp_drop=0.0, 83 | drop_path=0.0, 84 | act_layer=nn.GELU, 85 | norm_layer=nn.LayerNorm, 86 | layer_norm_first=True, 87 | ffn_targets=False, 88 | cosine_attention=False, 89 | ): 90 | super().__init__() 91 | 92 | self.layer_norm_first = layer_norm_first 93 | self.ffn_targets = ffn_targets 94 | 95 | from timm.models.vision_transformer import DropPath, Mlp 96 | 97 | self.norm1 = norm_layer(dim) 98 | self.attn = AltAttention( 99 | dim, 100 | num_heads=num_heads, 101 | qkv_bias=qkv_bias, 102 | qk_scale=qk_scale, 103 | attn_drop=attn_drop, 104 | proj_drop=drop, 105 | cosine_attention=cosine_attention, 106 | ) 107 | 108 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 109 | self.norm2 = norm_layer(dim) 110 | mlp_hidden_dim = int(dim * mlp_ratio) 111 | self.mlp = Mlp( 112 | in_features=dim, 113 | hidden_features=mlp_hidden_dim, 114 | act_layer=act_layer, 115 | drop=mlp_drop, 116 | ) 117 | self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False) 118 | 119 | def forward(self, x, padding_mask=None, alibi_bias=None): 120 | if self.layer_norm_first: 121 | x = x + self.drop_path(self.attn(self.norm1(x), padding_mask, alibi_bias)) 122 | r = x = self.mlp(self.norm2(x)) 123 | t = x 124 | x = r + self.drop_path(self.post_mlp_dropout(x)) 125 | if not self.ffn_targets: 126 | t = x 127 | else: 128 | x = x + self.drop_path(self.attn(x, padding_mask, alibi_bias)) 129 | r = x = self.norm1(x) 130 | x = self.mlp(x) 131 | t = x 132 | x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x))) 133 | if not self.ffn_targets: 134 | t = x 135 | 136 | return x, t 137 | 138 | 139 | class AltAttention(nn.Module): 140 | def __init__( 141 | self, 142 | dim, 143 | num_heads=8, 144 | qkv_bias=False, 145 | qk_scale=None, 146 | attn_drop=0.0, 147 | proj_drop=0.0, 148 | cosine_attention=False, 149 | ): 150 | super().__init__() 151 | self.num_heads = num_heads 152 | head_dim = dim // num_heads 153 | self.scale = qk_scale or head_dim ** -0.5 154 | 155 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 156 | self.attn_drop = nn.Dropout(attn_drop) 157 | self.proj = nn.Linear(dim, dim) 158 | self.proj_drop = nn.Dropout(proj_drop) 159 | 160 | self.cosine_attention = cosine_attention 161 | 162 | if cosine_attention: 163 | self.logit_scale = nn.Parameter( 164 | torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True 165 | ) 166 | 167 | def forward(self, x, padding_mask=None, alibi_bias=None): 168 | B, N, C = x.shape 169 | qkv = ( 170 | self.qkv(x) 171 | .reshape(B, N, 3, self.num_heads, C // self.num_heads) 172 | .permute(2, 0, 3, 1, 4) # qkv x B x H x L x D 173 | ) 174 | q, k, v = ( 175 | qkv[0], 176 | qkv[1], 177 | qkv[2], 178 | ) # make torchscript happy (cannot use tensor as tuple) 179 | 180 | dtype = q.dtype 181 | 182 | if self.cosine_attention: 183 | # cosine attention 184 | attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1) 185 | logit_scale = torch.clamp( 186 | self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01)) 187 | ).exp() 188 | attn = attn * logit_scale 189 | else: 190 | q = q * self.scale 191 | attn = q @ k.transpose(-2, -1) 192 | 193 | if alibi_bias is not None: 194 | attn = attn.type_as(alibi_bias) 195 | attn[:, : alibi_bias.size(1)] += alibi_bias 196 | 197 | if padding_mask is not None and padding_mask.any(): 198 | attn = attn.masked_fill( 199 | padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), 200 | float("-inf"), 201 | ) 202 | 203 | attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype) 204 | attn = self.attn_drop(attn) 205 | x = (attn @ v).transpose(1, 2) # 206 | x = x.reshape(B, N, C) 207 | x = self.proj(x) 208 | x = self.proj_drop(x) 209 | return x 210 | --------------------------------------------------------------------------------