├── .gitignore
├── ACKNOWLEDGEMENTS.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── assets
├── device_measurements.png
└── latency.png
├── export.py
├── tests.py
└── vision_transformers
├── __init__.py
├── attention_utils.py
├── mbconv.py
└── model.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .ruff_cache
3 | .pytest_cache
4 | exported_model/
--------------------------------------------------------------------------------
/ACKNOWLEDGEMENTS.md:
--------------------------------------------------------------------------------
1 | Acknowledgements
2 | Portions of this Software may reference the following copyrighted
3 | material, the use of which is hereby acknowledged.
4 |
5 | _____________________
6 | Swin Transformer: https://github.com/microsoft/Swin-Transformer
7 |
8 | MOAT: https://github.com/google-research/deeplab2/blob/main/model/pixel_encoder/moat.py
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to making participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 |
55 | ## Enforcement
56 |
57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All
59 | complaints will be reviewed and investigated and will result in a response that
60 | is deemed necessary and appropriate to the circumstances. The project team is
61 | obligated to maintain confidentiality with regard to the reporter of an incident.
62 | Further details of specific enforcement policies may be posted separately.
63 |
64 | Project maintainers who do not follow or enforce the Code of Conduct in good
65 | faith may face temporary or permanent repercussions as determined by other
66 | members of the project's leadership.
67 |
68 | ## Attribution
69 |
70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4,
71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html)
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contribution Guide
2 |
3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository.
4 |
5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged.
6 |
7 | ## Before you get started
8 |
9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE).
10 |
11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md).
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (C) 2023 Apple Inc. All Rights Reserved.
2 |
3 | IMPORTANT: This Apple software is supplied to you by Apple
4 | Inc. ("Apple") in consideration of your agreement to the following
5 | terms, and your use, installation, modification or redistribution of
6 | this Apple software constitutes acceptance of these terms. If you do
7 | not agree with these terms, please do not use, install, modify or
8 | redistribute this Apple software.
9 |
10 | In consideration of your agreement to abide by the following terms, and
11 | subject to these terms, Apple grants you a personal, non-exclusive
12 | license, under Apple's copyrights in this original Apple software (the
13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple
14 | Software, with or without modifications, in source and/or binary forms;
15 | provided that if you redistribute the Apple Software in its entirety and
16 | without modifications, you must retain this notice and the following
17 | text and disclaimers in all such redistributions of the Apple Software.
18 | Neither the name, trademarks, service marks or logos of Apple Inc. may
19 | be used to endorse or promote products derived from the Apple Software
20 | without specific prior written permission from Apple. Except as
21 | expressly stated in this notice, no other rights or licenses, express or
22 | implied, are granted by Apple herein, including but not limited to any
23 | patent rights that may be infringed by your derivative works or by other
24 | works in which the Apple Software may be incorporated.
25 |
26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE
27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
31 |
32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
39 | POSSIBILITY OF SUCH DAMAGE.
40 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Efficient Vision Transformers on Apple Neural Engine
2 |
3 | This software project accompanies the Apple ML research article [Deploying Attention-Based Vision Transformers to Apple Neural Engine](https://machinelearning.apple.com/research/vision-transformers).
4 |
5 | This project implements attention-based vision transformer efficiently on Apple Neural Engine (ANE). We release the efficient attention module, utility functions like window_partition / window_reverse, an example hybrid CNN Transformer architecture based on [MOAT](https://arxiv.org/abs/2210.01820), and an export function that converts the model to ml program.
6 |
7 | Please refer to our research article for detailed explanations of the optimizations on partition/reverse tensor ops, position embedding design, attention mechanism and split_softmax.
8 |
9 | Below are the latency comparison between different models. Our optimized MOAT is multiple times faster than the [3rd party implementation](https://github.com/RooKichenn/pytorch-MOAT) on ANE, and also much faster than the optimized DeiT/16 (tiny).
10 |
11 |
12 |
13 |
14 |
15 | ## Getting Started
16 | Install dependencies:
17 | ```
18 | pip install torch coremltools pytest timm
19 | ```
20 | ## Usage
21 | To use the attention module
22 | ```python
23 | import torch
24 | from vision_transformers.attention_utils import (
25 | WindowAttention,
26 | PEType,
27 | window_partition,
28 | window_reverse,
29 | )
30 |
31 | H, W, C = 16, 16, 32
32 | num_heads = 2
33 | # window based attention
34 | window_size = (8, 8)
35 | x = torch.rand((1, H, W, C))
36 | window_attention = WindowAttention(
37 | dim=C,
38 | window_size=window_size,
39 | num_heads=num_heads,
40 | split_head=True,
41 | pe_type=PEType.SINGLE_HEAD_RPE,
42 | )
43 | windows = window_partition(x, window_size)
44 | windows_reshape = windows.reshape((-1, window_size[0] * window_size[1], C))
45 | attn_windows = window_attention(windows_reshape)
46 | output = window_reverse(attn_windows, window_size, H, W)
47 |
48 | # global attention, window size will be the full res of feature map
49 | global_attention = WindowAttention(
50 | dim=C,
51 | window_size=(H, W),
52 | num_heads=num_heads,
53 | split_head=True,
54 | pe_type=PEType.SINGLE_HEAD_RPE,
55 | )
56 | global_reshape = x.reshape(1, H * W, C)
57 | global_output = global_attention(global_reshape)
58 | ```
59 | To construct MOAT architecture:
60 | ```python
61 | from vision_transformers.model import _build_model
62 |
63 | image_height, image_width = 256, 256
64 | model_config, model = _build_model(
65 | shape=(1, 3, image_height, image_width),
66 | base_arch="tiny-moat-0",
67 | attention_mode="global",
68 | output_stride=32,
69 | )
70 | ```
71 | To export ml model that runs on ANE
72 | ```
73 | $ python export.py
74 | ```
75 | To verify performance, developers can now launch Xcode and simply add this model package file as a resource in their projects. After clicking on the Performance tab, the developer can generate a performance report on locally available devices. The figure below shows a performance report generated for this model on a list of iPhone devices.
76 |
77 |
78 | To customize model hyperparameter, and profile the exported model, refer to the section **Model Export Walk-Through** in the blog.
79 |
80 | ## Unit Tests
81 | We provide unit tests to ensure the build and export parts function correctly, also can be used as examples.
82 | To run the unit tests:
83 | ```
84 | pytest tests.py
85 | ```
86 |
--------------------------------------------------------------------------------
/assets/device_measurements.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-vision-transformers-ane/c42abada32e96686f4d0bb37acb79ef7d99602ed/assets/device_measurements.png
--------------------------------------------------------------------------------
/assets/latency.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-vision-transformers-ane/c42abada32e96686f4d0bb37acb79ef7d99602ed/assets/latency.png
--------------------------------------------------------------------------------
/export.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 | import torch
6 | import coremltools as ct
7 |
8 | from vision_transformers.attention_utils import (
9 | PEType,
10 | )
11 | from vision_transformers.model import _build_model
12 |
13 |
14 | def moat_export(
15 | base_arch="tiny-moat-0",
16 | shape=(1, 3, 256, 256),
17 | pe_type=PEType.LePE_ADD,
18 | attention_mode="local",
19 | ):
20 | """
21 |
22 | :param base_arch: (Default value = "tiny-moat-0")
23 | :param shape: (Default value = (1)
24 | :param 3:
25 | :param 256:
26 | :param 256):
27 | :param pe_type: (Default value = PEType.LePE_ADD)
28 | :param attention_mode: (Default value = "local")
29 |
30 | """
31 | split_head = True
32 | batch = shape[0]
33 | print("****** batch_size: ", batch)
34 | pe_type = pe_type if "moat" in base_arch else "ape"
35 | attention_mode = attention_mode if "moat" in base_arch else "global"
36 | local_window_size = [8, 8] if attention_mode == "local" else None
37 | print("****** building model: ", base_arch)
38 | if "tiny-moat" in base_arch:
39 | _, model = _build_model(
40 | base_arch=base_arch,
41 | shape=shape,
42 | split_head=split_head,
43 | pe_type=pe_type,
44 | channel_buffer_align=False,
45 | attention_mode=attention_mode,
46 | local_window_size=local_window_size,
47 | )
48 | resolution = f"{shape[-2]}x{shape[-1]}"
49 |
50 | x = torch.rand(shape)
51 |
52 | with torch.no_grad():
53 | model.eval()
54 | traced_optimized_model = torch.jit.trace(model, (x,))
55 | ane_mlpackage_obj = ct.convert(
56 | traced_optimized_model,
57 | convert_to="mlprogram",
58 | inputs=[
59 | ct.TensorType("x", shape=x.shape),
60 | ],
61 | )
62 |
63 | out_name = f"{base_arch}_{attention_mode}Attn_batch{batch}_{resolution}_{pe_type}_split-head_{split_head}"
64 | out_path = f"./exported_model/{out_name}.mlpackage"
65 | ane_mlpackage_obj.save(out_path)
66 |
67 | import shutil
68 |
69 | shutil.make_archive(f"{out_path}", "zip", out_path)
70 |
71 |
72 | if __name__ == "__main__":
73 | base_arch = "tiny-moat-0"
74 | attention_mode = ["global", "local"]
75 | pe_type = PEType.SINGLE_HEAD_RPE
76 | shapes = [[1, 3, 512, 512], [1, 3, 256, 256]]
77 | bs = [1]
78 | for att_mode in attention_mode:
79 | for shape in shapes:
80 | for batch in bs:
81 | shape[0] = batch
82 | moat_export(
83 | base_arch,
84 | shape,
85 | pe_type=pe_type,
86 | attention_mode=att_mode,
87 | )
88 |
--------------------------------------------------------------------------------
/tests.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from vision_transformers.model import _build_model, get_stage_strides
3 | import torch
4 | from export import moat_export
5 | import os
6 |
7 |
8 | @pytest.mark.parametrize("image_shape", [(512, 512), (256, 256)])
9 | @pytest.mark.parametrize("output_stride", [8, 16, 32])
10 | @pytest.mark.parametrize("attention_mode", ["global", "local"])
11 | @pytest.mark.parametrize(
12 | "base_arch",
13 | [
14 | "tiny-moat-0",
15 | "tiny-moat-1",
16 | "tiny-moat-2",
17 | ],
18 | )
19 | def test_model(output_stride, base_arch, image_shape, attention_mode):
20 | """MOAT unit test
21 |
22 | :param output_stride: param base_arch:
23 | :param image_shape: param attention_mode:
24 | :param base_arch:
25 | :param attention_mode:
26 |
27 | """
28 | image_height, image_width = image_shape
29 | with torch.no_grad():
30 | model_config, model = _build_model(
31 | shape=(1, 3, image_height, image_width),
32 | base_arch=base_arch,
33 | attention_mode=attention_mode,
34 | output_stride=output_stride,
35 | )
36 | stage_stride = get_stage_strides(output_stride)
37 | inputs = torch.zeros((1, 3, image_height, image_width), device="cpu")
38 | output = model(inputs)
39 | assert len(output) == 4
40 | output_h, output_w = image_height // 2, image_width // 2
41 | for stage_idx, stride in enumerate(stage_stride):
42 | assert len(output[stage_idx].shape) == 4
43 | output_h, output_w = output_h // stride, output_w // stride
44 | assert output_h == output[stage_idx].shape[-2]
45 | assert output_w == output[stage_idx].shape[-1]
46 |
47 |
48 | def test_export():
49 | """ """
50 | moat_export(
51 | base_arch="tiny-moat-0",
52 | shape=(1, 3, 256, 256),
53 | attention_mode="local",
54 | )
55 | assert os.path.exists(
56 | "exported_model/tiny-moat-0_localAttn_batch1_256x256_PEType.LePE_ADD_split-head_True.mlpackage"
57 | )
58 |
--------------------------------------------------------------------------------
/vision_transformers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-vision-transformers-ane/c42abada32e96686f4d0bb37acb79ef7d99602ed/vision_transformers/__init__.py
--------------------------------------------------------------------------------
/vision_transformers/attention_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 | from enum import Enum, unique
6 | import logging
7 | from typing import Optional, Sequence
8 |
9 | import numpy as np
10 |
11 | import torch
12 | from torch import nn
13 | from timm.models.layers import trunc_normal_
14 |
15 | logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
16 | logger = logging.getLogger(__name__)
17 | logger.setLevel(logging.INFO)
18 |
19 | """
20 | Reference:
21 | [1] Swin Transformer: https://arxiv.org/abs/2103.14030
22 | [2] Swin Github: https://github.com/microsoft/Swin-Transformer
23 | [3] Local enhanced position embedding: https://arxiv.org/abs/2107.00652
24 | """
25 |
26 |
27 | def window_partition(x: torch.Tensor, window_size: Sequence[int]):
28 | """Partition image feature map into small windows, in an ANE friendly manner (w/o resorting to 6D tensors).
29 |
30 | :param x: feature map to be partitioned, (batch_size, H, W, C)
31 | :param window_size: target window_size, (win_h, win_w)
32 | :param x: torch.Tensor:
33 | :param window_size: Sequence[int]:
34 | :returns: (batch_size * num_windows, H, W, C).
35 | :rtype: Partitioned feature map windows
36 |
37 | """
38 | B, H, W, C = x.shape
39 | # example partition process: 1, 12, 16, 160 -> 1, 2, 6, 16, 160 -> 2, 6, 16, 160 -> 2, 6, 2, 8, 160 -> ...
40 | x = x.reshape(
41 | (B, H // window_size[0], window_size[0], W, C)
42 | ) # B, H//w_size, w_size, W, C
43 | x = x.reshape(
44 | (B * H // window_size[0], window_size[0], W, C)
45 | ) # B * H // w_size, w_size, W, C
46 | x = x.reshape(
47 | (
48 | B * H // window_size[0],
49 | window_size[0],
50 | W // window_size[1],
51 | window_size[1],
52 | -1,
53 | )
54 | )
55 | x = x.permute((0, 2, 1, 3, 4))
56 | windows = x.reshape((-1, window_size[0], window_size[1], C))
57 | return windows
58 |
59 |
60 | def window_reverse(windows: torch.Tensor, window_size: Sequence[int], H: int, W: int):
61 | """Merge partitioned windows back to feature map
62 |
63 | :param windows: (num_windows*batch_size, win_h, win_w, C)
64 | :param window_size: Window size
65 | :type window_size: int
66 | :param H: Height of image
67 | :type H: int
68 | :param W: Width of image
69 | :type W: int
70 | :param windows: torch.Tensor:
71 | :param window_size: Sequence[int]:
72 | :param H: int:
73 | :param W: int:
74 | :returns: (batch_size, H, W, C)
75 | :rtype: Feature maos
76 |
77 | """
78 | B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
79 | x = windows.reshape(
80 | (
81 | B * H // window_size[0],
82 | W // window_size[1],
83 | window_size[0],
84 | window_size[1],
85 | -1,
86 | )
87 | )
88 | x = x.permute((0, 2, 1, 3, 4)).reshape(
89 | (B * H // window_size[0], window_size[0], W, -1)
90 | )
91 | x = x.reshape((B, H // window_size[0], window_size[0], W, -1))
92 | x = x.reshape((B, H, W, -1))
93 | return x
94 |
95 |
96 | @unique
97 | class PEType(Enum):
98 | """ """
99 |
100 | LePE_ADD = 0
101 | LePE_FUSED = 1
102 | RPE = 2
103 | SINGLE_HEAD_RPE = 3
104 |
105 |
106 | class WindowAttention(nn.Module):
107 | """Window/Global based multi-head self attention (MHSA) module
108 |
109 | Supports only non-shifting window attention as there is no native shifting support for ANE.
110 | Supports attention computation that is efficient on ANE by splitting on the softmax dimension.
111 |
112 | :param dim: Number of input channels.
113 | :param window_size: The height and width of the window.
114 | :param num_heads: Number of attention heads.
115 | :param qkv_bias: If True, add a learnable bias to query.
116 | :param qk_scale: Override default qk scale of head_dim ** -0.5 if set.
117 | :param attn_drop: Dropout ratio of attention weight.
118 | :param proj_drop: Dropout ratio of output.
119 | :param split_head: Whether to split head for softmax.
120 | split_softmax reduces latency significantly, therefore enabled by default.
121 | :param pe_type: position embedding type.
122 |
123 | """
124 |
125 | def __init__(
126 | self,
127 | dim: int,
128 | window_size: Sequence[int],
129 | num_heads: int,
130 | qkv_bias: bool = True,
131 | qk_scale: Optional[float] = None,
132 | attn_drop: float = 0.0,
133 | proj_drop: float = 0.0,
134 | split_head: bool = True,
135 | pe_type: Enum = PEType.LePE_ADD,
136 | ):
137 | super().__init__()
138 | self.dim = dim
139 | self.window_size = window_size
140 | self.num_heads = num_heads
141 | head_dim = dim // num_heads
142 |
143 | self.scale = qk_scale or head_dim**-0.5
144 |
145 | self.split_head = split_head
146 | self.pe_type = pe_type
147 |
148 | if pe_type == PEType.RPE or pe_type == PEType.SINGLE_HEAD_RPE:
149 | # TODO: single-head RPE.
150 | self.rpe_num_heads = 1 if PEType.SINGLE_HEAD_RPE else num_heads
151 | logger.info(f"******Using RPE on {self.rpe_num_heads} heads.")
152 | shape = (
153 | (2 * window_size[0] - 1),
154 | (2 * window_size[1] - 1),
155 | self.rpe_num_heads,
156 | )
157 |
158 | self.relative_position_bias_table = nn.Parameter(
159 | torch.zeros(shape)
160 | ) # 2*Wh-1 * 2*Ww-1, nH
161 | trunc_normal_(self.relative_position_bias_table, std=0.02)
162 |
163 | # get pair-wise relative position index for each token inside the window
164 | coords_h = np.arange(self.window_size[0])
165 | coords_w = np.arange(self.window_size[1])
166 |
167 | mesh = np.meshgrid(coords_h, coords_w)
168 | # mesh grid returns transposed results compared w/ pytorch
169 | coords = np.stack((mesh[0].T, mesh[1].T)) # NOTE: 2, Wh, Ww
170 | coords_flatten = coords.reshape(2, -1)
171 | relative_coords = (
172 | coords_flatten[:, :, None] - coords_flatten[:, None, :]
173 | ) # 2, Wh*Ww, Wh*Ww
174 | relative_coords = relative_coords.transpose((1, 2, 0)) # Wh*Ww, Wh*Ww, 2
175 |
176 | relative_coords[:, :, 0] += self.window_size[0] - 1
177 | relative_coords[:, :, 1] += self.window_size[1] - 1
178 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
179 | self.relative_position_index = np.sum(relative_coords, -1) # Wh*Ww, Wh*Ww
180 | bias_index_bound = shape[0] if len(shape) == 2 else shape[0] * shape[1]
181 | assert (self.relative_position_index >= 0).all()
182 | assert (self.relative_position_index < bias_index_bound).all()
183 | elif pe_type == PEType.LePE_ADD:
184 | logger.info("******Using LePE_ADD.")
185 | self.LePE_for_Value = nn.Conv2d(
186 | in_channels=dim,
187 | out_channels=dim,
188 | groups=dim,
189 | bias=qkv_bias,
190 | kernel_size=3,
191 | padding="same",
192 | )
193 | self.abs_pe = nn.Parameter(
194 | torch.zeros(1, window_size[0] * window_size[1], dim)
195 | )
196 |
197 | # Use separate conv1x1 projection to avoid L2 cache hit
198 | self.q_proj = nn.Conv2d(
199 | in_channels=dim,
200 | out_channels=dim,
201 | kernel_size=1,
202 | bias=qkv_bias,
203 | )
204 | self.k_proj = nn.Conv2d(
205 | in_channels=dim,
206 | out_channels=dim,
207 | kernel_size=1,
208 | bias=qkv_bias,
209 | )
210 | self.v_proj = nn.Conv2d(
211 | in_channels=dim,
212 | out_channels=dim,
213 | kernel_size=1,
214 | bias=qkv_bias,
215 | )
216 | self.proj = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1)
217 | self.softmax = nn.Softmax(dim=1) # TODO: double check
218 | self.attn_drop = nn.Dropout(attn_drop)
219 | self.proj_drop = nn.Dropout(proj_drop)
220 |
221 | def forward(self, x: torch.Tensor):
222 | """
223 |
224 | :param x: torch.Tensor:
225 |
226 | """
227 | if self.pe_type == PEType.RPE or self.pe_type == PEType.SINGLE_HEAD_RPE:
228 | local_table = self.relative_position_bias_table.reshape(
229 | (-1, self.rpe_num_heads)
230 | )
231 | elif self.pe_type == PEType.LePE_ADD:
232 | x += self.abs_pe
233 |
234 | BW, N, C = x.shape # BW=num_windows*B=64*1
235 | assert (
236 | N == self.window_size[0] * self.window_size[1]
237 | ), "N: {}, num_windows: {}".format(N, self.window_size[0] * self.window_size[1])
238 | image_shape = (BW, C, self.window_size[0], self.window_size[1])
239 | x_2d = x.permute((0, 2, 1)).reshape(image_shape) # BCHW
240 | x_flat = torch.unsqueeze(x.permute((0, 2, 1)), 2) # BC1L
241 |
242 | q, k, v_2d = self.q_proj(x_flat), self.k_proj(x_flat), self.v_proj(x_2d)
243 | if self.pe_type == PEType.LePE_ADD:
244 | LePE = self.LePE_for_Value(v_2d).reshape(x_flat.shape)
245 | mh_LePE = torch.split(LePE, self.dim // self.num_heads, dim=1)
246 | mh_q = torch.split(q, self.dim // self.num_heads, dim=1) # BC1L
247 | mh_v = torch.split(
248 | v_2d.reshape(x_flat.shape), self.dim // self.num_heads, dim=1
249 | )
250 | # BL1C, transposeThenSplit is more efficient than the other way around
251 | mh_k = torch.split(
252 | torch.permute(k, (0, 3, 2, 1)), self.dim // self.num_heads, dim=3
253 | )
254 |
255 | # attn weights in each head.
256 | attn_weights = [
257 | torch.einsum("bchq, bkhc->bkhq", qi, ki) * self.scale
258 | for qi, ki in zip(mh_q, mh_k)
259 | ]
260 |
261 | # add RPE bias
262 | if self.pe_type == PEType.RPE or self.pe_type == PEType.SINGLE_HEAD_RPE:
263 | relative_position_bias = local_table[
264 | self.relative_position_index.reshape((-1,))
265 | ].reshape(
266 | (
267 | self.window_size[0] * self.window_size[1],
268 | self.window_size[0] * self.window_size[1],
269 | -1,
270 | )
271 | ) # Wh*Ww, Wh*Ww, nH
272 | relative_position_bias = torch.unsqueeze(
273 | relative_position_bias.permute((2, 0, 1)), 2
274 | ) # nH, Wh*Ww, 1, Wh*Ww
275 | relative_position_bias = torch.split(relative_position_bias, 1, dim=0)
276 |
277 | # split_softmax
278 | for head_idx in range(self.num_heads):
279 | rpe_idx = head_idx if self.pe_type == PEType.RPE else 0
280 | attn_weights[head_idx] = (
281 | attn_weights[head_idx] + relative_position_bias[rpe_idx]
282 | )
283 |
284 | attn_weights = [
285 | self.softmax(aw) for aw in attn_weights
286 | ] # softmax applied on channel "C"
287 | mh_w = [self.attn_drop(aw) for aw in attn_weights]
288 |
289 | # compute attn@v
290 | mh_x = [torch.einsum("bkhq,bchk->bchq", wi, vi) for wi, vi in zip(mh_w, mh_v)]
291 | if self.pe_type == PEType.LePE_ADD:
292 | mh_x = [v + pe for v, pe in zip(mh_x, mh_LePE)]
293 | # concat heads
294 | x = torch.cat(mh_x, dim=1)
295 |
296 | x = self.proj(x)
297 | x = self.proj_drop(x)
298 | x = torch.squeeze(x, dim=2)
299 | x = x.permute((0, 2, 1)) # BLC
300 | return x
301 |
--------------------------------------------------------------------------------
/vision_transformers/mbconv.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 | """PyTorch MBConv block for MOAT."""
6 | import torch
7 | from torch import nn
8 | from torch.nn import Conv2d
9 | from typing import Optional
10 |
11 |
12 | class Swish(nn.Module):
13 | """ """
14 |
15 | def forward(self, x):
16 | """
17 |
18 | :param x:
19 |
20 | """
21 | return x * torch.sigmoid(x)
22 |
23 |
24 | class MBConvBlock(nn.Module):
25 | """Mobile Inverted Residual Bottleneck Block
26 | References:
27 | [1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
28 | [2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
29 | [3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
30 |
31 | :param block_args: BlockArgs, see above
32 | :type block_args: namedtuple
33 | :param global_params: GlobalParam, see above
34 | :type global_params: namedtuple
35 | :param name: Block name
36 | :type name: string
37 |
38 | """
39 |
40 | def __init__(
41 | self,
42 | block_args,
43 | batch_norm_momentum: Optional[float] = 0.99,
44 | batch_norm_epsilon: Optional[float] = 1e-3,
45 | drop_rate: Optional[float] = None,
46 | pre_norm: bool = False,
47 | name: str = "_block_",
48 | activation: str = "swish",
49 | ):
50 | super(MBConvBlock, self).__init__()
51 | self.name = name
52 | self._block_args = block_args
53 | self.block_activation = activation
54 | # in torch.batchnorm, (1-momentum)*running_mean + momentum * x_new
55 | self._bn_mom = 1 - batch_norm_momentum
56 | self._bn_eps = batch_norm_epsilon
57 | self.has_se = (self._block_args.se_ratio is not None) and (
58 | 0 < self._block_args.se_ratio <= 1
59 | )
60 | self.drop_rate = drop_rate
61 | self.id_skip = block_args.id_skip # skip connection and drop connect
62 | self.pre_norm = pre_norm
63 |
64 | if self.pre_norm:
65 | self.pre_norm_layer = nn.BatchNorm2d(
66 | num_features=self._block_args.input_filters
67 | )
68 |
69 | # Expansion phase
70 | inp = self._block_args.input_filters # number of input channels
71 | oup = (
72 | self._block_args.input_filters * self._block_args.expand_ratio
73 | ) # number of output channels
74 | if self._block_args.expand_ratio != 1:
75 | self._expand_conv = Conv2d(
76 | in_channels=inp, out_channels=oup, kernel_size=1, bias=False
77 | )
78 | self._bn0 = nn.BatchNorm2d(
79 | num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
80 | )
81 |
82 | # Depthwise convolution phase
83 | k = self._block_args.kernel_size
84 | s = self._block_args.stride
85 | self._depthwise_conv = Conv2d(
86 | in_channels=oup,
87 | out_channels=oup,
88 | groups=oup, # groups makes it depthwise
89 | kernel_size=k,
90 | stride=s,
91 | padding="same" if s == 1 else 1,
92 | bias=False,
93 | )
94 | self._bn1 = nn.BatchNorm2d(
95 | num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
96 | )
97 |
98 | # Squeeze and Excitation layer, if desired
99 | if self.has_se:
100 | num_squeezed_channels = max(
101 | 1, int(self._block_args.input_filters * self._block_args.se_ratio)
102 | )
103 | self._se_reduce = Conv2d(
104 | in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1
105 | )
106 | self._se_expand = Conv2d(
107 | in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1
108 | )
109 |
110 | # Output phase
111 | final_oup = self._block_args.output_filters
112 | self._project_conv = Conv2d(
113 | in_channels=oup,
114 | out_channels=final_oup,
115 | kernel_size=1,
116 | bias=False,
117 | padding="same",
118 | )
119 | self._bn2 = nn.BatchNorm2d(
120 | num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps
121 | )
122 |
123 | if self.block_activation == "swish":
124 | self._swish = Swish()
125 | elif self.block_activation == "relu":
126 | self._swish = nn.ReLU(inplace=True)
127 | elif self.block_activation == "gelu":
128 | self._swish = nn.GELU()
129 | else:
130 | raise ValueError("Unsupported activation in MBConv block.")
131 |
132 | if self.drop_rate is not None:
133 | self.dropout = nn.Dropout(self.drop_rate)
134 |
135 | if block_args.stride == 2:
136 | self.shortcut_pool = nn.AvgPool2d(
137 | kernel_size=2,
138 | stride=2,
139 | )
140 | self.shortcut_conv = None
141 | if block_args.input_filters != block_args.output_filters:
142 | self.shortcut_conv = Conv2d(
143 | in_channels=block_args.input_filters,
144 | out_channels=block_args.output_filters,
145 | kernel_size=1,
146 | stride=1,
147 | padding="same",
148 | bias=True,
149 | )
150 |
151 | def forward(self, inputs):
152 | """param inputs: input tensor (batch_size, C, H, W)
153 | param drop_connect_rate: drop connect rate (float, between 0 and 1)
154 |
155 | :param inputs:
156 |
157 | """
158 |
159 | shortcut = inputs
160 | x = inputs
161 |
162 | if self.pre_norm:
163 | x = self.pre_norm_layer(x)
164 | # Expansion and Depthwise Convolution
165 | if self._block_args.expand_ratio != 1:
166 | x = self._swish(self._bn0(self._expand_conv(x)))
167 | x = self._swish(self._bn1(self._depthwise_conv(x)))
168 |
169 | # Squeeze and Excitation
170 | if self.has_se:
171 | x_squeezed = nn.AdaptiveAvgPool2d(output_size=(1, 1))(x)
172 | x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed)))
173 | x = torch.sigmoid(x_squeezed) * x
174 |
175 | x = self._bn2(self._project_conv(x))
176 |
177 | # Skip connection and drop connect
178 | input_filters, output_filters = (
179 | self._block_args.input_filters,
180 | self._block_args.output_filters,
181 | )
182 | if self.id_skip:
183 | if self._block_args.stride == 1 and input_filters == output_filters:
184 | if self.drop_rate:
185 | x = self.dropout(x)
186 | elif self._block_args.stride == 2:
187 | shortcut = self.shortcut_pool(inputs)
188 | if self.shortcut_conv is not None:
189 | shortcut = self.shortcut_conv(shortcut)
190 | elif (
191 | self._block_args.stride == 1
192 | or self._block_args.stride == [1, 1]
193 | and input_filters != output_filters
194 | ):
195 | if self.shortcut_conv is not None:
196 | shortcut = self.shortcut_conv(shortcut)
197 | x = torch.add(x, shortcut) # skip connection
198 | return x
199 |
--------------------------------------------------------------------------------
/vision_transformers/model.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 | """
6 | Reference:
7 | [1] MOAT: https://arxiv.org/pdf/2210.01820.pdf.
8 | [2] Tensorflow official impl: https://github.com/google-research/deeplab2/blob/main/model/pixel_encoder/moat.py
9 | """
10 | import logging
11 | import math
12 | from typing import Optional, Sequence, Any, Tuple
13 |
14 | import collections
15 | from vision_transformers.attention_utils import (
16 | WindowAttention,
17 | PEType,
18 | window_partition,
19 | window_reverse,
20 | )
21 | from vision_transformers.mbconv import MBConvBlock
22 | import torch
23 | from torch import nn
24 | from torch.nn import GELU
25 | from dataclasses import dataclass
26 |
27 | logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
28 | logger = logging.getLogger(__name__)
29 | logger.setLevel(logging.INFO)
30 |
31 | BlockArgs = collections.namedtuple(
32 | "BlockArgs",
33 | [
34 | "kernel_size",
35 | "num_repeat",
36 | "input_filters",
37 | "output_filters",
38 | "expand_ratio",
39 | "id_skip",
40 | "stride",
41 | "se_ratio",
42 | ],
43 | )
44 |
45 | # Change namedtuple defaults
46 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
47 |
48 |
49 | @dataclass
50 | class MOATConfig:
51 | """MOAT config. Default values are from tiny_moat_0.
52 |
53 | For detailed model hyperparameter configuration, refer to the MOAT paper.
54 | Supports two attention modes of operation, which you may experiment with for tradeoff between PnP and KPI:
55 | 1. For local attention, window size is limited to a fixed window size (up to user configuration depends on input resolution).
56 | 2. For global attention, do attention on the full feature map.
57 |
58 | :param stem_size: hidden size of the conv stem.
59 | :param block_type: type of each stage.
60 | :param num_blocks: number of blocks for each stage.
61 | :param hidden_size: hidden size of each stage.
62 | :param window_size: window_size of each stage if using attention block.
63 | :param activation: activation function to use.
64 | :param attention_mode: use global or local attention.
65 | :param split_head: whether do split head attention. split_head makes it run more efficiently on ANE.
66 |
67 | """
68 |
69 | stem_size: Sequence[int] = (32, 32)
70 | block_type: Sequence[str] = ("mbconv", "mbconv", "moat", "moat")
71 | num_blocks: Sequence[int] = (2, 3, 7, 2)
72 | hidden_size: Sequence[int] = (32, 64, 128, 256)
73 | window_size: Sequence[Any] = (None, None, (14, 14), (7, 7))
74 | activation: nn.Module = GELU()
75 | attention_mode: str = "global"
76 | split_head: bool = True
77 | stage_stride: Sequence[int] = (2, 2, 2, 2)
78 | mbconv_block_expand_ratio: int = 4
79 | moat_block_expand_ratio: int = 4
80 | pe_type: PEType = PEType.LePE_ADD
81 |
82 | def __post_init__(self):
83 | if self.attention_mode == "local":
84 | # window_size should be limited to local context
85 | local_context_lower, local_context_upper = 6, 16
86 | for window in self.window_size:
87 | if window is not None:
88 | assert isinstance(window, tuple) or isinstance(window, list)
89 | for hw in window:
90 | assert hw >= local_context_lower and hw <= local_context_upper
91 |
92 |
93 | @dataclass
94 | class MOATBlockConfig:
95 | """MOAT block config.
96 |
97 | :param block_name: name of the block.
98 | :param window_size: attention window size.
99 | :param attn_norm_class: normalization layer in attention.
100 | :param activation: activation function.
101 | :param head_dim: dimension of each head.
102 | :param kernel_size: kernel size for MBConv block
103 | :param stride: stride for MBConv block
104 | :param expand_ratio: expansion ratio in the MBConv block.
105 | :param id_skip: do skip connection or not.
106 | :param se_ratio: channel reduction ratio in squeeze and excitation, if 0 or None, no SE.
107 | :param attention_mode: use global or local attention.
108 | :param split_head: whether do split head attention. split_head makes it run more efficiently on ANE.
109 | :param pe_type: position embedding type
110 |
111 | """
112 |
113 | block_name: str = "moat_block"
114 | window_size: Optional[Sequence[int]] = None
115 | attn_norm_class: nn.Module = nn.LayerNorm
116 | head_dim: int = 32 # dim of each head
117 | activation: nn.Module = GELU()
118 | # BlockArgs
119 | kernel_size: int = 3
120 | stride: int = 1
121 | input_filters: int = 32
122 | output_filters: int = 32
123 | expand_ratio: int = 4
124 | id_skip: bool = True
125 | se_ratio: Optional[float] = None
126 | attention_mode: str = "global"
127 | split_head: bool = False
128 | pe_type: PEType = PEType.LePE_ADD
129 |
130 |
131 | class Stem(nn.Sequential):
132 | """Convolutional stem consists of 2 convolution layers.
133 |
134 | :param dims: specifies the dimensions for the convolution stems.
135 |
136 | """
137 |
138 | def __init__(self, dims: Sequence[int]):
139 | stem_layers = []
140 |
141 | for i in range(len(dims)):
142 | norm_layer = None
143 | activation_layer = None
144 |
145 | if i == 0:
146 | activation_layer = GELU()
147 | norm_layer = True
148 |
149 | stride = 2 if i == 0 else 1
150 | in_channels = dims[i - 1] if i >= 1 else 3
151 | conv_layer = nn.Conv2d(
152 | in_channels=in_channels,
153 | out_channels=dims[i],
154 | kernel_size=3,
155 | bias=True,
156 | stride=stride,
157 | padding="same"
158 | if stride == 1
159 | else 1, # strided conv does not support "same"
160 | )
161 | stem_layers.append(conv_layer)
162 | if activation_layer is not None:
163 | stem_layers.append(activation_layer)
164 | if norm_layer:
165 | stem_layers.append(nn.BatchNorm2d(dims[i]))
166 |
167 | super().__init__(*stem_layers)
168 |
169 |
170 | class MOATBlock(nn.Module):
171 | """A MOAT block consists of MBConv (w/o squeeze-excitation blocks) and MHSA.
172 |
173 | :param config: a MOATBlockConfig object to specify block dims, attention mode, window size, etc.
174 |
175 | """
176 |
177 | def __init__(self, config: MOATBlockConfig):
178 | super().__init__()
179 | block_args = BlockArgs(
180 | kernel_size=config.kernel_size,
181 | stride=config.stride,
182 | se_ratio=None, # MOAT block does not use SE branch
183 | input_filters=config.input_filters,
184 | output_filters=config.output_filters,
185 | id_skip=True,
186 | expand_ratio=config.expand_ratio,
187 | )
188 | self._mbconv = MBConvBlock(
189 | block_args,
190 | activation="gelu",
191 | pre_norm=True,
192 | )
193 |
194 | # dim after MBConv block
195 | dim = config.output_filters
196 |
197 | # currently LN apply normalization to the last few dimensions, therefore need NHWC format.
198 | # see pytorch issue 71456: https://github.com/pytorch/pytorch/issues/71465
199 | self._attn_norm = config.attn_norm_class(
200 | normalized_shape=dim,
201 | eps=1e-5,
202 | elementwise_affine=True,
203 | )
204 | assert (
205 | dim % config.head_dim == 0
206 | ), "tensor dimension: {} can not divide by head_dim: {}.".format(
207 | dim, config.head_dim
208 | )
209 | num_heads = dim // config.head_dim
210 | print("######pe_type in MOATBlock: ", config.pe_type)
211 | self._window_attention = WindowAttention(
212 | dim,
213 | window_size=config.window_size,
214 | num_heads=num_heads,
215 | split_head=config.split_head,
216 | pe_type=config.pe_type,
217 | )
218 | self.window_size = config.window_size
219 | self.attention_mode = config.attention_mode
220 |
221 | def forward(self, inputs):
222 | """inputs: (batch_size, C, H, W)
223 | output: ((batch_size, C, H//stride, W//stride)
224 |
225 | :param inputs:
226 |
227 | """
228 |
229 | # MBConv block may contain downsampling layer
230 | output = self._mbconv(inputs)
231 | N, C, H, W = output.shape
232 |
233 | # shortcut is before LN in MOAT
234 | shortcut = output
235 | # transpose to prepare the tensor for window_partition
236 | output = output.permute((0, 2, 3, 1)) # NHWC
237 |
238 | assert (
239 | output.shape[-1] % 32 == 0
240 | ), "ANE buffer not aligned, last dim={}.".format(output.shape[-1])
241 | output = self._attn_norm(output)
242 |
243 | if self.attention_mode == "local":
244 | x_windows = window_partition(output, self.window_size)
245 | x_windows = x_windows.reshape(
246 | (-1, self.window_size[0] * self.window_size[1], C)
247 | )
248 | attn_windows = self._window_attention(x_windows)
249 | output = window_reverse(attn_windows, self.window_size, H, W)
250 | # No need for window_partion/reverse on low res input
251 | elif self.attention_mode == "global":
252 | global_attention_windows = output.reshape((N, H * W, C))
253 | output = self._window_attention(global_attention_windows)
254 |
255 | output = output.reshape((N, H, W, C)).permute((0, 3, 1, 2)) # NCHW
256 |
257 | # may add drop_path here for output
258 | output = shortcut + output
259 | return output # NCHW
260 |
261 |
262 | class MOAT(nn.Module):
263 | """MOAT model definition.
264 |
265 | :param config: a MOATConfig object to specify MOAT variant, attention mode, etc.
266 |
267 | """
268 |
269 | def __init__(self, config: MOATConfig):
270 | super().__init__()
271 | self._stem = Stem(dims=config.stem_size)
272 | # Need to use ModuleList (instead of vanilla python list) for the module to be properly registered
273 | self._blocks = nn.ModuleList()
274 | self.config = config
275 | for stage_id in range(len(config.block_type)):
276 | stage_blocks = nn.ModuleList()
277 | stage_input_filters = (
278 | config.hidden_size[stage_id - 1]
279 | if stage_id > 0
280 | else config.stem_size[-1]
281 | )
282 | stage_output_filters = config.hidden_size[stage_id]
283 |
284 | for local_block_id in range(config.num_blocks[stage_id]):
285 | block_stride = 1
286 | block_name = "block_{:0>2d}_{:0>2d}_".format(stage_id, local_block_id)
287 |
288 | if local_block_id == 0: # downsample in the first block of each stage
289 | block_stride = config.stage_stride[stage_id]
290 | block_input_filters = stage_input_filters
291 | else:
292 | block_input_filters = stage_output_filters
293 |
294 | if config.block_type[stage_id] == "mbconv":
295 | block_args = BlockArgs(
296 | kernel_size=3,
297 | stride=block_stride,
298 | se_ratio=0.25, # SE block reduction ratio
299 | input_filters=block_input_filters,
300 | output_filters=stage_output_filters,
301 | expand_ratio=config.mbconv_block_expand_ratio,
302 | id_skip=True,
303 | )
304 | block = MBConvBlock(
305 | block_args,
306 | activation="gelu",
307 | pre_norm=True,
308 | )
309 | elif config.block_type[stage_id] == "moat":
310 | print("######pe_type: ", config.pe_type)
311 | block_config = MOATBlockConfig(
312 | block_name=block_name,
313 | stride=block_stride,
314 | window_size=config.window_size[stage_id],
315 | input_filters=block_input_filters,
316 | output_filters=stage_output_filters,
317 | attention_mode=config.attention_mode,
318 | split_head=config.split_head,
319 | expand_ratio=config.moat_block_expand_ratio,
320 | pe_type=config.pe_type,
321 | )
322 | block = MOATBlock(block_config)
323 | else:
324 | raise ValueError(
325 | "Network type {} not defined.".format(config.block_type)
326 | )
327 |
328 | stage_blocks.append(block)
329 |
330 | self._blocks.append(stage_blocks)
331 |
332 | def forward(self, inputs: torch.Tensor, out_indices: Sequence[int] = (0, 1, 2, 3)):
333 | """
334 |
335 | :param inputs: torch.Tensor:
336 | :param out_indices: Sequence[int]: (Default value = (0)
337 | :param 1: param 2:
338 | :param 3:
339 | :param inputs: torch.Tensor:
340 | :param out_indices: Sequence[int]: (Default value = (0)
341 | :param 2:
342 | :param 3):
343 |
344 | """
345 | outs = []
346 | output = self._stem(inputs)
347 |
348 | for stage_id, stage_blocks in enumerate(self._blocks):
349 | for block in stage_blocks:
350 | output = block(output)
351 | if stage_id in out_indices:
352 | outs.append(output)
353 | return outs
354 |
355 |
356 | def get_stage_strides(output_stride):
357 | """
358 |
359 | :param output_stride:
360 |
361 | """
362 | if output_stride == 32:
363 | stage_stride = (2, 2, 2, 2)
364 | elif output_stride == 16:
365 | stage_stride = (2, 2, 2, 1)
366 | elif output_stride == 8:
367 | stage_stride = (2, 2, 1, 1)
368 | return stage_stride
369 |
370 |
371 | def _build_model(
372 | shape: Sequence[int] = (1, 3, 192, 256),
373 | base_arch: str = "tiny-moat-2",
374 | attention_mode: str = "global",
375 | split_head: bool = True,
376 | output_stride: int = 32,
377 | channel_buffer_align: bool = True,
378 | num_blocks: Sequence[int] = (2, 3, 7, 2),
379 | mbconv_block_expand_ratio: int = 4,
380 | moat_block_expand_ratio: int = 4,
381 | local_window_size: Optional[Sequence[int]] = None,
382 | pe_type: PEType = PEType.LePE_ADD,
383 | ) -> Tuple[MOATConfig, MOAT]:
384 | """Construct MOAT models.
385 |
386 | :param shape: input shape to the model.
387 | :param base_arch: architecture variant of MOAT.
388 | :param attention_mode: global or local (window based) attention
389 | :param output_stride: stride of output with respect to the input res, e.g., 32 meaning output will be 1/32 of input res
390 | :param split_head: whether do split_head attention. split_head is enabled by default as it is faster on ANE.
391 | :param channel_buffer_align: if True, make channel divisible by 32
392 | :param num_blocks: number of blocks in each stage.
393 | :param mbconv_block_expand_ratio: expansion ratio of mbconv blocks in first 2 stages
394 | :param moat_block_expand_ratio: expansion ratio of moat blocks in last 2 stages
395 | :param local_window_size: local window size of attention.
396 | :param pe_type: position embedding type
397 | :param shape: Sequence[int]: (Default value = (1)
398 | :param 3: param 192:
399 | :param 256: param base_arch: str: (Default value = "tiny-moat-2")
400 | :param attention_mode: str: (Default value = "global")
401 | :param split_head: bool: (Default value = True)
402 | :param output_stride: int: (Default value = 32)
403 | :param channel_buffer_align: bool: (Default value = True)
404 | :param num_blocks: Sequence[int]: (Default value = (2)
405 | :param 7: param 2):
406 | :param mbconv_block_expand_ratio: int: (Default value = 4)
407 | :param moat_block_expand_ratio: int: (Default value = 4)
408 | :param local_window_size: Optional[Sequence[int]]: (Default value = None)
409 | :param pe_type: PEType: (Default value = PEType.LePE_ADD)
410 | :param shape: Sequence[int]: (Default value = (1)
411 | :param 192:
412 | :param 256):
413 | :param base_arch: str: (Default value = "tiny-moat-2")
414 | :param attention_mode: str: (Default value = "global")
415 | :param split_head: bool: (Default value = True)
416 | :param output_stride: int: (Default value = 32)
417 | :param channel_buffer_align: bool: (Default value = True)
418 | :param num_blocks: Sequence[int]: (Default value = (2)
419 | :param 2):
420 | :param mbconv_block_expand_ratio: int: (Default value = 4)
421 | :param moat_block_expand_ratio: int: (Default value = 4)
422 | :param local_window_size: Optional[Sequence[int]]: (Default value = None)
423 | :param pe_type: PEType: (Default value = PEType.LePE_ADD)
424 | :returns: tiny moat model according to the config.
425 |
426 | """
427 | assert shape[-2] % 32 == 0
428 | assert shape[-1] % 32 == 0
429 |
430 | if attention_mode == "global" and local_window_size is not None:
431 | raise RuntimeError(
432 | "global attention should not have local_window_size for local attention."
433 | )
434 |
435 | if output_stride == 32:
436 | out_stride_stage3, out_stride_stage4 = 16, 32
437 | else:
438 | out_stride_stage3, out_stride_stage4 = output_stride, output_stride
439 |
440 | stage_stride = get_stage_strides(output_stride)
441 |
442 | feature_hw = [shape[-2] // output_stride, shape[-1] // output_stride]
443 |
444 | def _get_default_local_window_size(feature_hw):
445 | """
446 |
447 | :param feature_hw:
448 |
449 | """
450 | window_hw = []
451 | attention_field_candidates = [6, 8, 10]
452 | for h_or_w in feature_hw:
453 | if h_or_w % attention_field_candidates[0] == 0:
454 | window_hw.append(attention_field_candidates[0])
455 | elif h_or_w % attention_field_candidates[1] == 0:
456 | window_hw.append(attention_field_candidates[1])
457 | elif h_or_w % attention_field_candidates[2] == 0:
458 | window_hw.append(attention_field_candidates[2])
459 | else:
460 | raise RuntimeError(
461 | f"Not a regular feature map size: {feature_hw}, consider other input resolution."
462 | )
463 | return window_hw
464 |
465 | if attention_mode == "global":
466 | window_size = (
467 | None,
468 | None,
469 | [shape[-2] // out_stride_stage3, shape[-1] // out_stride_stage3],
470 | [shape[-2] // out_stride_stage4, shape[-1] // out_stride_stage4],
471 | )
472 | elif attention_mode == "local":
473 | if local_window_size is None:
474 | local_window_size = _get_default_local_window_size(feature_hw)
475 | window_size = (None, None, local_window_size, local_window_size)
476 | else:
477 | raise ValueError("Undefined attention mode.")
478 |
479 | if base_arch == "tiny-moat-0":
480 | tiny_moat_config = MOATConfig(
481 | num_blocks=num_blocks,
482 | window_size=window_size,
483 | attention_mode=attention_mode,
484 | split_head=split_head,
485 | stage_stride=stage_stride,
486 | mbconv_block_expand_ratio=mbconv_block_expand_ratio,
487 | moat_block_expand_ratio=moat_block_expand_ratio,
488 | pe_type=pe_type,
489 | )
490 | elif base_arch == "tiny-moat-1":
491 | tiny_moat_config = MOATConfig(
492 | stem_size=(40, 40),
493 | hidden_size=(40, 80, 160, 320),
494 | window_size=window_size,
495 | attention_mode=attention_mode,
496 | num_blocks=num_blocks,
497 | split_head=split_head,
498 | stage_stride=stage_stride,
499 | mbconv_block_expand_ratio=mbconv_block_expand_ratio,
500 | moat_block_expand_ratio=moat_block_expand_ratio,
501 | pe_type=pe_type,
502 | )
503 | elif base_arch == "tiny-moat-2":
504 | tiny_moat_config = MOATConfig(
505 | stem_size=(56, 56),
506 | hidden_size=(56, 112, 224, 448),
507 | window_size=window_size,
508 | num_blocks=num_blocks,
509 | attention_mode=attention_mode,
510 | split_head=split_head,
511 | stage_stride=stage_stride,
512 | mbconv_block_expand_ratio=mbconv_block_expand_ratio,
513 | moat_block_expand_ratio=moat_block_expand_ratio,
514 | pe_type=pe_type,
515 | )
516 |
517 | if channel_buffer_align:
518 | aligned_hidden_size = [
519 | math.ceil(h / 32) * 32 for h in tiny_moat_config.hidden_size
520 | ]
521 | aligned_stem_size = [math.ceil(h / 32) * 32 for h in tiny_moat_config.stem_size]
522 | tiny_moat_config.hidden_size = aligned_hidden_size
523 | tiny_moat_config.stem_size = aligned_stem_size
524 |
525 | logger.info("Using config: %s", tiny_moat_config)
526 | tiny_moat = MOAT(tiny_moat_config)
527 |
528 | return tiny_moat_config, tiny_moat
529 |
530 |
531 | if __name__ == "__main__":
532 | config, model = _build_model()
533 |
--------------------------------------------------------------------------------