├── .idea
├── .gitignore
├── dblue.iml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── other.xml
├── README.md
├── models
├── resnet.py
├── resunet.py
└── unet.py
├── module
├── image.py
└── slice.py
└── progress
└── image_slice.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # 默认忽略的文件
2 | /shelf/
3 | /workspace.xml
4 | # 基于编辑器的 HTTP 客户端请求
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/dblue.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
26 |
27 |
28 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/other.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 遥感深度学习入门
2 |
3 | image.py:影像读取相关函数
4 |
5 | slice.py:影像切分相关函数
6 |
7 | image_slice.py:主函数,执行切分
8 |
9 | unet.py:Unet网络定义
10 |
11 | resunet.py:ResUnet网络定义
12 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/11/4 10:46
3 | # @Author : Zph
4 | # @Email : hhs_zph@mail.imu.edu.cn
5 | # @File : resnet.py
6 | # @Software: PyCharm
7 |
8 |
9 | import torch
10 | from torch import nn
11 | import torchinfo
12 | import onnx
13 | import netron
14 |
15 |
16 | class ConvBN(nn.Module):
17 | """定义conv+bn结构,可选relu"""
18 |
19 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, do_relu=False):
20 | super(ConvBN, self).__init__()
21 | layers = [
22 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
23 | nn.BatchNorm2d(out_channels)]
24 | if do_relu:
25 | layers.append(nn.ReLU(inplace=True))
26 | self.conv_bn_relu = nn.Sequential(*layers)
27 |
28 | def forward(self, in_lyr):
29 | return self.conv_bn_relu(in_lyr)
30 |
31 |
32 | class BasicBlock(nn.Module):
33 | """定义BasicBlock结构"""
34 |
35 | def __init__(self, in_channels, out_channels, stride=None, downsample=None):
36 | super(BasicBlock, self).__init__()
37 | # 是否进行降采样
38 | self.downsample = downsample
39 | # 残差计算后的relu
40 | self.relu = nn.ReLU(inplace=True)
41 | # 定义主路:conv+bn+relu+conv+bn
42 | layers = [ConvBN(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, do_relu=True),
43 | ConvBN(out_channels, out_channels, kernel_size=3, stride=1, padding=1, do_relu=False)]
44 | self.basicblock = nn.Sequential(*layers)
45 |
46 | # 是否进行降采样
47 | if downsample:
48 | # 定义降采样旁路:1×1conv
49 | self.do_downsample = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2, padding=0)
50 |
51 | def forward(self, in_lyr):
52 | # 进行主路计算
53 | basicblock_lyr = self.basicblock(in_lyr)
54 | if self.downsample:
55 | # 若进行降采样,通过降采样旁路进行计算,再相加
56 | shortcut_lyr = self.do_downsample(in_lyr)
57 | residual_lyr = self.relu(basicblock_lyr + shortcut_lyr)
58 | else:
59 | # 若不进行降采样,直接进行相加
60 | residual_lyr = self.relu(basicblock_lyr + in_lyr)
61 | return residual_lyr # 返回残差块BasicBlock
62 |
63 |
64 | class BottleNeck(nn.Module):
65 | """定义BottleNeck结构"""
66 | def __init__(self, in_channels, mid_channels, out_channels, stride=None, downsample=None):
67 | super(BottleNeck, self).__init__()
68 | # 是否进行降采样
69 | self.downsample = downsample
70 | # 残差计算后的relu
71 | self.relu = nn.ReLU(inplace=True)
72 | # 定义主路
73 | layers = [ConvBN(in_channels, mid_channels, kernel_size=1, stride=1, padding=0, do_relu=True),
74 | ConvBN(mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, do_relu=True),
75 | ConvBN(mid_channels, out_channels, kernel_size=1, stride=1, padding=0, do_relu=False)]
76 | self.bottleneck = nn.Sequential(*layers)
77 | # 是否进行降采样
78 | if downsample:
79 | self.do_downsample = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0)
80 |
81 | def forward(self, in_lyr):
82 | # 进行主路计算
83 | bottleneck_lyr = self.bottleneck(in_lyr)
84 | if self.downsample:
85 | # 若进行降采样,通过降采样旁路进行计算,再相加
86 | shortcut_lyr = self.do_downsample(in_lyr)
87 | residual_lyr = self.relu(bottleneck_lyr + shortcut_lyr)
88 | else:
89 | # 若不进行降采样,直接进行相加
90 | residual_lyr = self.relu(bottleneck_lyr + in_lyr)
91 | return residual_lyr
92 |
93 |
94 | class Resnet_bb(nn.Module):
95 | """创建基于BasicBlock的Resnet18或34"""
96 |
97 | def __init__(self, in_channels, conv_num=None, class_num=None):
98 | super(Resnet_bb, self).__init__()
99 | # 第一组卷积组合
100 | self.conv1 = ConvBN(in_channels, 64, kernel_size=7, stride=2, padding=3, do_relu=True)
101 | # 最大池化
102 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
103 |
104 | # 根据传入的残差块层数,创建残差块的组合
105 | def get_conv_block(in_channels, out_channels, conv_count=None, downsample=None):
106 | if downsample:
107 | convblock = [BasicBlock(in_channels, out_channels, stride=2, downsample=True)]
108 | else:
109 | convblock = [BasicBlock(in_channels, out_channels, stride=1, downsample=False)]
110 |
111 | for _ in range(conv_count - 1):
112 | convblock.append(BasicBlock(out_channels, out_channels, stride=1, downsample=False))
113 |
114 | return convblock
115 |
116 | self.conv2 = nn.Sequential(*get_conv_block(64, 64, conv_num[0], downsample=False))
117 | self.conv3 = nn.Sequential(*get_conv_block(64, 128, conv_num[1], downsample=True))
118 | self.conv4 = nn.Sequential(*get_conv_block(128, 256, conv_num[2], downsample=True))
119 | self.conv5 = nn.Sequential(*get_conv_block(256, 512, conv_num[3], downsample=True))
120 | # 平均池化+全连接层
121 | self.avepool = nn.AvgPool2d(kernel_size=7)
122 | self.full = nn.Linear(512, class_num)
123 |
124 | def forward(self, in_lyr):
125 | conv1_lyr = self.conv1(in_lyr)
126 | maxpool_lyr = self.maxpool(conv1_lyr)
127 | conv2_lyr = self.conv2(maxpool_lyr)
128 | conv3_lyr = self.conv3(conv2_lyr)
129 | conv4_lyr = self.conv4(conv3_lyr)
130 | conv5_lyr = self.conv5(conv4_lyr)
131 | avepool_lyr = self.avepool(conv5_lyr)
132 | full_lyr = self.full(avepool_lyr.view(avepool_lyr.size(0), -1))
133 | return full_lyr
134 |
135 |
136 | class Resnet_bn(nn.Module):
137 | """创建基于BottleNeck的Resnet50或101"""
138 |
139 | def __init__(self, in_channels, conv_num=None, class_num=None):
140 | super(Resnet_bn, self).__init__()
141 | # 第一组卷积组合
142 | self.conv1 = ConvBN(in_channels, 64, kernel_size=7, stride=2, padding=3, do_relu=True)
143 | # 最大池化
144 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
145 |
146 | # 根据传入的残差块层数,创建残差块的组合
147 | def get_conv_block(in_channels, mid_channels, out_channels, conv_count=None, stride=None):
148 | convblock = [BottleNeck(in_channels, mid_channels, out_channels, stride=stride, downsample=True)]
149 | for _ in range(conv_count - 1):
150 | convblock.append(BottleNeck(out_channels, mid_channels, out_channels, stride=1, downsample=False))
151 |
152 | return convblock
153 |
154 | self.conv2 = nn.Sequential(*get_conv_block(64, 64, 256, conv_num[0], stride=1))
155 | self.conv3 = nn.Sequential(*get_conv_block(256, 128, 512, conv_num[1], stride=2))
156 | self.conv4 = nn.Sequential(*get_conv_block(512, 256, 1024, conv_num[2], stride=2))
157 | self.conv5 = nn.Sequential(*get_conv_block(1024, 512, 2048, conv_num[3], stride=2))
158 | # 平均池化+全连接层
159 | self.avepool = nn.AvgPool2d(kernel_size=7)
160 | self.full = nn.Linear(2048, class_num)
161 |
162 | def forward(self, in_lyr):
163 | conv1_lyr = self.conv1(in_lyr)
164 | maxpool_lyr = self.maxpool(conv1_lyr)
165 | conv2_lyr = self.conv2(maxpool_lyr)
166 | conv3_lyr = self.conv3(conv2_lyr)
167 | conv4_lyr = self.conv4(conv3_lyr)
168 | conv5_lyr = self.conv5(conv4_lyr)
169 | avepool_lyr = self.avepool(conv5_lyr)
170 | full_lyr = self.full(avepool_lyr.view(avepool_lyr.size(0), -1))
171 | return full_lyr
172 |
173 |
174 | if __name__ == "__main__":
175 | resnet = Resnet_bn(in_channels=3, conv_num=[3, 4, 6, 3], class_num=1000).cuda()
176 | print(resnet)
177 | torchinfo.summary(resnet, input_size=(1, 3, 224, 224))
178 | modelData = 'resnet_test.pth'
179 | test = torch.randn([1, 3, 224, 224]).cuda()
180 | # 保存为onnx格式
181 | torch.onnx.export(resnet, test, modelData, export_params=True)
182 | # 重新读取
183 | onnx_model = onnx.load(modelData)
184 | # 增加特征图纬度信息
185 | onnx.save(onnx.shape_inference.infer_shapes(onnx_model), modelData)
186 | # 显示网络结构
187 | netron.start(modelData)
188 | #
--------------------------------------------------------------------------------
/models/resunet.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/11/4 11:40
3 | # @Author : Zph
4 | # @Email : hhs_zph@mail.imu.edu.cn
5 | # @File : resunet.py
6 | # @Software: PyCharm
7 |
8 | import torch
9 | from torch import nn
10 | import torchinfo
11 | import onnx
12 | import netron
13 |
14 | class ResBlock(nn.Module):
15 | """定义ResBlock结构"""
16 |
17 | def __init__(self, in_channels, out_channels):
18 | super(ResBlock, self).__init__()
19 | # 定义残差结构主路
20 | res_layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
21 | nn.BatchNorm2d(out_channels),
22 | nn.ReLU(inplace=True),
23 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
24 | nn.BatchNorm2d(out_channels)]
25 | self.res_layers = nn.Sequential(*res_layers)
26 | # 定义残差结构旁路
27 | shortcut_layers = [
28 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
29 | nn.BatchNorm2d(out_channels)]
30 | self.shortcut = nn.Sequential(*shortcut_layers)
31 | self.relu = nn.ReLU(inplace=True)
32 |
33 | def forward(self, in_lyr):
34 | res = self.res_layers(in_lyr)
35 | shortcut = self.shortcut(in_lyr)
36 | return self.relu(res + shortcut)
37 |
38 | class DownSampling(nn.Module):
39 | """定义下采样操作"""
40 |
41 | def __init__(self, in_channel, out_channel):
42 | super(DownSampling, self).__init__()
43 | self.down_sampling = nn.Sequential(
44 | nn.MaxPool2d(2),
45 | ResBlock(in_channel, out_channel)
46 | )
47 |
48 | def forward(self, in_channel):
49 | return self.down_sampling(in_channel)
50 |
51 |
52 | class UpSampling(nn.Module):
53 | """定义上采样操作"""
54 |
55 | def __init__(self, in_channel, out_channel):
56 | super(UpSampling, self).__init__()
57 | self.up_sampling = nn.Sequential(
58 | nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2, bias=False),
59 | nn.ReLU(inplace=True), # inplace=True,结果直接替换输入的层,节省内存
60 | nn.BatchNorm2d(out_channel)
61 | )
62 | self.conv = ResBlock(in_channel, out_channel)
63 |
64 | def forward(self, in_lyr, down_lyr):
65 | lyr_1 = self.up_sampling(in_lyr),
66 | cat_lyr = torch.cat([lyr_1[0], down_lyr], dim=1)
67 | out_lyr = self.conv(cat_lyr)
68 | return out_lyr
69 |
70 |
71 | class ResUNet(nn.Module): # 继承nn.Module类
72 | """定义ResUNet网路结构"""
73 |
74 | def __init__(self, in_channel, out_channel, hidden_channels):
75 | super(ResUNet, self).__init__() # 利用父类(nn.Module)的初始化方法来初始化继承的属性
76 |
77 | self.conv_1 = ResBlock(in_channel, hidden_channels[0]) # 进行两次(卷积+relu+BN)
78 | self.down_1 = DownSampling(hidden_channels[0], hidden_channels[1])
79 | self.down_2 = DownSampling(hidden_channels[1], hidden_channels[2])
80 | self.down_3 = DownSampling(hidden_channels[2], hidden_channels[3])
81 | self.down_4 = DownSampling(hidden_channels[3], hidden_channels[4])
82 |
83 | self.up_1 = UpSampling(hidden_channels[4], hidden_channels[3])
84 | self.up_2 = UpSampling(hidden_channels[3], hidden_channels[2])
85 | self.up_3 = UpSampling(hidden_channels[2], hidden_channels[1])
86 | self.up_4 = UpSampling(hidden_channels[1], hidden_channels[0])
87 |
88 | self.conv_2 = nn.Conv2d(hidden_channels[0], out_channel, kernel_size=1, stride=1, padding=0)
89 | self.sigmoid = nn.Sigmoid()
90 |
91 | def forward(self, in_lyr):
92 | conv_lyr_1 = self.conv_1(in_lyr)
93 | down_lyr_1 = self.down_1(conv_lyr_1)
94 | down_lyr_2 = self.down_2(down_lyr_1)
95 | down_lyr_3 = self.down_3(down_lyr_2)
96 | down_lyr_4 = self.down_4(down_lyr_3)
97 | up_lyr_1 = self.up_1(down_lyr_4, down_lyr_3)
98 | up_lyr_2 = self.up_2(up_lyr_1, down_lyr_2)
99 | up_lyr_3 = self.up_3(up_lyr_2, down_lyr_1)
100 | up_lyr_4 = self.up_4(up_lyr_3, conv_lyr_1)
101 | conv_lyr_2 = self.conv_2(up_lyr_4)
102 | return self.sigmoid(conv_lyr_2)
103 |
104 |
105 |
106 | if __name__ == '__main__':
107 | # hidden_channels = [64, 128, 256, 512, 1024]
108 | hidden_channels = [32, 64, 128, 256, 512]
109 | resunet = ResUNet(4, 1, hidden_channels).cuda()
110 |
111 | torchinfo.summary(resunet, input_size=(1, 4, 256, 256))
112 | #
113 | # test = torch.randn([1, 4, 256, 256]).cuda()
114 | # # 设定网络保存路径
115 | # modelData = 'resunet_test.pth'
116 | # # 保存为onnx格式
117 | # torch.onnx.export(resunet, test, modelData, export_params=True, opset_version=9)
118 | # # 重新读取
119 | # onnx_model = onnx.load(modelData)
120 | # # 增加特征图纬度信息
121 | # onnx.save(onnx.shape_inference.infer_shapes(onnx_model), modelData)
122 | # # 显示网络结构
123 | # netron.start(modelData)
124 |
125 |
126 |
127 |
128 | # res = UpSampling(256, 128)
129 | # in_lyr = torch.randn(size=(1,256,64,64))
130 | # down_lyr = torch.randn(size=(1,128,128,128))
131 | # result = res(in_lyr,down_lyr)
132 | # print(result.size())
133 |
134 | #torchinfo.summary(res, input_size=(1, 512, 256, 256))
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
--------------------------------------------------------------------------------
/models/unet.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/11/4 10:46
3 | # @Author : Zph
4 | # @Email : hhs_zph@mail.imu.edu.cn
5 | # @File : unet.py
6 | # @Software: PyCharm
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torchinfo
11 | import netron
12 | import onnx
13 |
14 |
15 | class DoubleConv(nn.Module):
16 | """定义卷积块(双卷积+relu)"""
17 |
18 | def __init__(self, in_channel, out_channel):
19 | super(DoubleConv, self).__init__()
20 | self.doubleconv = nn.Sequential(
21 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, padding=1, bias=False),
22 | nn.ReLU(inplace=True), # inplace=True,结果直接替换输入的层,节省内存
23 | nn.BatchNorm2d(out_channel),
24 | nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding=1, bias=False),
25 | nn.ReLU(inplace=True),
26 | nn.BatchNorm2d(out_channel)
27 | )
28 |
29 | def forward(self, in_channel):
30 | return self.doubleconv(in_channel)
31 |
32 |
33 | class DownSampling(nn.Module):
34 | """定义下采样操作"""
35 |
36 | def __init__(self, in_channel, out_channel):
37 | super(DownSampling, self).__init__()
38 | self.down_sampling = nn.Sequential(
39 | nn.MaxPool2d(2),
40 | DoubleConv(in_channel, out_channel)
41 | )
42 |
43 | def forward(self, in_channel):
44 | return self.down_sampling(in_channel)
45 |
46 |
47 | class UpSampling(nn.Module):
48 | """定义上采样操作"""
49 |
50 | def __init__(self, in_channel, out_channel):
51 | super(UpSampling, self).__init__()
52 | self.up_sampling = nn.Sequential(
53 | nn.ConvTranspose2d(in_channels=in_channel, out_channels=out_channel, kernel_size=2, stride=2, bias=False),
54 | nn.ReLU(inplace=True), # inplace=True,结果直接替换输入的层,节省内存
55 | nn.BatchNorm2d(out_channel)
56 | )
57 | self.conv = DoubleConv(in_channel, out_channel)
58 |
59 | def forward(self, in_lyr, down_lyr):
60 | lyr_1 = self.up_sampling(in_lyr),
61 | cat_lyr = torch.cat([lyr_1[0], down_lyr], dim=1)
62 | out_lyr = self.conv(cat_lyr)
63 | return out_lyr
64 |
65 |
66 | class UNet(nn.Module): # 继承nn.Module类
67 | """定义Unet网路结构"""
68 |
69 | def __init__(self, in_channel, out_channel, hidden_channels):
70 | super(UNet, self).__init__() # 利用父类(nn.Module)的初始化方法来初始化继承的属性
71 |
72 | self.conv_1 = DoubleConv(in_channel, hidden_channels[0]) # 进行两次(卷积+relu+BN)
73 | self.down_1 = DownSampling(hidden_channels[0], hidden_channels[1])
74 | self.down_2 = DownSampling(hidden_channels[1], hidden_channels[2])
75 | self.down_3 = DownSampling(hidden_channels[2], hidden_channels[3])
76 | self.down_4 = DownSampling(hidden_channels[3], hidden_channels[4])
77 | self.up_1 = UpSampling(hidden_channels[4], hidden_channels[3])
78 | self.up_2 = UpSampling(hidden_channels[3], hidden_channels[2])
79 | self.up_3 = UpSampling(hidden_channels[2], hidden_channels[1])
80 | self.up_4 = UpSampling(hidden_channels[1], hidden_channels[0])
81 | self.conv_2 = nn.Conv2d(in_channels=hidden_channels[0], out_channels=out_channel, kernel_size=1)
82 | self.sigmoid = nn.Sigmoid()
83 |
84 | def forward(self, in_lyr):
85 | conv_lyr_1 = self.conv_1(in_lyr)
86 | down_lyr_1 = self.down_1(conv_lyr_1)
87 | down_lyr_2 = self.down_2(down_lyr_1)
88 | down_lyr_3 = self.down_3(down_lyr_2)
89 | down_lyr_4 = self.down_4(down_lyr_3)
90 | up_lyr_1 = self.up_1(down_lyr_4, down_lyr_3)
91 | up_lyr_2 = self.up_2(up_lyr_1, down_lyr_2)
92 | up_lyr_3 = self.up_3(up_lyr_2, down_lyr_1)
93 | up_lyr_4 = self.up_4(up_lyr_3, conv_lyr_1)
94 | conv_lyr_2 = self.conv_2(up_lyr_4)
95 | return self.sigmoid(conv_lyr_2)
96 |
97 |
98 | if __name__ == '__main__':
99 | # hidden_channels = [64, 128, 256, 512, 1024]
100 | hidden_channels = [32, 64, 128, 256, 512]
101 | unet = UNet(4, 1, hidden_channels).cuda()
102 | test = torch.randn([1, 4, 512, 512]).cuda()
103 | # 打印网络结构
104 | torchinfo.summary(unet, input_size=(1, 4, 512, 512))
105 | # 设定网络保存路径
106 | modelData = 'unet_test.pth'
107 | # 保存为onnx格式
108 | torch.onnx.export(unet, test, modelData, export_params=True)
109 | # 重新读取
110 | onnx_model = onnx.load(modelData)
111 | # 增加特征图纬度信息
112 | onnx.save(onnx.shape_inference.infer_shapes(onnx_model), modelData)
113 | # 显示网络结构
114 | netron.start(modelData)
115 | #
--------------------------------------------------------------------------------
/module/image.py:
--------------------------------------------------------------------------------
1 | from osgeo import gdal
2 | import numpy as np
3 | import os
4 | os.environ['PROJ_LIB'] = r'C:\Users\Lenovo\.conda\envs\zph\Library\share\proj'
5 | os.environ['GDAL_DATA'] = r'C:\Users\Lenovo\.conda\envs\zph\Library\share'
6 | gdal.PushErrorHandler("CPLQuietErrorHandler")
7 |
8 |
9 | class ImageProcess:
10 | def __init__(self, filepath: str):
11 | self.filepath = filepath
12 | self.dataset = gdal.Open(self.filepath, gdal.GA_ReadOnly)
13 | self.info = []
14 | self.img_data = None
15 | self.data_8bit = None
16 |
17 | def read_img_info(self):
18 | # 获取波段、宽、高
19 | img_bands = self.dataset.RasterCount
20 | img_width = self.dataset.RasterXSize
21 | img_height = self.dataset.RasterYSize
22 | # 获取仿射矩阵、投影
23 | img_geotrans = self.dataset.GetGeoTransform()
24 | img_proj = self.dataset.GetProjection()
25 | self.info = [img_bands, img_width, img_height, img_geotrans, img_proj]
26 | return self.info
27 |
28 | def read_img_data(self):
29 | self.img_data = self.dataset.ReadAsArray(0, 0, self.info[1], self.info[2])
30 | return self.img_data
31 |
32 | # 影像写入文件
33 | @staticmethod
34 | def write_img(filename: str, img_data: np.array, **kwargs):
35 | # 判断栅格数据的数据类型
36 | if 'int8' in img_data.dtype.name:
37 | datatype = gdal.GDT_Byte
38 | elif 'int16' in img_data.dtype.name:
39 | datatype = gdal.GDT_UInt16
40 | else:
41 | datatype = gdal.GDT_Float32
42 | # 判读数组维数
43 | if len(img_data.shape) >= 3:
44 | img_bands, img_height, img_width = img_data.shape
45 | else:
46 | img_bands, (img_height, img_width) = 1, img_data.shape
47 | # 创建文件
48 | driver = gdal.GetDriverByName("GTiff")
49 | outdataset = driver.Create(filename, img_width, img_height, img_bands, datatype)
50 | # 写入仿射变换参数
51 | if 'img_geotrans' in kwargs:
52 | outdataset.SetGeoTransform(kwargs['img_geotrans'])
53 | # 写入投影
54 | if 'img_proj' in kwargs:
55 | outdataset.SetProjection(kwargs['img_proj'])
56 | # 写入文件
57 | if img_bands == 1:
58 | outdataset.GetRasterBand(1).WriteArray(img_data) # 写入数组数据
59 | else:
60 | for i in range(img_bands):
61 | outdataset.GetRasterBand(i + 1).WriteArray(img_data[i])
62 |
63 | del outdataset
64 |
65 |
66 | def read_multi_bands(image_path):
67 | """
68 | 读取多波段文件
69 | :param image_path: 多波段文件路径
70 | :return: 影像对象,影像元信息,影像矩阵
71 | """
72 | # 影像读取
73 | image = ImageProcess(filepath=image_path)
74 | # 读取影像元信息
75 | image_info = image.read_img_info()
76 | print(f"多波段影像元信息:{image_info}")
77 | # 读取影像矩阵
78 | image_data = image.read_img_data()
79 | print(f"多波段矩阵大小:{image_data.shape}")
80 | return image, image_info, image_data
81 |
82 |
83 | def read_single_band(band_path):
84 | """
85 | 读取单波段文件
86 | :param band_path: 单波段文件路径
87 | :return: 影像对象,影像元信息,影像矩阵
88 | """
89 | # 影像读取
90 | band = ImageProcess(filepath=band_path)
91 | # 读取影像元信息
92 | band_info = band.read_img_info()
93 | print(f"单波段影像元信息:{band_info}")
94 | # 读取影像矩阵
95 | band_data = band.read_img_data()
96 | print(f"单波段矩阵大小:{band_data.shape}")
97 | return band, band_info, band_data
98 |
--------------------------------------------------------------------------------
/module/slice.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | import numpy as np
4 | from alive_progress import alive_bar
5 | from module.image import *
6 |
7 | def cal_single_band_slice(single_band_data, slice_size=1000):
8 | """
9 | 计算单波段的格网裁剪四角点
10 | :param single_band_data:单波段原始数据
11 | :param slice_size: 裁剪大小
12 | :return: 嵌套列表,每一个块的四角行列号
13 | """
14 | single_band_size = single_band_data.shape
15 | row_num = math.ceil(single_band_size[0] / slice_size) # 向上取整
16 | col_num = math.ceil(single_band_size[1] / slice_size) # 向上取整
17 | print(f"行列数:{single_band_size},行分割数量:{row_num},列分割数量:{col_num}")
18 | slice_index = []
19 | for i in range(row_num):
20 | for j in range(col_num):
21 | row_min = i * slice_size
22 | row_max = (i + 1) * slice_size
23 | if (i + 1) * slice_size > single_band_size[0]:
24 | row_max = single_band_size[0]
25 | col_min = j * slice_size
26 | col_max = (j + 1) * slice_size
27 | if (j + 1) * slice_size > single_band_size[1]:
28 | col_max = single_band_size[1]
29 | slice_index.append([row_min, row_max, col_min, col_max])
30 | return slice_index
31 |
32 |
33 | def single_band_slice(single_band_data, index=[0, 1000, 0, 1000], slice_size=1000, edge_fill=False):
34 | """
35 | 依据四角坐标,切分单波段影像
36 | :param single_band_data:原始矩阵数据
37 | :param index: 四角坐标
38 | :param slice_size: 分块大小
39 | :param edge_fill: 是否进行边缘填充
40 | :return: 切分好的单波段矩阵
41 | """
42 | if edge_fill:
43 | if (index[1] - index[0] != slice_size) or (index[3] - index[2] != slice_size):
44 | result = np.empty(shape=(slice_size, slice_size))
45 | new_row_min = index[0] % slice_size
46 | new_row_max = new_row_min + (index[1] - index[0])
47 | new_col_min = index[2] % slice_size
48 | new_col_max = new_col_min + (index[3] - index[2])
49 | result[new_row_min:new_row_max, new_col_min:new_col_max] = single_band_data[index[0]:index[1],
50 | index[2]:index[3]]
51 | else:
52 | result = single_band_data[index[0]:index[1], index[2]:index[3]]
53 | else:
54 | result = single_band_data[index[0]:index[1], index[2]:index[3]]
55 | return result.astype(single_band_data.dtype)
56 |
57 |
58 | def multi_bands_slice(multi_bands_data, index=[0, 1000, 0, 1000], slice_size=1000, edge_fill=False):
59 | """
60 | 依据四角坐标,切分多波段影像
61 | :param multi_bands_data: 原始多波段矩阵
62 | :param index: 四角坐标
63 | :param slice_size: 分块大小
64 | :param edge_fill: 是否进行边缘填充
65 | :return: 切分好的多波段矩阵
66 | """
67 | if edge_fill:
68 | if (index[1] - index[0] != slice_size) or (index[3] - index[2] != slice_size):
69 | result = np.empty(shape=(multi_bands_data.shape[0], slice_size, slice_size))
70 | new_row_min = index[0] % slice_size
71 | new_row_max = new_row_min + (index[1] - index[0])
72 | new_col_min = index[2] % slice_size
73 | new_col_max = new_col_min + (index[3] - index[2])
74 | result[:, new_row_min:new_row_max, new_col_min:new_col_max] = multi_bands_data[:, index[0]:index[1],
75 | index[2]:index[3]]
76 | else:
77 | result = multi_bands_data[:, index[0]:index[1], index[2]:index[3]]
78 | else:
79 | result = multi_bands_data[:, index[0]:index[1], index[2]:index[3]]
80 | return result.astype(multi_bands_data.dtype)
81 |
82 |
83 | def slice_conbine(slice_all, slice_index):
84 | """
85 | 将分块矩阵进行合并
86 | :param slice_all: 所有的分块矩阵列表
87 | :param slice_index: 分块的四角坐标
88 | :return: 合并的矩阵
89 | """
90 | combine_data = np.zeros(shape=(slice_index[-1][1], slice_index[-1][3]))
91 | # print(combine_data.shape)
92 | for i, slice_element in enumerate(slice_index):
93 | combine_data[slice_element[0]:slice_element[1], slice_element[2]:slice_element[3]] = slice_all[i]
94 | return combine_data
95 |
96 |
97 | def coordtransf(Xpixel, Ypixel, GeoTransform):
98 | """
99 | 像素坐标和地理坐标仿射变换
100 | :param Xpixel: 左上角行号
101 | :param Ypixel: 左上角列号
102 | :param GeoTransform: 原始仿射矩阵
103 | :return: 新的仿射矩阵
104 | """
105 | XGeo = GeoTransform[0] + GeoTransform[1] * Xpixel + Ypixel * GeoTransform[2];
106 | YGeo = GeoTransform[3] + GeoTransform[4] * Xpixel + Ypixel * GeoTransform[5];
107 | slice_geotrans = (XGeo, GeoTransform[1], GeoTransform[2], YGeo, GeoTransform[4], GeoTransform[5])
108 | return slice_geotrans
109 |
110 |
111 | def multi_bands_grid_slice(image_path, image_slice_dir, slice_size, edge_fill=False):
112 | """
113 | 多波段格网裁剪
114 | :param image_path: 原始多波段影像
115 | :param image_slice_dir: 裁剪保存文件夹
116 | :param slice_size: 裁剪大小
117 | :return:
118 | """
119 | image, image_info, image_data = read_multi_bands(image_path)
120 | # 计算分块的四角行列号
121 | slice_index = cal_single_band_slice(image_data[0, :, :], slice_size=slice_size, )
122 | # 执行裁剪
123 | with alive_bar(len(slice_index), force_tty=True) as bar:
124 | for i, slice_element in enumerate(slice_index):
125 | slice_data = multi_bands_slice(image_data, index=slice_element, slice_size=slice_size,
126 | edge_fill=edge_fill) # 裁剪多波段影像
127 | slice_geotrans = coordtransf(slice_element[2], slice_element[0], image_info[3]) # 转换仿射坐标
128 | image.write_img(image_slice_dir + r'\multi_grid_slice_' + str(i) + '.tif', slice_data,
129 | img_geotrans=slice_geotrans, img_proj=image_info[4]) # 写入文件
130 | bar()
131 | print('多波段格网裁剪完成')
132 |
133 |
134 | def single_band_grid_slice(band_path, band_slice_dir, slice_size, edge_fill=False):
135 | """
136 | 单波段格网裁剪
137 | :param band_path: 原始单波段影像
138 | :param band_slice_dir: 裁剪保存文件夹
139 | :param slice_size: 裁剪大小
140 | :return:
141 | """
142 | band, band_info, band_data = read_single_band(band_path)
143 | # 计算分块的四角行列号
144 | slice_index = cal_single_band_slice(band_data, slice_size=slice_size)
145 | # 执行裁剪
146 | with alive_bar(len(slice_index), force_tty=True) as bar:
147 | for i, slice_element in enumerate(slice_index):
148 | slice_data = single_band_slice(band_data, index=slice_element, slice_size=slice_size,
149 | edge_fill=edge_fill) # 裁剪单波段影像
150 | slice_geotrans = coordtransf(slice_element[2], slice_element[0], band_info[3]) # 转换仿射坐标
151 | band.write_img(band_slice_dir + r'\single_grid_slice_' + str(i) + '.tif', slice_data,
152 | img_geotrans=slice_geotrans, img_proj=band_info[4]) # 写入文件
153 | bar()
154 | print('单波段格网裁剪完成')
155 |
156 |
157 | def multi_bands_rand_slice(image_path, image_slice_dir, slice_size, slice_count):
158 | """
159 | 多波段随机裁剪
160 | :param image_path: 原始多波段影像
161 | :param image_slice_dir: 裁剪保存文件夹
162 | :param slice_size: 裁剪大小
163 | :param slice_count: 裁剪数量
164 | :return:
165 | """
166 | image, image_info, image_data = read_multi_bands(image_path)
167 | # 生成随机起始点
168 | randx = [random.randint(0, image_info[2] - slice_size - 1) for i in range(slice_count)]
169 | randy = [random.randint(0, image_info[1] - slice_size - 1) for j in range(slice_count)]
170 | randx1 = np.add(randx, slice_size).tolist()
171 | randy1 = np.add(randy, slice_size).tolist()
172 | rand_index = [[randx[k], randx1[k], randy[k], randy1[k]] for k in range(slice_count)]
173 | # 进行裁剪
174 | with alive_bar(len(rand_index), force_tty=True) as bar:
175 | for i, slice_element in enumerate(rand_index):
176 | slice_data = multi_bands_slice(image_data, index=slice_element, slice_size=slice_size) # 裁剪多波段影像
177 | slice_geotrans = coordtransf(slice_element[2], slice_element[0], image_info[3]) # 转换仿射坐标
178 | image.write_img(image_slice_dir + r'\multi_rand_slice_' + str(i) + '.tif', slice_data,
179 | img_geotrans=slice_geotrans, img_proj=image_info[4]) # 写入文件
180 | bar()
181 | print('多波段随机裁剪完成')
182 |
183 |
184 | def single_band_rand_slice(band_path, band_slice_dir, slice_size, slice_count):
185 | """
186 | 单波段随机裁剪
187 | :param band_path: 原始单波段影像
188 | :param band_slice_dir: 裁剪保存文件夹
189 | :param slice_size: 裁剪大小
190 | :param slice_count: 裁剪数量
191 | :return:
192 | """
193 | band, band_info, band_data = read_single_band(band_path)
194 | # 生成随机起始点
195 | randx = [random.randint(0, band_info[2] - slice_size - 1) for i in range(slice_count)]
196 | randy = [random.randint(0, band_info[1] - slice_size - 1) for j in range(slice_count)]
197 | randx1 = np.add(randx, slice_size).tolist()
198 | randy1 = np.add(randy, slice_size).tolist()
199 | rand_index = [[randx[k], randx1[k], randy[k], randy1[k]] for k in range(slice_count)]
200 | # 进行裁剪
201 | with alive_bar(len(rand_index), force_tty=True) as bar:
202 | for i, slice_element in enumerate(rand_index):
203 | slice_data = single_band_slice(band_data, index=slice_element, slice_size=slice_size) # 裁剪单波段影像
204 | slice_geotrans = coordtransf(slice_element[2], slice_element[0], band_info[3]) # 转换仿射坐标
205 | band.write_img(band_slice_dir + r'\single_rand_slice_' + str(i) + '.tif', slice_data,
206 | img_geotrans=slice_geotrans, img_proj=band_info[4]) # 写入文件
207 | bar()
208 | print('单波段随机裁剪完成')
209 |
210 |
211 | def deeplr_grid_slice(image_path, band_path, image_slice_dir, band_slice_dir, slice_size, edge_fill=False):
212 | """
213 | 制作深度学习样本-格网裁剪:同时裁剪多波段、单波段影像
214 | :param image_path: 原始image影像
215 | :param band_path: 原始label影像
216 | :param image_slice_dir: image裁剪保存文件夹
217 | :param band_slice_dir: label裁剪保存文件夹
218 | :param slice_size: 裁剪大小
219 | :return:
220 | """
221 | image, image_info, image_data = read_multi_bands(image_path)
222 | band, band_info, band_data = read_single_band(band_path)
223 | # 计算分块的四角行列号
224 | slice_index = cal_single_band_slice(image_data[0, :, :], slice_size=slice_size)
225 | # 执行裁剪
226 | with alive_bar(len(slice_index), force_tty=True) as bar:
227 | for i, slice_element in enumerate(slice_index):
228 | slice_data = multi_bands_slice(image_data, index=slice_element, slice_size=slice_size,
229 | edge_fill=edge_fill) # 裁剪多波段影像
230 | slice_geotrans = coordtransf(slice_element[2], slice_element[0], image_info[3]) # 转换仿射坐标
231 | image.write_img(image_slice_dir + r'\multi_grid_slice_' + str(i) + '.tif', slice_data,
232 | img_geotrans=slice_geotrans, img_proj=image_info[4]) # 写入文件
233 |
234 | slice_band = single_band_slice(band_data, index=slice_element, slice_size=slice_size,
235 | edge_fill=edge_fill) # 裁剪单波段影像
236 | band.write_img(band_slice_dir + r'\single_grid_slice_' + str(i) + '.tif', slice_band,
237 | img_geotrans=slice_geotrans, img_proj=band_info[4]) # 写入文件
238 | bar()
239 | print('深度学习样本-格网裁剪完成')
240 |
241 |
242 | def deeplr_rand_slice(image_path, band_path, image_slice_dir, band_slice_dir, slice_size, slice_count):
243 | """
244 | 制作深度学习样本-随机裁剪:同时裁剪多波段、单波段影像
245 | :param image_path: 原始image影像
246 | :param band_path: 原始label影像
247 | :param image_slice_dir: image裁剪保存文件夹
248 | :param band_slice_dir: label裁剪保存文件夹
249 | :param slice_size: 裁剪大小
250 | :param slice_count: 裁剪数量
251 | :return:
252 | """
253 | image, image_info, image_data = read_multi_bands(image_path)
254 | band, band_info, band_data = read_single_band(band_path)
255 | # 生成随机起始点
256 | randx = [random.randint(0, image_info[2] - slice_size - 1) for i in range(slice_count)]
257 | randy = [random.randint(0, image_info[1] - slice_size - 1) for j in range(slice_count)]
258 | randx1 = np.add(randx, slice_size).tolist()
259 | randy1 = np.add(randy, slice_size).tolist()
260 | rand_index = [[randx[k], randx1[k], randy[k], randy1[k]] for k in range(slice_count)]
261 | # 执行裁剪
262 | with alive_bar(len(rand_index), force_tty=True) as bar:
263 | for i, slice_element in enumerate(rand_index):
264 | slice_data = multi_bands_slice(image_data, index=slice_element, slice_size=slice_size) # 裁剪多波段影像
265 | slice_geotrans = coordtransf(slice_element[2], slice_element[0], image_info[3]) # 转换仿射坐标
266 | image.write_img(image_slice_dir + r'\multi_rand_slice_' + str(i) + '.tif', slice_data,
267 | img_geotrans=slice_geotrans, img_proj=image_info[4]) # 写入文件
268 |
269 | slice_band = single_band_slice(band_data, index=slice_element, slice_size=slice_size) # 裁剪单波段影像
270 | band.write_img(band_slice_dir + r'\single_rand_slice_' + str(i) + '.tif', slice_band,
271 | img_geotrans=slice_geotrans, img_proj=band_info[4]) # 写入文件
272 | bar()
273 | print('深度学习样本-随机裁剪完成')
274 |
--------------------------------------------------------------------------------
/progress/image_slice.py:
--------------------------------------------------------------------------------
1 | from module import slice
2 |
3 | if __name__ == '__main__':
4 | print(f"0-多波段格网裁剪、1-单波段格网裁剪、2-多波段随机裁剪、3-单波段随机裁剪、4-制作深度学习样本-格网裁剪、5-制作深度学习样本-随机裁剪")
5 | slice_type = input(f"请选择:")
6 | if int(slice_type) == 0:
7 | # 参数设置
8 | image_path = input(f"请输入待裁剪多波段影像路径:")
9 | image_slice_dir = input(f"请输入结果存放路径:")
10 | slice_size = int(input(f"请输入裁剪块大小:"))
11 | edge_fill = bool(int(input(f"是否进行边缘填充(0/1):")))
12 | slice.multi_bands_grid_slice(image_path, image_slice_dir, slice_size, edge_fill=edge_fill)
13 |
14 | elif int(slice_type) == 1:
15 | # 参数设置
16 | band_path = input(f"请输入待裁剪单波段影像路径:")
17 | band_slice_dir = input(f"请输入结果存放路径:")
18 | slice_size = int(input(f"请输入裁剪块大小:"))
19 | edge_fill = bool(int(input(f"是否进行边缘填充(0/1):")))
20 | slice.single_band_grid_slice(band_path, band_slice_dir, slice_size, edge_fill=edge_fill)
21 |
22 | elif int(slice_type) == 2:
23 | # 参数设置
24 | image_path = input(f"请输入待裁剪多波段影像路径:")
25 | image_slice_dir = input(f"请输入结果存放路径:")
26 | slice_size = int(input(f"请输入裁剪块大小:"))
27 | slice_count = int(input(f"请输入裁剪数量:"))
28 | slice.multi_bands_rand_slice(image_path, image_slice_dir, slice_size, slice_count)
29 |
30 | elif int(slice_type) == 3:
31 | # 参数设置
32 | band_path = input(f"请输入待裁剪单波段影像路径:")
33 | band_slice_dir = input(f"请输入结果存放路径:")
34 | slice_size = int(input(f"请输入裁剪块大小:"))
35 | slice_count = int(input(f"请输入裁剪数量:"))
36 | slice.single_band_rand_slice(band_path, band_slice_dir, slice_size, slice_count)
37 |
38 | elif int(slice_type) == 4:
39 | # 参数设置
40 | image_path = input(f"请输入待裁剪多波段image路径:")
41 | band_path = input(f"请输入待裁剪单波段label路径:")
42 | image_slice_dir = input(f"请输入image裁剪结果存放路径:")
43 | band_slice_dir = input(f"请输入label裁剪结果存放路径:")
44 | slice_size = int(input(f"请输入裁剪块大小:"))
45 | edge_fill = bool(int(input(f"是否进行边缘填充(0/1):")))
46 | slice.deeplr_grid_slice(image_path, band_path, image_slice_dir, band_slice_dir, slice_size, edge_fill=edge_fill)
47 |
48 | elif int(slice_type) == 5:
49 | # 参数设置
50 | image_path = input(f"请输入待裁剪多波段image路径:")
51 | band_path = input(f"请输入待裁剪单波段label路径:")
52 | image_slice_dir = input(f"请输入image裁剪结果存放路径:")
53 | band_slice_dir = input(f"请输入label裁剪结果存放路径:")
54 | slice_size = int(input(f"请输入裁剪块大小:"))
55 | slice_count = int(input(f"请输入裁剪数量:"))
56 | slice.deeplr_rand_slice(image_path, band_path, image_slice_dir, band_slice_dir, slice_size, slice_count)
57 |
58 | else:
59 | print('输入错误')
--------------------------------------------------------------------------------