├── README.md └── pmodel.py /README.md: -------------------------------------------------------------------------------- 1 | # EAFNet (Efficient Attention-bridged Fusion) 2 | tensorflow 2.0; 3 | ## Description 4 | EAFNet is a semantic segmentation (SS) model which is designed for the multimodal sensor fusion. 5 | We utilize polarization information as the complementary information to improve SS's performance on certain categories with high polarization feature like glass and car. 6 | We will upload the code when our paper is accepted. (Polarization-driven Semantic Segmentation via Efficient Attention-bridged Fusion) 7 | ## Main Dependencies 8 | ``` 9 | tensorflow 2.0 10 | Open CV 11 | Python 3.6.5 12 | ``` 13 | -------------------------------------------------------------------------------- /pmodel.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.keras as keras 3 | from tensorflow.keras.regularizers import l2 4 | import numpy as np 5 | import math 6 | 7 | 8 | def conv(numout,kernel_size=3,strides=1,kernel_regularizer=0.0005,padding='same',use_bias=False,name='conv'): 9 | return tf.keras.layers.Conv2D(name=name,filters=numout, kernel_size=kernel_size,strides=strides, padding=padding,use_bias=use_bias, kernel_regularizer=l2(kernel_regularizer),kernel_initializer=tf.random_normal_initializer(stddev=0.1)) 10 | def bn(name,momentum=0.9): 11 | return tf.keras.layers.BatchNormalization(name=name,momentum=momentum) 12 | def ln(name,momentum=0.9): 13 | return tf.keras.layers.BatchNormalization(name=name,momentum=momentum) 14 | def mish(x): 15 | return x*tf.math.tanh(tf.math.softplus(x)) 16 | class branch1(keras.Model): 17 | def __init__(self,scope: str="branch1",numout:int =16,strides:int=1,reg:float=0.0005): 18 | super(branch1, self).__init__(name=scope) 19 | self.conv1=conv(numout=numout,strides=strides, kernel_regularizer=reg,name='conv1/conv') 20 | self.bn1=bn(name='conv1/bn') 21 | self.conv2=conv(numout=numout, kernel_regularizer=reg,name='conv2/conv') 22 | self.bn2=bn(name='conv2/bn') 23 | def call(self,x,training=False): 24 | x = self.conv1(x) 25 | x = self.bn1(x,training=training) 26 | x = tf.nn.relu(x) 27 | x = self.conv2(x) 28 | x = self.bn2(x,training=training) 29 | return x 30 | 31 | class branch2(keras.Model): 32 | def __init__(self,scope: str="branch2",numout:int =16,strides:int=1,reg:float=0.0005): 33 | super(branch2, self).__init__(name=scope) 34 | self.conv1=conv(numout=numout,kernel_size=1,strides=strides, kernel_regularizer=reg) 35 | self.bn1=bn(name='bn') 36 | def call(self,x,training=False): 37 | x = self.conv1(x) 38 | x = self.bn1(x,training=training) 39 | return x 40 | 41 | 42 | class residual(keras.Model): 43 | def __init__(self,scope: str="block",numout:int =16,strides:int=1,reg:float=0.0005,branch=False): 44 | super(residual, self).__init__(name=scope) 45 | self.branch1=branch1(scope="branch1",numout=numout,strides=strides, reg=reg) 46 | self.branch2=branch2(scope='convshortcut',numout=numout,strides=strides, reg=reg) 47 | self.branch=branch 48 | self.numout=numout 49 | def call(self,x,training=False): 50 | block = self.branch1(x,training=training) 51 | if x.get_shape().as_list()[3] != self.numout or self.branch: 52 | skip = self.branch2(x,training=training) 53 | x = tf.nn.relu(block+skip) 54 | return x 55 | x = tf.nn.relu(x + block) 56 | return x 57 | 58 | class attention(keras.Model): 59 | def __init__(self,scope: str="attention",numout:int =16,strides:int=1,reg:float=0.0005,branch=False): 60 | super(attention, self).__init__(name=scope) 61 | t=int(abs(np.log2(numout)+1)/2) 62 | k= t if t%2 else t+1 63 | self.conv1=conv(numout=1, kernel_size=k,kernel_regularizer=reg,name='atten_conv') 64 | def call(self,x,training=False): 65 | x = tf.math.reduce_mean(x, axis=[1,2], keepdims=True) 66 | x = tf.transpose(x,[0,3,2,1]) 67 | x = self.conv1(x) 68 | x = tf.nn.sigmoid(x) 69 | x = tf.transpose(x,[0,3,2,1]) 70 | return x 71 | 72 | 73 | class resnet18(keras.Model): 74 | def __init__(self,scope: str="Resnet18",reg:float=0.0005): 75 | super(resnet18, self).__init__(name=scope) 76 | self.conv1 = conv(numout=64,strides=2, kernel_size=7,kernel_regularizer=reg,name='conv0') 77 | self.bn1 = bn(name='conv0/bn') 78 | self.atten_rgb_0 = attention(scope='rgb0_atten',numout=64,reg=reg) 79 | self.atten_Aop_0 = attention(scope='aop0_atten',numout=64,reg=reg) 80 | 81 | 82 | self.res0a = residual(scope='group0/block0',numout=64,branch=True,reg=reg) 83 | self.res0b = residual(scope='group0/block1',numout=64,reg=reg) 84 | self.atten_rgb_1 = attention(scope='rgb1_atten',numout=64,reg=reg) 85 | self.atten_Aop_1 = attention(scope='aop1_atten',numout=64,reg=reg) 86 | 87 | self.res1a = residual(scope='group1/block0',numout=128,strides=2,reg=reg) 88 | self.res1b = residual(scope='group1/block1',numout=128,reg=reg) 89 | 90 | self.atten_rgb_2 = attention(scope='rgb2_atten',numout=128,reg=reg) 91 | self.atten_Aop_2 = attention(scope='aop2_atten',numout=128,reg=reg) 92 | 93 | self.res2a = residual(scope='group2/block0',numout=256,strides=2,reg=reg) 94 | self.res2b = residual(scope='group2/block1',numout=256,reg=reg) 95 | self.atten_rgb_3 = attention(scope='rgb3_atten',numout=256,reg=reg) 96 | self.atten_Aop_3 = attention(scope='aop3_atten',numout=256,reg=reg) 97 | 98 | 99 | self.res3a = residual(scope='group3/block0',numout=512,strides=2,reg=reg) 100 | self.res3b = residual(scope='group3/block1',numout=512,reg=reg) 101 | self.atten_rgb_4 = attention(scope='rgb4_atten',numout=512,reg=reg) 102 | self.atten_Aop_4 = attention(scope='aop4_atten',numout=512,reg=reg) 103 | 104 | self.conv1Aop = conv(numout=64,strides=2, kernel_size=7,kernel_regularizer=reg,name='conv0_aop') 105 | self.bn1Aop = bn(name='conv0/bn_aop') 106 | self.res0aAop = residual(scope='group0/block0_aop',numout=64,branch=True,reg=reg) 107 | self.res0bAop = residual(scope='group0/block1_aop',numout=64,reg=reg) 108 | self.res1aAop = residual(scope='group1/block0_aop',numout=128,strides=2,reg=reg) 109 | self.res1bAop = residual(scope='group1/block1_aop',numout=128,reg=reg) 110 | self.res2aAop = residual(scope='group2/block0_aop',numout=256,strides=2,reg=reg) 111 | self.res2bAop = residual(scope='group2/block1_aop',numout=256,reg=reg) 112 | self.res3aAop = residual(scope='group3/block0_aop',numout=512,strides=2,reg=reg) 113 | self.res3bAop = residual(scope='group3/block1_aop',numout=512,reg=reg) 114 | 115 | self.res0amerge = residual(scope='group0/block0_merge',numout=64,branch=True,reg=reg) 116 | self.res0bmerge = residual(scope='group0/block1_merge',numout=64,reg=reg) 117 | self.res1amerge = residual(scope='group1/block0_merge',numout=128,strides=2,reg=reg) 118 | self.res1bmerge = residual(scope='group1/block1_merge',numout=128,reg=reg) 119 | self.res2amerge = residual(scope='group2/block0_merge',numout=256,strides=2,reg=reg) 120 | self.res2bmerge = residual(scope='group2/block1_merge',numout=256,reg=reg) 121 | self.res3amerge = residual(scope='group3/block0_merge',numout=512,strides=2,reg=reg) 122 | self.res3bmerge = residual(scope='group3/block1_merge',numout=512,reg=reg) 123 | def call(self,x,aop,training=False): 124 | feature=[] 125 | showfeature=[] 126 | x=self.conv1(x) 127 | x=self.bn1(x,training=training) 128 | x=tf.nn.relu(x) 129 | x_atten = self.atten_rgb_0(x,training=training) 130 | showfeature.append(x) 131 | showfeature.append(x_atten) 132 | 133 | 134 | x_aop=self.conv1Aop(aop) 135 | x_aop=self.bn1Aop(x_aop,training=training) 136 | x_aop=tf.nn.relu(x_aop) 137 | x_aop_atten = self.atten_Aop_0(x_aop,training=training) 138 | showfeature.append(x_aop) 139 | showfeature.append(x_aop_atten) 140 | m = x_atten*x+x_aop_atten*x_aop 141 | 142 | x=tf.nn.max_pool(x,ksize=[1,3,3,1],strides=[1,2,2,1],padding='SAME',name='pool1') 143 | x_aop=tf.nn.max_pool(x_aop,ksize=[1,3,3,1],strides=[1,2,2,1],padding='SAME',name='pool1_aop') 144 | m=tf.nn.max_pool(m,ksize=[1,3,3,1],strides=[1,2,2,1],padding='SAME',name='pool1_fuse') 145 | 146 | x = self.res0a(x,training=training) 147 | x = self.res0b(x,training=training) 148 | x_aop = self.res0aAop(x_aop,training=training) 149 | x_aop = self.res0bAop(x_aop,training=training) 150 | 151 | 152 | x_atten = self.atten_rgb_1(x,training=training) 153 | x_aop_atten = self.atten_Aop_1(x_aop,training=training) 154 | showfeature.append(x) 155 | showfeature.append(x_atten) 156 | showfeature.append(x_aop) 157 | showfeature.append(x_aop_atten) 158 | 159 | m = self.res0amerge(m,training=training) 160 | m = self.res0bmerge(m,training=training) 161 | m = m+x_atten*x+x_aop_atten*x_aop 162 | feature.append(m) 163 | 164 | x=self.res1a(x,training=training) 165 | x=self.res1b(x,training=training) 166 | x_aop = self.res1aAop(x_aop,training=training) 167 | x_aop = self.res1bAop(x_aop,training=training) 168 | x_atten = self.atten_rgb_2(x,training=training) 169 | x_aop_atten = self.atten_Aop_2(x_aop,training=training) 170 | m = self.res1amerge(m,training=training) 171 | m = self.res1bmerge(m,training=training) 172 | m = m+x_atten*x+x_aop_atten*x_aop 173 | feature.append(m) 174 | showfeature.append(x) 175 | showfeature.append(x_atten) 176 | showfeature.append(x_aop) 177 | showfeature.append(x_aop_atten) 178 | 179 | x=self.res2a(x,training=training) 180 | x=self.res2b(x,training=training) 181 | x_aop = self.res2aAop(x_aop,training=training) 182 | x_aop = self.res2bAop(x_aop,training=training) 183 | x_atten = self.atten_rgb_3(x,training=training) 184 | x_aop_atten = self.atten_Aop_3(x_aop,training=training) 185 | m = self.res2amerge(m,training=training) 186 | m = self.res2bmerge(m,training=training) 187 | m = m+x_atten*x+x_aop_atten*x_aop 188 | feature.append(m) 189 | showfeature.append(x) 190 | showfeature.append(x_atten) 191 | showfeature.append(x_aop) 192 | showfeature.append(x_aop_atten) 193 | 194 | x=self.res3a(x,training=training) 195 | x=self.res3b(x,training=training) 196 | x_aop = self.res3aAop(x_aop,training=training) 197 | x_aop = self.res3bAop(x_aop,training=training) 198 | x_atten = self.atten_rgb_4(x,training=training) 199 | x_aop_atten = self.atten_Aop_4(x_aop,training=training) 200 | m = self.res3amerge(m,training=training) 201 | m = self.res3bmerge(m,training=training) 202 | m = m+x_atten*x+x_aop_atten*x_aop 203 | feature.append(m) 204 | showfeature.append(x) 205 | showfeature.append(x_atten) 206 | showfeature.append(x_aop) 207 | showfeature.append(x_aop_atten) 208 | return feature,showfeature 209 | 210 | 211 | 212 | class SpatialPyramidPooling(keras.Model): 213 | def __init__(self,scope: str="spp",reg:float=0.0005,grids=(8,4,2)): 214 | super(SpatialPyramidPooling, self).__init__(name=scope) 215 | self.bn0=bn(name='bn0') 216 | self.bn1=bn(name='blendbn') 217 | self.conv0=conv(numout=128,kernel_size=1, kernel_regularizer=reg, name='conv0') 218 | self.conv1=conv(numout=128,kernel_size=1, kernel_regularizer=reg, name='blendconv') 219 | self.bngroup=[] 220 | self.convgroup=[] 221 | self.grids=grids 222 | self.level=len(grids) 223 | for i in range(self.level): 224 | self.bngroup.append(bn(name='bn'+str(i+1))) 225 | self.convgroup.append(conv(numout=43, kernel_size=1, kernel_regularizer=reg, name='conv1'+str(i+1))) 226 | def call(self,x,shape=[768,768],training=False): 227 | levels=[] 228 | height = math.ceil(shape[0]/32) 229 | width = math.ceil(shape[1]/32) 230 | x=tf.nn.relu(self.bn0(x,training=training)) 231 | x=self.conv0(x) 232 | levels.append(x) 233 | for i in range(self.level): 234 | h=height//self.grids[i] 235 | w=width//self.grids[i] 236 | kh=height-(self.grids[i]-1) * h 237 | kw=width-(self.grids[i]-1) * w 238 | y=tf.nn.avg_pool(x,[1,kh,kw,1],[1,h,w,1],padding='VALID') 239 | y=self.bngroup[i](y,training=training) 240 | y=tf.nn.relu(y) 241 | y=self.convgroup[i](y) 242 | y=tf.image.resize(y, [height,width]) 243 | levels.append(y) 244 | x=tf.concat(levels,-1) 245 | x=self.bn1(x,training=training) 246 | x=tf.nn.relu(x) 247 | x=self.conv1(x) 248 | return x 249 | 250 | 251 | 252 | class Upsample(keras.Model): 253 | def __init__(self,scope: str="up",reg:float=0.0005): 254 | super(Upsample, self).__init__(name=scope) 255 | self.bn0=bn(name='skipbn') 256 | self.bn1=bn(name='blendbn') 257 | self.conv0=conv(numout=128, kernel_size=1, kernel_regularizer=reg, name='skipconv') 258 | self.conv1=conv(numout=128, kernel_regularizer=reg, name='blendconv') 259 | def call(self,x,skip,training=False): 260 | skip=tf.nn.relu(self.bn0(skip,training=training)) 261 | skip=self.conv0(skip) 262 | x=tf.image.resize(x, [skip.shape[1],skip.shape[2]]) 263 | x=x+skip 264 | x=self.bn1(x,training=training) 265 | x=tf.nn.relu(x) 266 | x=self.conv1(x) 267 | return x 268 | 269 | 270 | 271 | class swiftnet(keras.Model): 272 | def __init__(self,scope: str="swiftnet",reg:float=0.0005,num_class:int=19): 273 | super(swiftnet, self).__init__(name=scope) 274 | self.bn = bn(name='classbn') 275 | self.conv = conv(numout=num_class, kernel_regularizer=reg, name='classconv') 276 | self.basenet = resnet18('Resnet18',reg=reg/4) 277 | self.spp = SpatialPyramidPooling('spp',reg=reg,grids=(8,4,2)) 278 | self.up1 = Upsample('up1',reg=reg) 279 | self.up2 = Upsample('up2',reg=reg) 280 | self.up3 = Upsample('up3',reg=reg) 281 | def call(self,x,aop,shape=[768,768],training=False): 282 | h=shape[0] 283 | w=shape[1] 284 | feature,showfeature=self.basenet(x,aop,training=training) 285 | x=self.spp(feature[-1],shape=[h,w],training=training) 286 | x=self.up1(x,feature[-2],training=training) 287 | x=self.up2(x,feature[-3],training=training) 288 | x=self.up3(x,feature[-4],training=training) 289 | x=self.bn(x,training=training) 290 | x=tf.nn.relu(x) 291 | x=self.conv(x) 292 | x=tf.image.resize(x, [h,w]) 293 | x=tf.nn.softmax(x) 294 | return x 295 | 296 | 297 | 298 | --------------------------------------------------------------------------------