├── .DS_Store ├── README.md └── STLNet.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lanyunzhu99/Learning-Statistical-Texture-for-Semantic-Segmentation/98128700142c267db5a744248bf8f9aac5c348b2/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning-Statistical-Texture-for-Semantic-Segmentation (CVPR 2021) 2 | This is the official Pytorch implementation of "Learning-Statistical-Texture-for-Semantic-Segmentation". [[Paper]](https://openaccess.thecvf.com/content/CVPR2021/papers/Zhu_Learning_Statistical_Texture_for_Semantic_Segmentation_CVPR_2021_paper.pdf) 3 | 4 | We have only released the code of model, more training codes will be released soon. 5 | 6 | ## Citation 7 | If you find the code helpful in your research or work, please cite the following papers. 8 | ``` 9 | @inproceedings{zhu2021learning, 10 | title={Learning Statistical Texture for Semantic Segmentation}, 11 | author={Zhu, Lanyun and Ji, Deyi and Zhu, Shiping and Gan, Weihao and Wu, Wei and Yan, Junjie}, 12 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 13 | pages={12537--12546}, 14 | year={2021} 15 | } 16 | ``` 17 | 18 | -------------------------------------------------------------------------------- /STLNet.py: -------------------------------------------------------------------------------- 1 | 2 | from numpy.core.fromnumeric import size 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from torch.nn.modules import module 7 | 8 | 9 | class ConvBNReLU(nn.Module): 10 | '''Module for the Conv-BN-ReLU tuple.''' 11 | 12 | def __init__(self, c_in, c_out, kernel_size, stride, padding, dilation=1, group=1, 13 | has_bn=True, has_relu=True, mode='2d'): 14 | super(ConvBNReLU, self).__init__() 15 | self.has_bn = has_bn 16 | self.has_relu = has_relu 17 | if mode == '2d': 18 | self.conv = nn.Conv2d( 19 | c_in, c_out, kernel_size=kernel_size, stride=stride, 20 | padding=padding, dilation=dilation, bias=False, groups=group) 21 | norm_layer = nn.BatchNorm2d 22 | elif mode == '1d': 23 | self.conv = nn.Conv1d( 24 | c_in, c_out, kernel_size=kernel_size, stride=stride, 25 | padding=padding, dilation=dilation, bias=False, groups=group) 26 | norm_layer = nn.BatchNorm1d 27 | if self.has_bn: 28 | self.bn = norm_layer(c_out) 29 | if self.has_relu: 30 | self.relu = nn.ReLU(inplace=True) 31 | 32 | def forward(self, x): 33 | x = self.conv(x) 34 | if self.has_bn: 35 | x = self.bn(x) 36 | if self.has_relu: 37 | x = self.relu(x) 38 | return x 39 | 40 | 41 | 42 | class QCO_1d(nn.Module): 43 | def __init__(self, level_num): 44 | super(QCO_1d, self).__init__() 45 | self.conv1 = nn.Sequential(ConvBNReLU(256, 256, 3, 1, 1, has_relu=False), nn.LeakyReLU(inplace=True)) 46 | self.conv2 = ConvBNReLU(256, 128, 1, 1, 0, has_bn=False, has_relu=False) 47 | self.f1 = nn.Sequential(ConvBNReLU(2, 64, 1, 1, 0, has_bn=False, has_relu=False, mode='1d'), nn.LeakyReLU(inplace=True)) 48 | self.f2 = ConvBNReLU(64, 128, 1, 1, 0, has_bn=False, mode='1d') 49 | self.out = ConvBNReLU(256, 128, 1, 1, 0, has_bn=True, mode='1d') 50 | self.level_num = level_num 51 | def forward(self, x): 52 | x = self.conv1(x) 53 | x = self.conv2(x) 54 | N, C, H, W = x.shape 55 | x_ave = F.adaptive_avg_pool2d(x, (1, 1)) 56 | cos_sim = (F.normalize(x_ave, dim=1) * F.normalize(x, dim=1)).sum(1) 57 | cos_sim = cos_sim.view(N, -1) 58 | cos_sim_min, _ = cos_sim.min(-1) 59 | cos_sim_min = cos_sim_min.unsqueeze(-1) 60 | cos_sim_max, _ = cos_sim.max(-1) 61 | cos_sim_max = cos_sim_max.unsqueeze(-1) 62 | q_levels = torch.arange(self.level_num).float().cuda() 63 | q_levels = q_levels.expand(N, self.level_num) 64 | q_levels = (2 * q_levels + 1) / (2 * self.level_num) * (cos_sim_max - cos_sim_min) + cos_sim_min 65 | q_levels = q_levels.unsqueeze(1) 66 | q_levels_inter = q_levels[:, :, 1] - q_levels[:, :, 0] 67 | q_levels_inter = q_levels_inter.unsqueeze(-1) 68 | cos_sim = cos_sim.unsqueeze(-1) 69 | quant = 1 - torch.abs(q_levels - cos_sim) 70 | quant = quant * (quant > (1 - q_levels_inter)) 71 | sta = quant.sum(1) 72 | sta = sta / (sta.sum(-1).unsqueeze(-1)) 73 | sta = sta.unsqueeze(1) 74 | sta = torch.cat([q_levels, sta], dim=1) 75 | sta = self.f1(sta) 76 | sta = self.f2(sta) 77 | x_ave = x_ave.squeeze(-1).squeeze(-1) 78 | x_ave = x_ave.expand(self.level_num, N, C).permute(1, 2, 0) 79 | sta = torch.cat([sta, x_ave], dim=1) 80 | sta = self.out(sta) 81 | return sta, quant 82 | 83 | 84 | 85 | class QCO_2d(nn.Module): 86 | def __init__(self, scale, level_num): 87 | super(QCO_2d, self).__init__() 88 | self.f1 = nn.Sequential(ConvBNReLU(3, 64, 1, 1, 0, has_bn=False, has_relu=False, mode='2d'), nn.LeakyReLU(inplace=True)) 89 | self.f2 = ConvBNReLU(64, 128, 1, 1, 0, has_bn=False, mode='2d') 90 | self.out = nn.Sequential(ConvBNReLU(256+128, 128, 1, 1, 0, has_bn=True, has_relu=True, mode='2d'), ConvBNReLU(128, 128, 1, 1, 0, has_bn=True, has_relu=False, mode='2d')) 91 | self.scale = scale 92 | self.level_num = level_num 93 | def forward(self, x): 94 | N1, C1, H1, W1 = x.shape 95 | if H1 // self.level_num != 0 or W1 // self.level_num != 0: 96 | x = F.adaptive_avg_pool2d(x, ((int(H1/self.level_num)*self.level_num), int(W1/self.level_num)*self.level_num)) 97 | N, C, H, W = x.shape 98 | self.size_h = int(H / self.scale) 99 | self.size_w = int(W / self.scale) 100 | x_ave = F.adaptive_avg_pool2d(x, (self.scale, self.scale)) 101 | x_ave_up = F.adaptive_avg_pool2d(x_ave, (H, W)) 102 | cos_sim = (F.normalize(x_ave_up, dim=1) * F.normalize(x, dim=1)).sum(1) 103 | cos_sim = cos_sim.unsqueeze(1) 104 | cos_sim = cos_sim.reshape(N, 1, self.scale, self.size_h, self.scale, self.size_w) 105 | cos_sim = cos_sim.permute(0, 1, 2, 4, 3, 5) 106 | cos_sim = cos_sim.reshape(N, 1, int(self.scale*self.scale), int(self.size_h*self.size_w)) 107 | cos_sim = cos_sim.permute(0, 1, 3, 2) 108 | cos_sim = cos_sim.squeeze(1) 109 | cos_sim_min, _ = cos_sim.min(1) 110 | cos_sim_min = cos_sim_min.unsqueeze(-1) 111 | cos_sim_max, _ = cos_sim.max(1) 112 | cos_sim_max = cos_sim_max.unsqueeze(-1) 113 | q_levels = torch.arange(self.level_num).float().cuda() 114 | q_levels = q_levels.expand(N, self.scale*self.scale, self.level_num) 115 | q_levels = (2 * q_levels + 1) / (2 * self.level_num) * (cos_sim_max - cos_sim_min) + cos_sim_min 116 | q_levels_inter = q_levels[:, :, 1] - q_levels[:, :, 0] 117 | q_levels_inter = q_levels_inter.unsqueeze(1).unsqueeze(-1) 118 | cos_sim = cos_sim.unsqueeze(-1) 119 | q_levels = q_levels.unsqueeze(1) 120 | quant = 1 - torch.abs(q_levels - cos_sim) 121 | quant = quant * (quant > (1 - q_levels_inter)) 122 | quant = quant.view([N, self.size_h, self.size_w, self.scale*self.scale, self.level_num]) 123 | quant = quant.permute(0, -2, -1, 1, 2) 124 | quant = quant.view(N, -1, self.size_h, self.size_w) 125 | quant = F.pad(quant, (0, 1, 0, 1), mode='constant', value=0.) 126 | quant = quant.view(N, self.scale*self.scale, self.level_num, self.size_h+1, self.size_w+1) 127 | quant_left = quant[:, :, :, :self.size_h, :self.size_w].unsqueeze(3) 128 | quant_right = quant[:, :, :, 1:, 1:].unsqueeze(2) 129 | quant = quant_left * quant_right 130 | sta = quant.sum(-1).sum(-1) 131 | sta = sta / (sta.sum(-1).sum(-1).unsqueeze(-1).unsqueeze(-1) + 1e-6) 132 | sta = sta.unsqueeze(1) 133 | q_levels = q_levels.expand(self.level_num, N, 1, self.scale*self.scale, self.level_num) 134 | q_levels_h = q_levels.permute(1, 2, 3, 0, 4) 135 | q_levels_w = q_levels_h.permute(0, 1, 2, 4, 3) 136 | sta = torch.cat([q_levels_h, q_levels_w, sta], dim=1) 137 | sta = sta.view(N, 3, self.scale * self.scale, -1) 138 | sta = self.f1(sta) 139 | sta = self.f2(sta) 140 | x_ave = x_ave.view(N, C, -1) 141 | x_ave = x_ave.expand(self.level_num*self.level_num, N, C, self.scale*self.scale) 142 | x_ave = x_ave.permute(1, 2, 3, 0) 143 | sta = torch.cat([x_ave, sta], dim=1) 144 | sta = self.out(sta) 145 | sta = sta.mean(-1) 146 | sta = sta.view(N, sta.shape[1], self.scale, self.scale) 147 | return sta 148 | 149 | 150 | 151 | class TEM(nn.Module): 152 | def __init__(self, level_num): 153 | super(TEM, self).__init__() 154 | self.level_num = level_num 155 | self.qco = QCO_1d(level_num) 156 | self.k = ConvBNReLU(128, 128, 1, 1, 0, has_bn=False, has_relu=False, mode='1d') 157 | self.q = ConvBNReLU(128, 128, 1, 1, 0, has_bn=False, has_relu=False, mode='1d') 158 | self.v = ConvBNReLU(128, 128, 1, 1, 0, has_bn=False, has_relu=False, mode='1d') 159 | self.out = ConvBNReLU(128, 256, 1, 1, 0, mode='1d') 160 | def forward(self, x): 161 | N, C, H, W = x.shape 162 | sta, quant = self.qco(x) 163 | k = self.k(sta) 164 | q = self.q(sta) 165 | v = self.v(sta) 166 | k = k.permute(0, 2, 1) 167 | w = torch.bmm(k, q) 168 | w = F.softmax(w, dim=-1) 169 | v = v.permute(0, 2, 1) 170 | f = torch.bmm(w, v) 171 | f = f.permute(0, 2, 1) 172 | f = self.out(f) 173 | quant = quant.permute(0, 2, 1) 174 | out = torch.bmm(f, quant) 175 | out = out.view(N, 256, H, W) 176 | return out 177 | 178 | 179 | 180 | class PTFEM(nn.Module): 181 | def __init__(self): 182 | super(PTFEM, self).__init__() 183 | self.conv = ConvBNReLU(512, 256, 1, 1, 0, has_bn=False, has_relu=False) 184 | self.qco_1 = QCO_2d(1, 8) 185 | self.qco_2 = QCO_2d(2, 8) 186 | self.qco_3 = QCO_2d(3, 8) 187 | self.qco_6 = QCO_2d(6, 8) 188 | self.out = ConvBNReLU(512, 256, 1, 1, 0) 189 | def forward(self, x): 190 | H, W = x.shape[2:] 191 | x = self.conv(x) 192 | sta_1 = self.qco_1(x) 193 | sta_2 = self.qco_2(x) 194 | sta_3 = self.qco_3(x) 195 | sta_6 = self.qco_6(x) 196 | N, C = sta_1.shape[:2] 197 | sta_1 = sta_1.view(N, C, 1, 1) 198 | sta_2 = sta_2.view(N, C, 2, 2) 199 | sta_3 = sta_3.view(N, C, 3, 3) 200 | sta_6 = sta_6.view(N, C, 6, 6) 201 | sta_1 = F.interpolate(sta_1, size=(H, W), mode='bilinear', align_corners=True) 202 | sta_2 = F.interpolate(sta_2, size=(H, W), mode='bilinear', align_corners=True) 203 | sta_3 = F.interpolate(sta_3, size=(H, W), mode='bilinear', align_corners=True) 204 | sta_6 = F.interpolate(sta_6, size=(H, W), mode='bilinear', align_corners=True) 205 | x = torch.cat([sta_1, sta_2, sta_3, sta_6], dim=1) 206 | x = self.out(x) 207 | return x 208 | 209 | 210 | 211 | class STL(nn.Module): 212 | def __init__(self, in_channel): 213 | super().__init__() 214 | self.conv_start = ConvBNReLU(in_channel, 256, 1, 1, 0) 215 | self.tem = TEM(128) 216 | self.ptfem = PTFEM() 217 | def forward(self, x): 218 | x = self.conv_start(x) 219 | x_tem = self.tem(x) 220 | x = torch.cat([x_tem, x], dim=1) #c = 256 + 256 = 512 221 | x_ptfem = self.ptfem(x) # 256 222 | x = torch.cat([x_ptfem, x], dim=1) 223 | return x 224 | --------------------------------------------------------------------------------