├── .idea ├── .gitignore ├── dblue.iml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── other.xml ├── README.md ├── models ├── resnet.py ├── resunet.py └── unet.py ├── module ├── image.py └── slice.py └── progress └── image_slice.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # 默认忽略的文件 2 | /shelf/ 3 | /workspace.xml 4 | # 基于编辑器的 HTTP 客户端请求 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/dblue.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 28 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 遥感深度学习入门 2 | 3 | image.py:影像读取相关函数 4 | 5 | slice.py:影像切分相关函数 6 | 7 | image_slice.py:主函数,执行切分 8 | 9 | unet.py:Unet网络定义 10 | 11 | resunet.py:ResUnet网络定义 12 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/11/4 10:46 3 | # @Author : Zph 4 | # @Email : hhs_zph@mail.imu.edu.cn 5 | # @File : resnet.py 6 | # @Software: PyCharm 7 | 8 | 9 | import torch 10 | from torch import nn 11 | import torchinfo 12 | import onnx 13 | import netron 14 | 15 | 16 | class ConvBN(nn.Module): 17 | """定义conv+bn结构,可选relu""" 18 | 19 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, do_relu=False): 20 | super(ConvBN, self).__init__() 21 | layers = [ 22 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False), 23 | nn.BatchNorm2d(out_channels)] 24 | if do_relu: 25 | layers.append(nn.ReLU(inplace=True)) 26 | self.conv_bn_relu = nn.Sequential(*layers) 27 | 28 | def forward(self, in_lyr): 29 | return self.conv_bn_relu(in_lyr) 30 | 31 | 32 | class BasicBlock(nn.Module): 33 | """定义BasicBlock结构""" 34 | 35 | def __init__(self, in_channels, out_channels, stride=None, downsample=None): 36 | super(BasicBlock, self).__init__() 37 | # 是否进行降采样 38 | self.downsample = downsample 39 | # 残差计算后的relu 40 | self.relu = nn.ReLU(inplace=True) 41 | # 定义主路:conv+bn+relu+conv+bn 42 | layers = [ConvBN(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, do_relu=True), 43 | ConvBN(out_channels, out_channels, kernel_size=3, stride=1, padding=1, do_relu=False)] 44 | self.basicblock = nn.Sequential(*layers) 45 | 46 | # 是否进行降采样 47 | if downsample: 48 | # 定义降采样旁路:1×1conv 49 | self.do_downsample = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2, padding=0) 50 | 51 | def forward(self, in_lyr): 52 | # 进行主路计算 53 | basicblock_lyr = self.basicblock(in_lyr) 54 | if self.downsample: 55 | # 若进行降采样,通过降采样旁路进行计算,再相加 56 | shortcut_lyr = self.do_downsample(in_lyr) 57 | residual_lyr = self.relu(basicblock_lyr + shortcut_lyr) 58 | else: 59 | # 若不进行降采样,直接进行相加 60 | residual_lyr = self.relu(basicblock_lyr + in_lyr) 61 | return residual_lyr # 返回残差块BasicBlock 62 | 63 | 64 | class BottleNeck(nn.Module): 65 | """定义BottleNeck结构""" 66 | def __init__(self, in_channels, mid_channels, out_channels, stride=None, downsample=None): 67 | super(BottleNeck, self).__init__() 68 | # 是否进行降采样 69 | self.downsample = downsample 70 | # 残差计算后的relu 71 | self.relu = nn.ReLU(inplace=True) 72 | # 定义主路 73 | layers = [ConvBN(in_channels, mid_channels, kernel_size=1, stride=1, padding=0, do_relu=True), 74 | ConvBN(mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, do_relu=True), 75 | ConvBN(mid_channels, out_channels, kernel_size=1, stride=1, padding=0, do_relu=False)] 76 | self.bottleneck = nn.Sequential(*layers) 77 | # 是否进行降采样 78 | if downsample: 79 | self.do_downsample = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0) 80 | 81 | def forward(self, in_lyr): 82 | # 进行主路计算 83 | bottleneck_lyr = self.bottleneck(in_lyr) 84 | if self.downsample: 85 | # 若进行降采样,通过降采样旁路进行计算,再相加 86 | shortcut_lyr = self.do_downsample(in_lyr) 87 | residual_lyr = self.relu(bottleneck_lyr + shortcut_lyr) 88 | else: 89 | # 若不进行降采样,直接进行相加 90 | residual_lyr = self.relu(bottleneck_lyr + in_lyr) 91 | return residual_lyr 92 | 93 | 94 | class Resnet_bb(nn.Module): 95 | """创建基于BasicBlock的Resnet18或34""" 96 | 97 | def __init__(self, in_channels, conv_num=None, class_num=None): 98 | super(Resnet_bb, self).__init__() 99 | # 第一组卷积组合 100 | self.conv1 = ConvBN(in_channels, 64, kernel_size=7, stride=2, padding=3, do_relu=True) 101 | # 最大池化 102 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) 103 | 104 | # 根据传入的残差块层数,创建残差块的组合 105 | def get_conv_block(in_channels, out_channels, conv_count=None, downsample=None): 106 | if downsample: 107 | convblock = [BasicBlock(in_channels, out_channels, stride=2, downsample=True)] 108 | else: 109 | convblock = [BasicBlock(in_channels, out_channels, stride=1, downsample=False)] 110 | 111 | for _ in range(conv_count - 1): 112 | convblock.append(BasicBlock(out_channels, out_channels, stride=1, downsample=False)) 113 | 114 | return convblock 115 | 116 | self.conv2 = nn.Sequential(*get_conv_block(64, 64, conv_num[0], downsample=False)) 117 | self.conv3 = nn.Sequential(*get_conv_block(64, 128, conv_num[1], downsample=True)) 118 | self.conv4 = nn.Sequential(*get_conv_block(128, 256, conv_num[2], downsample=True)) 119 | self.conv5 = nn.Sequential(*get_conv_block(256, 512, conv_num[3], downsample=True)) 120 | # 平均池化+全连接层 121 | self.avepool = nn.AvgPool2d(kernel_size=7) 122 | self.full = nn.Linear(512, class_num) 123 | 124 | def forward(self, in_lyr): 125 | conv1_lyr = self.conv1(in_lyr) 126 | maxpool_lyr = self.maxpool(conv1_lyr) 127 | conv2_lyr = self.conv2(maxpool_lyr) 128 | conv3_lyr = self.conv3(conv2_lyr) 129 | conv4_lyr = self.conv4(conv3_lyr) 130 | conv5_lyr = self.conv5(conv4_lyr) 131 | avepool_lyr = self.avepool(conv5_lyr) 132 | full_lyr = self.full(avepool_lyr.view(avepool_lyr.size(0), -1)) 133 | return full_lyr 134 | 135 | 136 | class Resnet_bn(nn.Module): 137 | """创建基于BottleNeck的Resnet50或101""" 138 | 139 | def __init__(self, in_channels, conv_num=None, class_num=None): 140 | super(Resnet_bn, self).__init__() 141 | # 第一组卷积组合 142 | self.conv1 = ConvBN(in_channels, 64, kernel_size=7, stride=2, padding=3, do_relu=True) 143 | # 最大池化 144 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) 145 | 146 | # 根据传入的残差块层数,创建残差块的组合 147 | def get_conv_block(in_channels, mid_channels, out_channels, conv_count=None, stride=None): 148 | convblock = [BottleNeck(in_channels, mid_channels, out_channels, stride=stride, downsample=True)] 149 | for _ in range(conv_count - 1): 150 | convblock.append(BottleNeck(out_channels, mid_channels, out_channels, stride=1, downsample=False)) 151 | 152 | return convblock 153 | 154 | self.conv2 = nn.Sequential(*get_conv_block(64, 64, 256, conv_num[0], stride=1)) 155 | self.conv3 = nn.Sequential(*get_conv_block(256, 128, 512, conv_num[1], stride=2)) 156 | self.conv4 = nn.Sequential(*get_conv_block(512, 256, 1024, conv_num[2], stride=2)) 157 | self.conv5 = nn.Sequential(*get_conv_block(1024, 512, 2048, conv_num[3], stride=2)) 158 | # 平均池化+全连接层 159 | self.avepool = nn.AvgPool2d(kernel_size=7) 160 | self.full = nn.Linear(2048, class_num) 161 | 162 | def forward(self, in_lyr): 163 | conv1_lyr = self.conv1(in_lyr) 164 | maxpool_lyr = self.maxpool(conv1_lyr) 165 | conv2_lyr = self.conv2(maxpool_lyr) 166 | conv3_lyr = self.conv3(conv2_lyr) 167 | conv4_lyr = self.conv4(conv3_lyr) 168 | conv5_lyr = self.conv5(conv4_lyr) 169 | avepool_lyr = self.avepool(conv5_lyr) 170 | full_lyr = self.full(avepool_lyr.view(avepool_lyr.size(0), -1)) 171 | return full_lyr 172 | 173 | 174 | if __name__ == "__main__": 175 | resnet = Resnet_bn(in_channels=3, conv_num=[3, 4, 6, 3], class_num=1000).cuda() 176 | print(resnet) 177 | torchinfo.summary(resnet, input_size=(1, 3, 224, 224)) 178 | modelData = 'resnet_test.pth' 179 | test = torch.randn([1, 3, 224, 224]).cuda() 180 | # 保存为onnx格式 181 | torch.onnx.export(resnet, test, modelData, export_params=True) 182 | # 重新读取 183 | onnx_model = onnx.load(modelData) 184 | # 增加特征图纬度信息 185 | onnx.save(onnx.shape_inference.infer_shapes(onnx_model), modelData) 186 | # 显示网络结构 187 | netron.start(modelData) 188 | # -------------------------------------------------------------------------------- /models/resunet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/11/4 11:40 3 | # @Author : Zph 4 | # @Email : hhs_zph@mail.imu.edu.cn 5 | # @File : resunet.py 6 | # @Software: PyCharm 7 | 8 | import torch 9 | from torch import nn 10 | import torchinfo 11 | import onnx 12 | import netron 13 | 14 | class ResBlock(nn.Module): 15 | """定义ResBlock结构""" 16 | 17 | def __init__(self, in_channels, out_channels): 18 | super(ResBlock, self).__init__() 19 | # 定义残差结构主路 20 | res_layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), 21 | nn.BatchNorm2d(out_channels), 22 | nn.ReLU(inplace=True), 23 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), 24 | nn.BatchNorm2d(out_channels)] 25 | self.res_layers = nn.Sequential(*res_layers) 26 | # 定义残差结构旁路 27 | shortcut_layers = [ 28 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), 29 | nn.BatchNorm2d(out_channels)] 30 | self.shortcut = nn.Sequential(*shortcut_layers) 31 | self.relu = nn.ReLU(inplace=True) 32 | 33 | def forward(self, in_lyr): 34 | res = self.res_layers(in_lyr) 35 | shortcut = self.shortcut(in_lyr) 36 | return self.relu(res + shortcut) 37 | 38 | class DownSampling(nn.Module): 39 | """定义下采样操作""" 40 | 41 | def __init__(self, in_channel, out_channel): 42 | super(DownSampling, self).__init__() 43 | self.down_sampling = nn.Sequential( 44 | nn.MaxPool2d(2), 45 | ResBlock(in_channel, out_channel) 46 | ) 47 | 48 | def forward(self, in_channel): 49 | return self.down_sampling(in_channel) 50 | 51 | 52 | class UpSampling(nn.Module): 53 | """定义上采样操作""" 54 | 55 | def __init__(self, in_channel, out_channel): 56 | super(UpSampling, self).__init__() 57 | self.up_sampling = nn.Sequential( 58 | nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2, bias=False), 59 | nn.ReLU(inplace=True), # inplace=True,结果直接替换输入的层,节省内存 60 | nn.BatchNorm2d(out_channel) 61 | ) 62 | self.conv = ResBlock(in_channel, out_channel) 63 | 64 | def forward(self, in_lyr, down_lyr): 65 | lyr_1 = self.up_sampling(in_lyr), 66 | cat_lyr = torch.cat([lyr_1[0], down_lyr], dim=1) 67 | out_lyr = self.conv(cat_lyr) 68 | return out_lyr 69 | 70 | 71 | class ResUNet(nn.Module): # 继承nn.Module类 72 | """定义ResUNet网路结构""" 73 | 74 | def __init__(self, in_channel, out_channel, hidden_channels): 75 | super(ResUNet, self).__init__() # 利用父类(nn.Module)的初始化方法来初始化继承的属性 76 | 77 | self.conv_1 = ResBlock(in_channel, hidden_channels[0]) # 进行两次(卷积+relu+BN) 78 | self.down_1 = DownSampling(hidden_channels[0], hidden_channels[1]) 79 | self.down_2 = DownSampling(hidden_channels[1], hidden_channels[2]) 80 | self.down_3 = DownSampling(hidden_channels[2], hidden_channels[3]) 81 | self.down_4 = DownSampling(hidden_channels[3], hidden_channels[4]) 82 | 83 | self.up_1 = UpSampling(hidden_channels[4], hidden_channels[3]) 84 | self.up_2 = UpSampling(hidden_channels[3], hidden_channels[2]) 85 | self.up_3 = UpSampling(hidden_channels[2], hidden_channels[1]) 86 | self.up_4 = UpSampling(hidden_channels[1], hidden_channels[0]) 87 | 88 | self.conv_2 = nn.Conv2d(hidden_channels[0], out_channel, kernel_size=1, stride=1, padding=0) 89 | self.sigmoid = nn.Sigmoid() 90 | 91 | def forward(self, in_lyr): 92 | conv_lyr_1 = self.conv_1(in_lyr) 93 | down_lyr_1 = self.down_1(conv_lyr_1) 94 | down_lyr_2 = self.down_2(down_lyr_1) 95 | down_lyr_3 = self.down_3(down_lyr_2) 96 | down_lyr_4 = self.down_4(down_lyr_3) 97 | up_lyr_1 = self.up_1(down_lyr_4, down_lyr_3) 98 | up_lyr_2 = self.up_2(up_lyr_1, down_lyr_2) 99 | up_lyr_3 = self.up_3(up_lyr_2, down_lyr_1) 100 | up_lyr_4 = self.up_4(up_lyr_3, conv_lyr_1) 101 | conv_lyr_2 = self.conv_2(up_lyr_4) 102 | return self.sigmoid(conv_lyr_2) 103 | 104 | 105 | 106 | if __name__ == '__main__': 107 | # hidden_channels = [64, 128, 256, 512, 1024] 108 | hidden_channels = [32, 64, 128, 256, 512] 109 | resunet = ResUNet(4, 1, hidden_channels).cuda() 110 | 111 | torchinfo.summary(resunet, input_size=(1, 4, 256, 256)) 112 | # 113 | # test = torch.randn([1, 4, 256, 256]).cuda() 114 | # # 设定网络保存路径 115 | # modelData = 'resunet_test.pth' 116 | # # 保存为onnx格式 117 | # torch.onnx.export(resunet, test, modelData, export_params=True, opset_version=9) 118 | # # 重新读取 119 | # onnx_model = onnx.load(modelData) 120 | # # 增加特征图纬度信息 121 | # onnx.save(onnx.shape_inference.infer_shapes(onnx_model), modelData) 122 | # # 显示网络结构 123 | # netron.start(modelData) 124 | 125 | 126 | 127 | 128 | # res = UpSampling(256, 128) 129 | # in_lyr = torch.randn(size=(1,256,64,64)) 130 | # down_lyr = torch.randn(size=(1,128,128,128)) 131 | # result = res(in_lyr,down_lyr) 132 | # print(result.size()) 133 | 134 | #torchinfo.summary(res, input_size=(1, 512, 256, 256)) 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/11/4 10:46 3 | # @Author : Zph 4 | # @Email : hhs_zph@mail.imu.edu.cn 5 | # @File : unet.py 6 | # @Software: PyCharm 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torchinfo 11 | import netron 12 | import onnx 13 | 14 | 15 | class DoubleConv(nn.Module): 16 | """定义卷积块(双卷积+relu)""" 17 | 18 | def __init__(self, in_channel, out_channel): 19 | super(DoubleConv, self).__init__() 20 | self.doubleconv = nn.Sequential( 21 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, bias=False), 22 | nn.ReLU(inplace=True), # inplace=True,结果直接替换输入的层,节省内存 23 | nn.BatchNorm2d(out_channel), 24 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, bias=False), 25 | nn.ReLU(inplace=True), 26 | nn.BatchNorm2d(out_channel) 27 | ) 28 | 29 | def forward(self, in_channel): 30 | return self.doubleconv(in_channel) 31 | 32 | 33 | class DownSampling(nn.Module): 34 | """定义下采样操作""" 35 | 36 | def __init__(self, in_channel, out_channel): 37 | super(DownSampling, self).__init__() 38 | self.down_sampling = nn.Sequential( 39 | nn.MaxPool2d(2), 40 | DoubleConv(in_channel, out_channel) 41 | ) 42 | 43 | def forward(self, in_channel): 44 | return self.down_sampling(in_channel) 45 | 46 | 47 | class UpSampling(nn.Module): 48 | """定义上采样操作""" 49 | 50 | def __init__(self, in_channel, out_channel): 51 | super(UpSampling, self).__init__() 52 | self.up_sampling = nn.Sequential( 53 | nn.ConvTranspose2d(in_channels=in_channel, out_channels=out_channel, kernel_size=2, stride=2, bias=False), 54 | nn.ReLU(inplace=True), # inplace=True,结果直接替换输入的层,节省内存 55 | nn.BatchNorm2d(out_channel) 56 | ) 57 | self.conv = DoubleConv(in_channel, out_channel) 58 | 59 | def forward(self, in_lyr, down_lyr): 60 | lyr_1 = self.up_sampling(in_lyr), 61 | cat_lyr = torch.cat([lyr_1[0], down_lyr], dim=1) 62 | out_lyr = self.conv(cat_lyr) 63 | return out_lyr 64 | 65 | 66 | class UNet(nn.Module): # 继承nn.Module类 67 | """定义Unet网路结构""" 68 | 69 | def __init__(self, in_channel, out_channel, hidden_channels): 70 | super(UNet, self).__init__() # 利用父类(nn.Module)的初始化方法来初始化继承的属性 71 | 72 | self.conv_1 = DoubleConv(in_channel, hidden_channels[0]) # 进行两次(卷积+relu+BN) 73 | self.down_1 = DownSampling(hidden_channels[0], hidden_channels[1]) 74 | self.down_2 = DownSampling(hidden_channels[1], hidden_channels[2]) 75 | self.down_3 = DownSampling(hidden_channels[2], hidden_channels[3]) 76 | self.down_4 = DownSampling(hidden_channels[3], hidden_channels[4]) 77 | self.up_1 = UpSampling(hidden_channels[4], hidden_channels[3]) 78 | self.up_2 = UpSampling(hidden_channels[3], hidden_channels[2]) 79 | self.up_3 = UpSampling(hidden_channels[2], hidden_channels[1]) 80 | self.up_4 = UpSampling(hidden_channels[1], hidden_channels[0]) 81 | self.conv_2 = nn.Conv2d(in_channels=hidden_channels[0], out_channels=out_channel, kernel_size=1) 82 | self.sigmoid = nn.Sigmoid() 83 | 84 | def forward(self, in_lyr): 85 | conv_lyr_1 = self.conv_1(in_lyr) 86 | down_lyr_1 = self.down_1(conv_lyr_1) 87 | down_lyr_2 = self.down_2(down_lyr_1) 88 | down_lyr_3 = self.down_3(down_lyr_2) 89 | down_lyr_4 = self.down_4(down_lyr_3) 90 | up_lyr_1 = self.up_1(down_lyr_4, down_lyr_3) 91 | up_lyr_2 = self.up_2(up_lyr_1, down_lyr_2) 92 | up_lyr_3 = self.up_3(up_lyr_2, down_lyr_1) 93 | up_lyr_4 = self.up_4(up_lyr_3, conv_lyr_1) 94 | conv_lyr_2 = self.conv_2(up_lyr_4) 95 | return self.sigmoid(conv_lyr_2) 96 | 97 | 98 | if __name__ == '__main__': 99 | # hidden_channels = [64, 128, 256, 512, 1024] 100 | hidden_channels = [32, 64, 128, 256, 512] 101 | unet = UNet(4, 1, hidden_channels).cuda() 102 | test = torch.randn([1, 4, 512, 512]).cuda() 103 | # 打印网络结构 104 | torchinfo.summary(unet, input_size=(1, 4, 512, 512)) 105 | # 设定网络保存路径 106 | modelData = 'unet_test.pth' 107 | # 保存为onnx格式 108 | torch.onnx.export(unet, test, modelData, export_params=True) 109 | # 重新读取 110 | onnx_model = onnx.load(modelData) 111 | # 增加特征图纬度信息 112 | onnx.save(onnx.shape_inference.infer_shapes(onnx_model), modelData) 113 | # 显示网络结构 114 | netron.start(modelData) 115 | # -------------------------------------------------------------------------------- /module/image.py: -------------------------------------------------------------------------------- 1 | from osgeo import gdal 2 | import numpy as np 3 | import os 4 | os.environ['PROJ_LIB'] = r'C:\Users\Lenovo\.conda\envs\zph\Library\share\proj' 5 | os.environ['GDAL_DATA'] = r'C:\Users\Lenovo\.conda\envs\zph\Library\share' 6 | gdal.PushErrorHandler("CPLQuietErrorHandler") 7 | 8 | 9 | class ImageProcess: 10 | def __init__(self, filepath: str): 11 | self.filepath = filepath 12 | self.dataset = gdal.Open(self.filepath, gdal.GA_ReadOnly) 13 | self.info = [] 14 | self.img_data = None 15 | self.data_8bit = None 16 | 17 | def read_img_info(self): 18 | # 获取波段、宽、高 19 | img_bands = self.dataset.RasterCount 20 | img_width = self.dataset.RasterXSize 21 | img_height = self.dataset.RasterYSize 22 | # 获取仿射矩阵、投影 23 | img_geotrans = self.dataset.GetGeoTransform() 24 | img_proj = self.dataset.GetProjection() 25 | self.info = [img_bands, img_width, img_height, img_geotrans, img_proj] 26 | return self.info 27 | 28 | def read_img_data(self): 29 | self.img_data = self.dataset.ReadAsArray(0, 0, self.info[1], self.info[2]) 30 | return self.img_data 31 | 32 | # 影像写入文件 33 | @staticmethod 34 | def write_img(filename: str, img_data: np.array, **kwargs): 35 | # 判断栅格数据的数据类型 36 | if 'int8' in img_data.dtype.name: 37 | datatype = gdal.GDT_Byte 38 | elif 'int16' in img_data.dtype.name: 39 | datatype = gdal.GDT_UInt16 40 | else: 41 | datatype = gdal.GDT_Float32 42 | # 判读数组维数 43 | if len(img_data.shape) >= 3: 44 | img_bands, img_height, img_width = img_data.shape 45 | else: 46 | img_bands, (img_height, img_width) = 1, img_data.shape 47 | # 创建文件 48 | driver = gdal.GetDriverByName("GTiff") 49 | outdataset = driver.Create(filename, img_width, img_height, img_bands, datatype) 50 | # 写入仿射变换参数 51 | if 'img_geotrans' in kwargs: 52 | outdataset.SetGeoTransform(kwargs['img_geotrans']) 53 | # 写入投影 54 | if 'img_proj' in kwargs: 55 | outdataset.SetProjection(kwargs['img_proj']) 56 | # 写入文件 57 | if img_bands == 1: 58 | outdataset.GetRasterBand(1).WriteArray(img_data) # 写入数组数据 59 | else: 60 | for i in range(img_bands): 61 | outdataset.GetRasterBand(i + 1).WriteArray(img_data[i]) 62 | 63 | del outdataset 64 | 65 | 66 | def read_multi_bands(image_path): 67 | """ 68 | 读取多波段文件 69 | :param image_path: 多波段文件路径 70 | :return: 影像对象,影像元信息,影像矩阵 71 | """ 72 | # 影像读取 73 | image = ImageProcess(filepath=image_path) 74 | # 读取影像元信息 75 | image_info = image.read_img_info() 76 | print(f"多波段影像元信息:{image_info}") 77 | # 读取影像矩阵 78 | image_data = image.read_img_data() 79 | print(f"多波段矩阵大小:{image_data.shape}") 80 | return image, image_info, image_data 81 | 82 | 83 | def read_single_band(band_path): 84 | """ 85 | 读取单波段文件 86 | :param band_path: 单波段文件路径 87 | :return: 影像对象,影像元信息,影像矩阵 88 | """ 89 | # 影像读取 90 | band = ImageProcess(filepath=band_path) 91 | # 读取影像元信息 92 | band_info = band.read_img_info() 93 | print(f"单波段影像元信息:{band_info}") 94 | # 读取影像矩阵 95 | band_data = band.read_img_data() 96 | print(f"单波段矩阵大小:{band_data.shape}") 97 | return band, band_info, band_data 98 | -------------------------------------------------------------------------------- /module/slice.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import numpy as np 4 | from alive_progress import alive_bar 5 | from module.image import * 6 | 7 | def cal_single_band_slice(single_band_data, slice_size=1000): 8 | """ 9 | 计算单波段的格网裁剪四角点 10 | :param single_band_data:单波段原始数据 11 | :param slice_size: 裁剪大小 12 | :return: 嵌套列表,每一个块的四角行列号 13 | """ 14 | single_band_size = single_band_data.shape 15 | row_num = math.ceil(single_band_size[0] / slice_size) # 向上取整 16 | col_num = math.ceil(single_band_size[1] / slice_size) # 向上取整 17 | print(f"行列数:{single_band_size},行分割数量:{row_num},列分割数量:{col_num}") 18 | slice_index = [] 19 | for i in range(row_num): 20 | for j in range(col_num): 21 | row_min = i * slice_size 22 | row_max = (i + 1) * slice_size 23 | if (i + 1) * slice_size > single_band_size[0]: 24 | row_max = single_band_size[0] 25 | col_min = j * slice_size 26 | col_max = (j + 1) * slice_size 27 | if (j + 1) * slice_size > single_band_size[1]: 28 | col_max = single_band_size[1] 29 | slice_index.append([row_min, row_max, col_min, col_max]) 30 | return slice_index 31 | 32 | 33 | def single_band_slice(single_band_data, index=[0, 1000, 0, 1000], slice_size=1000, edge_fill=False): 34 | """ 35 | 依据四角坐标,切分单波段影像 36 | :param single_band_data:原始矩阵数据 37 | :param index: 四角坐标 38 | :param slice_size: 分块大小 39 | :param edge_fill: 是否进行边缘填充 40 | :return: 切分好的单波段矩阵 41 | """ 42 | if edge_fill: 43 | if (index[1] - index[0] != slice_size) or (index[3] - index[2] != slice_size): 44 | result = np.empty(shape=(slice_size, slice_size)) 45 | new_row_min = index[0] % slice_size 46 | new_row_max = new_row_min + (index[1] - index[0]) 47 | new_col_min = index[2] % slice_size 48 | new_col_max = new_col_min + (index[3] - index[2]) 49 | result[new_row_min:new_row_max, new_col_min:new_col_max] = single_band_data[index[0]:index[1], 50 | index[2]:index[3]] 51 | else: 52 | result = single_band_data[index[0]:index[1], index[2]:index[3]] 53 | else: 54 | result = single_band_data[index[0]:index[1], index[2]:index[3]] 55 | return result.astype(single_band_data.dtype) 56 | 57 | 58 | def multi_bands_slice(multi_bands_data, index=[0, 1000, 0, 1000], slice_size=1000, edge_fill=False): 59 | """ 60 | 依据四角坐标,切分多波段影像 61 | :param multi_bands_data: 原始多波段矩阵 62 | :param index: 四角坐标 63 | :param slice_size: 分块大小 64 | :param edge_fill: 是否进行边缘填充 65 | :return: 切分好的多波段矩阵 66 | """ 67 | if edge_fill: 68 | if (index[1] - index[0] != slice_size) or (index[3] - index[2] != slice_size): 69 | result = np.empty(shape=(multi_bands_data.shape[0], slice_size, slice_size)) 70 | new_row_min = index[0] % slice_size 71 | new_row_max = new_row_min + (index[1] - index[0]) 72 | new_col_min = index[2] % slice_size 73 | new_col_max = new_col_min + (index[3] - index[2]) 74 | result[:, new_row_min:new_row_max, new_col_min:new_col_max] = multi_bands_data[:, index[0]:index[1], 75 | index[2]:index[3]] 76 | else: 77 | result = multi_bands_data[:, index[0]:index[1], index[2]:index[3]] 78 | else: 79 | result = multi_bands_data[:, index[0]:index[1], index[2]:index[3]] 80 | return result.astype(multi_bands_data.dtype) 81 | 82 | 83 | def slice_conbine(slice_all, slice_index): 84 | """ 85 | 将分块矩阵进行合并 86 | :param slice_all: 所有的分块矩阵列表 87 | :param slice_index: 分块的四角坐标 88 | :return: 合并的矩阵 89 | """ 90 | combine_data = np.zeros(shape=(slice_index[-1][1], slice_index[-1][3])) 91 | # print(combine_data.shape) 92 | for i, slice_element in enumerate(slice_index): 93 | combine_data[slice_element[0]:slice_element[1], slice_element[2]:slice_element[3]] = slice_all[i] 94 | return combine_data 95 | 96 | 97 | def coordtransf(Xpixel, Ypixel, GeoTransform): 98 | """ 99 | 像素坐标和地理坐标仿射变换 100 | :param Xpixel: 左上角行号 101 | :param Ypixel: 左上角列号 102 | :param GeoTransform: 原始仿射矩阵 103 | :return: 新的仿射矩阵 104 | """ 105 | XGeo = GeoTransform[0] + GeoTransform[1] * Xpixel + Ypixel * GeoTransform[2]; 106 | YGeo = GeoTransform[3] + GeoTransform[4] * Xpixel + Ypixel * GeoTransform[5]; 107 | slice_geotrans = (XGeo, GeoTransform[1], GeoTransform[2], YGeo, GeoTransform[4], GeoTransform[5]) 108 | return slice_geotrans 109 | 110 | 111 | def multi_bands_grid_slice(image_path, image_slice_dir, slice_size, edge_fill=False): 112 | """ 113 | 多波段格网裁剪 114 | :param image_path: 原始多波段影像 115 | :param image_slice_dir: 裁剪保存文件夹 116 | :param slice_size: 裁剪大小 117 | :return: 118 | """ 119 | image, image_info, image_data = read_multi_bands(image_path) 120 | # 计算分块的四角行列号 121 | slice_index = cal_single_band_slice(image_data[0, :, :], slice_size=slice_size, ) 122 | # 执行裁剪 123 | with alive_bar(len(slice_index), force_tty=True) as bar: 124 | for i, slice_element in enumerate(slice_index): 125 | slice_data = multi_bands_slice(image_data, index=slice_element, slice_size=slice_size, 126 | edge_fill=edge_fill) # 裁剪多波段影像 127 | slice_geotrans = coordtransf(slice_element[2], slice_element[0], image_info[3]) # 转换仿射坐标 128 | image.write_img(image_slice_dir + r'\multi_grid_slice_' + str(i) + '.tif', slice_data, 129 | img_geotrans=slice_geotrans, img_proj=image_info[4]) # 写入文件 130 | bar() 131 | print('多波段格网裁剪完成') 132 | 133 | 134 | def single_band_grid_slice(band_path, band_slice_dir, slice_size, edge_fill=False): 135 | """ 136 | 单波段格网裁剪 137 | :param band_path: 原始单波段影像 138 | :param band_slice_dir: 裁剪保存文件夹 139 | :param slice_size: 裁剪大小 140 | :return: 141 | """ 142 | band, band_info, band_data = read_single_band(band_path) 143 | # 计算分块的四角行列号 144 | slice_index = cal_single_band_slice(band_data, slice_size=slice_size) 145 | # 执行裁剪 146 | with alive_bar(len(slice_index), force_tty=True) as bar: 147 | for i, slice_element in enumerate(slice_index): 148 | slice_data = single_band_slice(band_data, index=slice_element, slice_size=slice_size, 149 | edge_fill=edge_fill) # 裁剪单波段影像 150 | slice_geotrans = coordtransf(slice_element[2], slice_element[0], band_info[3]) # 转换仿射坐标 151 | band.write_img(band_slice_dir + r'\single_grid_slice_' + str(i) + '.tif', slice_data, 152 | img_geotrans=slice_geotrans, img_proj=band_info[4]) # 写入文件 153 | bar() 154 | print('单波段格网裁剪完成') 155 | 156 | 157 | def multi_bands_rand_slice(image_path, image_slice_dir, slice_size, slice_count): 158 | """ 159 | 多波段随机裁剪 160 | :param image_path: 原始多波段影像 161 | :param image_slice_dir: 裁剪保存文件夹 162 | :param slice_size: 裁剪大小 163 | :param slice_count: 裁剪数量 164 | :return: 165 | """ 166 | image, image_info, image_data = read_multi_bands(image_path) 167 | # 生成随机起始点 168 | randx = [random.randint(0, image_info[2] - slice_size - 1) for i in range(slice_count)] 169 | randy = [random.randint(0, image_info[1] - slice_size - 1) for j in range(slice_count)] 170 | randx1 = np.add(randx, slice_size).tolist() 171 | randy1 = np.add(randy, slice_size).tolist() 172 | rand_index = [[randx[k], randx1[k], randy[k], randy1[k]] for k in range(slice_count)] 173 | # 进行裁剪 174 | with alive_bar(len(rand_index), force_tty=True) as bar: 175 | for i, slice_element in enumerate(rand_index): 176 | slice_data = multi_bands_slice(image_data, index=slice_element, slice_size=slice_size) # 裁剪多波段影像 177 | slice_geotrans = coordtransf(slice_element[2], slice_element[0], image_info[3]) # 转换仿射坐标 178 | image.write_img(image_slice_dir + r'\multi_rand_slice_' + str(i) + '.tif', slice_data, 179 | img_geotrans=slice_geotrans, img_proj=image_info[4]) # 写入文件 180 | bar() 181 | print('多波段随机裁剪完成') 182 | 183 | 184 | def single_band_rand_slice(band_path, band_slice_dir, slice_size, slice_count): 185 | """ 186 | 单波段随机裁剪 187 | :param band_path: 原始单波段影像 188 | :param band_slice_dir: 裁剪保存文件夹 189 | :param slice_size: 裁剪大小 190 | :param slice_count: 裁剪数量 191 | :return: 192 | """ 193 | band, band_info, band_data = read_single_band(band_path) 194 | # 生成随机起始点 195 | randx = [random.randint(0, band_info[2] - slice_size - 1) for i in range(slice_count)] 196 | randy = [random.randint(0, band_info[1] - slice_size - 1) for j in range(slice_count)] 197 | randx1 = np.add(randx, slice_size).tolist() 198 | randy1 = np.add(randy, slice_size).tolist() 199 | rand_index = [[randx[k], randx1[k], randy[k], randy1[k]] for k in range(slice_count)] 200 | # 进行裁剪 201 | with alive_bar(len(rand_index), force_tty=True) as bar: 202 | for i, slice_element in enumerate(rand_index): 203 | slice_data = single_band_slice(band_data, index=slice_element, slice_size=slice_size) # 裁剪单波段影像 204 | slice_geotrans = coordtransf(slice_element[2], slice_element[0], band_info[3]) # 转换仿射坐标 205 | band.write_img(band_slice_dir + r'\single_rand_slice_' + str(i) + '.tif', slice_data, 206 | img_geotrans=slice_geotrans, img_proj=band_info[4]) # 写入文件 207 | bar() 208 | print('单波段随机裁剪完成') 209 | 210 | 211 | def deeplr_grid_slice(image_path, band_path, image_slice_dir, band_slice_dir, slice_size, edge_fill=False): 212 | """ 213 | 制作深度学习样本-格网裁剪:同时裁剪多波段、单波段影像 214 | :param image_path: 原始image影像 215 | :param band_path: 原始label影像 216 | :param image_slice_dir: image裁剪保存文件夹 217 | :param band_slice_dir: label裁剪保存文件夹 218 | :param slice_size: 裁剪大小 219 | :return: 220 | """ 221 | image, image_info, image_data = read_multi_bands(image_path) 222 | band, band_info, band_data = read_single_band(band_path) 223 | # 计算分块的四角行列号 224 | slice_index = cal_single_band_slice(image_data[0, :, :], slice_size=slice_size) 225 | # 执行裁剪 226 | with alive_bar(len(slice_index), force_tty=True) as bar: 227 | for i, slice_element in enumerate(slice_index): 228 | slice_data = multi_bands_slice(image_data, index=slice_element, slice_size=slice_size, 229 | edge_fill=edge_fill) # 裁剪多波段影像 230 | slice_geotrans = coordtransf(slice_element[2], slice_element[0], image_info[3]) # 转换仿射坐标 231 | image.write_img(image_slice_dir + r'\multi_grid_slice_' + str(i) + '.tif', slice_data, 232 | img_geotrans=slice_geotrans, img_proj=image_info[4]) # 写入文件 233 | 234 | slice_band = single_band_slice(band_data, index=slice_element, slice_size=slice_size, 235 | edge_fill=edge_fill) # 裁剪单波段影像 236 | band.write_img(band_slice_dir + r'\single_grid_slice_' + str(i) + '.tif', slice_band, 237 | img_geotrans=slice_geotrans, img_proj=band_info[4]) # 写入文件 238 | bar() 239 | print('深度学习样本-格网裁剪完成') 240 | 241 | 242 | def deeplr_rand_slice(image_path, band_path, image_slice_dir, band_slice_dir, slice_size, slice_count): 243 | """ 244 | 制作深度学习样本-随机裁剪:同时裁剪多波段、单波段影像 245 | :param image_path: 原始image影像 246 | :param band_path: 原始label影像 247 | :param image_slice_dir: image裁剪保存文件夹 248 | :param band_slice_dir: label裁剪保存文件夹 249 | :param slice_size: 裁剪大小 250 | :param slice_count: 裁剪数量 251 | :return: 252 | """ 253 | image, image_info, image_data = read_multi_bands(image_path) 254 | band, band_info, band_data = read_single_band(band_path) 255 | # 生成随机起始点 256 | randx = [random.randint(0, image_info[2] - slice_size - 1) for i in range(slice_count)] 257 | randy = [random.randint(0, image_info[1] - slice_size - 1) for j in range(slice_count)] 258 | randx1 = np.add(randx, slice_size).tolist() 259 | randy1 = np.add(randy, slice_size).tolist() 260 | rand_index = [[randx[k], randx1[k], randy[k], randy1[k]] for k in range(slice_count)] 261 | # 执行裁剪 262 | with alive_bar(len(rand_index), force_tty=True) as bar: 263 | for i, slice_element in enumerate(rand_index): 264 | slice_data = multi_bands_slice(image_data, index=slice_element, slice_size=slice_size) # 裁剪多波段影像 265 | slice_geotrans = coordtransf(slice_element[2], slice_element[0], image_info[3]) # 转换仿射坐标 266 | image.write_img(image_slice_dir + r'\multi_rand_slice_' + str(i) + '.tif', slice_data, 267 | img_geotrans=slice_geotrans, img_proj=image_info[4]) # 写入文件 268 | 269 | slice_band = single_band_slice(band_data, index=slice_element, slice_size=slice_size) # 裁剪单波段影像 270 | band.write_img(band_slice_dir + r'\single_rand_slice_' + str(i) + '.tif', slice_band, 271 | img_geotrans=slice_geotrans, img_proj=band_info[4]) # 写入文件 272 | bar() 273 | print('深度学习样本-随机裁剪完成') 274 | -------------------------------------------------------------------------------- /progress/image_slice.py: -------------------------------------------------------------------------------- 1 | from module import slice 2 | 3 | if __name__ == '__main__': 4 | print(f"0-多波段格网裁剪、1-单波段格网裁剪、2-多波段随机裁剪、3-单波段随机裁剪、4-制作深度学习样本-格网裁剪、5-制作深度学习样本-随机裁剪") 5 | slice_type = input(f"请选择:") 6 | if int(slice_type) == 0: 7 | # 参数设置 8 | image_path = input(f"请输入待裁剪多波段影像路径:") 9 | image_slice_dir = input(f"请输入结果存放路径:") 10 | slice_size = int(input(f"请输入裁剪块大小:")) 11 | edge_fill = bool(int(input(f"是否进行边缘填充(0/1):"))) 12 | slice.multi_bands_grid_slice(image_path, image_slice_dir, slice_size, edge_fill=edge_fill) 13 | 14 | elif int(slice_type) == 1: 15 | # 参数设置 16 | band_path = input(f"请输入待裁剪单波段影像路径:") 17 | band_slice_dir = input(f"请输入结果存放路径:") 18 | slice_size = int(input(f"请输入裁剪块大小:")) 19 | edge_fill = bool(int(input(f"是否进行边缘填充(0/1):"))) 20 | slice.single_band_grid_slice(band_path, band_slice_dir, slice_size, edge_fill=edge_fill) 21 | 22 | elif int(slice_type) == 2: 23 | # 参数设置 24 | image_path = input(f"请输入待裁剪多波段影像路径:") 25 | image_slice_dir = input(f"请输入结果存放路径:") 26 | slice_size = int(input(f"请输入裁剪块大小:")) 27 | slice_count = int(input(f"请输入裁剪数量:")) 28 | slice.multi_bands_rand_slice(image_path, image_slice_dir, slice_size, slice_count) 29 | 30 | elif int(slice_type) == 3: 31 | # 参数设置 32 | band_path = input(f"请输入待裁剪单波段影像路径:") 33 | band_slice_dir = input(f"请输入结果存放路径:") 34 | slice_size = int(input(f"请输入裁剪块大小:")) 35 | slice_count = int(input(f"请输入裁剪数量:")) 36 | slice.single_band_rand_slice(band_path, band_slice_dir, slice_size, slice_count) 37 | 38 | elif int(slice_type) == 4: 39 | # 参数设置 40 | image_path = input(f"请输入待裁剪多波段image路径:") 41 | band_path = input(f"请输入待裁剪单波段label路径:") 42 | image_slice_dir = input(f"请输入image裁剪结果存放路径:") 43 | band_slice_dir = input(f"请输入label裁剪结果存放路径:") 44 | slice_size = int(input(f"请输入裁剪块大小:")) 45 | edge_fill = bool(int(input(f"是否进行边缘填充(0/1):"))) 46 | slice.deeplr_grid_slice(image_path, band_path, image_slice_dir, band_slice_dir, slice_size, edge_fill=edge_fill) 47 | 48 | elif int(slice_type) == 5: 49 | # 参数设置 50 | image_path = input(f"请输入待裁剪多波段image路径:") 51 | band_path = input(f"请输入待裁剪单波段label路径:") 52 | image_slice_dir = input(f"请输入image裁剪结果存放路径:") 53 | band_slice_dir = input(f"请输入label裁剪结果存放路径:") 54 | slice_size = int(input(f"请输入裁剪块大小:")) 55 | slice_count = int(input(f"请输入裁剪数量:")) 56 | slice.deeplr_rand_slice(image_path, band_path, image_slice_dir, band_slice_dir, slice_size, slice_count) 57 | 58 | else: 59 | print('输入错误') --------------------------------------------------------------------------------