├── README.md ├── data ├── gen_list.py ├── test.lst ├── train.lst ├── trainval.lst └── val.lst ├── eval.py ├── model ├── Danet │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── backbone.cpython-36.pyc │ │ ├── backbone.cpython-37.pyc │ │ ├── danet.cpython-36.pyc │ │ ├── danet.cpython-37.pyc │ │ ├── resnet.cpython-36.pyc │ │ └── resnet.cpython-37.pyc │ ├── backbone.py │ ├── danet.py │ └── resnet.py ├── Unet.py ├── Unet_module.py ├── __init__.py ├── __pycache__ │ ├── Unet.cpython-36.pyc │ ├── Unet.cpython-37.pyc │ ├── Unet_module.cpython-36.pyc │ ├── Unet_module.cpython-37.pyc │ ├── __init__.cpython-36.pyc │ └── __init__.cpython-37.pyc ├── deeplabv3 │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── _deeplab.cpython-36.pyc │ │ ├── _deeplab.cpython-37.pyc │ │ ├── modeling.cpython-36.pyc │ │ ├── modeling.cpython-37.pyc │ │ ├── utils.cpython-36.pyc │ │ └── utils.cpython-37.pyc │ ├── _deeplab.py │ ├── backbone │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── mobilenetv2.cpython-36.pyc │ │ │ ├── mobilenetv2.cpython-37.pyc │ │ │ ├── resnet.cpython-36.pyc │ │ │ ├── resnet.cpython-37.pyc │ │ │ ├── xception.cpython-36.pyc │ │ │ └── xception.cpython-37.pyc │ │ ├── mobilenetv2.py │ │ ├── resnet.py │ │ ├── sync_batchnorm │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-36.pyc │ │ │ │ ├── __init__.cpython-37.pyc │ │ │ │ ├── batchnorm.cpython-36.pyc │ │ │ │ ├── batchnorm.cpython-37.pyc │ │ │ │ ├── comm.cpython-36.pyc │ │ │ │ ├── comm.cpython-37.pyc │ │ │ │ ├── replicate.cpython-36.pyc │ │ │ │ └── replicate.cpython-37.pyc │ │ │ ├── batchnorm.py │ │ │ ├── comm.py │ │ │ ├── replicate.py │ │ │ └── unittest.py │ │ └── xception.py │ ├── modeling.py │ └── utils.py └── hrnet │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── hrnet.cpython-36.pyc │ ├── hrnet.cpython-37.pyc │ ├── hrnet_module.cpython-36.pyc │ └── hrnet_module.cpython-37.pyc │ ├── hrnet.py │ └── hrnet_module.py ├── predict.py ├── runs ├── Sep14_22-36-00_amax-101studentLR_0.001_BS_32 │ └── events.out.tfevents.1600094160.amax-101.18906.0 ├── Sep15_12-37-15_amax-101studentLR_0.0001_BS_32 │ └── events.out.tfevents.1600144635.amax-101.9169.0 ├── Sep15_13-55-14_amax-101studentLR_0.0001_BS_32 │ └── events.out.tfevents.1600149314.amax-101.17317.0 └── Sep15_19-33-06_amax-101studentLR_0.0001_BS_32 │ └── events.out.tfevents.1600169586.amax-101.20053.0 ├── train.py ├── train ├── images │ ├── 1_3_2.tif │ ├── 1_3_3.tif │ ├── 1_4_0.tif │ ├── 1_4_1.tif │ ├── 1_4_2.tif │ ├── 1_4_3.tif │ ├── 1_5_0.tif │ ├── 1_5_1.tif │ ├── 1_5_2.tif │ └── 1_5_3.tif └── labels │ ├── 1_3_2.png │ ├── 1_3_3.png │ ├── 1_4_0.png │ ├── 1_4_1.png │ ├── 1_4_2.png │ ├── 1_4_3.png │ ├── 1_5_0.png │ ├── 1_5_1.png │ ├── 1_5_2.png │ └── 1_5_3.png ├── train_student.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── __init__.cpython-37.pyc ├── basic_dataset.cpython-36.pyc ├── basic_dataset.cpython-37.pyc ├── data_vis.cpython-36.pyc ├── data_vis.cpython-37.pyc ├── function.cpython-36.pyc ├── function.cpython-37.pyc ├── hrnet_loss.cpython-36.pyc ├── loss.cpython-36.pyc ├── loss.cpython-37.pyc ├── polyLR.cpython-36.pyc └── polyLR.cpython-37.pyc ├── basic_dataset.py ├── data_vis.py ├── function.py ├── loss.py └── polyLR.py /README.md: -------------------------------------------------------------------------------- 1 | # Image_segmentation 2 | 对高分辨率光学遥感图像中各类地物光谱信息和空间信息进行分析,将图像中具有语义信息的各个像元分别赋予语义类别标签;以包含典型土地利用分类的光学遥感图像为处理对象,使用提供的遥感图像数据集进行土地利用类型语义分割处理。 -------------------------------------------------------------------------------- /data/gen_list.py: -------------------------------------------------------------------------------- 1 | import os 2 | data_dir="../../data/image_A" 3 | f=open("test.lst",'w') 4 | for i in os.listdir(data_dir): 5 | f.write("/home/archlab/lzr_satellite_image_regonization/data/image_A/"+str(i)+'\n') 6 | f.close() 7 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from tqdm import tqdm 4 | import numpy as np 5 | from utils import Evaluator 6 | from model import Unet 7 | from utils import BasicDataset 8 | from torch.utils.data import Dataset, DataLoader 9 | def eval_net(net,data_loader,device): 10 | net.eval() 11 | val_batch_num=len(data_loader) 12 | eval_loss=0 13 | 14 | e = Evaluator(num_class=8) 15 | pixel_acc_avg = 0 16 | mean_iou_avg = 0 17 | fw_iou_avg = 0 18 | 19 | with tqdm(total=val_batch_num, desc='Validation round', unit='batch', leave=False) as pbar: 20 | for idx,batch_samples in enumerate(data_loader): 21 | batch_image, batch_mask = batch_samples["image"], batch_samples["mask"] 22 | batch_image=batch_image.to(device=device,dtype=torch.float32) 23 | mask_true=batch_mask.to(device=device,dtype=torch.long) 24 | 25 | with torch.no_grad(): 26 | mask_pred=net(batch_image) 27 | probs = F.softmax(mask_pred, dim=1).squeeze(0) # [8, 256, 256] 28 | pre = torch.argmax(probs, dim=1) # [256,256] 29 | 30 | #???? 31 | e.add_batch(mask_true.cpu().data.numpy(),pre.cpu().data.numpy()) 32 | pixel_acc=e.Pixel_Accuracy() 33 | pixel_acc_avg+=pixel_acc 34 | 35 | mean_iou=e.Mean_Intersection_over_Union() 36 | mean_iou_avg+=mean_iou 37 | 38 | fw_iou=e.Frequency_Weighted_Intersection_over_Union() 39 | fw_iou_avg+=fw_iou 40 | 41 | eval_loss+=F.cross_entropy(mask_pred,mask_true).item() 42 | pbar.set_postfix({'eval_loss': eval_loss/(idx+1)}) 43 | pbar.update() 44 | e.reset() 45 | print("pixel_acc_avg:"+str(pixel_acc_avg/val_batch_num)) 46 | print("mean_iou_avg:" + str(mean_iou_avg / val_batch_num)) 47 | print("fw_iou_avg:" + str(fw_iou_avg / val_batch_num)) 48 | net.train() 49 | return eval_loss/val_batch_num,pixel_acc_avg/val_batch_num,mean_iou_avg / val_batch_num,fw_iou_avg / val_batch_num 50 | 51 | kwargs={'map_location':lambda storage, loc: storage.cuda(1)} 52 | def load_GPUS(model,model_path,kwargs): 53 | state_dict = torch.load(model_path,**kwargs) 54 | # create new OrderedDict that does not contain `module.` 55 | from collections import OrderedDict 56 | new_state_dict = OrderedDict() 57 | for k, v in state_dict['net'].items(): 58 | name = k[7:] # remove `module.` 59 | new_state_dict[name] = v 60 | # load params 61 | model.load_state_dict(new_state_dict) 62 | return model 63 | 64 | if __name__=="__main__": 65 | dir_checkpoint = 'checkpoints/' 66 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 67 | net = Unet(n_channels=3, n_classes=8, bilinear=True) 68 | net.to(device=device) 69 | model = torch.load(dir_checkpoint + 'best_score_model_unet.pth') 70 | net.load_state_dict(model['net']) 71 | #net = load_GPUS(net, dir_checkpoint + 'student_net.pth', kwargs) 72 | sate_dataset_val = BasicDataset("./data/val.lst") 73 | eval_dataloader = DataLoader(sate_dataset_val, batch_size=32, shuffle=True, num_workers=5, drop_last=True) 74 | print("begin") 75 | eval_net(net, eval_dataloader, device) -------------------------------------------------------------------------------- /model/Danet/__init__.py: -------------------------------------------------------------------------------- 1 | from .danet import DANet -------------------------------------------------------------------------------- /model/Danet/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/Danet/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/Danet/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/Danet/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/Danet/__pycache__/backbone.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/Danet/__pycache__/backbone.cpython-36.pyc -------------------------------------------------------------------------------- /model/Danet/__pycache__/backbone.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/Danet/__pycache__/backbone.cpython-37.pyc -------------------------------------------------------------------------------- /model/Danet/__pycache__/danet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/Danet/__pycache__/danet.cpython-36.pyc -------------------------------------------------------------------------------- /model/Danet/__pycache__/danet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/Danet/__pycache__/danet.cpython-37.pyc -------------------------------------------------------------------------------- /model/Danet/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/Danet/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /model/Danet/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/Danet/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /model/Danet/backbone.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from .resnet import resnet50 4 | from torch.nn import functional as F 5 | class ResNet50(nn.Module): 6 | def __init__(self, pretrained=True): 7 | """Declare all needed layers.""" 8 | super(ResNet50, self).__init__() 9 | self.model = resnet50(pretrained=pretrained) 10 | self.relu = self.model.relu # Place a hook 11 | 12 | layers_cfg = [4, 5, 6, 7] 13 | self.blocks = [] 14 | for i, num_this_layer in enumerate(layers_cfg): 15 | self.blocks.append(list(self.model.children())[num_this_layer]) 16 | 17 | def base_forward(self, x): 18 | feature_map = [] 19 | x = self.model.conv1(x) 20 | x = self.model.bn1(x) 21 | x = self.model.relu(x) 22 | x = self.model.maxpool(x) 23 | 24 | for i, block in enumerate(self.blocks): 25 | x = block(x) 26 | feature_map.append(x) 27 | 28 | out = nn.AvgPool2d(x.shape[2:])(x).view(x.shape[0], -1) 29 | 30 | return feature_map, out -------------------------------------------------------------------------------- /model/Danet/danet.py: -------------------------------------------------------------------------------- 1 | """Dual Attention Network""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .backbone import ResNet50 6 | 7 | class DANet(ResNet50): 8 | r"""Pyramid Scene Parsing Network 9 | Parameters 10 | ---------- 11 | nclass : int 12 | Number of categories for the training dataset. 13 | backbone : string 14 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50', 15 | 'resnet101' or 'resnet152'). 16 | norm_layer : object 17 | Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`; 18 | for Synchronized Cross-GPU BachNormalization). 19 | aux : bool 20 | Auxiliary loss. 21 | Reference: 22 | Jun Fu, Jing Liu, Haijie Tian, Yong Li, Yongjun Bao, Zhiwei Fang,and Hanqing Lu. 23 | "Dual Attention Network for Scene Segmentation." *CVPR*, 2019 24 | """ 25 | 26 | def __init__(self, nclass, aux=True, **kwargs): 27 | super(DANet, self).__init__(nclass) 28 | self.head = _DAHead(2048, nclass, aux, **kwargs) 29 | self.aux = True 30 | self.__setattr__('exclusive', ['head']) 31 | 32 | def forward(self, x): 33 | size = x.size()[2:] 34 | feature_map,_ = self.base_forward(x) 35 | c3,c4 = feature_map[2],feature_map[3] 36 | 37 | outputs = [] 38 | x = self.head(c4) 39 | x0 = F.interpolate(x[0], size, mode='bilinear', align_corners=True) 40 | outputs.append(x0) 41 | 42 | if self.aux: 43 | #print('x[1]:{}'.format(x[1].shape)) 44 | x1 = F.interpolate(x[1], size, mode='bilinear', align_corners=True) 45 | x2 = F.interpolate(x[2], size, mode='bilinear', align_corners=True) 46 | outputs.append(x1) 47 | outputs.append(x2) 48 | return outputs 49 | 50 | 51 | class _PositionAttentionModule(nn.Module): 52 | """ Position attention module""" 53 | 54 | def __init__(self, in_channels, **kwargs): 55 | super(_PositionAttentionModule, self).__init__() 56 | self.conv_b = nn.Conv2d(in_channels, in_channels // 8, 1) 57 | self.conv_c = nn.Conv2d(in_channels, in_channels // 8, 1) 58 | self.conv_d = nn.Conv2d(in_channels, in_channels, 1) 59 | self.alpha = nn.Parameter(torch.zeros(1)) 60 | self.softmax = nn.Softmax(dim=-1) 61 | 62 | def forward(self, x): 63 | batch_size, _, height, width = x.size() 64 | feat_b = self.conv_b(x).view(batch_size, -1, height * width).permute(0, 2, 1) 65 | feat_c = self.conv_c(x).view(batch_size, -1, height * width) 66 | attention_s = self.softmax(torch.bmm(feat_b, feat_c)) 67 | feat_d = self.conv_d(x).view(batch_size, -1, height * width) 68 | feat_e = torch.bmm(feat_d, attention_s.permute(0, 2, 1)).view(batch_size, -1, height, width) 69 | out = self.alpha * feat_e + x 70 | 71 | return out 72 | 73 | 74 | class _ChannelAttentionModule(nn.Module): 75 | """Channel attention module""" 76 | 77 | def __init__(self, **kwargs): 78 | super(_ChannelAttentionModule, self).__init__() 79 | self.beta = nn.Parameter(torch.zeros(1)) 80 | self.softmax = nn.Softmax(dim=-1) 81 | 82 | def forward(self, x): 83 | batch_size, _, height, width = x.size() 84 | feat_a = x.view(batch_size, -1, height * width) 85 | feat_a_transpose = x.view(batch_size, -1, height * width).permute(0, 2, 1) 86 | attention = torch.bmm(feat_a, feat_a_transpose) 87 | attention_new = torch.max(attention, dim=-1, keepdim=True)[0].expand_as(attention) - attention 88 | attention = self.softmax(attention_new) 89 | 90 | feat_e = torch.bmm(attention, feat_a).view(batch_size, -1, height, width) 91 | out = self.beta * feat_e + x 92 | 93 | return out 94 | 95 | 96 | class _DAHead(nn.Module): 97 | def __init__(self, in_channels, nclass, aux=True, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs): 98 | super(_DAHead, self).__init__() 99 | self.aux = aux 100 | inter_channels = in_channels // 4 101 | self.conv_p1 = nn.Sequential( 102 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 103 | norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)), 104 | nn.ReLU(True) 105 | ) 106 | self.conv_c1 = nn.Sequential( 107 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 108 | norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)), 109 | nn.ReLU(True) 110 | ) 111 | self.pam = _PositionAttentionModule(inter_channels, **kwargs) 112 | self.cam = _ChannelAttentionModule(**kwargs) 113 | self.conv_p2 = nn.Sequential( 114 | nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), 115 | norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)), 116 | nn.ReLU(True) 117 | ) 118 | self.conv_c2 = nn.Sequential( 119 | nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), 120 | norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)), 121 | nn.ReLU(True) 122 | ) 123 | self.out = nn.Sequential( 124 | nn.Dropout(0.1), 125 | nn.Conv2d(inter_channels, nclass, 1) 126 | ) 127 | if aux: 128 | self.conv_p3 = nn.Sequential( 129 | nn.Dropout(0.1), 130 | nn.Conv2d(inter_channels, nclass, 1) 131 | ) 132 | self.conv_c3 = nn.Sequential( 133 | nn.Dropout(0.1), 134 | nn.Conv2d(inter_channels, nclass, 1) 135 | ) 136 | 137 | def forward(self, x): 138 | feat_p = self.conv_p1(x) 139 | feat_p = self.pam(feat_p) 140 | feat_p = self.conv_p2(feat_p) 141 | 142 | feat_c = self.conv_c1(x) 143 | feat_c = self.cam(feat_c) 144 | feat_c = self.conv_c2(feat_c) 145 | 146 | feat_fusion = feat_p + feat_c 147 | 148 | outputs = [] 149 | fusion_out = self.out(feat_fusion) 150 | outputs.append(fusion_out) 151 | if self.aux: 152 | p_out = self.conv_p3(feat_p) 153 | c_out = self.conv_c3(feat_c) 154 | outputs.append(p_out) 155 | outputs.append(c_out) 156 | 157 | return tuple(outputs) 158 | 159 | # 160 | # def get_danet( backbone='resnet50', pretrained_base=True, **kwargs): 161 | # cityspaces_numclass = 19 162 | # model = DANet(cityspaces_numclass, backbone=backbone, pretrained_base=pretrained_base, **kwargs) 163 | # return model 164 | -------------------------------------------------------------------------------- /model/Danet/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | __all__ = ['ResNet', 'resnet50'] 5 | 6 | model_urls = { 7 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 8 | 9 | } 10 | 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | """3x3 convolution with padding""" 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 15 | padding=1, bias=False) 16 | 17 | 18 | class Bottleneck(nn.Module): 19 | expansion = 4 20 | 21 | def __init__(self, inplanes, planes, stride=1, downsample=None, rate=1): 22 | super(Bottleneck, self).__init__() 23 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 24 | self.bn1 = nn.BatchNorm2d(planes) 25 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 26 | padding=rate, dilation=rate, bias=False) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 29 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.downsample = downsample 32 | self.stride = stride 33 | 34 | def forward(self, x): 35 | residual = x 36 | 37 | out = self.conv1(x) 38 | out = self.bn1(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv2(out) 42 | out = self.bn2(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv3(out) 46 | out = self.bn3(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | class ResNet(nn.Module): 57 | def __init__(self, block, layers, num_classes=1000): 58 | self.inplanes = 64 59 | super(ResNet, self).__init__() 60 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 61 | bias=False) 62 | self.bn1 = nn.BatchNorm2d(64) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 65 | self.layer1 = self._make_layer(block, 64, layers[0]) 66 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 67 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 68 | 69 | rates = [1, 2, 4] 70 | self.layer4 = self._make_deeplabv3_layer(block, 512, layers[3], rates=rates, stride=1) # stride 2 => stride 1 71 | self.avgpool = nn.AvgPool2d(7, stride=1) 72 | self.fc = nn.Linear(512 * block.expansion, num_classes) 73 | 74 | for m in self.modules(): 75 | if isinstance(m, nn.Conv2d): 76 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 77 | elif isinstance(m, nn.BatchNorm2d): 78 | nn.init.constant_(m.weight, 1) 79 | nn.init.constant_(m.bias, 0) 80 | 81 | def _make_layer(self, block, planes, blocks, stride=1): 82 | downsample = None 83 | if stride != 1 or self.inplanes != planes * block.expansion: 84 | downsample = nn.Sequential( 85 | nn.Conv2d(self.inplanes, planes * block.expansion, 86 | kernel_size=1, stride=stride, bias=False), 87 | nn.BatchNorm2d(planes * block.expansion), 88 | ) 89 | 90 | layers = [] 91 | layers.append(block(self.inplanes, planes, stride, downsample)) 92 | self.inplanes = planes * block.expansion 93 | for i in range(1, blocks): 94 | layers.append(block(self.inplanes, planes)) 95 | 96 | return nn.Sequential(*layers) 97 | 98 | def _make_deeplabv3_layer(self, block, planes, blocks, rates, stride=1): 99 | downsample = None 100 | if stride != 1 or self.inplanes != planes * block.expansion: 101 | downsample = nn.Sequential( 102 | nn.Conv2d(self.inplanes, planes * block.expansion, 103 | kernel_size=1, stride=stride, bias=False), 104 | nn.BatchNorm2d(planes * block.expansion), 105 | ) 106 | 107 | layers = [] 108 | layers.append(block(self.inplanes, planes, stride, downsample)) 109 | self.inplanes = planes * block.expansion 110 | for i in range(1, blocks): 111 | layers.append(block(self.inplanes, planes, rate=rates[i])) 112 | 113 | return nn.Sequential(*layers) 114 | 115 | def forward(self, x): 116 | x = self.conv1(x) 117 | x = self.bn1(x) 118 | x = self.relu(x) 119 | x = self.maxpool(x) 120 | 121 | x = self.layer1(x) 122 | x = self.layer2(x) 123 | x = self.layer3(x) 124 | x = self.layer4(x) 125 | 126 | x = self.avgpool(x) 127 | x = x.view(x.size(0), -1) 128 | x = self.fc(x) 129 | 130 | return x 131 | 132 | def resnet50(pretrained=False, **kwargs): 133 | """Constructs a ResNet-50 model. 134 | Args: 135 | pretrained (bool): If True, returns a model pre-trained on ImageNet 136 | """ 137 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 138 | if pretrained: 139 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 140 | return model -------------------------------------------------------------------------------- /model/Unet.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from .Unet_module import * 3 | class Unet(nn.Module): 4 | def __init__(self,n_channels,n_classes,bilinear=True): 5 | super(Unet,self).__init__() 6 | self.n_channels=n_channels 7 | self.n_classes=n_classes 8 | self.bilinear=bilinear 9 | 10 | self.conv_block=Conv_block(n_channels,64) 11 | self.down1=Down_block(64,128) 12 | self.down2=Down_block(128,256) 13 | self.down3=Down_block(256,512) 14 | 15 | factor=2 if bilinear else 1 16 | self.down4=Down_block(512,1024//factor) 17 | 18 | self.up1=Up_block(1024,512//factor,bilinear) 19 | self.up2 = Up_block(512, 256// factor, bilinear) 20 | self.up3 = Up_block(256, 128// factor, bilinear) 21 | self.up4 = Up_block(128, 64, bilinear) 22 | self.output=OutConv(64,n_classes) 23 | 24 | def forward(self,x): 25 | x1=self.conv_block(x) 26 | x2 = self.down1(x1) 27 | x3 = self.down2(x2) 28 | x4 = self.down3(x3) 29 | x5 = self.down4(x4) 30 | x = self.up1(x5, x4) 31 | x = self.up2(x, x3) 32 | x = self.up3(x, x2) 33 | x = self.up4(x, x1) 34 | logits = self.output(x) 35 | return logits -------------------------------------------------------------------------------- /model/Unet_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | class Conv_block(nn.Module): 5 | def __init__(self,in_channels,out_channels,mid_channels=None): 6 | super().__init__() 7 | if not mid_channels: 8 | mid_channels=out_channels 9 | self.double_conv=nn.Sequential( 10 | nn.Conv2d(in_channels,mid_channels,kernel_size=3,padding=1), 11 | nn.BatchNorm2d(mid_channels), 12 | nn.ReLU(inplace=True), 13 | nn.Conv2d(mid_channels,out_channels,kernel_size=3,padding=1), 14 | nn.ReLU(inplace=True) 15 | ) 16 | def forward(self,x): 17 | return self.double_conv(x) 18 | 19 | class Down_block(nn.Module): 20 | def __init__(self,in_channels,out_channels): 21 | super().__init__() 22 | self.maxpool_conv=nn.Sequential( 23 | nn.MaxPool2d(2), 24 | Conv_block(in_channels,out_channels) 25 | ) 26 | def forward(self,x): 27 | return self.maxpool_conv(x) 28 | 29 | class Up_block(nn.Module): 30 | def __init__(self,in_channels,out_channels,bilinear=True): 31 | super().__init__() 32 | 33 | if bilinear: 34 | self.upsample=nn.Upsample(scale_factor=2,mode='bilinear',align_corners=True) 35 | self.conv=Conv_block(in_channels,out_channels,in_channels//2) 36 | else: 37 | self.upsample=nn.ConvTranspose2d(in_channels,in_channels//2,kernel_size=2,stride=2) 38 | self.conv=Conv_block(in_channels,out_channels) 39 | 40 | def forward(self,x1,x2): 41 | x1=self.upsample(x1) 42 | diffY=x2.size()[2]-x1.size()[2] 43 | diffX=x2.size()[3]-x1.size()[3] 44 | x1=F.pad(x1,[ 45 | diffX//2,diffX-diffX//2, 46 | diffY//2,diffY-diffY//2 47 | ]) 48 | x=torch.cat([x2,x1],dim=1) 49 | return self.conv(x) 50 | 51 | class OutConv(nn.Module): 52 | def __init__(self, in_channels, out_channels): 53 | super(OutConv, self).__init__() 54 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 55 | 56 | def forward(self, x): 57 | return self.conv(x) -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .Unet import Unet 2 | from .deeplabv3 import * 3 | from .Danet import DANet 4 | from .hrnet import HRNet -------------------------------------------------------------------------------- /model/__pycache__/Unet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/__pycache__/Unet.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/Unet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/__pycache__/Unet.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/Unet_module.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/__pycache__/Unet_module.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/Unet_module.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/__pycache__/Unet_module.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/deeplabv3/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling import * 2 | from ._deeplab import convert_to_separable_conv -------------------------------------------------------------------------------- /model/deeplabv3/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/deeplabv3/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/deeplabv3/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/deeplabv3/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/deeplabv3/__pycache__/_deeplab.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/deeplabv3/__pycache__/_deeplab.cpython-36.pyc -------------------------------------------------------------------------------- /model/deeplabv3/__pycache__/_deeplab.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/deeplabv3/__pycache__/_deeplab.cpython-37.pyc -------------------------------------------------------------------------------- /model/deeplabv3/__pycache__/modeling.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/deeplabv3/__pycache__/modeling.cpython-36.pyc -------------------------------------------------------------------------------- /model/deeplabv3/__pycache__/modeling.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/deeplabv3/__pycache__/modeling.cpython-37.pyc -------------------------------------------------------------------------------- /model/deeplabv3/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/deeplabv3/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /model/deeplabv3/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/deeplabv3/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /model/deeplabv3/_deeplab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from .utils import _SimpleSegmentationModel 6 | 7 | 8 | __all__ = ["DeepLabV3"] 9 | 10 | 11 | class DeepLabV3(_SimpleSegmentationModel): 12 | """ 13 | Implements DeepLabV3 model from 14 | `"Rethinking Atrous Convolution for Semantic Image Segmentation" 15 | `_. 16 | 17 | Arguments: 18 | backbone (nn.Module): the network used to compute the features for the model. 19 | The backbone should return an OrderedDict[Tensor], with the key being 20 | "out" for the last feature map used, and "aux" if an auxiliary classifier 21 | is used. 22 | classifier (nn.Module): module that takes the "out" element returned from 23 | the backbone and returns a dense prediction. 24 | aux_classifier (nn.Module, optional): auxiliary classifier used during training 25 | """ 26 | pass 27 | 28 | class DeepLabHeadV3Plus(nn.Module): 29 | def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[6, 12, 18]): 30 | super(DeepLabHeadV3Plus, self).__init__() 31 | self.project = nn.Sequential( 32 | nn.Conv2d(low_level_channels, 48, 1, bias=False), 33 | nn.BatchNorm2d(48), 34 | nn.ReLU(inplace=True), 35 | ) 36 | 37 | self.aspp = ASPP(in_channels, aspp_dilate) 38 | 39 | self.classifier = nn.Sequential( 40 | nn.Conv2d(304, 256, 3, padding=1, bias=False), 41 | nn.BatchNorm2d(256), 42 | nn.ReLU(inplace=True), 43 | nn.Conv2d(256, num_classes, 1) 44 | ) 45 | self._init_weight() 46 | 47 | def forward(self, feature): 48 | 49 | low_level_feature = self.project( feature['low_level'] ) 50 | output_feature = self.aspp(feature['out']) 51 | output_feature = F.interpolate(output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False) 52 | return self.classifier( torch.cat( [ low_level_feature, output_feature ], dim=1 ) ) 53 | 54 | def _init_weight(self): 55 | for m in self.modules(): 56 | if isinstance(m, nn.Conv2d): 57 | nn.init.kaiming_normal_(m.weight) 58 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 59 | nn.init.constant_(m.weight, 1) 60 | nn.init.constant_(m.bias, 0) 61 | 62 | class DeepLabHead(nn.Module): 63 | def __init__(self, in_channels, num_classes, aspp_dilate=[12, 24, 36]): 64 | super(DeepLabHead, self).__init__() 65 | 66 | self.classifier = nn.Sequential( 67 | ASPP(in_channels, aspp_dilate), 68 | nn.Conv2d(256, 256, 3, padding=1, bias=False), 69 | nn.BatchNorm2d(256), 70 | nn.ReLU(inplace=True), 71 | nn.Conv2d(256, num_classes, 1) 72 | ) 73 | self._init_weight() 74 | 75 | def forward(self, feature): 76 | return self.classifier( feature['out'] ) 77 | 78 | def _init_weight(self): 79 | for m in self.modules(): 80 | if isinstance(m, nn.Conv2d): 81 | nn.init.kaiming_normal_(m.weight) 82 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 83 | nn.init.constant_(m.weight, 1) 84 | nn.init.constant_(m.bias, 0) 85 | 86 | class AtrousSeparableConvolution(nn.Module): 87 | """ Atrous Separable Convolution 88 | """ 89 | def __init__(self, in_channels, out_channels, kernel_size, 90 | stride=1, padding=0, dilation=1, bias=True): 91 | super(AtrousSeparableConvolution, self).__init__() 92 | self.body = nn.Sequential( 93 | # Separable Conv 94 | nn.Conv2d( in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, groups=in_channels ), 95 | # PointWise Conv 96 | nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias), 97 | ) 98 | 99 | self._init_weight() 100 | 101 | def forward(self, x): 102 | return self.body(x) 103 | 104 | def _init_weight(self): 105 | for m in self.modules(): 106 | if isinstance(m, nn.Conv2d): 107 | nn.init.kaiming_normal_(m.weight) 108 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 109 | nn.init.constant_(m.weight, 1) 110 | nn.init.constant_(m.bias, 0) 111 | 112 | class ASPPConv(nn.Sequential): 113 | def __init__(self, in_channels, out_channels, dilation): 114 | modules = [ 115 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), 116 | nn.BatchNorm2d(out_channels), 117 | nn.ReLU(inplace=True) 118 | ] 119 | super(ASPPConv, self).__init__(*modules) 120 | 121 | class ASPPPooling(nn.Sequential): 122 | def __init__(self, in_channels, out_channels): 123 | super(ASPPPooling, self).__init__( 124 | nn.AdaptiveAvgPool2d(1), 125 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 126 | nn.BatchNorm2d(out_channels), 127 | nn.ReLU(inplace=True)) 128 | 129 | def forward(self, x): 130 | size = x.shape[-2:] 131 | x = super(ASPPPooling, self).forward(x) 132 | return F.interpolate(x, size=size, mode='bilinear', align_corners=False) 133 | 134 | class ASPP(nn.Module): 135 | def __init__(self, in_channels, atrous_rates): 136 | super(ASPP, self).__init__() 137 | out_channels = 256 138 | modules = [] 139 | modules.append(nn.Sequential( 140 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 141 | nn.BatchNorm2d(out_channels), 142 | nn.ReLU(inplace=True))) 143 | 144 | rate1, rate2, rate3 = tuple(atrous_rates) 145 | modules.append(ASPPConv(in_channels, out_channels, rate1)) 146 | modules.append(ASPPConv(in_channels, out_channels, rate2)) 147 | modules.append(ASPPConv(in_channels, out_channels, rate3)) 148 | modules.append(ASPPPooling(in_channels, out_channels)) 149 | 150 | self.convs = nn.ModuleList(modules) 151 | 152 | self.project = nn.Sequential( 153 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), 154 | nn.BatchNorm2d(out_channels), 155 | nn.ReLU(inplace=True), 156 | nn.Dropout(0.1),) 157 | 158 | def forward(self, x): 159 | res = [] 160 | for conv in self.convs: 161 | res.append(conv(x)) 162 | res = torch.cat(res, dim=1) 163 | return self.project(res) 164 | 165 | 166 | 167 | def convert_to_separable_conv(module): 168 | new_module = module 169 | if isinstance(module, nn.Conv2d) and module.kernel_size[0]>1: 170 | new_module = AtrousSeparableConvolution(module.in_channels, 171 | module.out_channels, 172 | module.kernel_size, 173 | module.stride, 174 | module.padding, 175 | module.dilation, 176 | module.bias) 177 | for name, child in module.named_children(): 178 | new_module.add_module(name, convert_to_separable_conv(child)) 179 | return new_module -------------------------------------------------------------------------------- /model/deeplabv3/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from . import resnet 2 | from . import mobilenetv2 3 | from .xception import AlignedXception 4 | -------------------------------------------------------------------------------- /model/deeplabv3/backbone/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/deeplabv3/backbone/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/deeplabv3/backbone/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/deeplabv3/backbone/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/deeplabv3/backbone/__pycache__/mobilenetv2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/deeplabv3/backbone/__pycache__/mobilenetv2.cpython-36.pyc -------------------------------------------------------------------------------- /model/deeplabv3/backbone/__pycache__/mobilenetv2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/deeplabv3/backbone/__pycache__/mobilenetv2.cpython-37.pyc -------------------------------------------------------------------------------- /model/deeplabv3/backbone/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/deeplabv3/backbone/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /model/deeplabv3/backbone/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/deeplabv3/backbone/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /model/deeplabv3/backbone/__pycache__/xception.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/deeplabv3/backbone/__pycache__/xception.cpython-36.pyc -------------------------------------------------------------------------------- /model/deeplabv3/backbone/__pycache__/xception.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/deeplabv3/backbone/__pycache__/xception.cpython-37.pyc -------------------------------------------------------------------------------- /model/deeplabv3/backbone/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torchvision.models.utils import load_state_dict_from_url 3 | import torch.nn.functional as F 4 | 5 | __all__ = ['MobileNetV2', 'mobilenet_v2'] 6 | 7 | 8 | model_urls = { 9 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 10 | } 11 | 12 | 13 | def _make_divisible(v, divisor, min_value=None): 14 | """ 15 | This function is taken from the original tf repo. 16 | It ensures that all layers have a channel number that is divisible by 8 17 | It can be seen here: 18 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 19 | :param v: 20 | :param divisor: 21 | :param min_value: 22 | :return: 23 | """ 24 | if min_value is None: 25 | min_value = divisor 26 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 27 | # Make sure that round down does not go down by more than 10%. 28 | if new_v < 0.9 * v: 29 | new_v += divisor 30 | return new_v 31 | 32 | 33 | class ConvBNReLU(nn.Sequential): 34 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, dilation=1, groups=1): 35 | #padding = (kernel_size - 1) // 2 36 | super(ConvBNReLU, self).__init__( 37 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, 0, dilation=dilation, groups=groups, bias=False), 38 | nn.BatchNorm2d(out_planes), 39 | nn.ReLU6(inplace=True) 40 | ) 41 | 42 | def fixed_padding(kernel_size, dilation): 43 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 44 | pad_total = kernel_size_effective - 1 45 | pad_beg = pad_total // 2 46 | pad_end = pad_total - pad_beg 47 | return (pad_beg, pad_end, pad_beg, pad_end) 48 | 49 | class InvertedResidual(nn.Module): 50 | def __init__(self, inp, oup, stride, dilation, expand_ratio): 51 | super(InvertedResidual, self).__init__() 52 | self.stride = stride 53 | assert stride in [1, 2] 54 | 55 | hidden_dim = int(round(inp * expand_ratio)) 56 | self.use_res_connect = self.stride == 1 and inp == oup 57 | 58 | layers = [] 59 | if expand_ratio != 1: 60 | # pw 61 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 62 | 63 | layers.extend([ 64 | # dw 65 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, dilation=dilation, groups=hidden_dim), 66 | # pw-linear 67 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 68 | nn.BatchNorm2d(oup), 69 | ]) 70 | self.conv = nn.Sequential(*layers) 71 | 72 | self.input_padding = fixed_padding( 3, dilation ) 73 | 74 | def forward(self, x): 75 | x_pad = F.pad(x, self.input_padding) 76 | if self.use_res_connect: 77 | return x + self.conv(x_pad) 78 | else: 79 | return self.conv(x_pad) 80 | 81 | class MobileNetV2(nn.Module): 82 | def __init__(self, num_classes=1000, output_stride=8, width_mult=1.0, inverted_residual_setting=None, round_nearest=8): 83 | """ 84 | MobileNet V2 main class 85 | 86 | Args: 87 | num_classes (int): Number of classes 88 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 89 | inverted_residual_setting: Network structure 90 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 91 | Set to 1 to turn off rounding 92 | """ 93 | super(MobileNetV2, self).__init__() 94 | block = InvertedResidual 95 | input_channel = 32 96 | last_channel = 1280 97 | self.output_stride = output_stride 98 | current_stride = 1 99 | if inverted_residual_setting is None: 100 | inverted_residual_setting = [ 101 | # t, c, n, s 102 | [1, 16, 1, 1], 103 | [6, 24, 2, 2], 104 | [6, 32, 3, 2], 105 | [6, 64, 4, 2], 106 | [6, 96, 3, 1], 107 | [6, 160, 3, 2], 108 | [6, 320, 1, 1], 109 | ] 110 | 111 | # only check the first element, assuming user knows t,c,n,s are required 112 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 113 | raise ValueError("inverted_residual_setting should be non-empty " 114 | "or a 4-element list, got {}".format(inverted_residual_setting)) 115 | 116 | # building first layer 117 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 118 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 119 | features = [ConvBNReLU(3, input_channel, stride=2)] 120 | current_stride *= 2 121 | dilation=1 122 | previous_dilation = 1 123 | 124 | # building inverted residual blocks 125 | for t, c, n, s in inverted_residual_setting: 126 | output_channel = _make_divisible(c * width_mult, round_nearest) 127 | previous_dilation = dilation 128 | if current_stride == output_stride: 129 | stride = 1 130 | dilation *= s 131 | else: 132 | stride = s 133 | current_stride *= s 134 | output_channel = int(c * width_mult) 135 | 136 | for i in range(n): 137 | if i==0: 138 | features.append(block(input_channel, output_channel, stride, previous_dilation, expand_ratio=t)) 139 | else: 140 | features.append(block(input_channel, output_channel, 1, dilation, expand_ratio=t)) 141 | input_channel = output_channel 142 | # building last several layers 143 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 144 | # make it nn.Sequential 145 | self.features = nn.Sequential(*features) 146 | 147 | # building classifier 148 | self.classifier = nn.Sequential( 149 | nn.Dropout(0.2), 150 | nn.Linear(self.last_channel, num_classes), 151 | ) 152 | 153 | # weight initialization 154 | for m in self.modules(): 155 | if isinstance(m, nn.Conv2d): 156 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 157 | if m.bias is not None: 158 | nn.init.zeros_(m.bias) 159 | elif isinstance(m, nn.BatchNorm2d): 160 | nn.init.ones_(m.weight) 161 | nn.init.zeros_(m.bias) 162 | elif isinstance(m, nn.Linear): 163 | nn.init.normal_(m.weight, 0, 0.01) 164 | nn.init.zeros_(m.bias) 165 | 166 | def forward(self, x): 167 | x = self.features(x) 168 | x = x.mean([2, 3]) 169 | x = self.classifier(x) 170 | return x 171 | 172 | 173 | def mobilenet_v2(pretrained=False, progress=True, **kwargs): 174 | """ 175 | Constructs a MobileNetV2 architecture from 176 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 177 | 178 | Args: 179 | pretrained (bool): If True, returns a model pre-trained on ImageNet 180 | progress (bool): If True, displays a progress bar of the download to stderr 181 | """ 182 | model = MobileNetV2(**kwargs) 183 | if pretrained: 184 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], 185 | progress=progress) 186 | model.load_state_dict(state_dict) 187 | return model 188 | -------------------------------------------------------------------------------- /model/deeplabv3/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.models.utils import load_state_dict_from_url 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 8 | 'wide_resnet50_2', 'wide_resnet101_2'] 9 | 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 18 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 19 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 20 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 25 | """3x3 convolution with padding""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=dilation, groups=groups, bias=False, dilation=dilation) 28 | 29 | 30 | def conv1x1(in_planes, out_planes, stride=1): 31 | """1x1 convolution""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 33 | 34 | 35 | class BasicBlock(nn.Module): 36 | expansion = 1 37 | 38 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 39 | base_width=64, dilation=1, norm_layer=None): 40 | super(BasicBlock, self).__init__() 41 | if norm_layer is None: 42 | norm_layer = nn.BatchNorm2d 43 | if groups != 1 or base_width != 64: 44 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 45 | if dilation > 1: 46 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 47 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 48 | self.conv1 = conv3x3(inplanes, planes, stride) 49 | self.bn1 = norm_layer(planes) 50 | self.relu = nn.ReLU(inplace=True) 51 | self.conv2 = conv3x3(planes, planes) 52 | self.bn2 = norm_layer(planes) 53 | self.downsample = downsample 54 | self.stride = stride 55 | 56 | def forward(self, x): 57 | identity = x 58 | 59 | out = self.conv1(x) 60 | out = self.bn1(out) 61 | out = self.relu(out) 62 | 63 | out = self.conv2(out) 64 | out = self.bn2(out) 65 | 66 | if self.downsample is not None: 67 | identity = self.downsample(x) 68 | 69 | out += identity 70 | out = self.relu(out) 71 | 72 | return out 73 | 74 | 75 | class Bottleneck(nn.Module): 76 | expansion = 4 77 | 78 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 79 | base_width=64, dilation=1, norm_layer=None): 80 | super(Bottleneck, self).__init__() 81 | if norm_layer is None: 82 | norm_layer = nn.BatchNorm2d 83 | width = int(planes * (base_width / 64.)) * groups 84 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 85 | self.conv1 = conv1x1(inplanes, width) 86 | self.bn1 = norm_layer(width) 87 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 88 | self.bn2 = norm_layer(width) 89 | self.conv3 = conv1x1(width, planes * self.expansion) 90 | self.bn3 = norm_layer(planes * self.expansion) 91 | self.relu = nn.ReLU(inplace=True) 92 | self.downsample = downsample 93 | self.stride = stride 94 | 95 | def forward(self, x): 96 | identity = x 97 | 98 | out = self.conv1(x) 99 | out = self.bn1(out) 100 | out = self.relu(out) 101 | 102 | out = self.conv2(out) 103 | out = self.bn2(out) 104 | out = self.relu(out) 105 | 106 | out = self.conv3(out) 107 | out = self.bn3(out) 108 | 109 | if self.downsample is not None: 110 | identity = self.downsample(x) 111 | 112 | out += identity 113 | out = self.relu(out) 114 | 115 | return out 116 | 117 | 118 | class ResNet(nn.Module): 119 | 120 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 121 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 122 | norm_layer=None): 123 | super(ResNet, self).__init__() 124 | if norm_layer is None: 125 | norm_layer = nn.BatchNorm2d 126 | self._norm_layer = norm_layer 127 | 128 | self.inplanes = 64 129 | self.dilation = 1 130 | if replace_stride_with_dilation is None: 131 | # each element in the tuple indicates if we should replace 132 | # the 2x2 stride with a dilated convolution instead 133 | replace_stride_with_dilation = [False, False, False] 134 | if len(replace_stride_with_dilation) != 3: 135 | raise ValueError("replace_stride_with_dilation should be None " 136 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 137 | self.groups = groups 138 | self.base_width = width_per_group 139 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 140 | bias=False) 141 | self.bn1 = norm_layer(self.inplanes) 142 | self.relu = nn.ReLU(inplace=True) 143 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 144 | self.layer1 = self._make_layer(block, 64, layers[0]) 145 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 146 | dilate=replace_stride_with_dilation[0]) 147 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 148 | dilate=replace_stride_with_dilation[1]) 149 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 150 | dilate=replace_stride_with_dilation[2]) 151 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 152 | self.fc = nn.Linear(512 * block.expansion, num_classes) 153 | 154 | for m in self.modules(): 155 | if isinstance(m, nn.Conv2d): 156 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 157 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 158 | nn.init.constant_(m.weight, 1) 159 | nn.init.constant_(m.bias, 0) 160 | 161 | # Zero-initialize the last BN in each residual branch, 162 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 163 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 164 | if zero_init_residual: 165 | for m in self.modules(): 166 | if isinstance(m, Bottleneck): 167 | nn.init.constant_(m.bn3.weight, 0) 168 | elif isinstance(m, BasicBlock): 169 | nn.init.constant_(m.bn2.weight, 0) 170 | 171 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 172 | norm_layer = self._norm_layer 173 | downsample = None 174 | previous_dilation = self.dilation 175 | if dilate: 176 | self.dilation *= stride 177 | stride = 1 178 | if stride != 1 or self.inplanes != planes * block.expansion: 179 | downsample = nn.Sequential( 180 | conv1x1(self.inplanes, planes * block.expansion, stride), 181 | norm_layer(planes * block.expansion), 182 | ) 183 | 184 | layers = [] 185 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 186 | self.base_width, previous_dilation, norm_layer)) 187 | self.inplanes = planes * block.expansion 188 | for _ in range(1, blocks): 189 | layers.append(block(self.inplanes, planes, groups=self.groups, 190 | base_width=self.base_width, dilation=self.dilation, 191 | norm_layer=norm_layer)) 192 | 193 | return nn.Sequential(*layers) 194 | 195 | def forward(self, x): 196 | x = self.conv1(x) 197 | x = self.bn1(x) 198 | x = self.relu(x) 199 | x = self.maxpool(x) 200 | 201 | x = self.layer1(x) 202 | x = self.layer2(x) 203 | x = self.layer3(x) 204 | x = self.layer4(x) 205 | 206 | x = self.avgpool(x) 207 | x = torch.flatten(x, 1) 208 | x = self.fc(x) 209 | 210 | return x 211 | 212 | 213 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 214 | model = ResNet(block, layers, **kwargs) 215 | if pretrained: 216 | state_dict = load_state_dict_from_url(model_urls[arch], 217 | progress=progress) 218 | model.load_state_dict(state_dict) 219 | return model 220 | 221 | 222 | def resnet18(pretrained=False, progress=True, **kwargs): 223 | r"""ResNet-18 model from 224 | `"Deep Residual Learning for Image Recognition" `_ 225 | 226 | Args: 227 | pretrained (bool): If True, returns a model pre-trained on ImageNet 228 | progress (bool): If True, displays a progress bar of the download to stderr 229 | """ 230 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 231 | **kwargs) 232 | 233 | 234 | def resnet34(pretrained=False, progress=True, **kwargs): 235 | r"""ResNet-34 model from 236 | `"Deep Residual Learning for Image Recognition" `_ 237 | 238 | Args: 239 | pretrained (bool): If True, returns a model pre-trained on ImageNet 240 | progress (bool): If True, displays a progress bar of the download to stderr 241 | """ 242 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 243 | **kwargs) 244 | 245 | 246 | def resnet50(pretrained=False, progress=True, **kwargs): 247 | r"""ResNet-50 model from 248 | `"Deep Residual Learning for Image Recognition" `_ 249 | 250 | Args: 251 | pretrained (bool): If True, returns a model pre-trained on ImageNet 252 | progress (bool): If True, displays a progress bar of the download to stderr 253 | """ 254 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 255 | **kwargs) 256 | 257 | 258 | def resnet101(pretrained=False, progress=True, **kwargs): 259 | r"""ResNet-101 model from 260 | `"Deep Residual Learning for Image Recognition" `_ 261 | 262 | Args: 263 | pretrained (bool): If True, returns a model pre-trained on ImageNet 264 | progress (bool): If True, displays a progress bar of the download to stderr 265 | """ 266 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 267 | **kwargs) 268 | 269 | 270 | def resnet152(pretrained=False, progress=True, **kwargs): 271 | r"""ResNet-152 model from 272 | `"Deep Residual Learning for Image Recognition" `_ 273 | 274 | Args: 275 | pretrained (bool): If True, returns a model pre-trained on ImageNet 276 | progress (bool): If True, displays a progress bar of the download to stderr 277 | """ 278 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 279 | **kwargs) 280 | 281 | 282 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 283 | r"""ResNeXt-50 32x4d model from 284 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 285 | 286 | Args: 287 | pretrained (bool): If True, returns a model pre-trained on ImageNet 288 | progress (bool): If True, displays a progress bar of the download to stderr 289 | """ 290 | kwargs['groups'] = 32 291 | kwargs['width_per_group'] = 4 292 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 293 | pretrained, progress, **kwargs) 294 | 295 | 296 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 297 | r"""ResNeXt-101 32x8d model from 298 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 299 | 300 | Args: 301 | pretrained (bool): If True, returns a model pre-trained on ImageNet 302 | progress (bool): If True, displays a progress bar of the download to stderr 303 | """ 304 | kwargs['groups'] = 32 305 | kwargs['width_per_group'] = 8 306 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 307 | pretrained, progress, **kwargs) 308 | 309 | 310 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 311 | r"""Wide ResNet-50-2 model from 312 | `"Wide Residual Networks" `_ 313 | 314 | The model is the same as ResNet except for the bottleneck number of channels 315 | which is twice larger in every block. The number of channels in outer 1x1 316 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 317 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 318 | 319 | Args: 320 | pretrained (bool): If True, returns a model pre-trained on ImageNet 321 | progress (bool): If True, displays a progress bar of the download to stderr 322 | """ 323 | kwargs['width_per_group'] = 64 * 2 324 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 325 | pretrained, progress, **kwargs) 326 | 327 | 328 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 329 | r"""Wide ResNet-101-2 model from 330 | `"Wide Residual Networks" `_ 331 | 332 | The model is the same as ResNet except for the bottleneck number of channels 333 | which is twice larger in every block. The number of channels in outer 1x1 334 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 335 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 336 | 337 | Args: 338 | pretrained (bool): If True, returns a model pre-trained on ImageNet 339 | progress (bool): If True, displays a progress bar of the download to stderr 340 | """ 341 | kwargs['width_per_group'] = 64 * 2 342 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 343 | pretrained, progress, **kwargs) 344 | -------------------------------------------------------------------------------- /model/deeplabv3/backbone/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback -------------------------------------------------------------------------------- /model/deeplabv3/backbone/sync_batchnorm/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/deeplabv3/backbone/sync_batchnorm/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/deeplabv3/backbone/sync_batchnorm/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/deeplabv3/backbone/sync_batchnorm/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/deeplabv3/backbone/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/deeplabv3/backbone/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc -------------------------------------------------------------------------------- /model/deeplabv3/backbone/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/deeplabv3/backbone/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc -------------------------------------------------------------------------------- /model/deeplabv3/backbone/sync_batchnorm/__pycache__/comm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/deeplabv3/backbone/sync_batchnorm/__pycache__/comm.cpython-36.pyc -------------------------------------------------------------------------------- /model/deeplabv3/backbone/sync_batchnorm/__pycache__/comm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/deeplabv3/backbone/sync_batchnorm/__pycache__/comm.cpython-37.pyc -------------------------------------------------------------------------------- /model/deeplabv3/backbone/sync_batchnorm/__pycache__/replicate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/deeplabv3/backbone/sync_batchnorm/__pycache__/replicate.cpython-36.pyc -------------------------------------------------------------------------------- /model/deeplabv3/backbone/sync_batchnorm/__pycache__/replicate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/deeplabv3/backbone/sync_batchnorm/__pycache__/replicate.cpython-37.pyc -------------------------------------------------------------------------------- /model/deeplabv3/backbone/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | return F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | 55 | # Resize the input to (B, C, -1). 56 | input_shape = input.size() 57 | input = input.view(input.size(0), self.num_features, -1) 58 | 59 | # Compute the sum and square-sum. 60 | sum_size = input.size(0) * input.size(2) 61 | input_sum = _sum_ft(input) 62 | input_ssum = _sum_ft(input ** 2) 63 | 64 | # Reduce-and-broadcast the statistics. 65 | if self._parallel_id == 0: 66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 67 | else: 68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | 70 | # Compute the output. 71 | if self.affine: 72 | # MJY:: Fuse the multiplication for speed. 73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 74 | else: 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 76 | 77 | # Reshape it. 78 | return output.view(input_shape) 79 | 80 | def __data_parallel_replicate__(self, ctx, copy_id): 81 | self._is_parallel = True 82 | self._parallel_id = copy_id 83 | 84 | # parallel_id == 0 means master device. 85 | if self._parallel_id == 0: 86 | ctx.sync_master = self._sync_master 87 | else: 88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 89 | 90 | def _data_parallel_master(self, intermediates): 91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 92 | 93 | # Always using same "device order" makes the ReduceAdd operation faster. 94 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 96 | 97 | to_reduce = [i[1][:2] for i in intermediates] 98 | to_reduce = [j for i in to_reduce for j in i] # flatten 99 | target_gpus = [i[1].sum.get_device() for i in intermediates] 100 | 101 | sum_size = sum([i[1].sum_size for i in intermediates]) 102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 104 | 105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 106 | 107 | outputs = [] 108 | for i, rec in enumerate(intermediates): 109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) 110 | 111 | return outputs 112 | 113 | def _compute_mean_std(self, sum_, ssum, size): 114 | """Compute the mean and standard-deviation with sum and square-sum. This method 115 | also maintains the moving average on the master device.""" 116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 117 | mean = sum_ / size 118 | sumvar = ssum - sum_ * mean 119 | unbias_var = sumvar / (size - 1) 120 | bias_var = sumvar / size 121 | 122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 124 | 125 | return mean, bias_var.clamp(self.eps) ** -0.5 126 | 127 | 128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 130 | mini-batch. 131 | .. math:: 132 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 133 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 134 | standard-deviation are reduced across all devices during training. 135 | For example, when one uses `nn.DataParallel` to wrap the network during 136 | training, PyTorch's implementation normalize the tensor on each device using 137 | the statistics only on that device, which accelerated the computation and 138 | is also easy to implement, but the statistics might be inaccurate. 139 | Instead, in this synchronized version, the statistics will be computed 140 | over all training samples distributed on multiple devices. 141 | 142 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 143 | as the built-in PyTorch implementation. 144 | The mean and standard-deviation are calculated per-dimension over 145 | the mini-batches and gamma and beta are learnable parameter vectors 146 | of size C (where C is the input size). 147 | During training, this layer keeps a running estimate of its computed mean 148 | and variance. The running sum is kept with a default momentum of 0.1. 149 | During evaluation, this running mean/variance is used for normalization. 150 | Because the BatchNorm is done over the `C` dimension, computing statistics 151 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 152 | Args: 153 | num_features: num_features from an expected input of size 154 | `batch_size x num_features [x width]` 155 | eps: a value added to the denominator for numerical stability. 156 | Default: 1e-5 157 | momentum: the value used for the running_mean and running_var 158 | computation. Default: 0.1 159 | affine: a boolean value that when set to ``True``, gives the layer learnable 160 | affine parameters. Default: ``True`` 161 | Shape: 162 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 163 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 164 | Examples: 165 | >>> # With Learnable Parameters 166 | >>> m = SynchronizedBatchNorm1d(100) 167 | >>> # Without Learnable Parameters 168 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 169 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 170 | >>> output = m(input) 171 | """ 172 | 173 | def _check_input_dim(self, input): 174 | if input.dim() != 2 and input.dim() != 3: 175 | raise ValueError('expected 2D or 3D input (got {}D input)' 176 | .format(input.dim())) 177 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 178 | 179 | 180 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 181 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 182 | of 3d inputs 183 | .. math:: 184 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 185 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 186 | standard-deviation are reduced across all devices during training. 187 | For example, when one uses `nn.DataParallel` to wrap the network during 188 | training, PyTorch's implementation normalize the tensor on each device using 189 | the statistics only on that device, which accelerated the computation and 190 | is also easy to implement, but the statistics might be inaccurate. 191 | Instead, in this synchronized version, the statistics will be computed 192 | over all training samples distributed on multiple devices. 193 | 194 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 195 | as the built-in PyTorch implementation. 196 | The mean and standard-deviation are calculated per-dimension over 197 | the mini-batches and gamma and beta are learnable parameter vectors 198 | of size C (where C is the input size). 199 | During training, this layer keeps a running estimate of its computed mean 200 | and variance. The running sum is kept with a default momentum of 0.1. 201 | During evaluation, this running mean/variance is used for normalization. 202 | Because the BatchNorm is done over the `C` dimension, computing statistics 203 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 204 | Args: 205 | num_features: num_features from an expected input of 206 | size batch_size x num_features x height x width 207 | eps: a value added to the denominator for numerical stability. 208 | Default: 1e-5 209 | momentum: the value used for the running_mean and running_var 210 | computation. Default: 0.1 211 | affine: a boolean value that when set to ``True``, gives the layer learnable 212 | affine parameters. Default: ``True`` 213 | Shape: 214 | - Input: :math:`(N, C, H, W)` 215 | - Output: :math:`(N, C, H, W)` (same shape as input) 216 | Examples: 217 | >>> # With Learnable Parameters 218 | >>> m = SynchronizedBatchNorm2d(100) 219 | >>> # Without Learnable Parameters 220 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 221 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 222 | >>> output = m(input) 223 | """ 224 | 225 | def _check_input_dim(self, input): 226 | if input.dim() != 4: 227 | raise ValueError('expected 4D input (got {}D input)' 228 | .format(input.dim())) 229 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 230 | 231 | 232 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 233 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 234 | of 4d inputs 235 | .. math:: 236 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 237 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 238 | standard-deviation are reduced across all devices during training. 239 | For example, when one uses `nn.DataParallel` to wrap the network during 240 | training, PyTorch's implementation normalize the tensor on each device using 241 | the statistics only on that device, which accelerated the computation and 242 | is also easy to implement, but the statistics might be inaccurate. 243 | Instead, in this synchronized version, the statistics will be computed 244 | over all training samples distributed on multiple devices. 245 | 246 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 247 | as the built-in PyTorch implementation. 248 | The mean and standard-deviation are calculated per-dimension over 249 | the mini-batches and gamma and beta are learnable parameter vectors 250 | of size C (where C is the input size). 251 | During training, this layer keeps a running estimate of its computed mean 252 | and variance. The running sum is kept with a default momentum of 0.1. 253 | During evaluation, this running mean/variance is used for normalization. 254 | Because the BatchNorm is done over the `C` dimension, computing statistics 255 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 256 | or Spatio-temporal BatchNorm 257 | Args: 258 | num_features: num_features from an expected input of 259 | size batch_size x num_features x depth x height x width 260 | eps: a value added to the denominator for numerical stability. 261 | Default: 1e-5 262 | momentum: the value used for the running_mean and running_var 263 | computation. Default: 0.1 264 | affine: a boolean value that when set to ``True``, gives the layer learnable 265 | affine parameters. Default: ``True`` 266 | Shape: 267 | - Input: :math:`(N, C, D, H, W)` 268 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 269 | Examples: 270 | >>> # With Learnable Parameters 271 | >>> m = SynchronizedBatchNorm3d(100) 272 | >>> # Without Learnable Parameters 273 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 274 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 275 | >>> output = m(input) 276 | """ 277 | 278 | def _check_input_dim(self, input): 279 | if input.dim() != 5: 280 | raise ValueError('expected 5D input (got {}D input)' 281 | .format(input.dim())) 282 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) -------------------------------------------------------------------------------- /model/deeplabv3/backbone/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /model/deeplabv3/backbone/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | Note that, as all modules are isomorphism, we assign each sub-module with a context 32 | (shared among multiple copies of this module on different devices). 33 | Through this context, different copies can share some information. 34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 35 | of any slave copies. 36 | """ 37 | master_copy = modules[0] 38 | nr_modules = len(list(master_copy.modules())) 39 | ctxs = [CallbackContext() for _ in range(nr_modules)] 40 | 41 | for i, module in enumerate(modules): 42 | for j, m in enumerate(module.modules()): 43 | if hasattr(m, '__data_parallel_replicate__'): 44 | m.__data_parallel_replicate__(ctxs[j], i) 45 | 46 | 47 | class DataParallelWithCallback(DataParallel): 48 | """ 49 | Data Parallel with a replication callback. 50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 51 | original `replicate` function. 52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 53 | Examples: 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | # sync_bn.__data_parallel_replicate__ will be invoked. 57 | """ 58 | 59 | def replicate(self, module, device_ids): 60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 61 | execute_replication_callbacks(modules) 62 | return modules 63 | 64 | 65 | def patch_replication_callback(data_parallel): 66 | """ 67 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 68 | Useful when you have customized `DataParallel` implementation. 69 | Examples: 70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 72 | > patch_replication_callback(sync_bn) 73 | # this is equivalent to 74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 76 | """ 77 | 78 | assert isinstance(data_parallel, DataParallel) 79 | 80 | old_replicate = data_parallel.replicate 81 | 82 | @functools.wraps(old_replicate) 83 | def new_replicate(module, device_ids): 84 | modules = old_replicate(module, device_ids) 85 | execute_replication_callbacks(modules) 86 | return modules 87 | 88 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /model/deeplabv3/backbone/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /model/deeplabv3/backbone/xception.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 7 | 8 | def fixed_padding(inputs, kernel_size, dilation): 9 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 10 | pad_total = kernel_size_effective - 1 11 | pad_beg = pad_total // 2 12 | pad_end = pad_total - pad_beg 13 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 14 | return padded_inputs 15 | 16 | 17 | class SeparableConv2d(nn.Module): 18 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, BatchNorm=None): 19 | super(SeparableConv2d, self).__init__() 20 | 21 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation, 22 | groups=inplanes, bias=bias) 23 | self.bn = BatchNorm(inplanes) 24 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) 25 | 26 | def forward(self, x): 27 | x = fixed_padding(x, self.conv1.kernel_size[0], dilation=self.conv1.dilation[0]) 28 | x = self.conv1(x) 29 | x = self.bn(x) 30 | x = self.pointwise(x) 31 | return x 32 | 33 | 34 | class Block(nn.Module): 35 | def __init__(self, inplanes, planes, reps, stride=1, dilation=1, BatchNorm=None, 36 | start_with_relu=True, grow_first=True, is_last=False): 37 | super(Block, self).__init__() 38 | 39 | if planes != inplanes or stride != 1: 40 | self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False) 41 | self.skipbn = BatchNorm(planes) 42 | else: 43 | self.skip = None 44 | 45 | self.relu = nn.ReLU(inplace=True) 46 | rep = [] 47 | 48 | filters = inplanes 49 | if grow_first: 50 | rep.append(self.relu) 51 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm)) 52 | rep.append(BatchNorm(planes)) 53 | filters = planes 54 | 55 | for i in range(reps - 1): 56 | rep.append(self.relu) 57 | rep.append(SeparableConv2d(filters, filters, 3, 1, dilation, BatchNorm=BatchNorm)) 58 | rep.append(BatchNorm(filters)) 59 | 60 | if not grow_first: 61 | rep.append(self.relu) 62 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm)) 63 | rep.append(BatchNorm(planes)) 64 | 65 | if stride != 1: 66 | rep.append(self.relu) 67 | rep.append(SeparableConv2d(planes, planes, 3, 2, BatchNorm=BatchNorm)) 68 | rep.append(BatchNorm(planes)) 69 | 70 | if stride == 1 and is_last: 71 | rep.append(self.relu) 72 | rep.append(SeparableConv2d(planes, planes, 3, 1, BatchNorm=BatchNorm)) 73 | rep.append(BatchNorm(planes)) 74 | 75 | if not start_with_relu: 76 | rep = rep[1:] 77 | 78 | self.rep = nn.Sequential(*rep) 79 | 80 | def forward(self, inp): 81 | x = self.rep(inp) 82 | 83 | if self.skip is not None: 84 | skip = self.skip(inp) 85 | skip = self.skipbn(skip) 86 | else: 87 | skip = inp 88 | 89 | x = x + skip 90 | 91 | return x 92 | 93 | 94 | class AlignedXception(nn.Module): 95 | """ 96 | Modified Alighed Xception 97 | """ 98 | def __init__(self, output_stride, BatchNorm, 99 | pretrained=True): 100 | super(AlignedXception, self).__init__() 101 | 102 | if output_stride == 16: 103 | entry_block3_stride = 2 104 | middle_block_dilation = 1 105 | exit_block_dilations = (1, 2) 106 | elif output_stride == 8: 107 | entry_block3_stride = 1 108 | middle_block_dilation = 2 109 | exit_block_dilations = (2, 4) 110 | else: 111 | raise NotImplementedError 112 | 113 | 114 | # Entry flow 115 | self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False) 116 | self.bn1 = BatchNorm(32) 117 | self.relu = nn.ReLU(inplace=True) 118 | 119 | self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False) 120 | self.bn2 = BatchNorm(64) 121 | 122 | self.block1 = Block(64, 128, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False) 123 | self.block2 = Block(128, 256, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False, 124 | grow_first=True) 125 | self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, BatchNorm=BatchNorm, 126 | start_with_relu=True, grow_first=True, is_last=True) 127 | 128 | # Middle flow 129 | self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 130 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 131 | self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 132 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 133 | self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 134 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 135 | self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 136 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 137 | self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 138 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 139 | self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 140 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 141 | self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 142 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 143 | self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 144 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 145 | self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 146 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 147 | self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 148 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 149 | self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 150 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 151 | self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 152 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 153 | self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 154 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 155 | self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 156 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 157 | self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 158 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 159 | self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 160 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 161 | 162 | # Exit flow 163 | self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_dilations[0], 164 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=False, is_last=True) 165 | 166 | self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 167 | self.bn3 = BatchNorm(1536) 168 | 169 | self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 170 | self.bn4 = BatchNorm(1536) 171 | 172 | self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 173 | self.bn5 = BatchNorm(2048) 174 | 175 | # Init weights 176 | self._init_weight() 177 | 178 | # Load pretrained model 179 | if pretrained: 180 | self._load_pretrained_model() 181 | 182 | def forward(self, x): 183 | # Entry flow 184 | x = self.conv1(x) 185 | x = self.bn1(x) 186 | x = self.relu(x) 187 | 188 | x = self.conv2(x) 189 | x = self.bn2(x) 190 | x = self.relu(x) 191 | 192 | x = self.block1(x) 193 | # add relu here 194 | x = self.relu(x) 195 | low_level_feat = x 196 | x = self.block2(x) 197 | x = self.block3(x) 198 | 199 | # Middle flow 200 | x = self.block4(x) 201 | x = self.block5(x) 202 | x = self.block6(x) 203 | x = self.block7(x) 204 | x = self.block8(x) 205 | x = self.block9(x) 206 | x = self.block10(x) 207 | x = self.block11(x) 208 | x = self.block12(x) 209 | x = self.block13(x) 210 | x = self.block14(x) 211 | x = self.block15(x) 212 | x = self.block16(x) 213 | x = self.block17(x) 214 | x = self.block18(x) 215 | x = self.block19(x) 216 | 217 | # Exit flow 218 | x = self.block20(x) 219 | x = self.relu(x) 220 | x = self.conv3(x) 221 | x = self.bn3(x) 222 | x = self.relu(x) 223 | 224 | x = self.conv4(x) 225 | x = self.bn4(x) 226 | x = self.relu(x) 227 | 228 | x = self.conv5(x) 229 | x = self.bn5(x) 230 | x = self.relu(x) 231 | 232 | return {'out':x, 'low_level':low_level_feat} 233 | 234 | def _init_weight(self): 235 | for m in self.modules(): 236 | if isinstance(m, nn.Conv2d): 237 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 238 | m.weight.data.normal_(0, math.sqrt(2. / n)) 239 | elif isinstance(m, SynchronizedBatchNorm2d): 240 | m.weight.data.fill_(1) 241 | m.bias.data.zero_() 242 | elif isinstance(m, nn.BatchNorm2d): 243 | m.weight.data.fill_(1) 244 | m.bias.data.zero_() 245 | 246 | 247 | def _load_pretrained_model(self): 248 | pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth') 249 | model_dict = {} 250 | state_dict = self.state_dict() 251 | 252 | for k, v in pretrain_dict.items(): 253 | if k in state_dict: 254 | if 'pointwise' in k: 255 | v = v.unsqueeze(-1).unsqueeze(-1) 256 | if k.startswith('block11'): 257 | model_dict[k] = v 258 | model_dict[k.replace('block11', 'block12')] = v 259 | model_dict[k.replace('block11', 'block13')] = v 260 | model_dict[k.replace('block11', 'block14')] = v 261 | model_dict[k.replace('block11', 'block15')] = v 262 | model_dict[k.replace('block11', 'block16')] = v 263 | model_dict[k.replace('block11', 'block17')] = v 264 | model_dict[k.replace('block11', 'block18')] = v 265 | model_dict[k.replace('block11', 'block19')] = v 266 | elif k.startswith('block12'): 267 | model_dict[k.replace('block12', 'block20')] = v 268 | elif k.startswith('bn3'): 269 | model_dict[k] = v 270 | model_dict[k.replace('bn3', 'bn4')] = v 271 | elif k.startswith('conv4'): 272 | model_dict[k.replace('conv4', 'conv5')] = v 273 | elif k.startswith('bn4'): 274 | model_dict[k.replace('bn4', 'bn5')] = v 275 | else: 276 | model_dict[k] = v 277 | state_dict.update(model_dict) 278 | self.load_state_dict(state_dict) 279 | 280 | 281 | 282 | if __name__ == "__main__": 283 | import torch 284 | model = AlignedXception(BatchNorm=nn.BatchNorm2d, pretrained=False, output_stride=16) 285 | input = torch.rand(1, 3, 256, 256) 286 | output, low_level_feat = model(input) 287 | print(output.size()) 288 | print(low_level_feat.size()) -------------------------------------------------------------------------------- /model/deeplabv3/modeling.py: -------------------------------------------------------------------------------- 1 | from .utils import IntermediateLayerGetter 2 | from ._deeplab import DeepLabHead, DeepLabHeadV3Plus, DeepLabV3 3 | from .backbone import resnet 4 | from .backbone import mobilenetv2 5 | from .backbone import AlignedXception 6 | import torch.nn as nn 7 | from torch.nn import functional as F 8 | def _segm_resnet(name, backbone_name, num_classes, output_stride, pretrained_backbone): 9 | 10 | if output_stride==8: 11 | replace_stride_with_dilation=[False, True, True] 12 | aspp_dilate = [12, 24, 36] 13 | else: 14 | replace_stride_with_dilation=[False, False, True] 15 | aspp_dilate = [6, 12, 18] 16 | 17 | backbone = resnet.__dict__[backbone_name]( 18 | pretrained=pretrained_backbone, 19 | replace_stride_with_dilation=replace_stride_with_dilation) 20 | 21 | inplanes = 2048 22 | low_level_planes = 256 23 | 24 | if name=='deeplabv3plus': 25 | return_layers = {'layer4': 'out', 'layer1': 'low_level'} 26 | classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate) 27 | elif name=='deeplabv3': 28 | return_layers = {'layer4': 'out'} 29 | classifier = DeepLabHead(inplanes , num_classes, aspp_dilate) 30 | backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) 31 | 32 | model = DeepLabV3(backbone, classifier) 33 | return model 34 | 35 | def _segm_mobilenet(name, backbone_name, num_classes, output_stride, pretrained_backbone): 36 | if output_stride==8: 37 | aspp_dilate = [12, 24, 36] 38 | else: 39 | aspp_dilate = [6, 12, 18] 40 | 41 | backbone = mobilenetv2.mobilenet_v2(pretrained=pretrained_backbone, output_stride=output_stride) 42 | 43 | # rename layers 44 | backbone.low_level_features = backbone.features[0:4] 45 | backbone.high_level_features = backbone.features[4:-1] 46 | backbone.features = None 47 | backbone.classifier = None 48 | 49 | inplanes = 320 50 | low_level_planes = 24 51 | 52 | if name=='deeplabv3plus': 53 | return_layers = {'high_level_features': 'out', 'low_level_features': 'low_level'} 54 | classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate) 55 | elif name=='deeplabv3': 56 | return_layers = {'high_level_features': 'out'} 57 | classifier = DeepLabHead(inplanes , num_classes, aspp_dilate) 58 | backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) 59 | 60 | model = DeepLabV3(backbone, classifier) 61 | return model 62 | 63 | def _load_model(arch_type, backbone, num_classes, output_stride, pretrained_backbone): 64 | 65 | if backbone=='mobilenetv2': 66 | model = _segm_mobilenet(arch_type, backbone, num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) 67 | elif backbone.startswith('resnet'): 68 | model = _segm_resnet(arch_type, backbone, num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) 69 | else: 70 | raise NotImplementedError 71 | return model 72 | 73 | 74 | # Deeplab v3 75 | 76 | def deeplabv3_resnet50(num_classes=21, output_stride=8, pretrained_backbone=True): 77 | """Constructs a DeepLabV3 model with a ResNet-50 backbone. 78 | 79 | Args: 80 | num_classes (int): number of classes. 81 | output_stride (int): output stride for deeplab. 82 | pretrained_backbone (bool): If True, use the pretrained backbone. 83 | """ 84 | return _load_model('deeplabv3', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) 85 | 86 | def deeplabv3_resnet101(num_classes=21, output_stride=8, pretrained_backbone=True): 87 | """Constructs a DeepLabV3 model with a ResNet-101 backbone. 88 | 89 | Args: 90 | num_classes (int): number of classes. 91 | output_stride (int): output stride for deeplab. 92 | pretrained_backbone (bool): If True, use the pretrained backbone. 93 | """ 94 | return _load_model('deeplabv3', 'resnet101', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) 95 | 96 | def deeplabv3_mobilenet(num_classes=21, output_stride=8, pretrained_backbone=True, **kwargs): 97 | """Constructs a DeepLabV3 model with a MobileNetv2 backbone. 98 | 99 | Args: 100 | num_classes (int): number of classes. 101 | output_stride (int): output stride for deeplab. 102 | pretrained_backbone (bool): If True, use the pretrained backbone. 103 | """ 104 | return _load_model('deeplabv3', 'mobilenetv2', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) 105 | 106 | 107 | # Deeplab v3+ 108 | 109 | def deeplabv3plus_resnet50(num_classes=21, output_stride=8, pretrained_backbone=True): 110 | """Constructs a DeepLabV3 model with a ResNet-50 backbone. 111 | 112 | Args: 113 | num_classes (int): number of classes. 114 | output_stride (int): output stride for deeplab. 115 | pretrained_backbone (bool): If True, use the pretrained backbone. 116 | """ 117 | return _load_model('deeplabv3plus', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) 118 | 119 | 120 | def deeplabv3plus_resnet101(num_classes=21, output_stride=8, pretrained_backbone=True): 121 | """Constructs a DeepLabV3+ model with a ResNet-101 backbone. 122 | 123 | Args: 124 | num_classes (int): number of classes. 125 | output_stride (int): output stride for deeplab. 126 | pretrained_backbone (bool): If True, use the pretrained backbone. 127 | """ 128 | return _load_model('deeplabv3plus', 'resnet101', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) 129 | 130 | 131 | def deeplabv3plus_mobilenet(num_classes=21, output_stride=8, pretrained_backbone=True): 132 | """Constructs a DeepLabV3+ model with a MobileNetv2 backbone. 133 | 134 | Args: 135 | num_classes (int): number of classes. 136 | output_stride (int): output stride for deeplab. 137 | pretrained_backbone (bool): If True, use the pretrained backbone. 138 | """ 139 | return _load_model('deeplabv3plus', 'mobilenetv2', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) 140 | 141 | class deeplabv3plus_xception(nn.Module): 142 | def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[6, 12, 18]): 143 | super(deeplabv3plus_xception, self).__init__() 144 | self.backbone = AlignedXception(BatchNorm=nn.BatchNorm2d, pretrained=False, output_stride=16) 145 | self.deeplabv3_plus=DeepLabHeadV3Plus(in_channels, low_level_channels, num_classes) 146 | 147 | 148 | def forward(self,x): 149 | x=self.backbone(x) 150 | x=self.deeplabv3_plus(x) 151 | x=F.interpolate(x, size=(256,256), mode='bilinear', align_corners=False) 152 | return x 153 | 154 | -------------------------------------------------------------------------------- /model/deeplabv3/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from collections import OrderedDict 6 | 7 | class _SimpleSegmentationModel(nn.Module): 8 | def __init__(self, backbone, classifier): 9 | super(_SimpleSegmentationModel, self).__init__() 10 | self.backbone = backbone 11 | self.classifier = classifier 12 | 13 | def forward(self, x): 14 | input_shape = x.shape[-2:] 15 | features = self.backbone(x) 16 | x = self.classifier(features) 17 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) 18 | return x 19 | 20 | 21 | class IntermediateLayerGetter(nn.ModuleDict): 22 | """ 23 | Module wrapper that returns intermediate layers from a model 24 | 25 | It has a strong assumption that the modules have been registered 26 | into the model in the same order as they are used. 27 | This means that one should **not** reuse the same nn.Module 28 | twice in the forward if you want this to work. 29 | 30 | Additionally, it is only able to query submodules that are directly 31 | assigned to the model. So if `model` is passed, `model.feature1` can 32 | be returned, but not `model.feature1.layer2`. 33 | 34 | Arguments: 35 | model (nn.Module): model on which we will extract the features 36 | return_layers (Dict[name, new_name]): a dict containing the names 37 | of the modules for which the activations will be returned as 38 | the key of the dict, and the value of the dict is the name 39 | of the returned activation (which the user can specify). 40 | 41 | Examples:: 42 | 43 | >>> m = torchvision.models.resnet18(pretrained=True) 44 | >>> # extract layer1 and layer3, giving as names `feat1` and feat2` 45 | >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, 46 | >>> {'layer1': 'feat1', 'layer3': 'feat2'}) 47 | >>> out = new_m(torch.rand(1, 3, 224, 224)) 48 | >>> print([(k, v.shape) for k, v in out.items()]) 49 | >>> [('feat1', torch.Size([1, 64, 56, 56])), 50 | >>> ('feat2', torch.Size([1, 256, 14, 14]))] 51 | """ 52 | def __init__(self, model, return_layers): 53 | if not set(return_layers).issubset([name for name, _ in model.named_children()]): 54 | raise ValueError("return_layers are not present in model") 55 | 56 | orig_return_layers = return_layers 57 | return_layers = {k: v for k, v in return_layers.items()} 58 | layers = OrderedDict() 59 | for name, module in model.named_children(): 60 | layers[name] = module 61 | if name in return_layers: 62 | del return_layers[name] 63 | if not return_layers: 64 | break 65 | 66 | super(IntermediateLayerGetter, self).__init__(layers) 67 | self.return_layers = orig_return_layers 68 | 69 | def forward(self, x): 70 | out = OrderedDict() 71 | for name, module in self.named_children(): 72 | x = module(x) 73 | if name in self.return_layers: 74 | out_name = self.return_layers[name] 75 | out[out_name] = x 76 | return out 77 | -------------------------------------------------------------------------------- /model/hrnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .hrnet import HRNet -------------------------------------------------------------------------------- /model/hrnet/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/hrnet/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/hrnet/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/hrnet/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/hrnet/__pycache__/hrnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/hrnet/__pycache__/hrnet.cpython-36.pyc -------------------------------------------------------------------------------- /model/hrnet/__pycache__/hrnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/hrnet/__pycache__/hrnet.cpython-37.pyc -------------------------------------------------------------------------------- /model/hrnet/__pycache__/hrnet_module.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/hrnet/__pycache__/hrnet_module.cpython-36.pyc -------------------------------------------------------------------------------- /model/hrnet/__pycache__/hrnet_module.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/model/hrnet/__pycache__/hrnet_module.cpython-37.pyc -------------------------------------------------------------------------------- /model/hrnet/hrnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .hrnet_module import BasicBlock, Bottleneck 4 | 5 | class StageModule(nn.Module): 6 | def __init__(self, stage, output_branches, c, bn_momentum): 7 | super(StageModule, self).__init__() 8 | self.stage = stage 9 | self.output_branches = output_branches 10 | 11 | self.branches = nn.ModuleList() 12 | for i in range(self.stage): 13 | w = c * (2 ** i) 14 | branch = nn.Sequential( 15 | BasicBlock(w, w, bn_momentum=bn_momentum), 16 | BasicBlock(w, w, bn_momentum=bn_momentum), 17 | BasicBlock(w, w, bn_momentum=bn_momentum), 18 | BasicBlock(w, w, bn_momentum=bn_momentum), 19 | ) 20 | self.branches.append(branch) 21 | 22 | self.fuse_layers = nn.ModuleList() 23 | # for each output_branches (i.e. each branch in all cases but the very last one) 24 | for i in range(self.output_branches): 25 | self.fuse_layers.append(nn.ModuleList()) 26 | for j in range(self.stage): # for each branch 27 | if i == j: 28 | self.fuse_layers[-1].append(nn.Sequential()) # Used in place of "None" because it is callable 29 | elif i < j: 30 | self.fuse_layers[-1].append(nn.Sequential( 31 | nn.Conv2d(c * (2 ** j), c * (2 ** i), kernel_size=(1, 1), stride=(1, 1), bias=False), 32 | nn.BatchNorm2d(c * (2 ** i), eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), 33 | nn.Upsample(scale_factor=(2.0 ** (j - i)), mode='nearest'), 34 | )) 35 | elif i > j: 36 | ops = [] 37 | for k in range(i - j - 1): 38 | ops.append(nn.Sequential( 39 | nn.Conv2d(c * (2 ** j), c * (2 ** j), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), 40 | bias=False), 41 | nn.BatchNorm2d(c * (2 ** j), eps=1e-05, momentum=0.1, affine=True, 42 | track_running_stats=True), 43 | nn.ReLU(inplace=True), 44 | )) 45 | ops.append(nn.Sequential( 46 | nn.Conv2d(c * (2 ** j), c * (2 ** i), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), 47 | bias=False), 48 | nn.BatchNorm2d(c * (2 ** i), eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), 49 | )) 50 | self.fuse_layers[-1].append(nn.Sequential(*ops)) 51 | 52 | self.relu = nn.ReLU(inplace=True) 53 | 54 | def forward(self, x): 55 | assert len(self.branches) == len(x) 56 | 57 | x = [branch(b) for branch, b in zip(self.branches, x)] 58 | 59 | x_fused = [] 60 | for i in range(len(self.fuse_layers)): 61 | for j in range(0, len(self.branches)): 62 | if j == 0: 63 | x_fused.append(self.fuse_layers[i][0](x[0])) 64 | else: 65 | x_fused[i] = x_fused[i] + self.fuse_layers[i][j](x[j]) 66 | 67 | for i in range(len(x_fused)): 68 | x_fused[i] = self.relu(x_fused[i]) 69 | 70 | return x_fused 71 | 72 | 73 | class HRNet(nn.Module): 74 | def __init__(self, c=48, num_classes=8, bn_momentum=0.1): 75 | super(HRNet, self).__init__() 76 | 77 | # Input (stem net) 78 | self.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) 79 | self.bn1 = nn.BatchNorm2d(64, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True) 80 | self.conv2 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) 81 | self.bn2 = nn.BatchNorm2d(64, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True) 82 | self.relu = nn.ReLU(inplace=True) 83 | 84 | # Stage 1 (layer1) - First group of bottleneck (resnet) modules 85 | downsample = nn.Sequential( 86 | nn.Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False), 87 | nn.BatchNorm2d(256, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True), 88 | ) 89 | self.layer1 = nn.Sequential( 90 | Bottleneck(64, 64, downsample=downsample), 91 | Bottleneck(256, 64), 92 | Bottleneck(256, 64), 93 | Bottleneck(256, 64), 94 | ) 95 | 96 | # Fusion layer 1 (transition1) - Creation of the first two branches (one full and one half resolution) 97 | self.transition1 = nn.ModuleList([ 98 | nn.Sequential( 99 | nn.Conv2d(256, c, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), 100 | nn.BatchNorm2d(c, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True), 101 | nn.ReLU(inplace=True), 102 | ), 103 | nn.Sequential(nn.Sequential( # Double Sequential to fit with official pretrained weights 104 | nn.Conv2d(256, c * (2 ** 1), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False), 105 | nn.BatchNorm2d(c * (2 ** 1), eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True), 106 | nn.ReLU(inplace=True), 107 | )), 108 | ]) 109 | 110 | # Stage 2 (stage2) - Second module with 1 group of bottleneck (resnet) modules. This has 2 branches 111 | self.stage2 = nn.Sequential( 112 | StageModule(stage=2, output_branches=2, c=c, bn_momentum=bn_momentum), 113 | ) 114 | 115 | # Fusion layer 2 (transition2) - Creation of the third branch (1/4 resolution) 116 | self.transition2 = nn.ModuleList([ 117 | nn.Sequential(), # None, - Used in place of "None" because it is callable 118 | nn.Sequential(), # None, - Used in place of "None" because it is callable 119 | nn.Sequential(nn.Sequential( # Double Sequential to fit with official pretrained weights 120 | nn.Conv2d(c * (2 ** 1), c * (2 ** 2), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False), 121 | nn.BatchNorm2d(c * (2 ** 2), eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True), 122 | nn.ReLU(inplace=True), 123 | )), # ToDo Why the new branch derives from the "upper" branch only? 124 | ]) 125 | 126 | # Stage 3 (stage3) - Third module with 4 groups of bottleneck (resnet) modules. This has 3 branches 127 | self.stage3 = nn.Sequential( 128 | StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum), 129 | StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum), 130 | StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum), 131 | StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum), 132 | ) 133 | 134 | # Fusion layer 3 (transition3) - Creation of the fourth branch (1/8 resolution) 135 | self.transition3 = nn.ModuleList([ 136 | nn.Sequential(), # None, - Used in place of "None" because it is callable 137 | nn.Sequential(), # None, - Used in place of "None" because it is callable 138 | nn.Sequential(), # None, - Used in place of "None" because it is callable 139 | nn.Sequential(nn.Sequential( # Double Sequential to fit with official pretrained weights 140 | nn.Conv2d(c * (2 ** 2), c * (2 ** 3), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False), 141 | nn.BatchNorm2d(c * (2 ** 3), eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True), 142 | nn.ReLU(inplace=True), 143 | )), # ToDo Why the new branch derives from the "upper" branch only? 144 | ]) 145 | 146 | # Stage 4 (stage4) - Fourth module with 3 groups of bottleneck (resnet) modules. This has 4 branches 147 | self.stage4 = nn.Sequential( 148 | StageModule(stage=4, output_branches=4, c=c, bn_momentum=bn_momentum), 149 | StageModule(stage=4, output_branches=4, c=c, bn_momentum=bn_momentum), 150 | StageModule(stage=4, output_branches=1, c=c, bn_momentum=bn_momentum), 151 | ) 152 | 153 | # Final layer (final_layer) 154 | # self.final_layer = nn.Conv2d(c, nof_joints, kernel_size=(1, 1), stride=(1, 1)) 155 | 156 | self.last_layer = nn.Sequential( 157 | nn.Conv2d( 158 | in_channels=c, 159 | out_channels=c, 160 | kernel_size=1, 161 | stride=1, 162 | padding=0), 163 | nn.BatchNorm2d(c, momentum=0.1), 164 | nn.ReLU(inplace=True), 165 | nn.Conv2d( 166 | in_channels=c, 167 | out_channels=num_classes,#num_classes 168 | kernel_size=1, 169 | stride=1) 170 | ) 171 | 172 | def forward(self, x): 173 | x = self.conv1(x) 174 | x = self.bn1(x) 175 | x = self.relu(x) 176 | x = self.conv2(x) 177 | x = self.bn2(x) 178 | x = self.relu(x) 179 | 180 | x = self.layer1(x) 181 | x = [trans(x) for trans in self.transition1] # Since now, x is a list (# == nof branches) 182 | 183 | x = self.stage2(x) 184 | # x = [trans(x[-1]) for trans in self.transition2] # New branch derives from the "upper" branch only 185 | x = [ 186 | self.transition2[0](x[0]), 187 | self.transition2[1](x[1]), 188 | self.transition2[2](x[-1]) 189 | ] # New branch derives from the "upper" branch only 190 | 191 | x = self.stage3(x) 192 | # x = [trans(x) for trans in self.transition3] # New branch derives from the "upper" branch only 193 | x = [ 194 | self.transition3[0](x[0]), 195 | self.transition3[1](x[1]), 196 | self.transition3[2](x[2]), 197 | self.transition3[3](x[-1]) 198 | ] # New branch derives from the "upper" branch only 199 | 200 | x = self.stage4(x) 201 | x = self.last_layer(x[-1]) 202 | 203 | return x 204 | 205 | # 206 | # if __name__ == '__main__': 207 | # model = HRNet(48, 8, 0.1) 208 | # #model = HRNet(32, 8, 0.1) 209 | # 210 | # # print(model) 211 | # # 212 | # # model.load_state_dict( 213 | # # # torch.load('./weights/pose_hrnet_w48_384x288.pth') 214 | # # torch.load('./weights/pose_hrnet_w32_256x192.pth') 215 | # # ) 216 | # # print('ok!!') 217 | # 218 | # # if torch.cuda.is_available() and False: 219 | # # torch.backends.cudnn.deterministic = True 220 | # # device = torch.device('cuda:0') 221 | # # else: 222 | # # device = torch.device('cpu') 223 | # device = torch.device('cuda:0') 224 | # print(device) 225 | # 226 | # model = model.to(device) 227 | # 228 | # y = model(torch.ones(1, 3, 256, 256).to(device)) 229 | # print(y.shape) 230 | # from torchsummaryX import summary 231 | # sum = summary(model.cuda(), torch.rand((1, 3, 256, 256)).cuda()) 232 | # sum.to_excel('data.xls') -------------------------------------------------------------------------------- /model/hrnet/hrnet_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Bottleneck(nn.Module): 6 | expansion = 4 7 | 8 | def __init__(self, inplanes, planes, stride=1, downsample=None, bn_momentum=0.1): 9 | super(Bottleneck, self).__init__() 10 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 11 | self.bn1 = nn.BatchNorm2d(planes, momentum=bn_momentum) 12 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 13 | self.bn2 = nn.BatchNorm2d(planes, momentum=bn_momentum) 14 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 15 | self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=bn_momentum) 16 | self.relu = nn.ReLU(inplace=True) 17 | self.downsample = downsample 18 | self.stride = stride 19 | 20 | def forward(self, x): 21 | residual = x 22 | 23 | out = self.conv1(x) 24 | out = self.bn1(out) 25 | out = self.relu(out) 26 | 27 | out = self.conv2(out) 28 | out = self.bn2(out) 29 | out = self.relu(out) 30 | 31 | out = self.conv3(out) 32 | out = self.bn3(out) 33 | 34 | if self.downsample is not None: 35 | residual = self.downsample(x) 36 | 37 | out += residual 38 | out = self.relu(out) 39 | 40 | return out 41 | 42 | 43 | class BasicBlock(nn.Module): 44 | expansion = 1 45 | 46 | def __init__(self, inplanes, planes, stride=1, downsample=None, bn_momentum=0.1): 47 | super(BasicBlock, self).__init__() 48 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 49 | self.bn1 = nn.BatchNorm2d(planes, momentum=bn_momentum) 50 | self.relu = nn.ReLU(inplace=True) 51 | self.conv2 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1, bias=False) 52 | self.bn2 = nn.BatchNorm2d(planes, momentum=bn_momentum) 53 | self.downsample = downsample 54 | self.stride = stride 55 | 56 | def forward(self, x): 57 | residual = x 58 | 59 | out = self.conv1(x) 60 | out = self.bn1(out) 61 | out = self.relu(out) 62 | 63 | out = self.conv2(out) 64 | out = self.bn2(out) 65 | 66 | if self.downsample is not None: 67 | residual = self.downsample(x) 68 | 69 | out += residual 70 | out = self.relu(out) 71 | 72 | return out -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from utils import plot_img_and_mask 2 | from utils import BasicDataset 3 | from model import Unet 4 | import matplotlib.pyplot as plt 5 | from model import Unet,deeplabv3plus_resnet50,deeplabv3plus_resnet101 6 | import logging 7 | import torch 8 | import numpy as np 9 | import torch.nn.functional as F 10 | import os 11 | import cv2 12 | from torchvision import transforms 13 | from PIL import Image 14 | from tqdm import tqdm 15 | from utils import BasicDataset 16 | from tqdm import tqdm 17 | # import os 18 | # os.environ["CUDA_VISIBLE_DEVICES"] = "2" 19 | def predict_img(net,ori_img,device): 20 | net.eval() 21 | img = np.float32(ori_img) / 127.5 - 1 22 | img=torch.from_numpy(img).permute(2, 0, 1).type(torch.FloatTensor).unsqueeze(0) 23 | img=img.to(device=device,dtype=torch.float32) 24 | 25 | with torch.no_grad(): 26 | output=net(img) 27 | probs=F.softmax(output,dim=1).squeeze(0) #[8, 256, 256] 28 | pre=torch.argmax(probs,dim=0).cpu().data.numpy()#[256,256] 29 | pre_mask_img=Image.fromarray(np.uint8(pre)) 30 | palette = [ 31 | 0, 0, 0, 32 | 0, 0, 255, 33 | 15, 29, 15, 34 | 26, 141, 52, 35 | 41, 41, 41, 36 | 65, 105, 225, 37 | 85, 11, 18, 38 | 128, 0, 128, 39 | ] 40 | pre_mask_img.putpalette(palette) 41 | 42 | # tf = transforms.Compose( 43 | # [ 44 | # transforms.ToPILImage(), 45 | # transforms.ToTensor() 46 | # ] 47 | # ) 48 | # probs = tf(probs.cpu()) 49 | # full_mask = probs.squeeze().cpu().numpy()#(3, 256, 256) 50 | 51 | return pre,pre_mask_img 52 | 53 | def mask_to_image(mask): 54 | return (mask.transpose(1,2,0) * 255).astype(np.uint8) 55 | 56 | kwargs={'map_location':lambda storage, loc: storage.cuda(1)} 57 | def load_GPUS(model,model_path,kwargs): 58 | state_dict = torch.load(model_path,**kwargs) 59 | # create new OrderedDict that does not contain `module.` 60 | from collections import OrderedDict 61 | new_state_dict = OrderedDict() 62 | for k, v in state_dict['net'].items(): 63 | name = k[7:] # remove `module.` 64 | new_state_dict[name] = v 65 | # load params 66 | model.load_state_dict(new_state_dict) 67 | return model 68 | if __name__ == "__main__": 69 | 70 | matches = [100, 200, 300, 400, 500, 600, 700, 800] 71 | dir_checkpoint = 'checkpoints/' 72 | device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') 73 | #net = Unet(n_channels=3, n_classes=8, bilinear=True) 74 | net = Unet(n_channels=3,n_classes=8) 75 | net.to(device=device) 76 | #net=load_GPUS(net, dir_checkpoint + 'best_score_model_res50_deeplabv3+.pth', kwargs) 77 | checkpoint = torch.load(dir_checkpoint + 'student_net.pth',map_location=device) 78 | net.load_state_dict(checkpoint['net']) 79 | logging.info("Model loaded !") 80 | 81 | list_path = "data/test.lst" 82 | output_path="data/results/" 83 | img_list = [line.strip('\n') for line in open(list_path)] 84 | for i, fn in tqdm(enumerate(img_list)): 85 | save_img = np.zeros((256, 256), dtype=np.uint16) 86 | logging.info("\nPredicting image {} ...".format(i)) 87 | img = Image.open(fn) 88 | pre,_=predict_img(net,img,device) 89 | for i in range(256): 90 | for j in range(256): 91 | save_img[i][j] = matches[int(pre[i][j])] 92 | index=fn.split("/")[-1].split(".")[0] 93 | cv2.imwrite(os.path.join(output_path, index+".png"), save_img) 94 | 95 | # image_path='../baseline/test/images/1_1_1.tif' 96 | # img=Image.open(image_path) 97 | # mask,pre_mask_img=predict_img(net,img,device) 98 | # # result = mask_to_image(mask) 99 | # print(mask) 100 | # 101 | # plt.figure() 102 | # plt.subplot(1,2,1) 103 | # plt.imshow(pre_mask_img) 104 | # plt.subplot(1,2,2) 105 | # plt.imshow(img) 106 | # plt.show() 107 | 108 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /runs/Sep14_22-36-00_amax-101studentLR_0.001_BS_32/events.out.tfevents.1600094160.amax-101.18906.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/runs/Sep14_22-36-00_amax-101studentLR_0.001_BS_32/events.out.tfevents.1600094160.amax-101.18906.0 -------------------------------------------------------------------------------- /runs/Sep15_12-37-15_amax-101studentLR_0.0001_BS_32/events.out.tfevents.1600144635.amax-101.9169.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/runs/Sep15_12-37-15_amax-101studentLR_0.0001_BS_32/events.out.tfevents.1600144635.amax-101.9169.0 -------------------------------------------------------------------------------- /runs/Sep15_13-55-14_amax-101studentLR_0.0001_BS_32/events.out.tfevents.1600149314.amax-101.17317.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/runs/Sep15_13-55-14_amax-101studentLR_0.0001_BS_32/events.out.tfevents.1600149314.amax-101.17317.0 -------------------------------------------------------------------------------- /runs/Sep15_19-33-06_amax-101studentLR_0.0001_BS_32/events.out.tfevents.1600169586.amax-101.20053.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/runs/Sep15_19-33-06_amax-101studentLR_0.0001_BS_32/events.out.tfevents.1600169586.amax-101.20053.0 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Nov 23 16:39:55 2019 4 | 5 | @author: Administrator 6 | """ 7 | 8 | # coding: utf-8 9 | 10 | import numpy as np 11 | import sys 12 | import torch.nn.functional as F 13 | from datetime import datetime 14 | 15 | import argparse 16 | import logging 17 | 18 | from model import Unet,deeplabv3plus_resnet50,deeplabv3plus_resnet101,DANet,HRNet,deeplabv3plus_xception 19 | from utils import BasicDataset,CrossEntropy,PolyLR,FocalLoss 20 | from eval import eval_net 21 | from torch.utils.tensorboard import SummaryWriter 22 | import torch 23 | from torch.utils.data import Dataset, DataLoader 24 | import torch.nn as nn 25 | from torch import optim 26 | from tqdm import tqdm 27 | import os 28 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 29 | 30 | kwargs={'map_location':lambda storage, loc: storage.cuda(1)} 31 | def load_GPUS(model,model_path,kwargs): 32 | state_dict = torch.load(model_path,**kwargs) 33 | # create new OrderedDict that does not contain `module.` 34 | from collections import OrderedDict 35 | new_state_dict = OrderedDict() 36 | for k, v in state_dict['net'].items(): 37 | name = k[7:] # remove `module.` 38 | new_state_dict[name] = v 39 | # load params 40 | model.load_state_dict(new_state_dict) 41 | return model 42 | 43 | 44 | correct_ratio = [] 45 | alpha = 0.5 46 | batch_size = 32 47 | 48 | sate_dataset_train = BasicDataset("./data/train.lst")#读取训练集文件,数据预处理在此类�? 49 | train_steps = len(sate_dataset_train) 50 | sate_dataset_val= BasicDataset("./data/val.lst") 51 | train_dataloader = DataLoader(sate_dataset_train, batch_size=batch_size, shuffle=True, num_workers=8)#将训练集封装成data_loader 52 | eval_dataloader = DataLoader(sate_dataset_val, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True)#将验证集封装�? 53 | 54 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 55 | teach_model = deeplabv3plus_resnet50(num_classes=8, output_stride=16) 56 | teach_model.to(device=device) 57 | teach_model.load_state_dict(torch.load('./checkpoints/0.7669_best_score_model_res50_deeplabv3+.pth',map_location=device)['net']) 58 | 59 | model = Unet(n_channels=3,n_classes=8) 60 | model.to(device=device) 61 | #model_dir = './checkpoints/best_score_model_unet.pth' 62 | model_dir = './checkpoints/student_net.pth' 63 | 64 | if os.path.exists(model_dir): 65 | #model = load_GPUS(model_dir, model_dir, kwargs) 66 | model.load_state_dict(torch.load(model_dir)['net']) 67 | print("loading model sccessful----" + model_dir) 68 | #model.load_state_dict(torch.load('teach_net_params_0.9895.pkl')) 69 | criterion = nn.CrossEntropyLoss() 70 | criterion2 = nn.KLDivLoss() 71 | 72 | optimizer = optim.Adam(model.parameters(),lr = 0.0001) 73 | #optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) 74 | n_size = batch_size*256*256 75 | 76 | writer = SummaryWriter(comment='student'+f'LR_0.0001_BS_32')#创建一个tensorboard文件 77 | epochs = 50 78 | global_step = 1 79 | for epoch in range(epochs): 80 | loss_sigma = 0.0 81 | correct = 0.0 82 | total = 0.0 83 | model.train() 84 | with tqdm(total=train_steps, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: 85 | for i, data in enumerate(train_dataloader): 86 | inputs, labels = data['image'], data['mask'] 87 | inputs = inputs.to(device=device, dtype=torch.float32) 88 | labels = labels.to(device=device, dtype=torch.long) 89 | 90 | optimizer.zero_grad() 91 | 92 | outputs = model(inputs) 93 | loss1 = criterion(outputs, labels) 94 | 95 | teacher_outputs = teach_model(inputs.float()) 96 | T = 2 97 | outputs_S = F.softmax(outputs/T,dim=1) 98 | outputs_T = F.softmax(teacher_outputs/T,dim=1) 99 | loss2 = criterion2(outputs_S,outputs_T)*T*T 100 | 101 | loss = loss1*(1-alpha) + loss2*alpha 102 | 103 | # loss = loss1 104 | loss.backward() 105 | optimizer.step() 106 | 107 | _, predicted = torch.max(outputs.data, dim = 1) 108 | correct = (predicted.cpu()==labels.cpu()).squeeze().sum().numpy() 109 | pbar.set_postfix(**{'loss_avg': loss.item(),'Acc': correct/n_size}) 110 | writer.add_scalar('Loss/train', loss.item(), global_step) 111 | writer.add_scalar('acc/train', correct, global_step) 112 | # print('loss_avg:{:.2} Acc:{:.2%}'.format(loss_avg, correct/n_size/100)) 113 | pbar.update(inputs.shape[0]) 114 | global_step += 1 115 | if epoch % 1 == 0: 116 | loss_sigma = 0.0 117 | correct = 0 118 | cls_num = 10 119 | conf_mat = np.zeros([cls_num, cls_num]) # 混淆矩阵 120 | model.eval() 121 | for i, data in enumerate(eval_dataloader): 122 | 123 | # 获取图片和标�? 124 | inputs, labels = data['image'], data['mask'] 125 | inputs = inputs.to(device=device, dtype=torch.float32) 126 | labels = labels.to(device=device, dtype=torch.long) 127 | # forward 128 | outputs = model(inputs) 129 | outputs.detach_() 130 | 131 | # 计算loss 132 | loss = criterion(outputs, labels) 133 | loss_sigma += loss.item() 134 | 135 | # 统计 136 | _, predicted = torch.max(outputs.data, 1) 137 | # labels = labels.data # Variable --> tensor 138 | correct += (predicted.cpu()==labels.cpu()).squeeze().sum().numpy() 139 | 140 | avg_correct = correct/len(eval_dataloader)/n_size 141 | val_loss,pixel_acc_avg,mean_iou_avg,fw_iou_avg = eval_net(model,eval_dataloader,device) 142 | writer.add_scalar('Loss/test', loss_sigma/len(eval_dataloader), global_step) 143 | writer.add_scalar('fw_iou/test', fw_iou_avg, global_step) 144 | writer.add_scalar('acc/test',avg_correct , global_step) 145 | 146 | if epoch==0: 147 | _fw_iou_avg = fw_iou_avg 148 | net_save_path = 'checkpoints/student_net' + '.pth' 149 | model_file = {'net':model.state_dict(),'correct':correct/len(eval_dataloader),'epoch':epoch+1} 150 | torch.save(model_file,net_save_path) 151 | print('-------------------------{} set correct:{:.4%}---------------------'.format('Valid', avg_correct)) 152 | print('-------------------------{} set fw_iou:{:.4%}---------------------'.format('Valid', fw_iou_avg)) 153 | elif fw_iou_avg > _fw_iou_avg: 154 | _fw_iou_avg = fw_iou_avg 155 | net_save_path = 'checkpoints/student_net' + '.pth' 156 | model_file = {'net':model.state_dict(),'correct':correct/len(eval_dataloader),'epoch':epoch+1} 157 | torch.save(model_file,net_save_path) 158 | print('-------------------------{} set correct:{:.4%}---------------------'.format('Valid', avg_correct)) 159 | print('-------------------------{} set fw_iou:{:.4%}---------------------'.format('Valid', fw_iou_avg)) 160 | 161 | -------------------------------------------------------------------------------- /train/images/1_3_2.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/train/images/1_3_2.tif -------------------------------------------------------------------------------- /train/images/1_3_3.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/train/images/1_3_3.tif -------------------------------------------------------------------------------- /train/images/1_4_0.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/train/images/1_4_0.tif -------------------------------------------------------------------------------- /train/images/1_4_1.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/train/images/1_4_1.tif -------------------------------------------------------------------------------- /train/images/1_4_2.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/train/images/1_4_2.tif -------------------------------------------------------------------------------- /train/images/1_4_3.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/train/images/1_4_3.tif -------------------------------------------------------------------------------- /train/images/1_5_0.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/train/images/1_5_0.tif -------------------------------------------------------------------------------- /train/images/1_5_1.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/train/images/1_5_1.tif -------------------------------------------------------------------------------- /train/images/1_5_2.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/train/images/1_5_2.tif -------------------------------------------------------------------------------- /train/images/1_5_3.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/train/images/1_5_3.tif -------------------------------------------------------------------------------- /train/labels/1_3_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/train/labels/1_3_2.png -------------------------------------------------------------------------------- /train/labels/1_3_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/train/labels/1_3_3.png -------------------------------------------------------------------------------- /train/labels/1_4_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/train/labels/1_4_0.png -------------------------------------------------------------------------------- /train/labels/1_4_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/train/labels/1_4_1.png -------------------------------------------------------------------------------- /train/labels/1_4_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/train/labels/1_4_2.png -------------------------------------------------------------------------------- /train/labels/1_4_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/train/labels/1_4_3.png -------------------------------------------------------------------------------- /train/labels/1_5_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/train/labels/1_5_0.png -------------------------------------------------------------------------------- /train/labels/1_5_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/train/labels/1_5_1.png -------------------------------------------------------------------------------- /train/labels/1_5_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/train/labels/1_5_2.png -------------------------------------------------------------------------------- /train/labels/1_5_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/train/labels/1_5_3.png -------------------------------------------------------------------------------- /train_student.py: -------------------------------------------------------------------------------- 1 | #-*-conding:utf-8-*- 2 | import argparse 3 | import logging 4 | from model import Unet,deeplabv3plus_resnet50,deeplabv3plus_resnet101,DANet,HRNet,deeplabv3plus_mobilenet 5 | from utils import BasicDataset,CrossEntropy,PolyLR,FocalLoss 6 | from torch.utils.tensorboard import SummaryWriter 7 | import torch 8 | from torch.utils.data import Dataset, DataLoader 9 | import torch.nn as nn 10 | from torch import optim 11 | from tqdm import tqdm 12 | import os 13 | from eval import eval_net 14 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 15 | def train_net(unet,device,batch_size,epochs,lr,dir_checkpoint,checkpoint_name): 16 | global_step = 0 17 | writer = SummaryWriter(comment=checkpoint_name+f'LR_{lr}_BS_{batch_size}')#创建一个tensorboard文件 18 | sate_dataset_train = BasicDataset("./data/train.lst")#读取训练集文件,数据预处理在此类中 19 | sate_dataset_val= BasicDataset("./data/val.lst") 20 | train_steps = len(sate_dataset_train) 21 | train_dataloader = DataLoader(sate_dataset_train, batch_size=batch_size, shuffle=True, num_workers=8)#将训练集封装成data_loader 22 | eval_dataloader = DataLoader(sate_dataset_val, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True)#将验证集封装成data_loader,drop_last是将最后一个batch不足32的丢弃 23 | criterion = nn.CrossEntropyLoss()#交叉熵损失函数 24 | #criterion = CrossEntropy() #交叉熵损失函数 25 | #criterion = FocalLoss()#focalloss损失函数 26 | #optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=1e-8)#优化器 27 | optimizer = optim.SGD(net.parameters(), lr=lr, weight_decay=1e-8,momentum = 0.9) # 优化器 28 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=80,factor=0.9,min_lr=5e-5)#学习率调整器 29 | #scheduler = PolyLR(optimizer, 8*100000/batch_size, power=0.9) 30 | epoch_val_loss = float('inf')#为了保存最佳模型,以验证集精度为标准 31 | fw_iou_avg = 0 32 | for epoch in range(epochs): 33 | epochs_loss=0#计算每个epoch的loss 34 | with tqdm(total=train_steps, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: 35 | for idx, batch_samples in enumerate(train_dataloader): 36 | batch_image, batch_mask = batch_samples["image"], batch_samples["mask"] 37 | batch_image=batch_image.to(device=device, dtype=torch.float32) 38 | logits=unet(batch_image) #torch.Size([batchsize, 8, 256, 256]) 39 | y_true=batch_mask.to(device=device, dtype=torch.long) #torch.Size([batchsize, 256, 256]) 40 | loss=criterion(logits,y_true) 41 | epochs_loss += loss.item() 42 | writer.add_scalar('Loss/train', loss.item(), global_step) 43 | pbar.set_postfix(**{'loss (batch)': loss.item()}) 44 | 45 | optimizer.zero_grad() 46 | loss.backward() 47 | nn.utils.clip_grad_value_(net.parameters(), 0.1)#梯度裁剪 48 | optimizer.step() 49 | pbar.update(batch_image.shape[0])#进度条的总轮数,默认为10 50 | global_step += 1 51 | scheduler.step(loss) # 监控量,调整学习率 52 | #scheduler.step() 53 | writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step) 54 | if global_step % (train_steps // ( batch_size)) == 0: 55 | for tag,value in net.named_parameters(): 56 | tag=tag.replace('.','/') 57 | writer.add_histogram('weights/'+tag,value.data.cpu().numpy(),global_step) 58 | writer.add_histogram('grads/' + tag, value.data.cpu().numpy(), global_step) 59 | val_loss,pixel_acc_avg,mean_iou_avg,_fw_iou_avg = eval_net(net,eval_dataloader,device) 60 | if fw_iou_avg<_fw_iou_avg: 61 | fw_iou_avg=_fw_iou_avg 62 | logging.info('Validation cross entropy: {}'.format(val_loss)) 63 | writer.add_scalar('Loss/test', val_loss, global_step) 64 | writer.add_scalar('pixel_acc_avg', pixel_acc_avg, global_step) 65 | writer.add_scalar('mean_iou_avg', mean_iou_avg, global_step) 66 | writer.add_scalar('fw_iou_avg', fw_iou_avg, global_step) 67 | 68 | 69 | #以下将每个验证集损失保存到模型文件中,每个epoch之后取出与当前损失进行比较,当取出损失大于当前损失时,保存模型 70 | if os.path.exists(dir_checkpoint+checkpoint_name): #如果已经存在模型文件 71 | checkpoint = torch.load(dir_checkpoint+checkpoint_name) 72 | print(fw_iou_avg, checkpoint['fw_iou_avg']) 73 | if fw_iou_avg>checkpoint['fw_iou_avg']: 74 | print('save!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') 75 | state ={'net':net.state_dict(),'epoch_val_score':epoch_val_loss,'fw_iou_avg':fw_iou_avg,'epochth':epoch + 1} 76 | torch.save(state,dir_checkpoint+checkpoint_name) 77 | logging.info(f'checkpoint {epoch + 1} saved!') 78 | else:#如果不存在模型文件 79 | try: 80 | os.mkdir(dir_checkpoint) 81 | logging.info('create checkpoint directory!') 82 | except OSError: 83 | logging.info('save checkpoint error!') 84 | state ={'net':net.state_dict(),'epoch_val_score':epoch_val_loss,'fw_iou_avg':fw_iou_avg,'epochth':epoch + 1} 85 | torch.save(state, dir_checkpoint + checkpoint_name) 86 | logging.info(f'checkpoint {epoch + 1} saved!') 87 | writer.close() 88 | 89 | def get_args(): 90 | parser = argparse.ArgumentParser(description='Train the UNet on images and target masks', 91 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 92 | parser.add_argument('-e', '--epochs', metavar='E', type=int, default=50, 93 | help='Number of epochs', dest='epochs') 94 | parser.add_argument('-b', '--batch-size', metavar='B', type=int, nargs='?', default=64, 95 | help='Batch size', dest='batchsize') 96 | parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=1e-3, 97 | help='Learning rate', dest='lr') 98 | return parser.parse_args() 99 | 100 | kwargs={'map_location':lambda storage, loc: storage.cuda(1)} 101 | def load_GPUS(model,model_path,kwargs): 102 | state_dict = torch.load(model_path,**kwargs) 103 | # create new OrderedDict that does not contain `module.` 104 | print("The model's loss is:"+str(state_dict['epoch_val_score'])) 105 | print("The model's fw_iou_avg is:" + str(state_dict['fw_iou_avg'])) 106 | from collections import OrderedDict 107 | new_state_dict = OrderedDict() 108 | for k, v in state_dict['net'].items(): 109 | name = k[7:] # remove `module.` 110 | new_state_dict[name] = v 111 | # load params 112 | model.load_state_dict(new_state_dict) 113 | print("loading model success!") 114 | return model 115 | 116 | if __name__=="__main__": 117 | dir_checkpoint='checkpoints/' 118 | checkpoint_name='best_score_model_unet.pth' 119 | args = get_args() 120 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 121 | logging.info(f'Using device {device}') 122 | net = Unet(n_channels=3, n_classes=8, bilinear=True) 123 | #net = deeplabv3plus_mobilenet(num_classes=8, output_stride=16) 124 | #net = DANet(8, backbone='resnet50', pretrained_base=True) 125 | #net = HRNet(48, 8, 0.1) 126 | net.to(device=device) 127 | # state_dict = torch.load(dir_checkpoint+checkpoint_name) 128 | # print(state_dict['epoch_val_score'], state_dict['fw_iou_avg']) 129 | # net.load_state_dict(state_dict['net']) 130 | #net = load_GPUS(net, dir_checkpoint +checkpoint_name, kwargs) 131 | #net = torch.nn.DataParallel(net) 132 | train_net(net,device, args.batchsize, args.epochs,args.lr,dir_checkpoint,checkpoint_name) 133 | 134 | # import torch 135 | # model=deeplabv3_resnet50(num_classes=8, output_stride=16) 136 | # # from torchsummary import summary 137 | # model=model.cuda() 138 | # x=torch.rand((2, 3, 256, 256)).cuda() 139 | # # summary(model.cuda(),x) 140 | # y=model(x) 141 | # print(y.shape) 142 | 143 | 144 | 145 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic_dataset import BasicDataset 2 | from .data_vis import plot_img_and_mask 3 | from .function import Evaluator 4 | from .loss import CrossEntropy,FocalLoss 5 | from .polyLR import PolyLR -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/basic_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/utils/__pycache__/basic_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/basic_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/utils/__pycache__/basic_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/data_vis.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/utils/__pycache__/data_vis.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/data_vis.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/utils/__pycache__/data_vis.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/function.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/utils/__pycache__/function.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/function.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/utils/__pycache__/function.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/hrnet_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/utils/__pycache__/hrnet_loss.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/utils/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/utils/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/polyLR.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/utils/__pycache__/polyLR.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/polyLR.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihngo/Image_segmentation/1b913fe430c326d4570b7db3ac0113045ff0aab1/utils/__pycache__/polyLR.cpython-37.pyc -------------------------------------------------------------------------------- /utils/basic_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from torch.utils.data import Dataset, DataLoader 4 | import logging 5 | import numpy as np 6 | from PIL import Image 7 | class BasicDataset(Dataset): 8 | def __init__(self,list_path): 9 | self.list_path=list_path 10 | self.img_list = [] 11 | self.matches = [100, 200, 300, 400, 500, 600, 700, 800] 12 | self.img_list=[line.strip().split(" ") for line in open(list_path)] 13 | print(f'Creating dataset with {len(self.img_list)} examples') 14 | 15 | def __len__(self): 16 | return len(self.img_list) 17 | 18 | @classmethod 19 | def preprocess(cls,img,mask,matches,nClasses=8): 20 | #process img 21 | img = np.float32(img) / 127.5 - 1 22 | mask=np.array(mask) 23 | #process mask 24 | #seg_labels = np.zeros((256, 256, nClasses)) 25 | for m in matches: 26 | mask[mask == m] = matches.index(m) 27 | 28 | #one-hot 29 | # for c in range(nClasses): 30 | # seg_labels[:, :, c] = (mask == c).astype(int) 31 | # seg_labels = np.reshape(seg_labels, (256 * 256, nClasses)) 32 | return img,mask 33 | 34 | def __getitem__(self, i): 35 | img_file = self.img_list[i][0] 36 | mask_file=self.img_list[i][1] 37 | img = Image.open(img_file) 38 | mask = Image.open(mask_file) 39 | 40 | img,mask = self.preprocess(img,mask,self.matches) 41 | return { 42 | 'image': torch.from_numpy(img).permute(2,0,1).type(torch.FloatTensor), 43 | 'mask': torch.from_numpy(mask).type(torch.FloatTensor) 44 | } 45 | # 46 | # if __name__ == "__main__": 47 | # sentiment_dataset = BasicDataset("../data/train.lst") 48 | # sentiment_dataloader = DataLoader(sentiment_dataset, batch_size=1,shuffle=True,num_workers=5) 49 | # for idx, batch_samples in enumerate(sentiment_dataloader): 50 | # text_batchs, text_labels = batch_samples["image"], batch_samples["mask"] 51 | # # print(text_batchs) 52 | # print(text_labels[0].shape) 53 | # break 54 | 55 | 56 | #one-hot 57 | # for c in range(nClasses): 58 | # seg_labels[:, :, c] = (mask == c).astype(int) 59 | # seg_labels = np.reshape(seg_labels, (256 * 256, nClasses)) 60 | return img,mask 61 | 62 | def __getitem__(self, i): 63 | img_file = self.img_list[i][0] 64 | mask_file=self.img_list[i][1] 65 | img = Image.open(img_file) 66 | mask = Image.open(mask_file) 67 | 68 | img,mask = self.preprocess(img,mask,self.matches) 69 | return { 70 | 'image': torch.from_numpy(img).permute(2,0,1).type(torch.FloatTensor), 71 | 'mask': torch.from_numpy(mask).type(torch.FloatTensor) 72 | } 73 | -------------------------------------------------------------------------------- /utils/data_vis.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | 4 | def plot_img_and_mask(img, mask): 5 | classes = mask.shape[2] if len(mask.shape) > 2 else 1 6 | fig, ax = plt.subplots(1, classes + 1) 7 | ax[0].set_title('Input image') 8 | ax[0].imshow(img) 9 | if classes > 1: 10 | for i in range(classes): 11 | ax[i+1].set_title(f'Output mask (class {i+1})') 12 | ax[i+1].imshow(mask[:, :, i]) 13 | else: 14 | ax[1].set_title(f'Output mask') 15 | ax[1].imshow(mask) 16 | plt.xticks([]), plt.yticks([]) 17 | plt.show() -------------------------------------------------------------------------------- /utils/function.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | class Evaluator(object): 3 | def __init__(self, num_class): 4 | self.num_class = num_class 5 | self.confusion_matrix = np.zeros((self.num_class,) * 2) # 21*21???,???ground truth??,???preds???,??? 6 | 7 | ''' 8 | ???????????? 9 | ''' 10 | 11 | def Pixel_Accuracy(self): 12 | Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum() 13 | return Acc 14 | 15 | ''' 16 | ?????????????? 17 | ''' 18 | 19 | def Pixel_Accuracy_Class(self): 20 | Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1) 21 | Acc = np.nanmean(Acc) 22 | return Acc 23 | 24 | ''' 25 | Mean Intersection over Union(MIoU?????)???????????????????????????. 26 | ????????????????????ground truth??????predicted segmentation?? 27 | ?????????????intersection?????????????????????????IoU?????? 28 | 29 | ??21???,???IOU: 30 | ??,????1?IOU????: 31 | (1)???ground truth?????1???? 32 | (2)????????????1???? 33 | (1) + (2)??????????(????????????, ?:?????????????) 34 | ????????(??ground truth???????????????),??????????(?????1??????:??TP,FP,FN) 35 | ????: 36 | TP(??): ????, ???????, ????? 37 | FP(??): ????, ???????, ????? 38 | FN(??): ????, ???????, ????? 39 | 40 | TN(??): ????, ???????, ????? #???1??,????????? 41 | (???, ??:???1, ??:????1) 42 | 43 | mIoU: 44 | ??????????IoU????? 45 | 46 | ''' 47 | 48 | def Mean_Intersection_over_Union(self): 49 | MIoU = np.diag(self.confusion_matrix) / ( 50 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 51 | np.diag(self.confusion_matrix)) 52 | MIoU = np.nanmean(MIoU) # ??0??mean,shape:[21] 53 | return MIoU 54 | 55 | def Class_IOU(self): 56 | MIoU = np.diag(self.confusion_matrix) / ( 57 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 58 | np.diag(self.confusion_matrix)) 59 | return MIoU 60 | 61 | def Frequency_Weighted_Intersection_over_Union(self): 62 | freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) 63 | iu = np.diag(self.confusion_matrix) / ( 64 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 65 | np.diag(self.confusion_matrix)) 66 | 67 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() 68 | return FWIoU 69 | 70 | ''' 71 | ?????: 72 | evaluator = Evaluate(4) #???????4 73 | evaluator.add_batch(target, preb) #target:[batch_size, 512, 512] , preb:[batch_size, 512, 512] 74 | ?add_batch?????epoch???????????ground truth?????, ???confusion??(?????mean) 75 | 76 | 77 | ??????: 78 | gt_image: target ??????? [batch_size, 512, 512] 79 | per_image: preb ???????????? [batch_size, 512, 512] 80 | 81 | parameters: 82 | mask: ground truth?????(??[0, classe_num])???label?mask---????ground truth????????????[0, 20] 83 | label: ????????, ????????num_class*num_class??, ??label???????0?num_class**2??. [batch_size, 512, 512] 84 | cout(reshape): ??????????????,???????,????????,count???(x, y)?????????????????x,????y????? 85 | np.bincount: https://blog.csdn.net/xlinsist/article/details/51346523 86 | confusion_matrix: ????????????????????(preb?target??),????????????????????????? 87 | ''' 88 | 89 | # ?????? 90 | def _generate_matrix(self, gt_image, pre_image): 91 | mask = (gt_image >= 0) & (gt_image < self.num_class) # ground truth?????(??[0, classe_num])???label?mask 92 | 93 | label = self.num_class * gt_image[mask].astype('int') + pre_image[mask] 94 | # np.bincount????0?n**2-1?n**2?????????????????(n, n) 95 | count = np.bincount(label, minlength=self.num_class ** 2) 96 | confusion_matrix = count.reshape(self.num_class, self.num_class) # 21 * 21(for pascal) 97 | return confusion_matrix 98 | 99 | # -------------------------------------------------------------------------------- 100 | 101 | def add_batch(self, gt_image, pre_image): 102 | assert gt_image.shape == pre_image.shape 103 | tmp = self._generate_matrix(gt_image, pre_image) 104 | # ?????????????,?21*21?????pixel-wise?? 105 | self.confusion_matrix += self._generate_matrix(gt_image, pre_image) 106 | 107 | def reset(self): 108 | self.confusion_matrix = np.zeros((self.num_class,) * 2) 109 | 110 | # if __name__=="__main__": 111 | # gt_image = np.array([ 112 | # [0, 1, 2, 4], 113 | # [0, 0, 0, 0], 114 | # [0, 0, 0, 0], 115 | # [0, 0, 0, 0] 116 | # ]) 117 | # 118 | # pre_image = np.array([ 119 | # [0, 1, 2, 4], 120 | # [0, 1, 0, 0], 121 | # [0, 1, 0, 0], 122 | # [0, 0, 1, 0] 123 | # ]) 124 | # e=Evaluator(num_class=8) 125 | # e.add_batch(gt_image,pre_image) 126 | # acc=e.Pixel_Accuracy() 127 | # print(acc) -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import logging 5 | class CrossEntropy(nn.Module): 6 | def __init__(self, ignore_label=-1, weight=None): 7 | super(CrossEntropy, self).__init__() 8 | self.ignore_label = ignore_label 9 | self.criterion = nn.CrossEntropyLoss( 10 | weight=weight, 11 | ignore_index=ignore_label 12 | ) 13 | 14 | def _forward(self, score, target): 15 | ph, pw = score.size(2), score.size(3) 16 | h, w = target.size(1), target.size(2) 17 | if ph != h or pw != w: 18 | score = F.interpolate(input=score, size=( 19 | h, w), mode='bilinear', align_corners=False) 20 | loss = self.criterion(score, target) 21 | 22 | return loss 23 | 24 | def forward(self, score, target): 25 | score = [score] 26 | weights = [1] 27 | assert len(weights) == len(score) 28 | 29 | return sum([w * self._forward(x, target) for (w, x) in zip(weights, score)]) 30 | 31 | class FocalLoss(nn.Module): 32 | def __init__(self, alpha=0.5, gamma=2, size_average=True, ignore_index=255): 33 | super(FocalLoss, self).__init__() 34 | self.alpha = alpha 35 | self.gamma = gamma 36 | self.ignore_index = ignore_index 37 | self.size_average = size_average 38 | 39 | def forward(self, inputs, targets): 40 | ce_loss = F.cross_entropy(inputs, targets, 41 | reduction='none', ignore_index=self.ignore_index) 42 | pt = torch.exp(-ce_loss) 43 | focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss 44 | if self.size_average: 45 | return focal_loss.mean() 46 | else: 47 | return focal_loss.sum() -------------------------------------------------------------------------------- /utils/polyLR.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler, StepLR 2 | 3 | 4 | class PolyLR(_LRScheduler): 5 | def __init__(self, optimizer, max_iters, power=0.9, last_epoch=-1, min_lr=1e-6): 6 | self.power = power 7 | self.max_iters = max_iters # avoid zero lr 8 | self.min_lr = min_lr 9 | super(PolyLR, self).__init__(optimizer, last_epoch) 10 | 11 | def get_lr(self): 12 | return [max(base_lr * (1 - self.last_epoch / self.max_iters) ** self.power, self.min_lr) 13 | for base_lr in self.base_lrs] --------------------------------------------------------------------------------