├── .gitignore ├── LICENSE ├── README.md ├── audio_mae.py ├── mae.PNG ├── modules.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Rishikesh (ऋषिकेश) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Masked Autoencoders that Listen : 2 | This repo is Unofficial implementation of paper [Masked Autoencoders that Listen](https://arxiv.org/abs/2207.06405). Audio-MAE first encodes audio spectrogram patches with a high masking ratio, feeding only the non-masked tokens through encoder layers. The decoder then re-orders and decodes the encoded context padded with mask tokens, in order to reconstruct the input spectrogram. 3 | ![](mae.PNG) 4 | 5 | * Most of the code borrowed from repos mentioned in reference section below. 6 | 7 | ## Usage: 8 | ```python 9 | import torch 10 | from audio_mae import AudioMaskedAutoencoderViT 11 | 12 | audio_mels = torch.ones([2, 1, 1024, 128]) 13 | 14 | # Paper recommended archs 15 | model = AudioMaskedAutoencoderViT( 16 | num_mels=128, mel_len=1024, in_chans=1, 17 | patch_size=16, embed_dim=768, encoder_depth=12, num_heads=12, 18 | decoder_embed_dim=512, decoder_depth=16, decoder_num_heads=16, 19 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6)) 20 | 21 | loss, pred, mask = model(audio_mels) 22 | ``` 23 | 24 | ## Citation: 25 | ``` 26 | @misc{https://doi.org/10.48550/arxiv.2207.06405, 27 | doi = {10.48550/ARXIV.2207.06405}, 28 | 29 | url = {https://arxiv.org/abs/2207.06405}, 30 | 31 | author = {Huang, Po-Yao and Xu, Hu and Li, Juncheng and Baevski, Alexei and Auli, Michael and Galuba, Wojciech and Metze, Florian and Feichtenhofer, Christoph}, 32 | 33 | keywords = {Sound (cs.SD), Artificial Intelligence (cs.AI), Machine Learning (cs.LG), Audio and Speech Processing (eess.AS), FOS: Computer and information sciences, FOS: Computer and information sciences, FOS: Electrical engineering, electronic engineering, information engineering, FOS: Electrical engineering, electronic engineering, information engineering}, 34 | 35 | title = {Masked Autoencoders that Listen}, 36 | 37 | publisher = {arXiv}, 38 | 39 | year = {2022}, 40 | 41 | copyright = {Creative Commons Attribution 4.0 International} 42 | } 43 | ``` 44 | ``` 45 | 46 | @misc{https://doi.org/10.48550/arxiv.2203.16691, 47 | doi = {10.48550/ARXIV.2203.16691}, 48 | 49 | url = {https://arxiv.org/abs/2203.16691}, 50 | 51 | author = {Baade, Alan and Peng, Puyuan and Harwath, David}, 52 | 53 | keywords = {Audio and Speech Processing (eess.AS), Artificial Intelligence (cs.AI), Computation and Language (cs.CL), Machine Learning (cs.LG), Sound (cs.SD), FOS: Electrical engineering, electronic engineering, information engineering, FOS: Electrical engineering, electronic engineering, information engineering, FOS: Computer and information sciences, FOS: Computer and information sciences}, 54 | 55 | title = {MAE-AST: Masked Autoencoding Audio Spectrogram Transformer}, 56 | 57 | publisher = {arXiv}, 58 | 59 | year = {2022}, 60 | 61 | copyright = {Creative Commons Attribution 4.0 International} 62 | } 63 | ``` 64 | 65 | ## Reference: 66 | * [Masked Autoencoders that Listen](https://arxiv.org/abs/2207.06405) 67 | * [MAE-AST: Masked Autoencoding Audio Spectrogram Transformer](https://arxiv.org/abs/2203.16691) 68 | * https://github.com/facebookresearch/mae 69 | * https://github.com/berniwal/swin-transformer-pytorch 70 | * https://github.com/microsoft/Swin-Transformer 71 | * https://github.com/rwightman/pytorch-image-models 72 | -------------------------------------------------------------------------------- /audio_mae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from einops import rearrange 4 | from modules import Attention, PreNorm, FeedForward, SwinBlock, PatchEmbed 5 | from utils import get_2d_sincos_pos_embed 6 | 7 | 8 | class Transformer(nn.Module): 9 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 10 | super().__init__() 11 | self.layers = nn.ModuleList([]) 12 | for _ in range(depth): 13 | self.layers.append(nn.ModuleList([ 14 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 15 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) 16 | ])) 17 | def forward(self, x): 18 | for attn, ff in self.layers: 19 | x = attn(x) + x 20 | x = ff(x) + x 21 | return x 22 | 23 | 24 | class AudioMaskedAutoencoderViT(nn.Module): 25 | """ Masked Autoencoder with VisionTransformer backbone 26 | """ 27 | 28 | def __init__(self, num_mels=128, mel_len=1024, patch_size=16, in_chans=3, 29 | embed_dim=768, encoder_depth=12, num_heads=12, 30 | decoder_embed_dim=512, decoder_depth=16, decoder_num_heads=16, 31 | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False): 32 | super().__init__() 33 | 34 | # -------------------------------------------------------------------------- 35 | # MAE encoder specifics 36 | self.patch_embed = PatchEmbed((mel_len, num_mels), (patch_size, patch_size), in_chans, embed_dim) 37 | num_patches = self.patch_embed.num_patches 38 | self.grid_h = int(mel_len // patch_size) 39 | self.grid_w = int(num_mels // patch_size) 40 | 41 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 42 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), 43 | requires_grad=False) # fixed sin-cos embedding 44 | 45 | self.encoder = Transformer(embed_dim, encoder_depth, num_heads, embed_dim // num_heads, mlp_ratio * embed_dim) 46 | 47 | self.norm = norm_layer(embed_dim) 48 | # -------------------------------------------------------------------------- 49 | 50 | # -------------------------------------------------------------------------- 51 | # MAE decoder specifics 52 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) 53 | 54 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 55 | 56 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches, decoder_embed_dim), 57 | requires_grad=False) # fixed sin-cos embedding 58 | 59 | self.decoder_blocks = nn.ModuleList([ 60 | SwinBlock(decoder_embed_dim, decoder_num_heads, decoder_embed_dim // num_heads, 61 | mlp_ratio * decoder_embed_dim, 62 | shifted=True, window_size=4, relative_pos_embedding=True) 63 | for i in range(decoder_depth)]) 64 | 65 | self.decoder_norm = norm_layer(decoder_embed_dim) 66 | self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True) # decoder to patch 67 | # -------------------------------------------------------------------------- 68 | 69 | self.norm_pix_loss = norm_pix_loss 70 | 71 | self.initialize_weights() 72 | 73 | def initialize_weights(self): 74 | # initialization 75 | # initialize (and freeze) pos_embed by sin-cos embedding 76 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], (self.grid_h, self.grid_w), cls_token=True) 77 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 78 | 79 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], (self.grid_h, self.grid_w), 80 | cls_token=False) 81 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 82 | 83 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 84 | w = self.patch_embed.proj.weight.data 85 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 86 | 87 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 88 | torch.nn.init.normal_(self.cls_token, std=.02) 89 | torch.nn.init.normal_(self.mask_token, std=.02) 90 | 91 | # initialize nn.Linear and nn.LayerNorm 92 | self.apply(self._init_weights) 93 | 94 | def _init_weights(self, m): 95 | if isinstance(m, nn.Linear): 96 | # we use xavier_uniform following official JAX ViT: 97 | torch.nn.init.xavier_uniform_(m.weight) 98 | if isinstance(m, nn.Linear) and m.bias is not None: 99 | nn.init.constant_(m.bias, 0) 100 | elif isinstance(m, nn.LayerNorm): 101 | nn.init.constant_(m.bias, 0) 102 | nn.init.constant_(m.weight, 1.0) 103 | 104 | def patchify(self, imgs): 105 | """ 106 | imgs: (N, 1, H, W) 107 | x: (N, L, patch_size**2 *3) 108 | """ 109 | p = self.patch_embed.patch_size[0] 110 | 111 | h = self.grid_h 112 | w = self.grid_w 113 | x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p)) 114 | x = torch.einsum('nchpwq->nhwpqc', x) 115 | x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2)) 116 | return x 117 | 118 | def unpatchify(self, x): 119 | """ 120 | x: (N, L, patch_size**2 *3) 121 | imgs: (N, 3, H, W) 122 | """ 123 | p = self.patch_embed.patch_size[0] 124 | h = w = int(x.shape[1] ** .5) 125 | assert h * w == x.shape[1] 126 | 127 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) 128 | x = torch.einsum('nhwpqc->nchpwq', x) 129 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 130 | return imgs 131 | 132 | def random_masking(self, x, mask_ratio): 133 | """ 134 | Perform per-sample random masking by per-sample shuffling. 135 | Per-sample shuffling is done by argsort random noise. 136 | x: [N, L, D], sequence 137 | """ 138 | N, L, D = x.shape # batch, length, dim 139 | len_keep = int(L * (1 - mask_ratio)) 140 | 141 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 142 | 143 | # sort noise for each sample 144 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 145 | ids_restore = torch.argsort(ids_shuffle, dim=1) 146 | 147 | # keep the first subset 148 | ids_keep = ids_shuffle[:, :len_keep] 149 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 150 | 151 | # generate the binary mask: 0 is keep, 1 is remove 152 | mask = torch.ones([N, L], device=x.device) 153 | mask[:, :len_keep] = 0 154 | # unshuffle to get the binary mask 155 | mask = torch.gather(mask, dim=1, index=ids_restore) 156 | 157 | return x_masked, mask, ids_restore 158 | 159 | def forward_encoder(self, x, mask_ratio): 160 | # embed patches 161 | x = self.patch_embed(x) 162 | 163 | # add pos embed w/o cls token 164 | x = x + self.pos_embed[:, 1:, :] 165 | 166 | # masking: length -> length * mask_ratio 167 | x, mask, ids_restore = self.random_masking(x, mask_ratio) 168 | 169 | # append cls token 170 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 171 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 172 | x = torch.cat((cls_tokens, x), dim=1) 173 | 174 | # apply Transformer blocks 175 | x = self.encoder(x) 176 | x = self.norm(x) 177 | 178 | return x, mask, ids_restore 179 | 180 | def forward_decoder(self, x, ids_restore): 181 | 182 | # embed tokens 183 | x = self.decoder_embed(x[:, 1:, :]) 184 | 185 | # append mask tokens to sequence 186 | 187 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1) 188 | x_ = torch.cat([x, mask_tokens], dim=1) # no cls token 189 | x = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 190 | 191 | b, l, c = x.shape 192 | 193 | assert l == self.grid_h * self.grid_w, "input feature has wrong size" 194 | 195 | # add pos embed 196 | x = x + self.decoder_pos_embed 197 | x = x.view(b, self.grid_h, self.grid_w, c) 198 | # apply Transformer blocks 199 | for blk in self.decoder_blocks: 200 | x = blk(x) 201 | 202 | x = rearrange(x, 'b h w c -> b (h w) c') 203 | x = self.decoder_norm(x) 204 | 205 | # predictor projection 206 | x = self.decoder_pred(x) 207 | 208 | # remove cls token 209 | # x = x[:, 1:, :] 210 | 211 | return x 212 | 213 | def forward_loss(self, imgs, pred, mask): 214 | """ 215 | imgs: [N, 3, H, W] -> [2, 1, 1024, 128] 216 | pred: [N, L, p*p*1] 217 | mask: [N, L], 0 is keep, 1 is remove, 218 | """ 219 | target = self.patchify(imgs) 220 | if self.norm_pix_loss: 221 | mean = target.mean(dim=-1, keepdim=True) 222 | var = target.var(dim=-1, keepdim=True) 223 | target = (target - mean) / (var + 1.e-6) ** .5 224 | 225 | loss = (pred - target) ** 2 226 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 227 | 228 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 229 | return loss 230 | 231 | def forward(self, imgs, mask_ratio=0.8): 232 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) 233 | pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*1] 234 | loss = self.forward_loss(imgs, pred, mask) 235 | return loss, pred, mask -------------------------------------------------------------------------------- /mae.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rishikksh20/AudioMAE-pytorch/d46d8a0ecb1e32e05cda523ad3c48e86b3121b71/mae.PNG -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import numpy as np 4 | from einops import rearrange 5 | 6 | 7 | 8 | class PatchEmbed(nn.Module): 9 | """ Image to Patch Embedding 10 | """ 11 | def __init__(self, img_size=(1024, 128), patch_size=(16, 16), in_chans=1, embed_dim=768): 12 | super().__init__() 13 | 14 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 15 | self.img_size = img_size 16 | self.patch_size = patch_size 17 | self.num_patches = num_patches 18 | 19 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 20 | 21 | def forward(self, x): 22 | B, C, H, W = x.shape 23 | assert H == self.img_size[0] and W == self.img_size[1], \ 24 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 25 | x = self.proj(x).flatten(2).transpose(1, 2) 26 | return x 27 | 28 | 29 | 30 | class CyclicShift(nn.Module): 31 | def __init__(self, displacement): 32 | super().__init__() 33 | self.displacement = displacement 34 | 35 | def forward(self, x): 36 | return torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2)) 37 | 38 | 39 | class Residual(nn.Module): 40 | def __init__(self, fn): 41 | super().__init__() 42 | self.fn = fn 43 | def forward(self, x, **kwargs): 44 | return self.fn(x, **kwargs) + x 45 | 46 | class PreNorm(nn.Module): 47 | def __init__(self, dim, fn): 48 | super().__init__() 49 | self.norm = nn.LayerNorm(dim) 50 | self.fn = fn 51 | def forward(self, x, **kwargs): 52 | return self.fn(self.norm(x), **kwargs) 53 | 54 | class FeedForward(nn.Module): 55 | def __init__(self, dim, hidden_dim, dropout = 0.): 56 | super().__init__() 57 | self.net = nn.Sequential( 58 | nn.Linear(dim, hidden_dim), 59 | nn.GELU(), 60 | nn.Dropout(dropout), 61 | nn.Linear(hidden_dim, dim), 62 | nn.Dropout(dropout) 63 | ) 64 | def forward(self, x): 65 | return self.net(x) 66 | 67 | class Attention(nn.Module): 68 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 69 | super().__init__() 70 | inner_dim = dim_head * heads 71 | project_out = not (heads == 1 and dim_head == dim) 72 | 73 | self.heads = heads 74 | self.scale = dim_head ** -0.5 75 | 76 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 77 | 78 | self.to_out = nn.Sequential( 79 | nn.Linear(inner_dim, dim), 80 | nn.Dropout(dropout) 81 | ) if project_out else nn.Identity() 82 | 83 | def forward(self, x): 84 | b, n, _, h = *x.shape, self.heads 85 | qkv = self.to_qkv(x).chunk(3, dim = -1) 86 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 87 | 88 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 89 | 90 | attn = dots.softmax(dim=-1) 91 | 92 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 93 | out = rearrange(out, 'b h n d -> b n (h d)') 94 | out = self.to_out(out) 95 | return out 96 | 97 | 98 | 99 | def create_mask(window_size, displacement, upper_lower, left_right): 100 | mask = torch.zeros(window_size ** 2, window_size ** 2) 101 | 102 | if upper_lower: 103 | mask[-displacement * window_size:, :-displacement * window_size] = float('-inf') 104 | mask[:-displacement * window_size, -displacement * window_size:] = float('-inf') 105 | 106 | if left_right: 107 | mask = rearrange(mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=window_size, h2=window_size) 108 | mask[:, -displacement:, :, :-displacement] = float('-inf') 109 | mask[:, :-displacement, :, -displacement:] = float('-inf') 110 | mask = rearrange(mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)') 111 | 112 | return mask 113 | 114 | 115 | def get_relative_distances(window_size): 116 | indices = torch.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)])) 117 | distances = indices[None, :, :] - indices[:, None, :] 118 | return distances 119 | 120 | 121 | class WindowAttention(nn.Module): 122 | def __init__(self, dim, heads, head_dim, shifted, window_size, relative_pos_embedding): 123 | super().__init__() 124 | inner_dim = head_dim * heads 125 | 126 | self.heads = heads 127 | self.scale = head_dim ** -0.5 128 | self.window_size = window_size 129 | self.relative_pos_embedding = relative_pos_embedding 130 | self.shifted = shifted 131 | 132 | if self.shifted: 133 | displacement = window_size // 2 134 | self.cyclic_shift = CyclicShift(-displacement) 135 | self.cyclic_back_shift = CyclicShift(displacement) 136 | self.upper_lower_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement, 137 | upper_lower=True, left_right=False), requires_grad=False) 138 | self.left_right_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement, 139 | upper_lower=False, left_right=True), requires_grad=False) 140 | 141 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 142 | 143 | if self.relative_pos_embedding: 144 | self.relative_indices = get_relative_distances(window_size) + window_size - 1 145 | self.pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1)) 146 | else: 147 | self.pos_embedding = nn.Parameter(torch.randn(window_size ** 2, window_size ** 2)) 148 | 149 | self.to_out = nn.Linear(inner_dim, dim) 150 | 151 | def forward(self, x): 152 | if self.shifted: 153 | x = self.cyclic_shift(x) 154 | 155 | b, n_h, n_w, _, h = *x.shape, self.heads 156 | 157 | qkv = self.to_qkv(x).chunk(3, dim=-1) 158 | nw_h = n_h // self.window_size 159 | nw_w = n_w // self.window_size 160 | 161 | q, k, v = map( 162 | lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d', 163 | h=h, w_h=self.window_size, w_w=self.window_size), qkv) 164 | 165 | dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) * self.scale 166 | 167 | if self.relative_pos_embedding: 168 | dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]] 169 | else: 170 | dots += self.pos_embedding 171 | 172 | if self.shifted: 173 | dots[:, :, -nw_w:] += self.upper_lower_mask 174 | dots[:, :, nw_w - 1::nw_w] += self.left_right_mask 175 | 176 | attn = dots.softmax(dim=-1) 177 | 178 | out = einsum('b h w i j, b h w j d -> b h w i d', attn, v) 179 | out = rearrange(out, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)', 180 | h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w) 181 | out = self.to_out(out) 182 | 183 | if self.shifted: 184 | out = self.cyclic_back_shift(out) 185 | return out 186 | 187 | 188 | class SwinBlock(nn.Module): 189 | def __init__(self, dim, heads, head_dim, mlp_dim, shifted, window_size, relative_pos_embedding): 190 | super().__init__() 191 | self.attention_block = Residual(PreNorm(dim, WindowAttention(dim=dim, 192 | heads=heads, 193 | head_dim=head_dim, 194 | shifted=shifted, 195 | window_size=window_size, 196 | relative_pos_embedding=relative_pos_embedding))) 197 | self.mlp_block = Residual(PreNorm(dim, FeedForward(dim=dim, hidden_dim=mlp_dim))) 198 | 199 | def forward(self, x): 200 | x = self.attention_block(x) 201 | x = self.mlp_block(x) 202 | return x 203 | 204 | 205 | 206 | 207 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_sizes, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_sizes[0], dtype=np.float32) 27 | grid_w = np.arange(grid_sizes[1], dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_sizes[0], grid_sizes[1]]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | --------------------------------------------------------------------------------