├── .gitignore ├── ACKNOWLEDGEMENTS.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── assets ├── device_measurements.png └── latency.png ├── export.py ├── tests.py └── vision_transformers ├── __init__.py ├── attention_utils.py ├── mbconv.py └── model.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .ruff_cache 3 | .pytest_cache 4 | exported_model/ -------------------------------------------------------------------------------- /ACKNOWLEDGEMENTS.md: -------------------------------------------------------------------------------- 1 | Acknowledgements 2 | Portions of this Software may reference the following copyrighted 3 | material, the use of which is hereby acknowledged. 4 | 5 | _____________________ 6 | Swin Transformer: https://github.com/microsoft/Swin-Transformer 7 | 8 | MOAT: https://github.com/google-research/deeplab2/blob/main/model/pixel_encoder/moat.py -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2023 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Efficient Vision Transformers on Apple Neural Engine 2 | 3 | This software project accompanies the Apple ML research article [Deploying Attention-Based Vision Transformers to Apple Neural Engine](https://machinelearning.apple.com/research/vision-transformers). 4 | 5 | This project implements attention-based vision transformer efficiently on Apple Neural Engine (ANE). We release the efficient attention module, utility functions like window_partition / window_reverse, an example hybrid CNN Transformer architecture based on [MOAT](https://arxiv.org/abs/2210.01820), and an export function that converts the model to ml program. 6 | 7 | Please refer to our research article for detailed explanations of the optimizations on partition/reverse tensor ops, position embedding design, attention mechanism and split_softmax. 8 | 9 | Below are the latency comparison between different models. Our optimized MOAT is multiple times faster than the [3rd party implementation](https://github.com/RooKichenn/pytorch-MOAT) on ANE, and also much faster than the optimized DeiT/16 (tiny). 10 | 11 | 12 | 13 | 14 | 15 | ## Getting Started 16 | Install dependencies: 17 | ``` 18 | pip install torch coremltools pytest timm 19 | ``` 20 | ## Usage 21 | To use the attention module 22 | ```python 23 | import torch 24 | from vision_transformers.attention_utils import ( 25 | WindowAttention, 26 | PEType, 27 | window_partition, 28 | window_reverse, 29 | ) 30 | 31 | H, W, C = 16, 16, 32 32 | num_heads = 2 33 | # window based attention 34 | window_size = (8, 8) 35 | x = torch.rand((1, H, W, C)) 36 | window_attention = WindowAttention( 37 | dim=C, 38 | window_size=window_size, 39 | num_heads=num_heads, 40 | split_head=True, 41 | pe_type=PEType.SINGLE_HEAD_RPE, 42 | ) 43 | windows = window_partition(x, window_size) 44 | windows_reshape = windows.reshape((-1, window_size[0] * window_size[1], C)) 45 | attn_windows = window_attention(windows_reshape) 46 | output = window_reverse(attn_windows, window_size, H, W) 47 | 48 | # global attention, window size will be the full res of feature map 49 | global_attention = WindowAttention( 50 | dim=C, 51 | window_size=(H, W), 52 | num_heads=num_heads, 53 | split_head=True, 54 | pe_type=PEType.SINGLE_HEAD_RPE, 55 | ) 56 | global_reshape = x.reshape(1, H * W, C) 57 | global_output = global_attention(global_reshape) 58 | ``` 59 | To construct MOAT architecture: 60 | ```python 61 | from vision_transformers.model import _build_model 62 | 63 | image_height, image_width = 256, 256 64 | model_config, model = _build_model( 65 | shape=(1, 3, image_height, image_width), 66 | base_arch="tiny-moat-0", 67 | attention_mode="global", 68 | output_stride=32, 69 | ) 70 | ``` 71 | To export ml model that runs on ANE 72 | ``` 73 | $ python export.py 74 | ``` 75 | To verify performance, developers can now launch Xcode and simply add this model package file as a resource in their projects. After clicking on the Performance tab, the developer can generate a performance report on locally available devices. The figure below shows a performance report generated for this model on a list of iPhone devices. 76 | 77 | 78 | To customize model hyperparameter, and profile the exported model, refer to the section **Model Export Walk-Through** in the blog. 79 | 80 | ## Unit Tests 81 | We provide unit tests to ensure the build and export parts function correctly, also can be used as examples. 82 | To run the unit tests: 83 | ``` 84 | pytest tests.py 85 | ``` 86 | -------------------------------------------------------------------------------- /assets/device_measurements.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-vision-transformers-ane/c42abada32e96686f4d0bb37acb79ef7d99602ed/assets/device_measurements.png -------------------------------------------------------------------------------- /assets/latency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-vision-transformers-ane/c42abada32e96686f4d0bb37acb79ef7d99602ed/assets/latency.png -------------------------------------------------------------------------------- /export.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | import torch 6 | import coremltools as ct 7 | 8 | from vision_transformers.attention_utils import ( 9 | PEType, 10 | ) 11 | from vision_transformers.model import _build_model 12 | 13 | 14 | def moat_export( 15 | base_arch="tiny-moat-0", 16 | shape=(1, 3, 256, 256), 17 | pe_type=PEType.LePE_ADD, 18 | attention_mode="local", 19 | ): 20 | """ 21 | 22 | :param base_arch: (Default value = "tiny-moat-0") 23 | :param shape: (Default value = (1) 24 | :param 3: 25 | :param 256: 26 | :param 256): 27 | :param pe_type: (Default value = PEType.LePE_ADD) 28 | :param attention_mode: (Default value = "local") 29 | 30 | """ 31 | split_head = True 32 | batch = shape[0] 33 | print("****** batch_size: ", batch) 34 | pe_type = pe_type if "moat" in base_arch else "ape" 35 | attention_mode = attention_mode if "moat" in base_arch else "global" 36 | local_window_size = [8, 8] if attention_mode == "local" else None 37 | print("****** building model: ", base_arch) 38 | if "tiny-moat" in base_arch: 39 | _, model = _build_model( 40 | base_arch=base_arch, 41 | shape=shape, 42 | split_head=split_head, 43 | pe_type=pe_type, 44 | channel_buffer_align=False, 45 | attention_mode=attention_mode, 46 | local_window_size=local_window_size, 47 | ) 48 | resolution = f"{shape[-2]}x{shape[-1]}" 49 | 50 | x = torch.rand(shape) 51 | 52 | with torch.no_grad(): 53 | model.eval() 54 | traced_optimized_model = torch.jit.trace(model, (x,)) 55 | ane_mlpackage_obj = ct.convert( 56 | traced_optimized_model, 57 | convert_to="mlprogram", 58 | inputs=[ 59 | ct.TensorType("x", shape=x.shape), 60 | ], 61 | ) 62 | 63 | out_name = f"{base_arch}_{attention_mode}Attn_batch{batch}_{resolution}_{pe_type}_split-head_{split_head}" 64 | out_path = f"./exported_model/{out_name}.mlpackage" 65 | ane_mlpackage_obj.save(out_path) 66 | 67 | import shutil 68 | 69 | shutil.make_archive(f"{out_path}", "zip", out_path) 70 | 71 | 72 | if __name__ == "__main__": 73 | base_arch = "tiny-moat-0" 74 | attention_mode = ["global", "local"] 75 | pe_type = PEType.SINGLE_HEAD_RPE 76 | shapes = [[1, 3, 512, 512], [1, 3, 256, 256]] 77 | bs = [1] 78 | for att_mode in attention_mode: 79 | for shape in shapes: 80 | for batch in bs: 81 | shape[0] = batch 82 | moat_export( 83 | base_arch, 84 | shape, 85 | pe_type=pe_type, 86 | attention_mode=att_mode, 87 | ) 88 | -------------------------------------------------------------------------------- /tests.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from vision_transformers.model import _build_model, get_stage_strides 3 | import torch 4 | from export import moat_export 5 | import os 6 | 7 | 8 | @pytest.mark.parametrize("image_shape", [(512, 512), (256, 256)]) 9 | @pytest.mark.parametrize("output_stride", [8, 16, 32]) 10 | @pytest.mark.parametrize("attention_mode", ["global", "local"]) 11 | @pytest.mark.parametrize( 12 | "base_arch", 13 | [ 14 | "tiny-moat-0", 15 | "tiny-moat-1", 16 | "tiny-moat-2", 17 | ], 18 | ) 19 | def test_model(output_stride, base_arch, image_shape, attention_mode): 20 | """MOAT unit test 21 | 22 | :param output_stride: param base_arch: 23 | :param image_shape: param attention_mode: 24 | :param base_arch: 25 | :param attention_mode: 26 | 27 | """ 28 | image_height, image_width = image_shape 29 | with torch.no_grad(): 30 | model_config, model = _build_model( 31 | shape=(1, 3, image_height, image_width), 32 | base_arch=base_arch, 33 | attention_mode=attention_mode, 34 | output_stride=output_stride, 35 | ) 36 | stage_stride = get_stage_strides(output_stride) 37 | inputs = torch.zeros((1, 3, image_height, image_width), device="cpu") 38 | output = model(inputs) 39 | assert len(output) == 4 40 | output_h, output_w = image_height // 2, image_width // 2 41 | for stage_idx, stride in enumerate(stage_stride): 42 | assert len(output[stage_idx].shape) == 4 43 | output_h, output_w = output_h // stride, output_w // stride 44 | assert output_h == output[stage_idx].shape[-2] 45 | assert output_w == output[stage_idx].shape[-1] 46 | 47 | 48 | def test_export(): 49 | """ """ 50 | moat_export( 51 | base_arch="tiny-moat-0", 52 | shape=(1, 3, 256, 256), 53 | attention_mode="local", 54 | ) 55 | assert os.path.exists( 56 | "exported_model/tiny-moat-0_localAttn_batch1_256x256_PEType.LePE_ADD_split-head_True.mlpackage" 57 | ) 58 | -------------------------------------------------------------------------------- /vision_transformers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-vision-transformers-ane/c42abada32e96686f4d0bb37acb79ef7d99602ed/vision_transformers/__init__.py -------------------------------------------------------------------------------- /vision_transformers/attention_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | from enum import Enum, unique 6 | import logging 7 | from typing import Optional, Sequence 8 | 9 | import numpy as np 10 | 11 | import torch 12 | from torch import nn 13 | from timm.models.layers import trunc_normal_ 14 | 15 | logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") 16 | logger = logging.getLogger(__name__) 17 | logger.setLevel(logging.INFO) 18 | 19 | """ 20 | Reference: 21 | [1] Swin Transformer: https://arxiv.org/abs/2103.14030 22 | [2] Swin Github: https://github.com/microsoft/Swin-Transformer 23 | [3] Local enhanced position embedding: https://arxiv.org/abs/2107.00652 24 | """ 25 | 26 | 27 | def window_partition(x: torch.Tensor, window_size: Sequence[int]): 28 | """Partition image feature map into small windows, in an ANE friendly manner (w/o resorting to 6D tensors). 29 | 30 | :param x: feature map to be partitioned, (batch_size, H, W, C) 31 | :param window_size: target window_size, (win_h, win_w) 32 | :param x: torch.Tensor: 33 | :param window_size: Sequence[int]: 34 | :returns: (batch_size * num_windows, H, W, C). 35 | :rtype: Partitioned feature map windows 36 | 37 | """ 38 | B, H, W, C = x.shape 39 | # example partition process: 1, 12, 16, 160 -> 1, 2, 6, 16, 160 -> 2, 6, 16, 160 -> 2, 6, 2, 8, 160 -> ... 40 | x = x.reshape( 41 | (B, H // window_size[0], window_size[0], W, C) 42 | ) # B, H//w_size, w_size, W, C 43 | x = x.reshape( 44 | (B * H // window_size[0], window_size[0], W, C) 45 | ) # B * H // w_size, w_size, W, C 46 | x = x.reshape( 47 | ( 48 | B * H // window_size[0], 49 | window_size[0], 50 | W // window_size[1], 51 | window_size[1], 52 | -1, 53 | ) 54 | ) 55 | x = x.permute((0, 2, 1, 3, 4)) 56 | windows = x.reshape((-1, window_size[0], window_size[1], C)) 57 | return windows 58 | 59 | 60 | def window_reverse(windows: torch.Tensor, window_size: Sequence[int], H: int, W: int): 61 | """Merge partitioned windows back to feature map 62 | 63 | :param windows: (num_windows*batch_size, win_h, win_w, C) 64 | :param window_size: Window size 65 | :type window_size: int 66 | :param H: Height of image 67 | :type H: int 68 | :param W: Width of image 69 | :type W: int 70 | :param windows: torch.Tensor: 71 | :param window_size: Sequence[int]: 72 | :param H: int: 73 | :param W: int: 74 | :returns: (batch_size, H, W, C) 75 | :rtype: Feature maos 76 | 77 | """ 78 | B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) 79 | x = windows.reshape( 80 | ( 81 | B * H // window_size[0], 82 | W // window_size[1], 83 | window_size[0], 84 | window_size[1], 85 | -1, 86 | ) 87 | ) 88 | x = x.permute((0, 2, 1, 3, 4)).reshape( 89 | (B * H // window_size[0], window_size[0], W, -1) 90 | ) 91 | x = x.reshape((B, H // window_size[0], window_size[0], W, -1)) 92 | x = x.reshape((B, H, W, -1)) 93 | return x 94 | 95 | 96 | @unique 97 | class PEType(Enum): 98 | """ """ 99 | 100 | LePE_ADD = 0 101 | LePE_FUSED = 1 102 | RPE = 2 103 | SINGLE_HEAD_RPE = 3 104 | 105 | 106 | class WindowAttention(nn.Module): 107 | """Window/Global based multi-head self attention (MHSA) module 108 | 109 | Supports only non-shifting window attention as there is no native shifting support for ANE. 110 | Supports attention computation that is efficient on ANE by splitting on the softmax dimension. 111 | 112 | :param dim: Number of input channels. 113 | :param window_size: The height and width of the window. 114 | :param num_heads: Number of attention heads. 115 | :param qkv_bias: If True, add a learnable bias to query. 116 | :param qk_scale: Override default qk scale of head_dim ** -0.5 if set. 117 | :param attn_drop: Dropout ratio of attention weight. 118 | :param proj_drop: Dropout ratio of output. 119 | :param split_head: Whether to split head for softmax. 120 | split_softmax reduces latency significantly, therefore enabled by default. 121 | :param pe_type: position embedding type. 122 | 123 | """ 124 | 125 | def __init__( 126 | self, 127 | dim: int, 128 | window_size: Sequence[int], 129 | num_heads: int, 130 | qkv_bias: bool = True, 131 | qk_scale: Optional[float] = None, 132 | attn_drop: float = 0.0, 133 | proj_drop: float = 0.0, 134 | split_head: bool = True, 135 | pe_type: Enum = PEType.LePE_ADD, 136 | ): 137 | super().__init__() 138 | self.dim = dim 139 | self.window_size = window_size 140 | self.num_heads = num_heads 141 | head_dim = dim // num_heads 142 | 143 | self.scale = qk_scale or head_dim**-0.5 144 | 145 | self.split_head = split_head 146 | self.pe_type = pe_type 147 | 148 | if pe_type == PEType.RPE or pe_type == PEType.SINGLE_HEAD_RPE: 149 | # TODO: single-head RPE. 150 | self.rpe_num_heads = 1 if PEType.SINGLE_HEAD_RPE else num_heads 151 | logger.info(f"******Using RPE on {self.rpe_num_heads} heads.") 152 | shape = ( 153 | (2 * window_size[0] - 1), 154 | (2 * window_size[1] - 1), 155 | self.rpe_num_heads, 156 | ) 157 | 158 | self.relative_position_bias_table = nn.Parameter( 159 | torch.zeros(shape) 160 | ) # 2*Wh-1 * 2*Ww-1, nH 161 | trunc_normal_(self.relative_position_bias_table, std=0.02) 162 | 163 | # get pair-wise relative position index for each token inside the window 164 | coords_h = np.arange(self.window_size[0]) 165 | coords_w = np.arange(self.window_size[1]) 166 | 167 | mesh = np.meshgrid(coords_h, coords_w) 168 | # mesh grid returns transposed results compared w/ pytorch 169 | coords = np.stack((mesh[0].T, mesh[1].T)) # NOTE: 2, Wh, Ww 170 | coords_flatten = coords.reshape(2, -1) 171 | relative_coords = ( 172 | coords_flatten[:, :, None] - coords_flatten[:, None, :] 173 | ) # 2, Wh*Ww, Wh*Ww 174 | relative_coords = relative_coords.transpose((1, 2, 0)) # Wh*Ww, Wh*Ww, 2 175 | 176 | relative_coords[:, :, 0] += self.window_size[0] - 1 177 | relative_coords[:, :, 1] += self.window_size[1] - 1 178 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 179 | self.relative_position_index = np.sum(relative_coords, -1) # Wh*Ww, Wh*Ww 180 | bias_index_bound = shape[0] if len(shape) == 2 else shape[0] * shape[1] 181 | assert (self.relative_position_index >= 0).all() 182 | assert (self.relative_position_index < bias_index_bound).all() 183 | elif pe_type == PEType.LePE_ADD: 184 | logger.info("******Using LePE_ADD.") 185 | self.LePE_for_Value = nn.Conv2d( 186 | in_channels=dim, 187 | out_channels=dim, 188 | groups=dim, 189 | bias=qkv_bias, 190 | kernel_size=3, 191 | padding="same", 192 | ) 193 | self.abs_pe = nn.Parameter( 194 | torch.zeros(1, window_size[0] * window_size[1], dim) 195 | ) 196 | 197 | # Use separate conv1x1 projection to avoid L2 cache hit 198 | self.q_proj = nn.Conv2d( 199 | in_channels=dim, 200 | out_channels=dim, 201 | kernel_size=1, 202 | bias=qkv_bias, 203 | ) 204 | self.k_proj = nn.Conv2d( 205 | in_channels=dim, 206 | out_channels=dim, 207 | kernel_size=1, 208 | bias=qkv_bias, 209 | ) 210 | self.v_proj = nn.Conv2d( 211 | in_channels=dim, 212 | out_channels=dim, 213 | kernel_size=1, 214 | bias=qkv_bias, 215 | ) 216 | self.proj = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1) 217 | self.softmax = nn.Softmax(dim=1) # TODO: double check 218 | self.attn_drop = nn.Dropout(attn_drop) 219 | self.proj_drop = nn.Dropout(proj_drop) 220 | 221 | def forward(self, x: torch.Tensor): 222 | """ 223 | 224 | :param x: torch.Tensor: 225 | 226 | """ 227 | if self.pe_type == PEType.RPE or self.pe_type == PEType.SINGLE_HEAD_RPE: 228 | local_table = self.relative_position_bias_table.reshape( 229 | (-1, self.rpe_num_heads) 230 | ) 231 | elif self.pe_type == PEType.LePE_ADD: 232 | x += self.abs_pe 233 | 234 | BW, N, C = x.shape # BW=num_windows*B=64*1 235 | assert ( 236 | N == self.window_size[0] * self.window_size[1] 237 | ), "N: {}, num_windows: {}".format(N, self.window_size[0] * self.window_size[1]) 238 | image_shape = (BW, C, self.window_size[0], self.window_size[1]) 239 | x_2d = x.permute((0, 2, 1)).reshape(image_shape) # BCHW 240 | x_flat = torch.unsqueeze(x.permute((0, 2, 1)), 2) # BC1L 241 | 242 | q, k, v_2d = self.q_proj(x_flat), self.k_proj(x_flat), self.v_proj(x_2d) 243 | if self.pe_type == PEType.LePE_ADD: 244 | LePE = self.LePE_for_Value(v_2d).reshape(x_flat.shape) 245 | mh_LePE = torch.split(LePE, self.dim // self.num_heads, dim=1) 246 | mh_q = torch.split(q, self.dim // self.num_heads, dim=1) # BC1L 247 | mh_v = torch.split( 248 | v_2d.reshape(x_flat.shape), self.dim // self.num_heads, dim=1 249 | ) 250 | # BL1C, transposeThenSplit is more efficient than the other way around 251 | mh_k = torch.split( 252 | torch.permute(k, (0, 3, 2, 1)), self.dim // self.num_heads, dim=3 253 | ) 254 | 255 | # attn weights in each head. 256 | attn_weights = [ 257 | torch.einsum("bchq, bkhc->bkhq", qi, ki) * self.scale 258 | for qi, ki in zip(mh_q, mh_k) 259 | ] 260 | 261 | # add RPE bias 262 | if self.pe_type == PEType.RPE or self.pe_type == PEType.SINGLE_HEAD_RPE: 263 | relative_position_bias = local_table[ 264 | self.relative_position_index.reshape((-1,)) 265 | ].reshape( 266 | ( 267 | self.window_size[0] * self.window_size[1], 268 | self.window_size[0] * self.window_size[1], 269 | -1, 270 | ) 271 | ) # Wh*Ww, Wh*Ww, nH 272 | relative_position_bias = torch.unsqueeze( 273 | relative_position_bias.permute((2, 0, 1)), 2 274 | ) # nH, Wh*Ww, 1, Wh*Ww 275 | relative_position_bias = torch.split(relative_position_bias, 1, dim=0) 276 | 277 | # split_softmax 278 | for head_idx in range(self.num_heads): 279 | rpe_idx = head_idx if self.pe_type == PEType.RPE else 0 280 | attn_weights[head_idx] = ( 281 | attn_weights[head_idx] + relative_position_bias[rpe_idx] 282 | ) 283 | 284 | attn_weights = [ 285 | self.softmax(aw) for aw in attn_weights 286 | ] # softmax applied on channel "C" 287 | mh_w = [self.attn_drop(aw) for aw in attn_weights] 288 | 289 | # compute attn@v 290 | mh_x = [torch.einsum("bkhq,bchk->bchq", wi, vi) for wi, vi in zip(mh_w, mh_v)] 291 | if self.pe_type == PEType.LePE_ADD: 292 | mh_x = [v + pe for v, pe in zip(mh_x, mh_LePE)] 293 | # concat heads 294 | x = torch.cat(mh_x, dim=1) 295 | 296 | x = self.proj(x) 297 | x = self.proj_drop(x) 298 | x = torch.squeeze(x, dim=2) 299 | x = x.permute((0, 2, 1)) # BLC 300 | return x 301 | -------------------------------------------------------------------------------- /vision_transformers/mbconv.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | """PyTorch MBConv block for MOAT.""" 6 | import torch 7 | from torch import nn 8 | from torch.nn import Conv2d 9 | from typing import Optional 10 | 11 | 12 | class Swish(nn.Module): 13 | """ """ 14 | 15 | def forward(self, x): 16 | """ 17 | 18 | :param x: 19 | 20 | """ 21 | return x * torch.sigmoid(x) 22 | 23 | 24 | class MBConvBlock(nn.Module): 25 | """Mobile Inverted Residual Bottleneck Block 26 | References: 27 | [1] https://arxiv.org/abs/1704.04861 (MobileNet v1) 28 | [2] https://arxiv.org/abs/1801.04381 (MobileNet v2) 29 | [3] https://arxiv.org/abs/1905.02244 (MobileNet v3) 30 | 31 | :param block_args: BlockArgs, see above 32 | :type block_args: namedtuple 33 | :param global_params: GlobalParam, see above 34 | :type global_params: namedtuple 35 | :param name: Block name 36 | :type name: string 37 | 38 | """ 39 | 40 | def __init__( 41 | self, 42 | block_args, 43 | batch_norm_momentum: Optional[float] = 0.99, 44 | batch_norm_epsilon: Optional[float] = 1e-3, 45 | drop_rate: Optional[float] = None, 46 | pre_norm: bool = False, 47 | name: str = "_block_", 48 | activation: str = "swish", 49 | ): 50 | super(MBConvBlock, self).__init__() 51 | self.name = name 52 | self._block_args = block_args 53 | self.block_activation = activation 54 | # in torch.batchnorm, (1-momentum)*running_mean + momentum * x_new 55 | self._bn_mom = 1 - batch_norm_momentum 56 | self._bn_eps = batch_norm_epsilon 57 | self.has_se = (self._block_args.se_ratio is not None) and ( 58 | 0 < self._block_args.se_ratio <= 1 59 | ) 60 | self.drop_rate = drop_rate 61 | self.id_skip = block_args.id_skip # skip connection and drop connect 62 | self.pre_norm = pre_norm 63 | 64 | if self.pre_norm: 65 | self.pre_norm_layer = nn.BatchNorm2d( 66 | num_features=self._block_args.input_filters 67 | ) 68 | 69 | # Expansion phase 70 | inp = self._block_args.input_filters # number of input channels 71 | oup = ( 72 | self._block_args.input_filters * self._block_args.expand_ratio 73 | ) # number of output channels 74 | if self._block_args.expand_ratio != 1: 75 | self._expand_conv = Conv2d( 76 | in_channels=inp, out_channels=oup, kernel_size=1, bias=False 77 | ) 78 | self._bn0 = nn.BatchNorm2d( 79 | num_features=oup, momentum=self._bn_mom, eps=self._bn_eps 80 | ) 81 | 82 | # Depthwise convolution phase 83 | k = self._block_args.kernel_size 84 | s = self._block_args.stride 85 | self._depthwise_conv = Conv2d( 86 | in_channels=oup, 87 | out_channels=oup, 88 | groups=oup, # groups makes it depthwise 89 | kernel_size=k, 90 | stride=s, 91 | padding="same" if s == 1 else 1, 92 | bias=False, 93 | ) 94 | self._bn1 = nn.BatchNorm2d( 95 | num_features=oup, momentum=self._bn_mom, eps=self._bn_eps 96 | ) 97 | 98 | # Squeeze and Excitation layer, if desired 99 | if self.has_se: 100 | num_squeezed_channels = max( 101 | 1, int(self._block_args.input_filters * self._block_args.se_ratio) 102 | ) 103 | self._se_reduce = Conv2d( 104 | in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1 105 | ) 106 | self._se_expand = Conv2d( 107 | in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1 108 | ) 109 | 110 | # Output phase 111 | final_oup = self._block_args.output_filters 112 | self._project_conv = Conv2d( 113 | in_channels=oup, 114 | out_channels=final_oup, 115 | kernel_size=1, 116 | bias=False, 117 | padding="same", 118 | ) 119 | self._bn2 = nn.BatchNorm2d( 120 | num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps 121 | ) 122 | 123 | if self.block_activation == "swish": 124 | self._swish = Swish() 125 | elif self.block_activation == "relu": 126 | self._swish = nn.ReLU(inplace=True) 127 | elif self.block_activation == "gelu": 128 | self._swish = nn.GELU() 129 | else: 130 | raise ValueError("Unsupported activation in MBConv block.") 131 | 132 | if self.drop_rate is not None: 133 | self.dropout = nn.Dropout(self.drop_rate) 134 | 135 | if block_args.stride == 2: 136 | self.shortcut_pool = nn.AvgPool2d( 137 | kernel_size=2, 138 | stride=2, 139 | ) 140 | self.shortcut_conv = None 141 | if block_args.input_filters != block_args.output_filters: 142 | self.shortcut_conv = Conv2d( 143 | in_channels=block_args.input_filters, 144 | out_channels=block_args.output_filters, 145 | kernel_size=1, 146 | stride=1, 147 | padding="same", 148 | bias=True, 149 | ) 150 | 151 | def forward(self, inputs): 152 | """param inputs: input tensor (batch_size, C, H, W) 153 | param drop_connect_rate: drop connect rate (float, between 0 and 1) 154 | 155 | :param inputs: 156 | 157 | """ 158 | 159 | shortcut = inputs 160 | x = inputs 161 | 162 | if self.pre_norm: 163 | x = self.pre_norm_layer(x) 164 | # Expansion and Depthwise Convolution 165 | if self._block_args.expand_ratio != 1: 166 | x = self._swish(self._bn0(self._expand_conv(x))) 167 | x = self._swish(self._bn1(self._depthwise_conv(x))) 168 | 169 | # Squeeze and Excitation 170 | if self.has_se: 171 | x_squeezed = nn.AdaptiveAvgPool2d(output_size=(1, 1))(x) 172 | x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed))) 173 | x = torch.sigmoid(x_squeezed) * x 174 | 175 | x = self._bn2(self._project_conv(x)) 176 | 177 | # Skip connection and drop connect 178 | input_filters, output_filters = ( 179 | self._block_args.input_filters, 180 | self._block_args.output_filters, 181 | ) 182 | if self.id_skip: 183 | if self._block_args.stride == 1 and input_filters == output_filters: 184 | if self.drop_rate: 185 | x = self.dropout(x) 186 | elif self._block_args.stride == 2: 187 | shortcut = self.shortcut_pool(inputs) 188 | if self.shortcut_conv is not None: 189 | shortcut = self.shortcut_conv(shortcut) 190 | elif ( 191 | self._block_args.stride == 1 192 | or self._block_args.stride == [1, 1] 193 | and input_filters != output_filters 194 | ): 195 | if self.shortcut_conv is not None: 196 | shortcut = self.shortcut_conv(shortcut) 197 | x = torch.add(x, shortcut) # skip connection 198 | return x 199 | -------------------------------------------------------------------------------- /vision_transformers/model.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | """ 6 | Reference: 7 | [1] MOAT: https://arxiv.org/pdf/2210.01820.pdf. 8 | [2] Tensorflow official impl: https://github.com/google-research/deeplab2/blob/main/model/pixel_encoder/moat.py 9 | """ 10 | import logging 11 | import math 12 | from typing import Optional, Sequence, Any, Tuple 13 | 14 | import collections 15 | from vision_transformers.attention_utils import ( 16 | WindowAttention, 17 | PEType, 18 | window_partition, 19 | window_reverse, 20 | ) 21 | from vision_transformers.mbconv import MBConvBlock 22 | import torch 23 | from torch import nn 24 | from torch.nn import GELU 25 | from dataclasses import dataclass 26 | 27 | logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") 28 | logger = logging.getLogger(__name__) 29 | logger.setLevel(logging.INFO) 30 | 31 | BlockArgs = collections.namedtuple( 32 | "BlockArgs", 33 | [ 34 | "kernel_size", 35 | "num_repeat", 36 | "input_filters", 37 | "output_filters", 38 | "expand_ratio", 39 | "id_skip", 40 | "stride", 41 | "se_ratio", 42 | ], 43 | ) 44 | 45 | # Change namedtuple defaults 46 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) 47 | 48 | 49 | @dataclass 50 | class MOATConfig: 51 | """MOAT config. Default values are from tiny_moat_0. 52 | 53 | For detailed model hyperparameter configuration, refer to the MOAT paper. 54 | Supports two attention modes of operation, which you may experiment with for tradeoff between PnP and KPI: 55 | 1. For local attention, window size is limited to a fixed window size (up to user configuration depends on input resolution). 56 | 2. For global attention, do attention on the full feature map. 57 | 58 | :param stem_size: hidden size of the conv stem. 59 | :param block_type: type of each stage. 60 | :param num_blocks: number of blocks for each stage. 61 | :param hidden_size: hidden size of each stage. 62 | :param window_size: window_size of each stage if using attention block. 63 | :param activation: activation function to use. 64 | :param attention_mode: use global or local attention. 65 | :param split_head: whether do split head attention. split_head makes it run more efficiently on ANE. 66 | 67 | """ 68 | 69 | stem_size: Sequence[int] = (32, 32) 70 | block_type: Sequence[str] = ("mbconv", "mbconv", "moat", "moat") 71 | num_blocks: Sequence[int] = (2, 3, 7, 2) 72 | hidden_size: Sequence[int] = (32, 64, 128, 256) 73 | window_size: Sequence[Any] = (None, None, (14, 14), (7, 7)) 74 | activation: nn.Module = GELU() 75 | attention_mode: str = "global" 76 | split_head: bool = True 77 | stage_stride: Sequence[int] = (2, 2, 2, 2) 78 | mbconv_block_expand_ratio: int = 4 79 | moat_block_expand_ratio: int = 4 80 | pe_type: PEType = PEType.LePE_ADD 81 | 82 | def __post_init__(self): 83 | if self.attention_mode == "local": 84 | # window_size should be limited to local context 85 | local_context_lower, local_context_upper = 6, 16 86 | for window in self.window_size: 87 | if window is not None: 88 | assert isinstance(window, tuple) or isinstance(window, list) 89 | for hw in window: 90 | assert hw >= local_context_lower and hw <= local_context_upper 91 | 92 | 93 | @dataclass 94 | class MOATBlockConfig: 95 | """MOAT block config. 96 | 97 | :param block_name: name of the block. 98 | :param window_size: attention window size. 99 | :param attn_norm_class: normalization layer in attention. 100 | :param activation: activation function. 101 | :param head_dim: dimension of each head. 102 | :param kernel_size: kernel size for MBConv block 103 | :param stride: stride for MBConv block 104 | :param expand_ratio: expansion ratio in the MBConv block. 105 | :param id_skip: do skip connection or not. 106 | :param se_ratio: channel reduction ratio in squeeze and excitation, if 0 or None, no SE. 107 | :param attention_mode: use global or local attention. 108 | :param split_head: whether do split head attention. split_head makes it run more efficiently on ANE. 109 | :param pe_type: position embedding type 110 | 111 | """ 112 | 113 | block_name: str = "moat_block" 114 | window_size: Optional[Sequence[int]] = None 115 | attn_norm_class: nn.Module = nn.LayerNorm 116 | head_dim: int = 32 # dim of each head 117 | activation: nn.Module = GELU() 118 | # BlockArgs 119 | kernel_size: int = 3 120 | stride: int = 1 121 | input_filters: int = 32 122 | output_filters: int = 32 123 | expand_ratio: int = 4 124 | id_skip: bool = True 125 | se_ratio: Optional[float] = None 126 | attention_mode: str = "global" 127 | split_head: bool = False 128 | pe_type: PEType = PEType.LePE_ADD 129 | 130 | 131 | class Stem(nn.Sequential): 132 | """Convolutional stem consists of 2 convolution layers. 133 | 134 | :param dims: specifies the dimensions for the convolution stems. 135 | 136 | """ 137 | 138 | def __init__(self, dims: Sequence[int]): 139 | stem_layers = [] 140 | 141 | for i in range(len(dims)): 142 | norm_layer = None 143 | activation_layer = None 144 | 145 | if i == 0: 146 | activation_layer = GELU() 147 | norm_layer = True 148 | 149 | stride = 2 if i == 0 else 1 150 | in_channels = dims[i - 1] if i >= 1 else 3 151 | conv_layer = nn.Conv2d( 152 | in_channels=in_channels, 153 | out_channels=dims[i], 154 | kernel_size=3, 155 | bias=True, 156 | stride=stride, 157 | padding="same" 158 | if stride == 1 159 | else 1, # strided conv does not support "same" 160 | ) 161 | stem_layers.append(conv_layer) 162 | if activation_layer is not None: 163 | stem_layers.append(activation_layer) 164 | if norm_layer: 165 | stem_layers.append(nn.BatchNorm2d(dims[i])) 166 | 167 | super().__init__(*stem_layers) 168 | 169 | 170 | class MOATBlock(nn.Module): 171 | """A MOAT block consists of MBConv (w/o squeeze-excitation blocks) and MHSA. 172 | 173 | :param config: a MOATBlockConfig object to specify block dims, attention mode, window size, etc. 174 | 175 | """ 176 | 177 | def __init__(self, config: MOATBlockConfig): 178 | super().__init__() 179 | block_args = BlockArgs( 180 | kernel_size=config.kernel_size, 181 | stride=config.stride, 182 | se_ratio=None, # MOAT block does not use SE branch 183 | input_filters=config.input_filters, 184 | output_filters=config.output_filters, 185 | id_skip=True, 186 | expand_ratio=config.expand_ratio, 187 | ) 188 | self._mbconv = MBConvBlock( 189 | block_args, 190 | activation="gelu", 191 | pre_norm=True, 192 | ) 193 | 194 | # dim after MBConv block 195 | dim = config.output_filters 196 | 197 | # currently LN apply normalization to the last few dimensions, therefore need NHWC format. 198 | # see pytorch issue 71456: https://github.com/pytorch/pytorch/issues/71465 199 | self._attn_norm = config.attn_norm_class( 200 | normalized_shape=dim, 201 | eps=1e-5, 202 | elementwise_affine=True, 203 | ) 204 | assert ( 205 | dim % config.head_dim == 0 206 | ), "tensor dimension: {} can not divide by head_dim: {}.".format( 207 | dim, config.head_dim 208 | ) 209 | num_heads = dim // config.head_dim 210 | print("######pe_type in MOATBlock: ", config.pe_type) 211 | self._window_attention = WindowAttention( 212 | dim, 213 | window_size=config.window_size, 214 | num_heads=num_heads, 215 | split_head=config.split_head, 216 | pe_type=config.pe_type, 217 | ) 218 | self.window_size = config.window_size 219 | self.attention_mode = config.attention_mode 220 | 221 | def forward(self, inputs): 222 | """inputs: (batch_size, C, H, W) 223 | output: ((batch_size, C, H//stride, W//stride) 224 | 225 | :param inputs: 226 | 227 | """ 228 | 229 | # MBConv block may contain downsampling layer 230 | output = self._mbconv(inputs) 231 | N, C, H, W = output.shape 232 | 233 | # shortcut is before LN in MOAT 234 | shortcut = output 235 | # transpose to prepare the tensor for window_partition 236 | output = output.permute((0, 2, 3, 1)) # NHWC 237 | 238 | assert ( 239 | output.shape[-1] % 32 == 0 240 | ), "ANE buffer not aligned, last dim={}.".format(output.shape[-1]) 241 | output = self._attn_norm(output) 242 | 243 | if self.attention_mode == "local": 244 | x_windows = window_partition(output, self.window_size) 245 | x_windows = x_windows.reshape( 246 | (-1, self.window_size[0] * self.window_size[1], C) 247 | ) 248 | attn_windows = self._window_attention(x_windows) 249 | output = window_reverse(attn_windows, self.window_size, H, W) 250 | # No need for window_partion/reverse on low res input 251 | elif self.attention_mode == "global": 252 | global_attention_windows = output.reshape((N, H * W, C)) 253 | output = self._window_attention(global_attention_windows) 254 | 255 | output = output.reshape((N, H, W, C)).permute((0, 3, 1, 2)) # NCHW 256 | 257 | # may add drop_path here for output 258 | output = shortcut + output 259 | return output # NCHW 260 | 261 | 262 | class MOAT(nn.Module): 263 | """MOAT model definition. 264 | 265 | :param config: a MOATConfig object to specify MOAT variant, attention mode, etc. 266 | 267 | """ 268 | 269 | def __init__(self, config: MOATConfig): 270 | super().__init__() 271 | self._stem = Stem(dims=config.stem_size) 272 | # Need to use ModuleList (instead of vanilla python list) for the module to be properly registered 273 | self._blocks = nn.ModuleList() 274 | self.config = config 275 | for stage_id in range(len(config.block_type)): 276 | stage_blocks = nn.ModuleList() 277 | stage_input_filters = ( 278 | config.hidden_size[stage_id - 1] 279 | if stage_id > 0 280 | else config.stem_size[-1] 281 | ) 282 | stage_output_filters = config.hidden_size[stage_id] 283 | 284 | for local_block_id in range(config.num_blocks[stage_id]): 285 | block_stride = 1 286 | block_name = "block_{:0>2d}_{:0>2d}_".format(stage_id, local_block_id) 287 | 288 | if local_block_id == 0: # downsample in the first block of each stage 289 | block_stride = config.stage_stride[stage_id] 290 | block_input_filters = stage_input_filters 291 | else: 292 | block_input_filters = stage_output_filters 293 | 294 | if config.block_type[stage_id] == "mbconv": 295 | block_args = BlockArgs( 296 | kernel_size=3, 297 | stride=block_stride, 298 | se_ratio=0.25, # SE block reduction ratio 299 | input_filters=block_input_filters, 300 | output_filters=stage_output_filters, 301 | expand_ratio=config.mbconv_block_expand_ratio, 302 | id_skip=True, 303 | ) 304 | block = MBConvBlock( 305 | block_args, 306 | activation="gelu", 307 | pre_norm=True, 308 | ) 309 | elif config.block_type[stage_id] == "moat": 310 | print("######pe_type: ", config.pe_type) 311 | block_config = MOATBlockConfig( 312 | block_name=block_name, 313 | stride=block_stride, 314 | window_size=config.window_size[stage_id], 315 | input_filters=block_input_filters, 316 | output_filters=stage_output_filters, 317 | attention_mode=config.attention_mode, 318 | split_head=config.split_head, 319 | expand_ratio=config.moat_block_expand_ratio, 320 | pe_type=config.pe_type, 321 | ) 322 | block = MOATBlock(block_config) 323 | else: 324 | raise ValueError( 325 | "Network type {} not defined.".format(config.block_type) 326 | ) 327 | 328 | stage_blocks.append(block) 329 | 330 | self._blocks.append(stage_blocks) 331 | 332 | def forward(self, inputs: torch.Tensor, out_indices: Sequence[int] = (0, 1, 2, 3)): 333 | """ 334 | 335 | :param inputs: torch.Tensor: 336 | :param out_indices: Sequence[int]: (Default value = (0) 337 | :param 1: param 2: 338 | :param 3: 339 | :param inputs: torch.Tensor: 340 | :param out_indices: Sequence[int]: (Default value = (0) 341 | :param 2: 342 | :param 3): 343 | 344 | """ 345 | outs = [] 346 | output = self._stem(inputs) 347 | 348 | for stage_id, stage_blocks in enumerate(self._blocks): 349 | for block in stage_blocks: 350 | output = block(output) 351 | if stage_id in out_indices: 352 | outs.append(output) 353 | return outs 354 | 355 | 356 | def get_stage_strides(output_stride): 357 | """ 358 | 359 | :param output_stride: 360 | 361 | """ 362 | if output_stride == 32: 363 | stage_stride = (2, 2, 2, 2) 364 | elif output_stride == 16: 365 | stage_stride = (2, 2, 2, 1) 366 | elif output_stride == 8: 367 | stage_stride = (2, 2, 1, 1) 368 | return stage_stride 369 | 370 | 371 | def _build_model( 372 | shape: Sequence[int] = (1, 3, 192, 256), 373 | base_arch: str = "tiny-moat-2", 374 | attention_mode: str = "global", 375 | split_head: bool = True, 376 | output_stride: int = 32, 377 | channel_buffer_align: bool = True, 378 | num_blocks: Sequence[int] = (2, 3, 7, 2), 379 | mbconv_block_expand_ratio: int = 4, 380 | moat_block_expand_ratio: int = 4, 381 | local_window_size: Optional[Sequence[int]] = None, 382 | pe_type: PEType = PEType.LePE_ADD, 383 | ) -> Tuple[MOATConfig, MOAT]: 384 | """Construct MOAT models. 385 | 386 | :param shape: input shape to the model. 387 | :param base_arch: architecture variant of MOAT. 388 | :param attention_mode: global or local (window based) attention 389 | :param output_stride: stride of output with respect to the input res, e.g., 32 meaning output will be 1/32 of input res 390 | :param split_head: whether do split_head attention. split_head is enabled by default as it is faster on ANE. 391 | :param channel_buffer_align: if True, make channel divisible by 32 392 | :param num_blocks: number of blocks in each stage. 393 | :param mbconv_block_expand_ratio: expansion ratio of mbconv blocks in first 2 stages 394 | :param moat_block_expand_ratio: expansion ratio of moat blocks in last 2 stages 395 | :param local_window_size: local window size of attention. 396 | :param pe_type: position embedding type 397 | :param shape: Sequence[int]: (Default value = (1) 398 | :param 3: param 192: 399 | :param 256: param base_arch: str: (Default value = "tiny-moat-2") 400 | :param attention_mode: str: (Default value = "global") 401 | :param split_head: bool: (Default value = True) 402 | :param output_stride: int: (Default value = 32) 403 | :param channel_buffer_align: bool: (Default value = True) 404 | :param num_blocks: Sequence[int]: (Default value = (2) 405 | :param 7: param 2): 406 | :param mbconv_block_expand_ratio: int: (Default value = 4) 407 | :param moat_block_expand_ratio: int: (Default value = 4) 408 | :param local_window_size: Optional[Sequence[int]]: (Default value = None) 409 | :param pe_type: PEType: (Default value = PEType.LePE_ADD) 410 | :param shape: Sequence[int]: (Default value = (1) 411 | :param 192: 412 | :param 256): 413 | :param base_arch: str: (Default value = "tiny-moat-2") 414 | :param attention_mode: str: (Default value = "global") 415 | :param split_head: bool: (Default value = True) 416 | :param output_stride: int: (Default value = 32) 417 | :param channel_buffer_align: bool: (Default value = True) 418 | :param num_blocks: Sequence[int]: (Default value = (2) 419 | :param 2): 420 | :param mbconv_block_expand_ratio: int: (Default value = 4) 421 | :param moat_block_expand_ratio: int: (Default value = 4) 422 | :param local_window_size: Optional[Sequence[int]]: (Default value = None) 423 | :param pe_type: PEType: (Default value = PEType.LePE_ADD) 424 | :returns: tiny moat model according to the config. 425 | 426 | """ 427 | assert shape[-2] % 32 == 0 428 | assert shape[-1] % 32 == 0 429 | 430 | if attention_mode == "global" and local_window_size is not None: 431 | raise RuntimeError( 432 | "global attention should not have local_window_size for local attention." 433 | ) 434 | 435 | if output_stride == 32: 436 | out_stride_stage3, out_stride_stage4 = 16, 32 437 | else: 438 | out_stride_stage3, out_stride_stage4 = output_stride, output_stride 439 | 440 | stage_stride = get_stage_strides(output_stride) 441 | 442 | feature_hw = [shape[-2] // output_stride, shape[-1] // output_stride] 443 | 444 | def _get_default_local_window_size(feature_hw): 445 | """ 446 | 447 | :param feature_hw: 448 | 449 | """ 450 | window_hw = [] 451 | attention_field_candidates = [6, 8, 10] 452 | for h_or_w in feature_hw: 453 | if h_or_w % attention_field_candidates[0] == 0: 454 | window_hw.append(attention_field_candidates[0]) 455 | elif h_or_w % attention_field_candidates[1] == 0: 456 | window_hw.append(attention_field_candidates[1]) 457 | elif h_or_w % attention_field_candidates[2] == 0: 458 | window_hw.append(attention_field_candidates[2]) 459 | else: 460 | raise RuntimeError( 461 | f"Not a regular feature map size: {feature_hw}, consider other input resolution." 462 | ) 463 | return window_hw 464 | 465 | if attention_mode == "global": 466 | window_size = ( 467 | None, 468 | None, 469 | [shape[-2] // out_stride_stage3, shape[-1] // out_stride_stage3], 470 | [shape[-2] // out_stride_stage4, shape[-1] // out_stride_stage4], 471 | ) 472 | elif attention_mode == "local": 473 | if local_window_size is None: 474 | local_window_size = _get_default_local_window_size(feature_hw) 475 | window_size = (None, None, local_window_size, local_window_size) 476 | else: 477 | raise ValueError("Undefined attention mode.") 478 | 479 | if base_arch == "tiny-moat-0": 480 | tiny_moat_config = MOATConfig( 481 | num_blocks=num_blocks, 482 | window_size=window_size, 483 | attention_mode=attention_mode, 484 | split_head=split_head, 485 | stage_stride=stage_stride, 486 | mbconv_block_expand_ratio=mbconv_block_expand_ratio, 487 | moat_block_expand_ratio=moat_block_expand_ratio, 488 | pe_type=pe_type, 489 | ) 490 | elif base_arch == "tiny-moat-1": 491 | tiny_moat_config = MOATConfig( 492 | stem_size=(40, 40), 493 | hidden_size=(40, 80, 160, 320), 494 | window_size=window_size, 495 | attention_mode=attention_mode, 496 | num_blocks=num_blocks, 497 | split_head=split_head, 498 | stage_stride=stage_stride, 499 | mbconv_block_expand_ratio=mbconv_block_expand_ratio, 500 | moat_block_expand_ratio=moat_block_expand_ratio, 501 | pe_type=pe_type, 502 | ) 503 | elif base_arch == "tiny-moat-2": 504 | tiny_moat_config = MOATConfig( 505 | stem_size=(56, 56), 506 | hidden_size=(56, 112, 224, 448), 507 | window_size=window_size, 508 | num_blocks=num_blocks, 509 | attention_mode=attention_mode, 510 | split_head=split_head, 511 | stage_stride=stage_stride, 512 | mbconv_block_expand_ratio=mbconv_block_expand_ratio, 513 | moat_block_expand_ratio=moat_block_expand_ratio, 514 | pe_type=pe_type, 515 | ) 516 | 517 | if channel_buffer_align: 518 | aligned_hidden_size = [ 519 | math.ceil(h / 32) * 32 for h in tiny_moat_config.hidden_size 520 | ] 521 | aligned_stem_size = [math.ceil(h / 32) * 32 for h in tiny_moat_config.stem_size] 522 | tiny_moat_config.hidden_size = aligned_hidden_size 523 | tiny_moat_config.stem_size = aligned_stem_size 524 | 525 | logger.info("Using config: %s", tiny_moat_config) 526 | tiny_moat = MOAT(tiny_moat_config) 527 | 528 | return tiny_moat_config, tiny_moat 529 | 530 | 531 | if __name__ == "__main__": 532 | config, model = _build_model() 533 | --------------------------------------------------------------------------------