├── .gitmodules ├── LICENSE ├── README.md └── models ├── __init__.py ├── caa.py └── modules ├── __init__.py └── caa_head.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "iseg"] 2 | path = iseg 3 | url = https://github.com/edwardyehuang/iSeg.git 4 | [submodule "ids"] 5 | path = ids 6 | url = https://github.com/edwardyehuang/iDS 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 edwardyehuang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Channelized Axial Attention for Semantic Segmentation (AAAI-2022) 2 | 3 | ## News 4 | Mar-16-2022: Based on the latest [Class-aware Regularizations (CAR)](https://github.com/edwardyehuang/CAR), CAA + ConvNeXt-Large achieved 64.12% mIOU on Pascal Context! The [repo](https://github.com/edwardyehuang/CAR) of CAR also contains the train code for CAA. 5 | 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/caa-channelized-axial-attention-for-semantic/semantic-segmentation-on-coco-stuff-test)](https://paperswithcode.com/sota/semantic-segmentation-on-coco-stuff-test?p=caa-channelized-axial-attention-for-semantic) 7 | 8 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/car-class-aware-regularizations-for-semantic-1/semantic-segmentation-on-pascal-context)](https://paperswithcode.com/sota/semantic-segmentation-on-pascal-context?p=car-class-aware-regularizations-for-semantic-1) 9 | 10 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/caa-channelized-axial-attention-for-semantic/semantic-segmentation-on-cityscapes)](https://paperswithcode.com/sota/semantic-segmentation-on-cityscapes?p=caa-channelized-axial-attention-for-semantic) 11 | 12 | Some segmentation results on Flickr images: 13 | 14 | 15 | 16 | ## Installation 17 | 1. Install TensorFlow (>= 2.4, 2.3 is not recommend for GPU, but okay for TPU) 18 | 2. Install iSeg (My personal segmentation codebase, update soon at https://github.com/edwardyehuang/iSeg) 19 | 3. Install iDS (Dataset supports for iSeg, update soon at https://github.com/edwardyehuang/iDS) 20 | 4. Clone this repo 21 | 22 | ## Model Zoo 23 | Since some of the original experiments (especially for ResNet-101) are conducted a long time ago, the ckpts listed below may have slightly different performance with paper reported. 24 | 25 | ### Pascal Context 26 | 27 | | Backbone | ckpts | mIOU% | configs | 28 | | ---- | ---- | ---- | ---- | 29 | | ResNet-101 | [weiyun](https://share.weiyun.com/nSUwp76n) |55.0|configs | 30 | | EfficientNet-B7 | [weiyun](https://share.weiyun.com/uMXjsmXf)|60.3| configs | 31 | 32 | 33 | ### COCOStuff-10k 34 | 35 | | Backbone | ckpts | mIOU% | configs | 36 | | ---- | ---- | ---- | ---- | 37 | | ResNet-101 | [weiyun](https://share.weiyun.com/LtcKwuhK) |41.2| configs | 38 | 39 | ### COCOStuff-164k 40 | 41 | | Backbone | ckpts | mIOU% | configs | 42 | | ---- | ---- | ---- | ---- | 43 | | EfficientNet-B5 | [weiyun](https://share.weiyun.com/p5xbCE55) |47.27| configs | 44 | 45 | 46 | ## Please cite us 47 | 48 | ``` 49 | @InProceedings{cCAA, 50 | author = "Ye Huang and Di Kang and Wenjing Jia and Xiangjian He and Liu liu", 51 | title = "Channelized Axial Attention - Considering Channel Relation within Spatial Attention for Semantic Segmentation", 52 | booktitle = "Proceedings of the AAAI Conference on Artificial Intelligence", 53 | year = "2022", 54 | DOI={10.1609/aaai.v36i1.19985}, 55 | } 56 | ``` 57 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edwardyehuang/CAA/74a1de8b1cfca6a6669e4fe76dc2130e37fba19d/models/__init__.py -------------------------------------------------------------------------------- /models/caa.py: -------------------------------------------------------------------------------- 1 | # ================================================================ 2 | # MIT License 3 | # Copyright (c) 2021 edwardyehuang (https://github.com/edwardyehuang) 4 | # ================================================================ 5 | 6 | import iseg.static_strings as ss 7 | import tensorflow as tf 8 | 9 | from iseg.backbones.feature_extractor import get_backbone 10 | from iseg.utils import resize_image 11 | from iseg.layers.model_builder import ConvBnRelu 12 | 13 | from iseg import SegFoundation 14 | from models.modules.caa_head import ChannelizedAxialAttentionHead 15 | 16 | # CAA was implemented in 2020, long time ago 17 | # SegFoundation will be replaced by segmanaged in future (after I have time) 18 | class CAA(SegFoundation): 19 | def __init__( 20 | self, 21 | backbone_name=ss.RESNET50, 22 | backbone_weights_path=None, 23 | num_class=2, 24 | output_stride=16, 25 | use_channel_attention=True, 26 | use_image_level=True, 27 | use_aux_loss=True, 28 | aux_loss_rate=0.4, 29 | use_typical_aux_feature=False, 30 | always_map_fn=False, 31 | num_parallel_group_fn=16, 32 | **kwargs, 33 | ): 34 | 35 | super().__init__( 36 | num_class=num_class, 37 | num_aux_loss=1 if use_aux_loss else 0, 38 | aux_loss_rate=aux_loss_rate, 39 | ) 40 | 41 | self.output_stride = output_stride 42 | self.use_aux_loss = use_aux_loss 43 | self.use_image_level = use_image_level 44 | self.use_typical_aux_feature = use_typical_aux_feature 45 | 46 | self.backbone = get_backbone( 47 | backbone_name, 48 | output_stride=output_stride, 49 | resnet_multi_grids=[1, 2, 4], 50 | resnet_slim=True, 51 | weights_path=backbone_weights_path, 52 | return_endpoints=True, 53 | ) 54 | 55 | 56 | self.seg_head = ChannelizedAxialAttentionHead( 57 | filters=512, 58 | use_channel_attention=use_channel_attention, 59 | use_image_level=self.use_image_level, 60 | num_parallel_group_fn=num_parallel_group_fn, 61 | fallback_concat=always_map_fn, 62 | ) 63 | 64 | 65 | self.seg_head_convbnrelu = ConvBnRelu( 66 | 256, (1, 1), dropout_rate=0.1, name="seg_head_conv" 67 | ) 68 | 69 | self.logits_conv = tf.keras.layers.Conv2D(num_class, (1, 1), name="logits_conv") 70 | 71 | if self.use_aux_loss: 72 | aux_conv_name = "aux_feature_conv0" 73 | 74 | if self.use_typical_aux_feature: 75 | aux_conv_name = "aux_feature_typical_conv0" 76 | self.aux_down_conv = ConvBnRelu(512, (3, 3), name="aux_down_conv") 77 | 78 | self.aux_feature_convbnrelu0 = ConvBnRelu(256, (1, 1), dropout_rate=0.1, name=aux_conv_name) 79 | self.aux_loss_logits_conv = tf.keras.layers.Conv2D(num_class, (1, 1), name="aux_loss_logits_conv") 80 | 81 | def call(self, inputs, training=None, **kwargs): 82 | 83 | input_shape = tf.shape(inputs) 84 | input_size = input_shape[1:3] 85 | 86 | x = inputs 87 | 88 | endpoints = self.backbone(x, training=training, **kwargs) 89 | endpoints = endpoints[1:] 90 | 91 | if self.use_aux_loss: 92 | # aux_layer_index = -1 if not self.use_typical_aux_feature else -2 93 | aux_layer_feature = endpoints[-1] # self.aux_layer() 94 | 95 | if self.use_typical_aux_feature: 96 | aux_layer_feature = self.aux_down_conv(aux_layer_feature, training=training) 97 | 98 | x = endpoints[-1] 99 | x = self.seg_head(x, training=training) 100 | x = self.seg_head_convbnrelu(x, training=training) 101 | x = self.logits_conv(x) 102 | 103 | x = resize_image(x, size=input_size) 104 | x = tf.cast(x, tf.float32) 105 | 106 | if self.use_aux_loss: 107 | aux_feature = self.aux_feature_convbnrelu0(aux_layer_feature, training=training) 108 | aux_feature = self.aux_loss_logits_conv(aux_feature) 109 | aux_feature = resize_image(aux_feature, size=input_size) 110 | 111 | aux_feature = tf.cast(aux_feature, tf.float32) 112 | return [x, aux_feature] 113 | 114 | return x 115 | -------------------------------------------------------------------------------- /models/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edwardyehuang/CAA/74a1de8b1cfca6a6669e4fe76dc2130e37fba19d/models/modules/__init__.py -------------------------------------------------------------------------------- /models/modules/caa_head.py: -------------------------------------------------------------------------------- 1 | # ================================================================ 2 | # MIT License 3 | # Copyright (c) 2021 edwardyehuang (https://github.com/edwardyehuang) 4 | # ================================================================ 5 | 6 | import tensorflow as tf 7 | from iseg.layers import DenseExt 8 | 9 | from iseg.utils.attention_utils import get_axial_attention 10 | from iseg.layers.model_builder import ConvBnRelu, get_training_value 11 | 12 | from iseg.vis.vismanager import get_visualization_manager 13 | 14 | 15 | class ChannelizedAxialAttentionHead(tf.keras.Model): 16 | def __init__( 17 | self, 18 | filters=512, 19 | attention_blocks_num=1, 20 | use_channel_attention=True, 21 | use_image_level=True, 22 | num_parallel_group_fn=16, 23 | fallback_concat=False, 24 | use_entry_conv=True, 25 | name=None, 26 | ): 27 | 28 | if name is None: 29 | name = "ChannelizedAxialAttentionHead" 30 | 31 | super(ChannelizedAxialAttentionHead, self).__init__(name=name) 32 | 33 | self.use_image_level = use_image_level 34 | self.filters = filters 35 | 36 | self.entry_convbnrelu = ( 37 | None if not use_entry_conv else ConvBnRelu(self.filters, kernel_size=(3, 3), name="entry") 38 | ) 39 | 40 | self.end_convbnrelu = ConvBnRelu(self.filters, (1, 1), name="end") 41 | 42 | self.ca_blocks = [ 43 | AxialAttentionBlock( 44 | guided_filters=64, 45 | filters=self.filters, 46 | use_channel_attention=use_channel_attention, 47 | num_parallel_group_fn=num_parallel_group_fn, 48 | fallback_concat=fallback_concat, 49 | name="ca_block".format(i), 50 | ) 51 | for i in range(attention_blocks_num) 52 | ] 53 | 54 | if self.use_image_level: 55 | self.image_level_block = ImageLevelBlock(self.filters, (1, 2)) 56 | 57 | def call(self, inputs, training=None, **kwargs): 58 | 59 | x = inputs 60 | 61 | if self.entry_convbnrelu is not None: 62 | x = self.entry_convbnrelu(x, training=training) 63 | 64 | for ca_block in self.ca_blocks: 65 | x = ca_block(x, training=training) 66 | 67 | x = [self.end_convbnrelu(x, training=training)] 68 | 69 | if self.use_image_level: 70 | x += [self.image_level_block(inputs, training=training)] 71 | 72 | x = tf.concat(x, axis=-1) 73 | 74 | return x 75 | 76 | 77 | 78 | class ImageLevelBlock(tf.keras.Model): 79 | def __init__(self, filters=256, pooling_axis=(1, 2), name=None): 80 | super(ImageLevelBlock, self).__init__(name="ImageLevelBlock" if name is None else name) 81 | 82 | self.convbnrelu = ConvBnRelu(filters, (1, 1), name="conv") 83 | self.pooling_axis = pooling_axis 84 | 85 | def call(self, inputs, training=None): 86 | 87 | x = inputs 88 | inputs_dtype = inputs.dtype 89 | inputs_size = tf.shape(inputs)[1:3] 90 | 91 | x = tf.reduce_mean(x, axis=self.pooling_axis, keepdims=True, name="pool") 92 | x = self.convbnrelu(x, training=training) 93 | 94 | one = tf.ones((), dtype=tf.int32) 95 | target_height = tf.cast(inputs_size[0], dtype=tf.int32) 96 | target_width = tf.cast(inputs_size[1], dtype=tf.int32) 97 | target_shape = tf.stack([one, target_height, target_width, one], axis=0) 98 | 99 | x = tf.ones(target_shape, dtype=x.dtype) * x 100 | 101 | x = tf.cast(x, inputs_dtype) 102 | 103 | return x 104 | 105 | 106 | class ChannelAttentionBlock(tf.keras.Model): 107 | def __init__( 108 | self, 109 | hiddlen_fitlers=256, 110 | end_filters=256, 111 | hidden_layers_count=1, 112 | hidden_activation=tf.nn.leaky_relu, 113 | name=None, 114 | ): 115 | 116 | super(ChannelAttentionBlock, self).__init__(name=name if name is not None else "ChannelAttentionBlock") 117 | 118 | self.hidden_activation = hidden_activation 119 | 120 | # can be replaced by 1x1 conv (same) 121 | self.denses = [ 122 | DenseExt(hiddlen_fitlers, use_bias=False, name="dense{}".format(i)) for i in range(hidden_layers_count) 123 | ] 124 | 125 | self.dense_end = DenseExt(end_filters, use_bias=False, name="dense_end") 126 | 127 | def call(self, inputs, training=None): 128 | 129 | inputs_rank = len(inputs.shape) 130 | 131 | if inputs_rank == 4: 132 | x = tf.reduce_mean(inputs, axis=(1, 2)) # [N, C] 133 | elif inputs_rank == 3: 134 | x = tf.reduce_mean(inputs, axis=1) # [N, C] 135 | elif inputs_rank == 2: 136 | x = inputs 137 | else: 138 | raise ValueError("Incorrect inputs rank") 139 | 140 | for i in range(len(self.denses)): 141 | x = self.denses[i](x) 142 | x = self.hidden_activation(x) 143 | 144 | x = self.dense_end(x) 145 | x = tf.nn.sigmoid(x) 146 | 147 | x = tf.expand_dims(x, axis=1) 148 | 149 | if inputs_rank == 4: 150 | x = tf.expand_dims(x, axis=2) 151 | 152 | return x 153 | 154 | 155 | class AxialAttentionBlock(tf.keras.Model): 156 | def __init__( 157 | self, 158 | guided_filters=64, 159 | filters=512, 160 | hidden_layers_count=5, 161 | hidden_layer_ratio=0.25, 162 | use_channel_attention=True, 163 | channel_attention_norm=False, 164 | num_parallel_group_fn=16, 165 | fallback_concat=False, 166 | name=None, 167 | ): 168 | 169 | super(AxialAttentionBlock, self).__init__(name=name if name is not None else "AxialAttentionBlock") 170 | 171 | self.filters = filters 172 | self.use_channel_attention = use_channel_attention 173 | self.fallback_concat = fallback_concat 174 | self.num_parallel_group_fn = num_parallel_group_fn 175 | 176 | self.v_querykey_convbnrelu = ConvBnRelu(guided_filters, (1, 1), name="v_querykey_conv") 177 | self.h_querykey_convbnrelu = ConvBnRelu(guided_filters, (1, 1), name="h_querykey_conv") 178 | 179 | self.v_c_attention = ChannelAttentionBlock( 180 | int(filters * hidden_layer_ratio), 181 | filters, 182 | hidden_layers_count, 183 | channel_attention_norm, 184 | name="v_c_attention", 185 | ) 186 | self.h_c_attention = ChannelAttentionBlock( 187 | int(filters * hidden_layer_ratio), 188 | filters, 189 | hidden_layers_count, 190 | channel_attention_norm, 191 | name="h_c_attention", 192 | ) 193 | 194 | self.value_convbnrelu = ConvBnRelu(filters, (1, 1), name="value_conv") 195 | 196 | def call(self, inputs, training=None): 197 | 198 | x = inputs 199 | 200 | v_logits = self.compute_v_rate(x, training=training) # [N, W, H, H] 201 | h_logits = self.compute_h_rate(x, training=training) # [N, H, W, W] 202 | 203 | x = self.value_convbnrelu(x, training=training) 204 | 205 | vis_manager = get_visualization_manager() 206 | 207 | if vis_manager.recording: 208 | vis_manager.easy_add(x, name="before_x") 209 | 210 | x = self.compute_v_result(v_logits, x, training=training) 211 | x = self.compute_h_result(h_logits, x, training=training) 212 | 213 | if vis_manager.recording: 214 | vis_manager.easy_add(x, name="augmented_x") 215 | 216 | return x 217 | 218 | def compute_v_result(self, v_logits, features, training=None): 219 | 220 | v_logits = tf.transpose(v_logits, [2, 0, 3, 1], name="v_weights") # [N, W, H, H] => [H, N, H, W] 221 | x = self.apply_attention_map( 222 | v_logits, features, axial_axis=1, use_channel_attention=self.use_channel_attention, training=training 223 | ) # [H, N, W, C] 224 | 225 | x = tf.transpose(x, [1, 0, 2, 3], name="v_features_result") # [N, H, W, C] 226 | 227 | return x 228 | 229 | def compute_h_result(self, h_logits, features, training=None): 230 | 231 | h_logits = tf.transpose(h_logits, [2, 0, 1, 3], name="h_weights") # [N, H, W, W] => [W, N, H, W] 232 | x = self.apply_attention_map( 233 | h_logits, features, axial_axis=2, use_channel_attention=self.use_channel_attention, training=training 234 | ) # [W, N, H, C] 235 | 236 | x = tf.transpose(x, [1, 2, 0, 3], name="h_features_result") # [N, H, W, C] 237 | 238 | return x 239 | 240 | def compute_v_rate(self, features, training=None): 241 | 242 | q = k = self.v_querykey_convbnrelu(features, training=training) 243 | 244 | return get_axial_attention(q, k, axis=1) 245 | 246 | def compute_h_rate(self, features, training=None): 247 | 248 | q = k = self.h_querykey_convbnrelu(features, training=training) 249 | 250 | return get_axial_attention(q, k, axis=2) 251 | 252 | def apply_attention_map(self, attention_map, features, axial_axis=1, use_channel_attention=True, training=None): 253 | 254 | """ 255 | attention_map : [H, N, H, W] or [W, N, H, W] 256 | features : [N, H, W, C] 257 | 258 | """ 259 | 260 | channel_attention_func = self.v_c_attention if axial_axis == 1 else self.h_c_attention 261 | 262 | attention_map = tf.expand_dims(attention_map, axis=-1) # [H, N, H, W, 1] or [W, N, H, W, 1] 263 | 264 | return self.group_fn( 265 | attention_map, 266 | features, 267 | num_group=self.num_parallel_group_fn, 268 | channel_attention_fn=channel_attention_func if use_channel_attention else None, 269 | reduce_axis=axial_axis, 270 | fallback_concat=self.fallback_concat, 271 | training=training, 272 | ) 273 | 274 | def group_fn( 275 | self, 276 | attention_map, 277 | features, 278 | num_group=4, 279 | channel_attention_fn=None, 280 | reduce_axis=1, 281 | fallback_concat=False, 282 | training=None, 283 | ): 284 | 285 | # attention map [H, N, H, W, 1] or [W, N, H, W, 1] 286 | # features [N, H, W, C] 287 | 288 | features = tf.expand_dims(features, axis=0) # [1, N, H, W, C] 289 | 290 | attention_map_shape = tf.shape(attention_map) 291 | total_length = attention_map_shape[0] 292 | 293 | if not get_training_value(training): 294 | num_group = total_length 295 | 296 | batch_size = attention_map_shape[1] 297 | height = attention_map_shape[2] 298 | width = attention_map_shape[3] 299 | channels = features.shape[-1] 300 | 301 | height_or_width = width if reduce_axis == 1 else height 302 | 303 | group_size = total_length // num_group 304 | group_size_remain = total_length % group_size 305 | 306 | padding_len = group_size - group_size_remain 307 | 308 | attention_map = tf.pad( 309 | attention_map, [[0, padding_len], [0, 0], [0, 0], [0, 0], [0, 0]], name="pad_attention_map" 310 | ) 311 | 312 | pad_num_group = tf.shape(attention_map)[0] // group_size 313 | 314 | groups = tf.TensorArray(dtype=features.dtype, size=pad_num_group, clear_after_read=False, name="groups") 315 | 316 | def compuate_weighted_map(start_index, sliced_size): 317 | 318 | end_index = start_index + sliced_size 319 | 320 | sub_attention_map = attention_map[start_index:end_index] # [group_size, N, H, W, 1] 321 | weighted_map = tf.raw_ops.Mul( 322 | x=features, y=sub_attention_map, name="weighted_mul" 323 | ) # [group_size, N, H, W, C] 324 | weighted_map = tf.reshape( 325 | weighted_map, [group_size * batch_size, height, width, channels] 326 | ) # [group_size * N, H, W, C] 327 | 328 | if channel_attention_fn is not None: 329 | # [group_size * N, H, W, C] 330 | weighted_map = tf.multiply( 331 | weighted_map, channel_attention_fn(weighted_map, training=training), name="channel_mul" 332 | ) 333 | 334 | weighted_map = tf.reduce_sum( 335 | weighted_map, axis=reduce_axis 336 | ) # [group_size * N, W, C or group_size * N, H, C] 337 | weighted_map = tf.reshape( 338 | weighted_map, [sliced_size, batch_size, height_or_width, channels] 339 | ) # [group_size, N, W, C or group_size, N, H, C] 340 | 341 | return weighted_map 342 | 343 | i = tf.constant(0) 344 | 345 | def loop_body(i, _groups): 346 | start_index = i * group_size 347 | weighted_map = compuate_weighted_map( 348 | start_index, group_size 349 | ) # [group_size, N, W, C or group_size, N, H, C] 350 | 351 | return tf.add(i, 1), _groups.write(i, weighted_map) 352 | 353 | _, groups = tf.while_loop(lambda i, _: tf.less(i, pad_num_group), loop_body, [i, groups]) 354 | 355 | # For Apple CoreML support 356 | if fallback_concat: 357 | results = self.fallback_stack_tensor_array(groups) 358 | results = tf.reshape(results, [pad_num_group * group_size, batch_size, height_or_width, channels]) 359 | else: 360 | results = groups.concat(name="groups_concat") 361 | 362 | groups.close(name="groups_close") 363 | 364 | results = results[:total_length] 365 | 366 | return results 367 | 368 | def fallback_stack_tensor_array(self, arr: tf.TensorArray): 369 | 370 | arr_size = arr.size() 371 | results = arr.gather(tf.range(arr_size)) 372 | 373 | return results 374 | --------------------------------------------------------------------------------