├── README.md ├── models ├── baseU.py ├── baseline.py └── encoder │ ├── mobile.py │ ├── resnet.py │ └── vgg.py ├── setup.json ├── src ├── Eval.py ├── Experiment.py ├── Loader.py ├── Loss.py ├── Metrics.py ├── Score.py ├── Tester.py ├── Trainer.py ├── __init__.py ├── dataSet.py └── utils.py ├── test.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # ITSD-pytorch 2 | Code for CVPR 2020 [paper](https://openaccess.thecvf.com/content_CVPR_2020/papers/Zhou_Interactive_Two-Stream_Decoder_for_Accurate_and_Fast_Saliency_Detection_CVPR_2020_paper.pdf) "Interactive Two-Stream Decoder for Accurate and Fast Saliency Detection" 3 | 4 | Saliency maps can be download at: VGG ([Baidu Yun](https://pan.baidu.com/s/1AdkLgfOK1jwgcqk06zwOwQ) \[gf1i\]), Resnet ([Baidu Yun](https://pan.baidu.com/s/1Gu9RpKuMdZrj1iJvh4A2og) \[sanf\]) 5 | 6 | ## This code is somehow outdated, please use the implementation in our [SOD benchmark](https://github.com/moothes/SALOD). 7 | ## We release our new works on Unsupervised Salient Object Detection (USOD) at [A2S-USOD](https://github.com/moothes/A2S-USOD) and [A2S-v2](https://github.com/moothes/A2S-v2). 8 | 9 | ## Prerequisites 10 | 11 | - [Pytorch 1.0.0](http://pytorch.org/) 12 | - [torchvision 0.2.1](http://pytorch.org/) 13 | - Thop 14 | - Progress 15 | 16 | ## Usage: 17 | Official imagenet-pretrained weights can be download at [Resnet50](https://download.pytorch.org/models/resnet50-19c8e357.pth) and [VGG16](https://download.pytorch.org/models/vgg16-397923af.pth). 18 | 19 | Our models: Google drive ([ResNet50](https://drive.google.com/file/d/1qcZOOL7b7DJ0VbtXK0MvDsOTGNSsmXDm/view) and [VGG16](https://drive.google.com/file/d/1zyDqrjIacqK83pyzbq90rys9m732n28j/view)) or Baidu Disk ([Resnet50](https://pan.baidu.com/s/1qKSnPqbNs4--PwB5fA4E-g) [y55w] and [VGG16](https://pan.baidu.com/s/1ceI8lReLozh2WRsylszQgA) [kehh]) 20 | 21 | Please refer to this repo for results evaluation: [SalMetric](https://github.com/Andrew-Qibin/SalMetric). 22 | 23 | ### Training: 24 | ```bash 25 | python3 train.py --sub=[job_name] --ids=[gpus] --model=[vgg/resnet] 26 | ``` 27 | 28 | ### Testing: 29 | ```bash 30 | 31 | mv path_to_model ./save/[vgg/resnet]/[job_name]/final.pkl # if testing the provided models 32 | python3 test.py --sub=[job_name] --ids=[gpus] --model=[vgg/resnet] 33 | ``` 34 | 35 | 36 | ## Contact 37 | If you have any question, feel free to contact me via: `mootheszhou@gmail.com`. 38 | 39 | 40 | ## Bibtex 41 | ```latex 42 | @InProceedings{Zhou_2020_CVPR, 43 | author = {Zhou, Huajun and Xie, Xiaohua and Lai, Jian-Huang and Chen, Zixuan and Yang, Lingxiao}, 44 | title = {Interactive Two-Stream Decoder for Accurate and Fast Saliency Detection}, 45 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 46 | month = {June}, 47 | year = {2020} 48 | } 49 | ``` 50 | -------------------------------------------------------------------------------- /models/baseU.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, autograd, optim, Tensor, cuda 3 | from torch.nn import functional as F 4 | from torch.autograd import Variable 5 | 6 | from src import utils 7 | 8 | #VGG-16 9 | NUM = [3, 2, 2, 1, 1] 10 | 11 | class sep_conv(nn.Module): 12 | def __init__(self, In, Out): 13 | super(sep_conv, self).__init__() 14 | 15 | self.pw_conv = nn.Conv2d(In, In, kernel_size=3, padding=1, groups=In) 16 | self.bn1 = nn.BatchNorm2d(In) 17 | self.dw_conv = nn.Conv2d(In, Out, kernel_size=1) 18 | self.bn2 = nn.BatchNorm2d(Out) 19 | self.relu = nn.ReLU(inplace=True) 20 | 21 | def forward(self, x): 22 | x = self.bn1(self.pw_conv(x)) 23 | x = self.relu(self.bn2(self.dw_conv(x))) 24 | return x 25 | 26 | 27 | def get_centers(x, scores, K=3): 28 | score = scores.sigmoid_() 29 | centers = torch.sum(torch.sum(score * x, dim=-1, keepdim=True), dim=-2, keepdim=True) 30 | weights = torch.sum(torch.sum(score, dim=-1, keepdim=True), dim=-2, keepdim=True) 31 | centers = centers / weights 32 | return centers 33 | 34 | def cls_atten(x, heat): 35 | centers = get_centers(x, heat, 1) 36 | centers = centers.view(centers.size(0), centers.size(1), 1, 1).expand_as(x) 37 | cos_map = F.cosine_similarity(x, centers) 38 | #print(x.size(), centers.size()) 39 | return cos_map.unsqueeze(1) 40 | 41 | def gen_convs(In, Out, num=1): 42 | for i in range(num): 43 | yield nn.Conv2d(In, In, 3, padding=1) 44 | yield nn.ReLU(inplace=True) 45 | 46 | def gen_fuse(In, Out): 47 | yield nn.Conv2d(In, Out, 3, padding=1) 48 | yield nn.GroupNorm(Out//2, Out) 49 | yield nn.ReLU(inplace=True) 50 | 51 | def cp(x, n=2): 52 | batch, cat, w, h = x.size() 53 | xn = x.view(batch, cat//n, n, w, h) 54 | xn = torch.max(xn, dim=2)[0] 55 | return xn 56 | 57 | def gen_final(In, Out): 58 | yield nn.Conv2d(In, Out, 3, padding=1) 59 | yield nn.ReLU(inplace=True) 60 | 61 | # ---------------------decode method------------------------------ 62 | def decode_conv(layer, c): 63 | for i in range(4 - layer): 64 | yield nn.Conv2d(c, c, 3, padding=1) 65 | yield nn.ReLU(inplace=True) 66 | #yield sep_conv(c, c) 67 | yield nn.Upsample(scale_factor=2, mode='bilinear' if i == 2 else 'nearest') 68 | 69 | yield nn.Conv2d(c, 8, 3, padding=1) 70 | yield nn.ReLU(inplace=True) 71 | 72 | def decode_conv_new(layer, c): 73 | temp = c 74 | nc = c 75 | 76 | for i in range(4 - layer): 77 | oc = min(temp, nc) 78 | nc = temp // 2 79 | temp = temp // 2 if temp > 16 else 16 80 | yield nn.Conv2d(oc, nc, 3, padding=1) 81 | yield nn.ReLU(inplace=True) 82 | yield nn.Upsample(scale_factor=2, mode='nearest') 83 | 84 | yield nn.Conv2d(nc, 8, 3, padding=1) 85 | yield nn.ReLU(inplace=True) 86 | 87 | class pred_block(nn.Module): 88 | def __init__(self, In, Out, up=False): 89 | super(pred_block, self).__init__() 90 | 91 | self.final_conv = nn.Conv2d(In, Out, 3, padding=1) 92 | self.pr_conv = nn.Conv2d(Out, 4, 3, padding=1) 93 | self.up = up 94 | 95 | def forward(self, X): 96 | a = nn.functional.relu(self.final_conv(X)) 97 | a1 = self.pr_conv(a) 98 | pred = torch.max(a1, dim=1)[0] 99 | if self.up: 100 | a = nn.functional.interpolate(a, scale_factor=2, mode='bilinear') 101 | return [a, pred] 102 | 103 | class res_block(nn.Module): 104 | def __init__(self, cat, layer): 105 | super(res_block, self).__init__() 106 | 107 | if layer: 108 | self.conv4 = nn.Sequential(*list(gen_fuse(cat, cat // 2))) 109 | 110 | self.convs = nn.Sequential(*list(gen_convs(cat, cat, NUM[layer]))) 111 | self.conv2 = nn.Sequential(*list(gen_fuse(cat, cat//2))) 112 | 113 | self.final = nn.Sequential(*list(gen_final(cat, cat))) 114 | self.layer = layer 115 | self.initialize() 116 | 117 | def forward(self, X, encoder): 118 | if self.layer: 119 | X = nn.functional.interpolate(X, scale_factor=2, mode='bilinear') 120 | c = cp(X) 121 | d = self.conv4(encoder) 122 | X = torch.cat([c, d], 1) 123 | 124 | X = self.convs(X) 125 | a = cp(X) 126 | b = self.conv2(encoder) 127 | f = torch.cat([a, b], 1) 128 | f = self.final(f) 129 | return f 130 | 131 | def initialize(self): 132 | utils.initModule(self.convs) 133 | utils.initModule(self.conv2) 134 | utils.initModule(self.final) 135 | 136 | if self.layer: 137 | utils.initModule(self.conv4) 138 | 139 | class ctr_block(nn.Module): 140 | def __init__(self, cat, layer): 141 | super(ctr_block, self).__init__() 142 | self.conv1 = nn.Sequential(*list(gen_convs(cat, cat, NUM[layer]))) 143 | self.conv2 = nn.Sequential(*list(gen_fuse(cat, cat))) 144 | self.final = nn.Sequential(*list(gen_final(cat, cat))) 145 | self.layer = layer 146 | self.initialize() 147 | 148 | def forward(self, X): 149 | X = self.conv1(X) 150 | if self.layer: 151 | X = nn.functional.interpolate(X, scale_factor=2, mode='bilinear') 152 | X = self.conv2(X) 153 | x = self.final(X) 154 | return x 155 | 156 | def initialize(self): 157 | utils.initModule(self.conv1) 158 | utils.initModule(self.conv2) 159 | utils.initModule(self.final) 160 | 161 | class final_block(nn.Module): 162 | def __init__(self, backbone, channel): 163 | super(final_block, self).__init__() 164 | self.slc_decode = nn.ModuleList([nn.Sequential(*list(decode_conv(i, channel))) for i in range(5)]) 165 | self.conv = nn.Conv2d(40, 8, 3, padding=1) 166 | self.backbone = backbone 167 | 168 | def forward(self, xs, phase): 169 | feats = [self.slc_decode[i](xs[i]) for i in range(5)] 170 | x = torch.cat(feats, 1) 171 | 172 | x = self.conv(x) 173 | if not self.backbone.startswith('vgg'): 174 | x = nn.functional.interpolate(x, scale_factor=2, mode='bilinear') 175 | 176 | scale = 2 if phase == 'te' else 1 177 | x = torch.max(x, dim=1)[0] * scale 178 | return x 179 | 180 | class baseU(nn.Module): 181 | def __init__(self, backbone=False, channel=64): 182 | super(baseU, self).__init__() 183 | self.name = 'baseU' 184 | self.layer = 5 185 | 186 | self.slc_blocks = nn.ModuleList([res_block(channel, i) for i in range(self.layer)]) 187 | self.slc_preds = nn.ModuleList([pred_block(channel, channel//2) for i in range(self.layer)]) 188 | 189 | self.ctr_blocks = nn.ModuleList([ctr_block(channel, i) for i in range(self.layer)]) 190 | self.ctr_preds = nn.ModuleList([pred_block(channel, channel//2, up=True) for i in range(self.layer)]) 191 | 192 | self.slc_decode = nn.ModuleList([nn.Sequential(*list(decode_conv(i, channel))) for i in range(5)]) 193 | self.final = final_block(backbone, channel) 194 | 195 | def forward(self, encoders, phase='te'): 196 | slcs, slc_maps = [encoders[-1]], [] 197 | ctrs, ctr_maps = [], [] 198 | stc, cts = None, None 199 | 200 | for i in range(self.layer): 201 | slc = self.slc_blocks[i](slcs[-1], encoders[self.layer - 1 - i]) 202 | if cts is not None: 203 | slc = torch.cat([cp(slc), cts], dim=1) 204 | else: 205 | ctrs.append(slc) 206 | stc, slc_map = self.slc_preds[i](slc) 207 | 208 | ctr = self.ctr_blocks[i](ctrs[-1]) 209 | ctr = torch.cat([cp(ctr), stc], dim=1) 210 | cts, ctr_map = self.ctr_preds[i](ctr) 211 | 212 | slcs.append(slc) 213 | ctrs.append(ctr) 214 | slc_maps.append(slc_map) 215 | ctr_maps.append(ctr_map) 216 | 217 | final = self.final(slcs[1:], phase) 218 | 219 | OutPuts = {'final':final, 'preds':slc_maps, 'contour':ctr_maps} 220 | return OutPuts 221 | -------------------------------------------------------------------------------- /models/baseline.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | 4 | # Original 5 | #from models.encoder.vgg_our import vgg 6 | #from models.encoder.resnet_our import resnet 7 | 8 | # Official 9 | from models.encoder.vgg import vgg16 as vgg 10 | from models.encoder.resnet import resnet50 as resnet 11 | 12 | #from models.encoder.mobile import mobilenet 13 | from .baseU import baseU 14 | from torch.nn import functional as F 15 | 16 | res_inc = [64, 256, 512, 1024, 2048] 17 | vgg_inc = [64, 128, 256, 512, 512] 18 | mobile_inc = [16, 24, 32, 64, 160] 19 | 20 | class vgg_adapter(nn.Module): 21 | def __init__(self, in1, channel=64): 22 | super(vgg_adapter, self).__init__() 23 | self.channel = channel 24 | 25 | def forward(self, x): 26 | batch, cat, height, width = x.size() 27 | x = torch.max(x.view(batch, self.channel, -1, height, width), dim=2)[0] 28 | return x 29 | 30 | class resnet_adapter(nn.Module): 31 | def __init__(self, in1=64, out=64): 32 | super(resnet_adapter, self).__init__() 33 | self.reduce = in1 > 64 34 | self.conv = nn.Conv2d(in1//4 if self.reduce else in1, out, 3, padding=1) 35 | self.relu = nn.ReLU() 36 | 37 | def forward(self, X): 38 | if self.reduce: 39 | batch, cat, height, width = X.size() 40 | X = torch.max(X.view(batch, -1, 4, height, width), dim=2)[0] 41 | X = self.relu(self.conv(X)) 42 | 43 | return X 44 | 45 | class mobile_adapter(nn.Module): 46 | def __init__(self, in1=64, out=64): 47 | super(mobile_adapter, self).__init__() 48 | self.conv = nn.Conv2d(in1, out, 3, padding=1) 49 | self.relu = nn.ReLU() 50 | 51 | def forward(self, X): 52 | X = self.relu(self.conv(X)) 53 | return X 54 | 55 | 56 | class Encoder(nn.Module): 57 | def __init__(self, backbone, c=64): 58 | super(Encoder, self).__init__() 59 | 60 | # resnet50 61 | if backbone.startswith('resnet'): 62 | self.encoder = resnet(pretrained=True) 63 | self.adapters = nn.ModuleList([resnet_adapter(in1, c) for in1 in res_inc]) 64 | # mobilenet 65 | elif backbone.startswith('mobile'): 66 | self.encoder = mobilenet() 67 | c = 16 68 | self.adapters = nn.ModuleList([mobile_adapter(in1, c) for in1 in mobile_inc]) 69 | # vgg 70 | else: 71 | self.encoder = vgg('vgg16') 72 | self.adapters = nn.ModuleList([vgg_adapter(in1, c) for in1 in vgg_inc]) 73 | 74 | def forward(self, x): 75 | enc_feats = self.encoder(x) 76 | enc_feats = [self.adapters[i](e_feat) for i, e_feat in enumerate(enc_feats)] 77 | return enc_feats 78 | 79 | class baseline(nn.Module): 80 | def __init__(self, backbone, c=64): 81 | super(baseline, self).__init__() 82 | self.name = backbone 83 | 84 | self.encoder = Encoder(backbone, c) 85 | self.decoder = baseU(backbone, c) 86 | 87 | def forward(self, X, phase='te'): 88 | encoders = self.encoder(X) 89 | OutDict = self.decoder(encoders, phase) 90 | 91 | 92 | return OutDict 93 | 94 | -------------------------------------------------------------------------------- /models/encoder/mobile.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import init 7 | 8 | 9 | def _make_divisible(v, divisor, min_value=None): 10 | """ 11 | This function is taken from the original tf repo. 12 | It ensures that all layers have a channel number that is divisible by 8 13 | It can be seen here: 14 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 15 | :param v: 16 | :param divisor: 17 | :param min_value: 18 | :return: 19 | """ 20 | if min_value is None: 21 | min_value = divisor 22 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 23 | # Make sure that round down does not go down by more than 10%. 24 | if new_v < 0.9 * v: 25 | new_v += divisor 26 | return new_v 27 | 28 | 29 | class LinearBottleneck(nn.Module): 30 | def __init__(self, inplanes, outplanes, stride=1, t=6, activation=nn.ReLU6): 31 | super(LinearBottleneck, self).__init__() 32 | self.conv1 = nn.Conv2d(inplanes, inplanes * t, kernel_size=1, bias=False) 33 | self.bn1 = nn.BatchNorm2d(inplanes * t) 34 | self.conv2 = nn.Conv2d(inplanes * t, inplanes * t, kernel_size=3, stride=stride, padding=1, bias=False, 35 | groups=inplanes * t) 36 | self.bn2 = nn.BatchNorm2d(inplanes * t) 37 | self.conv3 = nn.Conv2d(inplanes * t, outplanes, kernel_size=1, bias=False) 38 | self.bn3 = nn.BatchNorm2d(outplanes) 39 | self.activation = activation(inplace=True) 40 | self.stride = stride 41 | self.t = t 42 | self.inplanes = inplanes 43 | self.outplanes = outplanes 44 | 45 | def forward(self, x): 46 | residual = x 47 | 48 | out = self.conv1(x) 49 | out = self.bn1(out) 50 | out = self.activation(out) 51 | 52 | out = self.conv2(out) 53 | out = self.bn2(out) 54 | out = self.activation(out) 55 | 56 | out = self.conv3(out) 57 | out = self.bn3(out) 58 | 59 | if self.stride == 1 and self.inplanes == self.outplanes: 60 | out += residual 61 | 62 | return out 63 | 64 | 65 | class MobileNet2(nn.Module): 66 | """MobileNet2 implementation. 67 | """ 68 | 69 | def __init__(self, scale=1.0, input_size=224, t=6, in_channels=3, num_classes=1000, activation=nn.ReLU6): 70 | """ 71 | MobileNet2 constructor. 72 | :param in_channels: (int, optional): number of channels in the input tensor. 73 | Default is 3 for RGB image inputs. 74 | :param input_size: 75 | :param num_classes: number of classes to predict. Default 76 | is 1000 for ImageNet. 77 | :param scale: 78 | :param t: 79 | :param activation: 80 | """ 81 | 82 | super(MobileNet2, self).__init__() 83 | 84 | self.scale = scale 85 | self.t = t 86 | self.activation_type = activation 87 | self.activation = activation(inplace=True) 88 | self.num_classes = num_classes 89 | 90 | self.num_of_channels = [32, 16, 24, 32, 64, 96, 160, 320] 91 | # assert (input_size % 32 == 0) 92 | 93 | self.c = [_make_divisible(ch * self.scale, 8) for ch in self.num_of_channels] 94 | self.n = [1, 1, 2, 3, 4, 3, 3, 1] 95 | self.s = [2, 1, 2, 2, 2, 1, 2, 1] 96 | self.conv1 = nn.Conv2d(in_channels, self.c[0], kernel_size=3, bias=False, stride=self.s[0], padding=1) 97 | self.bn1 = nn.BatchNorm2d(self.c[0]) 98 | self.bottlenecks = self._make_bottlenecks() 99 | #print(self.bottlenecks) 100 | 101 | # Last convolution has 1280 output channels for scale <= 1 102 | #self.last_conv_out_ch = 1280 if self.scale <= 1 else _make_divisible(1280 * self.scale, 8) 103 | #self.conv_last = nn.Conv2d(self.c[-1], self.last_conv_out_ch, kernel_size=1, bias=False) 104 | #self.bn_last = nn.BatchNorm2d(self.last_conv_out_ch) 105 | #self.avgpool = nn.AdaptiveAvgPool2d(1) 106 | #self.dropout = nn.Dropout(p=0.2, inplace=True) # confirmed by paper authors 107 | #self.fc = nn.Linear(self.last_conv_out_ch, self.num_classes) 108 | self.init_params() 109 | 110 | 111 | 112 | def init_params(self): 113 | for m in self.modules(): 114 | if isinstance(m, nn.Conv2d): 115 | init.kaiming_normal_(m.weight, mode='fan_out') 116 | if m.bias is not None: 117 | init.constant_(m.bias, 0) 118 | elif isinstance(m, nn.BatchNorm2d): 119 | init.constant_(m.weight, 1) 120 | init.constant_(m.bias, 0) 121 | elif isinstance(m, nn.Linear): 122 | init.normal_(m.weight, std=0.001) 123 | if m.bias is not None: 124 | init.constant_(m.bias, 0) 125 | 126 | def _make_stage(self, inplanes, outplanes, n, stride, t, stage): 127 | modules = OrderedDict() 128 | stage_name = "LinearBottleneck{}".format(stage) 129 | 130 | # First module is the only one utilizing stride 131 | first_module = LinearBottleneck(inplanes=inplanes, outplanes=outplanes, stride=stride, t=t, 132 | activation=self.activation_type) 133 | modules[stage_name + "_0"] = first_module 134 | 135 | # add more LinearBottleneck depending on number of repeats 136 | for i in range(n - 1): 137 | name = stage_name + "_{}".format(i + 1) 138 | module = LinearBottleneck(inplanes=outplanes, outplanes=outplanes, stride=1, t=6, 139 | activation=self.activation_type) 140 | modules[name] = module 141 | 142 | return nn.Sequential(modules) 143 | 144 | def _make_bottlenecks(self): 145 | modules = OrderedDict() 146 | stage_name = "Bottlenecks" 147 | 148 | # First module is the only one with t=1 149 | bottleneck1 = self._make_stage(inplanes=self.c[0], outplanes=self.c[1], n=self.n[1], stride=self.s[1], t=1, 150 | stage=0) 151 | modules[stage_name + "_0"] = bottleneck1 152 | 153 | # add more LinearBottleneck depending on number of repeats 154 | for i in range(1, len(self.c) - 1): 155 | name = stage_name + "_{}".format(i) 156 | module = self._make_stage(inplanes=self.c[i], outplanes=self.c[i + 1], n=self.n[i + 1], 157 | stride=self.s[i + 1], 158 | t=self.t, stage=i) 159 | modules[name] = module 160 | 161 | #print(modules) 162 | return nn.Sequential(modules) 163 | 164 | def forward(self, x): 165 | feat_list = [] 166 | x = self.conv1(x) 167 | x = self.bn1(x) 168 | x = self.activation(x) 169 | 170 | for i, module in enumerate(self.bottlenecks): 171 | x = module(x) 172 | if i in (0, 1, 2, 3, 5): 173 | feat_list.append(x) 174 | #x = self.bottlenecks(x) 175 | #x = self.conv_last(x) 176 | #x = self.bn_last(x) 177 | #x = self.activation(x) 178 | #feat_list.append(x) 179 | 180 | #for a in feat_list: 181 | # print(a.size()) 182 | 183 | # average pooling layer 184 | #x = self.avgpool(x) 185 | #x = self.dropout(x) 186 | 187 | # flatten for input to fully-connected layer 188 | #x = x.view(x.size(0), -1) 189 | #x = self.fc(x) 190 | return feat_list#F.log_softmax(x, dim=1) #TODO not needed(?) 191 | 192 | def mobilenet(): 193 | model = MobileNet2() 194 | 195 | pretrain = torch.load('../PretrainModel/mobilev2.pth.tar', map_location={'cuda:1':'cuda:0'})['state_dict'] 196 | 197 | new_pre = {} 198 | for key, val in pretrain.items(): 199 | new_pre[key[7:]] = val 200 | 201 | exist_dict = {k:v for k,v in new_pre.items() if k in model.state_dict()} 202 | model.load_state_dict(exist_dict) 203 | 204 | return model 205 | 206 | if __name__ == "__main__": 207 | """Testing 208 | """ 209 | model1 = MobileNet2() 210 | print(model1) 211 | model2 = MobileNet2(scale=0.35) 212 | print(model2) 213 | model3 = MobileNet2(in_channels=2, num_classes=10) 214 | print(model3) 215 | x = torch.randn(1, 2, 224, 224) 216 | print(model3(x)) 217 | model4_size = 32 * 10 218 | model4 = MobileNet2(input_size=model4_size, num_classes=10) 219 | print(model4) 220 | x2 = torch.randn(1, 3, model4_size, model4_size) 221 | print(model4(x2)) 222 | model5 = MobileNet2(input_size=196, num_classes=10) 223 | x3 = torch.randn(1, 3, 196, 196) 224 | print(model5(x3)) # fail 225 | -------------------------------------------------------------------------------- /models/encoder/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | try: 4 | from torch.hub import load_state_dict_from_url 5 | except ImportError: 6 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 7 | 8 | 9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 10 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 11 | 'wide_resnet50_2', 'wide_resnet101_2'] 12 | 13 | 14 | model_urls = { 15 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 16 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 17 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 18 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 19 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 20 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 21 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 22 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 23 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 24 | } 25 | 26 | 27 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 28 | """3x3 convolution with padding""" 29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 30 | padding=dilation, groups=groups, bias=False, dilation=dilation) 31 | 32 | 33 | def conv1x1(in_planes, out_planes, stride=1): 34 | """1x1 convolution""" 35 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 36 | 37 | 38 | class BasicBlock(nn.Module): 39 | expansion = 1 40 | 41 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 42 | base_width=64, dilation=1, norm_layer=None): 43 | super(BasicBlock, self).__init__() 44 | if norm_layer is None: 45 | norm_layer = nn.BatchNorm2d 46 | if groups != 1 or base_width != 64: 47 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 48 | if dilation > 1: 49 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 50 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 51 | self.conv1 = conv3x3(inplanes, planes, stride) 52 | self.bn1 = norm_layer(planes) 53 | self.relu = nn.ReLU(inplace=True) 54 | self.conv2 = conv3x3(planes, planes) 55 | self.bn2 = norm_layer(planes) 56 | self.downsample = downsample 57 | self.stride = stride 58 | 59 | def forward(self, x): 60 | identity = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | out = self.conv2(out) 67 | out = self.bn2(out) 68 | 69 | if self.downsample is not None: 70 | identity = self.downsample(x) 71 | 72 | out += identity 73 | out = self.relu(out) 74 | 75 | return out 76 | 77 | 78 | class Bottleneck(nn.Module): 79 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 80 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 81 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 82 | # This variant is also known as ResNet V1.5 and improves accuracy according to 83 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 84 | 85 | expansion = 4 86 | 87 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 88 | base_width=64, dilation=1, norm_layer=None): 89 | super(Bottleneck, self).__init__() 90 | if norm_layer is None: 91 | norm_layer = nn.BatchNorm2d 92 | width = int(planes * (base_width / 64.)) * groups 93 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 94 | self.conv1 = conv1x1(inplanes, width) 95 | self.bn1 = norm_layer(width) 96 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 97 | self.bn2 = norm_layer(width) 98 | self.conv3 = conv1x1(width, planes * self.expansion) 99 | self.bn3 = norm_layer(planes * self.expansion) 100 | self.relu = nn.ReLU(inplace=True) 101 | self.downsample = downsample 102 | self.stride = stride 103 | 104 | def forward(self, x): 105 | identity = x 106 | 107 | out = self.conv1(x) 108 | out = self.bn1(out) 109 | out = self.relu(out) 110 | 111 | out = self.conv2(out) 112 | out = self.bn2(out) 113 | out = self.relu(out) 114 | 115 | out = self.conv3(out) 116 | out = self.bn3(out) 117 | 118 | if self.downsample is not None: 119 | identity = self.downsample(x) 120 | 121 | out += identity 122 | out = self.relu(out) 123 | 124 | return out 125 | 126 | 127 | class ResNet(nn.Module): 128 | 129 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 130 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 131 | norm_layer=None): 132 | super(ResNet, self).__init__() 133 | if norm_layer is None: 134 | norm_layer = nn.BatchNorm2d 135 | self._norm_layer = norm_layer 136 | 137 | self.inplanes = 64 138 | self.dilation = 1 139 | if replace_stride_with_dilation is None: 140 | # each element in the tuple indicates if we should replace 141 | # the 2x2 stride with a dilated convolution instead 142 | replace_stride_with_dilation = [False, False, False] 143 | if len(replace_stride_with_dilation) != 3: 144 | raise ValueError("replace_stride_with_dilation should be None " 145 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 146 | self.groups = groups 147 | self.base_width = width_per_group 148 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 149 | bias=False) 150 | self.bn1 = norm_layer(self.inplanes) 151 | self.relu = nn.ReLU(inplace=True) 152 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 153 | self.layer1 = self._make_layer(block, 64, layers[0]) 154 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 155 | dilate=replace_stride_with_dilation[0]) 156 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 157 | dilate=replace_stride_with_dilation[1]) 158 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 159 | dilate=replace_stride_with_dilation[2]) 160 | #self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 161 | #self.fc = nn.Linear(512 * block.expansion, num_classes) 162 | 163 | for m in self.modules(): 164 | if isinstance(m, nn.Conv2d): 165 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 166 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 167 | nn.init.constant_(m.weight, 1) 168 | nn.init.constant_(m.bias, 0) 169 | 170 | # Zero-initialize the last BN in each residual branch, 171 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 172 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 173 | if zero_init_residual: 174 | for m in self.modules(): 175 | if isinstance(m, Bottleneck): 176 | nn.init.constant_(m.bn3.weight, 0) 177 | elif isinstance(m, BasicBlock): 178 | nn.init.constant_(m.bn2.weight, 0) 179 | 180 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 181 | norm_layer = self._norm_layer 182 | downsample = None 183 | previous_dilation = self.dilation 184 | if dilate: 185 | self.dilation *= stride 186 | stride = 1 187 | if stride != 1 or self.inplanes != planes * block.expansion: 188 | downsample = nn.Sequential( 189 | conv1x1(self.inplanes, planes * block.expansion, stride), 190 | norm_layer(planes * block.expansion), 191 | ) 192 | 193 | layers = [] 194 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 195 | self.base_width, previous_dilation, norm_layer)) 196 | self.inplanes = planes * block.expansion 197 | for _ in range(1, blocks): 198 | layers.append(block(self.inplanes, planes, groups=self.groups, 199 | base_width=self.base_width, dilation=self.dilation, 200 | norm_layer=norm_layer)) 201 | 202 | return nn.Sequential(*layers) 203 | 204 | def _forward_impl(self, x): 205 | # See note [TorchScript super()] 206 | xs = [] 207 | x = self.conv1(x) 208 | x = self.bn1(x) 209 | x = self.relu(x) 210 | xs.append(x) 211 | x = self.maxpool(x) 212 | 213 | x = self.layer1(x) 214 | xs.append(x) 215 | x = self.layer2(x) 216 | xs.append(x) 217 | x = self.layer3(x) 218 | xs.append(x) 219 | x = self.layer4(x) 220 | xs.append(x) 221 | 222 | return xs 223 | 224 | def forward(self, x): 225 | return self._forward_impl(x) 226 | 227 | 228 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 229 | model = ResNet(block, layers, **kwargs) 230 | if pretrained: 231 | state_dict = torch.load('../PretrainModel/'+arch+'.pth', map_location='cpu') 232 | existing_dict = model.state_dict() 233 | #for key in existing_dict.keys(): 234 | # existing_dict[key] = state_dict[key] 235 | existing_dict.update(state_dict) 236 | model.load_state_dict(existing_dict, strict=False) 237 | #state_dict = load_state_dict_from_url(model_urls[arch], 238 | # progress=progress) 239 | #model.load_state_dict(state_dict) 240 | return model 241 | 242 | 243 | def resnet18(pretrained=False, progress=True, **kwargs): 244 | r"""ResNet-18 model from 245 | `"Deep Residual Learning for Image Recognition" `_ 246 | Args: 247 | pretrained (bool): If True, returns a model pre-trained on ImageNet 248 | progress (bool): If True, displays a progress bar of the download to stderr 249 | """ 250 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 251 | **kwargs) 252 | 253 | 254 | def resnet34(pretrained=False, progress=True, **kwargs): 255 | r"""ResNet-34 model from 256 | `"Deep Residual Learning for Image Recognition" `_ 257 | Args: 258 | pretrained (bool): If True, returns a model pre-trained on ImageNet 259 | progress (bool): If True, displays a progress bar of the download to stderr 260 | """ 261 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 262 | **kwargs) 263 | 264 | 265 | def resnet50(pretrained=True, progress=True, **kwargs): 266 | r"""ResNet-50 model from 267 | `"Deep Residual Learning for Image Recognition" `_ 268 | Args: 269 | pretrained (bool): If True, returns a model pre-trained on ImageNet 270 | progress (bool): If True, displays a progress bar of the download to stderr 271 | """ 272 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 273 | **kwargs) 274 | 275 | 276 | def resnet101(pretrained=False, progress=True, **kwargs): 277 | r"""ResNet-101 model from 278 | `"Deep Residual Learning for Image Recognition" `_ 279 | Args: 280 | pretrained (bool): If True, returns a model pre-trained on ImageNet 281 | progress (bool): If True, displays a progress bar of the download to stderr 282 | """ 283 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 284 | **kwargs) 285 | 286 | 287 | def resnet152(pretrained=False, progress=True, **kwargs): 288 | r"""ResNet-152 model from 289 | `"Deep Residual Learning for Image Recognition" `_ 290 | Args: 291 | pretrained (bool): If True, returns a model pre-trained on ImageNet 292 | progress (bool): If True, displays a progress bar of the download to stderr 293 | """ 294 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 295 | **kwargs) 296 | 297 | 298 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 299 | r"""ResNeXt-50 32x4d model from 300 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 301 | Args: 302 | pretrained (bool): If True, returns a model pre-trained on ImageNet 303 | progress (bool): If True, displays a progress bar of the download to stderr 304 | """ 305 | kwargs['groups'] = 32 306 | kwargs['width_per_group'] = 4 307 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 308 | pretrained, progress, **kwargs) 309 | 310 | 311 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 312 | r"""ResNeXt-101 32x8d model from 313 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 314 | Args: 315 | pretrained (bool): If True, returns a model pre-trained on ImageNet 316 | progress (bool): If True, displays a progress bar of the download to stderr 317 | """ 318 | kwargs['groups'] = 32 319 | kwargs['width_per_group'] = 8 320 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 321 | pretrained, progress, **kwargs) 322 | 323 | 324 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 325 | r"""Wide ResNet-50-2 model from 326 | `"Wide Residual Networks" `_ 327 | The model is the same as ResNet except for the bottleneck number of channels 328 | which is twice larger in every block. The number of channels in outer 1x1 329 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 330 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 331 | Args: 332 | pretrained (bool): If True, returns a model pre-trained on ImageNet 333 | progress (bool): If True, displays a progress bar of the download to stderr 334 | """ 335 | kwargs['width_per_group'] = 64 * 2 336 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 337 | pretrained, progress, **kwargs) 338 | 339 | 340 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 341 | r"""Wide ResNet-101-2 model from 342 | `"Wide Residual Networks" `_ 343 | The model is the same as ResNet except for the bottleneck number of channels 344 | which is twice larger in every block. The number of channels in outer 1x1 345 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 346 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 347 | Args: 348 | pretrained (bool): If True, returns a model pre-trained on ImageNet 349 | progress (bool): If True, displays a progress bar of the download to stderr 350 | """ 351 | kwargs['width_per_group'] = 64 * 2 352 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 353 | pretrained, progress, **kwargs) 354 | -------------------------------------------------------------------------------- /models/encoder/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | #from .utils import load_state_dict_from_url 4 | 5 | 6 | __all__ = [ 7 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 8 | 'vgg19_bn', 'vgg19', 9 | ] 10 | 11 | 12 | model_urls = { 13 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 14 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 15 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 16 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 17 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 18 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 19 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 20 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 21 | } 22 | 23 | 24 | class VGG(nn.Module): 25 | 26 | def __init__(self, features, num_classes=1000, init_weights=True): 27 | super(VGG, self).__init__() 28 | self.features = features 29 | #print(self.features) 30 | ''' 31 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 32 | self.classifier = nn.Sequential( 33 | nn.Linear(512 * 7 * 7, 4096), 34 | nn.ReLU(True), 35 | nn.Dropout(), 36 | nn.Linear(4096, 4096), 37 | nn.ReLU(True), 38 | nn.Dropout(), 39 | nn.Linear(4096, num_classes), 40 | ) 41 | ''' 42 | if init_weights: 43 | self._initialize_weights() 44 | 45 | def forward(self, x): 46 | xs = [] 47 | 48 | for i in range(len(self.features)): 49 | x = self.features[i](x) 50 | if i in (3, 8, 15, 22, 29): 51 | xs.append(x) 52 | #print(x.size()) 53 | #x = self.features(x) 54 | #x = self.avgpool(x) 55 | #x = torch.flatten(x, 1) 56 | #x = self.classifier(x) 57 | return xs 58 | 59 | def _initialize_weights(self): 60 | for m in self.modules(): 61 | if isinstance(m, nn.Conv2d): 62 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 63 | if m.bias is not None: 64 | nn.init.constant_(m.bias, 0) 65 | elif isinstance(m, nn.BatchNorm2d): 66 | nn.init.constant_(m.weight, 1) 67 | nn.init.constant_(m.bias, 0) 68 | elif isinstance(m, nn.Linear): 69 | nn.init.normal_(m.weight, 0, 0.01) 70 | nn.init.constant_(m.bias, 0) 71 | 72 | 73 | def make_layers(cfg, batch_norm=False): 74 | layers = [] 75 | in_channels = 3 76 | for v in cfg: 77 | if v == 'M': 78 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 79 | else: 80 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 81 | if batch_norm: 82 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 83 | else: 84 | layers += [conv2d, nn.ReLU(inplace=True)] 85 | in_channels = v 86 | return nn.Sequential(*layers) 87 | 88 | 89 | cfgs = { 90 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 91 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 92 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 93 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 94 | } 95 | 96 | 97 | def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs): 98 | if pretrained: 99 | kwargs['init_weights'] = False 100 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) 101 | if pretrained: 102 | #state_dict = load_state_dict_from_url(model_urls[arch], 103 | # progress=progress) 104 | #model.load_state_dict(state_dict) 105 | 106 | model.load_state_dict(torch.load('../PretrainModel/vgg16.pth', map_location='cpu'), strict=False) 107 | return model 108 | 109 | 110 | def vgg11(pretrained=False, progress=True, **kwargs): 111 | r"""VGG 11-layer model (configuration "A") from 112 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 113 | Args: 114 | pretrained (bool): If True, returns a model pre-trained on ImageNet 115 | progress (bool): If True, displays a progress bar of the download to stderr 116 | """ 117 | return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs) 118 | 119 | 120 | def vgg11_bn(pretrained=False, progress=True, **kwargs): 121 | r"""VGG 11-layer model (configuration "A") with batch normalization 122 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 123 | Args: 124 | pretrained (bool): If True, returns a model pre-trained on ImageNet 125 | progress (bool): If True, displays a progress bar of the download to stderr 126 | """ 127 | return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs) 128 | 129 | 130 | def vgg13(pretrained=False, progress=True, **kwargs): 131 | r"""VGG 13-layer model (configuration "B") 132 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 133 | Args: 134 | pretrained (bool): If True, returns a model pre-trained on ImageNet 135 | progress (bool): If True, displays a progress bar of the download to stderr 136 | """ 137 | return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs) 138 | 139 | 140 | def vgg13_bn(pretrained=False, progress=True, **kwargs): 141 | r"""VGG 13-layer model (configuration "B") with batch normalization 142 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 143 | Args: 144 | pretrained (bool): If True, returns a model pre-trained on ImageNet 145 | progress (bool): If True, displays a progress bar of the download to stderr 146 | """ 147 | return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs) 148 | 149 | 150 | def vgg16(pretrained=True, progress=True, **kwargs): 151 | r"""VGG 16-layer model (configuration "D") 152 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 153 | Args: 154 | pretrained (bool): If True, returns a model pre-trained on ImageNet 155 | progress (bool): If True, displays a progress bar of the download to stderr 156 | """ 157 | return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) 158 | 159 | 160 | def vgg16_bn(pretrained=False, progress=True, **kwargs): 161 | r"""VGG 16-layer model (configuration "D") with batch normalization 162 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 163 | Args: 164 | pretrained (bool): If True, returns a model pre-trained on ImageNet 165 | progress (bool): If True, displays a progress bar of the download to stderr 166 | """ 167 | return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs) 168 | 169 | 170 | def vgg19(pretrained=False, progress=True, **kwargs): 171 | r"""VGG 19-layer model (configuration "E") 172 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 173 | Args: 174 | pretrained (bool): If True, returns a model pre-trained on ImageNet 175 | progress (bool): If True, displays a progress bar of the download to stderr 176 | """ 177 | return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs) 178 | 179 | 180 | def vgg19_bn(pretrained=False, progress=True, **kwargs): 181 | r"""VGG 19-layer model (configuration 'E') with batch normalization 182 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 183 | Args: 184 | pretrained (bool): If True, returns a model pre-trained on ImageNet 185 | progress (bool): If True, displays a progress bar of the download to stderr 186 | """ 187 | return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs) -------------------------------------------------------------------------------- /setup.json: -------------------------------------------------------------------------------- 1 | { 2 | "_comment":"it is the record of path([] is represented as list and {} is represented as dict in python)", 3 | "DATASET":{ 4 | "ECSSD":["../dataset/ECSSD/images", "../dataset/ECSSD/segmentations"], 5 | "MSRA10K":["../dataset/MSRA10K/images", "../dataset/MSRA10K/segmentations"], 6 | "DUTS-TE":["../dataset/DUTS-TE/images", "../dataset/DUTS-TE/segmentations"], 7 | "DUTS-TR":["../dataset/DUTS-TR/images", "../dataset/DUTS-TR/segmentations"], 8 | "DUT-OMRON":["../dataset/DUT-OMRON/images", "../dataset/DUT-OMRON/segmentations"], 9 | "PASCAL-S":["../dataset/PASCAL-S/images", "../dataset/PASCAL-S/segmentations"], 10 | "SOD":["../dataset/SOD/images", "../dataset/SOD/segmentations"], 11 | "HKU-IS":["../dataset/HKU-IS/images", "../dataset/HKU-IS/segmentations"] 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/Eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import src 3 | from src import Metrics, utils 4 | import collections 5 | 6 | class Eval: 7 | def __init__(self, L): 8 | self.Loader = L 9 | self.scores = collections.OrderedDict() 10 | for val in L.vals: 11 | self.scores[val] = src.Score(val, L) 12 | 13 | def eval_Saliency(self, Model, epoch=0, supervised=True): 14 | savedict = {} 15 | Outputs = {val : Metrics.getOutPuts(Model, valdata['X'], self.Loader, supervised=supervised) for val, valdata in self.Loader.valdatas.items()} 16 | 17 | for val in self.Loader.valdatas.keys(): 18 | pred, valdata = Outputs[val]['final'], self.Loader.valdatas[val]['Y'] 19 | F = Metrics.maxF(pred, valdata, self.Loader.ids[-1]) 20 | M = Metrics.mae(pred, valdata) 21 | 22 | saves = self.scores[val].update([F, M], epoch) 23 | savedict[val] = saves 24 | 25 | for val, score in self.scores.items(): 26 | score.print_present() 27 | print('-----------------------------------------') 28 | 29 | if self.Loader.MODE == 'train': 30 | torch.save(utils.makeDict(Model.state_dict()), utils.genPath(self.Loader.spath, 'present.pkl')) 31 | for val, saves in savedict.items(): 32 | for idx, save in enumerate(saves): 33 | if save: 34 | torch.save(utils.makeDict(Model.state_dict()), utils.genPath(self.Loader.spath, val+'_'+['F', 'M'][idx]+'.pkl')) 35 | 36 | for val, score in self.scores.items(): 37 | score.print_best() 38 | 39 | else: 40 | for val in self.Loader.valdatas.keys(): 41 | Outputs[val]['Name'] = self.Loader.valdatas[val]['Name'] 42 | Outputs[val]['Shape'] = self.Loader.valdatas[val]['Shape'] 43 | 44 | return Outputs if self.Loader.save else None 45 | -------------------------------------------------------------------------------- /src/Experiment.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import src 4 | from src import Loss 5 | 6 | class Experiment(object): 7 | def __init__(self, L, E): 8 | self.Loader = L 9 | self.Eval = E 10 | self.Model = L.Model 11 | if L.MODE in ('test') or L.mode == 'ft': 12 | self.Model.load_state_dict(torch.load(L.mpath, map_location='cpu')) 13 | 14 | self.Model = self.Model.eval() if L.MODE == 'test' else self.Model.train() 15 | if not L.cpu: 16 | self.Model = torch.nn.DataParallel(self.Model.cuda(L.ids[0]), device_ids=L.ids) 17 | 18 | def optims(self, optim, params): 19 | 20 | if optim == 'SGD': 21 | print('using SGD') 22 | return torch.optim.SGD(params=params, momentum=0.9, weight_decay=0.0005) 23 | 24 | elif optim == 'Adam': 25 | print('using Adam') 26 | return torch.optim.Adam(params=params, weight_decay=0.0005) 27 | 28 | def schedulers(self, scheduler, optimizer): 29 | 30 | if scheduler == 'StepLR': 31 | return torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1) 32 | 33 | elif scheduler == 'ReduceLROnPlateau': 34 | return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) 35 | -------------------------------------------------------------------------------- /src/Loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | 5 | import src 6 | import collections 7 | from src import utils 8 | import os 9 | 10 | import torch 11 | from models.baseline import baseline 12 | from thop import profile 13 | #from flop import print_model_parm_flops, print_model_parm_nums 14 | 15 | nece_args = { 16 | 'normal': ['batch', 'model', 'ids', 'spath', 'metrics', 'highers', 'mpath'], 17 | 'train': ['optim', 'loss', 'trset', 'sub', 'iter', 'epoch', 'scheduler', 'weights', 'no_contour'], 18 | 'test': ['save', 'rpath'] 19 | } 20 | 21 | def args(mode): 22 | assert mode in ['train', 'test', 'debug'] 23 | parser = argparse.ArgumentParser() 24 | 25 | if mode == 'train': 26 | parser.add_argument('--optim', default='SGD', help='set the optimizer of model [Adadelta, Adagrad, Adam, SparseAdam, Adamax, ASGD, LBFGS, RMSprop, Rprop, SGD]') 27 | parser.add_argument('--trset', default='DUTS-TR', help='set the traing set') 28 | parser.add_argument('--scheduler', default='StepLR', help='set the scheduler') 29 | parser.add_argument('--lr', default=0.01, type=float, help='set base learning rate') 30 | 31 | parser.add_argument('--model', default='resnet', help='Set the model') 32 | parser.add_argument('--batch', default=8, type=int, help='Batch Size') 33 | parser.add_argument('--size', default=288, type=int, help='Image Size') 34 | parser.add_argument('--vals', default='', help='Validation sets') 35 | parser.add_argument('--ids', default='0,1', help='Set the cuda devices') 36 | parser.add_argument('--sub', default='baseline', help='The name of network') 37 | parser.add_argument('--cpu', action='store_true') 38 | parser.add_argument('--debug', action='store_true') 39 | 40 | parser.add_argument('--save', action='store_false') 41 | parser.add_argument('--supervised', action='store_true') 42 | parser.add_argument('--spath', default='save', help='model path') 43 | parser.add_argument('--rpath', default='result', help='visualization path') 44 | 45 | return parser.parse_args() 46 | 47 | class Loader(object): 48 | def __init__(self, MODE): 49 | assert MODE in ['train', 'test', 'debug'] 50 | self.MODE = MODE 51 | 52 | opt = args(MODE) 53 | 54 | print('loading the settings') 55 | self.loading(opt) 56 | 57 | def check(self, nece, args): 58 | self.nece = nece['normal'] + nece[self.MODE] 59 | for arg in self.nece: 60 | if getattr(args, arg) is None: 61 | print('miss the %s' % (arg)) 62 | return False 63 | 64 | for arg in self.nece: 65 | self.__setattr__(arg, getattr(args, arg)) 66 | 67 | return True 68 | 69 | def loading(self, opt): 70 | self.sub = opt.sub 71 | self.debug = opt.debug 72 | self.model = opt.model 73 | 74 | if self.MODE == 'train': 75 | self.weights = [0.1, 0.3, 0.5, 0.7, 0.9, 1.5] 76 | 77 | self.optim = opt.optim 78 | self.batch = opt.batch 79 | self.scheduler = opt.scheduler 80 | self.lr = opt.lr 81 | 82 | self.plist = [['encoder', self.lr*0.1], ['decoder', self.lr]] 83 | 84 | self.trset = 'SOD' if self.debug else opt.trset 85 | self.trSet = src.DataSet(self.trset, mode='train', shape=opt.size, debug=self.debug) 86 | 87 | self.epoch = int(math.ceil(self.trSet.size / self.batch)) if self.MODE == 'train' else 10 88 | self.iter = self.epoch * 25 89 | else: 90 | self.batch = 1 91 | 92 | 93 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.ids 94 | num_gpu = len(opt.ids.split(',')) 95 | self.ids = list(range(num_gpu)) 96 | print('Backbone: {}, Using Gpu: {}'.format(self.model, opt.ids)) 97 | 98 | self.cpu = opt.cpu 99 | self.save = opt.save 100 | self.supervised = opt.supervised 101 | self.rpath = opt.rpath 102 | if self.save and not os.path.exists(self.rpath): 103 | os.makedirs(self.rpath) 104 | npath = utils.genPath(self.rpath, self.sub) 105 | if not os.path.exists(npath): 106 | os.makedirs(npath) 107 | 108 | self.spath = utils.genPath(opt.spath, self.model, self.sub) 109 | if not os.path.exists(self.spath): 110 | os.makedirs(self.spath) 111 | self.mpath = self.spath + '/present.pkl' 112 | 113 | 114 | if self.debug: 115 | self.vals = ['SOD'] 116 | elif opt.vals == '': 117 | self.vals = ['SOD', 'PASCAL-S', 'ECSSD', 'DUTS-TE', 'HKU-IS', 'DUT-OMRON'] 118 | else: 119 | self.vals = [opt.vals, ] 120 | 121 | self.valdatas = collections.OrderedDict() 122 | for val in self.vals: 123 | self.valdatas[val] = src.DataSet(val, mode='test', shape=opt.size).getFull() 124 | 125 | self.channel = 16 if self.model.startswith('mobile') else 64 126 | self.Model = baseline(self.model, self.channel) 127 | 128 | #print_model_parm_flops(self.Model) 129 | #print_model_parm_nums(self.Model) 130 | 131 | #input = torch.randn(1, 3, opt.size, opt.size) 132 | #flops, params = profile(self.Model, inputs=(input, )) 133 | #print('FLOPs: {:.2f}, Params: {:.2f}.'.format(flops / 1e9, params / 1e6)) 134 | 135 | 136 | self.mode = self.MODE 137 | -------------------------------------------------------------------------------- /src/Loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, autograd, optim, Tensor, cuda 3 | from torch.nn import functional as F 4 | from torch.autograd import Variable 5 | 6 | 7 | class FS_loss(nn.Module): 8 | def __init__(self, weights, b=0.3): 9 | super(FS_loss, self).__init__() 10 | self.contour = weights 11 | self.b = b 12 | 13 | def forward(self, X, Y, weights): 14 | loss = 0 15 | batch = Y.size(0) 16 | 17 | for weight, x in zip(weights, X): 18 | pre = x.sigmoid_() 19 | scale = int(Y.size(2) / x.size(2)) 20 | pos = F.avg_pool2d(Y, kernel_size=scale, stride=scale).gt(0.5).float() 21 | tp = pre * pos 22 | 23 | tp = (tp.view(batch, -1)).sum(dim = -1) 24 | posi = (pos.view(batch, -1)).sum(dim = -1) 25 | pre = (pre.view(batch, -1)).sum(dim = -1) 26 | 27 | f_score = tp * (1 + self.b) / (self.b * posi + pre) 28 | loss += weight * (1 - f_score.mean()) 29 | return loss 30 | 31 | 32 | 33 | def ACT(X, batchs, args): 34 | bce = nn.BCEWithLogitsLoss(reduction='none') 35 | slc_gt = torch.tensor(batchs['Y']).cuda() 36 | ctr_gt = torch.tensor(batchs['C']).cuda() 37 | 38 | slc_loss, ctr_loss = 0, 0 39 | for slc_pred, ctr_pred, weight in zip(X['preds'], X['contour'], args.weights): 40 | scale = int(slc_gt.size(-1) / slc_pred.size(-1)) 41 | ys = F.avg_pool2d(slc_gt, kernel_size=scale, stride=scale).gt(0.5).float() 42 | yc = F.max_pool2d(ctr_gt, kernel_size=scale, stride=scale) 43 | 44 | slc_pred = slc_pred.squeeze(1) 45 | 46 | # contour loss 47 | #w = torch.yc 48 | 49 | # ACT loss 50 | pc = ctr_pred.sigmoid_() 51 | w = torch.where(pc > yc, pc, yc) 52 | 53 | slc_loss += (bce(slc_pred, ys) * (w * 4 + 1)).mean() * weight 54 | 55 | if ctr_pred is not None: 56 | ctr_pred = ctr_pred.squeeze(1) 57 | ctr_loss += bce(ctr_pred, yc).mean() * weight 58 | 59 | pc = F.interpolate(pc.unsqueeze(1), size=ctr_gt.size()[-2:], mode='bilinear').squeeze(1) 60 | w = torch.where(pc > ctr_gt, pc, ctr_gt) 61 | fnl_loss = (bce(X['final'], slc_gt.gt(0.5).float()) * (w * 4 + 1)).mean() * args.weights[-1] 62 | 63 | return fnl_loss + ctr_loss + slc_loss 64 | 65 | 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /src/Metrics.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import time 4 | 5 | import torch 6 | from torch import nn, autograd, optim, Tensor, cuda 7 | from torch.nn import functional as F 8 | 9 | class PRF(nn.Module): 10 | def __init__(self, device, Y, steps=255, end=1): 11 | super(PRF, self).__init__() 12 | self.thresholds = torch.linspace(0, end, steps=steps).cuda(device) 13 | self.Y = Y 14 | 15 | def forward(self, _Y): 16 | TPs = [torch.sum(torch.sum((_Y >= threshold) & (self.Y), -1), -1).float() for threshold in self.thresholds] 17 | T1s = [torch.sum(torch.sum(_Y >= threshold, -1), -1).float() for threshold in self.thresholds] 18 | T2 = Tensor.float(torch.sum(torch.sum(self.Y, -1), -1)) 19 | Ps = [(TP / (T1 + 1e-9)).mean() for TP, T1 in zip(TPs, T1s)] 20 | Rs = [(TP / (T2 + 1e-9)).mean() for TP in TPs] 21 | Fs = [(1.3 * P * R / (R + 0.3 * P + 1e-9)) for P, R in zip(Ps, Rs)] 22 | 23 | return {'P':Ps, 'R':Rs, 'F':Fs} 24 | 25 | def getOutPuts(model, DX, args, supervised=False): 26 | num_img, channel, height, width = DX.shape 27 | if supervised: 28 | OutPuts = {'final':np.empty((len(DX), height, width), dtype=np.float32), 'contour':np.empty((len(DX), 5, height, width), dtype=np.float32), 'preds':np.empty((len(DX), 5, height, width), dtype=np.float32), 'time':0.} 29 | else: 30 | OutPuts = {'final':np.empty((len(DX), height, width), dtype=np.float32), 'time':0.} 31 | t1 = time.time() 32 | 33 | for idx in range(0, len(DX), args.batch): 34 | ind = min(len(DX), idx + args.batch) 35 | X = torch.tensor(DX[idx:ind]).cuda(args.ids[0]).float() 36 | Outs = model(X) 37 | 38 | OutPuts['final'][idx:ind] = torch.sigmoid(Outs['final']).cpu().data.numpy() 39 | 40 | if supervised: 41 | for supervision in ['preds', 'contour']: 42 | for i, pred in enumerate(Outs[supervision]): 43 | pre = F.interpolate(pred.unsqueeze(0), (height, width), mode='bilinear')[0] 44 | pre = torch.sigmoid(pre).cpu().data.numpy() 45 | OutPuts[supervision][idx:ind, i] = pre 46 | 47 | X, Outs, pre = 0, 0, 0 48 | 49 | OutPuts['time'] = (time.time() - t1) 50 | 51 | 52 | return OutPuts 53 | 54 | 55 | def mae(preds, labels, th=0.5): 56 | return np.mean(np.abs(preds - labels)) 57 | 58 | def fscore(preds, labels, th=0.5): 59 | tmp = preds >= th 60 | TP = np.sum(tmp & labels) 61 | T1 = np.sum(tmp) 62 | T2 = np.sum(labels) 63 | F = 1.3 * TP / (T1 + 0.3 * T2 + 1e-9) 64 | 65 | return F 66 | 67 | def maxF(preds, labels, device): 68 | preds = torch.tensor(preds).cuda(device) 69 | labels = torch.tensor(labels, dtype=torch.uint8).cuda(device) 70 | 71 | prf = PRF(device, labels).cuda(device) 72 | Fs = prf(preds)['F'] 73 | Fs = [F.cpu().data.numpy() for F in Fs] 74 | 75 | prf.to(torch.device('cpu')) 76 | torch.cuda.empty_cache() 77 | return max(Fs) 78 | 79 | def Normalize(atten): 80 | 81 | a_min, a_max = atten.min(), atten.max() 82 | atten = (atten - a_min) * 1. / (a_max - a_min) * 255. 83 | 84 | return np.uint8(atten) 85 | -------------------------------------------------------------------------------- /src/Score.py: -------------------------------------------------------------------------------- 1 | class Score: 2 | def __init__(self, name, loader): 3 | self.name = name 4 | self.metrics = ['F', 'M'] 5 | self.highers = [1, 0] 6 | self.scores = [0. if higher else 1. for higher in self.highers] 7 | self.best = self.scores 8 | self.best_epoch = [0] * len(self.scores) 9 | self.present = self.scores 10 | 11 | def update(self, scores, epoch): 12 | 13 | self.present = scores 14 | self.epoch = epoch 15 | self.best = [max(best, score) if self.highers[idx] else min(best, score) for idx, (best, score) in enumerate(zip(self.best, scores))] 16 | self.best_epoch = [epoch if present == best else best_epoch for present, best, best_epoch in zip(self.present, self.best, self.best_epoch)] 17 | saves = [epoch == best_epoch for best_epoch in self.best_epoch] 18 | 19 | return saves 20 | 21 | def print_present(self): 22 | m_str = '{} : {:.4f}, {} : {:.4f} on ' + self.name 23 | m_list = [] 24 | for metric, present in zip(self.metrics, self.present): 25 | m_list.append(metric) 26 | m_list.append(present) 27 | print(m_str.format(*m_list)) 28 | 29 | 30 | def print_best(self): 31 | m_str = 'Best score: {}_{} : {:.4f}, {}_{} : {:.4f} on ' + self.name 32 | m_list = [] 33 | for metric, best, best_epoch in zip(self.metrics, self.best, self.best_epoch): 34 | m_list.append(metric) 35 | m_list.append(best_epoch) 36 | m_list.append(best) 37 | print(m_str.format(*m_list)) 38 | 39 | -------------------------------------------------------------------------------- /src/Tester.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import src 3 | from src import utils, Metrics 4 | import numpy as np 5 | import os 6 | from PIL import Image 7 | 8 | class Tester(src.Experiment): 9 | def __init__(self, L, E): 10 | super(Tester, self).__init__(L, E) 11 | 12 | def test(self): 13 | self.supervised = self.Loader.supervised 14 | self.Outputs = self.Eval.eval_Saliency(self.Model, supervised=self.supervised) 15 | if self.Loader.save: 16 | 17 | self.save_preds() 18 | 19 | def save_preds(self): 20 | for valname, output in self.Outputs.items(): 21 | rpath = utils.genPath(self.Loader.rpath, self.Loader.sub, valname) 22 | if not os.path.exists(rpath): 23 | os.makedirs(rpath) 24 | 25 | names, shapes, finals, time = output['Name']['Y'], output['Shape'], output['final'] * 255., output['time'] 26 | for name, shape, final in zip(names, shapes, finals): 27 | ppath = utils.genPath(rpath, 'final') 28 | if not os.path.exists(ppath): 29 | os.makedirs(ppath) 30 | Image.fromarray(np.uint8(final)).resize((shape), Image.BICUBIC).save(utils.genPath(ppath, name)) 31 | 32 | if self.supervised: 33 | preds, conts = output['preds'] * 255., output['contour'] * 255. 34 | for name, shape, pred, cont in zip(names, shapes, preds, conts): 35 | for idx, pre in enumerate(pred): 36 | pred_path = utils.genPath(rpath, 'pred_'+ str(idx+1)) 37 | if not os.path.exists(pred_path): 38 | os.makedirs(pred_path) 39 | Image.fromarray(np.uint8(pre)).convert('L').save(utils.genPath(pred_path, name.split('.')[0]+'.png')) 40 | 41 | for idx, pre in enumerate(cont): 42 | pred_path = utils.genPath(rpath, 'cont_'+ str(idx+1)) 43 | if not os.path.exists(pred_path): 44 | os.makedirs(pred_path) 45 | Image.fromarray(np.uint8(pre)).convert('L').save(utils.genPath(pred_path, name.split('.')[0]+'.png')) 46 | 47 | print('Save predictions of datasets: {}.'.format(valname)) 48 | -------------------------------------------------------------------------------- /src/Trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import src 4 | from src import utils 5 | from progress.bar import Bar 6 | from src.Loss import ACT 7 | from torch import nn 8 | 9 | def freeze_bn(model): 10 | for m in model.encoder.modules(): 11 | if isinstance(m, nn.BatchNorm2d): 12 | m.eval() 13 | m.weight.requires_grad = False 14 | m.bias.requires_grad = False 15 | 16 | 17 | class Trainer(src.Experiment): 18 | def __init__(self, L, E): 19 | super(Trainer, self).__init__(L, E) 20 | self.epochs = L.iter // L.epoch 21 | multi_gpu = len(L.ids) > 1 22 | 23 | self.params = utils.genParams(L.plist, model=self.Model.module) 24 | self.optimizer = self.optims(L.optim, self.params) 25 | if multi_gpu: 26 | self.optimizer = torch.nn.DataParallel(self.optimizer, device_ids=L.ids) 27 | self.optimizer = self.optimizer.module 28 | freeze_bn(self.Model.module) 29 | self.scheduler = self.schedulers(L.scheduler, self.optimizer) if L.scheduler != 'None' else None 30 | self.loss = ACT 31 | 32 | def epoch(self, idx): 33 | st = time.time() 34 | ans = 0 35 | print('---------------------------------------------------------------------------') 36 | bar = Bar('{} | epoch {}:'.format(self.Loader.sub, idx), max=self.Loader.epoch) 37 | 38 | for i in range(self.Loader.epoch): 39 | self.optimizer.zero_grad() 40 | batchs = self.Loader.trSet.getBatch(self.Loader.batch) 41 | X = torch.tensor(batchs['X'], requires_grad=True).float().cuda(self.Loader.ids[0]) 42 | _y = self.Model(X, 'tr') 43 | loss = self.loss(_y, batchs, self.Loader) 44 | X, _y = 0, 0 45 | ans += loss.cpu().data.numpy() 46 | loss.backward() 47 | 48 | Bar.suffix = '{}/{} | loss: {}'.format(i, self.Loader.epoch, ans * 1. / (i + 1)) 49 | self.optimizer.step() 50 | bar.next() 51 | 52 | bar.finish() 53 | print('epoch: {}, time: {}, loss: {:.5f}.'.format(idx, time.time() - st, ans * 1. / self.Loader.epoch)) 54 | 55 | st = time.time() 56 | self.Eval.eval_Saliency(self.Model, epoch=idx, supervised=False) 57 | print('Evaluate using time: {:.5f}.'.format(time.time() - st)) 58 | 59 | def train(self): 60 | for idx in range(self.epochs): 61 | self.scheduler.step() 62 | self.epoch(idx+1) -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataSet import * 2 | from .Eval import * 3 | from .Experiment import * 4 | from .Loader import * 5 | from .Tester import * 6 | from .Trainer import * 7 | from .Score import * -------------------------------------------------------------------------------- /src/dataSet.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import random 4 | import time 5 | import json 6 | import os 7 | import cv2 8 | from src import utils 9 | 10 | class _DataSet(object): 11 | def __init__(self, name, mode='train', shape=224, debug=False): 12 | 13 | super(_DataSet, self).__init__() 14 | self.name = name 15 | 16 | self.tic = time.time() 17 | self.mean = np.array([0.485, 0.458, 0.407]).reshape([1, 3, 1, 1]) 18 | self.std = np.array([0.229, 0.224, 0.225]).reshape([1, 3, 1, 1]) 19 | self.dataset = utils.loadJson(utils.genPath('setup'))['DATASET'] 20 | assert name in self.dataset 21 | self.path = self.dataset[self.name] 22 | 23 | assert mode in ['train', 'test'] 24 | self.mode = mode 25 | 26 | self.shape = shape 27 | self.flip_code = [1, 0, -1] 28 | 29 | self.getNames() 30 | 31 | def getNames(self): 32 | names = list(map(lambda x : sorted(os.listdir(x)), self.path)) 33 | paths = [list(map(lambda x : utils.genPath(path, x), name)) for name, path in zip(names, self.path)] 34 | 35 | self.names = {'X' : names[0], 'Y' : names[1]} 36 | self.paths = paths 37 | self.size = len(paths[0]) 38 | 39 | def Crop(self, X, Y, shape): 40 | dice = random.random() 41 | w, h, _ = X.shape 42 | 43 | if dice < .1: 44 | rand_w = random.randint(0, w - shape) 45 | rand_h = random.randint(0, h - shape) 46 | X = X[rand_w:rand_w+shape, rand_h:rand_h+shape] 47 | Y = Y[rand_w:rand_w+shape, rand_h:rand_h+shape] 48 | else: 49 | X = cv2.resize(X, (shape, shape)) 50 | Y = cv2.resize(Y, (shape, shape)) 51 | 52 | return X, Y 53 | 54 | def random_rotate(self, X, Y): 55 | angle = np.random.randint(-25,25) 56 | h, w = Y.shape 57 | center = (w / 2, h / 2) 58 | M = cv2.getRotationMatrix2D(center, angle, 1.0) 59 | 60 | X = cv2.warpAffine(X, M, (w, h)) 61 | Y = cv2.warpAffine(Y, M, (w, h)) 62 | return X, Y 63 | 64 | def random_light(self, x): 65 | contrast = np.random.rand(1) + 0.5 66 | light = np.random.randint(-20,20) 67 | x = contrast * x + light 68 | return np.clip(x,0,255) 69 | 70 | def Flip(self, X, Y): 71 | dice = random.randint(0, 1) 72 | return (X, Y) if dice == 0 else (cv2.flip(X, self.flip_code[dice-1]), cv2.flip(Y, self.flip_code[dice-1])) 73 | 74 | def Normalize(self, X): 75 | X = X.transpose((0, 3, 1, 2)) 76 | X /= 255. 77 | X -= self.mean 78 | X /= self.std 79 | return X 80 | 81 | def shuffle(self): 82 | self.index = 0 83 | random.shuffle(self.idlist) 84 | 85 | def reProcess(self, X): 86 | X = ((X * self.std) + self.mean) * 255 87 | X = X.transpose((0, 2, 3, 1)) 88 | return X 89 | 90 | def help(self): 91 | print('shape : %s' % (list(self.shape))) 92 | print('[0:N,1:H,2:V,3H+V]') 93 | print('mode : %s' % (self.mode)) 94 | print('[train,test,trainval]') 95 | 96 | class DataSet(_DataSet): 97 | def __init__(self, name, mode='train', shape=224, debug=False): 98 | super(DataSet, self).__init__(name=name, mode=mode, shape=shape, debug=debug) 99 | 100 | if mode == 'train': 101 | self.idlist = list(range(self.size)) 102 | self.shuffle() 103 | 104 | self.debug = debug 105 | 106 | def getFull(self): 107 | DataDict = { 108 | 'X':np.empty((self.size, self.shape, self.shape, 3), dtype=np.float32), 109 | 'Y':np.empty((self.size, self.shape, self.shape), dtype=np.float32), 110 | 'C':np.empty((self.size, self.shape, self.shape), dtype=np.float32) 111 | } 112 | self.sizes = [] 113 | if self.debug: 114 | print('debug mode! random images are generated.') 115 | DataDict['X'] = np.random.rand(self.size, self.shape, self.shape, 3) * 255 116 | DataDict['Y'] = np.random.rand(self.size, self.shape, self.shape) * 255 117 | DataDict['C'] = np.random.rand(self.size, self.shape, self.shape) 118 | else: 119 | for idx in range(self.size): 120 | DataDict['X'][idx] = np.array(Image.open(self.paths[0][idx]).convert('RGB').resize((self.shape, self.shape), Image.ANTIALIAS)) 121 | imgY = Image.open(self.paths[1][idx]).convert('L') 122 | self.sizes.append(imgY.size) 123 | DataDict['Y'][idx] = (np.array(imgY.resize((self.shape, self.shape), Image.ANTIALIAS)) > 127).astype(np.float64) 124 | 125 | kernel = np.ones((5, 5)) 126 | for idx, y in enumerate(DataDict['Y']): 127 | DataDict['C'][idx] = cv2.dilate(y, kernel) - cv2.erode(y, kernel) 128 | 129 | DataDict['X'] = self.Normalize(DataDict['X']) 130 | FullDict = {'X' : DataDict['X'], 'Y' : np.int32(DataDict['Y']), 'C' : np.int32(DataDict['C']), 'Name' : self.names, 'Shape' : self.sizes} 131 | print('loading {} images from {} using time: {}s.'.format(self.size, self.name, round(time.time() - self.tic, 3))) 132 | 133 | return FullDict 134 | 135 | 136 | def getBatch(self, batch): 137 | scales = [-1, 0, 1] 138 | b_size = int(np.random.choice(scales, 1)) * 32 + self.shape 139 | aug_shape = int(b_size * 1.1) 140 | 141 | OutPuts = { 142 | 'X':np.empty((batch, b_size, b_size, 3), dtype=np.float32), 143 | 'Y':np.empty((batch, b_size, b_size), dtype=np.float32), 144 | 'C':np.empty((batch, b_size, b_size), dtype=np.float32) 145 | } 146 | 147 | for idx in range(batch): 148 | if self.index == self.size: 149 | self.shuffle() 150 | 151 | index = self.idlist[self.index] 152 | X = np.array(Image.open(self.paths[0][index]).convert('RGB').resize((aug_shape, aug_shape), Image.ANTIALIAS)) 153 | Y = (np.array(Image.open(self.paths[1][index]).convert('L').resize((aug_shape, aug_shape), Image.ANTIALIAS)) > 127).astype(np.float64) 154 | 155 | X, Y = self.Crop(X, Y, b_size) 156 | X, Y = self.Flip(X, Y) 157 | 158 | kernel = np.ones((5, 5)) 159 | C = cv2.dilate(Y, kernel) - cv2.erode(Y, kernel) 160 | OutPuts['X'][idx], OutPuts['Y'][idx], OutPuts['C'][idx] = X, Y, C 161 | self.index += 1 162 | 163 | OutPuts['X'] = self.Normalize(OutPuts['X']) 164 | return OutPuts 165 | 166 | if __name__ == '__main__': 167 | 168 | name = 'SOD' 169 | 170 | data = DataSet(name, 'train') 171 | batchsize = 1 172 | out = data.getBatch(batchsize) 173 | 174 | for i in range(300): 175 | out = data.getBatch(batchsize) 176 | 177 | Y = out['Y'] * 255 178 | C = out['C'] * 255 179 | 180 | X = data.reProcess(out['X']) 181 | for x, y, c in zip(X, Y, C): 182 | imx = Image.fromarray(np.uint8(x)) 183 | imy = Image.fromarray(np.uint8(y)) 184 | imz = Image.fromarray(np.uint8(c)) 185 | 186 | imx.save('temp/{}x.jpg'.format(i)) 187 | imy.save('temp/{}y.jpg'.format(i)) 188 | imz.save('temp/{}z.jpg'.format(i)) 189 | 190 | ''' 191 | out = data.getFull() 192 | for i in range(300): 193 | #print(out['X'][i]) 194 | 195 | X = np.expand_dims(out['X'][i], 0) 196 | X = data.reProcess(X) 197 | #out = data.getBatch(batchsize) 198 | 199 | Y = np.expand_dims(out['Y'][i], 0) * 255 200 | C = np.expand_dims(out['C'][i], 0) * 255 201 | 202 | #X = data.reProcess(out['X']) 203 | for x, y, c in zip(X, Y, C): 204 | imx = Image.fromarray(np.uint8(x)) 205 | imy = Image.fromarray(np.uint8(y)) 206 | imz = Image.fromarray(np.uint8(c)) 207 | 208 | imx.save('temp/{}x.jpg'.format(i)) 209 | imy.save('temp/{}y.jpg'.format(i)) 210 | imz.save('temp/{}z.jpg'.format(i)) 211 | #print(fulldict['Name']['X']) 212 | #print(fulldict['Shape']) 213 | ''' -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import random 4 | import time 5 | import json 6 | import os 7 | 8 | import torch 9 | from torch import nn 10 | from functools import reduce 11 | 12 | def deStr(param): 13 | if '=' not in param: 14 | return param, True 15 | 16 | else: 17 | key, value = param.split('=') 18 | 19 | if value.isalnum(): 20 | return (key, int(value)) if value.isdigit() else (key, value) 21 | 22 | else: 23 | assert '.' in value 24 | return key, float(value) 25 | 26 | def loadJson(fileName): 27 | with open(fileName + '.json', 'r') as jfile: 28 | return json.load(jfile) 29 | 30 | def checkKey(key): 31 | while key[:7] == 'module.': 32 | key = key[7:] 33 | 34 | return key 35 | 36 | def makeDict(sdict): 37 | return {checkKey(k) : v for k, v in sdict.items()} 38 | 39 | #param: eg. encoder:vgg16:0.0001:pretrain,k=5,multilayers 40 | def deParams(params, mode): 41 | mdict = {param.split(':')[0] : param.split(':')[1] for param in params} 42 | mparams = {param.split(':')[1] : {key : value for key, value in filter(lambda p : p[0] != '', map(deStr, param.split(':')[-1].split(',')))} for param in params} 43 | 44 | if mode == 'train' or mode == 'ft': 45 | plist = [[param.split(':')[0], float(param.split(':')[2])] for param in params] 46 | return mdict, plist, mparams 47 | 48 | return mdict, mparams 49 | 50 | def genParams(plist, model): 51 | return [{'params' : getattr(model, p[0]).parameters(), 'lr' : p[1]} for p in plist] 52 | 53 | def genPath(*paths): 54 | return reduce(lambda x, y : x + '/' + y, paths) 55 | 56 | def initModule(modules): 57 | for module in modules: 58 | if type(module) is nn.Conv2d or type(module) is nn.Linear: 59 | nn.init.kaiming_normal_(module.weight) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import src 2 | 3 | def test(): 4 | 5 | L = src.Loader('test') 6 | E = src.Eval(L) 7 | TE = src.Tester(L, E) 8 | TE.test() 9 | 10 | if __name__ == '__main__': 11 | 12 | test() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from src import * 2 | 3 | def train(): 4 | L = Loader('train') 5 | E = Eval(L) 6 | TR = Trainer(L, E) 7 | TR.train() 8 | 9 | if __name__ == '__main__': 10 | train() --------------------------------------------------------------------------------