├── LICENSE ├── README.md ├── adapter_modules.py ├── beit3.spm ├── beit3_adapter.py ├── beit3_seg.py ├── beit3_seg_ov_v2.py ├── image └── overview.png ├── modeling_utils.py ├── test.ipynb ├── train.ipynb └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | # AI²Lab Source Code License (National Taiwan University) 2 | 3 | ## 1. Definitions 4 | 5 | “Licensor” means any person or entity that distributes its Work. 6 | 7 | “Work” means (a) the original work of authorship made available under this license, which may include software, documentation, or other files, and (b) any additions to or derivative works thereof that are made available under this license. 8 | 9 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. 10 | 11 | Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing the applicability of this license to the Work, or (b) a copy of this license. 12 | 13 | ## 2. License Grant 14 | 15 | ### 2.1 Copyright Grant 16 | Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. 17 | 18 | ## 3. Limitations 19 | 20 | ### 3.1 Redistribution 21 | You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work. 22 | 23 | ### 3.2 Derivative Works 24 | You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself. 25 | 26 | ### 3.3 Use Limitation 27 | The Work and any derivative works thereof only may be used or intended for use non-commercially. As used herein, “non-commercially” means for research or evaluation purposes only. 28 | 29 | ### 3.4 Patent Claims 30 | If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately. 31 | 32 | ### 3.5 Trademarks 33 | This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this license. 34 | 35 | ### 3.6 Termination 36 | If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) will terminate immediately. 37 | 38 | ## 4. Disclaimer of Warranty. 39 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 40 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 41 | 42 | ## 5. Limitation of Liability. 43 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OMTSeg: Open-Vocabulary Panoptic Segmentation Using BERT Pre-Training of Vision-Language Multiway Transformer Model 2 | [![Paper](https://img.shields.io/badge/Paper-ICIP24-blue)](https://ieeexplore.ieee.org/abstract/document/10647459) [![License](https://img.shields.io/badge/license-custom-lightgrey)](./LICENSE) 3 | 4 | Official implementation of our ICIP 2024 paper: 5 | "Open-Vocabulary Panoptic Segmentation Using BERT Pre-Training of Vision-Language Multiway Transformer Model". 6 | 7 | ## 📄 Abstract 8 | 9 | Open-vocabulary panoptic segmentation remains a challenging problem. One of the biggest difficulties lies in training models to generalize to an unlimited number of classes using limited categorized training data. Recent popular methods involve large-scale vision-language pre-trained foundation models, such as CLIP. In this paper, we propose OMTSeg for open-vocabulary segmentation using another large-scale vision-language pre-trained model called BEiT-3 and leveraging the cross-modal attention between visual and linguistic features in BEiT-3 to achieve better performance. Experiments result demonstrates that OMTSeg performs favorably against state-of-the-art models. 10 | 11 | ## 🚀 Overview 12 | 13 |

14 | PDSeg Overview 15 |

