├── .gitattributes ├── DSC.py ├── DSC_sr.py ├── IRNN_Backward_cuda.cu ├── IRNN_Forward_cuda.cu ├── LICENSE ├── README.md ├── SBU_model └── README.md ├── backbone └── resnext │ ├── __init__.py │ ├── resnext101_regular.py │ └── resnext_101_32x4d_.py ├── dataset.py ├── dataset_sr.py ├── irnn.py ├── main.py ├── main_sr.py ├── misc.py ├── randomcrop.py ├── tensorboard.sh └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /DSC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | 6 | from irnn import irnn 7 | from backbone.resnext.resnext101_regular import ResNeXt101 8 | 9 | def conv1x1(in_channels, out_channels, stride = 1): 10 | return nn.Conv2d(in_channels,out_channels,kernel_size = 1, 11 | stride =stride, padding=0,bias=False) 12 | 13 | def conv3x3(in_channels, out_channels, stride = 1): 14 | return nn.Conv2d(in_channels,out_channels,kernel_size = 3, 15 | stride =stride, padding=1,bias=False) 16 | 17 | class Spacial_IRNN(nn.Module): 18 | def __init__(self,in_channels,alpha=1.0): 19 | super(Spacial_IRNN,self).__init__() 20 | self.left_weight = nn.Conv2d(in_channels,in_channels,kernel_size=1,stride=1,groups=in_channels,padding=0) 21 | self.right_weight = nn.Conv2d(in_channels,in_channels,kernel_size=1,stride=1,groups=in_channels,padding=0) 22 | self.up_weight = nn.Conv2d(in_channels,in_channels,kernel_size=1,stride=1,groups=in_channels,padding=0) 23 | self.down_weight = nn.Conv2d(in_channels,in_channels,kernel_size=1,stride=1,groups=in_channels,padding=0) 24 | self.left_weight.weight = nn.Parameter(torch.tensor([[[[alpha]]]]*in_channels)) 25 | self.right_weight.weight = nn.Parameter(torch.tensor([[[[alpha]]]]*in_channels)) 26 | self.up_weight.weight = nn.Parameter(torch.tensor([[[[alpha]]]]*in_channels)) 27 | self.down_weight.weight = nn.Parameter(torch.tensor([[[[alpha]]]]*in_channels)) 28 | 29 | def forward(self,input): 30 | return irnn.apply(input,self.up_weight.weight,self.right_weight.weight,self.down_weight.weight,self.left_weight.weight, self.up_weight.bias,self.right_weight.bias,self.down_weight.bias,self.left_weight.bias) 31 | 32 | class Attention(nn.Module): 33 | def __init__(self,in_channels): 34 | super(Attention,self).__init__() 35 | self.out_channels = int(in_channels/2) 36 | self.conv1 = nn.Conv2d(in_channels,self.out_channels,kernel_size=3,padding=1,stride=1) 37 | self.relu1 = nn.ReLU() 38 | self.conv2 = nn.Conv2d(self.out_channels,self.out_channels,kernel_size=3,padding=1,stride=1) 39 | self.relu2 = nn.ReLU() 40 | self.conv3 = nn.Conv2d(self.out_channels,4,kernel_size=1,padding=0,stride=1) 41 | self.sigmod = nn.Sigmoid() 42 | 43 | def forward(self,x): 44 | out = self.conv1(x) 45 | out = self.relu1(out) 46 | out = self.conv2(out) 47 | out = self.relu2(out) 48 | out = self.conv3(out) 49 | out = self.sigmod(out) 50 | return out 51 | 52 | class DSC_Module(nn.Module): 53 | def __init__(self,in_channels,out_channels,attention=1,alpha=1.0): 54 | super(DSC_Module,self).__init__() 55 | self.out_channels = out_channels 56 | self.irnn1 = Spacial_IRNN(self.out_channels,alpha) 57 | self.irnn2 = Spacial_IRNN(self.out_channels,alpha) 58 | self.conv_in = conv1x1(in_channels,in_channels) 59 | self.conv2 = conv1x1(in_channels*4,in_channels) 60 | self.conv3 = conv1x1(in_channels*4,in_channels) 61 | self.relu2 = nn.ReLU(True) 62 | self.attention = attention 63 | if self.attention: 64 | self.attention_layer = Attention(in_channels) 65 | 66 | 67 | 68 | def forward(self,x): 69 | if self.attention: 70 | weight = self.attention_layer(x) 71 | out = self.conv_in(x) 72 | top_up,top_right,top_down,top_left = self.irnn1(out) 73 | 74 | # direction attention 75 | if self.attention: 76 | top_up.mul(weight[:,0:1,:,:]) 77 | top_right.mul(weight[:,1:2,:,:]) 78 | top_down.mul(weight[:,2:3,:,:]) 79 | top_left.mul(weight[:,3:4,:,:]) 80 | out = torch.cat([top_up,top_right,top_down,top_left],dim=1) 81 | out = self.conv2(out) 82 | top_up,top_right,top_down,top_left = self.irnn2(out) 83 | 84 | # direction attention 85 | if self.attention: 86 | top_up.mul(weight[:,0:1,:,:]) 87 | top_right.mul(weight[:,1:2,:,:]) 88 | top_down.mul(weight[:,2:3,:,:]) 89 | top_left.mul(weight[:,3:4,:,:]) 90 | 91 | out = torch.cat([top_up,top_right,top_down,top_left],dim=1) 92 | out = self.conv3(out) 93 | out = self.relu2(out) 94 | 95 | return out 96 | 97 | class LayerConv(nn.Module): 98 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding, relu): 99 | super(LayerConv, self).__init__() 100 | self.conv = nn.Conv2d(in_channels=in_planes, out_channels=out_planes, kernel_size=kernel_size, 101 | stride=stride, padding=padding) 102 | self.relu = nn.ReLU() if relu else None 103 | 104 | def forward(self, x): 105 | x = self.conv(x) 106 | if self.relu is not None: 107 | x = self.relu(x) 108 | 109 | return x 110 | 111 | 112 | 113 | 114 | class Predict(nn.Module): 115 | def __init__(self, in_planes=32, out_planes=1, kernel_size=1): 116 | super(Predict, self).__init__() 117 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size) 118 | 119 | def forward(self, x): 120 | y = self.conv(x) 121 | 122 | return y 123 | 124 | class DSC(nn.Module): 125 | def __init__(self): 126 | super(DSC,self).__init__() 127 | 128 | resnext = ResNeXt101() 129 | self.layer0 = resnext.layer0 130 | self.layer1 = resnext.layer1 131 | self.layer2 = resnext.layer2 132 | self.layer3 = resnext.layer3 133 | self.layer4 = resnext.layer4 134 | 135 | 136 | self.layer4_conv1 = LayerConv(2048, 512, 7, 1, 3, True) 137 | self.layer4_conv2 = LayerConv(512, 512, 7, 1, 3, True) 138 | self.layer4_dsc = DSC_Module(512, 512) 139 | self.layer4_conv3 = LayerConv(1024, 32, 1, 1, 0, False) 140 | 141 | self.layer3_conv1 = LayerConv(1024, 256, 5, 1, 2, True) 142 | self.layer3_conv2 = LayerConv(256, 256, 5, 1, 2, True) 143 | self.layer3_dsc = DSC_Module(256, 256) 144 | self.layer3_conv3 = LayerConv(512, 32, 1, 1, 0, False) 145 | 146 | self.layer2_conv1 = LayerConv(512, 128, 5, 1, 2, True) 147 | self.layer2_conv2 = LayerConv(128, 128, 5, 1, 2, True) 148 | self.layer2_dsc = DSC_Module(128, 128) 149 | self.layer2_conv3 = LayerConv(256, 32, 1, 1, 0, False) 150 | 151 | self.layer1_conv1 = LayerConv(256, 64, 3, 1, 1, True) 152 | self.layer1_conv2 = LayerConv(64, 64, 3, 1, 1, True) 153 | self.layer1_dsc = DSC_Module(64, 64,alpha=0.8) 154 | self.layer1_conv3 = LayerConv(128, 32, 1, 1, 0, False) 155 | 156 | self.layer0_conv1 = LayerConv(64, 64, 3, 1, 1, True) 157 | self.layer0_conv2 = LayerConv(64, 64, 3, 1, 1, True) 158 | self.layer0_dsc = DSC_Module(64, 64,alpha=0.8) 159 | self.layer0_conv3 = LayerConv(128, 32, 1, 1, 0, False) 160 | 161 | self.relu = nn.ReLU() 162 | 163 | self.global_conv = LayerConv(160, 32, 1, 1, 0, True) 164 | 165 | self.layer4_predict = Predict(32, 1, 1) 166 | self.layer3_predict_ori = Predict(32, 1, 1) 167 | self.layer3_predict = Predict(2, 1, 1) 168 | self.layer2_predict_ori = Predict(32, 1, 1) 169 | self.layer2_predict = Predict(3, 1, 1) 170 | self.layer1_predict_ori = Predict(32, 1, 1) 171 | self.layer1_predict = Predict(4, 1, 1) 172 | self.layer0_predict_ori = Predict(32, 1, 1) 173 | self.layer0_predict = Predict(5, 1, 1) 174 | self.global_predict = Predict(32, 1, 1) 175 | self.fusion_predict = Predict(6, 1, 1) 176 | 177 | 178 | def forward(self, x): 179 | layer0 = self.layer0(x) 180 | layer1 = self.layer1(layer0) 181 | layer2 = self.layer2(layer1) 182 | layer3 = self.layer3(layer2) 183 | layer4 = self.layer4(layer3) 184 | 185 | layer4_conv1 = self.layer4_conv1(layer4) 186 | layer4_conv2 = self.layer4_conv2(layer4_conv1) 187 | layer4_dsc = self.layer4_dsc(layer4_conv2) 188 | layer4_context = torch.cat((layer4_conv2, layer4_dsc), 1) 189 | layer4_conv3 = self.layer4_conv3(layer4_context) 190 | layer4_up = F.upsample(layer4_conv3, size=x.size()[2:], mode='bilinear', align_corners=True) 191 | layer4_up = self.relu(layer4_up) 192 | 193 | layer3_conv1 = self.layer3_conv1(layer3) 194 | layer3_conv2 = self.layer3_conv2(layer3_conv1) 195 | layer3_dsc = self.layer3_dsc(layer3_conv2) 196 | layer3_context = torch.cat((layer3_conv2, layer3_dsc), 1) 197 | layer3_conv3 = self.layer3_conv3(layer3_context) 198 | layer3_up = F.upsample(layer3_conv3, size=x.size()[2:], mode='bilinear', align_corners=True) 199 | layer3_up = self.relu(layer3_up) 200 | 201 | layer2_conv1 = self.layer2_conv1(layer2) 202 | layer2_conv2 = self.layer2_conv2(layer2_conv1) 203 | layer2_dsc = self.layer2_dsc(layer2_conv2) 204 | layer2_context = torch.cat((layer2_conv2, layer2_dsc), 1) 205 | layer2_conv3 = self.layer2_conv3(layer2_context) 206 | layer2_up = F.upsample(layer2_conv3, size=x.size()[2:], mode='bilinear', align_corners=True) 207 | layer2_up = self.relu(layer2_up) 208 | 209 | layer1_conv1 = self.layer1_conv1(layer1) 210 | layer1_conv2 = self.layer1_conv2(layer1_conv1) 211 | layer1_dsc = self.layer1_dsc(layer1_conv2) 212 | layer1_context = torch.cat((layer1_conv2, layer1_dsc), 1) 213 | layer1_conv3 = self.layer1_conv3(layer1_context) 214 | layer1_up = F.upsample(layer1_conv3, size=x.size()[2:], mode='bilinear', align_corners=True) 215 | layer1_up = self.relu(layer1_up) 216 | 217 | layer0_conv1 = self.layer0_conv1(layer0) 218 | layer0_conv2 = self.layer0_conv2(layer0_conv1) 219 | layer0_dsc = self.layer0_dsc(layer0_conv2) 220 | layer0_context = torch.cat((layer0_conv2, layer0_dsc), 1) 221 | layer0_conv3 = self.layer0_conv3(layer0_context) 222 | layer0_up = F.upsample(layer0_conv3, size=x.size()[2:], mode='bilinear', align_corners=True) 223 | layer0_up = self.relu(layer0_up) 224 | 225 | global_concat = torch.cat((layer0_up, layer1_up, layer2_up, layer3_up, layer4_up), 1) 226 | global_conv = self.global_conv(global_concat) 227 | 228 | layer4_predict = self.layer4_predict(layer4_up) 229 | 230 | layer3_predict_ori = self.layer3_predict_ori(layer3_up) 231 | layer3_concat = torch.cat((layer3_predict_ori, layer4_predict), 1) 232 | layer3_predict = self.layer3_predict(layer3_concat) 233 | 234 | layer2_predict_ori = self.layer2_predict_ori(layer2_up) 235 | layer2_concat = torch.cat((layer2_predict_ori, layer3_predict_ori, layer4_predict), 1) 236 | layer2_predict = self.layer2_predict(layer2_concat) 237 | 238 | layer1_predict_ori = self.layer1_predict_ori(layer1_up) 239 | layer1_concat = torch.cat((layer1_predict_ori, layer2_predict_ori, layer3_predict_ori, layer4_predict), 1) 240 | layer1_predict = self.layer1_predict(layer1_concat) 241 | 242 | layer0_predict_ori = self.layer0_predict_ori(layer0_up) 243 | layer0_concat = torch.cat((layer0_predict_ori, layer1_predict_ori, layer2_predict_ori, 244 | layer3_predict_ori, layer4_predict), 1) 245 | layer0_predict = self.layer0_predict(layer0_concat) 246 | 247 | global_predict = self.global_predict(global_conv) 248 | 249 | # fusion 250 | fusion_concat = torch.cat((layer0_predict, layer1_predict, layer2_predict, layer3_predict, 251 | layer4_predict, global_predict), 1) 252 | fusion_predict = self.fusion_predict(fusion_concat) 253 | 254 | 255 | return layer4_predict, layer3_predict, layer2_predict, layer1_predict, layer0_predict, \ 256 | global_predict, fusion_predict 257 | 258 | -------------------------------------------------------------------------------- /DSC_sr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | 6 | from irnn import irnn 7 | from backbone.resnext.resnext101_regular import ResNeXt101 8 | 9 | def conv1x1(in_channels, out_channels, stride=1): 10 | return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False) 11 | 12 | def conv3x3(in_channels, out_channels, stride=1): 13 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) 14 | 15 | class Spacial_IRNN(nn.Module): 16 | def __init__(self, in_channels, alpha=1.0): 17 | super(Spacial_IRNN, self).__init__() 18 | self.left_weight = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, groups=in_channels, padding=0) 19 | self.right_weight = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, groups=in_channels, padding=0) 20 | self.up_weight = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, groups=in_channels, padding=0) 21 | self.down_weight = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, groups=in_channels, padding=0) 22 | self.left_weight.weight = nn.Parameter(torch.tensor([[[[alpha]]]]*in_channels)) 23 | self.right_weight.weight = nn.Parameter(torch.tensor([[[[alpha]]]]*in_channels)) 24 | self.up_weight.weight = nn.Parameter(torch.tensor([[[[alpha]]]]*in_channels)) 25 | self.down_weight.weight = nn.Parameter(torch.tensor([[[[alpha]]]]*in_channels)) 26 | 27 | def forward(self, input): 28 | return irnn.apply(input, self.up_weight.weight, self.right_weight.weight, self.down_weight.weight, self.left_weight.weight, self.up_weight.bias, self.right_weight.bias, self.down_weight.bias, self.left_weight.bias) 29 | 30 | class Attention(nn.Module): 31 | def __init__(self, in_channels): 32 | super(Attention, self).__init__() 33 | self.out_channels = int(in_channels / 2) 34 | self.conv1 = nn.Conv2d(in_channels, self.out_channels, kernel_size=3, padding=1, stride=1) 35 | self.relu1 = nn.ReLU() 36 | self.conv2 = nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1, stride=1) 37 | self.relu2 = nn.ReLU() 38 | self.conv3 = nn.Conv2d(self.out_channels, 4, kernel_size=1, padding=0, stride=1) 39 | self.sigmod = nn.Sigmoid() 40 | 41 | def forward(self, x): 42 | out = self.conv1(x) 43 | out = self.relu1(out) 44 | out = self.conv2(out) 45 | out = self.relu2(out) 46 | out = self.conv3(out) 47 | out = self.sigmod(out) 48 | return out 49 | 50 | class DSC_Module(nn.Module): 51 | def __init__(self, in_channels, out_channels, attention=1, alpha=1.0): 52 | super(DSC_Module, self).__init__() 53 | self.out_channels = out_channels 54 | self.irnn1 = Spacial_IRNN(self.out_channels, alpha) 55 | self.irnn2 = Spacial_IRNN(self.out_channels, alpha) 56 | self.conv_in = conv1x1(in_channels, in_channels) 57 | self.conv2 = conv1x1(in_channels * 4, in_channels) 58 | self.conv3 = conv1x1(in_channels * 4, in_channels) 59 | self.relu2 = nn.ReLU(True) 60 | self.attention = attention 61 | if self.attention: 62 | self.attention_layer = Attention(in_channels) 63 | 64 | def forward(self, x): 65 | if self.attention: 66 | weight = self.attention_layer(x) 67 | out = self.conv_in(x) 68 | top_up, top_right, top_down, top_left = self.irnn1(out) 69 | 70 | # direction attention 71 | if self.attention: 72 | top_up.mul(weight[:, 0:1, :, :]) 73 | top_right.mul(weight[:, 1:2, :, :]) 74 | top_down.mul(weight[:, 2:3, :, :]) 75 | top_left.mul(weight[:, 3:4, :, :]) 76 | out = torch.cat([top_up, top_right, top_down, top_left], dim=1) 77 | out = self.conv2(out) 78 | top_up, top_right, top_down, top_left = self.irnn2(out) 79 | 80 | # direction attention 81 | if self.attention: 82 | top_up.mul(weight[:, 0:1, :, :]) 83 | top_right.mul(weight[:, 1:2, :, :]) 84 | top_down.mul(weight[:, 2:3, :, :]) 85 | top_left.mul(weight[:, 3:4, :, :]) 86 | 87 | out = torch.cat([top_up, top_right, top_down, top_left], dim=1) 88 | out = self.conv3(out) 89 | out = self.relu2(out) 90 | 91 | return out 92 | 93 | class LayerConv(nn.Module): 94 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding, relu): 95 | super(LayerConv, self).__init__() 96 | self.conv = nn.Conv2d(in_channels=in_planes, out_channels=out_planes, kernel_size=kernel_size, 97 | stride=stride, padding=padding) 98 | self.relu = nn.ReLU() if relu else None 99 | 100 | def forward(self, x): 101 | x = self.conv(x) 102 | if self.relu is not None: 103 | x = self.relu(x) 104 | 105 | return x 106 | 107 | class Predict(nn.Module): 108 | def __init__(self, in_planes=32, out_planes=1, kernel_size=1): 109 | super(Predict, self).__init__() 110 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size) 111 | 112 | def forward(self, x): 113 | y = self.conv(x) 114 | 115 | return y 116 | 117 | class DSC(nn.Module): 118 | def __init__(self): 119 | super(DSC, self).__init__() 120 | 121 | resnext = ResNeXt101() 122 | self.layer0 = resnext.layer0 123 | self.layer1 = resnext.layer1 124 | self.layer2 = resnext.layer2 125 | self.layer3 = resnext.layer3 126 | self.layer4 = resnext.layer4 127 | 128 | self.layer4_conv1 = LayerConv(2048, 512, 7, 1, 3, True) 129 | self.layer4_conv2 = LayerConv(512, 512, 7, 1, 3, True) 130 | self.layer4_dsc = DSC_Module(512, 512) 131 | self.layer4_conv3 = LayerConv(1024, 32, 1, 1, 0, False) 132 | 133 | self.layer3_conv1 = LayerConv(1024, 256, 5, 1, 2, True) 134 | self.layer3_conv2 = LayerConv(256, 256, 5, 1, 2, True) 135 | self.layer3_dsc = DSC_Module(256, 256) 136 | self.layer3_conv3 = LayerConv(512, 32, 1, 1, 0, False) 137 | 138 | self.layer2_conv1 = LayerConv(512, 128, 5, 1, 2, True) 139 | self.layer2_conv2 = LayerConv(128, 128, 5, 1, 2, True) 140 | self.layer2_dsc = DSC_Module(128, 128) 141 | self.layer2_conv3 = LayerConv(256, 32, 1, 1, 0, False) 142 | 143 | self.layer1_conv1 = LayerConv(256, 64, 3, 1, 1, True) 144 | self.layer1_conv2 = LayerConv(64, 64, 3, 1, 1, True) 145 | self.layer1_dsc = DSC_Module(64, 64, alpha=0.8) 146 | self.layer1_conv3 = LayerConv(128, 32, 1, 1, 0, False) 147 | 148 | self.layer0_conv1 = LayerConv(64, 64, 3, 1, 1, True) 149 | self.layer0_conv2 = LayerConv(64, 64, 3, 1, 1, True) 150 | self.layer0_dsc = DSC_Module(64, 64, alpha=0.8) 151 | self.layer0_conv3 = LayerConv(128, 32, 1, 1, 0, False) 152 | 153 | self.relu = nn.ReLU() 154 | 155 | self.global_conv = LayerConv(160, 32, 1, 1, 0, True) 156 | 157 | # output channel to 3 158 | self.layer4_predict = Predict(32, 3, 1) 159 | self.layer3_predict_ori = Predict(32, 3, 1) 160 | self.layer3_predict = Predict(6, 3, 1) 161 | self.layer2_predict_ori = Predict(32, 3, 1) 162 | self.layer2_predict = Predict(9, 3, 1) 163 | self.layer1_predict_ori = Predict(32, 3, 1) 164 | self.layer1_predict = Predict(12, 3, 1) 165 | self.layer0_predict_ori = Predict(32, 3, 1) 166 | self.layer0_predict = Predict(15, 3, 1) 167 | self.global_predict = Predict(32, 3, 1) 168 | self.fusion_predict = Predict(18, 3, 1) 169 | 170 | def forward(self, x, x_non_norm): 171 | layer0 = self.layer0(x) 172 | layer1 = self.layer1(layer0) 173 | layer2 = self.layer2(layer1) 174 | layer3 = self.layer3(layer2) 175 | layer4 = self.layer4(layer3) 176 | 177 | layer4_conv1 = self.layer4_conv1(layer4) 178 | layer4_conv2 = self.layer4_conv2(layer4_conv1) 179 | layer4_dsc = self.layer4_dsc(layer4_conv2) 180 | layer4_context = torch.cat((layer4_conv2, layer4_dsc), 1) 181 | layer4_conv3 = self.layer4_conv3(layer4_context) 182 | layer4_up = F.interpolate(layer4_conv3, size=x.size()[2:], mode='bilinear', align_corners=True) 183 | layer4_up = self.relu(layer4_up) 184 | 185 | layer3_conv1 = self.layer3_conv1(layer3) 186 | layer3_conv2 = self.layer3_conv2(layer3_conv1) 187 | layer3_dsc = self.layer3_dsc(layer3_conv2) 188 | layer3_context = torch.cat((layer3_conv2, layer3_dsc), 1) 189 | layer3_conv3 = self.layer3_conv3(layer3_context) 190 | layer3_up = F.interpolate(layer3_conv3, size=x.size()[2:], mode='bilinear', align_corners=True) 191 | layer3_up = self.relu(layer3_up) 192 | 193 | layer2_conv1 = self.layer2_conv1(layer2) 194 | layer2_conv2 = self.layer2_conv2(layer2_conv1) 195 | layer2_dsc = self.layer2_dsc(layer2_conv2) 196 | layer2_context = torch.cat((layer2_conv2, layer2_dsc), 1) 197 | layer2_conv3 = self.layer2_conv3(layer2_context) 198 | layer2_up = F.interpolate(layer2_conv3, size=x.size()[2:], mode='bilinear', align_corners=True) 199 | layer2_up = self.relu(layer2_up) 200 | 201 | layer1_conv1 = self.layer1_conv1(layer1) 202 | layer1_conv2 = self.layer1_conv2(layer1_conv1) 203 | layer1_dsc = self.layer1_dsc(layer1_conv2) 204 | layer1_context = torch.cat((layer1_conv2, layer1_dsc), 1) 205 | layer1_conv3 = self.layer1_conv3(layer1_context) 206 | layer1_up = F.interpolate(layer1_conv3, size=x.size()[2:], mode='bilinear', align_corners=True) 207 | layer1_up = self.relu(layer1_up) 208 | 209 | layer0_conv1 = self.layer0_conv1(layer0) 210 | layer0_conv2 = self.layer0_conv2(layer0_conv1) 211 | layer0_dsc = self.layer0_dsc(layer0_conv2) 212 | layer0_context = torch.cat((layer0_conv2, layer0_dsc), 1) 213 | layer0_conv3 = self.layer0_conv3(layer0_context) 214 | layer0_up = F.interpolate(layer0_conv3, size=x.size()[2:], mode='bilinear', align_corners=True) 215 | layer0_up = self.relu(layer0_up) 216 | 217 | global_concat = torch.cat((layer0_up, layer1_up, layer2_up, layer3_up, layer4_up), 1) 218 | global_conv = self.global_conv(global_concat) 219 | 220 | layer4_predict = self.layer4_predict(layer4_up) 221 | 222 | layer3_predict_ori = self.layer3_predict_ori(layer3_up) 223 | layer3_concat = torch.cat((layer3_predict_ori, layer4_predict), 1) 224 | layer3_predict = self.layer3_predict(layer3_concat) 225 | 226 | layer2_predict_ori = self.layer2_predict_ori(layer2_up) 227 | layer2_concat = torch.cat((layer2_predict_ori, layer3_predict_ori, layer4_predict), 1) 228 | layer2_predict = self.layer2_predict(layer2_concat) 229 | 230 | layer1_predict_ori = self.layer1_predict_ori(layer1_up) 231 | layer1_concat = torch.cat((layer1_predict_ori, layer2_predict_ori, layer3_predict_ori, layer4_predict), 1) 232 | layer1_predict = self.layer1_predict(layer1_concat) 233 | 234 | layer0_predict_ori = self.layer0_predict_ori(layer0_up) 235 | layer0_concat = torch.cat((layer0_predict_ori, layer1_predict_ori, layer2_predict_ori, layer3_predict_ori, layer4_predict), 1) 236 | layer0_predict = self.layer0_predict(layer0_concat) 237 | 238 | global_predict = self.global_predict(global_conv) 239 | 240 | # fusion 241 | fusion_concat = torch.cat((layer0_predict, layer1_predict, layer2_predict, layer3_predict, layer4_predict, global_predict), 1) 242 | fusion_predict = self.fusion_predict(fusion_concat) 243 | 244 | 245 | # send x_non_norm to device 246 | x_non_norm = x_non_norm.to(x.device) 247 | layer4_predict = layer4_predict + x_non_norm 248 | layer3_predict = layer3_predict + x_non_norm 249 | layer2_predict = layer2_predict + x_non_norm 250 | layer1_predict = layer1_predict + x_non_norm 251 | layer0_predict = layer0_predict + x_non_norm 252 | global_predict = global_predict + x_non_norm 253 | fusion_predict = fusion_predict + x_non_norm 254 | return layer4_predict, layer3_predict, layer2_predict, layer1_predict, layer0_predict, global_predict, fusion_predict 255 | 256 | -------------------------------------------------------------------------------- /IRNN_Backward_cuda.cu: -------------------------------------------------------------------------------- 1 | #define CUDA_KERNEL_LOOP(i, n) \ 2 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 3 | i < (n); \ 4 | i += blockDim.x * gridDim.x) 5 | 6 | #define INDEX(b,c,h,w,channels,height,width) ((b * channels + c) * height + h) * width+ w 7 | 8 | 9 | extern "C" __global__ void IRNNBackward( 10 | float* grad_input, 11 | 12 | float* grad_weight_up_map, 13 | float* grad_weight_right_map, 14 | float* grad_weight_down_map, 15 | float* grad_weight_left_map, 16 | 17 | float* grad_bias_up_map, 18 | float* grad_bias_right_map, 19 | float* grad_bias_down_map, 20 | float* grad_bias_left_map, 21 | 22 | const float* weight_up, 23 | const float* weight_right, 24 | const float* weight_down, 25 | const float* weight_left, 26 | 27 | const float* grad_output_up, 28 | const float* grad_output_right, 29 | const float* grad_output_down, 30 | const float* grad_output_left, 31 | 32 | const float* output_up, 33 | const float* output_right, 34 | const float* output_down, 35 | const float* output_left, 36 | 37 | const int channels, 38 | const int height, 39 | const int width, 40 | const int n) { 41 | 42 | CUDA_KERNEL_LOOP(index,n){ 43 | 44 | int w = index % width; 45 | int h = index / width % height; 46 | int c = index / width / height % channels; 47 | int b = index / width / height / channels; 48 | 49 | float diff_left = 0; 50 | float diff_right = 0; 51 | float diff_up = 0; 52 | float diff_down = 0; 53 | 54 | //left 55 | 56 | for (int i = 0; i<=w; i++) 57 | { 58 | diff_left *= weight_left[c]; 59 | diff_left += grad_output_left[INDEX(b, c, h, i, channels, height, width)]; 60 | diff_left *= (output_left[INDEX(b, c, h, i, channels, height, width)]<=0)? 0 : 1; 61 | } 62 | 63 | 64 | float temp = grad_output_left[INDEX(b, c, h, 0, channels, height, width)]; 65 | for (int i = 1; i < w +1 ; i++) 66 | { 67 | temp = (output_left[INDEX(b, c, h, i-1, channels, height, width)] >0?1:0) * temp * weight_left[c] + grad_output_left[INDEX(b, c, h, i, channels, height, width)]; 68 | } 69 | 70 | if (w != width - 1){ 71 | grad_weight_left_map[index] = temp * output_left[INDEX(b, c, h, w+1, channels, height, width)] * (output_left[index] > 0? 1:0); 72 | grad_bias_left_map[index] = diff_left; 73 | } 74 | 75 | // right 76 | 77 | for (int i = width -1; i>=w; i--) 78 | { 79 | diff_right *= weight_right[c]; 80 | diff_right += grad_output_right[INDEX(b, c, h, i, channels, height, width)]; 81 | diff_right *= (output_right[INDEX(b, c, h, i, channels, height, width)]<=0)? 0 : 1; 82 | } 83 | 84 | 85 | temp = grad_output_right[INDEX(b, c, h, width-1, channels, height, width)]; 86 | for (int i = width -2; i > w - 1 ; i--) 87 | { 88 | temp = (output_right[INDEX(b, c, h, i+1, channels, height, width)] >0?1:0) * temp * weight_right[c] + grad_output_right[INDEX(b, c, h, i, channels, height, width)]; 89 | } 90 | 91 | if (w != 0){ 92 | grad_weight_right_map[index] = temp * output_right[INDEX(b, c, h, w-1, channels, height, width)] * (output_right[index] > 0? 1:0); 93 | grad_bias_right_map[index] = diff_right; 94 | } 95 | 96 | // up 97 | 98 | 99 | for (int i = 0; i<=h; i++) 100 | { 101 | diff_up *= weight_up[c]; 102 | diff_up += grad_output_up[INDEX(b, c, i, w, channels, height, width)]; 103 | diff_up *= (output_up[INDEX(b, c, i, w, channels, height, width)]<=0)? 0 : 1; 104 | } 105 | 106 | 107 | temp = grad_output_up[INDEX(b, c, 0, w, channels, height, width)]; 108 | for (int i = 1; i < h +1 ; i++) 109 | { 110 | temp = (output_up[INDEX(b, c, i-1, w, channels, height, width)] >0?1:0) * temp * weight_up[c] + grad_output_up[INDEX(b, c, i, w, channels, height, width)]; 111 | } 112 | 113 | if (h != height - 1){ 114 | grad_weight_up_map[index] = temp * output_up[INDEX(b, c, h+1, w, channels, height, width)] * (output_up[index] > 0? 1:0); 115 | grad_bias_up_map[index] = diff_up; 116 | } 117 | 118 | // down 119 | 120 | for (int i = height -1; i>=h; i--) 121 | { 122 | diff_down *= weight_down[c]; 123 | diff_down += grad_output_down[INDEX(b, c, i, w, channels, height, width)]; 124 | diff_down *= (output_down[INDEX(b, c, i, w, channels, height, width)]<=0)? 0 : 1; 125 | } 126 | 127 | 128 | temp = grad_output_down[INDEX(b, c, height-1, w, channels, height, width)]; 129 | for (int i = height -2; i > h - 1 ; i--) 130 | { 131 | temp = (output_down[INDEX(b, c, i+1, w, channels, height, width)] >0?1:0) * temp * weight_down[c] + grad_output_down[INDEX(b, c, i, w, channels, height, width)]; 132 | } 133 | 134 | if (h != 0){ 135 | grad_weight_down_map[index] = temp * output_down[INDEX(b, c, h-1, w, channels, height, width)] * (output_down[index] > 0? 1:0); 136 | grad_bias_down_map[index] = diff_down; 137 | } 138 | grad_input[index] = diff_down + diff_left + diff_right + diff_up; 139 | } 140 | } -------------------------------------------------------------------------------- /IRNN_Forward_cuda.cu: -------------------------------------------------------------------------------- 1 | #define CUDA_KERNEL_LOOP(i, n) \ 2 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 3 | i < (n); \ 4 | i += blockDim.x * gridDim.x) 5 | 6 | #define INDEX(b,c,h,w,channels,height,width) ((b * channels + c) * height + h) * width+ w 7 | 8 | extern "C" __global__ void IRNNForward( 9 | const float* input_feature, 10 | 11 | const float* weight_up, 12 | const float* weight_right, 13 | const float* weight_down, 14 | const float* weight_left, 15 | 16 | const float* bias_up, 17 | const float* bias_right, 18 | const float* bias_down, 19 | const float* bias_left, 20 | 21 | float* output_up, 22 | float* output_right, 23 | float* output_down, 24 | float* output_left, 25 | 26 | const int channels, 27 | const int height, 28 | const int width, 29 | const int n){ 30 | 31 | CUDA_KERNEL_LOOP(index,n){ 32 | int w = index % width; 33 | int h = index / width % height; 34 | int c = index / width / height % channels; 35 | int b = index / width / height / channels; 36 | 37 | float temp = 0; 38 | 39 | // left 40 | output_left[index] = input_feature[INDEX(b, c, h, width-1, channels, height, width)] > 0 ? input_feature[INDEX(b, c, h, width-1, channels, height, width)] : 0; 41 | for (int i = width-2; i>=w; i--) 42 | { 43 | temp = output_left[index] * weight_left[c] + bias_left[c] + input_feature[INDEX(b, c, h, i, channels, height, width)]; 44 | output_left[index] = (temp > 0)? temp : 0; 45 | } 46 | 47 | // right 48 | output_right[index] = input_feature[INDEX(b, c, h, 0, channels, height, width)] > 0 ? input_feature[INDEX(b, c, h, 0, channels, height, width)] : 0; 49 | for (int i = 1; i <= w; i++) 50 | { 51 | temp = output_right[index] * weight_right[c] + bias_right[c] + input_feature[INDEX(b, c, h, i, channels, height, width)]; 52 | output_right[index] = (temp > 0)? temp : 0; 53 | } 54 | 55 | // up 56 | output_up[index] = input_feature[INDEX(b,c,height-1,w,channels,height,width)] > 0 ? input_feature[INDEX(b,c,height-1,w,channels,height,width)] : 0; 57 | for (int i = height-2; i >= h; i--) 58 | { 59 | temp = output_up[index] * weight_up[c] + bias_up[c] + input_feature[INDEX(b, c, i, w, channels, height, width)]; 60 | output_up[index] = (temp > 0)? temp : 0; 61 | } 62 | 63 | // down 64 | output_down[index] = input_feature[INDEX(b, c, 0, w, channels, height, width)] > 0 ? input_feature[INDEX(b, c, 0, w, channels, height, width)] : 0; 65 | for (int i = 1; i <= h; i++) 66 | { 67 | temp = output_down[index] * weight_down[c] + bias_down[c] + input_feature[INDEX(b, c, i, w, channels, height, width)]; 68 | output_down[index] = (temp > 0)? temp : 0; 69 | } 70 | } 71 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 steve wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DSC-PyTorch 2 | This is a PyTorch implementation of ["Direction-Aware Spatial Context Features for Shadow Detection, CVPR'18"](https://arxiv.org/abs/1712.04142) and ["Direction-Aware Spatial Context Features for Shadow Detection and Removal, T-PAMI'19"](https://arxiv.org/abs/1805.04635) based on [Xiaowei](https://xw-hu.github.io)'s [DSC (Caffe)](https://github.com/xw-hu/DSC) written by Tianyu Wang. 3 | 4 | The Spacial IRNN is implemented by using CUDA 11.x. The backbone is ResNeXt101 pre-trained on ImageNet and the implementation of loss is from [Quanlong Zheng](https://quanlzheng.github.io). 5 | 6 | ## Results 7 | We use two GTX 1080Ti to train the DSC on SBU dataset. 8 | 9 | ### SBU 10 | | Methods | BER | Accuracy | 11 | | --- | --- | --- | 12 | | DSC (Caffe) | 5.59 |**0.97** | 13 | | DSC (Our) | **5.19** | 0.95 | 14 | 15 | **Pre-trained model is available. You can download from [OneDrive](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155152065_link_cuhk_edu_hk/EcO20MV0kSVKkEbXO2NVIWMB6jewfk_lJK4SJjDvHcB6Ag?e=6P2h0m) and put it into `SBU_model` folder.** 16 | 17 | * You can download the ResNeXt101 model from [Google Drive](https://drive.google.com/open?id=1EDUcaGNiakWO9Xvk9kWgkkcnTYZ6VQoT) and put it in main folder. 18 | 19 | ## Requirements 20 | * PyTorch == 1.8.1 (training and testing) 21 | * Cupy ([Installation Guide](https://docs-cupy.chainer.org/en/stable/install.html#install-cupy)) 22 | * TensorBoardX 23 | * Python 24 | * progressbar2 25 | * scikit-image 26 | * pydensecrf 27 | 28 | ## Train/Test 29 | 1. **Clone this repository** 30 | 31 | ```bash 32 | git clone https://github.com/stevewongv/DSC-PyTorch.git 33 | ``` 34 | 2. **Train** 35 | 36 | ```bash 37 | python3 main.py -a train # For Shadow Detection 38 | python3 main_sr.py -a train # For Shadow Removal 39 | ``` 40 | 3. **Test** 41 | 42 | ```bash 43 | python3 main.py -a test # For Shadow Detection 44 | python3 main_sr.py -a test # For Shadow Removal 45 | ``` 46 | 47 | ## Citations 48 | 49 | ``` 50 | @InProceedings{Hu_2018_CVPR, 51 | author = {Hu, Xiaowei and Zhu, Lei and Fu, Chi-Wing and Qin, Jing and Heng, Pheng-Ann}, 52 | title = {Direction-Aware Spatial Context Features for Shadow Detection}, 53 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 54 | pages={7454--7462}, 55 | year = {2018} 56 | } 57 | 58 | @article{hu2020direction, 59 | author = {Hu, Xiaowei and Fu, Chi-Wing and Zhu, Lei and Qin, Jing and Heng, Pheng-Ann}, 60 | title = {Direction-Aware Spatial Context Features for Shadow Detection and Removal}, 61 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 62 | pages={2795--2808}, 63 | year = {2020} 64 | } 65 | 66 | Modified DSC module is used in SPANet: 67 | 68 | @InProceedings{Wang_2019_CVPR, 69 | author = {Wang, Tianyu and Yang, Xin and Xu, Ke and Chen, Shaozhe and Zhang, Qiang and Lau, Rynson W.H.}, 70 | title = {Spatial Attentive Single-Image Deraining with a High Quality Real Rain Dataset}, 71 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 72 | month = {June}, 73 | year = {2019} 74 | } 75 | ``` 76 | 77 | ## TODO List 78 | * [x] ResNext101 Backbone 79 | * [x] Test on SBU Test Set 80 | * [ ] VGG19 Backbone 81 | * [ ] Test on ISTD Test Set 82 | * [ ] Test on UCF Test Set 83 | * [ ] ... 84 | -------------------------------------------------------------------------------- /SBU_model/README.md: -------------------------------------------------------------------------------- 1 | You can download the pre-trained model from [Google Drive](https://drive.google.com/file/d/17VfUOu5xwHHc3M05N0oCjF2FGGirw7gt/view?usp=sharing). -------------------------------------------------------------------------------- /backbone/resnext/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnext101_regular import ResNeXt101 2 | -------------------------------------------------------------------------------- /backbone/resnext/resnext101_regular.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from backbone.resnext import resnext_101_32x4d_ 5 | 6 | 7 | class ResNeXt101(nn.Module): 8 | def __init__(self, backbone_path='./resnext_101_32x4d.pth'): 9 | super(ResNeXt101, self).__init__() 10 | net = resnext_101_32x4d_.resnext_101_32x4d 11 | if backbone_path is not None: 12 | weights = torch.load(backbone_path) 13 | # del weights['0.weight'] 14 | net.load_state_dict(weights, strict=True) 15 | print("Load ResNeXt Weights Succeed!") 16 | 17 | net = list(net.children()) 18 | self.layer0 = nn.Sequential(*net[:3]) 19 | self.layer1 = nn.Sequential(*net[3: 5]) 20 | self.layer2 = net[5] 21 | self.layer3 = net[6] 22 | self.layer4 = net[7] 23 | 24 | def forward(self, x): 25 | layer0 = self.layer0(x) 26 | layer1 = self.layer1(layer0) 27 | layer2 = self.layer2(layer1) 28 | layer3 = self.layer3(layer2) 29 | layer4 = self.layer4(layer3) 30 | return layer4 31 | -------------------------------------------------------------------------------- /backbone/resnext/resnext_101_32x4d_.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class LambdaBase(nn.Sequential): 7 | def __init__(self, fn, *args): 8 | super(LambdaBase, self).__init__(*args) 9 | self.lambda_func = fn 10 | 11 | def forward_prepare(self, input): 12 | output = [] 13 | for module in self._modules.values(): 14 | output.append(module(input)) 15 | return output if output else input 16 | 17 | 18 | class Lambda(LambdaBase): 19 | def forward(self, input): 20 | return self.lambda_func(self.forward_prepare(input)) 21 | 22 | 23 | class LambdaMap(LambdaBase): 24 | def forward(self, input): 25 | return list(map(self.lambda_func, self.forward_prepare(input))) 26 | 27 | 28 | class LambdaReduce(LambdaBase): 29 | def forward(self, input): 30 | return reduce(self.lambda_func, self.forward_prepare(input)) 31 | 32 | 33 | resnext_101_32x4d = nn.Sequential( # Sequential, 34 | nn.Conv2d(3, 64, (7, 7), (2, 2), (3, 3), 1, 1, bias=False), 35 | nn.BatchNorm2d(64), 36 | nn.ReLU(), 37 | nn.MaxPool2d((3, 3), (2, 2), (1, 1)), 38 | nn.Sequential( # Sequential, 39 | nn.Sequential( # Sequential, 40 | LambdaMap(lambda x: x, # ConcatTable, 41 | nn.Sequential( # Sequential, 42 | nn.Sequential( # Sequential, 43 | nn.Conv2d(64, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 44 | nn.BatchNorm2d(128), 45 | nn.ReLU(), 46 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 47 | nn.BatchNorm2d(128), 48 | nn.ReLU(), 49 | ), 50 | nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 51 | nn.BatchNorm2d(256), 52 | ), 53 | nn.Sequential( # Sequential, 54 | nn.Conv2d(64, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 55 | nn.BatchNorm2d(256), 56 | ), 57 | ), 58 | LambdaReduce(lambda x, y: x + y), # CAddTable, 59 | nn.ReLU(), 60 | ), 61 | nn.Sequential( # Sequential, 62 | LambdaMap(lambda x: x, # ConcatTable, 63 | nn.Sequential( # Sequential, 64 | nn.Sequential( # Sequential, 65 | nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 66 | nn.BatchNorm2d(128), 67 | nn.ReLU(), 68 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 69 | nn.BatchNorm2d(128), 70 | nn.ReLU(), 71 | ), 72 | nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 73 | nn.BatchNorm2d(256), 74 | ), 75 | Lambda(lambda x: x), # Identity, 76 | ), 77 | LambdaReduce(lambda x, y: x + y), # CAddTable, 78 | nn.ReLU(), 79 | ), 80 | nn.Sequential( # Sequential, 81 | LambdaMap(lambda x: x, # ConcatTable, 82 | nn.Sequential( # Sequential, 83 | nn.Sequential( # Sequential, 84 | nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 85 | nn.BatchNorm2d(128), 86 | nn.ReLU(), 87 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 88 | nn.BatchNorm2d(128), 89 | nn.ReLU(), 90 | ), 91 | nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 92 | nn.BatchNorm2d(256), 93 | ), 94 | Lambda(lambda x: x), # Identity, 95 | ), 96 | LambdaReduce(lambda x, y: x + y), # CAddTable, 97 | nn.ReLU(), 98 | ), 99 | ), 100 | nn.Sequential( # Sequential, 101 | nn.Sequential( # Sequential, 102 | LambdaMap(lambda x: x, # ConcatTable, 103 | nn.Sequential( # Sequential, 104 | nn.Sequential( # Sequential, 105 | nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 106 | nn.BatchNorm2d(256), 107 | nn.ReLU(), 108 | nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), 1, 32, bias=False), 109 | nn.BatchNorm2d(256), 110 | nn.ReLU(), 111 | ), 112 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 113 | nn.BatchNorm2d(512), 114 | ), 115 | nn.Sequential( # Sequential, 116 | nn.Conv2d(256, 512, (1, 1), (2, 2), (0, 0), 1, 1, bias=False), 117 | nn.BatchNorm2d(512), 118 | ), 119 | ), 120 | LambdaReduce(lambda x, y: x + y), # CAddTable, 121 | nn.ReLU(), 122 | ), 123 | nn.Sequential( # Sequential, 124 | LambdaMap(lambda x: x, # ConcatTable, 125 | nn.Sequential( # Sequential, 126 | nn.Sequential( # Sequential, 127 | nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 128 | nn.BatchNorm2d(256), 129 | nn.ReLU(), 130 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 131 | nn.BatchNorm2d(256), 132 | nn.ReLU(), 133 | ), 134 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 135 | nn.BatchNorm2d(512), 136 | ), 137 | Lambda(lambda x: x), # Identity, 138 | ), 139 | LambdaReduce(lambda x, y: x + y), # CAddTable, 140 | nn.ReLU(), 141 | ), 142 | nn.Sequential( # Sequential, 143 | LambdaMap(lambda x: x, # ConcatTable, 144 | nn.Sequential( # Sequential, 145 | nn.Sequential( # Sequential, 146 | nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 147 | nn.BatchNorm2d(256), 148 | nn.ReLU(), 149 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 150 | nn.BatchNorm2d(256), 151 | nn.ReLU(), 152 | ), 153 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 154 | nn.BatchNorm2d(512), 155 | ), 156 | Lambda(lambda x: x), # Identity, 157 | ), 158 | LambdaReduce(lambda x, y: x + y), # CAddTable, 159 | nn.ReLU(), 160 | ), 161 | nn.Sequential( # Sequential, 162 | LambdaMap(lambda x: x, # ConcatTable, 163 | nn.Sequential( # Sequential, 164 | nn.Sequential( # Sequential, 165 | nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 166 | nn.BatchNorm2d(256), 167 | nn.ReLU(), 168 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 169 | nn.BatchNorm2d(256), 170 | nn.ReLU(), 171 | ), 172 | nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 173 | nn.BatchNorm2d(512), 174 | ), 175 | Lambda(lambda x: x), # Identity, 176 | ), 177 | LambdaReduce(lambda x, y: x + y), # CAddTable, 178 | nn.ReLU(), 179 | ), 180 | ), 181 | nn.Sequential( # Sequential, 182 | nn.Sequential( # Sequential, 183 | LambdaMap(lambda x: x, # ConcatTable, 184 | nn.Sequential( # Sequential, 185 | nn.Sequential( # Sequential, 186 | nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 187 | nn.BatchNorm2d(512), 188 | nn.ReLU(), 189 | nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), 1, 32, bias=False), 190 | nn.BatchNorm2d(512), 191 | nn.ReLU(), 192 | ), 193 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 194 | nn.BatchNorm2d(1024), 195 | ), 196 | nn.Sequential( # Sequential, 197 | nn.Conv2d(512, 1024, (1, 1), (2, 2), (0, 0), 1, 1, bias=False), 198 | nn.BatchNorm2d(1024), 199 | ), 200 | ), 201 | LambdaReduce(lambda x, y: x + y), # CAddTable, 202 | nn.ReLU(), 203 | ), 204 | nn.Sequential( # Sequential, 205 | LambdaMap(lambda x: x, # ConcatTable, 206 | nn.Sequential( # Sequential, 207 | nn.Sequential( # Sequential, 208 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 209 | nn.BatchNorm2d(512), 210 | nn.ReLU(), 211 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 212 | nn.BatchNorm2d(512), 213 | nn.ReLU(), 214 | ), 215 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 216 | nn.BatchNorm2d(1024), 217 | ), 218 | Lambda(lambda x: x), # Identity, 219 | ), 220 | LambdaReduce(lambda x, y: x + y), # CAddTable, 221 | nn.ReLU(), 222 | ), 223 | nn.Sequential( # Sequential, 224 | LambdaMap(lambda x: x, # ConcatTable, 225 | nn.Sequential( # Sequential, 226 | nn.Sequential( # Sequential, 227 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 228 | nn.BatchNorm2d(512), 229 | nn.ReLU(), 230 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 231 | nn.BatchNorm2d(512), 232 | nn.ReLU(), 233 | ), 234 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 235 | nn.BatchNorm2d(1024), 236 | ), 237 | Lambda(lambda x: x), # Identity, 238 | ), 239 | LambdaReduce(lambda x, y: x + y), # CAddTable, 240 | nn.ReLU(), 241 | ), 242 | nn.Sequential( # Sequential, 243 | LambdaMap(lambda x: x, # ConcatTable, 244 | nn.Sequential( # Sequential, 245 | nn.Sequential( # Sequential, 246 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 247 | nn.BatchNorm2d(512), 248 | nn.ReLU(), 249 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 250 | nn.BatchNorm2d(512), 251 | nn.ReLU(), 252 | ), 253 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 254 | nn.BatchNorm2d(1024), 255 | ), 256 | Lambda(lambda x: x), # Identity, 257 | ), 258 | LambdaReduce(lambda x, y: x + y), # CAddTable, 259 | nn.ReLU(), 260 | ), 261 | nn.Sequential( # Sequential, 262 | LambdaMap(lambda x: x, # ConcatTable, 263 | nn.Sequential( # Sequential, 264 | nn.Sequential( # Sequential, 265 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 266 | nn.BatchNorm2d(512), 267 | nn.ReLU(), 268 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 269 | nn.BatchNorm2d(512), 270 | nn.ReLU(), 271 | ), 272 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 273 | nn.BatchNorm2d(1024), 274 | ), 275 | Lambda(lambda x: x), # Identity, 276 | ), 277 | LambdaReduce(lambda x, y: x + y), # CAddTable, 278 | nn.ReLU(), 279 | ), 280 | nn.Sequential( # Sequential, 281 | LambdaMap(lambda x: x, # ConcatTable, 282 | nn.Sequential( # Sequential, 283 | nn.Sequential( # Sequential, 284 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 285 | nn.BatchNorm2d(512), 286 | nn.ReLU(), 287 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 288 | nn.BatchNorm2d(512), 289 | nn.ReLU(), 290 | ), 291 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 292 | nn.BatchNorm2d(1024), 293 | ), 294 | Lambda(lambda x: x), # Identity, 295 | ), 296 | LambdaReduce(lambda x, y: x + y), # CAddTable, 297 | nn.ReLU(), 298 | ), 299 | nn.Sequential( # Sequential, 300 | LambdaMap(lambda x: x, # ConcatTable, 301 | nn.Sequential( # Sequential, 302 | nn.Sequential( # Sequential, 303 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 304 | nn.BatchNorm2d(512), 305 | nn.ReLU(), 306 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 307 | nn.BatchNorm2d(512), 308 | nn.ReLU(), 309 | ), 310 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 311 | nn.BatchNorm2d(1024), 312 | ), 313 | Lambda(lambda x: x), # Identity, 314 | ), 315 | LambdaReduce(lambda x, y: x + y), # CAddTable, 316 | nn.ReLU(), 317 | ), 318 | nn.Sequential( # Sequential, 319 | LambdaMap(lambda x: x, # ConcatTable, 320 | nn.Sequential( # Sequential, 321 | nn.Sequential( # Sequential, 322 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 323 | nn.BatchNorm2d(512), 324 | nn.ReLU(), 325 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 326 | nn.BatchNorm2d(512), 327 | nn.ReLU(), 328 | ), 329 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 330 | nn.BatchNorm2d(1024), 331 | ), 332 | Lambda(lambda x: x), # Identity, 333 | ), 334 | LambdaReduce(lambda x, y: x + y), # CAddTable, 335 | nn.ReLU(), 336 | ), 337 | nn.Sequential( # Sequential, 338 | LambdaMap(lambda x: x, # ConcatTable, 339 | nn.Sequential( # Sequential, 340 | nn.Sequential( # Sequential, 341 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 342 | nn.BatchNorm2d(512), 343 | nn.ReLU(), 344 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 345 | nn.BatchNorm2d(512), 346 | nn.ReLU(), 347 | ), 348 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 349 | nn.BatchNorm2d(1024), 350 | ), 351 | Lambda(lambda x: x), # Identity, 352 | ), 353 | LambdaReduce(lambda x, y: x + y), # CAddTable, 354 | nn.ReLU(), 355 | ), 356 | nn.Sequential( # Sequential, 357 | LambdaMap(lambda x: x, # ConcatTable, 358 | nn.Sequential( # Sequential, 359 | nn.Sequential( # Sequential, 360 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 361 | nn.BatchNorm2d(512), 362 | nn.ReLU(), 363 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 364 | nn.BatchNorm2d(512), 365 | nn.ReLU(), 366 | ), 367 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 368 | nn.BatchNorm2d(1024), 369 | ), 370 | Lambda(lambda x: x), # Identity, 371 | ), 372 | LambdaReduce(lambda x, y: x + y), # CAddTable, 373 | nn.ReLU(), 374 | ), 375 | nn.Sequential( # Sequential, 376 | LambdaMap(lambda x: x, # ConcatTable, 377 | nn.Sequential( # Sequential, 378 | nn.Sequential( # Sequential, 379 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 380 | nn.BatchNorm2d(512), 381 | nn.ReLU(), 382 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 383 | nn.BatchNorm2d(512), 384 | nn.ReLU(), 385 | ), 386 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 387 | nn.BatchNorm2d(1024), 388 | ), 389 | Lambda(lambda x: x), # Identity, 390 | ), 391 | LambdaReduce(lambda x, y: x + y), # CAddTable, 392 | nn.ReLU(), 393 | ), 394 | nn.Sequential( # Sequential, 395 | LambdaMap(lambda x: x, # ConcatTable, 396 | nn.Sequential( # Sequential, 397 | nn.Sequential( # Sequential, 398 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 399 | nn.BatchNorm2d(512), 400 | nn.ReLU(), 401 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 402 | nn.BatchNorm2d(512), 403 | nn.ReLU(), 404 | ), 405 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 406 | nn.BatchNorm2d(1024), 407 | ), 408 | Lambda(lambda x: x), # Identity, 409 | ), 410 | LambdaReduce(lambda x, y: x + y), # CAddTable, 411 | nn.ReLU(), 412 | ), 413 | nn.Sequential( # Sequential, 414 | LambdaMap(lambda x: x, # ConcatTable, 415 | nn.Sequential( # Sequential, 416 | nn.Sequential( # Sequential, 417 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 418 | nn.BatchNorm2d(512), 419 | nn.ReLU(), 420 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 421 | nn.BatchNorm2d(512), 422 | nn.ReLU(), 423 | ), 424 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 425 | nn.BatchNorm2d(1024), 426 | ), 427 | Lambda(lambda x: x), # Identity, 428 | ), 429 | LambdaReduce(lambda x, y: x + y), # CAddTable, 430 | nn.ReLU(), 431 | ), 432 | nn.Sequential( # Sequential, 433 | LambdaMap(lambda x: x, # ConcatTable, 434 | nn.Sequential( # Sequential, 435 | nn.Sequential( # Sequential, 436 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 437 | nn.BatchNorm2d(512), 438 | nn.ReLU(), 439 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 440 | nn.BatchNorm2d(512), 441 | nn.ReLU(), 442 | ), 443 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 444 | nn.BatchNorm2d(1024), 445 | ), 446 | Lambda(lambda x: x), # Identity, 447 | ), 448 | LambdaReduce(lambda x, y: x + y), # CAddTable, 449 | nn.ReLU(), 450 | ), 451 | nn.Sequential( # Sequential, 452 | LambdaMap(lambda x: x, # ConcatTable, 453 | nn.Sequential( # Sequential, 454 | nn.Sequential( # Sequential, 455 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 456 | nn.BatchNorm2d(512), 457 | nn.ReLU(), 458 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 459 | nn.BatchNorm2d(512), 460 | nn.ReLU(), 461 | ), 462 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 463 | nn.BatchNorm2d(1024), 464 | ), 465 | Lambda(lambda x: x), # Identity, 466 | ), 467 | LambdaReduce(lambda x, y: x + y), # CAddTable, 468 | nn.ReLU(), 469 | ), 470 | nn.Sequential( # Sequential, 471 | LambdaMap(lambda x: x, # ConcatTable, 472 | nn.Sequential( # Sequential, 473 | nn.Sequential( # Sequential, 474 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 475 | nn.BatchNorm2d(512), 476 | nn.ReLU(), 477 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 478 | nn.BatchNorm2d(512), 479 | nn.ReLU(), 480 | ), 481 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 482 | nn.BatchNorm2d(1024), 483 | ), 484 | Lambda(lambda x: x), # Identity, 485 | ), 486 | LambdaReduce(lambda x, y: x + y), # CAddTable, 487 | nn.ReLU(), 488 | ), 489 | nn.Sequential( # Sequential, 490 | LambdaMap(lambda x: x, # ConcatTable, 491 | nn.Sequential( # Sequential, 492 | nn.Sequential( # Sequential, 493 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 494 | nn.BatchNorm2d(512), 495 | nn.ReLU(), 496 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 497 | nn.BatchNorm2d(512), 498 | nn.ReLU(), 499 | ), 500 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 501 | nn.BatchNorm2d(1024), 502 | ), 503 | Lambda(lambda x: x), # Identity, 504 | ), 505 | LambdaReduce(lambda x, y: x + y), # CAddTable, 506 | nn.ReLU(), 507 | ), 508 | nn.Sequential( # Sequential, 509 | LambdaMap(lambda x: x, # ConcatTable, 510 | nn.Sequential( # Sequential, 511 | nn.Sequential( # Sequential, 512 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 513 | nn.BatchNorm2d(512), 514 | nn.ReLU(), 515 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 516 | nn.BatchNorm2d(512), 517 | nn.ReLU(), 518 | ), 519 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 520 | nn.BatchNorm2d(1024), 521 | ), 522 | Lambda(lambda x: x), # Identity, 523 | ), 524 | LambdaReduce(lambda x, y: x + y), # CAddTable, 525 | nn.ReLU(), 526 | ), 527 | nn.Sequential( # Sequential, 528 | LambdaMap(lambda x: x, # ConcatTable, 529 | nn.Sequential( # Sequential, 530 | nn.Sequential( # Sequential, 531 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 532 | nn.BatchNorm2d(512), 533 | nn.ReLU(), 534 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 535 | nn.BatchNorm2d(512), 536 | nn.ReLU(), 537 | ), 538 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 539 | nn.BatchNorm2d(1024), 540 | ), 541 | Lambda(lambda x: x), # Identity, 542 | ), 543 | LambdaReduce(lambda x, y: x + y), # CAddTable, 544 | nn.ReLU(), 545 | ), 546 | nn.Sequential( # Sequential, 547 | LambdaMap(lambda x: x, # ConcatTable, 548 | nn.Sequential( # Sequential, 549 | nn.Sequential( # Sequential, 550 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 551 | nn.BatchNorm2d(512), 552 | nn.ReLU(), 553 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 554 | nn.BatchNorm2d(512), 555 | nn.ReLU(), 556 | ), 557 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 558 | nn.BatchNorm2d(1024), 559 | ), 560 | Lambda(lambda x: x), # Identity, 561 | ), 562 | LambdaReduce(lambda x, y: x + y), # CAddTable, 563 | nn.ReLU(), 564 | ), 565 | nn.Sequential( # Sequential, 566 | LambdaMap(lambda x: x, # ConcatTable, 567 | nn.Sequential( # Sequential, 568 | nn.Sequential( # Sequential, 569 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 570 | nn.BatchNorm2d(512), 571 | nn.ReLU(), 572 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 573 | nn.BatchNorm2d(512), 574 | nn.ReLU(), 575 | ), 576 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 577 | nn.BatchNorm2d(1024), 578 | ), 579 | Lambda(lambda x: x), # Identity, 580 | ), 581 | LambdaReduce(lambda x, y: x + y), # CAddTable, 582 | nn.ReLU(), 583 | ), 584 | nn.Sequential( # Sequential, 585 | LambdaMap(lambda x: x, # ConcatTable, 586 | nn.Sequential( # Sequential, 587 | nn.Sequential( # Sequential, 588 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 589 | nn.BatchNorm2d(512), 590 | nn.ReLU(), 591 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 592 | nn.BatchNorm2d(512), 593 | nn.ReLU(), 594 | ), 595 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 596 | nn.BatchNorm2d(1024), 597 | ), 598 | Lambda(lambda x: x), # Identity, 599 | ), 600 | LambdaReduce(lambda x, y: x + y), # CAddTable, 601 | nn.ReLU(), 602 | ), 603 | nn.Sequential( # Sequential, 604 | LambdaMap(lambda x: x, # ConcatTable, 605 | nn.Sequential( # Sequential, 606 | nn.Sequential( # Sequential, 607 | nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 608 | nn.BatchNorm2d(512), 609 | nn.ReLU(), 610 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 611 | nn.BatchNorm2d(512), 612 | nn.ReLU(), 613 | ), 614 | nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 615 | nn.BatchNorm2d(1024), 616 | ), 617 | Lambda(lambda x: x), # Identity, 618 | ), 619 | LambdaReduce(lambda x, y: x + y), # CAddTable, 620 | nn.ReLU(), 621 | ), 622 | ), 623 | nn.Sequential( # Sequential, 624 | nn.Sequential( # Sequential, 625 | LambdaMap(lambda x: x, # ConcatTable, 626 | nn.Sequential( # Sequential, 627 | nn.Sequential( # Sequential, 628 | nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 629 | nn.BatchNorm2d(1024), 630 | nn.ReLU(), 631 | nn.Conv2d(1024, 1024, (3, 3), (2, 2), (1, 1), 1, 32, bias=False), 632 | nn.BatchNorm2d(1024), 633 | nn.ReLU(), 634 | ), 635 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 636 | nn.BatchNorm2d(2048), 637 | ), 638 | nn.Sequential( # Sequential, 639 | nn.Conv2d(1024, 2048, (1, 1), (2, 2), (0, 0), 1, 1, bias=False), 640 | nn.BatchNorm2d(2048), 641 | ), 642 | ), 643 | LambdaReduce(lambda x, y: x + y), # CAddTable, 644 | nn.ReLU(), 645 | ), 646 | nn.Sequential( # Sequential, 647 | LambdaMap(lambda x: x, # ConcatTable, 648 | nn.Sequential( # Sequential, 649 | nn.Sequential( # Sequential, 650 | nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 651 | nn.BatchNorm2d(1024), 652 | nn.ReLU(), 653 | nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 654 | nn.BatchNorm2d(1024), 655 | nn.ReLU(), 656 | ), 657 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 658 | nn.BatchNorm2d(2048), 659 | ), 660 | Lambda(lambda x: x), # Identity, 661 | ), 662 | LambdaReduce(lambda x, y: x + y), # CAddTable, 663 | nn.ReLU(), 664 | ), 665 | nn.Sequential( # Sequential, 666 | LambdaMap(lambda x: x, # ConcatTable, 667 | nn.Sequential( # Sequential, 668 | nn.Sequential( # Sequential, 669 | nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 670 | nn.BatchNorm2d(1024), 671 | nn.ReLU(), 672 | nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), 673 | nn.BatchNorm2d(1024), 674 | nn.ReLU(), 675 | ), 676 | nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), 677 | nn.BatchNorm2d(2048), 678 | ), 679 | Lambda(lambda x: x), # Identity, 680 | ), 681 | LambdaReduce(lambda x, y: x + y), # CAddTable, 682 | nn.ReLU(), 683 | ), 684 | ), 685 | nn.AvgPool2d((7, 7), (1, 1)), 686 | Lambda(lambda x: x.view(x.size(0), -1)), # View, 687 | nn.Sequential(Lambda(lambda x: x.view(1, -1) if 1 == len(x.size()) else x), nn.Linear(2048, 1000)), # Linear, 688 | ) 689 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from numpy.random import RandomState 5 | from torch.utils.data import Dataset 6 | import torchvision.transforms as transforms 7 | import PIL.Image as Image 8 | from randomcrop import RandomHorizontallyFlip 9 | 10 | class TrainValDataset(Dataset): 11 | def __init__(self, name): 12 | super().__init__() 13 | self.dataset = name 14 | self.root = '../SBU-shadow/SBUTrain4KRecoveredSmall/' 15 | self.imgs = open(self.dataset).readlines() 16 | self.file_num = len(self.imgs) 17 | 18 | self.hflip = RandomHorizontallyFlip() 19 | self.trans = transforms.Compose([ 20 | transforms.ToTensor(), 21 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 22 | ]) 23 | 24 | def __len__(self): 25 | return self.file_num * 100 26 | 27 | def __getitem__(self, index): 28 | image_path,label_path = self.imgs[index % self.file_num][:-1].split(' ') 29 | image = Image.open(self.root + image_path).convert('RGB').resize((400,400)) 30 | label = Image.open(self.root + label_path).convert('L').resize((400,400)) 31 | 32 | image,label = self.hflip(image,label) 33 | 34 | label = np.array(label,dtype='float32') / 255.0 35 | if len(label.shape) > 2: 36 | label = label[:,:,0] 37 | 38 | image_nom = self.trans(image) 39 | label = np.array([label]) 40 | 41 | sample = {'O': image_nom,'B':label,'image':np.array(image,dtype='float32').transpose(2,0,1)/255} 42 | return sample 43 | 44 | 45 | 46 | class TestDataset(Dataset): 47 | def __init__(self, name): 48 | super().__init__() 49 | self.dataset = name 50 | self.root = '../SBU-shadow/SBU-Test/' 51 | self.imgs = open(self.root + 'SBU.txt').readlines() 52 | self.file_num = len(self.imgs) 53 | 54 | self.hflip = RandomHorizontallyFlip() 55 | self.trans = transforms.Compose([ 56 | transforms.ToTensor(), 57 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 58 | ]) 59 | 60 | def __len__(self): 61 | return self.file_num 62 | 63 | def __getitem__(self, index): 64 | 65 | image_path,label_path = self.imgs[index % self.file_num][:-1].split(' ') 66 | image = Image.open(self.root + image_path).convert('RGB').resize((400,400)) 67 | label = Image.open(self.root + label_path).convert('L').resize((400,400)) 68 | 69 | label = np.array(label,dtype='float32') / 255.0 70 | if len(label.shape) > 2: 71 | label = label[:,:,0] 72 | 73 | image_nom = self.trans(image) 74 | 75 | sample = {'O': image_nom,'B':label,'image':np.array(image)} 76 | 77 | return sample 78 | -------------------------------------------------------------------------------- /dataset_sr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | import torchvision.transforms as transforms 6 | import PIL.Image as Image 7 | from randomcrop import RandomHorizontallyFlip 8 | 9 | class TrainValDataset(Dataset): 10 | def __init__(self, name): 11 | super().__init__() 12 | self.dataset = name 13 | self.root = '/home/zhxing/Datasets/SRD_inpaint4shadow_fix/' 14 | # self.root = '/home/zhxing/Datasets/ISTD+/' 15 | 16 | self.imgs = open(self.dataset).readlines() 17 | self.file_num = len(self.imgs) 18 | 19 | self.hflip = RandomHorizontallyFlip() 20 | self.trans = transforms.Compose([ 21 | transforms.ToTensor(), 22 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 23 | ]) 24 | 25 | def __len__(self): 26 | return self.file_num * 100 27 | 28 | def __getitem__(self, index): 29 | line = self.imgs[index % self.file_num].strip() 30 | parts = line.split() 31 | image_path, label_path = parts[0], parts[1] 32 | 33 | image = cv2.imread(self.root + image_path) 34 | label = cv2.imread(self.root + label_path) 35 | 36 | # Convert to LAB color space 37 | image_lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB) 38 | label_lab = cv2.cvtColor(label, cv2.COLOR_BGR2LAB) 39 | 40 | # Resize images 41 | image_lab = cv2.resize(image_lab, (512, 512)) 42 | label_lab = cv2.resize(label_lab, (512, 512)) 43 | 44 | # Convert to PIL Image for transformations 45 | image_lab = Image.fromarray(image_lab) 46 | label_lab = Image.fromarray(label_lab) 47 | 48 | # max and min value of image_lab 49 | # print("image_lab max: ", np.max(image_lab)) 50 | # print("image_lab min: ", np.min(image_lab)) 51 | 52 | image_lab, label_lab = self.hflip(image_lab, label_lab) 53 | 54 | # label_lab = np.array(label_lab, dtype='float32') / 255.0 55 | label_lab = np.array(label_lab, dtype='float32') 56 | 57 | image_nom = self.trans(image_lab) 58 | # print("image_nom max: ", image_nom.max()) 59 | # print("image_nom min: ", image_nom.min()) 60 | label_lab = np.array([label_lab]) 61 | # print("image_nom shape: ", image_nom.shape) 62 | # label_lab shape: (1, 512, 512, 3) 63 | # image_nom shape: torch.Size([3, 512, 512]) 64 | # align the shape of label_lab to image_nom 65 | label_lab = label_lab.transpose(3, 0, 1, 2) 66 | # label_lab shape: (3, 1, 512, 512) 67 | # align the shape of label_lab to image_nom 68 | label_lab = np.squeeze(label_lab) 69 | # print("label_lab shape: ", label_lab.shape) 70 | 71 | image_ori = np.array(image_lab, dtype='float32').transpose(2, 0, 1) 72 | sample = {'O': image_nom, 'B': label_lab, 'image': np.array(image_lab, dtype='float32').transpose(2, 0, 1) / 255, "image_ori": np.array(image_lab, dtype='float32').transpose(2, 0, 1)} 73 | 74 | return sample 75 | 76 | 77 | class TestDataset(Dataset): 78 | def __init__(self, name): 79 | super().__init__() 80 | self.dataset = name 81 | self.root = '/home/zhxing/Datasets/SRD_inpaint4shadow_fix/' 82 | # self.root = '/home/zhxing/Datasets/ISTD+/' 83 | # self.root = '/home/zhxing/Datasets/DESOBA_xvision/' 84 | 85 | self.imgs = open(self.root + 'test_dsc.txt').readlines() 86 | self.file_num = len(self.imgs) 87 | 88 | self.trans = transforms.Compose([ 89 | transforms.ToTensor(), 90 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 91 | ]) 92 | 93 | def __len__(self): 94 | return self.file_num 95 | 96 | def __getitem__(self, index): 97 | image_path, label_path = self.imgs[index % self.file_num][:-1].split(' ') 98 | image = cv2.imread(self.root + image_path) 99 | label = cv2.imread(self.root + label_path) 100 | 101 | # Convert to LAB color space 102 | image_lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB) 103 | label_lab = cv2.cvtColor(label, cv2.COLOR_BGR2LAB) 104 | 105 | # Resize images 106 | image_lab = cv2.resize(image_lab, (512, 512)) 107 | label_lab = cv2.resize(label_lab, (512, 512)) 108 | 109 | # Convert to PIL Image for transformations 110 | image_lab = Image.fromarray(image_lab) 111 | 112 | label_lab = np.array(label_lab, dtype='float32') / 255.0 113 | image_nom = self.trans(image_lab) 114 | 115 | label_lab = np.array([label_lab]) 116 | # print("image_nom shape: ", image_nom.shape) 117 | # label_lab shape: (1, 512, 512, 3) 118 | # image_nom shape: torch.Size([3, 512, 512]) 119 | # align the shape of label_lab to image_nom 120 | label_lab = label_lab.transpose(3, 0, 1, 2) 121 | # label_lab shape: (3, 1, 512, 512) 122 | # align the shape of label_lab to image_nom 123 | label_lab = np.squeeze(label_lab) 124 | # print("label_lab shape: ", label_lab.shape) 125 | 126 | # print the range of image_nom 127 | # print("image_nom max: ", image_nom.max()) 128 | 129 | image_ori = np.array(image_lab, dtype='float32').transpose(2, 0, 1) 130 | 131 | sample = {'O': image_nom, 'B': label_lab, 'image': np.array(image_lab), "image_ori": image_ori} 132 | 133 | 134 | return sample 135 | -------------------------------------------------------------------------------- /irnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import cupy 4 | from torch import nn 5 | import torch.nn.functional as F 6 | class Stream: 7 | ptr = torch.cuda.current_stream().cuda_stream 8 | 9 | IRNNForward = open('./IRNN_Forward_cuda.cu','r').read() 10 | 11 | IRNNBackward = open('./IRNN_Backward_cuda.cu','r').read() 12 | 13 | # IRNNWeightBaisBackward = open('./IRNN_Weight_Bias_Backward_cuda.cu','r').read() 14 | 15 | 16 | @cupy.memoize(for_each_device=True) 17 | def cunnex(strFunction): 18 | return cupy.cuda.compile_with_cache(globals()[strFunction]).get_function(strFunction) 19 | # end 20 | 21 | class irnn(torch.autograd.Function): 22 | def __init__(self): 23 | super(irnn, self).__init__() 24 | 25 | @staticmethod 26 | def forward(self, input_feature, weight_up, weight_right, weight_down, weight_left, bias_up, bias_right, bias_down, bias_left): 27 | 28 | 29 | assert(input_feature.is_contiguous() == True) 30 | assert(weight_left.is_contiguous() == True) 31 | assert(weight_right.is_contiguous() == True) 32 | assert(weight_down.is_contiguous() == True) 33 | 34 | assert(weight_up.is_contiguous() == True) 35 | assert(bias_left.is_contiguous() ==True) 36 | assert(bias_right.is_contiguous() == True) 37 | assert(bias_up.is_contiguous() == True) 38 | assert(bias_down.is_contiguous() == True) 39 | 40 | output_left = input_feature.clone() 41 | output_right = input_feature.clone() 42 | output_up = input_feature.clone() 43 | output_down = input_feature.clone() 44 | 45 | if input_feature.is_cuda == True: 46 | n = input_feature.nelement() 47 | cuda_num_threads = 1024 48 | cunnex('IRNNForward')( 49 | grid=tuple([ int((n + cuda_num_threads - 1) / cuda_num_threads ), 1, 1 ]), 50 | block=tuple([ cuda_num_threads , 1, 1 ]), 51 | args=[ 52 | input_feature.data_ptr(), 53 | 54 | weight_up.data_ptr(), 55 | weight_right.data_ptr(), 56 | weight_down.data_ptr(), 57 | weight_left.data_ptr(), 58 | 59 | bias_up.data_ptr(), 60 | bias_right.data_ptr(), 61 | bias_down.data_ptr(), 62 | bias_left.data_ptr(), 63 | 64 | output_up.data_ptr(), 65 | output_right.data_ptr(), 66 | output_down.data_ptr(), 67 | output_left.data_ptr(), 68 | 69 | input_feature.size(1), 70 | input_feature.size(2), 71 | input_feature.size(3), 72 | n], 73 | stream=Stream 74 | ) 75 | elif input_feature.is_cuda == False: 76 | raise NotImplementedError() 77 | 78 | 79 | self.save_for_backward(input_feature,weight_up,weight_right,weight_down,weight_left,output_up,output_right,output_down,output_left) 80 | 81 | 82 | return output_up,output_right,output_down,output_left 83 | # end 84 | 85 | @staticmethod 86 | def backward(self, grad_output_up,grad_output_right,grad_output_down,grad_output_left): 87 | 88 | input_feature,weight_up,weight_right,weight_down,weight_left,output_up,output_right,output_down,output_left = self.saved_tensors 89 | # print(weight_left) 90 | if grad_output_up.is_contiguous() != True: 91 | grad_output_up = grad_output_up.contiguous() 92 | if grad_output_right.is_contiguous() != True: 93 | grad_output_right = grad_output_right.contiguous() 94 | if grad_output_down.is_contiguous() != True: 95 | grad_output_down = grad_output_down.contiguous() 96 | if grad_output_left.is_contiguous() != True: 97 | grad_output_left = grad_output_left.contiguous() 98 | 99 | # init gradient of input_feature 100 | grad_input = torch.zeros_like(input_feature) 101 | # init gradient map of weights 102 | grad_weight_up_map = torch.zeros_like(input_feature) 103 | grad_weight_right_map = torch.zeros_like(input_feature) 104 | grad_weight_down_map = torch.zeros_like(input_feature) 105 | grad_weight_left_map = torch.zeros_like(input_feature) 106 | # init gradient of weights 107 | grad_weight_left = torch.zeros_like(weight_left) 108 | grad_weight_right = torch.zeros_like(weight_left) 109 | grad_weight_up = torch.zeros_like(weight_left) 110 | grad_weight_down = torch.zeros_like(weight_left) 111 | 112 | grad_bias_up_map = torch.zeros_like(input_feature) 113 | grad_bias_right_map = torch.zeros_like(input_feature) 114 | grad_bias_down_map = torch.zeros_like(input_feature) 115 | grad_bias_left_map = torch.zeros_like(input_feature) 116 | 117 | if input_feature.is_cuda == True: 118 | 119 | n = grad_input.nelement() 120 | 121 | cuda_num_threads = 1024 122 | cunnex('IRNNBackward')( 123 | grid=tuple([ int((n + cuda_num_threads - 1) / cuda_num_threads), 1, 1 ]), 124 | block=tuple([ cuda_num_threads , 1, 1 ]), 125 | args=[ 126 | grad_input.data_ptr(), 127 | 128 | grad_weight_up_map.data_ptr(), 129 | grad_weight_right_map.data_ptr(), 130 | grad_weight_down_map.data_ptr(), 131 | grad_weight_left_map.data_ptr(), 132 | 133 | grad_bias_up_map.data_ptr(), 134 | grad_bias_right_map.data_ptr(), 135 | grad_bias_down_map.data_ptr(), 136 | grad_bias_left_map.data_ptr(), 137 | 138 | weight_up.data_ptr(), 139 | weight_right.data_ptr(), 140 | weight_down.data_ptr(), 141 | weight_left.data_ptr(), 142 | 143 | grad_output_up.data_ptr(), 144 | grad_output_right.data_ptr(), 145 | grad_output_down.data_ptr(), 146 | grad_output_left.data_ptr(), 147 | 148 | output_up.data_ptr(), 149 | output_right.data_ptr(), 150 | output_down.data_ptr(), 151 | output_left.data_ptr(), 152 | 153 | input_feature.size(1), 154 | input_feature.size(2), 155 | input_feature.size(3), 156 | n], 157 | stream=Stream 158 | ) 159 | # print(grad_weight_left_map,"<-- grad weight map") 160 | 161 | grad_bias_up = torch.zeros_like(weight_left).reshape(weight_left.size(0)) 162 | grad_bias_right = torch.zeros_like(weight_left).reshape(weight_left.size(0)) 163 | grad_bias_down = torch.zeros_like(weight_left).reshape(weight_left.size(0)) 164 | grad_bias_left = torch.zeros_like(weight_left).reshape(weight_left.size(0)) 165 | 166 | grad_weight_left = grad_weight_left_map.sum(2).sum(2).sum(0).resize_as_(grad_weight_left) 167 | grad_weight_right = grad_weight_right_map.sum(2).sum(2).sum(0).resize_as_(grad_weight_left) 168 | grad_weight_up = grad_weight_up_map.sum(2).sum(2).sum(0).resize_as_(grad_weight_left) 169 | grad_weight_down = grad_weight_down_map.sum(2).sum(2).sum(0).resize_as_(grad_weight_left) 170 | 171 | grad_bias_up = grad_bias_up_map.sum(2).sum(2).sum(0).resize_as_(grad_bias_up) 172 | grad_bias_right = grad_bias_right_map.sum(2).sum(2).sum(0).resize_as_(grad_bias_up) 173 | grad_bias_down = grad_bias_down_map.sum(2).sum(2).sum(0).resize_as_(grad_bias_up) 174 | grad_bias_left = grad_bias_left_map.sum(2).sum(2).sum(0).resize_as_(grad_bias_up) 175 | 176 | 177 | 178 | # n = input_feature.size(1) 179 | # cuda_num_threads = n 180 | # cunnex('IRNNWeightBaisBackward')( 181 | # grid=tuple([ int((n + cuda_num_threads - 1) / cuda_num_threads), 1, 1 ]), 182 | # block=tuple([ cuda_num_threads , 1, 1 ]), 183 | # args=[ 184 | # grad_weight_up_map.data_ptr(), 185 | # grad_weight_right_map.data_ptr(), 186 | # grad_weight_down_map.data_ptr(), 187 | # grad_weight_left_map.data_ptr(), 188 | 189 | # grad_bias_up_map.data_ptr(), 190 | # grad_bias_right_map.data_ptr(), 191 | # grad_bias_down_map.data_ptr(), 192 | # grad_bias_left_map.data_ptr(), 193 | 194 | # grad_weight_up.data_ptr(), 195 | # grad_weight_right.data_ptr(), 196 | # grad_weight_down.data_ptr(), 197 | # grad_weight_left.data_ptr(), 198 | 199 | # grad_bias_up.data_ptr(), 200 | # grad_bias_right.data_ptr(), 201 | # grad_bias_down.data_ptr(), 202 | # grad_bias_left.data_ptr(), 203 | 204 | # input_feature.size(0), 205 | # input_feature.size(1), 206 | # input_feature.size(2), 207 | # input_feature.size(3), 208 | # n], 209 | # stream=Stream 210 | # ) 211 | 212 | elif input_feature.is_cuda == False: 213 | raise NotImplementedError() 214 | 215 | 216 | return grad_input, grad_weight_up,grad_weight_right,grad_weight_down,grad_weight_left,grad_bias_up, grad_bias_right, grad_bias_down, grad_bias_left 217 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import numpy as np 6 | import logging 7 | import time 8 | from DSC import DSC 9 | import torch 10 | from torch import nn 11 | from torch.nn import MSELoss 12 | from torch import optim 13 | import torch.nn.functional as F 14 | from torch.optim.lr_scheduler import MultiStepLR 15 | from torch.utils.data import DataLoader 16 | 17 | from tensorboardX import SummaryWriter 18 | import skimage.measure as ms 19 | import progressbar 20 | import skimage.io as io 21 | import PIL.Image as I 22 | from dataset import TrainValDataset, TestDataset 23 | from misc import crf_refine 24 | import shutil 25 | from utils import MyWcploss 26 | 27 | logger = logging.getLogger('train') 28 | logger.setLevel(logging.INFO) 29 | ch = logging.StreamHandler() 30 | ch.setLevel(logging.INFO) 31 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 32 | ch.setFormatter(formatter) 33 | logger.addHandler(ch) 34 | torch.cuda.manual_seed_all(2018) 35 | torch.manual_seed(2018) 36 | torch.backends.cudnn.benchmark = True 37 | 38 | 39 | def ensure_dir(dir_path): 40 | if not os.path.isdir(dir_path): 41 | os.makedirs(dir_path) 42 | 43 | 44 | class Session: 45 | def __init__(self): 46 | self.device = torch.device("cuda") 47 | 48 | self.log_dir = './logdir' 49 | self.model_dir = './SBU_model' 50 | ensure_dir(self.log_dir) 51 | ensure_dir(self.model_dir) 52 | self.log_name = 'train_SBU_alpha_1' 53 | self.val_log_name = 'val_SBU_alpha_1' 54 | logger.info('set log dir as %s' % self.log_dir) 55 | logger.info('set model dir as %s' % self.model_dir) 56 | 57 | self.test_data_path = '../SBU-shadow/SBU-Test/' # test dataset txt file path 58 | self.train_data_path = '../SBU-shadow/SBUTrain4KRecoveredSmall/SBU.txt' # train dataset txt file path 59 | 60 | self.multi_gpu = True 61 | self.net = DSC().to(self.device) 62 | self.bce = MyWcploss().to(self.device) 63 | 64 | self.step = 0 65 | self.save_steps = 200 66 | self.num_workers = 16 67 | self.batch_size = 4 68 | self.writers = {} 69 | self.dataloaders = {} 70 | self.shuffle = True 71 | self.opt = optimizer = optim.SGD([ 72 | {'params': [param for name, param in self.net.named_parameters() if name[-4:] == 'bias'], 73 | 'lr': 2 * 5e-3}, 74 | {'params': [param for name, param in self.net.named_parameters() if name[-4:] != 'bias'], 75 | 'lr': 5e-3, 'weight_decay': 5e-4} 76 | ], momentum= 0.9) 77 | 78 | def tensorboard(self, name): 79 | self.writers[name] = SummaryWriter(os.path.join(self.log_dir, name + '.events')) 80 | return self.writers[name] 81 | 82 | def write(self, name, out): 83 | for k, v in out.items(): 84 | self.writers[name].add_scalar(k, v, self.step) 85 | 86 | out['lr'] = self.opt.param_groups[0]['lr'] 87 | out['step'] = self.step 88 | outputs = [ 89 | "{}:{:.4g}".format(k, v) 90 | for k, v in out.items() 91 | ] 92 | logger.info(name + '--' + ' '.join(outputs)) 93 | 94 | def get_dataloader(self, dataset_name, train_mode=True): 95 | dataset = { 96 | True: TrainValDataset, 97 | False: TestDataset, 98 | }[train_mode](dataset_name) 99 | self.dataloaders[dataset_name] = \ 100 | DataLoader(dataset, batch_size=self.batch_size, 101 | shuffle=self.shuffle, num_workers=self.num_workers, drop_last=True) 102 | if train_mode: 103 | return iter(self.dataloaders[dataset_name]) 104 | else: 105 | return self.dataloaders[dataset_name] 106 | 107 | def save_checkpoints(self, name): 108 | ckp_path = os.path.join(self.model_dir, name) 109 | if self.multi_gpu : 110 | obj = { 111 | 'net': self.net.module.state_dict(), 112 | 'clock': self.step, 113 | 'opt': self.opt.state_dict(), 114 | } 115 | else: 116 | obj = { 117 | 'net': self.net.state_dict(), 118 | 'clock': self.step, 119 | 'opt': self.opt.state_dict(), 120 | } 121 | torch.save(obj, ckp_path) 122 | 123 | def load_checkpoints(self, name,mode='train'): 124 | ckp_path = os.path.join(self.model_dir, name) 125 | try: 126 | obj = torch.load(ckp_path) 127 | except FileNotFoundError: 128 | return 129 | self.net.load_state_dict(obj['net']) 130 | if mode == 'train': 131 | self.step = obj['clock'] 132 | if mode == 'test': 133 | path = '../realtest/{}/'.format(self.model_dir[2:]) 134 | ensure_dir(path) 135 | shutil.copy(ckp_path,path) 136 | 137 | 138 | def inf_batch(self, name, batch): 139 | if name == 'test': 140 | torch.set_grad_enabled(False) 141 | O, B,= batch['O'], batch['B'] 142 | O, B = O.to(self.device), B.to(self.device) 143 | 144 | predicts= self.net(O) 145 | predict_4, predict_3, predict_2, predict_1, predict_0, predict_g, predict_f = predicts 146 | if name == 'test': 147 | predicts = [F.sigmoid(predict_4), F.sigmoid(predict_3), F.sigmoid(predict_2), \ 148 | F.sigmoid(predict_1), F.sigmoid(predict_0), F.sigmoid(predict_g), \ 149 | F.sigmoid(predict_f)] 150 | return predicts 151 | 152 | loss_4 = self.bce(predict_4, B) 153 | loss_3 = self.bce(predict_3, B) 154 | loss_2 = self.bce(predict_2, B) 155 | loss_1 = self.bce(predict_1, B) 156 | loss_0 = self.bce(predict_0, B) 157 | loss_g = self.bce(predict_g, B) 158 | loss_f = self.bce(predict_f, B) 159 | predicts = [F.sigmoid(predict_4), F.sigmoid(predict_3), F.sigmoid(predict_2), \ 160 | F.sigmoid(predict_1), F.sigmoid(predict_0), F.sigmoid(predict_g), \ 161 | F.sigmoid(predict_f)] 162 | loss = loss_4 + loss_3 + loss_2 + loss_1 + loss_0 + loss_g + loss_f 163 | # log 164 | losses = { 165 | 'loss_all' : loss.item(), 166 | 'loss_0' : loss_0.item(), 167 | 'loss_1' : loss_1.item(), 168 | 'loss_2' : loss_2.item(), 169 | 'loss_3' : loss_3.item(), 170 | 'loss_4' : loss_4.item(), 171 | 'loss_g' : loss_g.item(), 172 | 'loss_f' : loss_f.item() 173 | } 174 | 175 | return predicts, loss, losses 176 | 177 | 178 | def save_mask(self, name, img_lists,m = 0): 179 | data, label, predicts = img_lists 180 | 181 | data, label= (data.numpy() * 255).astype('uint8'), (label.numpy() * 255).astype('uint8') 182 | 183 | label = np.tile(label,(3,1,1)) 184 | 185 | h, w = 400,400 186 | 187 | gen_num = (2,1) 188 | 189 | predict_4, predict_3, predict_2, predict_1, predict_0, predict_g, predict_f = predicts 190 | 191 | predict_4, predict_3, predict_2, predict_1, predict_0, predict_g, predict_f = \ 192 | (np.tile(predict_4.cpu().data * 255,(3,1,1))).astype('uint8'), \ 193 | (np.tile(predict_3.cpu().data * 255,(3,1,1))).astype('uint8'), \ 194 | (np.tile(predict_2.cpu().data * 255,(3,1,1))).astype('uint8'), \ 195 | (np.tile(predict_1.cpu().data * 255,(3,1,1))).astype('uint8'), \ 196 | (np.tile(predict_0.cpu().data * 255,(3,1,1))).astype('uint8'), \ 197 | (np.tile(predict_g.cpu().data * 255,(3,1,1))).astype('uint8'), \ 198 | (np.tile(predict_f.cpu().data * 255,(3,1,1))).astype('uint8') 199 | 200 | img = np.zeros((gen_num[0] * h, gen_num[1] * 9 * w, 3)) 201 | for img_list in img_lists: 202 | for i in range(gen_num[0]): 203 | row = i * h 204 | for j in range(gen_num[1]): 205 | idx = i * gen_num[1] + j 206 | tmp_list = [data[idx], label[idx],predict_4[idx], predict_3[idx], predict_2[idx], predict_1[idx], predict_0[idx], predict_g[idx], predict_f[idx]] 207 | for k in range(9): 208 | col = (j * 9 + k) * w 209 | tmp = np.transpose(tmp_list[k], (1, 2, 0)) 210 | # print(tmp.shape) 211 | img[row: row+h, col: col+w] = tmp 212 | 213 | img_file = os.path.join(self.log_dir, '%d_%s.jpg' % (self.step, name)) 214 | io.imsave(img_file, img) 215 | 216 | 217 | 218 | def run_train_val(ckp_name='latest'): 219 | sess = Session() 220 | sess.load_checkpoints(ckp_name) 221 | if sess.multi_gpu : 222 | sess.net = nn.DataParallel(sess.net) 223 | sess.tensorboard(sess.log_name) 224 | sess.tensorboard(sess.val_log_name) 225 | 226 | dt_train = sess.get_dataloader(sess.train_data_path) 227 | dt_val = sess.get_dataloader(sess.train_data_path) 228 | 229 | while sess.step <= 5000: 230 | # sess.sche.step() 231 | sess.opt.param_groups[0]['lr'] = 2 * 5e-3 * (1 - float(sess.step) / 5000 232 | ) ** 0.9 233 | sess.opt.param_groups[1]['lr'] = 5e-3 * (1 - float(sess.step) / 5000 234 | ) ** 0.9 235 | sess.net.train() 236 | sess.net.zero_grad() 237 | 238 | batch_t = next(dt_train) 239 | 240 | # out, loss, losses, predicts 241 | pred_t, loss_t, losses_t = sess.inf_batch(sess.log_name, batch_t) 242 | sess.write(sess.log_name, losses_t) 243 | 244 | loss_t.backward() 245 | 246 | sess.opt.step() 247 | if sess.step % 10 == 0: 248 | sess.net.eval() 249 | batch_v = next(dt_val) 250 | pred_v, loss_v, losses_v = sess.inf_batch(sess.val_log_name, batch_v) 251 | sess.write(sess.val_log_name, losses_v) 252 | if sess.step % int(sess.save_steps / 5) == 0: 253 | sess.save_checkpoints('latest') 254 | if sess.step % int(sess.save_steps / 10) == 0: 255 | sess.save_mask(sess.log_name, [batch_t['image'], batch_t['B'],pred_t]) 256 | if sess.step % 10 == 0: 257 | sess.save_mask(sess.val_log_name, [batch_v['image'], batch_v['B'],pred_v]) 258 | logger.info('save image as step_%d' % sess.step) 259 | if sess.step % (sess.save_steps * 5) == 0: 260 | sess.save_checkpoints('step_%d' % sess.step) 261 | logger.info('save model as step_%d' % sess.step) 262 | sess.step += 1 263 | 264 | 265 | def run_test(ckp_name): 266 | sess = Session() 267 | sess.net.eval() 268 | sess.load_checkpoints(ckp_name,'test') 269 | if sess.multi_gpu : 270 | sess.net = nn.DataParallel(sess.net) 271 | sess.batch_size = 1 272 | sess.shuffle = False 273 | sess.outs = -1 274 | dt = sess.get_dataloader(sess.test_data_path, train_mode=False) 275 | 276 | 277 | input_names = open(sess.test_data_path+'SBU.txt').readlines() 278 | widgets = [progressbar.Percentage(),progressbar.Bar(),progressbar.ETA()] 279 | bar = progressbar.ProgressBar(widgets=widgets,maxval=len(dt)).start() 280 | for i, batch in enumerate(dt): 281 | 282 | pred = sess.inf_batch('test', batch) 283 | image = I.open(sess.test_data_path+input_names[i].split(' ')[0]).convert('RGB') 284 | final = I.fromarray((pred[-1].cpu().data * 255).numpy().astype('uint8')[0,0,:,:]) 285 | final = np.array(final.resize(image.size)) 286 | final_crf = crf_refine(np.array(image),final) 287 | ensure_dir('./results') 288 | io.imsave('./results/'+input_names[i].split(' ')[0].split('/')[1][:-3]+'png',final_crf) 289 | bar.update(i+1) 290 | 291 | 292 | 293 | 294 | if __name__ == '__main__': 295 | parser = argparse.ArgumentParser() 296 | parser.add_argument('-a', '--action', default='test') 297 | parser.add_argument('-m', '--model', default='latest') 298 | 299 | args = parser.parse_args(sys.argv[1:]) 300 | 301 | if args.action == 'train': 302 | run_train_val(args.model) 303 | elif args.action == 'test': 304 | run_test(args.model) 305 | 306 | -------------------------------------------------------------------------------- /main_sr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | import numpy as np 6 | import logging 7 | import time 8 | from DSC_sr import DSC 9 | import torch 10 | from torch import nn 11 | from torch.nn import MSELoss 12 | from torch import optim 13 | import torch.nn.functional as F 14 | from torch.optim.lr_scheduler import MultiStepLR 15 | from torch.utils.data import DataLoader 16 | 17 | from tensorboardX import SummaryWriter 18 | import skimage.measure as ms 19 | import progressbar 20 | import skimage.io as io 21 | import PIL.Image as I 22 | from dataset_sr import TrainValDataset, TestDataset 23 | import shutil 24 | from utils import MyWcploss, ShadowRemovalL1Loss 25 | 26 | 27 | logger = logging.getLogger('train') 28 | logger.setLevel(logging.INFO) 29 | ch = logging.StreamHandler() 30 | ch.setLevel(logging.INFO) 31 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 32 | ch.setFormatter(formatter) 33 | logger.addHandler(ch) 34 | torch.cuda.manual_seed_all(2018) 35 | torch.manual_seed(2018) 36 | torch.backends.cudnn.benchmark = True 37 | torch.cuda.set_device(1) 38 | 39 | iter_num = 320000 #160000 40 | 41 | def ensure_dir(dir_path): 42 | if not os.path.isdir(dir_path): 43 | os.makedirs(dir_path) 44 | 45 | 46 | class L2Loss(nn.Module): 47 | def __init__(self): 48 | super(L2Loss, self).__init__() 49 | 50 | def forward(self, predicted, target): 51 | return torch.mean((predicted - target) ** 2) 52 | 53 | class Session: 54 | def __init__(self): 55 | self.device = torch.device("cuda") 56 | 57 | # SRD 58 | self.log_dir = './SRD512_logdir' 59 | self.model_dir = './SRD512_model' 60 | # ensure_dir(self.log_dir) 61 | ensure_dir(self.model_dir) 62 | self.log_name = 'train_SRD512_alpha_1' 63 | self.val_log_name = 'val_SRD512_alpha_1' 64 | logger.info('set log dir as %s' % self.log_dir) 65 | logger.info('set model dir as %s' % self.model_dir) 66 | self.test_data_path = '/home/zhxing/Datasets/SRD_inpaint4shadow_fix/' 67 | # self.test_data_path = '/home/zhxing/Datasets/DESOBA_xvision/' 68 | self.train_data_path = '/home/zhxing/Datasets/SRD_inpaint4shadow_fix/train_dsc.txt' 69 | 70 | # ISTD 71 | # self.log_dir = './ISTD+512_logdir' 72 | # self.model_dir = './ISTD+512_model' 73 | # ensure_dir(self.log_dir) 74 | # ensure_dir(self.model_dir) 75 | # self.log_name = 'train_ISTD+512_alpha_1' 76 | # self.val_log_name = 'val_ISTD+512_alpha_1' 77 | # logger.info('set log dir as %s' % self.log_dir) 78 | # logger.info('set model dir as %s' % self.model_dir) 79 | # self.test_data_path = '/home/zhxing/Datasets/ISTD+/' 80 | # self.train_data_path = '/home/zhxing/Datasets/ISTD+/train_dsc.txt' 81 | 82 | self.multi_gpu = False 83 | self.net = DSC().to(self.device) 84 | self.l2_loss = L2Loss().to(self.device) 85 | 86 | 87 | 88 | self.step = 0 89 | self.save_steps = 20000 90 | self.num_workers = 16 91 | self.batch_size = 2 # 92 | self.writers = {} 93 | self.dataloaders = {} 94 | self.shuffle = True 95 | self.opt = optim.Adam([ 96 | {'params': [param for name, param in self.net.named_parameters() if name[-4:] == 'bias'], 97 | 'lr': 1e-5}, # Adjust learning rate for Adam 98 | {'params': [param for name, param in self.net.named_parameters() if name[-4:] != 'bias'], 99 | 'lr': 1e-5, 'weight_decay': 5e-4} 100 | ], betas=(0.9, 0.999)) # Typical beta values for Adam 101 | 102 | def tensorboard(self, name): 103 | self.writers[name] = SummaryWriter(os.path.join(self.log_dir, name + '.events')) 104 | return self.writers[name] 105 | 106 | def write(self, name, out): 107 | for k, v in out.items(): 108 | self.writers[name].add_scalar(k, v, self.step) 109 | 110 | out['lr'] = self.opt.param_groups[0]['lr'] 111 | out['step'] = self.step 112 | outputs = [ 113 | "{}:{:.4g}".format(k, v) 114 | for k, v in out.items() 115 | ] 116 | logger.info(name + '--' + ' '.join(outputs)) 117 | 118 | def get_dataloader(self, dataset_name, train_mode=True): 119 | dataset = { 120 | True: TrainValDataset, 121 | False: TestDataset, 122 | }[train_mode](dataset_name) 123 | self.dataloaders[dataset_name] = \ 124 | DataLoader(dataset, batch_size=self.batch_size, 125 | shuffle=self.shuffle, num_workers=self.num_workers, drop_last=True) 126 | if train_mode: 127 | return iter(self.dataloaders[dataset_name]) 128 | else: 129 | return self.dataloaders[dataset_name] 130 | 131 | def save_checkpoints(self, name): 132 | ckp_path = os.path.join(self.model_dir, name) 133 | if self.multi_gpu : 134 | obj = { 135 | 'net': self.net.module.state_dict(), 136 | 'clock': self.step, 137 | 'opt': self.opt.state_dict(), 138 | } 139 | else: 140 | obj = { 141 | 'net': self.net.state_dict(), 142 | 'clock': self.step, 143 | 'opt': self.opt.state_dict(), 144 | } 145 | torch.save(obj, ckp_path) 146 | 147 | def load_checkpoints(self, name,mode='train'): 148 | ckp_path = os.path.join(self.model_dir, name) 149 | try: 150 | obj = torch.load(ckp_path) 151 | except FileNotFoundError: 152 | return 153 | self.net.load_state_dict(obj['net']) 154 | if mode == 'train': 155 | self.step = obj['clock'] 156 | if mode == 'test': 157 | path = '../realtest/{}/'.format(self.model_dir[2:]) 158 | ensure_dir(path) 159 | shutil.copy(ckp_path,path) 160 | 161 | 162 | def inf_batch(self, name, batch): 163 | if name == 'test': 164 | torch.set_grad_enabled(False) 165 | O, B = batch['O'], batch['B'] 166 | O, B = O.to(self.device), B.to(self.device) 167 | 168 | predicts = self.net(O, batch['image_ori']) 169 | predict_4, predict_3, predict_2, predict_1, predict_0, predict_g, predict_f = predicts 170 | 171 | if name == 'test': 172 | # No sigmoid for shadow removal task 173 | return predicts 174 | 175 | # Calculate losses without sigmoid 176 | loss_4 = self.l2_loss(predict_4, B) 177 | loss_3 = self.l2_loss(predict_3, B) 178 | loss_2 = self.l2_loss(predict_2, B) 179 | loss_1 = self.l2_loss(predict_1, B) 180 | loss_0 = self.l2_loss(predict_0, B) 181 | loss_g = self.l2_loss(predict_g, B) 182 | loss_f = self.l2_loss(predict_f, B) 183 | 184 | loss = loss_4 + loss_3 + loss_2 + loss_1 + loss_0 + loss_g + loss_f 185 | 186 | # Log the losses 187 | losses = { 188 | 'loss_all': loss.item(), 189 | 'loss_0': loss_0.item(), 190 | 'loss_1': loss_1.item(), 191 | 'loss_2': loss_2.item(), 192 | 'loss_3': loss_3.item(), 193 | 'loss_4': loss_4.item(), 194 | 'loss_g': loss_g.item(), 195 | 'loss_f': loss_f.item() 196 | } 197 | 198 | return predicts, loss, losses 199 | 200 | 201 | def save_mask(self, name, img_lists): 202 | data, label, predicts = img_lists 203 | 204 | # 将数据和标签从LAB转换为RGB,并确保缩放和转换 205 | data = (data.numpy().transpose(0, 2, 3, 1) * 255).astype('uint8') # 假设数据格式为 (N, C, H, W) 206 | # label = (label.numpy().transpose(0, 2, 3, 1) * 255).astype('uint8') # 假设标签格式为 (N, C, H, W) 207 | label = (label.numpy().transpose(0, 2, 3, 1)).astype('uint8') # 假设标签格式为 (N, C, H, W) 208 | 209 | # 将预测转换为numpy数组,确保它们是3通道图像并缩放到255 210 | # predicts = [ 211 | # (predict.cpu().data.numpy().transpose(0, 2, 3, 1) * 255).astype('float32') # 假设预测格式为 (N, C, H, W) 212 | # for predict in predicts 213 | # ] 214 | 215 | predicts = [ 216 | (predict.cpu().data.numpy().transpose(0, 2, 3, 1)).astype('float32') # 假设预测格式为 (N, C, H, W) 217 | for predict in predicts 218 | ] 219 | 220 | # LAB到RGB转换 221 | def lab_to_rgb(lab_img): 222 | lab_img = lab_img.astype('float32') 223 | lab_img[:, :, 0] = lab_img[:, :, 0] * 100 / 255.0 # L通道范围 [0, 100] 224 | lab_img[:, :, 1] = lab_img[:, :, 1] - 128 # a通道范围 [-128, 127] 225 | lab_img[:, :, 2] = lab_img[:, :, 2] - 128 # b通道范围 [-128, 127] 226 | lab_img = cv2.cvtColor(lab_img, cv2.COLOR_LAB2RGB) 227 | lab_img = np.clip(lab_img * 255, 0, 255).astype('uint8') 228 | return lab_img 229 | 230 | data = np.array([lab_to_rgb(img) for img in data]) 231 | label = np.array([lab_to_rgb(img) for img in label]) 232 | predicts = [np.array([lab_to_rgb(img) for img in predict]) for predict in predicts] 233 | 234 | h, w = predicts[-1].shape[1:3] 235 | num_preds = len(predicts) 236 | gen_num = (2, 1) if len(data) > 1 else (1, 1) 237 | 238 | # 准备输出图像 239 | img = np.zeros((gen_num[0] * h, gen_num[1] * (2 + num_preds) * w, 3), dtype='uint8') 240 | 241 | for i in range(gen_num[0]): 242 | row = i * h 243 | for j in range(gen_num[1]): 244 | idx = i * gen_num[1] + j 245 | tmp_list = [data[idx], label[idx]] + [predict[idx] for predict in predicts] 246 | 247 | for k in range(len(tmp_list)): 248 | col = (j * (2 + num_preds) + k) * w 249 | tmp = tmp_list[k] 250 | img[row: row + h, col: col + w] = tmp 251 | 252 | # 保存图像 253 | img_file = os.path.join(self.log_dir, f'{self.step}_{name}.jpg') 254 | io.imsave(img_file, img) 255 | 256 | 257 | 258 | 259 | def run_train_val(ckp_name='latest'): 260 | sess = Session() 261 | sess.load_checkpoints(ckp_name) 262 | if sess.multi_gpu : 263 | sess.net = nn.DataParallel(sess.net) 264 | sess.tensorboard(sess.log_name) 265 | sess.tensorboard(sess.val_log_name) 266 | 267 | dt_train = sess.get_dataloader(sess.train_data_path) 268 | dt_val = sess.get_dataloader(sess.train_data_path) 269 | 270 | while sess.step <= iter_num: 271 | # sess.sche.step() 272 | 273 | sess.opt.param_groups[0]['lr'] = 2 * 5e-4 * (1 - float(sess.step) / iter_num 274 | ) ** 0.9 275 | sess.opt.param_groups[1]['lr'] = 5e-4 * (1 - float(sess.step) / iter_num 276 | ) ** 0.9 277 | 278 | 279 | sess.net.train() 280 | sess.net.zero_grad() 281 | 282 | try: 283 | batch_t = next(dt_train) 284 | except StopIteration: 285 | dt_train = iter(sess.get_dataloader(sess.train_data_path)) 286 | batch_t = next(dt_train) 287 | 288 | # out, loss, losses, predicts 289 | pred_t, loss_t, losses_t = sess.inf_batch(sess.log_name, batch_t) 290 | sess.write(sess.log_name, losses_t) 291 | 292 | loss_t.backward() 293 | 294 | sess.opt.step() 295 | if sess.step % 10 == 0: 296 | sess.net.eval() 297 | batch_v = next(dt_val) 298 | pred_v, loss_v, losses_v = sess.inf_batch(sess.val_log_name, batch_v) 299 | sess.write(sess.val_log_name, losses_v) 300 | if sess.step % int(sess.save_steps / 5) == 0: 301 | sess.save_checkpoints('latest') 302 | if sess.step % int(sess.save_steps / 10) == 0: 303 | sess.save_mask(sess.log_name, [batch_t['image'], batch_t['B'],pred_t]) 304 | if sess.step % 10 == 0: 305 | sess.save_mask(sess.val_log_name, [batch_v['image'], batch_v['B'],pred_v]) 306 | logger.info('save image as step_%d' % sess.step) 307 | if sess.step % (sess.save_steps * 5) == 0: 308 | sess.save_checkpoints('step_%d' % sess.step) 309 | logger.info('save model as step_%d' % sess.step) 310 | sess.step += 1 311 | sess.save_checkpoints('final') 312 | 313 | # for run_test function 314 | def ensure_dir(path): 315 | if not os.path.exists(path): 316 | os.makedirs(path) 317 | 318 | import os 319 | import numpy as np 320 | import torch.nn as nn 321 | import progressbar 322 | from PIL import Image 323 | import cv2 324 | 325 | def run_test(ckp_name): 326 | sess = Session() 327 | sess.net.eval() 328 | sess.load_checkpoints(ckp_name, 'test') 329 | 330 | num_params = sum(p.numel() for p in sess.net.parameters()) 331 | print(f'Number of model parameters: {num_params}') 332 | 333 | if sess.multi_gpu: 334 | sess.net = nn.DataParallel(sess.net) 335 | 336 | sess.batch_size = 1 337 | sess.shuffle = False 338 | sess.outs = -1 339 | dt = sess.get_dataloader(sess.test_data_path, train_mode=False) 340 | 341 | input_names = open(os.path.join(sess.test_data_path, 'test_dsc.txt')).readlines() # "test.txt" 342 | widgets = [progressbar.Percentage(), progressbar.Bar(), progressbar.ETA()] 343 | bar = progressbar.ProgressBar(widgets=widgets, maxval=len(dt)).start() 344 | 345 | for i, batch in enumerate(dt): 346 | pred = sess.inf_batch('test', batch) 347 | saved_pred = pred[-1] # tensor, shape 1,3,512,512, value [-1,1], should scaled to LAB space and then scaled to rgb space to save the image 348 | 349 | # Scale the prediction to LAB space 350 | saved_pred = (saved_pred.cpu().data.numpy().transpose(0, 2, 3, 1)).astype('float32') # (N, C, H, W) to (N, H, W, C) 351 | 352 | 353 | # LAB to RGB conversion 354 | def lab_to_rgb(lab_img): 355 | lab_img = lab_img.astype('float32') 356 | lab_img[:, :, 0] = lab_img[:, :, 0] * 100 / 255.0 # L channel range [0, 100] 357 | lab_img[:, :, 1] = lab_img[:, :, 1] - 128 # a channel range [-128, 127] 358 | lab_img[:, :, 2] = lab_img[:, :, 2] - 128 # b channel range [-128, 127] 359 | lab_img = cv2.cvtColor(lab_img, cv2.COLOR_LAB2RGB) 360 | lab_img = np.clip(lab_img * 255, 0, 255).astype('uint8') 361 | return lab_img 362 | 363 | saved_pred_rgb = np.array([lab_to_rgb(img) for img in saved_pred]) 364 | 365 | # Save the image 366 | image_name = input_names[i].strip().split('/')[-1] 367 | output_path = os.path.join('./test_sr/SRD512_DESOBA', image_name) 368 | # output_path = os.path.join('./test_sr/ISTD+512', image_name) 369 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 370 | Image.fromarray(saved_pred_rgb[0]).save(output_path) 371 | 372 | bar.update(i + 1) 373 | 374 | 375 | bar.finish() 376 | 377 | 378 | 379 | 380 | 381 | if __name__ == '__main__': 382 | parser = argparse.ArgumentParser() 383 | parser.add_argument('-a', '--action', default='test') 384 | parser.add_argument('-m', '--model', default='latest') 385 | 386 | args = parser.parse_args(sys.argv[1:]) 387 | 388 | if args.action == 'train': 389 | run_train_val(args.model) 390 | elif args.action == 'test': 391 | run_test(args.model) 392 | 393 | -------------------------------------------------------------------------------- /misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pydensecrf.densecrf as dcrf 3 | 4 | def _sigmoid(x): 5 | return 1 / (1 + np.exp(-x)) 6 | 7 | 8 | def crf_refine(img, annos): 9 | assert img.dtype == np.uint8 10 | assert annos.dtype == np.uint8 11 | assert img.shape[:2] == annos.shape 12 | 13 | # img and annos should be np array with data type uint8 14 | 15 | EPSILON = 1e-8 16 | 17 | M = 2 # salient or not 18 | tau = 1.05 19 | # Setup the CRF model 20 | d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], M) 21 | 22 | anno_norm = annos / 255. 23 | 24 | n_energy = -np.log((1.0 - anno_norm + EPSILON)) / (tau * _sigmoid(1 - anno_norm)) 25 | p_energy = -np.log(anno_norm + EPSILON) / (tau * _sigmoid(anno_norm)) 26 | 27 | U = np.zeros((M, img.shape[0] * img.shape[1]), dtype='float32') 28 | U[0, :] = n_energy.flatten() 29 | U[1, :] = p_energy.flatten() 30 | 31 | d.setUnaryEnergy(U) 32 | 33 | d.addPairwiseGaussian(sxy=3, compat=3) 34 | d.addPairwiseBilateral(sxy=60, srgb=5, rgbim=img, compat=5) 35 | 36 | # Do the inference 37 | infer = np.array(d.inference(1)).astype('float32') 38 | res = infer[1, :] 39 | 40 | res = res * 255 41 | res = res.reshape(img.shape[:2]) 42 | return res.astype('uint8') 43 | 44 | -------------------------------------------------------------------------------- /randomcrop.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # from __future__ import division 3 | import torch 4 | import math 5 | import random 6 | from PIL import Image, ImageOps, ImageEnhance 7 | import numbers 8 | import torchvision.transforms.functional as F 9 | import numpy as np 10 | 11 | class RandomCrop(object): 12 | """Crop the given PIL Image at a random location. 13 | 14 | Args: 15 | size (sequence or int): Desired output size of the crop. If size is an 16 | int instead of sequence like (h, w), a square crop (size, size) is 17 | made. 18 | padding (int or sequence, optional): Optional padding on each border 19 | of the image. Default is 0, i.e no padding. If a sequence of length 20 | 4 is provided, it is used to pad left, top, right, bottom borders 21 | respectively. 22 | pad_if_needed (boolean): It will pad the image if smaller than the 23 | desired size to avoid raising an exception. 24 | """ 25 | 26 | def __init__(self, size, padding=0, pad_if_needed=False): 27 | if isinstance(size, numbers.Number): 28 | self.size = (int(size), int(size)) 29 | else: 30 | self.size = size 31 | self.padding = padding 32 | self.pad_if_needed = pad_if_needed 33 | 34 | @staticmethod 35 | def get_params(img, output_size): 36 | """Get parameters for ``crop`` for a random crop. 37 | 38 | Args: 39 | img (PIL Image): Image to be cropped. 40 | output_size (tuple): Expected output size of the crop. 41 | 42 | Returns: 43 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 44 | """ 45 | h,w,_ = img.shape 46 | th, tw = output_size 47 | if w == tw and h == th: 48 | return 0, 0, h, w 49 | 50 | i = random.randint(0, h - th) 51 | j = random.randint(0, w - tw) 52 | return i, j, th, tw 53 | 54 | def __call__(self, img,img_gt): 55 | """ 56 | Args: 57 | img (PIL Image): Image to be cropped. 58 | 59 | Returns: 60 | PIL Image: Cropped image. 61 | """ 62 | if self.padding > 0: 63 | img = F.pad(img, self.padding) 64 | 65 | # pad the width if needed 66 | if self.pad_if_needed and img.size[0] < self.size[1]: 67 | img = F.pad(img, (int((1 + self.size[1] - img.size[0]) / 2), 0)) 68 | # pad the height if needed 69 | if self.pad_if_needed and img.size[1] < self.size[0]: 70 | img = F.pad(img, (0, int((1 + self.size[0] - img.size[1]) / 2))) 71 | 72 | i, j, h, w = self.get_params(img, self.size) 73 | 74 | return img[i:i+self.size[0],j:j+self.size[1],:],img_gt[i:i+self.size[0],j:j+self.size[1],:] 75 | # return F.crop(img, i, j, h, w),F.crop(img_gt, i, j, h, w) 76 | 77 | def __repr__(self): 78 | return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding) 79 | 80 | 81 | 82 | 83 | 84 | class RandomResizedCrop(object): 85 | """Crop the given PIL Image to random size and aspect ratio. 86 | 87 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random 88 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop 89 | is finally resized to given size. 90 | This is popularly used to train the Inception networks. 91 | 92 | Args: 93 | size: expected output size of each edge 94 | scale: range of size of the origin size cropped 95 | ratio: range of aspect ratio of the origin aspect ratio cropped 96 | interpolation: Default: PIL.Image.BILINEAR 97 | """ 98 | 99 | def __init__(self, size, scale=(0.8, 1), ratio=(3/4., 4/3), interpolation=Image.BICUBIC): 100 | self.size = (size, size) 101 | self.interpolation = interpolation 102 | self.scale = scale 103 | self.ratio = ratio 104 | 105 | @staticmethod 106 | def get_params(img, scale, ratio): 107 | """Get parameters for ``crop`` for a random sized crop. 108 | 109 | Args: 110 | img (PIL Image): Image to be cropped. 111 | scale (tuple): range of size of the origin size cropped 112 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped 113 | 114 | Returns: 115 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 116 | sized crop. 117 | """ 118 | for attempt in range(10): 119 | area = img.size[0] * img.size[1] 120 | target_area = random.uniform(*scale) * area 121 | aspect_ratio = random.uniform(*ratio) 122 | 123 | w = int(round(math.sqrt(target_area * aspect_ratio))) 124 | h = int(round(math.sqrt(target_area / aspect_ratio))) 125 | 126 | if random.random() < 0.5: 127 | w, h = h, w 128 | 129 | if w <= img.size[0] and h <= img.size[1]: 130 | i = random.randint(0, img.size[1] - h) 131 | j = random.randint(0, img.size[0] - w) 132 | return i, j, h, w 133 | 134 | # Fallback 135 | w = min(img.size[0], img.size[1]) 136 | i = (img.size[1] - w) // 2 137 | j = (img.size[0] - w) // 2 138 | return i, j, w, w 139 | 140 | def __call__(self, img,img_gt): 141 | """ 142 | Args: 143 | img (PIL Image): Image to be cropped and resized. 144 | 145 | Returns: 146 | PIL Image: Randomly cropped and resized image. 147 | """ 148 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 149 | return (F.resized_crop(img, i, j, h, w, self.size, self.interpolation),F.resized_crop(img_gt, i, j, h, w, self.size, self.interpolation)) 150 | 151 | def __repr__(self): 152 | interpolate_str = _pil_interpolation_to_str[self.interpolation] 153 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size) 154 | format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) 155 | format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) 156 | format_string += ', interpolation={0})'.format(interpolate_str) 157 | return format_string 158 | class RandomRotation(object): 159 | """Rotate the image by angle. 160 | 161 | Args: 162 | degrees (sequence or float or int): Range of degrees to select from. 163 | If degrees is a number instead of sequence like (min, max), the range of degrees 164 | will be (-degrees, +degrees). 165 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 166 | An optional resampling filter. 167 | See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters 168 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 169 | expand (bool, optional): Optional expansion flag. 170 | If true, expands the output to make it large enough to hold the entire rotated image. 171 | If false or omitted, make the output image the same size as the input image. 172 | Note that the expand flag assumes rotation around the center and no translation. 173 | center (2-tuple, optional): Optional center of rotation. 174 | Origin is the upper left corner. 175 | Default is the center of the image. 176 | """ 177 | 178 | def __init__(self, degrees, resample=Image.BICUBIC, expand=1, center=None): 179 | if isinstance(degrees, numbers.Number): 180 | if degrees < 0: 181 | raise ValueError("If degrees is a single number, it must be positive.") 182 | self.degrees = (-degrees, degrees) 183 | else: 184 | if len(degrees) != 2: 185 | raise ValueError("If degrees is a sequence, it must be of len 2.") 186 | self.degrees = degrees 187 | 188 | self.resample = resample 189 | self.expand = expand 190 | self.center = center 191 | 192 | @staticmethod 193 | def get_params(degrees): 194 | """Get parameters for ``rotate`` for a random rotation. 195 | 196 | Returns: 197 | sequence: params to be passed to ``rotate`` for random rotation. 198 | """ 199 | angle = np.random.uniform(degrees[0], degrees[1]) 200 | 201 | return angle 202 | 203 | def __call__(self, img,img_gt): 204 | """ 205 | img (PIL Image): Image to be rotated. 206 | 207 | Returns: 208 | PIL Image: Rotated image. 209 | """ 210 | 211 | angle = self.get_params(self.degrees) 212 | 213 | return (F.rotate(img, angle, self.resample, self.expand, self.center),F.rotate(img_gt, angle, self.resample, self.expand, self.center)) 214 | 215 | def __repr__(self): 216 | return self.__class__.__name__ + '(degrees={0})'.format(self.degrees) 217 | 218 | 219 | class RandomHorizontallyFlip(object): 220 | def __call__(self, img, mask): 221 | if random.random() < 0.5: 222 | return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT) 223 | return img, mask 224 | 225 | class RandomVerticallyFlip(object): 226 | def __call__(self, img, mask): 227 | if random.random() < 0.5: 228 | return img.transpose(Image.FLIP_TOP_BOTTOM), mask.transpose(Image.FLIP_TOP_BOTTOM) 229 | return img, mask -------------------------------------------------------------------------------- /tensorboard.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | function rand() { 3 | min=$1 4 | max=$(($2-$min+1)) 5 | num=$(($RANDOM+1000000000000)) 6 | echo $(($num%$max+$min)) 7 | } 8 | 9 | rnd=$(rand 3000 12000) 10 | tensorboard --logdir ./logdir --host 0.0.0.0 --port 4288 --reload_interval 3 11 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | 7 | class MyWcploss(nn.Module): 8 | def __init__(self): 9 | super(MyWcploss, self).__init__() 10 | 11 | 12 | def forward(self, pred, gt): 13 | eposion = 1e-10 14 | sigmoid_pred = torch.sigmoid(pred) 15 | count_pos = torch.sum(gt)*1.0+eposion 16 | count_neg = torch.sum(1.-gt)*1.0 17 | beta = count_neg/count_pos 18 | beta_back = count_pos / (count_pos + count_neg) 19 | 20 | 21 | bce1 = nn.BCEWithLogitsLoss(pos_weight=beta) 22 | loss = beta_back*bce1(pred, gt) 23 | 24 | return loss 25 | --------------------------------------------------------------------------------