├── Fig ├── README.md ├── coco-r-results.png ├── coco-results.png ├── img-results.png ├── lam-line-x.jpg ├── lam-line-x.pdf ├── lra-results.png ├── methods.jpg ├── mt-results.png ├── public.jpg ├── public.pdf └── sota.pdf ├── README.md ├── _config.yml ├── _includes └── head-custom.html ├── _layouts └── default.html ├── core ├── GRC_Attention.py └── pvt_grc.py ├── ct-public.gif └── index.md /Fig/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Fig/coco-r-results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annosubmission/GRC-Cache/bdaffd4af9647f3028def2c444757b0e61d2e87f/Fig/coco-r-results.png -------------------------------------------------------------------------------- /Fig/coco-results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annosubmission/GRC-Cache/bdaffd4af9647f3028def2c444757b0e61d2e87f/Fig/coco-results.png -------------------------------------------------------------------------------- /Fig/img-results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annosubmission/GRC-Cache/bdaffd4af9647f3028def2c444757b0e61d2e87f/Fig/img-results.png -------------------------------------------------------------------------------- /Fig/lam-line-x.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annosubmission/GRC-Cache/bdaffd4af9647f3028def2c444757b0e61d2e87f/Fig/lam-line-x.jpg -------------------------------------------------------------------------------- /Fig/lam-line-x.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annosubmission/GRC-Cache/bdaffd4af9647f3028def2c444757b0e61d2e87f/Fig/lam-line-x.pdf -------------------------------------------------------------------------------- /Fig/lra-results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annosubmission/GRC-Cache/bdaffd4af9647f3028def2c444757b0e61d2e87f/Fig/lra-results.png -------------------------------------------------------------------------------- /Fig/methods.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annosubmission/GRC-Cache/bdaffd4af9647f3028def2c444757b0e61d2e87f/Fig/methods.jpg -------------------------------------------------------------------------------- /Fig/mt-results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annosubmission/GRC-Cache/bdaffd4af9647f3028def2c444757b0e61d2e87f/Fig/mt-results.png -------------------------------------------------------------------------------- /Fig/public.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annosubmission/GRC-Cache/bdaffd4af9647f3028def2c444757b0e61d2e87f/Fig/public.jpg -------------------------------------------------------------------------------- /Fig/public.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annosubmission/GRC-Cache/bdaffd4af9647f3028def2c444757b0e61d2e87f/Fig/public.pdf -------------------------------------------------------------------------------- /Fig/sota.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annosubmission/GRC-Cache/bdaffd4af9647f3028def2c444757b0e61d2e87f/Fig/sota.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Cached Transformers 4 | This annoymous repo contains introductions and codes of paper "Cached Transformers: Improving Transformers with Differentiable Memory Cache ". 5 | 6 | 7 | ## Introduction 8 | In this work, we propose a novel family of Transformer model, called Cached Transformer, which has a gated recurrent caches (GRC), a lightweight and flexible widget enabling Transformers to access the historical knowledge. 9 | 10 | 11 | 12 | #### Behavior 13 | We look into this behavior in image classification and find that GRC can separate features into two parts, attending over caches yielding instance-invariant 14 | features, as well as attending over self yielding instance-specific features (See visualizations Below}). 15 | 16 | features 17 | 18 | 19 | #### Results 20 | We conduct extensive experiments on more than **ten** representative Transformer networks from both vision and language tasks, including long range arena, image classification, object detection, instance segmentation, and machine translation. The results demonstrate that our approach significantly improves performance of recent Transformers. 21 | 22 | 23 | 24 | ##### ImageNet Results 25 | imgnet 26 | 27 | ##### COCO2017 Results (Mask R-CNN 1x) 28 | coco 29 | 30 | ##### COCO2017 Results (RetinaNet 1x) 31 | coco-r 32 | 33 | ##### LRA Results 34 | lra 35 | 36 | ##### Machine Translation Results 37 | mt 38 | 39 | 40 | ## Methods 41 | 42 | #### Cached Attention with GRC (GRC-Attention) 43 | 44 | meth 45 | 46 | The illustration of proposed GRC-Attention in Cached Transformers. 47 | 48 | (a) Details of the updating process of Gated Recurrent Cache. The updated cache $C_t$ is derived based on current tokens $X_t$ and cache of last step $C_{t-1}$. The reset gates $g_r$ reset the previous cache $C_{t-1}$ to reset cache $C_t$, and the update gates $g_u$ controls the update intensity. 49 | 50 | (b) Overall pipeline of GRC-Attention. Inputs will attend over cache and themselves respectively, and the outputs are formulated as interpolation of the two attention results. 51 | 52 | 53 | ## Anaylysis 54 | 55 | 56 | #### Significance of Cached Attention 57 | lam 58 | 59 | To verify that the above performance gains mainly come from attending over caches, we analyze the contribution of $o_{mem}$ by visualizing the learnable attention ratio $\sigma(\lambda^h)$. 60 | Hence, $\sigma(\lambda^h)$ can be used to represent the relative significance of $o_{mem}^h$ and $o_{self}^h$. 61 | We observe that, for more than half of the layers, $\sigma(\lambda^h)$ is larger than $0.5$, denoting that outputs of those layers are highly dependent on the cached attention. 62 | Besides, we also notice an interesting fact that the models always prefer more cached attention except for the last several layers. 63 | 64 | #### Roles of Cached Attention 65 | pub 66 | 67 | We investigate the function of GRC-Attention by visualizing their interior feature maps. 68 | We choose the middle layers of cached ViT-S, averaging the outputs from self-attention $o_{self}$ and cached attention ($o_{mem}$) across the head and channel dimension, and then normalizing them into $[0, 1]$. 69 | The corresponding results are denoting as $o_{self}$ and $o_{mem}$, respectively. 70 | As $o_{self}$ and $o_{mem}$ are sequences of patches, they are unflattened to $14 \times 14$ shape for better comparison. 71 | As shown, Features derived by the above two attentions are visually complementary. 72 | 73 | In GRC-Attention, $o_{mem}$ is derived by attending over the proposed cache (GRC) containing compressive representations of historical samples, and thus being adept in recognizing **public** and frequently showing-up patches of this **class**. 74 | While for $o_{self}$ from self-attention branch, it can focus on finding out more private and **characteristic** features of the input **instance**. 75 | With above postulates, we can attempt to explain the regularity of $\sigma(\lambda^h)$: employing more $o_{mem}$ (larger $\sigma(\lambda^h)$ ) in former layers can help the network to distinguish this instance coarsely, and employing more $o_{self}$ (smaller $\sigma(\lambda^h)$) enable the model to make fine-grained decision. 76 | 77 | 78 | 79 | ## Core Codes 80 | The pytorch implementation of GRC-Attention module is provided in "core" directory. 81 | Full training and testing codes will be released later. 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman 2 | 3 | markdown: kramdown 4 | kramdown: 5 | math_engine: katex 6 | -------------------------------------------------------------------------------- /_includes/head-custom.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | {% include head-custom-google-analytics.html %} 5 | 6 | 7 | 8 | 9 | 10 | 11 | 19 | 20 | -------------------------------------------------------------------------------- /_layouts/default.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | {% seo %} 8 | 9 | 10 | 11 | 12 | 13 | 14 | {% include head-custom.html %} 15 | 16 | 17 | Skip to the content. 18 | 19 | 30 | 31 |
32 | {{ content }} 33 | 34 | 35 | 36 | 38 | 39 | 40 |
41 | 42 | 43 | -------------------------------------------------------------------------------- /core/GRC_Attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as f 4 | from timm.models.layers import trunc_normal_ 5 | 6 | # Cached Transformers: Improving Vision Transformers with Differentiable Memory Cache 7 | 8 | def window_partition(x, window_size): 9 | """ 10 | Args: 11 | x: (B, H, W, C) 12 | window_size (int): window size 13 | 14 | Returns: 15 | windows: (num_windows*B, window_size, window_size, C) 16 | """ 17 | B, H, W, C = x.shape 18 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 19 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 20 | return windows 21 | 22 | def window_reverse(windows, window_size, H, W): 23 | """ 24 | Args: 25 | windows: (num_windows*B, window_size, window_size, C) 26 | window_size (int): Window size 27 | H (int): Height of image 28 | W (int): Width of image 29 | 30 | Returns: 31 | x: (B, H, W, C) 32 | """ 33 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 34 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 35 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 36 | return x 37 | 38 | class WindowAttention(nn.Module): 39 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 40 | It supports both of shifted and non-shifted window. 41 | 42 | Args: 43 | dim (int): Number of input channels. 44 | window_size (tuple[int]): The height and width of the window. 45 | num_heads (int): Number of attention heads. 46 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 47 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 48 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 49 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 50 | """ 51 | 52 | def __init__(self, 53 | dim, 54 | window_size=None, 55 | num_heads=None, 56 | qkv_bias=True, 57 | qk_scale=None, 58 | attn_drop=0., 59 | proj_drop=0.): 60 | super().__init__() 61 | 62 | self.dim = dim 63 | self.window_size = window_size 64 | self.num_heads = num_heads 65 | head_dim = dim // num_heads 66 | self.scale = qk_scale or head_dim ** -0.5 67 | 68 | # rel pos 69 | self.relative_position_bias_table = nn.Parameter( 70 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), 71 | num_heads)) 72 | 73 | # 74 | coords_h = torch.arange(self.window_size[0]) 75 | coords_w = torch.arange(self.window_size[1]) 76 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 77 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 78 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 79 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 80 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 81 | relative_coords[:, :, 1] += self.window_size[1] - 1 82 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 83 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 84 | self.register_buffer("relative_position_index", relative_position_index) 85 | 86 | # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 87 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 88 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 89 | self.attn_drop = nn.Dropout(attn_drop) 90 | self.proj = nn.Linear(dim, dim) 91 | self.proj_drop = nn.Dropout(proj_drop) 92 | 93 | trunc_normal_(self.relative_position_bias_table, std=.02) 94 | self.softmax = nn.Softmax(dim=-1) 95 | 96 | def forward(self, q, kv, mask=None): 97 | """ 98 | Args: 99 | x: input features with shape of (num_windows*B, N, C) 100 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 101 | """ 102 | B_, N, C = q.shape 103 | kv = self.kv(kv).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 104 | k, v = kv[0], kv[1] 105 | q = self.q(q).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # B H N C 106 | 107 | q = q * self.scale 108 | attn = (q @ k.transpose(-2, -1)) 109 | 110 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 111 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 112 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 113 | attn = attn + relative_position_bias.unsqueeze(0) 114 | 115 | if mask is not None: 116 | nW = mask.shape[0] 117 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 118 | attn = attn.view(-1, self.num_heads, N, N) 119 | attn = self.softmax(attn) 120 | else: 121 | attn = self.softmax(attn) 122 | 123 | attn = self.attn_drop(attn) 124 | 125 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 126 | x = self.proj(x) 127 | x = self.proj_drop(x) 128 | return x 129 | 130 | class Attention(nn.Module): 131 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., inner_dim=None, 132 | q_in=None, k_in=None, v_in=None, inner_proj=True): 133 | super().__init__() 134 | self.num_heads = num_heads 135 | inner_dim = dim if inner_dim is None else inner_dim 136 | self.inner_dim = inner_dim 137 | 138 | head_dim = inner_dim // num_heads 139 | self.scale = qk_scale or head_dim ** -0.5 140 | 141 | self.q = nn.Linear(dim, inner_dim, bias=qkv_bias) 142 | self.kv = nn.Linear(dim, inner_dim * 2, bias=qkv_bias) 143 | 144 | self.attn_drop = nn.Dropout(attn_drop) 145 | self.proj = nn.Linear(inner_dim, dim) if inner_proj else nn.Identity() 146 | self.proj_drop = nn.Dropout(proj_drop) 147 | 148 | def forward(self, q, kv=None): 149 | if kv is None: 150 | kv = q 151 | B, N, C = kv.shape 152 | B_q, N_q, C_q = q.shape 153 | kv = self.kv(kv).reshape(B, N, 2, self.num_heads, self.inner_dim // self.num_heads).permute(2, 0, 3, 1, 4) 154 | k, v = kv[0], kv[1] 155 | q = self.q(q).reshape(B_q, N_q, self.num_heads, self.inner_dim // self.num_heads).permute(0, 2, 1, 3) # B H N C 156 | 157 | attn = (q @ k.transpose(-2, -1)) * self.scale 158 | attn = attn.softmax(dim=-1) 159 | attn = self.attn_drop(attn) 160 | 161 | x = (attn @ v).transpose(1, 2).reshape(B_q, N_q, self.inner_dim) 162 | x = self.proj(x) 163 | x = self.proj_drop(x) 164 | return x 165 | 166 | class GRC_Self_Attention(nn.Module): 167 | def __init__(self, dim, attention_func=nn.MultiheadAttention, gr_cache=None, cache_ratio=0.5, decoder_attn=True, 168 | spatial_pos_emb=True 169 | , cls_dim=1, **kwargs): 170 | super().__init__() 171 | self.spatial_pos_emb = spatial_pos_emb 172 | self.cache_dim = int(dim * cache_ratio) 173 | self.decoder_attn = decoder_attn 174 | self.attn_self_func = attention_func(dim, **kwargs) 175 | self.attn_mem = attention_func(self.cache_dim, **kwargs) 176 | self.linear_reset = nn.Linear(2 * self.cache_dim, self.cache_dim) 177 | self.linear_update = nn.Linear(2 * self.cache_dim, self.cache_dim) 178 | self.linear_add = nn.Linear(2 * self.cache_dim, self.cache_dim) 179 | 180 | self.Norm1 = nn.LayerNorm(self.cache_dim) 181 | self.cls_dim = cls_dim 182 | self.heads = kwargs['num_heads'] 183 | self.register_parameter('lam', nn.Parameter(torch.zeros([self.heads]) - 1)) 184 | if self.spatial_pos_emb: 185 | self.pos_emb = nn.Conv2d(dim, dim, kernel_size=3, groups=self.cache_dim, stride=1, padding=1) 186 | 187 | self.memory_length = None 188 | if gr_cache is not None: 189 | self.gr_cache = gr_cache 190 | elif self.memory_length is not None: 191 | self.register_buffer('gr_cache', torch.zeros((1, self.memory_length, self.cache_dim,))) 192 | 193 | def forward(self, x, **kwargs): 194 | B, T, C = x.shape 195 | 196 | if 'H' in kwargs: 197 | H, W = kwargs['H'], kwargs['W'] 198 | else: 199 | H = W = int((T - self.cls_dim) ** 0.5) 200 | 201 | 202 | if not hasattr(self, 'gr_cache'): 203 | self.register_buffer('gr_cache', torch.zeros((1,) + (x.shape[1], self.cache_dim)).to(x.device)) 204 | self.mH = H 205 | self.mW = W 206 | 207 | if self.spatial_pos_emb: 208 | B, T, C = x.shape 209 | spa_pos_emd = self.pos_emb(x[:, self.cls_dim:, :].view(B, H, W, C).permute(0, 3, 1, 2)) 210 | spa_pos_emd = spa_pos_emd.view(B, -1, T - self.cls_dim).transpose(1, 2) 211 | spa_pos_emd = torch.cat([torch.zeros([B, self.cls_dim, C]).to(spa_pos_emd.device), spa_pos_emd], 212 | dim=1) if self.cls_dim > 0 else spa_pos_emd 213 | else: 214 | spa_pos_emd = 0 215 | 216 | x_self = self.attn_self_func(x, kv=x, **kwargs).view(B, T, self.heads, -1) 217 | 218 | gr_cache = self.gr_cache.to(x.device) 219 | gr_cache_value = gr_cache.expand((x.shape[0], gr_cache.shape[1], gr_cache.shape[-1])) 220 | 221 | x_summary = x_self.view_as(x)[:, :, :self.cache_dim] 222 | x_summary = f.interpolate(x_summary.transpose(1, 2), (self.gr_cache.shape[1])).transpose(1, 2) 223 | reset_gate = f.sigmoid(self.linear_reset(torch.cat([gr_cache_value, x_summary], dim=-1))) 224 | z_gate = f.sigmoid(self.linear_update(torch.cat([gr_cache_value, x_summary], dim=-1))) 225 | gr_cache_add = reset_gate * gr_cache_value 226 | gr_cache_add = self.Norm1(f.gelu(self.linear_add(torch.cat([gr_cache_add, x_summary], dim=-1)))) 227 | gr_cache_value = z_gate * gr_cache_add + (1 - z_gate) * gr_cache_value 228 | 229 | if self.training: 230 | self.gr_cache.data = gr_cache_value.mean(dim=0, keepdims=True) 231 | 232 | x_mem = self.attn_mem(x[:, :, :self.cache_dim], gr_cache_value).view(B, T, self.heads, -1) 233 | alpha = self.lam.sigmoid().view(1, 1, -1, 1) 234 | 235 | return (alpha * torch.cat([x_mem, torch.zeros(B, T, self.heads, (C - self.cache_dim) // self.heads).to(x.device)], 236 | dim=-1) + (1 - alpha) * x_self).view_as(x) + spa_pos_emd 237 | -------------------------------------------------------------------------------- /core/pvt_grc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | from timm.models.registry import register_model 8 | from timm.models.vision_transformer import _cfg 9 | from core.GRC_Attention import GRC_Self_Attention 10 | 11 | # Cached Transformers: Improving Vision Transformers with Differentiable Memory Cache 12 | 13 | __all__ = [ 14 | 'pvt_tiny_grc', 'pvt_small_grc', 'pvt_medium_grc', 'pvt_large_grc' 15 | ] 16 | 17 | 18 | class Mlp(nn.Module): 19 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 20 | super().__init__() 21 | out_features = out_features or in_features 22 | hidden_features = hidden_features or in_features 23 | self.fc1 = nn.Linear(in_features, hidden_features) 24 | self.act = act_layer() 25 | self.fc2 = nn.Linear(hidden_features, out_features) 26 | self.drop = nn.Dropout(drop) 27 | 28 | def forward(self, x): 29 | x = self.fc1(x) 30 | x = self.act(x) 31 | x = self.drop(x) 32 | x = self.fc2(x) 33 | x = self.drop(x) 34 | return x 35 | 36 | 37 | class Attention(nn.Module): 38 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, 39 | keep_kv=False): 40 | super().__init__() 41 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 42 | 43 | self.dim = dim 44 | self.num_heads = num_heads 45 | head_dim = dim // num_heads 46 | self.scale = qk_scale or head_dim ** -0.5 47 | self.keep_kv = keep_kv 48 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 49 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 50 | self.attn_drop = nn.Dropout(attn_drop) 51 | self.proj = nn.Linear(dim, dim) 52 | self.proj_drop = nn.Dropout(proj_drop) 53 | 54 | self.sr_ratio = sr_ratio 55 | if sr_ratio > 1 and not keep_kv: 56 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 57 | self.norm = nn.LayerNorm(dim) 58 | 59 | def forward(self, q, kv=None, H=None, W=None): 60 | B, N, C = q.shape 61 | 62 | if kv is None: 63 | kv = q 64 | 65 | q = self.q(q).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 66 | if H is None: 67 | H =W = int(N ** 0.5) 68 | 69 | if self.sr_ratio > 1 and not self.keep_kv: 70 | x_ = kv.permute(0, 2, 1).reshape(B, C, H, W) 71 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 72 | x_ = self.norm(x_) 73 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 74 | else: 75 | kv = self.kv(kv).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 76 | 77 | k, v = kv[0], kv[1] 78 | 79 | attn = (q @ k.transpose(-2, -1)) * self.scale 80 | attn = attn.softmax(dim=-1) 81 | attn = self.attn_drop(attn) 82 | 83 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 84 | x = self.proj(x) 85 | x = self.proj_drop(x) 86 | 87 | return x 88 | 89 | 90 | 91 | class Block_grc(nn.Module): 92 | 93 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 94 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, attn_type=None, tk_mem=None): 95 | super().__init__() 96 | self.norm1 = norm_layer(dim) 97 | Attn_func = Attention( 98 | dim, 99 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 100 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 101 | 102 | self.attn = GRC_Self_Attention( 103 | dim, attention_func=Attention, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 104 | attn_drop=attn_drop, proj_drop=drop, tk_mem=tk_mem, 105 | sr_ratio=sr_ratio) if attn_type is not None else Attn_func 106 | 107 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 108 | self.norm2 = norm_layer(dim) 109 | mlp_hidden_dim = int(dim * mlp_ratio) 110 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 111 | 112 | def forward(self, x, H, W): 113 | x = x + self.drop_path(self.attn(self.norm1(x), H=H, W=W)) 114 | x = x + self.drop_path(self.mlp(self.norm2(x))) 115 | 116 | return x 117 | 118 | 119 | class PatchEmbed(nn.Module): 120 | """ Image to Patch Embedding 121 | """ 122 | 123 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 124 | super().__init__() 125 | img_size = to_2tuple(img_size) 126 | patch_size = to_2tuple(patch_size) 127 | 128 | self.img_size = img_size 129 | self.patch_size = patch_size 130 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 131 | self.num_patches = self.H * self.W 132 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 133 | self.norm = nn.LayerNorm(embed_dim) 134 | 135 | def forward(self, x): 136 | B, C, H, W = x.shape 137 | 138 | x = self.proj(x).flatten(2).transpose(1, 2) 139 | x = self.norm(x) 140 | H, W = H // self.patch_size[0], W // self.patch_size[1] 141 | 142 | return x, (H, W) 143 | 144 | 145 | class PyramidVisionTransformer(nn.Module): 146 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 147 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 148 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, attn_dtn2='', 149 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], num_stages=4, attn_dtn=False, attn_type=None, **kwargs): 150 | super().__init__() 151 | self.num_classes = num_classes 152 | self.depths = depths 153 | self.num_stages = num_stages 154 | 155 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 156 | cur = 0 157 | 158 | for i in range(num_stages): 159 | patch_embed = PatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)), 160 | patch_size=patch_size if i == 0 else 2, 161 | in_chans=in_chans if i == 0 else embed_dims[i - 1], 162 | embed_dim=embed_dims[i]) 163 | num_patches = patch_embed.num_patches if i != num_stages - 1 else patch_embed.num_patches + 1 164 | pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dims[i])) 165 | pos_drop = nn.Dropout(p=drop_rate) 166 | 167 | block = nn.ModuleList([Block_grc( 168 | dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias, 169 | qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], 170 | norm_layer=norm_layer, sr_ratio=sr_ratios[i], attn_type=attn_type) 171 | for j in range(depths[i])]) 172 | 173 | cur += depths[i] 174 | 175 | setattr(self, f"patch_embed{i + 1}", patch_embed) 176 | setattr(self, f"pos_embed{i + 1}", pos_embed) 177 | setattr(self, f"pos_drop{i + 1}", pos_drop) 178 | setattr(self, f"block{i + 1}", block) 179 | 180 | self.norm = norm_layer(embed_dims[3]) 181 | 182 | # cls_token 183 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[3])) 184 | 185 | # classification head 186 | self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() 187 | 188 | # init weights 189 | for i in range(num_stages): 190 | pos_embed = getattr(self, f"pos_embed{i + 1}") 191 | trunc_normal_(pos_embed, std=.02) 192 | trunc_normal_(self.cls_token, std=.02) 193 | self.apply(self._init_weights) 194 | 195 | def _init_weights(self, m): 196 | if isinstance(m, nn.Linear): 197 | trunc_normal_(m.weight, std=.02) 198 | if isinstance(m, nn.Linear) and m.bias is not None: 199 | nn.init.constant_(m.bias, 0) 200 | elif isinstance(m, nn.LayerNorm): 201 | nn.init.constant_(m.bias, 0) 202 | nn.init.constant_(m.weight, 1.0) 203 | 204 | @torch.jit.ignore 205 | def no_weight_decay(self): 206 | # return {'pos_embed', 'cls_token'} # has pos_embed may be better 207 | return {'cls_token'} 208 | 209 | def get_classifier(self): 210 | return self.head 211 | 212 | def reset_classifier(self, num_classes, global_pool=''): 213 | self.num_classes = num_classes 214 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 215 | 216 | def _get_pos_embed(self, pos_embed, patch_embed, H, W): 217 | if H * W == self.patch_embed1.num_patches: 218 | return pos_embed 219 | else: 220 | return F.interpolate( 221 | pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2), 222 | size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1) 223 | 224 | def forward_features(self, x): 225 | B = x.shape[0] 226 | 227 | for i in range(self.num_stages): 228 | patch_embed = getattr(self, f"patch_embed{i + 1}") 229 | pos_embed = getattr(self, f"pos_embed{i + 1}") 230 | pos_drop = getattr(self, f"pos_drop{i + 1}") 231 | block = getattr(self, f"block{i + 1}") 232 | x, (H, W) = patch_embed(x) 233 | 234 | if i == self.num_stages - 1: 235 | block.apply(lambda m: setattr(m, 'cls_dim', 1)) 236 | cls_tokens = self.cls_token.expand(B, -1, -1) 237 | x = torch.cat((cls_tokens, x), dim=1) 238 | pos_embed_ = self._get_pos_embed(pos_embed[:, 1:], patch_embed, H, W) 239 | pos_embed = torch.cat((pos_embed[:, 0:1], pos_embed_), dim=1) 240 | else: 241 | block.apply(lambda m: setattr(m, 'cls_dim', 0)) 242 | pos_embed = self._get_pos_embed(pos_embed, patch_embed, H, W) 243 | 244 | x = pos_drop(x + pos_embed) 245 | for blk in block: 246 | x = blk(x, H, W) 247 | if i != self.num_stages - 1: 248 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 249 | 250 | x = self.norm(x) 251 | return x[:, 0] 252 | 253 | def forward(self, x): 254 | x = self.forward_features(x) 255 | x = self.head(x) 256 | 257 | return x 258 | 259 | 260 | def _conv_filter(state_dict, patch_size=16): 261 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 262 | out_dict = {} 263 | for k, v in state_dict.items(): 264 | if 'patch_embed.proj.weight' in k: 265 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 266 | out_dict[k] = v 267 | 268 | return out_dict 269 | -------------------------------------------------------------------------------- /ct-public.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/annosubmission/GRC-Cache/bdaffd4af9647f3028def2c444757b0e61d2e87f/ct-public.gif -------------------------------------------------------------------------------- /index.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Cached Transformers 4 | This annoymous repo contains introductions and codes of paper "Cached Transformers: Improving Transformers with Differentiable Memory Cache ". 5 | 6 | 7 | ## Introduction 8 | In this work, we propose a novel family of Transformer model, called Cached Transformer, which has a gated recurrent caches (GRC), a lightweight and flexible widget enabling Transformers to access the historical knowledge. 9 | 10 | 11 | 12 | #### Behavior 13 | We look into this behavior in image classification and find that GRC can separate features into two parts, attending over caches yielding instance-invariant 14 | features, as well as attending over self yielding instance-specific features (See visualizations Below}). 15 | 16 | features 17 | 18 | 19 | #### Results 20 | We conduct extensive experiments on more than **ten** representative Transformer networks from both vision and language tasks, including long range arena, image classification, object detection, instance segmentation, and machine translation. The results demonstrate that our approach significantly improves performance of recent Transformers. 21 | 22 | 23 | 24 | ##### ImageNet Results 25 | imgnet 26 | 27 | ##### COCO2017 Results (Mask R-CNN 1x) 28 | coco 29 | 30 | ##### COCO2017 Results (RetinaNet 1x) 31 | coco-r 32 | 33 | ##### LRA Results 34 | lra 35 | 36 | ##### Machine Translation Results 37 | mt 38 | 39 | 40 | ## Methods 41 | 42 | #### Cached Attention with GRC (GRC-Attention) 43 | 44 | meth 45 | 46 | The illustration of proposed GRC-Attention in Cached Transformers. 47 | 48 | (a) Details of the updating process of Gated Recurrent Cache. The updated cache $C_t$ is derived based on current tokens $X_t$ and cache of last step $C_{t-1}$. The reset gates $g_r$ reset the previous cache $C_{t-1}$ to reset cache $C_t$, and the update gates $g_u$ controls the update intensity. 49 | 50 | (b) Overall pipeline of GRC-Attention. Inputs will attend over cache and themselves respectively, and the outputs are formulated as interpolation of the two attention results. 51 | 52 | 53 | ## Anaylysis 54 | 55 | 56 | #### Significance of Cached Attention 57 | lam 58 | 59 | To verify that the above performance gains mainly come from attending over caches, we analyze the contribution of $o_{mem}$ by visualizing the learnable attention ratio $\sigma(\lambda^h)$. 60 | Hence, $\sigma(\lambda^h)$ can be used to represent the relative significance of $o_{mem}^h$ and $o_{self}^h$. 61 | We observe that, for more than half of the layers, $\sigma(\lambda^h)$ is larger than $0.5$, denoting that outputs of those layers are highly dependent on the cached attention. 62 | Besides, we also notice an interesting fact that the models always prefer more cached attention except for the last several layers. 63 | 64 | #### Roles of Cached Attention 65 | pub 66 | 67 | We investigate the function of GRC-Attention by visualizing their interior feature maps. 68 | We choose the middle layers of cached ViT-S, averaging the outputs from self-attention $o_{self}$ and cached attention ($o_{mem}$) across the head and channel dimension, and then normalizing them into $[0, 1]$. 69 | The corresponding results are denoting as $o_{self}$ and $o_{mem}$, respectively. 70 | As $o_{self}$ and $o_{mem}$ are sequences of patches, they are unflattened to $14 \times 14$ shape for better comparison. 71 | As shown, Features derived by the above two attentions are visually complementary. 72 | 73 | In GRC-Attention, $o_{mem}$ is derived by attending over the proposed cache (GRC) containing compressive representations of historical samples, and thus being adept in recognizing **public** and frequently showing-up patches of this **class**. 74 | While for $o_{self}$ from self-attention branch, it can focus on finding out more private and **characteristic** features of the input **instance**. 75 | With above postulates, we can attempt to explain the regularity of $\sigma(\lambda^h)$: employing more $o_{mem}$ (larger $\sigma(\lambda^h)$ ) in former layers can help the network to distinguish this instance coarsely, and employing more $o_{self}$ (smaller $\sigma(\lambda^h)$) enable the model to make fine-grained decision. 76 | 77 | 78 | 79 | ## Core Codes 80 | The pytorch implementation of GRC-Attention module is provided in "core" directory. 81 | Full training and testing codes will be released later. 82 | 83 | 84 | 85 | --------------------------------------------------------------------------------