16 | 17 | ## Citation 18 | 19 | If you find this work useful in your research, please cite our paper: 20 | 21 | ``` 22 | @inproceedings{chen2024open, 23 | title={Open-Vocabulary Panoptic Segmentation Using Bert Pre-Training of Vision-Language Multiway Transformer Model}, 24 | author={Chen, Yi-Chia and Li, Wei-Hua and Chen, Chu-Song}, 25 | booktitle={2024 IEEE International Conference on Image Processing (ICIP)}, 26 | pages={2494--2500}, 27 | year={2024}, 28 | organization={IEEE} 29 | } 30 | ``` 31 | 32 | 33 | ## License 34 | 35 | This project is released under a custom license. 36 | Please see the [LICENSE](./LICENSE) file for the full terms and conditions. 37 | 38 | For academic or commercial use, please contact the authors. 39 | 40 | 41 | -------------------------------------------------------------------------------- /adapter_modules.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from functools import partial 3 | import warnings 4 | 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.checkpoint as cp 10 | # from ops.modules import MSDeformAttn 11 | from transformers import DeformableDetrConfig 12 | # from transformers.models.deformable_detr.modeling_deformable_detr import DeformableDetrMultiscaleDeformableAttention 13 | from transformers.models.deformable_detr.modeling_deformable_detr import multi_scale_deformable_attention, MultiScaleDeformableAttentionFunction 14 | from timm.models.layers import DropPath, LayerNorm2d 15 | 16 | 17 | 18 | _logger = logging.getLogger(__name__) 19 | 20 | 21 | def get_reference_points(spatial_shapes, device): 22 | reference_points_list = [] 23 | for lvl, (H_, W_) in enumerate(spatial_shapes): 24 | ref_y, ref_x = torch.meshgrid( 25 | torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), 26 | torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) 27 | ref_y = ref_y.reshape(-1)[None] / H_ 28 | ref_x = ref_x.reshape(-1)[None] / W_ 29 | ref = torch.stack((ref_x, ref_y), -1) 30 | reference_points_list.append(ref) 31 | reference_points = torch.cat(reference_points_list, 1) 32 | reference_points = reference_points[:, :, None] 33 | return reference_points 34 | 35 | 36 | def deform_inputs(x, ss=None): 37 | bs, c, h, w = x.shape 38 | if ss is None: 39 | ss = (h, w) 40 | spatial_shapes = torch.as_tensor([(h // 8, w // 8), 41 | (h // 16, w // 16), 42 | (h // 32, w // 32)], 43 | dtype=torch.long, device=x.device) 44 | level_start_index = torch.cat((spatial_shapes.new_zeros( 45 | (1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) 46 | reference_points = get_reference_points([(ss[0] // 16, ss[1] // 16)], x.device) 47 | deform_inputs1 = [reference_points, spatial_shapes, level_start_index] 48 | 49 | spatial_shapes = torch.as_tensor([(ss[0] // 16, ss[1] // 16)], dtype=torch.long, device=x.device) 50 | level_start_index = torch.cat((spatial_shapes.new_zeros( 51 | (1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) 52 | reference_points = get_reference_points([(h // 8, w // 8), 53 | (h // 16, w // 16), 54 | (h // 32, w // 32)], x.device) 55 | deform_inputs2 = [reference_points, spatial_shapes, level_start_index] 56 | 57 | return deform_inputs1, deform_inputs2 58 | 59 | 60 | class ConvFFN(nn.Module): 61 | def __init__(self, in_features, hidden_features=None, out_features=None, 62 | act_layer=nn.GELU, drop=0.): 63 | super().__init__() 64 | out_features = out_features or in_features 65 | hidden_features = hidden_features or in_features 66 | self.fc1 = nn.Linear(in_features, hidden_features) 67 | self.dwconv = DWConv(hidden_features) 68 | self.act = act_layer() 69 | self.fc2 = nn.Linear(hidden_features, out_features) 70 | self.drop = nn.Dropout(drop) 71 | 72 | def forward(self, x, H, W): 73 | x = self.fc1(x) 74 | x = self.dwconv(x, H, W) 75 | x = self.act(x) 76 | x = self.drop(x) 77 | x = self.fc2(x) 78 | x = self.drop(x) 79 | return x 80 | 81 | 82 | class DWConv(nn.Module): 83 | def __init__(self, dim=768): 84 | super().__init__() 85 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 86 | 87 | def forward(self, x, H, W): 88 | B, N, C = x.shape 89 | n = N // 21 90 | x1 = x[:, 0:16 * n, :].transpose(1, 2).view(B, C, H * 2, W * 2).contiguous() 91 | x2 = x[:, 16 * n:20 * n, :].transpose(1, 2).view(B, C, H, W).contiguous() 92 | x3 = x[:, 20 * n:, :].transpose(1, 2).view(B, C, H // 2, W // 2).contiguous() 93 | x1 = self.dwconv(x1).flatten(2).transpose(1, 2) 94 | x2 = self.dwconv(x2).flatten(2).transpose(1, 2) 95 | x3 = self.dwconv(x3).flatten(2).transpose(1, 2) 96 | x = torch.cat([x1, x2, x3], dim=1) 97 | return x 98 | 99 | 100 | 101 | class DeformableDetrMultiscaleDeformableAttention(nn.Module): 102 | """ 103 | Multiscale deformable attention as proposed in Deformable DETR. 104 | """ 105 | 106 | def __init__(self, config: DeformableDetrConfig, num_heads: int, n_points: int, ratio=1.0): 107 | super().__init__() 108 | if config.d_model % num_heads != 0: 109 | raise ValueError( 110 | f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}" 111 | ) 112 | dim_per_head = config.d_model // num_heads 113 | # check if dim_per_head is power of 2 114 | if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0): 115 | warnings.warn( 116 | "You'd better set embed_dim (d_model) in DeformableDetrMultiscaleDeformableAttention to make the" 117 | " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA" 118 | " implementation." 119 | ) 120 | 121 | self.im2col_step = 64 122 | 123 | self.ratio = ratio 124 | self.d_model = config.d_model 125 | self.n_levels = config.num_feature_levels 126 | self.n_heads = num_heads 127 | self.n_points = n_points 128 | 129 | self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2) 130 | self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points) 131 | self.value_proj = nn.Linear(config.d_model, int(config.d_model * ratio)) 132 | self.output_proj = nn.Linear(int(config.d_model * ratio), config.d_model) 133 | 134 | self.disable_custom_kernels = config.disable_custom_kernels 135 | 136 | self._reset_parameters() 137 | 138 | def _reset_parameters(self): 139 | nn.init.constant_(self.sampling_offsets.weight.data, 0.0) 140 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 141 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) 142 | grid_init = ( 143 | (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) 144 | .view(self.n_heads, 1, 1, 2) 145 | .repeat(1, self.n_levels, self.n_points, 1) 146 | ) 147 | for i in range(self.n_points): 148 | grid_init[:, :, i, :] *= i + 1 149 | with torch.no_grad(): 150 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 151 | nn.init.constant_(self.attention_weights.weight.data, 0.0) 152 | nn.init.constant_(self.attention_weights.bias.data, 0.0) 153 | nn.init.xavier_uniform_(self.value_proj.weight.data) 154 | nn.init.constant_(self.value_proj.bias.data, 0.0) 155 | nn.init.xavier_uniform_(self.output_proj.weight.data) 156 | nn.init.constant_(self.output_proj.bias.data, 0.0) 157 | 158 | def with_pos_embed(self, tensor: torch.Tensor, position_embeddings): 159 | return tensor if position_embeddings is None else tensor + position_embeddings 160 | 161 | def forward( 162 | self, 163 | hidden_states: torch.Tensor, 164 | attention_mask = None, 165 | encoder_hidden_states=None, 166 | encoder_attention_mask=None, 167 | position_embeddings = None, 168 | reference_points=None, 169 | spatial_shapes=None, 170 | level_start_index=None, 171 | output_attentions: bool = False, 172 | ): 173 | # add position embeddings to the hidden states before projecting to queries and keys 174 | if position_embeddings is not None: 175 | hidden_states = self.with_pos_embed(hidden_states, position_embeddings) 176 | 177 | batch_size, num_queries, _ = hidden_states.shape 178 | batch_size, sequence_length, _ = encoder_hidden_states.shape 179 | if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length: 180 | raise ValueError( 181 | "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" 182 | ) 183 | 184 | value = self.value_proj(encoder_hidden_states) 185 | if attention_mask is not None: 186 | # we invert the attention_mask 187 | value = value.masked_fill(~attention_mask[..., None], float(0)) 188 | value = value.view(batch_size, sequence_length, self.n_heads, int(self.d_model * self.ratio) // self.n_heads) 189 | sampling_offsets = self.sampling_offsets(hidden_states).view( 190 | batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2 191 | ) 192 | attention_weights = self.attention_weights(hidden_states).view( 193 | batch_size, num_queries, self.n_heads, self.n_levels * self.n_points 194 | ) 195 | attention_weights = F.softmax(attention_weights, -1).view( 196 | batch_size, num_queries, self.n_heads, self.n_levels, self.n_points 197 | ) 198 | # batch_size, num_queries, n_heads, n_levels, n_points, 2 199 | if reference_points.shape[-1] == 2: 200 | offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) 201 | sampling_locations = ( 202 | reference_points[:, :, None, :, None, :] 203 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :] 204 | ) 205 | elif reference_points.shape[-1] == 4: 206 | sampling_locations = ( 207 | reference_points[:, :, None, :, None, :2] 208 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 209 | ) 210 | else: 211 | raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") 212 | 213 | if self.disable_custom_kernels: 214 | # PyTorch implementation 215 | output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights) 216 | else: 217 | try: 218 | # custom kernel 219 | output = MultiScaleDeformableAttentionFunction.apply( 220 | value, 221 | spatial_shapes, 222 | level_start_index, 223 | sampling_locations, 224 | attention_weights, 225 | self.im2col_step, 226 | ) 227 | except Exception: 228 | # PyTorch implementation 229 | output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights) 230 | output = self.output_proj(output) 231 | 232 | return output, attention_weights 233 | 234 | class AdapterMultiscaleDeformableAttention(DeformableDetrMultiscaleDeformableAttention): 235 | def __init__(self, 236 | d_model=256, 237 | n_levels=4, 238 | n_heads=8, 239 | n_points=4, 240 | ratio=1.0): 241 | 242 | fake_config = DeformableDetrConfig( 243 | d_model=d_model, 244 | num_feature_levels=n_levels, 245 | disable_custom_kernels=False, 246 | ) 247 | 248 | super().__init__(fake_config, num_heads=n_heads, n_points=n_points, ratio=ratio) 249 | 250 | def _reset_parameters(self): 251 | nn.init.constant_(self.sampling_offsets.weight.data, 0.0) 252 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 253 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) 254 | grid_init = ( 255 | (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) 256 | .view(self.n_heads, 1, 1, 2) 257 | .repeat(1, self.n_levels, self.n_points, 1) 258 | ) 259 | for i in range(self.n_points): 260 | grid_init[:, :, i, :] *= i + 1 261 | with torch.no_grad(): 262 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 263 | nn.init.constant_(self.attention_weights.weight.data, 0.0) 264 | nn.init.constant_(self.attention_weights.bias.data, 0.0) 265 | nn.init.xavier_uniform_(self.value_proj.weight.data) 266 | nn.init.constant_(self.value_proj.bias.data, 0.0) 267 | nn.init.xavier_uniform_(self.output_proj.weight.data) 268 | nn.init.constant_(self.output_proj.bias.data, 0.0) 269 | 270 | def with_pos_embed(self, tensor: torch.Tensor, position_embeddings): 271 | return tensor if position_embeddings is None else tensor + position_embeddings 272 | 273 | 274 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, 275 | input_level_start_index, input_padding_mask=None): 276 | 277 | hidden_states = query 278 | attention_mask=None 279 | encoder_hidden_states = input_flatten 280 | encoder_attention_mask = input_padding_mask 281 | position_embeddings=None 282 | reference_points=reference_points 283 | spatial_shapes=input_spatial_shapes 284 | level_start_index=input_level_start_index 285 | output_attentions=False 286 | 287 | output, attention_weights = super().forward( 288 | hidden_states=hidden_states, 289 | attention_mask=attention_mask, 290 | encoder_hidden_states=encoder_hidden_states, 291 | encoder_attention_mask=encoder_attention_mask, 292 | position_embeddings=position_embeddings, 293 | reference_points=reference_points, 294 | spatial_shapes=spatial_shapes, 295 | level_start_index=level_start_index, 296 | output_attentions=False, 297 | ) 298 | 299 | return output 300 | 301 | 302 | class Extractor(nn.Module): 303 | def __init__(self, dim, num_heads=6, n_points=4, n_levels=1, deform_ratio=1.0, 304 | with_cffn=True, cffn_ratio=0.25, drop=0., drop_path=0., 305 | norm_layer=partial(nn.LayerNorm, eps=1e-6), with_cp=False): 306 | super().__init__() 307 | self.query_norm = norm_layer(dim) 308 | self.feat_norm = norm_layer(dim) 309 | self.attn = AdapterMultiscaleDeformableAttention(d_model=dim, n_levels=n_levels, n_heads=num_heads, 310 | n_points=n_points, ratio=deform_ratio) 311 | self.with_cffn = with_cffn 312 | self.with_cp = with_cp 313 | if with_cffn: 314 | self.ffn = ConvFFN(in_features=dim, hidden_features=int(dim * cffn_ratio), drop=drop) 315 | self.ffn_norm = norm_layer(dim) 316 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 317 | 318 | def forward(self, query, reference_points, feat, spatial_shapes, level_start_index, H, W): 319 | 320 | def _inner_forward(query, feat): 321 | 322 | attn = self.attn(self.query_norm(query), reference_points, 323 | self.feat_norm(feat), spatial_shapes, 324 | level_start_index, None) 325 | query = query + attn 326 | 327 | if self.with_cffn: 328 | query = query + self.drop_path(self.ffn(self.ffn_norm(query), H, W)) 329 | return query 330 | 331 | if self.with_cp and query.requires_grad: 332 | query = cp.checkpoint(_inner_forward, query, feat) 333 | else: 334 | query = _inner_forward(query, feat) 335 | 336 | return query 337 | 338 | 339 | class Injector(nn.Module): 340 | def __init__(self, dim, num_heads=6, n_points=4, n_levels=1, deform_ratio=1.0, 341 | norm_layer=partial(nn.LayerNorm, eps=1e-6), init_values=0., with_cp=False): 342 | super().__init__() 343 | self.with_cp = with_cp 344 | self.query_norm = norm_layer(dim) 345 | self.feat_norm = norm_layer(dim) 346 | self.attn = AdapterMultiscaleDeformableAttention(d_model=dim, n_levels=n_levels, n_heads=num_heads, 347 | n_points=n_points, ratio=deform_ratio) 348 | self.gamma = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) 349 | 350 | def forward(self, query, reference_points, feat, spatial_shapes, level_start_index): 351 | 352 | def _inner_forward(query, feat): 353 | 354 | attn = self.attn(self.query_norm(query), reference_points, 355 | self.feat_norm(feat), spatial_shapes, 356 | level_start_index, None) 357 | return query + self.gamma * attn 358 | 359 | if self.with_cp and query.requires_grad: 360 | query = cp.checkpoint(_inner_forward, query, feat) 361 | else: 362 | query = _inner_forward(query, feat) 363 | 364 | return query 365 | 366 | 367 | class InteractionBlock(nn.Module): 368 | def __init__(self, dim, num_heads=6, n_points=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), 369 | drop=0., drop_path=0., with_cffn=True, cffn_ratio=0.25, init_values=0., 370 | deform_ratio=1.0, extra_extractor=False, with_cp=False): 371 | super().__init__() 372 | 373 | self.injector = Injector(dim=dim, n_levels=3, num_heads=num_heads, init_values=init_values, 374 | n_points=n_points, norm_layer=norm_layer, deform_ratio=deform_ratio, 375 | with_cp=with_cp) 376 | self.extractor = Extractor(dim=dim, n_levels=1, num_heads=num_heads, n_points=n_points, 377 | norm_layer=norm_layer, deform_ratio=deform_ratio, with_cffn=with_cffn, 378 | cffn_ratio=cffn_ratio, drop=drop, drop_path=drop_path, with_cp=with_cp) 379 | if extra_extractor: 380 | self.extra_extractors = nn.Sequential(*[ 381 | Extractor(dim=dim, num_heads=num_heads, n_points=n_points, norm_layer=norm_layer, 382 | with_cffn=with_cffn, cffn_ratio=cffn_ratio, deform_ratio=deform_ratio, 383 | drop=drop, drop_path=drop_path, with_cp=with_cp) 384 | for _ in range(2) 385 | ]) 386 | else: 387 | self.extra_extractors = None 388 | 389 | def forward(self, x, c, blocks, deform_inputs1, deform_inputs2, H, W): 390 | x = self.injector(query=x, reference_points=deform_inputs1[0], 391 | feat=c, spatial_shapes=deform_inputs1[1], 392 | level_start_index=deform_inputs1[2]) 393 | for idx, blk in enumerate(blocks): 394 | x = blk(x, H, W) 395 | c = self.extractor(query=c, reference_points=deform_inputs2[0], 396 | feat=x, spatial_shapes=deform_inputs2[1], 397 | level_start_index=deform_inputs2[2], H=H, W=W) 398 | if self.extra_extractors is not None: 399 | for extractor in self.extra_extractors: 400 | c = extractor(query=c, reference_points=deform_inputs2[0], 401 | feat=x, spatial_shapes=deform_inputs2[1], 402 | level_start_index=deform_inputs2[2], H=H, W=W) 403 | return x, c 404 | 405 | 406 | class InteractionBlockWithCls(nn.Module): 407 | def __init__(self, dim, num_heads=6, n_points=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), 408 | drop=0., drop_path=0., with_cffn=True, cffn_ratio=0.25, init_values=0., 409 | deform_ratio=1.0, extra_extractor=False, with_cp=False): 410 | super().__init__() 411 | 412 | self.injector = Injector(dim=dim, n_levels=3, num_heads=num_heads, init_values=init_values, 413 | n_points=n_points, norm_layer=norm_layer, deform_ratio=deform_ratio, 414 | with_cp=with_cp) 415 | self.extractor = Extractor(dim=dim, n_levels=1, num_heads=num_heads, n_points=n_points, 416 | norm_layer=norm_layer, deform_ratio=deform_ratio, with_cffn=with_cffn, 417 | cffn_ratio=cffn_ratio, drop=drop, drop_path=drop_path, with_cp=with_cp) 418 | if extra_extractor: 419 | self.extra_extractors = nn.Sequential(*[ 420 | Extractor(dim=dim, num_heads=num_heads, n_points=n_points, norm_layer=norm_layer, 421 | with_cffn=with_cffn, cffn_ratio=cffn_ratio, deform_ratio=deform_ratio, 422 | drop=drop, drop_path=drop_path, with_cp=with_cp) 423 | for _ in range(2) 424 | ]) 425 | else: 426 | self.extra_extractors = None 427 | 428 | def forward(self, x, c, cls, blocks, deform_inputs1, deform_inputs2, H, W): 429 | x = self.injector(query=x, reference_points=deform_inputs1[0], 430 | feat=c, spatial_shapes=deform_inputs1[1], 431 | level_start_index=deform_inputs1[2]) 432 | x = torch.cat((cls, x), dim=1) 433 | for idx, blk in enumerate(blocks): 434 | x = blk(x, H, W) 435 | cls, x = x[:, :1, ], x[:, 1:, ] 436 | c = self.extractor(query=c, reference_points=deform_inputs2[0], 437 | feat=x, spatial_shapes=deform_inputs2[1], 438 | level_start_index=deform_inputs2[2], H=H, W=W) 439 | if self.extra_extractors is not None: 440 | for extractor in self.extra_extractors: 441 | c = extractor(query=c, reference_points=deform_inputs2[0], 442 | feat=x, spatial_shapes=deform_inputs2[1], 443 | level_start_index=deform_inputs2[2], H=H, W=W) 444 | return x, c, cls 445 | 446 | 447 | class InteractionBlockWithClsAndMultiWay(nn.Module): 448 | def __init__(self, dim, num_heads=6, n_points=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), 449 | drop=0., drop_path=0., with_cffn=True, cffn_ratio=0.25, init_values=0., 450 | deform_ratio=1.0, extra_extractor=False, with_cp=False): 451 | super().__init__() 452 | 453 | self.injector = Injector(dim=dim, n_levels=3, num_heads=num_heads, init_values=init_values, 454 | n_points=n_points, norm_layer=norm_layer, deform_ratio=deform_ratio, 455 | with_cp=with_cp) 456 | self.extractor = Extractor(dim=dim, n_levels=1, num_heads=num_heads, n_points=n_points, 457 | norm_layer=norm_layer, deform_ratio=deform_ratio, with_cffn=with_cffn, 458 | cffn_ratio=cffn_ratio, drop=drop, drop_path=drop_path, with_cp=with_cp) 459 | if extra_extractor: 460 | self.extra_extractors = nn.Sequential(*[ 461 | Extractor(dim=dim, num_heads=num_heads, n_points=n_points, norm_layer=norm_layer, 462 | with_cffn=with_cffn, cffn_ratio=cffn_ratio, deform_ratio=deform_ratio, 463 | drop=drop, drop_path=drop_path, with_cp=with_cp) 464 | for _ in range(2) 465 | ]) 466 | else: 467 | self.extra_extractors = None 468 | 469 | def forward(self, x, c, cls, multiway_split_position, blocks, deform_inputs1, deform_inputs2, H, W, return_hiddens=False): 470 | 471 | if multiway_split_position == -1: 472 | x = self.injector(query=x, reference_points=deform_inputs1[0], 473 | feat=c, spatial_shapes=deform_inputs1[1], 474 | level_start_index=deform_inputs1[2]) 475 | x = torch.cat((cls, x), dim=1) 476 | else: 477 | x_visual, x_text = x[:, :multiway_split_position-1], x[:, multiway_split_position-1:] 478 | x_visual = self.injector(query=x_visual, reference_points=deform_inputs1[0], 479 | feat=c, spatial_shapes=deform_inputs1[1], 480 | level_start_index=deform_inputs1[2]) 481 | x = torch.cat((cls, x_visual, x_text), dim=1) 482 | 483 | hiddens = [] 484 | for idx, blk in enumerate(blocks): 485 | x = blk(x, H, W) 486 | if return_hiddens: 487 | hiddens.append(x) 488 | 489 | if multiway_split_position == -1: 490 | cls, x = x[:, :1, ], x[:, 1:, ] 491 | c = self.extractor(query=c, reference_points=deform_inputs2[0], 492 | feat=x, spatial_shapes=deform_inputs2[1], 493 | level_start_index=deform_inputs2[2], H=H, W=W) 494 | if self.extra_extractors is not None: 495 | for extractor in self.extra_extractors: 496 | c = extractor(query=c, reference_points=deform_inputs2[0], 497 | feat=x, spatial_shapes=deform_inputs2[1], 498 | level_start_index=deform_inputs2[2], H=H, W=W) 499 | else: 500 | cls, x_visual, x_text = x[:, :1, ], x[:, 1:multiway_split_position, ], x[:, multiway_split_position:, ] 501 | c = self.extractor(query=c, reference_points=deform_inputs2[0], 502 | feat=x_visual, spatial_shapes=deform_inputs2[1], 503 | level_start_index=deform_inputs2[2], H=H, W=W) 504 | if self.extra_extractors is not None: 505 | for extractor in self.extra_extractors: 506 | c = extractor(query=c, reference_points=deform_inputs2[0], 507 | feat=x_visual, spatial_shapes=deform_inputs2[1], 508 | level_start_index=deform_inputs2[2], H=H, W=W) 509 | x = torch.cat((x_visual, x_text), dim=1) 510 | 511 | return x, c, cls, hiddens 512 | 513 | 514 | class SpatialPriorModule(nn.Module): 515 | def __init__(self, inplanes=64, embed_dim=384, with_cp=False): 516 | super().__init__() 517 | self.with_cp = with_cp 518 | 519 | self.stem = nn.Sequential(*[ 520 | nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False), 521 | LayerNorm2d(inplanes), 522 | nn.ReLU(inplace=True), 523 | nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False), 524 | LayerNorm2d(inplanes), 525 | nn.ReLU(inplace=True), 526 | nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False), 527 | LayerNorm2d(inplanes), 528 | nn.ReLU(inplace=True), 529 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 530 | ]) 531 | self.conv2 = nn.Sequential(*[ 532 | nn.Conv2d(inplanes, 2 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), 533 | LayerNorm2d(2 * inplanes), 534 | nn.ReLU(inplace=True) 535 | ]) 536 | self.conv3 = nn.Sequential(*[ 537 | nn.Conv2d(2 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), 538 | LayerNorm2d(4 * inplanes), 539 | nn.ReLU(inplace=True) 540 | ]) 541 | self.conv4 = nn.Sequential(*[ 542 | nn.Conv2d(4 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), 543 | LayerNorm2d(4 * inplanes), 544 | nn.ReLU(inplace=True) 545 | ]) 546 | self.fc1 = nn.Conv2d(inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) 547 | self.fc2 = nn.Conv2d(2 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) 548 | self.fc3 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) 549 | self.fc4 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) 550 | 551 | def forward(self, x): 552 | 553 | def _inner_forward(x): 554 | c1 = self.stem(x) 555 | c2 = self.conv2(c1) 556 | c3 = self.conv3(c2) 557 | c4 = self.conv4(c3) 558 | c1 = self.fc1(c1) 559 | c2 = self.fc2(c2) 560 | c3 = self.fc3(c3) 561 | c4 = self.fc4(c4) 562 | 563 | bs, dim, _, _ = c1.shape 564 | # c1 = c1.view(bs, dim, -1).transpose(1, 2) # 4s 565 | c2 = c2.view(bs, dim, -1).transpose(1, 2) # 8s 566 | c3 = c3.view(bs, dim, -1).transpose(1, 2) # 16s 567 | c4 = c4.view(bs, dim, -1).transpose(1, 2) # 32s 568 | 569 | return c1, c2, c3, c4 570 | 571 | if self.with_cp and x.requires_grad: 572 | outs = cp.checkpoint(_inner_forward, x) 573 | else: 574 | outs = _inner_forward(x) 575 | return outs 576 | 577 | -------------------------------------------------------------------------------- /beit3.spm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Application-and-Integration-Lab/OMTSeg/3e06c9e6f2e65b0656e9fa7f47424149f911d84d/beit3.spm -------------------------------------------------------------------------------- /beit3_adapter.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/microsoft/torchscale/blob/main/torchscale/model/BEiT3.py 2 | # Copyright (c) 2022 Microsoft 3 | # Licensed under The MIT License [see LICENSE for details] 4 | 5 | import math 6 | from functools import partial 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.nn import LayerNorm 12 | from torch.nn.init import normal_ 13 | from timm.models.layers import trunc_normal_, LayerNorm2d 14 | 15 | from torchscale.model.BEiT3 import BEiT3 16 | from torchscale.architecture.config import EncoderConfig 17 | from torchscale.component.multiway_network import set_split_position 18 | 19 | from adapter_modules import SpatialPriorModule, deform_inputs 20 | from adapter_modules import InteractionBlockWithClsAndMultiWay as InteractionBlock 21 | from adapter_modules import AdapterMultiscaleDeformableAttention 22 | 23 | def block_decorator(f): 24 | def wrapper(x, H, W): 25 | x, l_aux = f(x) 26 | return x 27 | return wrapper 28 | 29 | class BEiT3Adapter(BEiT3): 30 | def __init__(self, beit3_args, conv_inplane=64, n_points=4, 31 | deform_num_heads=6, init_values=0., interaction_indexes=None, with_cffn=True, 32 | cffn_ratio=0.25, deform_ratio=1.0, add_vit_feature=True, pretrained=None, 33 | use_extra_extractor=True, with_cp=None, asymetric_input=False, beit_resolution=None, 34 | intepolate_pos=False, num_segments=None, **kwargs): 35 | 36 | super().__init__(beit3_args) 37 | 38 | if num_segments is not None: 39 | self.segment_embed = nn.Embedding(num_segments, beit3_args.encoder_embed_dim) 40 | 41 | self.asymetric_input = asymetric_input 42 | self.beit_resolution = beit_resolution 43 | self.intepolate_pos = intepolate_pos 44 | 45 | with_cp = with_cp if with_cp is not None else beit3_args.checkpoint_activations 46 | 47 | self.norm_layer = LayerNorm 48 | self.interaction_indexes = interaction_indexes 49 | self.add_vit_feature = add_vit_feature 50 | self.embed_dim = beit3_args.encoder_embed_dim 51 | embed_dim = beit3_args.encoder_embed_dim 52 | 53 | self.level_embed = nn.Parameter(torch.zeros(3, embed_dim)) 54 | self.spm = SpatialPriorModule(inplanes=conv_inplane, embed_dim=embed_dim, with_cp=False) 55 | self.interactions = nn.Sequential(*[ 56 | InteractionBlock(dim=embed_dim, num_heads=deform_num_heads, n_points=n_points, 57 | init_values=init_values, drop_path=beit3_args.drop_path_rate, 58 | norm_layer=self.norm_layer, with_cffn=with_cffn, 59 | cffn_ratio=cffn_ratio, deform_ratio=deform_ratio, 60 | extra_extractor=((True if i == len(interaction_indexes) - 1 61 | else False) and use_extra_extractor), 62 | with_cp=with_cp) 63 | for i in range(len(interaction_indexes)) 64 | ]) 65 | self.up = nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2) 66 | self.norm1 = LayerNorm2d(embed_dim) 67 | self.norm2 = LayerNorm2d(embed_dim) 68 | self.norm3 = LayerNorm2d(embed_dim) 69 | self.norm4 = LayerNorm2d(embed_dim) 70 | 71 | self.init_weights() 72 | 73 | def init_weights(self): 74 | self.up.apply(self._init_weights) 75 | self.spm.apply(self._init_weights) 76 | self.interactions.apply(self._init_weights) 77 | self.apply(self._init_deform_weights) 78 | normal_(self.level_embed) 79 | 80 | def _init_weights(self, m): 81 | if isinstance(m, nn.Linear): 82 | trunc_normal_(m.weight, std=.02) 83 | if isinstance(m, nn.Linear) and m.bias is not None: 84 | nn.init.constant_(m.bias, 0) 85 | elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d): 86 | nn.init.constant_(m.bias, 0) 87 | nn.init.constant_(m.weight, 1.0) 88 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 89 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 90 | fan_out //= m.groups 91 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 92 | if m.bias is not None: 93 | m.bias.data.zero_() 94 | 95 | def _init_deform_weights(self, m): 96 | if isinstance(m, AdapterMultiscaleDeformableAttention): 97 | m._reset_parameters() 98 | 99 | def _add_level_embed(self, c2, c3, c4): 100 | c2 = c2 + self.level_embed[0] 101 | c3 = c3 + self.level_embed[1] 102 | c4 = c4 + self.level_embed[2] 103 | return c2, c3, c4 104 | 105 | def forward_encoder( 106 | self, 107 | src_tokens, 108 | encoder_padding_mask=None, 109 | attn_mask=None, 110 | return_all_hiddens=False, 111 | token_embeddings=None, 112 | multiway_split_position=None, 113 | features_only=False, 114 | incremental_state=None, 115 | positions=None, 116 | **kwargs 117 | ): 118 | assert src_tokens is not None or token_embeddings is not None 119 | 120 | if encoder_padding_mask is None: 121 | if src_tokens is not None: 122 | encoder_padding_mask = torch.zeros_like( 123 | src_tokens, device=src_tokens.device 124 | ).bool() 125 | else: 126 | encoder_padding_mask = torch.zeros( 127 | [token_embeddings.size(0), token_embeddings.size(1)], 128 | device=token_embeddings.device, 129 | ).bool() 130 | 131 | if multiway_split_position is not None: 132 | assert self.encoder.args.multiway 133 | self.encoder.apply(set_split_position(multiway_split_position)) 134 | 135 | x, encoder_embedding = self.encoder.forward_embedding(src_tokens, token_embeddings, positions) 136 | x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) 137 | 138 | encoder_states = [] 139 | 140 | if return_all_hiddens: 141 | encoder_states.append(x) 142 | 143 | rel_pos_bias = None 144 | if self.encoder.relative_position is not None: 145 | rel_pos_bias = self.encoder.relative_position( 146 | batch_size=x.size(0), qlen=x.size(1), klen=x.size(1) 147 | ) 148 | 149 | # incremental_state is not None during inference if we use the bidirectional encoder as a generator as in s2s-ft (https://arxiv.org/abs/2110.13640) 150 | l_aux = [] 151 | for idx, layer in enumerate(self.encoder.layers): 152 | x, l_aux_i = layer( 153 | x, 154 | encoder_padding_mask=encoder_padding_mask if incremental_state is None else None, 155 | attn_mask=attn_mask, 156 | rel_pos=rel_pos_bias, 157 | multiway_split_position=multiway_split_position, 158 | incremental_state=incremental_state[idx] if incremental_state is not None else None, 159 | ) 160 | if return_all_hiddens: 161 | assert encoder_states is not None 162 | encoder_states.append(x) 163 | l_aux.append(l_aux_i) 164 | 165 | if self.encoder.layer_norm is not None: 166 | x = self.encoder.layer_norm(x) 167 | 168 | if not features_only and self.encoder.output_projection is not None: 169 | x = self.encoder.output_projection(x) 170 | 171 | return { 172 | "encoder_out": x, 173 | "encoder_embedding": encoder_embedding, 174 | "encoder_padding_mask": encoder_padding_mask, 175 | "encoder_states": encoder_states, 176 | "l_aux": l_aux, 177 | } 178 | 179 | def forward_beit3( 180 | self, 181 | textual_tokens=None, 182 | visual_tokens=None, 183 | text_padding_position=None, 184 | attn_mask=None, 185 | vision_masked_position=None, 186 | incremental_state=None, 187 | positions=None, 188 | ): 189 | assert textual_tokens is not None or visual_tokens is not None 190 | 191 | if textual_tokens is None: 192 | x = self.vision_embed(visual_tokens, vision_masked_position) 193 | encoder_padding_mask = None 194 | multiway_split_position = -1 195 | elif visual_tokens is None: 196 | x = self.text_embed(textual_tokens) 197 | encoder_padding_mask = text_padding_position 198 | multiway_split_position = 0 199 | else: 200 | x1 = self.vision_embed(visual_tokens, vision_masked_position) 201 | multiway_split_position = x1.size(1) 202 | x2 = self.text_embed(textual_tokens) 203 | x = torch.cat([x1, x2], dim=1) 204 | 205 | if text_padding_position is not None: 206 | encoder_padding_mask = torch.cat( 207 | [ 208 | torch.zeros(x1.shape[:-1]).to(x1.device).bool(), 209 | text_padding_position, 210 | ], 211 | dim=1, 212 | ) 213 | else: 214 | encoder_padding_mask = None 215 | 216 | encoder_out = self.forward_encoder( 217 | src_tokens=None, 218 | encoder_padding_mask=encoder_padding_mask, 219 | attn_mask=attn_mask, 220 | token_embeddings=x, 221 | multiway_split_position=multiway_split_position, 222 | incremental_state=incremental_state, 223 | positions=positions, 224 | ) 225 | encoder_out["multiway_split_position"] = multiway_split_position 226 | 227 | return encoder_out 228 | 229 | def forward( 230 | self, 231 | textual_tokens=None, 232 | visual_tokens=None, 233 | text_padding_position=None, 234 | attn_mask=None, 235 | vision_masked_position=None, 236 | incremental_state=None, 237 | positions=None, 238 | segment_ids=None, 239 | use_vit_adapter=False, 240 | return_all_hiddens=False, 241 | ): 242 | if not use_vit_adapter: 243 | return self.forward_beit3( 244 | textual_tokens=textual_tokens, 245 | visual_tokens=visual_tokens, 246 | text_padding_position=text_padding_position, 247 | attn_mask=attn_mask, 248 | vision_masked_position=vision_masked_position, 249 | incremental_state=incremental_state, 250 | positions=positions, 251 | ) 252 | 253 | bsz, _, H, W = visual_tokens.shape 254 | H, W = H//16, W//16 255 | 256 | deform_inputs1, deform_inputs2 = deform_inputs(visual_tokens, ss=self.beit_resolution) 257 | 258 | # SPM forward 259 | c1, c2, c3, c4 = self.spm(visual_tokens) 260 | c2, c3, c4 = self._add_level_embed(c2, c3, c4) 261 | c = torch.cat([c2, c3, c4], dim=1) 262 | 263 | # original beit patch embedding 264 | if self.asymetric_input: 265 | beit_input = F.interpolate( 266 | visual_tokens, size=self.beit_resolution, mode="bilinear" 267 | ) 268 | beit_H, beit_W = self.beit_resolution[0]//16, self.beit_resolution[1]//16 269 | else: 270 | beit_input = visual_tokens 271 | beit_H, beit_W = H, W 272 | 273 | if self.intepolate_pos: 274 | v_cls = self.vision_embed.cls_token.expand(bsz, -1, -1) 275 | v_feat = self.vision_embed.proj(beit_input) 276 | _, _, vH, vW = v_feat.shape 277 | multiway_split_position = (vH * vW) + 1 278 | x1 = torch.cat( 279 | (v_cls, v_feat.flatten(2).transpose(1, 2)), dim=1, 280 | ) 281 | # print(x1.shape) 282 | 283 | v_extra_pos = self.encoder.embed_positions.A.ori_weight[:3] 284 | v_pos = self.encoder.embed_positions.A.ori_weight[3:].reshape(40, 40, -1).permute(2, 0, 1).unsqueeze(0) 285 | v_pos = F.interpolate( 286 | v_pos, 287 | size=(vH, vW), 288 | mode='bicubic', 289 | antialias=False, 290 | align_corners=False, 291 | ).squeeze(0).flatten(1).transpose(0, 1) 292 | self.encoder.embed_positions.A.weight = torch.nn.Parameter(torch.cat((v_extra_pos, v_pos), dim=0)) 293 | 294 | if textual_tokens is not None: 295 | x2 = self.text_embed(textual_tokens) 296 | if segment_ids is not None: 297 | x2 = x2 + self.segment_embed(segment_ids) 298 | 299 | x = torch.cat([x1, x2], dim=1) 300 | if text_padding_position is not None: 301 | encoder_padding_mask = torch.cat( 302 | [ 303 | torch.zeros(x1.shape[:-1]).to(x1.device).bool(), 304 | text_padding_position, 305 | ], 306 | dim=1, 307 | ) 308 | else: 309 | encoder_padding_mask = None 310 | else: 311 | x = x1 312 | encoder_padding_mask = None 313 | multiway_split_position = -1 314 | 315 | else: 316 | if textual_tokens is None: 317 | x = self.vision_embed(beit_input, vision_masked_position) 318 | encoder_padding_mask = None 319 | multiway_split_position = -1 320 | else: 321 | x1 = self.vision_embed(beit_input, vision_masked_position) 322 | multiway_split_position = x1.size(1) 323 | x2 = self.text_embed(textual_tokens) 324 | if segment_ids is not None: 325 | x2 = x2 + self.segment_embed(segment_ids) 326 | 327 | x = torch.cat([x1, x2], dim=1) 328 | if text_padding_position is not None: 329 | encoder_padding_mask = torch.cat( 330 | [ 331 | torch.zeros(x1.shape[:-1]).to(x1.device).bool(), 332 | text_padding_position, 333 | ], 334 | dim=1, 335 | ) 336 | else: 337 | encoder_padding_mask = None 338 | 339 | # beit3 output to encoder input 340 | src_tokens=None 341 | token_embeddings=x 342 | 343 | # original encoder embedding 344 | assert src_tokens is not None or token_embeddings is not None 345 | 346 | if encoder_padding_mask is None: 347 | if src_tokens is not None: 348 | encoder_padding_mask = torch.zeros_like( 349 | src_tokens, device=src_tokens.device 350 | ).bool() 351 | else: 352 | encoder_padding_mask = torch.zeros( 353 | [token_embeddings.size(0), token_embeddings.size(1)], 354 | device=token_embeddings.device, 355 | ).bool() 356 | 357 | if multiway_split_position is not None: 358 | assert self.encoder.args.multiway 359 | self.encoder.apply(set_split_position(multiway_split_position)) 360 | 361 | x, encoder_embedding = self.encoder.forward_embedding(src_tokens, token_embeddings, positions) 362 | x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) 363 | bs, n, dim = x.shape 364 | 365 | rel_pos_bias = None 366 | if self.encoder.relative_position is not None: 367 | rel_pos_bias = self.encoder.relative_position( 368 | batch_size=x.size(0), qlen=x.size(1), klen=x.size(1) 369 | ) 370 | 371 | encoder_states = [] 372 | if return_all_hiddens: 373 | encoder_states.append(x) 374 | 375 | # Interaction 376 | l_aux = None 377 | outs = list() 378 | cls, x = x[:, :1, ], x[:, 1:, ] 379 | for i, layer in enumerate(self.interactions): 380 | indexes = self.interaction_indexes[i] 381 | wrapped_blocks = [ 382 | block_decorator(partial( 383 | self.encoder.layers[i_layer], 384 | encoder_padding_mask=encoder_padding_mask if incremental_state is None else None, 385 | attn_mask=attn_mask, rel_pos=rel_pos_bias, 386 | multiway_split_position=multiway_split_position, 387 | incremental_state=incremental_state[i_layer] if incremental_state is not None else None, 388 | )) 389 | for i_layer in range(indexes[0], indexes[-1]+1) 390 | ] 391 | x, c, cls, hiddens = layer(x, c, cls, multiway_split_position, wrapped_blocks, 392 | deform_inputs1, deform_inputs2, H, W, return_hiddens=return_all_hiddens) 393 | encoder_states.extend(hiddens) 394 | 395 | if multiway_split_position == -1: 396 | outs.append(x.transpose(1, 2).view(bs, dim, beit_H, beit_W).contiguous()) 397 | else: 398 | x_visual, x_text = x[:, :multiway_split_position-1], x[:, multiway_split_position-1:] 399 | outs.append(x_visual.transpose(1, 2).view(bs, dim, beit_H, beit_W).contiguous()) 400 | 401 | # Split & Reshape 402 | c2 = c[:, 0:c2.size(1), :] 403 | c3 = c[:, c2.size(1):c2.size(1) + c3.size(1), :] 404 | c4 = c[:, c2.size(1) + c3.size(1):, :] 405 | 406 | c2 = c2.transpose(1, 2).view(bs, dim, H * 2, W * 2).contiguous() 407 | c3 = c3.transpose(1, 2).view(bs, dim, H, W).contiguous() 408 | c4 = c4.transpose(1, 2).view(bs, dim, H // 2, W // 2).contiguous() 409 | c1 = self.up(c2) + c1 410 | 411 | if self.add_vit_feature: 412 | x1, x2, x3, x4 = outs 413 | x1 = F.interpolate(x1, size=c1.shape[-2:], mode='bilinear', align_corners=False) 414 | x2 = F.interpolate(x2, size=c2.shape[-2:], mode='bilinear', align_corners=False) 415 | x3 = F.interpolate(x3, size=c3.shape[-2:], mode='bilinear', align_corners=False) 416 | x4 = F.interpolate(x4, size=c4.shape[-2:], mode='bilinear', align_corners=False) 417 | c1, c2, c3, c4 = c1 + x1, c2 + x2, c3 + x3, c4 + x4 418 | 419 | # Final Norm 420 | f1 = self.norm1(c1) 421 | f2 = self.norm2(c2) 422 | f3 = self.norm3(c3) 423 | f4 = self.norm4(c4) 424 | 425 | x = torch.cat((cls, x), dim=1) 426 | if self.encoder.layer_norm is not None: 427 | x = self.encoder.layer_norm(x) 428 | visual_feature = x[:, 1:multiway_split_position, :] 429 | text_feature = x[:, multiway_split_position:, :] 430 | 431 | outputs = { 432 | "fpn_features": (f1, f2, f3, f4), 433 | "visual_feature": visual_feature, 434 | "text_feature": text_feature, 435 | "encoder_out": x, 436 | "encoder_embedding": encoder_embedding, 437 | "encoder_padding_mask": encoder_padding_mask, 438 | "encoder_states": encoder_states, 439 | "multiway_split_position": multiway_split_position, 440 | } 441 | 442 | return outputs -------------------------------------------------------------------------------- /beit3_seg_ov_v2.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict, List, Tuple 2 | import numpy as np 3 | 4 | import torch 5 | from torch import nn, Tensor 6 | import torch.nn.functional as F 7 | 8 | from dataclasses import dataclass 9 | from transformers.file_utils import ModelOutput, requires_backends 10 | from transformers.activations import ACT2FN 11 | from transformers.models.mask2former.modeling_mask2former import ( 12 | Mask2FormerPixelDecoder, Mask2FormerModel, Mask2FormerPreTrainedModel, 13 | Mask2FormerTransformerModule, Mask2FormerLoss, Mask2FormerSinePositionEmbedding, 14 | Mask2FormerMaskedAttentionDecoder, Mask2FormerMaskedAttentionDecoderOutput, Mask2FormerAttention, 15 | Mask2FormerMaskPredictor, Mask2FormerHungarianMatcher) 16 | 17 | from .beit3_adapter import BEiT3Adapter 18 | 19 | @dataclass 20 | class BEiT3SegMaskedAttentionDecoderOutput(Mask2FormerMaskedAttentionDecoderOutput): 21 | text_last_hidden_state: torch.FloatTensor = None 22 | text_hidden_states: Optional[Tuple[torch.FloatTensor]] = None 23 | text_intermediate_hidden_states: Optional[Tuple[torch.FloatTensor]] = None 24 | 25 | 26 | @dataclass 27 | class BEiT3SegPixelLevelModuleOutput(ModelOutput): 28 | fpn_features: Tuple[torch.FloatTensor] = None 29 | text_feature: torch.FloatTensor = None 30 | encoder_last_hidden_state: torch.FloatTensor = None 31 | encoder_visual_last_hidden_state: torch.FloatTensor = None 32 | encoder_text_last_hidden_state: torch.FloatTensor = None 33 | 34 | encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None # beit3 new 35 | decoder_last_hidden_state: torch.FloatTensor = None 36 | decoder_hidden_states: Tuple[torch.FloatTensor] = None 37 | 38 | @dataclass 39 | class BEiT3SegModelOutput(ModelOutput): 40 | encoder_visual_last_hidden_state: torch.FloatTensor = None # beit3 new 41 | encoder_text_last_hidden_state: torch.FloatTensor = None # beit3 new 42 | encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None # beit3 new 43 | fpn_features: Tuple[torch.FloatTensor] = None # beit3 new 44 | 45 | transformer_decoder_text_last_hidden_state: torch.FloatTensor = None 46 | transformer_decoder_text_hidden_states: Optional[Tuple[torch.FloatTensor]] = None 47 | transformer_decoder_text_intermediate_states: Tuple[torch.FloatTensor] = None 48 | 49 | pixel_decoder_last_hidden_state: torch.FloatTensor = None 50 | transformer_decoder_last_hidden_state: torch.FloatTensor = None 51 | pixel_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None 52 | transformer_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None 53 | transformer_decoder_intermediate_states: Tuple[torch.FloatTensor] = None 54 | masks_queries_logits: Tuple[torch.FloatTensor] = None 55 | attentions: Optional[Tuple[torch.FloatTensor]] = None 56 | 57 | 58 | @dataclass 59 | class BEiT3SegForUniversalSegmentationOutput(ModelOutput): 60 | loss: Optional[torch.FloatTensor] = None 61 | class_queries_logits: torch.FloatTensor = None 62 | masks_queries_logits: torch.FloatTensor = None 63 | auxiliary_logits: Optional[List[Dict[str, torch.FloatTensor]]] = None 64 | 65 | encoder_visual_last_hidden_state: torch.FloatTensor = None # beit3 new 66 | encoder_text_last_hidden_state: torch.FloatTensor = None # beit3 new 67 | fpn_features: Tuple[torch.FloatTensor] = None # beit3 new 68 | encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None # beit3 new 69 | 70 | pixel_decoder_last_hidden_state: torch.FloatTensor = None 71 | transformer_decoder_last_hidden_state: torch.FloatTensor = None 72 | pixel_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None 73 | transformer_decoder_hidden_states: Optional[torch.FloatTensor] = None 74 | attentions: Optional[Tuple[torch.FloatTensor]] = None 75 | 76 | loss_dict: Optional[Dict[str, torch.FloatTensor]] = None 77 | 78 | # Adapted from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py 79 | class BEiT3SegLoss(Mask2FormerLoss): 80 | def __init__(self, config, weight_dict: Dict[str, float]): 81 | """ 82 | The Mask2Former Loss. The loss is computed very similar to DETR. The process happens in two steps: 1) we 83 | compute hungarian assignment between ground truth masks and the outputs of the model 2) we supervise each pair 84 | of matched ground-truth / prediction (supervise class and mask) 85 | 86 | Args: 87 | config (`Mask2FormerConfig`): 88 | The configuration for Mask2Former model also containing loss calculation specific parameters. 89 | weight_dict (`Dict[str, float]`): 90 | A dictionary of weights to be applied to the different losses. 91 | """ 92 | super().__init__(config, weight_dict) 93 | self.match_once_only = config.match_once_only 94 | self.drop_first_ce_loss = config.drop_first_ce_loss 95 | self.use_objectness_loss = config.use_objectness_loss 96 | 97 | if self.use_objectness_loss: 98 | self.obj_loss_weight = torch.ones(2) 99 | self.obj_loss_weight[-1] = self.eos_coef 100 | self.empty_weight = self.empty_weight[:-1] 101 | 102 | def forward( 103 | self, 104 | masks_queries_logits: torch.Tensor, 105 | class_queries_logits: torch.Tensor, 106 | mask_labels: List[torch.Tensor], 107 | class_labels: List[torch.Tensor], 108 | auxiliary_predictions: Optional[Dict[str, torch.Tensor]] = None, 109 | ) -> Dict[str, torch.Tensor]: 110 | 111 | if self.use_objectness_loss: 112 | class_queries_logits_cls, class_queries_logits_obj = class_queries_logits 113 | # all obj_labels = 0 (first class) 114 | obj_labels = [torch.zeros_like(class_label) for class_label in class_labels] 115 | indices = self.matcher(masks_queries_logits, class_queries_logits_obj, mask_labels, obj_labels) 116 | num_masks = self.get_num_masks(obj_labels, device=obj_labels[0].device) 117 | 118 | # print('class_queries_logits_cls', class_queries_logits_cls.shape) 119 | # print('class_queries_logits_obj', class_queries_logits_obj.shape) 120 | 121 | losses: Dict[str, Tensor] = { 122 | **self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks), 123 | **self.loss_labels( 124 | class_queries_logits_obj, obj_labels, indices, 125 | empty_weight=self.obj_loss_weight.to(class_queries_logits_obj.device), 126 | fill_value=1, loss_name="loss_objectness", 127 | ), 128 | **self.loss_labels( 129 | class_queries_logits_cls, class_labels, indices, 130 | empty_weight=self.empty_weight.to(class_queries_logits_cls.device), 131 | fill_value=-100, loss_name="loss_cross_entropy", 132 | ), 133 | } 134 | 135 | else: 136 | # retrieve the matching between the outputs of the last layer and the labels 137 | # print('run match') 138 | indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels) 139 | # compute the average number of target masks for normalization purposes 140 | num_masks = self.get_num_masks(class_labels, device=class_labels[0].device) 141 | # get all the losses 142 | losses: Dict[str, Tensor] = { 143 | **self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks), 144 | **self.loss_labels(class_queries_logits, class_labels, indices), 145 | } 146 | 147 | # in case of auxiliary losses, we repeat this process with the output of each intermediate layer. 148 | if auxiliary_predictions is not None: 149 | for idx, aux_outputs in enumerate(auxiliary_predictions): 150 | masks_queries_logits = aux_outputs["masks_queries_logits"] 151 | class_queries_logits = aux_outputs["class_queries_logits"] 152 | if self.match_once_only: 153 | raise NotImplementedError('not implement match_once_only') 154 | loss_dict = { 155 | **self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks), 156 | **self.loss_labels(class_queries_logits, class_labels, indices), 157 | } 158 | else: 159 | loss_dict = self.forward(masks_queries_logits, class_queries_logits, mask_labels, class_labels) 160 | 161 | if idx == 0 and self.drop_first_ce_loss: 162 | if 'loss_cross_entropy' in loss_dict: 163 | del loss_dict['loss_cross_entropy'] 164 | if 'loss_objectness' in loss_dict: 165 | del loss_dict['loss_objectness'] 166 | 167 | loss_dict = {f"{key}_{idx}": value for key, value in loss_dict.items()} 168 | losses.update(loss_dict) 169 | 170 | return losses 171 | 172 | def loss_labels( 173 | self, class_queries_logits: Tensor, class_labels: List[Tensor], indices: Tuple[np.array], 174 | empty_weight=None, fill_value=None, loss_name="loss_cross_entropy", 175 | ): 176 | if empty_weight is None: 177 | empty_weight = self.empty_weight 178 | if fill_value is None: 179 | fill_value = self.num_labels 180 | 181 | pred_logits = class_queries_logits 182 | batch_size, num_queries, _ = pred_logits.shape 183 | criterion = nn.CrossEntropyLoss(empty_weight) 184 | idx = self._get_predictions_permutation_indices(indices) # shape of (batch_size, num_queries) 185 | target_classes_o = torch.cat( 186 | [target[j] for target, (_, j) in zip(class_labels, indices)] 187 | ) # shape of (batch_size, num_queries) 188 | target_classes = torch.full( 189 | (batch_size, num_queries), fill_value=fill_value, dtype=torch.int64, device=pred_logits.device 190 | ) 191 | target_classes[idx] = target_classes_o 192 | # Permute target_classes (batch_size, num_queries, num_labels) -> (batch_size, num_labels, num_queries) 193 | pred_logits_transposed = pred_logits.transpose(1, 2) 194 | loss_ce = criterion(pred_logits_transposed, target_classes) 195 | losses = {loss_name: loss_ce} 196 | return losses 197 | 198 | class BEiT3SegPixelLevelModule(nn.Module): 199 | def __init__(self, config): 200 | super().__init__() 201 | 202 | self.encoder = BEiT3Adapter(**config.backbone_config) 203 | self.decoder = Mask2FormerPixelDecoder(config, feature_channels=[self.encoder.embed_dim] * 4) 204 | 205 | def forward( 206 | self, pixel_values, 207 | input_ids=None, text_padding_position=None, 208 | output_hidden_states=False, 209 | ): 210 | backbone_features = self.encoder( 211 | visual_tokens=pixel_values, 212 | textual_tokens=input_ids, 213 | text_padding_position=text_padding_position, 214 | use_vit_adapter=True, 215 | return_all_hiddens=True, 216 | ) 217 | 218 | fpn_features = backbone_features['fpn_features'] 219 | text_feature = backbone_features['text_feature'] 220 | encoder_out = backbone_features['encoder_out'] 221 | multiway_split_position = backbone_features['multiway_split_position'] 222 | if multiway_split_position == -1: 223 | encoder_visual_last_hidden_state = encoder_out 224 | encoder_text_last_hidden_state = None 225 | else: 226 | encoder_visual_last_hidden_state = encoder_out[:, :multiway_split_position] 227 | encoder_text_last_hidden_state = encoder_out[:, multiway_split_position:] 228 | 229 | encoder_encoder_states = backbone_features['encoder_states'] 230 | if multiway_split_position == -1: 231 | encoder_encoder_states = [[state[:, :1], state[:, 1:]] for state in encoder_encoder_states] 232 | else: 233 | encoder_encoder_states = [ 234 | [state[:, :1], state[:, 1:multiway_split_position], state[:, multiway_split_position:]] 235 | for state in encoder_encoder_states] 236 | 237 | decoder_output = self.decoder(fpn_features, output_hidden_states=output_hidden_states) 238 | 239 | return BEiT3SegPixelLevelModuleOutput( 240 | fpn_features=fpn_features, 241 | text_feature=text_feature, 242 | 243 | encoder_last_hidden_state=encoder_out, 244 | encoder_visual_last_hidden_state=encoder_visual_last_hidden_state, 245 | encoder_text_last_hidden_state=encoder_text_last_hidden_state, 246 | encoder_hidden_states=encoder_encoder_states, 247 | 248 | decoder_last_hidden_state=decoder_output.mask_features, 249 | decoder_hidden_states=decoder_output.multi_scale_features, 250 | ) 251 | 252 | class BEiT3SegMaskedAttentionDecoderLayer(nn.Module): 253 | """ 254 | The Mask2FormerMaskedAttentionDecoderLayer is made up of self-attention, cross (masked) attention as well as FFN 255 | blocks. The cross attention block used as part of `Mask2FormerMaskedAttentionDecoderLayer` is actually a `masked 256 | attention` block that restricts the attention to localized features centered around predicted segments which leads 257 | to faster convergence and improved performance. The order of self and cross (i.e. masked) attention blocks have 258 | also been swapped in Mask2FormerMaskedAttentionDecoder compared to a standard DetrDecoder as an optimization 259 | improvement. 260 | 261 | Args: 262 | config (`Mask2FormerConfig`): 263 | The configuration used to initialize the Mask2FormerMaskedAttentionDecoder. 264 | """ 265 | 266 | def __init__(self, config): 267 | super().__init__() 268 | self.config = config 269 | self.embed_dim = self.config.hidden_dim 270 | self.pre_norm = self.config.pre_norm 271 | self.self_attn = nn.MultiheadAttention( 272 | self.embed_dim, 273 | self.config.num_attention_heads, 274 | self.config.dropout, 275 | ) 276 | # self.self_attn = Mask2FormerAttention( 277 | # embed_dim=self.embed_dim, 278 | # num_heads=config.num_attention_heads, 279 | # dropout=config.dropout, 280 | # is_decoder=True, 281 | # ) 282 | 283 | self.dropout = self.config.dropout 284 | self.activation_fn = ACT2FN[self.config.activation_function] 285 | self.activation_dropout = self.config.dropout 286 | 287 | self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) 288 | self.cross_attn = nn.MultiheadAttention(self.embed_dim, self.config.num_attention_heads, self.config.dropout) 289 | self.cross_attn_layer_norm = nn.LayerNorm(self.embed_dim) 290 | self.fc1 = nn.Linear(self.embed_dim, self.config.dim_feedforward) 291 | self.fc2 = nn.Linear(self.config.dim_feedforward, self.embed_dim) 292 | self.final_layer_norm = nn.LayerNorm(self.embed_dim) 293 | 294 | self.use_text_cross_attn = self.config.use_text_cross_attn 295 | if self.use_text_cross_attn: 296 | self.text_cross_attn = nn.MultiheadAttention(self.embed_dim, self.config.num_attention_heads, self.config.dropout) 297 | self.text_cross_attn_layer_norm = nn.LayerNorm(self.embed_dim) 298 | 299 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 300 | return tensor if pos is None else tensor + pos 301 | 302 | def forward_post( 303 | self, 304 | hidden_states: torch.Tensor, 305 | level_index: int = None, 306 | attention_mask: Optional[torch.Tensor] = None, 307 | position_embeddings: Optional[torch.Tensor] = None, 308 | text_position_embeddings: Optional[torch.Tensor] = None, 309 | query_position_embeddings: Optional[torch.Tensor] = None, 310 | encoder_hidden_states: Optional[torch.Tensor] = None, 311 | encoder_text_hidden_states: Optional[torch.Tensor] = None, 312 | encoder_attention_mask: Optional[torch.Tensor] = None, 313 | output_attentions: Optional[bool] = False, 314 | ): 315 | # Masked(Cross)-Attention Block 316 | cross_attn_weights = None 317 | self_attn_weights = None 318 | 319 | residual = hidden_states 320 | 321 | hidden_states, cross_attn_weights = self.cross_attn( 322 | query=self.with_pos_embed(hidden_states, query_position_embeddings), 323 | key=self.with_pos_embed(encoder_hidden_states[level_index], position_embeddings[level_index]), 324 | value=encoder_hidden_states[level_index], 325 | attn_mask=encoder_attention_mask, 326 | key_padding_mask=None, 327 | ) 328 | 329 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 330 | hidden_states = residual + hidden_states 331 | hidden_states = self.cross_attn_layer_norm(hidden_states) 332 | 333 | # Text Cross-Attention Block 334 | if self.use_text_cross_attn: 335 | residual = hidden_states 336 | hidden_states, _ = self.text_cross_attn( 337 | query=self.with_pos_embed(hidden_states, query_position_embeddings), 338 | key=self.with_pos_embed(encoder_text_hidden_states, text_position_embeddings), 339 | value=encoder_text_hidden_states, 340 | attn_mask=None, 341 | key_padding_mask=None, 342 | ) 343 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 344 | hidden_states = residual + hidden_states 345 | hidden_states = self.text_cross_attn_layer_norm(hidden_states) 346 | 347 | # Self Attention Block 348 | residual = hidden_states 349 | hidden_states, self_attn_weights = self.self_attn( 350 | query=self.with_pos_embed(hidden_states, query_position_embeddings), 351 | key=self.with_pos_embed(hidden_states, query_position_embeddings), 352 | value=hidden_states, 353 | attn_mask=None, 354 | key_padding_mask=None, 355 | ) 356 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 357 | hidden_states = residual + hidden_states 358 | hidden_states = self.self_attn_layer_norm(hidden_states) 359 | 360 | # Fully Connected 361 | residual = hidden_states 362 | hidden_states = self.activation_fn(self.fc1(hidden_states)) 363 | hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) 364 | hidden_states = self.fc2(hidden_states) 365 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 366 | hidden_states = residual + hidden_states 367 | hidden_states = self.final_layer_norm(hidden_states) 368 | outputs = (hidden_states, None, ) 369 | 370 | if output_attentions: 371 | outputs += (self_attn_weights, cross_attn_weights) 372 | 373 | return outputs 374 | 375 | def forward( 376 | self, 377 | hidden_states: torch.Tensor, 378 | level_index: int = None, 379 | attention_mask: Optional[torch.Tensor] = None, 380 | position_embeddings: Optional[torch.Tensor] = None, 381 | text_position_embeddings: Optional[torch.Tensor] = None, 382 | query_position_embeddings: Optional[torch.Tensor] = None, 383 | encoder_hidden_states: Optional[torch.Tensor] = None, 384 | encoder_text_hidden_states: Optional[torch.Tensor] = None, 385 | encoder_attention_mask: Optional[torch.Tensor] = None, 386 | output_attentions: Optional[bool] = False, 387 | ): 388 | """ 389 | Args: 390 | hidden_states (`torch.FloatTensor`): 391 | Input to the layer of shape `(seq_len, batch, embed_dim)`. 392 | attention_mask (`torch.FloatTensor`): 393 | Attention mask of shape `(1, seq_len, tgt_len, src_len)`. 394 | position_embeddings (`torch.FloatTensor`, *optional*): 395 | Position embeddings that are added to the keys in the masked-attention layer. 396 | query_position_embeddings (`torch.FloatTensor`, *optional*): 397 | Position embeddings that are added to the queries and keys in the self-attention layer. 398 | encoder_hidden_states (`torch.FloatTensor`): 399 | Cross attention input to the layer of shape `(seq_len, batch, embed_dim)`. 400 | encoder_attention_mask (`torch.FloatTensor`): 401 | Encoder attention mask of size`(1, seq_len, tgt_len, src_len)`. 402 | output_attentions (`bool`, *optional*): 403 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 404 | returned tensors for more detail. 405 | """ 406 | 407 | if self.pre_norm: 408 | raise NotImplementedError('not implement pre-norm') 409 | else: 410 | outputs = self.forward_post( 411 | hidden_states=hidden_states, 412 | level_index=level_index, 413 | position_embeddings=position_embeddings, 414 | text_position_embeddings=text_position_embeddings, 415 | query_position_embeddings=query_position_embeddings, 416 | encoder_hidden_states=encoder_hidden_states, 417 | encoder_text_hidden_states=encoder_text_hidden_states, 418 | encoder_attention_mask=encoder_attention_mask, 419 | output_attentions=output_attentions, 420 | ) 421 | 422 | return outputs 423 | 424 | 425 | class BEiT3SegMaskPredictor(Mask2FormerMaskPredictor): 426 | def forward(self, outputs: torch.Tensor, pixel_embeddings: torch.Tensor, attention_mask_target_size: int = None): 427 | mask_embeddings = self.mask_embedder(outputs.transpose(0, 1)) 428 | 429 | outputs_mask = torch.einsum('bqc, bchw -> bqhw', mask_embeddings, pixel_embeddings) 430 | 431 | attention_mask = nn.functional.interpolate( 432 | outputs_mask, size=attention_mask_target_size, mode="bilinear", align_corners=False 433 | ) 434 | 435 | attention_mask = attention_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1) 436 | attention_mask = (attention_mask.flatten(0, 1) < 0.5).bool() 437 | attention_mask = attention_mask.detach() 438 | 439 | return outputs_mask, attention_mask 440 | 441 | class BEiT3SegMaskedAttentionDecoder(nn.Module): 442 | 443 | def __init__(self, config): 444 | super().__init__() 445 | 446 | self.config = config 447 | self.mask_feature_size = config.mask_feature_size 448 | self.dropout = config.dropout 449 | self.layerdrop = config.dropout 450 | self.num_feature_levels = 3 # level embedding (3 scales) 451 | self.decoder_layers = config.decoder_layers - 1 452 | 453 | self.layers = nn.ModuleList( 454 | [BEiT3SegMaskedAttentionDecoderLayer(self.config) for _ in range(self.decoder_layers)] 455 | ) 456 | self.layernorm = nn.LayerNorm(config.hidden_dim) 457 | 458 | self.mask_predictor = BEiT3SegMaskPredictor( 459 | hidden_size=config.hidden_dim, 460 | num_heads=config.num_attention_heads, 461 | mask_feature_size=self.mask_feature_size, 462 | ) 463 | 464 | self.gradient_checkpointing = False 465 | 466 | def forward( 467 | self, 468 | inputs_embeds: torch.Tensor = None, 469 | multi_stage_positional_embeddings: torch.Tensor = None, 470 | text_positional_embeddings: torch.Tensor = None, 471 | pixel_embeddings: torch.Tensor = None, 472 | encoder_hidden_states: torch.Tensor = None, 473 | encoder_text_hidden_states: torch.Tensor = None, 474 | query_position_embeddings: torch.Tensor = None, 475 | feature_size_list: List = None, 476 | output_attentions: Optional[bool] = None, 477 | output_hidden_states: Optional[bool] = None, 478 | return_dict: Optional[bool] = None, 479 | ): 480 | r""" 481 | Args: 482 | inputs_embeds (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`): 483 | The query embeddings that are passed into the decoder. 484 | multi_stage_positional_embeddings (`torch.FloatTensor` of shape `(height*width, batch_size, num_channels)`): 485 | Position embeddings that are added to the keys in each cross(masked)-attention layer. 486 | pixel_embeddings (`torch.FloatTensor`): 487 | Tensor of shape `(batch_size, num_channels, height, width)`, 1/4 scale features from the last Pixel 488 | Decoder. 489 | query_position_embeddings (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`): 490 | , *optional*): Position embeddings that are added to the queries and keys in each self-attention layer. 491 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`): 492 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the 493 | cross(masked)-attention of the decoder. 494 | feature_size_list (`List[torch.Size]` ): 495 | This is a list containing shapes (height & width) of multi-scale features from the Pixel Decoder. 496 | output_attentions (`bool`, *optional*): 497 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 498 | returned tensors for more detail. 499 | output_hidden_states (`bool`, *optional*): 500 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 501 | for more detail. 502 | return_dict (`bool`, *optional*): 503 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 504 | """ 505 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 506 | output_hidden_states = ( 507 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 508 | ) 509 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 510 | 511 | if inputs_embeds is not None: 512 | hidden_states = inputs_embeds 513 | 514 | # intermediate hidden states with layernorm applied - required for predicting class logits 515 | intermediate = () 516 | text_intermediate = () 517 | 518 | # decoder layers 519 | all_hidden_states = () if output_hidden_states else None 520 | all_text_hidden_states = () if output_hidden_states else None 521 | attentions = () if output_attentions else None 522 | 523 | # intermediate mask predictions from transformer decoder layers 524 | intermediate_mask_predictions = () 525 | 526 | intermediate_hidden_states = self.layernorm(inputs_embeds) 527 | intermediate += (intermediate_hidden_states,) 528 | 529 | predicted_mask, attention_mask = self.mask_predictor( 530 | intermediate_hidden_states, pixel_embeddings, feature_size_list[0] 531 | ) 532 | intermediate_mask_predictions += (predicted_mask,) 533 | 534 | for idx, decoder_layer in enumerate(self.layers): 535 | if output_hidden_states: 536 | all_hidden_states += (hidden_states,) 537 | # all_text_hidden_states += (text_hidden_states,) 538 | 539 | dropout_probability = torch.rand([]) 540 | 541 | if self.training and (dropout_probability < self.layerdrop): 542 | continue 543 | 544 | if self.gradient_checkpointing and self.training: 545 | raise NotImplementedError('no grad checkpointing') 546 | layer_outputs = self._gradient_checkpointing_func( 547 | decoder_layer.__call__, 548 | hidden_states, 549 | attention_mask, 550 | encoder_hidden_states, 551 | None, 552 | None, 553 | output_attentions, 554 | ) 555 | 556 | else: 557 | level_index = idx % self.num_feature_levels 558 | 559 | attention_mask[torch.where(attention_mask.sum(-1) == attention_mask.shape[-1])] = False 560 | 561 | layer_outputs = decoder_layer( 562 | hidden_states, 563 | level_index=level_index, 564 | position_embeddings=multi_stage_positional_embeddings, 565 | text_position_embeddings=text_positional_embeddings, 566 | query_position_embeddings=query_position_embeddings, 567 | encoder_hidden_states=encoder_hidden_states, 568 | encoder_text_hidden_states=encoder_text_hidden_states, 569 | encoder_attention_mask=attention_mask, 570 | output_attentions=output_attentions, 571 | ) 572 | 573 | intermediate_hidden_states = self.layernorm(layer_outputs[0]) 574 | 575 | predicted_mask, attention_mask = self.mask_predictor( 576 | intermediate_hidden_states, 577 | pixel_embeddings, 578 | feature_size_list[(idx + 1) % self.num_feature_levels], 579 | ) 580 | 581 | intermediate_mask_predictions += (predicted_mask,) 582 | 583 | # add intermediate hidden states with layer norm applied which will be used for predicting class logits 584 | intermediate += (intermediate_hidden_states,) 585 | 586 | hidden_states = layer_outputs[0] 587 | text_hidden_states = layer_outputs[1] 588 | 589 | if output_attentions: 590 | attentions += (layer_outputs[2],) 591 | 592 | # add hidden states from the last decoder layer 593 | if output_hidden_states: 594 | all_hidden_states += (hidden_states,) 595 | all_text_hidden_states += (text_hidden_states,) 596 | 597 | hidden_states = hidden_states.transpose(1, 0) 598 | # text_hidden_states = text_hidden_states.transpose(1, 0) 599 | if not return_dict: 600 | outputs = [hidden_states, all_hidden_states, attentions, intermediate, intermediate_mask_predictions] 601 | return tuple(v for v in outputs if v is not None) 602 | 603 | return BEiT3SegMaskedAttentionDecoderOutput( 604 | last_hidden_state=hidden_states, 605 | hidden_states=all_hidden_states, 606 | # text_last_hidden_state=text_hidden_states, 607 | text_hidden_states=all_text_hidden_states, 608 | attentions=attentions, 609 | intermediate_hidden_states=intermediate, 610 | text_intermediate_hidden_states=text_intermediate, 611 | masks_queries_logits=intermediate_mask_predictions, 612 | ) 613 | 614 | class BEiT3SegTransformerModule(nn.Module): 615 | def __init__(self, in_features, config): 616 | super().__init__() 617 | hidden_dim = config.hidden_dim 618 | self.num_feature_levels = 3 619 | self.position_embedder = Mask2FormerSinePositionEmbedding(num_pos_feats=hidden_dim // 2, normalize=True) 620 | self.queries_embedder = nn.Embedding(config.num_queries, hidden_dim) 621 | self.queries_features = nn.Embedding(config.num_queries, hidden_dim) 622 | self.input_projections = [] 623 | 624 | # for text 625 | self.use_text_features = config.use_text_features 626 | if self.use_text_features: 627 | self.text_position_embedding = nn.Embedding(1000, hidden_dim) 628 | self.text_queries_features = nn.Embedding(config.num_queries, hidden_dim) 629 | # self.psuedo_class_embedder = nn.Embedding(config.num_labels + 1, hidden_dim) 630 | self.text_input_projections = nn.Linear(config.backbone_dim, hidden_dim) 631 | 632 | for _ in range(self.num_feature_levels): 633 | if in_features != hidden_dim or config.enforce_input_projection: 634 | self.input_projections.append(nn.Conv2d(in_features, hidden_dim, kernel_size=1)) 635 | else: 636 | self.input_projections.append(nn.Sequential()) 637 | 638 | self.decoder = BEiT3SegMaskedAttentionDecoder(config=config) 639 | self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim) 640 | 641 | def forward( 642 | self, 643 | multi_scale_features: List[Tensor], 644 | mask_features: Tensor, 645 | text_features: Tensor, 646 | output_hidden_states: bool = False, 647 | output_attentions: bool = False, 648 | ) -> BEiT3SegMaskedAttentionDecoderOutput: 649 | multi_stage_features = [] 650 | multi_stage_positional_embeddings = [] 651 | size_list = [] 652 | 653 | for i in range(self.num_feature_levels): 654 | size_list.append(multi_scale_features[i].shape[-2:]) 655 | multi_stage_positional_embeddings.append(self.position_embedder(multi_scale_features[i], None).flatten(2)) 656 | multi_stage_features.append( 657 | self.input_projections[i](multi_scale_features[i]).flatten(2) 658 | + self.level_embed.weight[i][None, :, None] 659 | ) 660 | 661 | # Flatten (batch_size, num_channels, height, width) -> (height*width, batch_size, num_channels) 662 | multi_stage_positional_embeddings[-1] = multi_stage_positional_embeddings[-1].permute(2, 0, 1) 663 | multi_stage_features[-1] = multi_stage_features[-1].permute(2, 0, 1) 664 | 665 | # for text 666 | if text_features is not None and self.use_text_features: 667 | batch_size, text_seq_len, _ = text_features.shape 668 | # bsz, text_len, c -> text_len, bsz, c 669 | text_embeddings = self.text_input_projections(text_features.transpose(0, 1)) 670 | text_pos_embedding = None 671 | # text_pos_embedding = self.text_position_embedding.weight[:text_seq_len].unsqueeze(1).repeat(1, batch_size, 1) 672 | # text_embeddings = text_embeddings + text_pos_embedding 673 | else: 674 | text_embeddings = None 675 | text_pos_embedding = None 676 | 677 | _, batch_size, _ = multi_stage_features[0].shape 678 | 679 | # [num_queries, batch_size, num_channels] 680 | query_embeddings = self.queries_embedder.weight.unsqueeze(1).repeat(1, batch_size, 1) 681 | query_features = self.queries_features.weight.unsqueeze(1).repeat(1, batch_size, 1) 682 | # text_query_features = self.text_queries_features.weight.unsqueeze(1).repeat(1, batch_size, 1) 683 | 684 | decoder_output = self.decoder( 685 | inputs_embeds=query_features, 686 | # text_inputs_embeds=text_query_features, 687 | multi_stage_positional_embeddings=multi_stage_positional_embeddings, 688 | text_positional_embeddings=text_pos_embedding, 689 | pixel_embeddings=mask_features, 690 | encoder_hidden_states=multi_stage_features, 691 | encoder_text_hidden_states=text_embeddings, 692 | query_position_embeddings=query_embeddings, 693 | feature_size_list=size_list, 694 | output_hidden_states=output_hidden_states, 695 | output_attentions=output_attentions, 696 | return_dict=True, 697 | ) 698 | 699 | return decoder_output 700 | 701 | 702 | class BEiT3SegModel(Mask2FormerPreTrainedModel): 703 | main_input_name = "pixel_values" 704 | 705 | def __init__(self, config): 706 | super().__init__(config) 707 | self.pixel_level_module = BEiT3SegPixelLevelModule(config) 708 | self.transformer_module = BEiT3SegTransformerModule(in_features=config.feature_size, config=config) 709 | self.post_init() 710 | 711 | def forward( 712 | self, 713 | pixel_values: Tensor, 714 | input_ids: Optional[Tensor] = None, 715 | cat_input_ids: Optional[Tensor] = None, 716 | text_padding_position=None, 717 | pixel_mask: Optional[Tensor] = None, 718 | output_hidden_states: Optional[bool] = None, 719 | output_attentions: Optional[bool] = None, 720 | return_dict: Optional[bool] = None, 721 | ) -> BEiT3SegModelOutput: 722 | 723 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 724 | output_hidden_states = ( 725 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 726 | ) 727 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 728 | 729 | batch_size, _, height, width = pixel_values.shape 730 | 731 | if pixel_mask is None: 732 | pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device) 733 | 734 | pixel_level_module_output = self.pixel_level_module( 735 | pixel_values=pixel_values, 736 | input_ids=input_ids, text_padding_position=text_padding_position, 737 | output_hidden_states=output_hidden_states, 738 | ) 739 | 740 | if input_ids is None: 741 | pixel_level_module_output.text_feature = None 742 | 743 | if cat_input_ids is not None: 744 | pixel_level_module_output.text_feature = pixel_level_module_output.text_feature[torch.arange(batch_size).unsqueeze(-1), cat_input_ids] 745 | 746 | transformer_module_output = self.transformer_module( 747 | multi_scale_features=pixel_level_module_output.decoder_hidden_states, 748 | mask_features=pixel_level_module_output.decoder_last_hidden_state, 749 | text_features=pixel_level_module_output.text_feature, 750 | output_hidden_states=True, 751 | output_attentions=output_attentions, 752 | ) 753 | 754 | fpn_features = None 755 | pixel_decoder_hidden_states = None 756 | transformer_decoder_hidden_states = None 757 | transformer_decoder_intermediate_states = None 758 | transformer_decoder_text_hidden_states = None 759 | transformer_decoder_text_intermediate_states = None 760 | 761 | if output_hidden_states: 762 | fpn_features = pixel_level_module_output.fpn_features 763 | pixel_decoder_hidden_states = pixel_level_module_output.decoder_hidden_states 764 | transformer_decoder_hidden_states = transformer_module_output.hidden_states 765 | transformer_decoder_intermediate_states = transformer_module_output.intermediate_hidden_states 766 | 767 | transformer_decoder_text_hidden_states = transformer_module_output.text_hidden_states 768 | transformer_decoder_text_intermediate_states = transformer_module_output.text_intermediate_hidden_states 769 | 770 | output = BEiT3SegModelOutput( 771 | encoder_visual_last_hidden_state=pixel_level_module_output.encoder_visual_last_hidden_state, 772 | encoder_text_last_hidden_state=pixel_level_module_output.encoder_text_last_hidden_state, 773 | encoder_hidden_states=pixel_level_module_output.encoder_hidden_states, 774 | pixel_decoder_last_hidden_state=pixel_level_module_output.decoder_last_hidden_state, 775 | transformer_decoder_last_hidden_state=transformer_module_output.last_hidden_state, 776 | fpn_features=fpn_features, 777 | pixel_decoder_hidden_states=pixel_decoder_hidden_states, 778 | transformer_decoder_hidden_states=transformer_decoder_hidden_states, 779 | transformer_decoder_intermediate_states=transformer_decoder_intermediate_states, 780 | attentions=transformer_module_output.attentions, 781 | masks_queries_logits=transformer_module_output.masks_queries_logits, 782 | 783 | transformer_decoder_text_last_hidden_state=transformer_module_output.text_last_hidden_state, 784 | transformer_decoder_text_hidden_states=transformer_decoder_text_hidden_states, 785 | transformer_decoder_text_intermediate_states=transformer_decoder_text_intermediate_states, 786 | ) 787 | 788 | if not return_dict: 789 | output = tuple(v for v in output.values() if v is not None) 790 | 791 | return output 792 | 793 | class BEiT3SegForUniversalSegmentation(Mask2FormerPreTrainedModel): 794 | main_input_name = "pixel_values" 795 | 796 | def __init__(self, config): 797 | super().__init__(config) 798 | self.model = BEiT3SegModel(config) 799 | 800 | self.weight_dict: Dict[str, float] = { 801 | "loss_objectness": config.objectness_weight if config.use_objectness_loss else 0.0, 802 | "loss_cross_entropy": config.class_weight, 803 | "loss_mask": config.mask_weight, 804 | "loss_dice": config.dice_weight, 805 | } 806 | 807 | self.use_text_contrastive_loss = config.use_text_contrastive_loss 808 | self.use_objectness_loss = config.use_objectness_loss 809 | 810 | if self.use_objectness_loss and self.use_text_contrastive_loss: 811 | self.class_predictor = nn.Sequential( 812 | nn.Linear(config.hidden_dim, 2), 813 | ) 814 | self.query_head = nn.Sequential( 815 | nn.Linear(config.hidden_dim, config.hidden_dim), 816 | nn.GELU(), 817 | nn.Linear(config.hidden_dim, config.hidden_dim), 818 | nn.GELU(), 819 | nn.Linear(config.hidden_dim, config.hidden_dim), 820 | ) 821 | self.text_target_head = nn.Sequential( 822 | nn.Linear(config.backbone_dim, config.hidden_dim), 823 | ) 824 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07), requires_grad=True) 825 | 826 | elif self.use_text_contrastive_loss: 827 | self.class_predictor = nn.Sequential( 828 | nn.Linear(config.hidden_dim, 1), 829 | ) 830 | self.query_head = nn.Sequential( 831 | nn.Linear(config.hidden_dim, config.hidden_dim), 832 | nn.GELU(), 833 | nn.Linear(config.hidden_dim, config.hidden_dim), 834 | nn.GELU(), 835 | nn.Linear(config.hidden_dim, config.hidden_dim), 836 | ) 837 | self.text_target_head = nn.Sequential( 838 | nn.Linear(config.backbone_dim, config.hidden_dim), 839 | ) 840 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07), requires_grad=True) 841 | # self.null_embed = nn.Parameter(torch.randn(1, 1, config.hidden_dim)) 842 | # # self.target_embedding = nn.Embedding(config.num_labels+1, config.hidden_dim) 843 | 844 | else: 845 | self.class_predictor = nn.Sequential( 846 | nn.Linear(config.hidden_dim, config.num_labels + 1) 847 | ) 848 | 849 | self.criterion = BEiT3SegLoss(config=config, weight_dict=self.weight_dict) 850 | self.post_init() 851 | 852 | def get_loss_dict( 853 | self, 854 | masks_queries_logits: Tensor, 855 | class_queries_logits: Tensor, 856 | mask_labels: Tensor, 857 | class_labels: Tensor, 858 | auxiliary_predictions: Dict[str, Tensor], 859 | ) -> Dict[str, Tensor]: 860 | loss_dict: Dict[str, Tensor] = self.criterion( 861 | masks_queries_logits=masks_queries_logits, 862 | class_queries_logits=class_queries_logits, 863 | mask_labels=mask_labels, 864 | class_labels=class_labels, 865 | auxiliary_predictions=auxiliary_predictions, 866 | ) 867 | 868 | # weight each loss by `self.weight_dict[]` including auxiliary losses 869 | for key, weight in self.weight_dict.items(): 870 | for loss_key, loss in loss_dict.items(): 871 | if key in loss_key: 872 | loss *= weight 873 | 874 | return loss_dict 875 | 876 | def get_loss(self, loss_dict: Dict[str, Tensor]) -> Tensor: 877 | return sum(loss_dict.values()) 878 | 879 | def get_auxiliary_logits(self, classes: torch.Tensor, output_masks: torch.Tensor): 880 | auxiliary_logits: List[Dict(str, Tensor)] = [] 881 | 882 | for aux_binary_masks, aux_classes in zip(output_masks[:-1], classes[:-1]): 883 | auxiliary_logits.append({"masks_queries_logits": aux_binary_masks, "class_queries_logits": aux_classes}) 884 | 885 | return auxiliary_logits 886 | 887 | def forward( 888 | self, 889 | pixel_values: Tensor, 890 | input_ids: Optional[Tensor] = None, 891 | text_padding_position=None, 892 | cat_input_ids: Optional[Tensor] = None, 893 | mask_labels: Optional[List[Tensor]] = None, 894 | class_labels: Optional[List[Tensor]] = None, 895 | pixel_mask: Optional[Tensor] = None, 896 | output_hidden_states: Optional[bool] = None, 897 | output_auxiliary_logits: Optional[bool] = None, 898 | output_attentions: Optional[bool] = None, 899 | return_dict: Optional[bool] = None, 900 | return_loss_dict: Optional[bool] = None, 901 | ) -> BEiT3SegForUniversalSegmentationOutput: 902 | 903 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 904 | output_hidden_states = ( 905 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 906 | ) 907 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 908 | 909 | outputs = self.model( 910 | pixel_values=pixel_values, 911 | input_ids=input_ids, 912 | cat_input_ids=cat_input_ids, 913 | text_padding_position=text_padding_position, 914 | pixel_mask=pixel_mask, 915 | output_hidden_states=output_hidden_states or self.config.use_auxiliary_loss, 916 | output_attentions=output_attentions, 917 | return_dict=True, 918 | ) 919 | 920 | loss, loss_dict, auxiliary_logits = None, None, None 921 | class_queries_logits = () 922 | 923 | if self.use_objectness_loss and self.use_text_contrastive_loss: 924 | bsz, _, _, _ = pixel_values.shape 925 | logit_scale = self.logit_scale.exp() 926 | 927 | target_embedding = outputs.encoder_text_last_hidden_state 928 | if cat_input_ids is not None: 929 | target_embedding = target_embedding[torch.arange(bsz).unsqueeze(-1), cat_input_ids] 930 | target_embedding = self.text_target_head(target_embedding) 931 | target_embedding = F.normalize(target_embedding, dim=-1) 932 | 933 | for decoder_output in outputs.transformer_decoder_intermediate_states: 934 | class_embedding = self.query_head(decoder_output.transpose(0, 1)) 935 | class_embedding = F.normalize(class_embedding, dim=-1) 936 | 937 | class_prediction = logit_scale * class_embedding @ target_embedding.transpose(1, 2) 938 | 939 | mask_prob = self.class_predictor(decoder_output.transpose(0, 1)) 940 | class_prediction = (class_prediction, mask_prob) 941 | 942 | class_queries_logits += (class_prediction,) 943 | 944 | elif self.use_text_contrastive_loss: 945 | # contrastive loss 946 | bsz, _, _, _ = pixel_values.shape 947 | logit_scale = self.logit_scale.exp() 948 | 949 | target_embedding = outputs.encoder_text_last_hidden_state 950 | if cat_input_ids is not None: 951 | target_embedding = target_embedding[torch.arange(bsz).unsqueeze(-1), cat_input_ids] 952 | target_embedding = self.text_target_head(target_embedding) 953 | target_embedding = F.normalize(target_embedding, dim=-1) 954 | 955 | for decoder_output in outputs.transformer_decoder_intermediate_states: 956 | class_embedding = self.query_head(decoder_output.transpose(0, 1)) 957 | class_embedding = F.normalize(class_embedding, dim=-1) 958 | 959 | class_prediction = logit_scale * class_embedding @ target_embedding.transpose(1, 2) 960 | 961 | mask_prob = self.class_predictor(decoder_output.transpose(0, 1)) 962 | class_prediction = torch.cat([class_prediction, mask_prob], dim=-1) 963 | 964 | class_queries_logits += (class_prediction,) 965 | else: 966 | # cls loss 967 | for decoder_output in outputs.transformer_decoder_intermediate_states: 968 | class_prediction = self.class_predictor(decoder_output.transpose(0, 1)) 969 | class_queries_logits += (class_prediction,) 970 | 971 | # cls loss use text out 972 | # for decoder_output in outputs.transformer_decoder_text_intermediate_states: 973 | # class_prediction = self.class_predictor(decoder_output.transpose(0, 1)) 974 | # class_queries_logits += (class_prediction,) 975 | 976 | masks_queries_logits = outputs.masks_queries_logits 977 | 978 | auxiliary_logits = self.get_auxiliary_logits(class_queries_logits, masks_queries_logits) 979 | 980 | # print(len(class_queries_logits), len(masks_queries_logits)) 981 | 982 | if mask_labels is not None and class_labels is not None: 983 | loss_dict = self.get_loss_dict( 984 | masks_queries_logits=masks_queries_logits[-1], 985 | class_queries_logits=class_queries_logits[-1], 986 | mask_labels=mask_labels, 987 | class_labels=class_labels, 988 | auxiliary_predictions=auxiliary_logits, 989 | ) 990 | loss = self.get_loss(loss_dict) 991 | 992 | fpn_features = None 993 | pixel_decoder_hidden_states = None 994 | transformer_decoder_hidden_states = None 995 | 996 | if output_hidden_states: 997 | fpn_features = outputs.fpn_features 998 | pixel_decoder_hidden_states = outputs.pixel_decoder_hidden_states 999 | transformer_decoder_hidden_states = outputs.transformer_decoder_hidden_states 1000 | 1001 | output_auxiliary_logits = ( 1002 | self.config.output_auxiliary_logits if output_auxiliary_logits is None else output_auxiliary_logits 1003 | ) 1004 | if not output_auxiliary_logits: 1005 | auxiliary_logits = None 1006 | 1007 | output = BEiT3SegForUniversalSegmentationOutput( 1008 | loss=loss, 1009 | class_queries_logits=class_queries_logits[-1], 1010 | masks_queries_logits=masks_queries_logits[-1], 1011 | auxiliary_logits=auxiliary_logits, 1012 | 1013 | encoder_visual_last_hidden_state=outputs.encoder_visual_last_hidden_state, 1014 | encoder_text_last_hidden_state=outputs.encoder_text_last_hidden_state, 1015 | encoder_hidden_states=outputs.encoder_hidden_states, 1016 | pixel_decoder_last_hidden_state=outputs.pixel_decoder_last_hidden_state, 1017 | transformer_decoder_last_hidden_state=outputs.transformer_decoder_last_hidden_state, 1018 | fpn_features=fpn_features, 1019 | pixel_decoder_hidden_states=pixel_decoder_hidden_states, 1020 | transformer_decoder_hidden_states=transformer_decoder_hidden_states, 1021 | attentions=outputs.attentions, 1022 | 1023 | loss_dict=loss_dict if return_loss_dict else None, 1024 | ) 1025 | 1026 | if not return_dict: 1027 | output = tuple(v for v in output.values() if v is not None) 1028 | if loss is not None: 1029 | output = ((loss)) + output 1030 | return output -------------------------------------------------------------------------------- /image/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Application-and-Integration-Lab/OMTSeg/3e06c9e6f2e65b0656e9fa7f47424149f911d84d/image/overview.png -------------------------------------------------------------------------------- /modeling_utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit3 4 | # Copyright (c) 2023 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # --------------------------------------------------------' 7 | 8 | import math 9 | import torch 10 | import torch.nn as nn 11 | from timm.models.layers import trunc_normal_ as __call_trunc_normal_ 12 | 13 | from torchscale.model.BEiT3 import BEiT3 14 | from torchscale.architecture.config import EncoderConfig 15 | 16 | 17 | def trunc_normal_(tensor, mean=0., std=1.): 18 | __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) 19 | 20 | 21 | def _get_base_config( 22 | img_size=224, patch_size=16, drop_path_rate=0, 23 | checkpoint_activations=None, mlp_ratio=4, vocab_size=64010, **kwargs 24 | ): 25 | return EncoderConfig( 26 | img_size=img_size, patch_size=patch_size, vocab_size=vocab_size, multiway=True, 27 | layernorm_embedding=False, normalize_output=True, no_output_layer=True, 28 | drop_path_rate=drop_path_rate, encoder_embed_dim=768, encoder_attention_heads=12, 29 | encoder_ffn_embed_dim=int(768 * mlp_ratio), encoder_layers=12, 30 | checkpoint_activations=checkpoint_activations, 31 | ) 32 | 33 | 34 | def _get_large_config( 35 | img_size=224, patch_size=16, drop_path_rate=0, 36 | checkpoint_activations=None, mlp_ratio=4, vocab_size=64010, **kwargs 37 | ): 38 | return EncoderConfig( 39 | img_size=img_size, patch_size=patch_size, vocab_size=vocab_size, multiway=True, 40 | layernorm_embedding=False, normalize_output=True, no_output_layer=True, 41 | drop_path_rate=drop_path_rate, encoder_embed_dim=1024, encoder_attention_heads=16, 42 | encoder_ffn_embed_dim=int(1024 * mlp_ratio), encoder_layers=24, 43 | checkpoint_activations=checkpoint_activations, 44 | ) 45 | 46 | 47 | class BEiT3Wrapper(nn.Module): 48 | def __init__(self, args, **kwargs): 49 | super().__init__() 50 | self.args = args 51 | self.beit3 = BEiT3(args) 52 | self.apply(self._init_weights) 53 | 54 | def fix_init_weight(self): 55 | def rescale(param, layer_id): 56 | param.div_(math.sqrt(2.0 * layer_id)) 57 | 58 | for layer_id, layer in enumerate(self.blocks): 59 | rescale(layer.attn.proj.weight.data, layer_id + 1) 60 | rescale(layer.mlp.fc2.weight.data, layer_id + 1) 61 | 62 | def get_num_layers(self): 63 | return self.beit3.encoder.num_layers 64 | 65 | @torch.jit.ignore 66 | def no_weight_decay(self): 67 | return {'pos_embed', 'cls_token', 'beit3.encoder.embed_positions.A.weight', 'beit3.vision_embed.cls_token', 'logit_scale'} 68 | 69 | def _init_weights(self, m): 70 | if isinstance(m, nn.Linear): 71 | trunc_normal_(m.weight, std=.02) 72 | if isinstance(m, nn.Linear) and m.bias is not None: 73 | nn.init.constant_(m.bias, 0) 74 | elif isinstance(m, nn.LayerNorm): 75 | nn.init.constant_(m.bias, 0) 76 | nn.init.constant_(m.weight, 1.0) 77 | -------------------------------------------------------------------------------- /test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import json\n", 11 | "import random\n", 12 | "import numpy as np\n", 13 | "from collections import Counter, OrderedDict\n", 14 | "from tqdm.auto import tqdm\n", 15 | "\n", 16 | "import cv2\n", 17 | "from PIL import Image\n", 18 | "from matplotlib import pyplot as plt\n", 19 | "\n", 20 | "import torch\n", 21 | "import torch.nn as nn\n", 22 | "import torch.nn.functional as F\n", 23 | "from torch.utils.data import Dataset, DataLoader, default_collate\n", 24 | "from timm.models.layers import LayerNorm2d\n", 25 | "import torchshow" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "from transformers import XLMRobertaTokenizer, AutoConfig\n", 35 | "from transformers import AutoImageProcessor, XLMRobertaTokenizer\n", 36 | "from torchscale.architecture.config import EncoderConfig" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 3, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "%load_ext autoreload\n", 46 | "%autoreload 2" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 4, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "class HpConfig:\n", 56 | " img_size = 640\n", 57 | " drop_path = 0.1\n", 58 | " val_batch_size = 1\n", 59 | " lr = 1e-4\n", 60 | " weight_decay = 0.05\n", 61 | " grad_ckpt = False\n", 62 | "\n", 63 | " batch_size = 2\n", 64 | " grad_acc_steps = 4\n", 65 | " num_gpu = 2\n", 66 | " mixed_precision='bf16'" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 5, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "from utils import load_model_and_may_interpolate\n", 76 | "from modeling_utils import _get_large_config" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 6, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "from beit3_seg import BEiT3SegForUniversalSegmentation\n", 86 | " \n", 87 | "mask2former_config = AutoConfig.from_pretrained(\"facebook/mask2former-swin-base-coco-panoptic\", )\n", 88 | "mask2former_config.backbone_config = dict(\n", 89 | " beit3_args=_get_large_config(\n", 90 | " img_size=HpConfig.img_size,\n", 91 | " drop_path_rate=HpConfig.drop_path,\n", 92 | " checkpoint_activations=False,\n", 93 | " ),\n", 94 | " deform_num_heads=16,\n", 95 | " deform_ratio=0.5,\n", 96 | " interaction_indexes=[[0, 5], [6, 11], [12, 17], [18, 23]],\n", 97 | "\n", 98 | " init_values=1e-6,\n", 99 | " conv_inplane=64,\n", 100 | " n_points=4,\n", 101 | " cffn_ratio=0.25,\n", 102 | " with_cp=HpConfig.grad_ckpt,\n", 103 | " num_segments = 1000,\n", 104 | ")\n", 105 | "mask2former_config.backbone_dim = 1024\n", 106 | "mask2former_config.num_labels = 3\n", 107 | "\n", 108 | "mask2former_config.use_text_cross_attn = True\n", 109 | "mask2former_config.use_text_features = True\n", 110 | "mask2former_config.use_text_contrastive_loss = True\n", 111 | "mask2former_config.use_objectness_loss = False\n", 112 | "\n", 113 | "mask2former_config.match_once_only = False\n", 114 | "mask2former_config.drop_first_ce_loss = False\n", 115 | "mask2former_config.encoder_layers=6\n", 116 | "mask2former_config.decoder_layers=10\n", 117 | "\n", 118 | "beit3_seg = BEiT3SegForUniversalSegmentation(mask2former_config)\n", 119 | "beit3_seg = beit3_seg.apply(beit3_seg._init_weights)\n", 120 | "beit3_seg.model.pixel_level_module.encoder.init_weights()" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "beit3_seg" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "tokenizer = XLMRobertaTokenizer(\"./beit3.spm\")\n", 139 | "tokenizer.add_tokens([\"\"])" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "tokenizer.tokenize(\"dog;cat;rabbit;\")" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 10, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "bs = 4\n", 158 | "pixel_values = torch.randn(bs, 3, HpConfig.img_size, HpConfig.img_size)\n", 159 | "input_ids = tokenizer([\"dog;cat;rabbit;\"]*bs, return_tensors=\"pt\")[\"input_ids\"]\n", 160 | "cat_input_ids = torch.tensor([[0, 3, 6] for _ in range(bs)])\n", 161 | "mask_labels =[torch.randint(0, 2, (2, HpConfig.img_size, HpConfig.img_size)).float().to(\"cuda\") for _ in range(bs)]\n", 162 | "class_labels = torch.tensor([[1,2] for _ in range(bs)])" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 11, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "beit3_seg = beit3_seg.to(\"cuda\").eval()" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "metadata": {}, 178 | "outputs": [], 179 | "source": [ 180 | "with torch.no_grad():\n", 181 | " outputs = beit3_seg(\n", 182 | " pixel_values=pixel_values.to(\"cuda\"),\n", 183 | " input_ids=input_ids.to(\"cuda\"),\n", 184 | " cat_input_ids=cat_input_ids.to(\"cuda\"),\n", 185 | " mask_labels=mask_labels,\n", 186 | " class_labels=class_labels.to(\"cuda\"),\n", 187 | " )" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "outputs.loss" 197 | ] 198 | } 199 | ], 200 | "metadata": { 201 | "kernelspec": { 202 | "display_name": "torch2", 203 | "language": "python", 204 | "name": "python3" 205 | }, 206 | "language_info": { 207 | "codemirror_mode": { 208 | "name": "ipython", 209 | "version": 3 210 | }, 211 | "file_extension": ".py", 212 | "mimetype": "text/x-python", 213 | "name": "python", 214 | "nbconvert_exporter": "python", 215 | "pygments_lexer": "ipython3", 216 | "version": "3.11.5" 217 | } 218 | }, 219 | "nbformat": 4, 220 | "nbformat_minor": 2 221 | } 222 | -------------------------------------------------------------------------------- /train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import json\n", 11 | "import random\n", 12 | "import numpy as np\n", 13 | "from collections import Counter, OrderedDict\n", 14 | "from tqdm.auto import tqdm\n", 15 | "import wandb\n", 16 | "\n", 17 | "import cv2\n", 18 | "from PIL import Image\n", 19 | "from matplotlib import pyplot as plt\n", 20 | "\n", 21 | "import albumentations as A\n", 22 | "import albumentations.augmentations.functional as F\n", 23 | "from albumentations.pytorch import ToTensorV2\n", 24 | "\n", 25 | "import torch\n", 26 | "import torch.nn as nn\n", 27 | "import torch.nn.functional as F\n", 28 | "torch.set_printoptions(sci_mode=False)\n", 29 | "from torch.utils.data import Dataset, DataLoader, default_collate\n", 30 | "from timm.models.layers import LayerNorm2d\n", 31 | "import torchshow\n", 32 | "\n", 33 | "from utils import load_model_and_may_interpolate\n", 34 | "from modeling_utils import _get_base_config, _get_large_config\n", 35 | "\n", 36 | "torch.backends.cudnn.benchmark = True" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 2, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "logit_scale.exp()" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 4, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "%load_ext autoreload\n", 64 | "%autoreload 2" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 5, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "from transformers import XLMRobertaTokenizer, AutoConfig\n", 74 | "from transformers import AutoImageProcessor, XLMRobertaTokenizer\n", 75 | "from torchscale.architecture.config import EncoderConfig\n", 76 | "from lion_pytorch import Lion\n", 77 | "from accelerate import Accelerator\n", 78 | "from accelerate.utils import set_seed\n", 79 | "from accelerate import notebook_launcher, DistributedDataParallelKwargs" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 6, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "from BEiT3_adapter.panoptic_dataset import COCOPanopticDataset" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 7, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "class HpConfig:\n", 98 | " img_size = 640\n", 99 | " drop_path = 0.1\n", 100 | " batch_size = 2\n", 101 | " val_batch_size = 1\n", 102 | " grad_acc_steps = 1\n", 103 | " lr = 1e-4\n", 104 | " weight_decay = 0.05\n", 105 | " grad_ckpt = False\n", 106 | " num_gpu = 2\n", 107 | " mixed_precision='bf16'\n", 108 | " wls_token = ''\n", 109 | " sep_token = '▁;'" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 8, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "def get_dataloaders(accelerator):\n", 119 | " tokenizer = XLMRobertaTokenizer(\"../beit3_weights/beit3.spm\")\n", 120 | " tokenizer.add_tokens([HpConfig.wls_token, HpConfig.sep_token])\n", 121 | " \n", 122 | " coco_mask2former_processor = AutoImageProcessor.from_pretrained(\n", 123 | " \"facebook/mask2former-swin-base-coco-panoptic\",\n", 124 | " do_resize=False, do_rescale=True, do_normalize=True, ignore_index=0,\n", 125 | " )\n", 126 | " ade_mask2former_processor = AutoImageProcessor.from_pretrained(\n", 127 | " \"facebook/mask2former-swin-large-ade-panoptic\",\n", 128 | " do_resize=False, do_rescale=True, do_normalize=True, ignore_index=0,\n", 129 | " )\n", 130 | "\n", 131 | " train_transform = A.Compose(\n", 132 | " [\n", 133 | " A.HorizontalFlip(p=0.5),\n", 134 | " A.SmallestMaxSize([HpConfig.img_size*i//10 for i in range(5, 21)], p=1.0),\n", 135 | " A.PadIfNeeded(\n", 136 | " HpConfig.img_size, HpConfig.img_size,\n", 137 | " position=A.PadIfNeeded.PositionType.TOP_LEFT,\n", 138 | " border_mode=cv2.BORDER_CONSTANT,\n", 139 | " ),\n", 140 | " A.RandomCrop(HpConfig.img_size, HpConfig.img_size),\n", 141 | " ]\n", 142 | " )\n", 143 | "\n", 144 | " val_trainform = A.Compose(\n", 145 | " [\n", 146 | " A.LongestMaxSize(HpConfig.img_size, p=1.0),\n", 147 | " A.PadIfNeeded(\n", 148 | " HpConfig.img_size, HpConfig.img_size,\n", 149 | " position=A.PadIfNeeded.PositionType.TOP_LEFT,\n", 150 | " border_mode=cv2.BORDER_CONSTANT,\n", 151 | " ),\n", 152 | " ]\n", 153 | " )\n", 154 | "\n", 155 | " with open('../../datasets/COCO/annotations/panoptic_train2017.json') as file:\n", 156 | " coco_train_ann = json.load(file)\n", 157 | " with open('../../datasets/COCO/annotations/panoptic_val2017.json') as file:\n", 158 | " coco_val_ann = json.load(file)\n", 159 | " with open('../../datasets/ADE20K/from_mmdet/ADEChallengeData2016/ade20k_panoptic_val.json') as file:\n", 160 | " ade_val_ann = json.load(file)\n", 161 | "\n", 162 | " coco_train_dataset = COCOPanopticDataset(\n", 163 | " coco_train_ann,\n", 164 | " '../../datasets/COCO/train2017',\n", 165 | " '../../datasets/COCO/annotations/panoptic_train2017',\n", 166 | " transform=train_transform,\n", 167 | " processor=coco_mask2former_processor,\n", 168 | " use_text=True,\n", 169 | " tokenizer=tokenizer,\n", 170 | " sep_token=HpConfig.sep_token,\n", 171 | " use_sep=True,\n", 172 | " num_sampled_label=133,\n", 173 | " wls_token=HpConfig.wls_token,\n", 174 | " max_sep_num=3,\n", 175 | " # use_sep=False,\n", 176 | " )\n", 177 | "\n", 178 | " coco_val_dataset = COCOPanopticDataset(\n", 179 | " coco_val_ann,\n", 180 | " '../../datasets/COCO/val2017',\n", 181 | " '../../datasets/COCO/annotations/panoptic_val2017',\n", 182 | " transform=val_trainform,\n", 183 | " processor=coco_mask2former_processor,\n", 184 | " use_text=True,\n", 185 | " tokenizer=tokenizer,\n", 186 | " sep_token=HpConfig.sep_token,\n", 187 | " use_sep=True,\n", 188 | " num_sampled_label=133,\n", 189 | " wls_token=HpConfig.wls_token,\n", 190 | " max_sep_num=1,\n", 191 | " # use_sep=False,\n", 192 | " )\n", 193 | "\n", 194 | " ade_val_dataset = COCOPanopticDataset(\n", 195 | " ade_val_ann,\n", 196 | " '../../datasets/ADE20K/from_mmdet/ADEChallengeData2016/images/validation',\n", 197 | " '../../datasets/ADE20K/from_mmdet/ADEChallengeData2016/ade20k_panoptic_val',\n", 198 | " transform=val_trainform,\n", 199 | " processor=ade_mask2former_processor,\n", 200 | " use_text=True,\n", 201 | " tokenizer=tokenizer,\n", 202 | " sep_token=HpConfig.sep_token,\n", 203 | " use_sep=True,\n", 204 | " num_sampled_label=150,\n", 205 | " wls_token=HpConfig.wls_token,\n", 206 | " max_sep_num=1,\n", 207 | " # use_sep=False,\n", 208 | " )\n", 209 | "\n", 210 | " def custom_collate(batch):\n", 211 | " collated_batch = {}\n", 212 | " \n", 213 | " first_elem = batch[0]\n", 214 | " if 'mask_labels' in first_elem:\n", 215 | " collated_batch['mask_labels'] = [b.pop('mask_labels') for b in batch]\n", 216 | " if 'class_labels' in first_elem:\n", 217 | " collated_batch['class_labels'] = [b.pop('class_labels') for b in batch]\n", 218 | " if 'origin_class_labels' in first_elem:\n", 219 | " collated_batch['origin_class_labels'] = [b.pop('origin_class_labels') for b in batch]\n", 220 | " if 'input_ids' in first_elem:\n", 221 | " collated_batch.update(tokenizer.pad(\n", 222 | " [{'input_ids': b.pop('input_ids')} for b in batch],\n", 223 | " max_length=640, padding=True,\n", 224 | " ))\n", 225 | " \n", 226 | " collated_batch.update(default_collate(batch))\n", 227 | " \n", 228 | " return collated_batch\n", 229 | "\n", 230 | " coco_train_loader = DataLoader(\n", 231 | " coco_train_dataset,\n", 232 | " batch_size=HpConfig.batch_size,\n", 233 | " shuffle=True,\n", 234 | " num_workers=8,\n", 235 | " pin_memory=True,\n", 236 | " drop_last=True,\n", 237 | " collate_fn=custom_collate,\n", 238 | " )\n", 239 | "\n", 240 | " coco_val_loader = DataLoader(\n", 241 | " coco_val_dataset,\n", 242 | " batch_size=HpConfig.val_batch_size,\n", 243 | " shuffle=False,\n", 244 | " num_workers=8,\n", 245 | " pin_memory=True,\n", 246 | " drop_last=False,\n", 247 | " collate_fn=custom_collate,\n", 248 | " )\n", 249 | "\n", 250 | " ade_val_loader = DataLoader(\n", 251 | " ade_val_dataset,\n", 252 | " batch_size=HpConfig.val_batch_size,\n", 253 | " shuffle=False,\n", 254 | " num_workers=8,\n", 255 | " pin_memory=True,\n", 256 | " drop_last=False,\n", 257 | " collate_fn=custom_collate,\n", 258 | " )\n", 259 | "\n", 260 | " return coco_train_loader, coco_val_loader, ade_val_loader" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": 20, 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [ 269 | "class CustomEmbedding(nn.Module):\n", 270 | " def __init__(self, old_embedding, new_embedding, split_idx):\n", 271 | " super().__init__()\n", 272 | " self.old_embedding = old_embedding\n", 273 | " self.new_embedding = new_embedding\n", 274 | " self.split_idx = split_idx\n", 275 | "\n", 276 | " def forward(self, input_ids):\n", 277 | " old_embeds = self.old_embedding(\n", 278 | " input_ids.clamp(max=self.old_embedding.num_embeddings - 1))\n", 279 | " new_embeds = self.new_embedding(\n", 280 | " (input_ids - self.split_idx).clamp(min=0))\n", 281 | "\n", 282 | " return torch.where(\n", 283 | " input_ids.unsqueeze(-1) < self.split_idx, old_embeds, new_embeds)\n", 284 | "\n", 285 | "def create_model(accelerator, load_weight=True, freeze_backbone=True, interpolate_pos=False, add_new_embedding=False):\n", 286 | " from BEiT3_adapter.beit3_seg_ov_v2 import BEiT3SegForUniversalSegmentation\n", 287 | " \n", 288 | " mask2former_config = AutoConfig.from_pretrained(\"facebook/mask2former-swin-base-coco-panoptic\", )\n", 289 | " mask2former_config.backbone_config = dict(\n", 290 | " beit3_args=_get_large_config(\n", 291 | " img_size=HpConfig.img_size,\n", 292 | " drop_path_rate=HpConfig.drop_path,\n", 293 | " checkpoint_activations=False,\n", 294 | " ),\n", 295 | " deform_num_heads=16,\n", 296 | " deform_ratio=0.5,\n", 297 | " interaction_indexes=[[0, 5], [6, 11], [12, 17], [18, 23]],\n", 298 | "\n", 299 | " init_values=1e-6,\n", 300 | " conv_inplane=64,\n", 301 | " n_points=4,\n", 302 | " cffn_ratio=0.25,\n", 303 | " with_cp=HpConfig.grad_ckpt,\n", 304 | " )\n", 305 | " mask2former_config.backbone_dim = 768\n", 306 | " mask2former_config.num_labels = 133\n", 307 | " mask2former_config.use_text_cross_attn = True\n", 308 | " mask2former_config.use_text_features = True\n", 309 | " mask2former_config.use_text_contrastive_loss = True\n", 310 | " mask2former_config.use_objectness_loss = True\n", 311 | " mask2former_config.match_once_only = False\n", 312 | " mask2former_config.drop_first_ce_loss = True\n", 313 | " mask2former_config.encoder_layers=6\n", 314 | " mask2former_config.decoder_layers=10\n", 315 | " mask2former_config.objectness_weight = 2\n", 316 | "\n", 317 | " beit3_seg = BEiT3SegForUniversalSegmentation(mask2former_config)\n", 318 | " beit3_seg = beit3_seg.apply(beit3_seg._init_weights)\n", 319 | " beit3_seg.model.pixel_level_module.encoder.init_weights()\n", 320 | "\n", 321 | " if load_weight:\n", 322 | " if accelerator.is_main_process:\n", 323 | " print('Loading BEiT3 pretraind weight...')\n", 324 | " load_model_and_may_interpolate(\n", 325 | " '../beit3_weights/beit3_base_patch16_224.pth',\n", 326 | " beit3_seg.model.pixel_level_module.encoder,\n", 327 | " 'model|module',\n", 328 | " 'beit3.',\n", 329 | " )\n", 330 | " print()\n", 331 | " mask2former_pretrained_weigths = torch.load('./training_checkpoints/vit_adapter_mask2former_coco_768.pth')\n", 332 | " beit3_seg_param_shapes = {n:v.shape for n, v in beit3_seg.state_dict().items()}\n", 333 | " for name, v_shape in [(n, v.shape) for n, v in mask2former_pretrained_weigths.items()]:\n", 334 | " if name in beit3_seg_param_shapes and v_shape != beit3_seg_param_shapes[name]:\n", 335 | " print('mismatch:', name, v_shape, beit3_seg_param_shapes[name])\n", 336 | " del mask2former_pretrained_weigths[name]\n", 337 | " r = beit3_seg.load_state_dict(mask2former_pretrained_weigths, strict=False)\n", 338 | " print(r)\n", 339 | "\n", 340 | " if interpolate_pos:\n", 341 | " if accelerator.is_main_process:\n", 342 | " with torch.no_grad():\n", 343 | " origin_pos = beit3_seg.model.pixel_level_module.encoder.encoder.embed_positions.B.weight[2:130].clone()\n", 344 | " new_pos = F.interpolate(origin_pos.unsqueeze(0).permute(0, 2, 1), 640, mode='linear').permute(0, 2, 1)[0]\n", 345 | " beit3_seg.model.pixel_level_module.encoder.encoder.embed_positions.B.weight[2:640+2] = new_pos\n", 346 | " # beit3_seg.model.transformer_module.psuedo_class_embedder[:512] = new_pos\n", 347 | "\n", 348 | " if add_new_embedding:\n", 349 | " print('Creating new embedding...')\n", 350 | " old_embedding = beit3_seg.model.pixel_level_module.encoder.text_embed\n", 351 | " new_embedding_init_weight = old_embedding.weight[[0]].detach().clone()\n", 352 | " new_embedding = nn.Embedding(\n", 353 | " 1,\n", 354 | " 768,\n", 355 | " _weight=new_embedding_init_weight,\n", 356 | " )\n", 357 | " beit3_seg.model.pixel_level_module.encoder.text_embed = CustomEmbedding(\n", 358 | " old_embedding,\n", 359 | " new_embedding,\n", 360 | " 64002,\n", 361 | " )\n", 362 | "\n", 363 | " if freeze_backbone:\n", 364 | " freeze_keywords = [\n", 365 | " 'model.pixel_level_module.encoder.text_embed', \n", 366 | " 'model.pixel_level_module.encoder.vision_embed', \n", 367 | " 'model.pixel_level_module.encoder.encoder', \n", 368 | " ]\n", 369 | " for name, param in beit3_seg.named_parameters():\n", 370 | " if any([kw in name for kw in freeze_keywords]):\n", 371 | " param.requires_grad_(False)\n", 372 | " else:\n", 373 | " param.requires_grad_(True)\n", 374 | " beit3_seg.model.pixel_level_module.encoder.encoder.embed_positions.A.requires_grad_(True)\n", 375 | " beit3_seg.model.pixel_level_module.encoder.encoder.embed_positions.B.requires_grad_(True)\n", 376 | " if add_new_embedding:\n", 377 | " beit3_seg.model.pixel_level_module.encoder.text_embed.new_embedding.weight.requires_grad_(True)\n", 378 | " \n", 379 | " train_names = []\n", 380 | " freeze_names = []\n", 381 | " for name, param in beit3_seg.named_parameters():\n", 382 | " if param.requires_grad:\n", 383 | " train_names.append(name)\n", 384 | " else:\n", 385 | " freeze_names.append(name)\n", 386 | "\n", 387 | " if accelerator.is_main_process:\n", 388 | " for name in train_names:\n", 389 | " print('o', name)\n", 390 | " for name in freeze_names:\n", 391 | " print('x', name)\n", 392 | "\n", 393 | " return beit3_seg" 394 | ] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "execution_count": 10, 399 | "metadata": {}, 400 | "outputs": [], 401 | "source": [ 402 | "def configure_optimizer(accelerator, model):\n", 403 | " def get_parameter_names(model, forbidden_layer_types):\n", 404 | " \"\"\"\n", 405 | " Returns the names of the model parameters that are not inside a forbidden layer.\n", 406 | " \"\"\"\n", 407 | " result = []\n", 408 | " for name, child in model.named_children():\n", 409 | " result += [\n", 410 | " f\"{name}.{n}\"\n", 411 | " for n in get_parameter_names(child, forbidden_layer_types)\n", 412 | " if not isinstance(child, tuple(forbidden_layer_types))\n", 413 | " ]\n", 414 | " # Add model specific parameters (defined with nn.Parameter) since they are not in any child.\n", 415 | " result += list(model._parameters.keys())\n", 416 | " return result\n", 417 | " \n", 418 | " lion_lr = HpConfig.lr / 5\n", 419 | " lion_weight_decay = HpConfig.weight_decay * 5\n", 420 | "\n", 421 | " no_decay_names = [\"bias\", \"embed_positions\", \"queries_embedder\", \"psuedo_class_embedder\", \"position_embedding\"]\n", 422 | " decay_parameters = get_parameter_names(model, [nn.LayerNorm, LayerNorm2d])\n", 423 | " decay_parameters = [name for name in decay_parameters if all([not ndn in name for ndn in no_decay_names])]\n", 424 | "\n", 425 | " param_groups = {\n", 426 | " \"backbone_decay\": [],\n", 427 | " \"backbone_no_decay\": [],\n", 428 | " \"head_decay\": [],\n", 429 | " \"head_no_decay\": [],\n", 430 | " }\n", 431 | " for n, p in model.named_parameters():\n", 432 | " if not p.requires_grad:\n", 433 | " continue\n", 434 | " if n in decay_parameters and 'beit3' in n:\n", 435 | " param_groups['backbone_decay'].append((n, p))\n", 436 | " elif not n in decay_parameters and 'beit3' in n:\n", 437 | " param_groups['backbone_no_decay'].append((n, p))\n", 438 | " elif n in decay_parameters and not 'beit3' in n:\n", 439 | " param_groups['head_decay'].append((n, p))\n", 440 | " elif not n in decay_parameters and not 'beit3' in n:\n", 441 | " param_groups['head_no_decay'].append((n, p))\n", 442 | " else:\n", 443 | " print(f'Strange param: {n}')\n", 444 | "\n", 445 | " # for group_name, group in param_groups.items():\n", 446 | " # print(group_name, len(group))\n", 447 | " # for n, _ in group:\n", 448 | " # print(f' - {n}')\n", 449 | "\n", 450 | " optimizer_grouped_parameters = [\n", 451 | " {\n", 452 | " \"params\": [p for n, p in param_groups['head_decay']],\n", 453 | " \"weight_decay\": lion_weight_decay,\n", 454 | " \"lr\": lion_lr,\n", 455 | " },\n", 456 | " {\n", 457 | " \"params\": [p for n, p in param_groups['head_no_decay']],\n", 458 | " \"weight_decay\": 0.0,\n", 459 | " \"lr\": lion_lr,\n", 460 | " },\n", 461 | " {\n", 462 | " \"params\": [p for n, p in param_groups['backbone_decay']],\n", 463 | " \"weight_decay\": lion_weight_decay,\n", 464 | " \"lr\": lion_lr*0.2,\n", 465 | " },\n", 466 | " {\n", 467 | " \"params\": [p for n, p in param_groups['backbone_no_decay']],\n", 468 | " \"weight_decay\": 0.0,\n", 469 | " \"lr\": lion_lr*0.2,\n", 470 | " },\n", 471 | " ]\n", 472 | "\n", 473 | " optimizer = Lion(\n", 474 | " optimizer_grouped_parameters,\n", 475 | " # lr=lion_lr,\n", 476 | " # weight_decay=lion_weight_decay,\n", 477 | " )\n", 478 | "\n", 479 | " def lr_lambda(step):\n", 480 | " if step < 2000*HpConfig.num_gpu:\n", 481 | " return step/(2000*HpConfig.num_gpu)\n", 482 | " elif step > 40000*HpConfig.num_gpu:\n", 483 | " return 0.01\n", 484 | " elif step > 30000*HpConfig.num_gpu:\n", 485 | " return 0.1\n", 486 | " else:\n", 487 | " return 1\n", 488 | " \n", 489 | " lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)\n", 490 | "\n", 491 | " return optimizer, lr_scheduler" 492 | ] 493 | }, 494 | { 495 | "cell_type": "code", 496 | "execution_count": 11, 497 | "metadata": {}, 498 | "outputs": [], 499 | "source": [ 500 | "def training_loop(seed: int = 42):\n", 501 | " set_seed(seed)\n", 502 | " # Initialize accelerator\n", 503 | " kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n", 504 | " accelerator = Accelerator(\n", 505 | " mixed_precision=HpConfig.mixed_precision,\n", 506 | " gradient_accumulation_steps=HpConfig.grad_acc_steps,\n", 507 | " log_with='wandb',\n", 508 | " kwargs_handlers=[kwargs],\n", 509 | " )\n", 510 | "\n", 511 | " # Build dataloaders\n", 512 | " coco_train_loader, coco_val_loader, ade_val_loader = get_dataloaders(accelerator)\n", 513 | " # model = create_model(accelerator, load_weight=True, freeze_backbone=True, interpolate_pos=True)\n", 514 | " model = create_model(\n", 515 | " accelerator, load_weight=False, freeze_backbone=True,\n", 516 | " interpolate_pos=False, add_new_embedding=True,\n", 517 | " )\n", 518 | "\n", 519 | " # if accelerator.is_local_main_process:\n", 520 | " # model.model.transformer_module = torch.compile(\n", 521 | " # model.model.transformer_module,\n", 522 | " # # mode='max-autotune',\n", 523 | " # )\n", 524 | "\n", 525 | " optimizer, lr_scheduler = configure_optimizer(accelerator, model)\n", 526 | "\n", 527 | " coco_train_loader, coco_val_loader, ade_val_loader = accelerator.prepare(\n", 528 | " coco_train_loader, coco_val_loader, ade_val_loader\n", 529 | " )\n", 530 | " model, optimizer, lr_scheduler = accelerator.prepare(\n", 531 | " model, optimizer, lr_scheduler\n", 532 | " )\n", 533 | " # lr_scheduler.step_with_optimizer = False\n", 534 | "\n", 535 | " model.module.model.pixel_level_module.encoder.encoder = torch.compile(\n", 536 | " model.module.model.pixel_level_module.encoder.encoder,\n", 537 | " mode='max-autotune',\n", 538 | " # mode=\"reduce-overhead\",\n", 539 | " )\n", 540 | " # model.module.model.transformer_module = torch.compile(\n", 541 | " # model.module.model.transformer_module,\n", 542 | " # mode='max-autotune',\n", 543 | " # # mode=\"reduce-overhead\",\n", 544 | " # )\n", 545 | "\n", 546 | " # if accelerator.is_local_main_process:\n", 547 | " # accelerator.init_trackers(\n", 548 | " # \"BEiT3_Seg_Acc\",\n", 549 | " # config={\n", 550 | " # 'img_size': 640,\n", 551 | " # },\n", 552 | " # )\n", 553 | " # wandb.run.log_code(\n", 554 | " # \".\",\n", 555 | " # include_fn=lambda path: path.endswith(\".py\") or path.endswith(\".ipynb\"),\n", 556 | " # )\n", 557 | "\n", 558 | " global_step = 0\n", 559 | " while global_step < 50000:\n", 560 | " model.train()\n", 561 | " for loader_step, batch in enumerate(tqdm(coco_train_loader, disable=not accelerator.is_local_main_process)):\n", 562 | " optimizer.zero_grad()\n", 563 | "\n", 564 | " outputs = model(\n", 565 | " pixel_values=batch['pixel_values'],\n", 566 | " input_ids=batch['input_ids'],\n", 567 | " cat_input_ids=batch['cat_token_idxs'],\n", 568 | " text_padding_position=1-batch['attention_mask'],\n", 569 | " class_labels=batch['origin_class_labels'],\n", 570 | " mask_labels=batch['mask_labels'],\n", 571 | " return_loss_dict=True,\n", 572 | " )\n", 573 | "\n", 574 | " accelerator.backward(outputs.loss)\n", 575 | " if accelerator.sync_gradients:\n", 576 | " accelerator.clip_grad_norm_(model.parameters(), 1.0)\n", 577 | " optimizer.step()\n", 578 | " if not accelerator.optimizer_step_was_skipped:\n", 579 | " lr_scheduler.step()\n", 580 | "\n", 581 | " step_log = {\n", 582 | " \"train\": {'loss': outputs.loss},\n", 583 | " \"losses\": outputs.loss_dict,\n", 584 | " \"learning rates\": {f\"group_{i}\":lr for i, lr in enumerate(lr_scheduler.get_last_lr())},\n", 585 | " }\n", 586 | " accelerator.log(step_log, step=global_step)\n", 587 | "\n", 588 | " if global_step != 0 and global_step % 4000 == 0 or global_step == 2000:\n", 589 | " accelerator.print(f'Saving model on step: {global_step}..')\n", 590 | " accelerator.wait_for_everyone()\n", 591 | " accelerator.save_state(f'training_checkpoints/adapter-v11-{global_step}')\n", 592 | " # accelerator.save_model(model, f'training_checkpoints/adapter-v2-{global_step}')\n", 593 | " accelerator.print('Model Saved!')\n", 594 | "\n", 595 | " # accelerator.print(f'Saveing training state on step: {global_step}..')\n", 596 | " # accelerator.save_state(output_dir=\"latest-training-state\")\n", 597 | " # accelerator.print('State Saved!')\n", 598 | "\n", 599 | " if global_step >= 90000:\n", 600 | " break\n", 601 | "\n", 602 | " global_step += 1\n", 603 | "\n", 604 | " accelerator.end_training()" 605 | ] 606 | }, 607 | { 608 | "cell_type": "code", 609 | "execution_count": 12, 610 | "metadata": {}, 611 | "outputs": [], 612 | "source": [ 613 | "# args = (96, )\n", 614 | "# notebook_launcher(training_loop, args, num_processes=HpConfig.num_gpu)" 615 | ] 616 | }, 617 | { 618 | "cell_type": "code", 619 | "execution_count": null, 620 | "metadata": {}, 621 | "outputs": [], 622 | "source": [] 623 | } 624 | ], 625 | "metadata": { 626 | "kernelspec": { 627 | "display_name": "torch2", 628 | "language": "python", 629 | "name": "python3" 630 | }, 631 | "language_info": { 632 | "codemirror_mode": { 633 | "name": "ipython", 634 | "version": 3 635 | }, 636 | "file_extension": ".py", 637 | "mimetype": "text/x-python", 638 | "name": "python", 639 | "nbconvert_exporter": "python", 640 | "pygments_lexer": "ipython3", 641 | "version": "3.11.5" 642 | } 643 | }, 644 | "nbformat": 4, 645 | "nbformat_minor": 2 646 | } 647 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit3 4 | # Copyright (c) 2023 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # --------------------------------------------------------' 7 | 8 | import datetime 9 | import io 10 | import os 11 | import math 12 | import time 13 | import json 14 | import argparse 15 | import numpy as np 16 | from pathlib import Path 17 | from collections import defaultdict, deque 18 | from timm.utils import get_state_dict 19 | 20 | import torch 21 | import torch.distributed as dist 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | # from torch._six import inf 25 | from torch import inf 26 | from torchmetrics import Metric 27 | # from tensorboardX import SummaryWriter 28 | 29 | 30 | def bool_flag(s): 31 | """ 32 | Parse boolean arguments from the command line. 33 | """ 34 | FALSY_STRINGS = {"off", "false", "0"} 35 | TRUTHY_STRINGS = {"on", "true", "1"} 36 | if s.lower() in FALSY_STRINGS: 37 | return False 38 | elif s.lower() in TRUTHY_STRINGS: 39 | return True 40 | else: 41 | raise argparse.ArgumentTypeError("invalid value for a boolean flag") 42 | 43 | 44 | class SmoothedValue(object): 45 | """Track a series of values and provide access to smoothed values over a 46 | window or the global series average. 47 | """ 48 | 49 | def __init__(self, window_size=20, fmt=None): 50 | if fmt is None: 51 | fmt = "{median:.4f} ({global_avg:.4f})" 52 | self.deque = deque(maxlen=window_size) 53 | self.total = 0.0 54 | self.count = 0 55 | self.fmt = fmt 56 | 57 | def update(self, value, n=1): 58 | self.deque.append(value) 59 | self.count += n 60 | self.total += value * n 61 | 62 | def synchronize_between_processes(self): 63 | """ 64 | Warning: does not synchronize the deque! 65 | """ 66 | if not is_dist_avail_and_initialized(): 67 | return 68 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 69 | dist.barrier() 70 | dist.all_reduce(t) 71 | t = t.tolist() 72 | self.count = int(t[0]) 73 | self.total = t[1] 74 | 75 | @property 76 | def median(self): 77 | d = torch.tensor(list(self.deque)) 78 | return d.median().item() 79 | 80 | @property 81 | def avg(self): 82 | d = torch.tensor(list(self.deque), dtype=torch.float32) 83 | return d.mean().item() 84 | 85 | @property 86 | def global_avg(self): 87 | return self.total / self.count 88 | 89 | @property 90 | def max(self): 91 | return max(self.deque) 92 | 93 | @property 94 | def value(self): 95 | return self.deque[-1] 96 | 97 | def __str__(self): 98 | return self.fmt.format( 99 | median=self.median, 100 | avg=self.avg, 101 | global_avg=self.global_avg, 102 | max=self.max, 103 | value=self.value) 104 | 105 | 106 | class MetricLogger(object): 107 | def __init__(self, delimiter="\t"): 108 | self.meters = defaultdict(SmoothedValue) 109 | self.delimiter = delimiter 110 | 111 | def update(self, **kwargs): 112 | for k, v in kwargs.items(): 113 | if v is None: 114 | continue 115 | if isinstance(v, torch.Tensor): 116 | v = v.item() 117 | assert isinstance(v, (float, int)) 118 | self.meters[k].update(v) 119 | 120 | def __getattr__(self, attr): 121 | if attr in self.meters: 122 | return self.meters[attr] 123 | if attr in self.__dict__: 124 | return self.__dict__[attr] 125 | raise AttributeError("'{}' object has no attribute '{}'".format( 126 | type(self).__name__, attr)) 127 | 128 | def __str__(self): 129 | loss_str = [] 130 | for name, meter in self.meters.items(): 131 | loss_str.append( 132 | "{}: {}".format(name, str(meter)) 133 | ) 134 | return self.delimiter.join(loss_str) 135 | 136 | def synchronize_between_processes(self): 137 | for meter in self.meters.values(): 138 | meter.synchronize_between_processes() 139 | 140 | def add_meter(self, name, meter): 141 | self.meters[name] = meter 142 | 143 | def log_every(self, iterable, print_freq, header=None): 144 | i = 0 145 | if not header: 146 | header = '' 147 | start_time = time.time() 148 | end = time.time() 149 | iter_time = SmoothedValue(fmt='{avg:.4f}') 150 | data_time = SmoothedValue(fmt='{avg:.4f}') 151 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 152 | log_msg = [ 153 | header, 154 | '[{0' + space_fmt + '}/{1}]', 155 | 'eta: {eta}', 156 | '{meters}', 157 | 'time: {time}', 158 | 'data: {data}' 159 | ] 160 | if torch.cuda.is_available(): 161 | log_msg.append('max mem: {memory:.0f}') 162 | log_msg = self.delimiter.join(log_msg) 163 | MB = 1024.0 * 1024.0 164 | for obj in iterable: 165 | data_time.update(time.time() - end) 166 | yield obj 167 | iter_time.update(time.time() - end) 168 | if i % print_freq == 0 or i == len(iterable) - 1: 169 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 170 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 171 | if torch.cuda.is_available(): 172 | print(log_msg.format( 173 | i, len(iterable), eta=eta_string, 174 | meters=str(self), 175 | time=str(iter_time), data=str(data_time), 176 | memory=torch.cuda.max_memory_allocated() / MB)) 177 | else: 178 | print(log_msg.format( 179 | i, len(iterable), eta=eta_string, 180 | meters=str(self), 181 | time=str(iter_time), data=str(data_time))) 182 | i += 1 183 | end = time.time() 184 | total_time = time.time() - start_time 185 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 186 | print('{} Total time: {} ({:.4f} s / it)'.format( 187 | header, total_time_str, total_time / len(iterable))) 188 | 189 | 190 | class TensorboardLogger(object): 191 | def __init__(self, log_dir): 192 | self.writer = SummaryWriter(logdir=log_dir) 193 | self.step = 0 194 | 195 | def set_step(self, step=None): 196 | if step is not None: 197 | self.step = step 198 | else: 199 | self.step += 1 200 | 201 | def update(self, head='scalar', step=None, **kwargs): 202 | for k, v in kwargs.items(): 203 | if v is None: 204 | continue 205 | if isinstance(v, torch.Tensor): 206 | v = v.item() 207 | assert isinstance(v, (float, int)) 208 | self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step) 209 | 210 | def flush(self): 211 | self.writer.flush() 212 | 213 | 214 | def _load_checkpoint_for_ema(model_ema, checkpoint): 215 | """ 216 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 217 | """ 218 | mem_file = io.BytesIO() 219 | torch.save(checkpoint, mem_file) 220 | mem_file.seek(0) 221 | model_ema._load_checkpoint(mem_file) 222 | 223 | 224 | def setup_for_distributed(is_master): 225 | """ 226 | This function disables printing when not in master process 227 | """ 228 | import builtins as __builtin__ 229 | builtin_print = __builtin__.print 230 | 231 | def print(*args, **kwargs): 232 | force = kwargs.pop('force', False) 233 | if is_master or force: 234 | builtin_print(*args, **kwargs) 235 | 236 | __builtin__.print = print 237 | 238 | 239 | def is_dist_avail_and_initialized(): 240 | if not dist.is_available(): 241 | return False 242 | if not dist.is_initialized(): 243 | return False 244 | return True 245 | 246 | 247 | def get_world_size(): 248 | if not is_dist_avail_and_initialized(): 249 | return 1 250 | return dist.get_world_size() 251 | 252 | 253 | def get_rank(): 254 | if not is_dist_avail_and_initialized(): 255 | return 0 256 | return dist.get_rank() 257 | 258 | 259 | def is_main_process(): 260 | return get_rank() == 0 261 | 262 | 263 | def save_on_master(*args, **kwargs): 264 | if is_main_process(): 265 | torch.save(*args, **kwargs) 266 | 267 | 268 | def _get_rank_env(): 269 | if "RANK" in os.environ: 270 | return int(os.environ["RANK"]) 271 | else: 272 | return int(os.environ['OMPI_COMM_WORLD_RANK']) 273 | 274 | 275 | def _get_local_rank_env(): 276 | if "LOCAL_RANK" in os.environ: 277 | return int(os.environ["LOCAL_RANK"]) 278 | else: 279 | return int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 280 | 281 | 282 | def _get_world_size_env(): 283 | if "WORLD_SIZE" in os.environ: 284 | return int(os.environ["WORLD_SIZE"]) 285 | else: 286 | return int(os.environ['OMPI_COMM_WORLD_SIZE']) 287 | 288 | 289 | # The implementation code is modified from DeiT (https://github.com/facebookresearch/deit.git) 290 | def init_distributed_mode(args): 291 | if args.dist_on_itp: 292 | args.rank = _get_rank_env() 293 | args.world_size = _get_world_size_env() # int(os.environ['OMPI_COMM_WORLD_SIZE']) 294 | args.gpu = _get_local_rank_env() 295 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 296 | os.environ['LOCAL_RANK'] = str(args.gpu) 297 | os.environ['RANK'] = str(args.rank) 298 | os.environ['WORLD_SIZE'] = str(args.world_size) 299 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 300 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 301 | args.rank = int(os.environ["RANK"]) 302 | args.world_size = int(os.environ['WORLD_SIZE']) 303 | args.gpu = int(os.environ['LOCAL_RANK']) 304 | elif 'SLURM_PROCID' in os.environ: 305 | args.rank = int(os.environ['SLURM_PROCID']) 306 | args.gpu = args.rank % torch.cuda.device_count() 307 | else: 308 | print('Not using distributed mode') 309 | args.distributed = False 310 | return 311 | 312 | args.distributed = True 313 | 314 | torch.cuda.set_device(args.gpu) 315 | args.dist_backend = 'nccl' 316 | print('| distributed init (rank {}): {}, gpu {}'.format( 317 | args.rank, args.dist_url, args.gpu), flush=True) 318 | torch.distributed.init_process_group( 319 | backend=args.dist_backend, init_method=args.dist_url, 320 | world_size=args.world_size, rank=args.rank, 321 | timeout=datetime.timedelta(0, 7200) 322 | ) 323 | torch.distributed.barrier() 324 | setup_for_distributed(args.rank == 0) 325 | 326 | 327 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"): 328 | missing_keys = [] 329 | unexpected_keys = [] 330 | error_msgs = [] 331 | # copy state_dict so _load_from_state_dict can modify it 332 | metadata = getattr(state_dict, '_metadata', None) 333 | state_dict = state_dict.copy() 334 | if metadata is not None: 335 | state_dict._metadata = metadata 336 | 337 | def load(module, prefix=''): 338 | local_metadata = {} if metadata is None else metadata.get( 339 | prefix[:-1], {}) 340 | module._load_from_state_dict( 341 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 342 | for name, child in module._modules.items(): 343 | if child is not None: 344 | load(child, prefix + name + '.') 345 | 346 | load(model, prefix=prefix) 347 | 348 | warn_missing_keys = [] 349 | ignore_missing_keys = [] 350 | for key in missing_keys: 351 | keep_flag = True 352 | for ignore_key in ignore_missing.split('|'): 353 | if ignore_key in key: 354 | keep_flag = False 355 | break 356 | if keep_flag: 357 | warn_missing_keys.append(key) 358 | else: 359 | ignore_missing_keys.append(key) 360 | 361 | missing_keys = warn_missing_keys 362 | 363 | if len(missing_keys) > 0: 364 | print("Weights of {} not initialized from pretrained model: {}".format( 365 | model.__class__.__name__, missing_keys)) 366 | if len(unexpected_keys) > 0: 367 | print("Weights from pretrained model not used in {}: {}".format( 368 | model.__class__.__name__, unexpected_keys)) 369 | if len(ignore_missing_keys) > 0: 370 | print("Ignored weights of {} not initialized from pretrained model: {}".format( 371 | model.__class__.__name__, ignore_missing_keys)) 372 | if len(error_msgs) > 0: 373 | print('\n'.join(error_msgs)) 374 | 375 | 376 | class NativeScalerWithGradNormCount: 377 | state_dict_key = "amp_scaler" 378 | 379 | def __init__(self): 380 | self._scaler = torch.cuda.amp.GradScaler() 381 | 382 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 383 | self._scaler.scale(loss).backward(create_graph=create_graph) 384 | if update_grad: 385 | if clip_grad is not None: 386 | assert parameters is not None 387 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 388 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 389 | else: 390 | self._scaler.unscale_(optimizer) 391 | norm = get_grad_norm_(parameters) 392 | self._scaler.step(optimizer) 393 | self._scaler.update() 394 | else: 395 | norm = None 396 | return norm 397 | 398 | def state_dict(self): 399 | return self._scaler.state_dict() 400 | 401 | def load_state_dict(self, state_dict): 402 | self._scaler.load_state_dict(state_dict) 403 | 404 | 405 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 406 | if isinstance(parameters, torch.Tensor): 407 | parameters = [parameters] 408 | parameters = [p for p in parameters if p.grad is not None] 409 | norm_type = float(norm_type) 410 | if len(parameters) == 0: 411 | return torch.tensor(0.) 412 | device = parameters[0].grad.device 413 | if norm_type == inf: 414 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 415 | else: 416 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 417 | return total_norm 418 | 419 | 420 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, 421 | start_warmup_value=0, warmup_steps=-1, sched_type="cos"): 422 | warmup_schedule = np.array([]) 423 | warmup_iters = warmup_epochs * niter_per_ep 424 | if warmup_steps > 0: 425 | warmup_iters = warmup_steps 426 | print("Set warmup steps = %d" % warmup_iters) 427 | if warmup_epochs > 0: 428 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 429 | 430 | if sched_type == "cos": 431 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 432 | schedule = np.array([ 433 | final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters]) 434 | elif sched_type == "linear": 435 | schedule = np.linspace(base_value, final_value, epochs * niter_per_ep - warmup_iters) 436 | else: 437 | raise NotImplementedError() 438 | 439 | schedule = np.concatenate((warmup_schedule, schedule)) 440 | 441 | assert len(schedule) == epochs * niter_per_ep 442 | return schedule 443 | 444 | 445 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): 446 | output_dir = Path(args.output_dir) 447 | if loss_scaler is not None: 448 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch)] 449 | for checkpoint_path in checkpoint_paths: 450 | to_save = { 451 | 'model': model_without_ddp.state_dict(), 452 | 'optimizer': optimizer.state_dict(), 453 | 'epoch': epoch, 454 | 'scaler': loss_scaler.state_dict(), 455 | 'args': args, 456 | } 457 | 458 | if model_ema is not None: 459 | to_save['model_ema'] = get_state_dict(model_ema) 460 | 461 | save_on_master(to_save, checkpoint_path) 462 | else: 463 | client_state = {'epoch': epoch, "args": args} 464 | if model_ema is not None: 465 | client_state['model_ema'] = get_state_dict(model_ema) 466 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch, client_state=client_state) 467 | 468 | 469 | def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): 470 | output_dir = Path(args.output_dir) 471 | if loss_scaler is not None: 472 | # torch.amp 473 | if args.auto_resume and len(args.resume) == 0: 474 | import glob 475 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth')) 476 | latest_ckpt = -1 477 | for ckpt in all_checkpoints: 478 | t = ckpt.split('-')[-1].split('.')[0] 479 | if t.isdigit(): 480 | latest_ckpt = max(int(t), latest_ckpt) 481 | if latest_ckpt >= 0: 482 | args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt) 483 | print("Auto resume checkpoint: %s" % args.resume) 484 | 485 | if args.resume: 486 | if args.resume.startswith('https'): 487 | checkpoint = torch.hub.load_state_dict_from_url( 488 | args.resume, map_location='cpu', check_hash=True) 489 | else: 490 | checkpoint = torch.load(args.resume, map_location='cpu') 491 | model_without_ddp.load_state_dict(checkpoint['model']) 492 | print("Resume checkpoint %s" % args.resume) 493 | if 'optimizer' in checkpoint and 'epoch' in checkpoint: 494 | optimizer.load_state_dict(checkpoint['optimizer']) 495 | args.start_epoch = checkpoint['epoch'] + 1 496 | if hasattr(args, 'model_ema') and args.model_ema: 497 | _load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) 498 | if 'scaler' in checkpoint: 499 | loss_scaler.load_state_dict(checkpoint['scaler']) 500 | print("With optim & sched!") 501 | else: 502 | # deepspeed, only support '--auto_resume'. 503 | if args.auto_resume: 504 | import glob 505 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*')) 506 | latest_ckpt = -1 507 | for ckpt in all_checkpoints: 508 | t = ckpt.split('-')[-1].split('.')[0] 509 | if t.isdigit(): 510 | latest_ckpt = max(int(t), latest_ckpt) 511 | if latest_ckpt >= 0: 512 | args.resume = os.path.join(output_dir, 'checkpoint-%d' % latest_ckpt) 513 | print("Auto resume checkpoint: %d" % latest_ckpt) 514 | _, client_states = model.load_checkpoint(args.output_dir, tag='checkpoint-%d' % latest_ckpt) 515 | args.start_epoch = client_states['epoch'] + 1 516 | if model_ema is not None: 517 | if args.model_ema: 518 | _load_checkpoint_for_ema(model_ema, client_states['model_ema']) 519 | 520 | 521 | # The implementation code is modified from DeiT (https://github.com/facebookresearch/deit.git) 522 | def load_model_and_may_interpolate(ckpt_path, model, model_key, model_prefix): 523 | if ckpt_path.startswith('https'): 524 | checkpoint = torch.hub.load_state_dict_from_url( 525 | ckpt_path, map_location='cpu', check_hash=True) 526 | else: 527 | checkpoint = torch.load(ckpt_path, map_location='cpu') 528 | 529 | print("Load ckpt from %s" % ckpt_path) 530 | checkpoint_model = None 531 | for model_key in model_key.split('|'): 532 | if model_key in checkpoint: 533 | checkpoint_model = checkpoint[model_key] 534 | print("Load state_dict by model_key = %s" % model_key) 535 | break 536 | 537 | if checkpoint_model is None: 538 | checkpoint_model = checkpoint 539 | 540 | state_dict = model.state_dict() 541 | for k in ['head.weight', 'head.bias']: 542 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 543 | print(f"Removing key {k} from pretrained checkpoint") 544 | del checkpoint_model[k] 545 | 546 | # interpolate position embedding 547 | for pos_embed_key in ("vision_pos_embed", "pos_embed", "beit3.encoder.embed_positions.A.weight"): 548 | if pos_embed_key in checkpoint_model: 549 | pos_embed_checkpoint = checkpoint_model[pos_embed_key] 550 | embedding_size = pos_embed_checkpoint.shape[-1] 551 | if pos_embed_key == "beit3.encoder.embed_positions.A.weight": 552 | # being consistent with Fairseq, which starts from 2 for position embedding 553 | torchscale_model = True 554 | num_patches = model.vision_embed.num_patches 555 | num_extra_tokens = model.vision_embed.num_position_embeddings() + 2 - num_patches 556 | else: 557 | torchscale_model = False 558 | num_patches = model.patch_embed.num_patches 559 | num_extra_tokens = getattr(model, pos_embed_key).shape[-2] - num_patches 560 | # height (== width) for the checkpoint position embedding 561 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 562 | # height (== width) for the new position embedding 563 | new_size = int(num_patches ** 0.5) 564 | # class_token and dist_token are kept unchanged 565 | if orig_size != new_size: 566 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 567 | if torchscale_model: 568 | extra_tokens = pos_embed_checkpoint[:num_extra_tokens].unsqueeze(0) 569 | # only the position tokens are interpolated 570 | pos_tokens = pos_embed_checkpoint[num_extra_tokens:] 571 | else: 572 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 573 | # only the position tokens are interpolated 574 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 575 | # pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 576 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2).float() 577 | pos_tokens = torch.nn.functional.interpolate( 578 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 579 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 580 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 581 | if torchscale_model: 582 | new_pos_embed = new_pos_embed.squeeze(0) 583 | checkpoint_model[pos_embed_key] = new_pos_embed 584 | 585 | load_state_dict(model, checkpoint_model, prefix=model_prefix) 586 | 587 | 588 | def create_ds_config(args): 589 | args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json") 590 | with open(args.deepspeed_config, mode="w") as writer: 591 | ds_config = { 592 | "train_batch_size": args.batch_size * args.update_freq * get_world_size(), 593 | "train_micro_batch_size_per_gpu": args.batch_size, 594 | "steps_per_print": 1000, 595 | "optimizer": { 596 | "type": "Adam", 597 | "adam_w_mode": True, 598 | "params": { 599 | "lr": args.lr, 600 | "weight_decay": args.weight_decay, 601 | "bias_correction": True, 602 | "betas": [ 603 | args.opt_betas[0], 604 | args.opt_betas[1] 605 | ], 606 | "eps": args.opt_eps 607 | } 608 | }, 609 | "fp16": { 610 | "enabled": True, 611 | "loss_scale": 0, 612 | "initial_scale_power": getattr(args, "initial_scale_power", 12), 613 | "loss_scale_window": 1000, 614 | "hysteresis": 2, 615 | "min_loss_scale": 1 616 | }, 617 | "amp": { 618 | "enabled": False, 619 | "opt_level": "O2" 620 | } 621 | } 622 | 623 | if args.clip_grad is not None: 624 | ds_config.update({'gradient_clipping': args.clip_grad}) 625 | 626 | if args.zero_stage == 1: 627 | ds_config.update({"zero_optimization": {"stage": args.zero_stage, "reduce_bucket_size": 5e8}}) 628 | elif args.zero_stage > 1: 629 | raise NotImplementedError() 630 | 631 | writer.write(json.dumps(ds_config, indent=2)) 632 | 633 | 634 | def merge_batch_tensors_by_dict_key(batch): 635 | batch_tensors = {} 636 | for tensor_key in batch[0]: 637 | if isinstance(batch[0][tensor_key], torch.Tensor): 638 | batch_tensors[tensor_key] = torch.stack([d[tensor_key] for d in batch]) 639 | else: 640 | batch_tensors[tensor_key] = torch.tensor([d[tensor_key] for d in batch], dtype=torch.long) 641 | return batch_tensors 642 | 643 | 644 | def get_loss_scale_for_deepspeed(model): 645 | optimizer = model.optimizer 646 | loss_scale = None 647 | if hasattr(optimizer, 'loss_scale'): 648 | loss_scale = optimizer.loss_scale 649 | elif hasattr(optimizer, 'cur_scale'): 650 | loss_scale = optimizer.cur_scale 651 | return loss_scale 652 | 653 | 654 | class GatherLayer(torch.autograd.Function): 655 | """ 656 | Gather tensors from all workers with support for backward propagation: 657 | This implementation does not cut the gradients as torch.distributed.all_gather does. 658 | """ 659 | @staticmethod 660 | def forward(ctx, x): 661 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] 662 | dist.all_gather(output, x) 663 | return tuple(output) 664 | @staticmethod 665 | def backward(ctx, *grads): 666 | all_gradients = torch.stack(grads) 667 | dist.all_reduce(all_gradients) 668 | return all_gradients[dist.get_rank()] 669 | 670 | 671 | def gather_features( 672 | image_features, 673 | text_features, 674 | ): 675 | gathered_image_features = GatherLayer.apply(image_features) 676 | gathered_text_features = GatherLayer.apply(text_features) 677 | all_image_features = torch.cat(gathered_image_features) 678 | all_text_features = torch.cat(gathered_text_features) 679 | 680 | return all_image_features, all_text_features 681 | 682 | 683 | # The implementation code is modified from open_clip (https://github.com/mlfoundations/open_clip.git) 684 | class ClipLoss(nn.Module): 685 | 686 | def __init__( 687 | self, 688 | cache_labels=False, 689 | rank=0, 690 | world_size=1, 691 | ): 692 | super().__init__() 693 | self.cache_labels = cache_labels 694 | self.rank = rank 695 | self.world_size = world_size 696 | 697 | # cache state 698 | self.prev_num_logits = 0 699 | self.labels = {} 700 | 701 | def forward(self, image_features, text_features, logit_scale): 702 | device = image_features.device 703 | if self.world_size > 1: 704 | all_image_features, all_text_features = gather_features( 705 | image_features, text_features 706 | ) 707 | 708 | logits_per_image = logit_scale * image_features @ all_text_features.T 709 | logits_per_text = logit_scale * text_features @ all_image_features.T 710 | else: 711 | logits_per_image = logit_scale * image_features @ text_features.T 712 | logits_per_text = logit_scale * text_features @ image_features.T 713 | 714 | # calculated ground-truth and cache if enabled 715 | num_logits = logits_per_image.shape[0] 716 | if self.prev_num_logits != num_logits or device not in self.labels: 717 | labels = torch.arange(num_logits, device=device, dtype=torch.long) 718 | if self.world_size > 1: 719 | labels = labels + num_logits * self.rank 720 | if self.cache_labels: 721 | self.labels[device] = labels 722 | self.prev_num_logits = num_logits 723 | else: 724 | labels = self.labels[device] 725 | 726 | total_loss = ( 727 | F.cross_entropy(logits_per_image, labels) + 728 | F.cross_entropy(logits_per_text, labels) 729 | ) / 2 730 | return total_loss, logits_per_image, logits_per_text 731 | 732 | 733 | def write_result_to_jsonl(test_stats, result_file): 734 | with open(result_file, mode="w", encoding="utf-8") as writer: 735 | writer.write(json.dumps(test_stats, indent=None)) 736 | 737 | 738 | def read_result_from_jsonl(result_file): 739 | with open(result_file, mode="r", encoding="utf-8") as reader: 740 | return json.load(reader) 741 | 742 | 743 | # The implementation code is from ViLT (https://github.com/dandelin/ViLT.git) 744 | class VQAScore(Metric): 745 | def __init__(self, dist_sync_on_step=False): 746 | super().__init__(dist_sync_on_step=dist_sync_on_step) 747 | self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum") 748 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 749 | 750 | def update(self, logits, target): 751 | logits, target = ( 752 | logits.detach().float().to(self.score.device), 753 | target.detach().float().to(self.score.device), 754 | ) 755 | logits = torch.max(logits, 1)[1] 756 | one_hots = torch.zeros(*target.size()).to(target) 757 | one_hots.scatter_(1, logits.view(-1, 1), 1) 758 | scores = one_hots * target 759 | 760 | self.score += scores.sum() 761 | self.total += len(logits) 762 | 763 | def compute(self): 764 | return self.score / self.total 765 | 766 | 767 | class BertCaptioningLoss(nn.Module): 768 | def __init__(self, label_smoothing, drop_worst_ratio, drop_worst_after): 769 | super().__init__() 770 | self.label_smoothing = label_smoothing 771 | self.drop_worst_ratio = drop_worst_ratio 772 | self.drop_worst_after = drop_worst_after 773 | self.log_soft = nn.LogSoftmax(dim=1) 774 | self.kl = nn.KLDivLoss(reduction='none') 775 | self.iter = 0 776 | 777 | def forward(self, logits, target, iter): 778 | eps = self.label_smoothing 779 | n_class = logits.size(1) 780 | one_hot = torch.zeros_like(logits).scatter(1, target.view(-1, 1), 1) 781 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) 782 | log_prb = self.log_soft(logits) 783 | loss = self.kl(log_prb, one_hot).sum(1) 784 | 785 | if self.drop_worst_ratio > 0 and iter > self.drop_worst_after: 786 | loss, _ = torch.topk(loss, 787 | k=int(loss.shape[0] * (1-self.drop_worst_ratio)), 788 | largest=False) 789 | loss = loss.mean() 790 | 791 | return loss 792 | 793 | 794 | class BeamHypotheses(object): 795 | def __init__(self, n_hyp, max_length, length_penalty, early_stopping): 796 | """ 797 | Initialize n-best list of hypotheses. 798 | """ 799 | self.max_length = max_length - 1 # ignoring bos_token 800 | self.length_penalty = length_penalty 801 | self.early_stopping = early_stopping 802 | self.n_hyp = n_hyp 803 | self.hyp = [] 804 | self.worst_score = 1e9 805 | 806 | def __len__(self): 807 | """ 808 | Number of hypotheses in the list. 809 | """ 810 | return len(self.hyp) 811 | 812 | def add(self, hyp, sum_logprobs): 813 | """ 814 | Add a new hypothesis to the list. 815 | """ 816 | score = sum_logprobs / len(hyp) ** self.length_penalty 817 | if len(self) < self.n_hyp or score > self.worst_score: 818 | self.hyp.append((score, hyp)) 819 | if len(self) > self.n_hyp: 820 | sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)]) 821 | del self.hyp[sorted_scores[0][1]] 822 | self.worst_score = sorted_scores[1][0] 823 | else: 824 | self.worst_score = min(score, self.worst_score) 825 | 826 | def is_done(self, best_sum_logprobs): 827 | """ 828 | If there are enough hypotheses and that none of the hypotheses being generated 829 | can become better than the worst one in the heap, then we are done with this sentence. 830 | """ 831 | if len(self) < self.n_hyp: 832 | return False 833 | elif self.early_stopping: 834 | return True 835 | else: 836 | return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty 837 | 838 | 839 | def dump_predictions(args, result, file_suffix): 840 | global_rank = get_rank() 841 | jsons = None 842 | if global_rank >= 0: 843 | output_file = os.path.join(args.task_cache_path, f"submit_{global_rank}_{file_suffix}.json") 844 | with open(output_file, "w") as fp: 845 | json.dump(result, fp, indent=2) 846 | torch.distributed.barrier() 847 | 848 | if global_rank == 0: 849 | world_size = get_world_size() 850 | jsons = [] 851 | for i in range(world_size): 852 | each_file = os.path.join(args.task_cache_path, f"submit_{i}_{file_suffix}.json") 853 | with open(each_file, "r") as fp: 854 | jsons += json.load(fp) 855 | 856 | new_jsons = [] 857 | res_dict = dict() 858 | if args.task in ["coco_captioning", "nocaps"]: 859 | qid_key = "image_id" 860 | else: 861 | # for VQAv2 862 | qid_key = "question_id" 863 | for item in jsons: 864 | if item[qid_key] in res_dict: 865 | continue 866 | new_jsons.append(item) 867 | res_dict[item[qid_key]] = item 868 | jsons = new_jsons 869 | 870 | torch.distributed.barrier() 871 | os.remove(output_file) 872 | else: 873 | jsons = result 874 | 875 | result_file = os.path.join(args.output_dir, f"submit_{file_suffix}.json") 876 | if jsons is not None: 877 | with open(result_file, "w") as fp: 878 | json.dump(jsons, fp, indent=2) 879 | print("Infer %d examples into %s" % (len(jsons), result_file)) 880 | return result_file 881 | 882 | 883 | # The evaluation code is from BLIP (https://github.com/salesforce/BLIP) 884 | # For nocaps, please submit the prediction file to the evaluate server (https://eval.ai/web/challenges/challenge-page/355/overview) to obtain the final results 885 | def coco_caption_eval(gt_dir, results_file, split): 886 | from pycocotools.coco import COCO 887 | from pycocoevalcap.eval import COCOEvalCap 888 | from torchvision.datasets.utils import download_url 889 | 890 | urls = {'coco_captioning_val': 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json', 891 | 'coco_captioning_test': 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json', 892 | 'nocaps_val': 'https://conversationhub.blob.core.windows.net/beit-share-public/beit3/nocaps/nocaps_val_gt.json?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D'} 893 | filenames = {'coco_captioning_val':'coco_karpathy_val_gt.json', 894 | 'coco_captioning_test':'coco_karpathy_test_gt.json', 895 | 'nocaps_val':'nocaps_val_gt.json'} 896 | 897 | download_url(urls[split], gt_dir) 898 | annotation_file = os.path.join(gt_dir, filenames[split]) 899 | 900 | # create coco object and coco_result object 901 | coco = COCO(annotation_file) 902 | coco_result = coco.loadRes(results_file) 903 | 904 | # create coco_eval object by taking coco and coco_result 905 | coco_eval = COCOEvalCap(coco, coco_result) 906 | 907 | # evaluate results 908 | # SPICE will take a few minutes the first time, but speeds up due to caching 909 | coco_eval.evaluate() 910 | 911 | res_dict = dict() 912 | for metric, score in coco_eval.eval.items(): 913 | res_dict[metric] = score 914 | 915 | return res_dict 916 | --------------------------------------------------------------------------------