├── .idea
├── .gitignore
├── Template.iml
├── deployment.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── remote-mappings.xml
├── Parameter
├── __init__.py
├── average_meter.py
├── lr_scheduler.py
└── metric.py
├── README.md
├── configs
└── configs.py
├── data
├── __init__.py
├── dataset.py
├── save.py
└── sync_transforms.py
├── main.py
├── model
├── ST_Unet
│ ├── deform_conv.py
│ ├── model_resnet.py
│ ├── vit_seg_configs.py
│ ├── vit_seg_modeling.py
│ └── vit_seg_modeling_resnet_skip.py
├── SwinUnet
│ ├── swin_transformer_unet_skip_expand_decoder_sys.py
│ └── vision_transformer.py
├── Swin_Transformer
│ └── SwinT.py
├── TransUnet
│ ├── vit_seg_configs.py
│ ├── vit_seg_modeling.py
│ └── vit_seg_modeling_resnet_skip.py
├── Unet
│ ├── Unet.py
│ └── _init_.py
└── deeplabv3_version_1
│ ├── aspp.py
│ ├── deeplabv3.py
│ └── resnet.py
├── tool
├── Save_predict.py
├── predict.py
├── train.py
└── val.py
└── utils
├── Data_process.py
├── Loss.py
└── palette.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/Template.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/remote-mappings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/Parameter/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wzysaber/ST_Unet_pytorch_Semantic-segmentation/b27f4d79ba85f81f793e17e686d6a7a1cd8b41ec/Parameter/__init__.py
--------------------------------------------------------------------------------
/Parameter/average_meter.py:
--------------------------------------------------------------------------------
1 | # 对相应的参数进行定义
2 | class AverageMeter(object):
3 | def __init__(self):
4 | self.reset()
5 |
6 | def reset(self):
7 | self.val = 0
8 | self.avg = 0
9 | self.sum = 0
10 | self.count = 0
11 |
12 | def update(self, val, n=1):
13 | self.val = val #当前值
14 | self.sum += val * n
15 | self.count += n
16 | self.avg = self.sum / self.count #平均值
--------------------------------------------------------------------------------
/Parameter/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import _LRScheduler
2 |
3 |
4 | class PolynomialLR(_LRScheduler):
5 | def __init__(self, optimizer, step_size, iter_max, power, last_epoch=-1):
6 | self.step_size = step_size
7 | self.iter_max = iter_max
8 | self.power = power
9 | self.last_epoch = last_epoch
10 | super(PolynomialLR, self).__init__(optimizer, last_epoch)
11 |
12 | def polynomial_decay(self, lr):
13 | return lr * (1 - float(self.last_epoch) / self.iter_max) ** self.power
14 |
15 | def get_lr(self):
16 | if((self.last_epoch == 0) or (self.last_epoch % self.step_size != 0) or (self.last_epoch > self.iter_max)):
17 | return [group['lr'] for group in self.optimizer.param_groups]
18 | return [self.polynomial_decay(lr) for lr in self.base_lrs]
19 |
--------------------------------------------------------------------------------
/Parameter/metric.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def matrix_change(conf_mat, num_classes=5):
5 | Matrix_data = conf_mat[:num_classes, :num_classes]
6 | return Matrix_data
7 |
8 |
9 | def confusion_matrix(pred, label, num_classes):
10 | mask = (label >= 0) & (label < num_classes)
11 | conf_mat = np.bincount(num_classes * label[mask].astype(int) + pred[mask], minlength=num_classes ** 2).reshape(
12 | num_classes, num_classes)
13 | Matrix_data = matrix_change(conf_mat)
14 | return Matrix_data
15 |
16 |
17 | def evaluate(Matrix_data):
18 | matrix = Matrix_data
19 | acc = np.diag(matrix).sum() / matrix.sum()
20 | acc_per_class = np.diag(matrix) / matrix.sum(axis=1)
21 | pre = np.nanmean(acc_per_class)
22 |
23 | recall_class = np.diag(matrix) / matrix.sum(axis=0)
24 | recall = np.nanmean(recall_class)
25 |
26 | F1_score = (2 * pre * recall) / (pre + recall)
27 |
28 | IoU = np.diag(matrix) / (matrix.sum(axis=1) + matrix.sum(axis=0) - np.diag(matrix))
29 | mean_IoU = np.nanmean(IoU)
30 |
31 | # 求kappa
32 | pe = np.dot(np.sum(matrix, axis=0), np.sum(matrix, axis=1)) / (matrix.sum() ** 2)
33 | kappa = (acc - pe) / (1 - pe)
34 | return acc, acc_per_class, pre, IoU, mean_IoU, kappa, F1_score, recall
35 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 【论文阅读】Swin Transformer Embedding UNet用于遥感图像语义分割
2 |
3 | [TOC]
4 |
5 |
6 |
7 | Swin Transformer Embedding UNet for Remote Sensing Image Semantic Segmentation
8 |
9 | 全局上下文信息是遥感图像语义分割的关键
10 |
11 | 具有强大全局建模能力的Swin transformer
12 |
13 | 提出了一种新的RS图像语义分割框架ST-UNet型网络(UNet)
14 |
15 | 解决方案:将Swin transformer嵌入到经典的基于cnn的UNet中
16 |
17 | ST-UNet由Swin变压器和CNN并联构成了一种新型的双编码器结构
18 |
19 | 相应结构:
20 |
21 | - 建立像素级相关性来编码Swin变压器块中的空间信息
22 | - 构造了特征压缩模块(FCM)
23 | - 作为双编码器之间的桥梁,设计了一个关系聚合模块(RAM)
24 |
25 | 数据集的使用:
26 |
27 | - Vaihingen
28 | - Potsdam
29 |
30 |
31 |
32 | ## 一、相应介绍
33 |
34 |
35 |
36 | **具体作用:**
37 |
38 | - 编码器用于提取特征
39 | - 解码器在融合高级语义和低级空间信息的同时,尽可能精细地恢复图像分辨率
40 |
41 | u型网络(UNet)[14]利用解码器通过跳过连接来学习相应编码阶段的空间相关性
42 |
43 |
44 |
45 | 利用变压器的编码器-解码器结构来模拟序列中元素之间的相互作用。
46 |
47 | 本文针对CNN在全局建模方面的不足,提出了一种新的RS图像语义分割网络框架ST-UNet
48 |
49 |
50 |
51 | **相应结构层次:**
52 |
53 | - 以UNet中的编码器为主编码器,Swin变压器为辅助编码器,形成一个并行的双编码器结构
54 | - 设计良好的关系聚合模块(RAM)构建了从辅助编码器到主编码器的单向信息流
55 | - RAM是ST-UNet的关键组件
56 | - 将SIM卡附加到Swin变压器上,以探索全局特征的空间相关性
57 | - 使用FCM提高小尺度目标的分割精度
58 |
59 |
60 |
61 | **相应贡献:**
62 |
63 | - 构建了空间交互模块(SIM),重点关注空间维度上的像素级特征相关性,SIM还弥补了Swin变压器窗口机制所限制的全局建模能力
64 | - 提出了特征压缩模块(FCM),以缓解patch token下采样过程中小尺度特征的遗漏
65 | - 设计了一个随机存储器,从辅助编码器中提取与chanel相关的信息作为全局线索来指导主编码器
66 |
67 |
68 |
69 | ## 二、相关工作
70 |
71 | ### 2.1 基于CNN的遥感图像语义分割
72 |
73 | 存在数据集:
74 |
75 | - IEEE地球科学与遥感学会(IGARSS)数据融合大赛
76 | - SpaceNet比赛
77 | - DeepGlobe比赛
78 |
79 |
80 |
81 | **在检测方面的发展过程**
82 |
83 | (1)在最开始的发展中,多分支并行卷积结构生成多尺度特征图,并设计自适应空间池化模块聚合更多局部上下文
84 |
85 | (2)引入了多层感知器(MLP),以产生更好的分割结果,最早是在自然语言中使用的。
86 |
87 | (3)关注了小尺度特征的特征提取
88 |
89 | (4)结合了基于patch的像素分类和像素到像素分割,引入了不确定映射,以实现对小尺度物体的高性能
90 |
91 | (5) 通过密集融合策略实现小尺度特征的聚合
92 |
93 | (6)明确引入边缘检测模块[43]来监督边界特征学习
94 |
95 | (7)提出了两个简单的边缘损失增强模块来增强物体边界的保存
96 |
97 |
98 |
99 | ### 2.2 Self-Attention机制
100 |
101 | 最早的注意力在计算机视觉领域
102 |
103 | (1)Zhao et al[45]和Li et al[46]分别给出了视频字幕的区域级注意和帧级注意
104 |
105 | (2)SENet[48]通过全局平均池化层表示通道之间的关系,自动了解不同通道的重要性
106 |
107 | (3)CBAM[49]将通道级注意和空间级注意应用于自适应特征细化
108 |
109 | (4)Ding等[19]提出了patch attention module来突出feature map的重点区域
110 |
111 | (5)在GCN[51]框架的每个阶段引入通道注意块,对特征图进行分层优化
112 |
113 | (6)[52] 关注小批量图像中的相似对象,并通过自注意机制对它们之间的交互信息进行编码
114 |
115 |
116 |
117 | ### 2.3 Vision Transformer
118 |
119 | 首次提出用于机器翻译任务[53],超越了以往基于复杂递归或cnn的序列转导模型
120 |
121 | 标准transformer由多头自注意(MSA)、多层感知器(MLP)和层归一化(LN)组成
122 |
123 | 通过分割和扁平化将图像数据转化为一系列tokens
124 |
125 | 密集的预测任务,ViT仍然有巨大的训练成本,只能输出一个不能匹配预测目标(与输入图像分辨率相同)的低分辨率特征
126 |
127 | **在现在过程中的发展:**
128 |
129 | 1. SETR[58]将转换器视为编码器,结合简单的解码器对每一层的全局上下文进行建模,形成语义分割网络
130 | 2. PVT[59]模仿CNN主干的特点,在ViT中引入金字塔结构,获得多比例尺特征图
131 | 3. 基于移位窗口策略的Swin变压器,将MSA的计算限制在不重叠的窗口
132 | 4. 以Swin转换器为骨干,Cao等[31]和Lin等[32]开发了医学图像语义分割的u型编码器-解码器框架
133 | 5. TransUNet[20]和TransFuse[60]指出,纯transformer细分网络的效果并不理想,因为transformer只关注全局建模,缺乏定位能力
134 | 6. 创建了CNN和transformer的混合结构。TransUNet将CNN和transformer依次堆叠
135 |
136 |
137 |
138 | 在本文中采用Swin变压器块组成的辅助编码器为基于cnn的主编码器提供全局上下文信息,提出的ST-UNet首次将Swin变压器应用到RS图像分割任务中,弥补了纯cnn的不足,提高了分割精度
139 |
140 |
141 |
142 | ## 三、方法
143 |
144 | **ST-UNet中的三个重要模块:**
145 |
146 | - **RAM**
147 | - **SIM**
148 | - **FCM**
149 |
150 |
151 |
152 | ### 3.1 网络结构
153 |
154 | ST-UNet的整体架构
155 |
156 | 
157 |
158 |
159 |
160 | **相应组成部分:**
161 |
162 | - ST-UNet是Swin transformer和UNet的混合体,它继承了UNet的优良结构,采用跳跃式连接层连接编码器和解码器
163 | - ST-UNet构造了由基于cnn的残差网络和Swin变压器组成的双编码器
164 | - 通过RAM传输信息,充分获取RS图像的判别特征
165 | - 设计了SIM和FCM,进一步提高了Swin transformer的性能。
166 |
167 |
168 |
169 | **辅助 encoder部分**
170 |
171 | ------
172 |
173 | **输入部分:**
174 |
175 | - RS图像X∈R^H×W×3^
176 | - 数据划分为不重叠的patch,以模拟序列数据的“token”
177 | - 通过卷积从每张图像中获取重叠的patch token
178 | - patch尺寸为8 × 8,重叠率为50%。然后将线性嵌入层压平
179 | - patch投影到C1维
180 | - patch token被放入Swin变压器块堆叠的辅助编码器
181 |
182 |
183 |
184 | 辅助编码器有四个特征提取阶段,每个阶段的输出定义为Sn, n = 1,2,3,4。标准的Swin变压器块包括两种类型,即基于窗口的变压器(W-Trans)和移位的W-Trans (SW-Trans)。
185 |
186 |
187 |
188 | **提出在SIM卡上建立像素级的信息交换,加在Swin transformer块上**
189 |
190 | SIM可以有效地弥补基于窗口的自我注意的局限性,缓解遮挡引起的语义模糊问题
191 |
192 |
193 |
194 | **通过缩短patch令牌长度构建FCM**
195 |
196 | 为了在与主编码器的特征分辨率匹配的同时获得多尺度特征,FCM的提出可以减少小尺度物体特征的遗漏
197 |
198 | 阶段n的输出分辨率为(H/(2^n+1^) × (W/(2^n+1^),维度为(2^n−1^)*C1
199 |
200 |
201 |
202 |
203 |
204 | **主要encode部分**
205 |
206 | ------
207 |
208 | **输入部分:**
209 |
210 | - 原始RS图像X先在通道上压缩一半后馈送到ResNet50
211 | - 第n个残差块的输出特征图可表示为An∈R(H/(2^n+1^))×(W/(2^n+1^))×2^n−1^C2
212 | - 将An和辅助编码器对应级的输出Sn送入RAM,融合结果返回主编码器。
213 | - RAM模块作为主辅编码器之间的桥梁,通过可变形卷积和通道注意机制建立连接。
214 |
215 |
216 |
217 | **解码部分**
218 |
219 | ------
220 |
221 | **具体操作:**
222 |
223 | - 特征F∈R(H/32)×(W/32)×1024,经过卷积层后送入解码器。然后,我们将其输入到2 × 2反卷积层以扩大分辨率
224 | - UNet之后,ST-UNet利用跳过连接层来连接编码器和解码器特性
225 | - 3×3卷积层的减少通道数量
226 | - 每个卷积层都伴随着一个批处理归一层和一个ReLU层
227 | - 最后,对特征F进行3 × 3卷积层和线性插值上采样,得到最终的预测掩码。
228 |
229 |
230 |
231 | ### 3.2 Swin Transformer BlocK
232 |
233 | 为了高效建模,Swin变压器提出了具有两种分区配置的W-MSA来替代普通MSA
234 |
235 | **MSA变化:**
236 |
237 | - 常规窗口配置(W-MSA)
238 | - 移位窗口配置(SW-MSA)
239 |
240 | 每个窗口只覆盖D × D补丁,将D设为8,**将两个Swin变压器块重命名为W-Trans块和SW-Trans块**
241 |
242 | 
243 |
244 | **相应的结构图**
245 |
246 | 
247 |
248 |
249 |
250 | ### 3.3 空间交互模块
251 |
252 | Swin transformer块在有限的窗口内建立patch token关系,有效地减少内存开销
253 |
254 | **具体操作:**
255 |
256 | - 采用了规则窗口和移位窗口的交替执行策略
257 | - 提出了跨W-Trans和SW-Trans区块的SIM,以进一步增强信息交换
258 | - SIM在两个空间维度上引入注意力,考虑像素之间的关系,而不仅仅是patch token
259 | - 在输入阶段将输入数据转化为一维
260 |
261 |
262 |
263 | **SIM结构框图**
264 |
265 | 
266 |
267 | **SIM操作:**
268 |
269 | - 通过一个大的接受场,将特征向量以2的扩张速率
270 | - 在一个3 × 3的扩张卷积层上进行卷积
271 | - 将通道数缩减为c1/2。然后,采用全局平均池化操作
272 |
273 | 在竖直方向和水平方向上的总张量分别记为 h×1×(c1/2)和1×w×(c1/2),因此我们将两者相乘得到与位置相关的注意力图M, 张量h×w×(c1/2),最后,将M与SW-Trans块的输出sl+1相加。
274 |
275 | M的维数需要通过卷积层增加,以匹配特征sl+1的维数(所以进行了1X1卷积来改变通道数)
276 |
277 |
278 |
279 | ### 3.4 特征压缩模块
280 |
281 | 在transformer的前期工作中,通过将图像补丁[27]、[59]平化投影或合并2个×2相邻补丁的特征,并对[30]进行线性处理,形成了一个层次网络。
282 |
283 | 在Swin变压器的patch token下采样中设计了FCM
284 |
285 | FCM避免了大量细节和结构信息的丢失,物体密集、小尺度的RS图像的语义分割,提高了小尺度对象的分割效果。
286 |
287 |
288 |
289 | **FCM结构框图**
290 |
291 | 
292 |
293 | **一种分支是扩大卷积的块:**
294 |
295 | - 扩大卷积的接受场,广泛地收集小尺度物体的特征和结构信息
296 | - 采用前1 × 1卷积层增维
297 | - 中间3 × 3扩张卷积层获取广泛的结构信息
298 | - 后1 × 1卷积层降低特征尺度
299 | - 输出结构(h/2)×(w/2)×2c1
300 |
301 |
302 |
303 | **另一个分支:**
304 |
305 | - 引入了软池[61]操作,以获得更精细的下采样
306 | - 软池可以以指数加权的方式激活池化内核中的像素
307 | - 将软池后的特征输入到卷积层(增维)
308 | - 输出结构(h/2)×(w/2)×2c1
309 |
310 |
311 |
312 | **两个分支按等比例合并为FCM的输出L**
313 |
314 |
315 |
316 | ### 3.5 关系聚合模块
317 |
318 | 基于cnn的主编码器在空间维度上提取了受卷积核限制的局部信息,缺乏对channel维度[48]之间关系的显式建模
319 |
320 | 提出了RAM,为了从整个特征图中强调重要且更具代表性的channel,从辅助编码器的全局特征中提取channel依赖关系,然后将其嵌入到从主编码器获得的局部特征中。
321 |
322 |
323 |
324 | **RAM结构特征图**
325 |
326 | 
327 |
328 | **RAM引入了可变形卷积[63]以适应不同形状的目标区域**
329 |
330 |
331 |
332 | **具体操作:**
333 |
334 | - An和Sn分别表示第n阶段主编码器和辅助编码器的输出
335 | - An输入到可变形卷积中,An = δ(An)。这里δ是一个3 × 3的可变形卷积
336 | - Sn被发送到卷积层以改变维数,由于特征图的每个通道都可以看作是一个特征检测器
337 | - 我们应用average-和max-pool层来计算通道上特征映射的统计特征,
338 | - 发送到共享的全连接层,PA&M结构数为 1×1×(c1/2)
339 | - σ代表ReLu函数,$1被设置为一个大小减半的全连接层
340 | - PA&M与PS相乘来优化每个通道
341 |
342 | 
343 |
344 |
345 |
346 | δ代表sigmoid函数,$2是一个大小增加的完全连接层,并表示元素级乘法。
347 |
348 | 我们将Channel依赖P作为权值与变形卷积运算的结果An相乘,得到了细化的特征。最后,将细化后的特征与残差结构相连接,形成RAM的输出特征Tn
349 |
350 | 
351 |
352 |
353 |
354 | ## 四、实验结果
355 |
356 | ### 4.1 数据集
357 |
358 | Vaihingen Dataset
359 |
360 | 包含33张由先进机载传感器采集的真正射影像(TOP)图像,每个TOP图像都有红外(IR)、红色(R)和绿色(G)通道。
361 |
362 | **相应参数:**
363 |
364 | - 图像被标记为sic类别
365 | - 11张图像用于训练(图像id: 1、3、5、7、13、17、21、23、26、32和37)
366 | - 5张图像用于测试(图像id: 11、15、28、30和34),
367 | - 裁剪为256 × 256
368 |
369 |
370 |
371 | Potsdam Dataset
372 |
373 | 有38个相同大小的patch (6000 × 6000),都是从高分辨率TOP提取
374 |
375 | **相应参数:**
376 |
377 | - 数据集进行了六个类别的标注,用于语义分割研究
378 | - 每张图像都有三种通道组合,即IR-R-G、R-G-B和R-G-B- ir
379 | - 使用14张带有R-G-B的图像进行测试
380 | - (图像id: 2_13, 2_14, 3_13, 3_14, 4_13, 4_14, 4_15, 5_13, 5_14, 5_15, 6_14, 6_15, 7_13)
381 | - 其余24张带有R-G-B的图像进行训练
382 | - 我们将这些原始图像切割为256 × 256
383 |
384 |
385 |
386 | ### 4.2 具体参数
387 |
388 | **实验具体参数:**
389 |
390 | - 动量项为0.9,权重衰减为1e−4
391 | - SGD优化器
392 | - 初始学习率设置为0.01
393 | - 批处理大小设置为8
394 | - 最大epoch为100
395 |
396 |
397 |
398 | 采用联合损失 dice loss [71] LDice与骰子损失cross-entropy loss LCE
399 |
400 | 
401 |
402 | **评价指标:**
403 |
404 | - 平均交叉over联合(MIoU)
405 | - 平均F1 (Ave.F1)
406 |
407 |
408 |
409 | ### 4.3 消融实验
410 |
411 | 为了评估所提出的网络结构和三个重要模块的性能,我们将UNet作为基线网络
412 |
413 | 采用Vaihingen数据集
414 |
415 | **步骤:**
416 |
417 | - 在我们的ST-UNet中,主编码器采用半压缩的ResNet50
418 | - 辅助编码器采用“Tiny”配置的Swin变压器
419 |
420 |
421 |
422 | 主要分为两种:
423 |
424 | Add~LS~,即在编码的最后阶段才合并辅助编码器和主编码器的特征
425 |
426 | Add~ES~,辅助编码器和主编码器在每个编码阶段的特征,通过元素相加。
--------------------------------------------------------------------------------
/configs/configs.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import os
4 |
5 |
6 | # 函数参数定义
7 | def parse_args():
8 | parser = argparse.ArgumentParser(description="RemoteSensingSegmentation by PyTorch")
9 |
10 | # dataset
11 | parser.add_argument('--dataset-name', type=str, default='Vaihingen')
12 | parser.add_argument('--train-data-root', type=str,
13 | default='/home/students/master/2022/wangzy/PyCharm-Remote/ST_Unet_test/Vaihingen_Img/Train/')
14 | parser.add_argument('--val-data-root', type=str,
15 | default='/home/students/master/2022/wangzy/PyCharm-Remote/ST_Unet_test/Vaihingen_Img/Test/')
16 | parser.add_argument('--train-batch-size', type=int, default=8, metavar='N',
17 | help='batch size for training (default:16)')
18 | parser.add_argument('--val-batch-size', type=int, default=8, metavar='N',
19 | help='batch size for testing (default:16)')
20 |
21 | # output_save_path
22 | # strftime格式化时间,显示当前的时间
23 | parser.add_argument('--experiment-start-time', type=str,
24 | default=time.strftime('%m-%d-%H:%M:%S', time.localtime(time.time())))
25 | parser.add_argument('--save-pseudo-data-path', type=str,
26 | default='/home/students/master/2022/wangzy/PyCharm-Remote/ST_Unet_test/pseudo_data')
27 | parser.add_argument('--save-file', default=False)
28 |
29 | # augmentation
30 | parser.add_argument('--base-size', type=int, default=256, help='base image size')
31 | parser.add_argument('--crop-size', type=int, default=256, help='crop image size')
32 | parser.add_argument('--flip-ratio', type=float, default=0.5)
33 | parser.add_argument('--resize-scale-range', type=str, default='0.5, 2.0')
34 |
35 | # model
36 | parser.add_argument('--model', type=str, default='Swin_Transformer', help='model name')
37 | parser.add_argument('--pretrained', action='store_true', default=True)
38 |
39 | # criterion
40 | # 损失的权重值
41 | parser.add_argument('--class-loss-weight', type=list, default=
42 | # [0.007814952234152803, 0.055862295151291756, 0.029094606950899726, 0.03104357983254851, 0.22757710412943985, 0.19666243636646102, 0.6088052968747066, 0.15683966777104494, 0.5288489922602664, 0.21668940382940433, 0.04310240828376457, 0.18284053575941367, 0.571096349549462, 0.32601488184885147, 0.45384359272537766, 1.0])
43 | # [0.007956167959807792, 0.05664417300631733, 0.029857031694750392, 0.03198534634969046, 0.2309102255169529,
44 | # 0.19627322641039702, 0.6074939752850792, 0.16196525436190998, 0.5396602408824741, 0.22346488456565283,
45 | # 0.04453628275090391, 0.18672995330033487, 0.5990724459491834, 0.33183887346397484, 0.47737597643193597, 1.0]
46 | [0.008728536232175135, 0.05870821984204281, 0.030766985878693004, 0.03295408432939304, 0.2399409412190348,
47 | 0.20305583055639448, 0.6344888568739531, 0.16440413437125656, 0.5372260524694122, 0.22310945250778813,
48 | 0.04659596810284655, 0.19246378709444723, 0.6087430986295436, 0.34431415558778183, 0.4718853977371564, 1.0])
49 |
50 | # loss
51 | parser.add_argument('--loss-names', type=str, default='cross_entropy')
52 | parser.add_argument('--classes-weight', type=str, default=None)
53 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='momentum (default:0.9)')
54 | parser.add_argument('--weight-decay', type=float, default=0.0001, metavar='M', help='weight-decay (default:1e-4)')
55 |
56 | # optimizer
57 | parser.add_argument('--optimizer-name', type=str, default='SGD')
58 |
59 | # learning_rate
60 | parser.add_argument('--base-lr', type=float, default=0.01, metavar='M', help='')
61 |
62 | # environment
63 | parser.add_argument('--use-cuda', action='store_true', default=True, help='using CUDA training')
64 | parser.add_argument('--num-GPUs', type=int, default=1, help='numbers of GPUs')
65 | parser.add_argument('--num_workers', type=int, default=32)
66 |
67 | # validation
68 | parser.add_argument('--eval', action='store_true', default=False, help='evaluation only')
69 | parser.add_argument('--no-val', action='store_true', default=False)
70 |
71 | parser.add_argument('--best-miou', type=float, default=0)
72 |
73 | parser.add_argument('--total-epochs', type=int, default=100, metavar='N',
74 | help='number of epochs to train (default: 120)')
75 | parser.add_argument('--start-epoch', type=int, default=0, metavar='N', help='start epoch (default:0)')
76 |
77 | parser.add_argument('--resume-path', type=str, default=None)
78 |
79 | args = parser.parse_args()
80 |
81 | directory = "weight/%s/%s/%s/" % (args.dataset_name, args.model, args.experiment_start_time)
82 | args.directory = directory
83 |
84 | if args.save_file:
85 | if not os.path.exists(directory):
86 | os.makedirs(directory)
87 | print("Creat and Save model.pth!")
88 |
89 | return args
90 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wzysaber/ST_Unet_pytorch_Semantic-segmentation/b27f4d79ba85f81f793e17e686d6a7a1cd8b41ec/data/__init__.py
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import os
3 | from PIL import Image
4 | from torchvision import transforms
5 | import numpy as np
6 | import torch
7 | from utils.Data_process import five_classes
8 |
9 | # 将图像数据转化为numpy型
10 | class MaskToTensor(object): # 将MaskToTensor定义为可以调用的类
11 | def __call__(self, img):
12 | return torch.from_numpy(np.array(img, dtype=np.int32)).long()
13 |
14 |
15 | # 对图像进行归一化的操作
16 | img_transform = transforms.Compose([
17 | transforms.ToTensor(),
18 | transforms.Normalize([.485, .456, .406], [.229, .224, .225])
19 | ])
20 | mask_transform = MaskToTensor()
21 |
22 |
23 | class RSDataset(Dataset):
24 | def __init__(self, root=None, mode=None, img_transform=img_transform, mask_transform=mask_transform,
25 | sync_transforms=None):
26 | # 数据相关
27 | self.class_names = five_classes() # 图像中所包含的种类
28 | self.mode = mode
29 | self.img_transform = img_transform
30 | self.mask_transform = mask_transform
31 | self.sync_transform = sync_transforms
32 | self.sync_img_mask = []
33 |
34 | if mode == "train":
35 | key_word = 'train_data'
36 | elif mode == "val":
37 | key_word = 'val_data'
38 | else:
39 | key_word = 'test_data'
40 |
41 | if mode == "src":
42 | img_dir = os.path.join(root, 'rgb')
43 | mask_dir = os.path.join(root, 'label')
44 | else:
45 | for dirname in os.listdir(root):
46 | # 进入选定的文件夹
47 | if dirname == key_word in dirname:
48 | break
49 |
50 | # 读取其中的图像数据
51 |
52 | img_dir = os.path.join(root, dirname, 'rgb')
53 | mask_dir = os.path.join(root, dirname, 'label')
54 |
55 | # 将相应的图像数据进行保存
56 | for img_filename in os.listdir(img_dir):
57 | img_mask_pair = (os.path.join(img_dir, img_filename),
58 | os.path.join(mask_dir,
59 | img_filename.replace(img_filename[-8:], "label_" + img_filename[-8:])))
60 |
61 | self.sync_img_mask.append(img_mask_pair)
62 |
63 | # print(self.sync_img_mask)
64 | if (len(self.sync_img_mask)) == 0:
65 | print("Found 0 data, please check your dataset!")
66 |
67 | def __getitem__(self, index):
68 | num_class = 6
69 | ignore_label = 5
70 |
71 | img_path, mask_path = self.sync_img_mask[index]
72 | img = Image.open(img_path).convert('RGB')
73 | mask = Image.open(mask_path).convert('L') # 将图像转化为灰度值
74 |
75 | # 将图像进行相应的裁剪,变换等操作
76 | if self.sync_transform is not None:
77 | img, mask = self.sync_transform(img, mask)
78 |
79 | # 将原始图像进行归一化操作
80 | if self.img_transform is not None:
81 | img = self.img_transform(img)
82 |
83 | # 将标签图转化为可以操作的形式
84 | if self.mask_transform is not None:
85 | mask = self.mask_transform(mask)
86 |
87 | mask[mask >= num_class] = ignore_label
88 | mask[mask < 0] = ignore_label
89 |
90 | return img, mask
91 |
92 | def __len__(self):
93 | return len(self.sync_img_mask)
94 |
95 | def classes(self):
96 | return self.class_names
97 |
98 |
99 | if __name__ == "__main__":
100 | pass
101 | # RSDataset(class_name, root=args.train_data_root, mode='train', sync_transforms=None)
102 |
--------------------------------------------------------------------------------
/data/save.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 | from configs.configs import parse_args
5 | from PIL import Image
6 | from utils.palette import colorize_mask
7 | from torchvision import transforms
8 |
9 | args = parse_args()
10 |
11 |
12 | # 保存相应的工作参数
13 | def save_work():
14 | directory = "work_dirs/%s/%s/%s/%s/" % (args.dataset_name, args.model, args.backbone, args.experiment_start_time)
15 | args.directory = directory
16 | if not os.path.exists(directory):
17 | os.makedirs(directory)
18 |
19 | config_file = os.path.join(directory, 'config.json')
20 |
21 | # 将相应参数转换为json格式,进行文本保存
22 | with open(config_file, 'w') as file:
23 | json.dump(vars(args), file, indent=4)
24 |
25 | if args.use_cuda:
26 | print('Numbers of GPUs:', args.num_GPUs)
27 | else:
28 | print("Using CPU")
29 |
30 |
31 | # 归一化操作
32 | # zip的作用将元素打包成为元组
33 | class DeNormalize(object):
34 | def __init__(self, mean, std):
35 | self.mean = mean
36 | self.std = std
37 |
38 | def __call__(self, tensor):
39 | for t, m, s in zip(tensor, self.mean, self.std):
40 | t.mul_(s).add_(m)
41 | return tensor
42 |
43 |
44 | resore_transform = transforms.Compose([
45 | DeNormalize([.485, .456, .406], [.229, .224, .225]), # 对相应的数据进行归一化操作
46 | transforms.ToPILImage() # 将图片变化为可以查看的形式
47 | ])
48 |
49 |
50 | def save_pic(score, data, preds, save_path, epoch, index):
51 | val_visual = []
52 | # 将相应的图片进行保存到文件夹
53 | for i in range(score.shape[0]):
54 |
55 | num_score = np.sum(score[i] > 0.9)
56 |
57 | if num_score > 0.9 * (512 * 512):
58 | # 将图片进行归一化操作
59 | # 提取原始图像后进行操作
60 | img_pil = resore_transform(data[0][i])
61 |
62 | # 将图片转化为灰度图片
63 | # 这个是我的预测图像
64 | preds_pil = Image.fromarray(preds[i].astype(np.uint8)).convert('L')
65 |
66 | # 将预测图片转化为RGB
67 | pred_vis_pil = colorize_mask(preds[i])
68 |
69 | # 将图片转化为RGB
70 | gt_vis_pil = colorize_mask(data[1][i].numpy())
71 |
72 | # 将相应的数据包装起来
73 | dir_list = ['rgb', 'label', 'vis_label', 'gt']
74 | rgb_save_path = os.path.join(save_path, dir_list[0], str(epoch))
75 | label_save_path = os.path.join(save_path, dir_list[1], str(epoch))
76 | vis_save_path = os.path.join(save_path, dir_list[2], str(epoch))
77 | gt_save_path = os.path.join(save_path, dir_list[3], str(epoch))
78 |
79 | path_list = [rgb_save_path, label_save_path, vis_save_path, gt_save_path]
80 |
81 | # 创建相应的地址位置
82 | for path in range(4):
83 | if not os.path.exists(path_list[path]):
84 | os.makedirs(path_list[path])
85 |
86 | # 将相应的地址位进行保存
87 | img_pil.save(os.path.join(path_list[0], 'img_batch_%d_%d.jpg' % (index, i)))
88 | preds_pil.save(os.path.join(path_list[1], 'label_%d_%d.png' % (index, i)))
89 | pred_vis_pil.save(os.path.join(path_list[2], 'vis_%d_%d.png' % (index, i)))
90 | gt_vis_pil.save(os.path.join(path_list[3], 'gt_%d_%d.png' % (index, i)))
91 |
--------------------------------------------------------------------------------
/data/sync_transforms.py:
--------------------------------------------------------------------------------
1 | import random
2 | from PIL import Image, ImageOps, ImageFilter
3 | import numpy as np
4 |
5 |
6 | # 将transforms中的img和mask提取出来
7 | class Compose(object):
8 | def __init__(self, transforms):
9 | self.transforms = transforms
10 |
11 | def __call__(self, img, mask):
12 | assert img.size == mask.size
13 | for t in self.transforms:
14 | img, mask = t(img, mask)
15 | return img, mask
16 |
17 |
18 | class RandomScale(object):
19 | def __init__(self, base_size, crop_size, resize_scale_range):
20 | self.base_size = base_size
21 | self.crop_size = crop_size
22 | self.resize_scale_range = resize_scale_range
23 |
24 | def __call__(self, img, mask):
25 | w, h = img.size
26 |
27 | # print("img.size:", img.size)
28 | # randon.randint返回指定范围内的整数
29 |
30 | short_size = random.randint(int(self.base_size * self.resize_scale_range[0]),
31 | int(self.base_size * self.resize_scale_range[1]))
32 | # print("short_size:", short_size)
33 | # if h > w:
34 | # ow = short_size
35 | # oh = int(1.0 * h * ow / w)
36 | # else:
37 | # oh = short_size
38 | # ow = int(1.0 * w * oh / h)
39 | ow, oh = short_size, short_size
40 | # print("ow, oh = ", ow, oh)
41 | img, mask = img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST) # 对相应的图片进行缩放
42 |
43 | # 当相应的尺寸不够的时候做扩展
44 | if short_size < self.crop_size:
45 | padh = self.crop_size - oh if oh < self.crop_size else 0
46 | padw = self.crop_size - ow if ow < self.crop_size else 0
47 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
48 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
49 |
50 | w, h = img.size
51 | img = np.array(img)
52 | mask = np.array(mask)
53 | num_crop = 0
54 | while num_crop < 5:
55 | x = random.randint(0, w - self.crop_size)
56 | y = random.randint(0, h - self.crop_size)
57 | endx = x + self.crop_size
58 | endy = y + self.crop_size
59 | patch = img[y:endy, x:endx]
60 | if (patch == 0).all():
61 | continue
62 | else:
63 | break
64 | img = img[y:endy, x:endx]
65 | mask = mask[y:endy, x:endx]
66 | img, mask = Image.fromarray(img), Image.fromarray(mask)
67 | return img, mask
68 |
69 |
70 | class RandomFlip(object):
71 | def __init__(self, flip_ratio=0.5):
72 | self.flip_ratio = flip_ratio
73 |
74 | def __call__(self, img, mask):
75 | if random.random() < self.flip_ratio:
76 | img, mask = img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT)
77 | else:
78 | img, mask = img.transpose(Image.FLIP_TOP_BOTTOM), mask.transpose(Image.FLIP_TOP_BOTTOM)
79 | return img, mask
80 |
81 |
82 | class RandomGaussianBlur(object):
83 | def __init__(self, prop):
84 | self.prop = prop
85 |
86 | def __call__(self, img, mask, prop):
87 | if random.random() < self.prop:
88 | img = img.filter(ImageFilter.GaussianBlur)(radius=random.random())
89 | return img, mask
90 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from configs.configs import parse_args
5 | from model.deeplabv3_version_1.deeplabv3 import DeepLabV3
6 | from model.Unet.Unet import Unet
7 | # from model.ST_Unet.vit_seg_modeling import VisionTransformer
8 | # from model.ST_Unet.vit_seg_configs import get_r50_b16_config
9 | from model.SwinUnet.vision_transformer import SwinUnet
10 | from model.TransUnet.vit_seg_configs import get_r50_b16_config
11 | from model.TransUnet.vit_seg_modeling import VisionTransformer
12 | from model.Swin_Transformer.SwinT import SwinTransformerV2
13 |
14 | from tool.train import close_optimizer
15 | from tool.train import data_set
16 | from tool.train import training
17 | from tool.val import validating
18 |
19 | from utils.Loss import DiceLoss
20 | from utils.Data_process import Print_data
21 | from utils.Data_process import Creat_LineGraph
22 |
23 | # 忽略相应的警告
24 | import warnings
25 |
26 | warnings.filterwarnings("ignore")
27 |
28 | # 清除pytorch无用缓存
29 | import gc
30 |
31 | gc.collect()
32 | torch.cuda.empty_cache()
33 |
34 | # # 设置GPU的序列号
35 | import os
36 |
37 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" # 设置采用的GPU序号
38 |
39 |
40 | def main():
41 | # # 所以,当这个参数设置为True时,启动算法的前期会比较慢,但算法跑起来以后会非常快
42 | torch.backends.cudnn.benchmark = True
43 |
44 | # 导入配置
45 | args = parse_args()
46 |
47 | # 加载训练和验证数据集
48 | train_loader = data_set(args)[0]
49 | train_dataset = data_set(args)[1]
50 |
51 | val_loader = data_set(args)[2]
52 |
53 | # 训练的相关配置
54 | device = torch.device("cuda:2")
55 |
56 | # 加载模型
57 | if args.model == "Unet":
58 | model = Unet(num_classes=6).to(device)
59 | elif args.model == "ST-Unet":
60 | config_vit = get_r50_b16_config()
61 | model = VisionTransformer(config_vit, img_size=256, num_classes=6).to(device)
62 | elif args.model == "deeplabv3+":
63 | model = DeepLabV3(num_classes=6).to(device)
64 | elif args.model == "SwinUnet":
65 | model = SwinUnet(num_classes=6).to(device)
66 | elif args.model == "TransUnet":
67 | config_vit = get_r50_b16_config()
68 | model = VisionTransformer(config_vit, img_size=256, num_classes=6).to(device)
69 | elif args.model == "Swin_Transformer":
70 | model = SwinTransformerV2().to(device)
71 |
72 | # 判断是否有训练好的模型
73 | if args.resume_path:
74 | state_dict = torch.load('.pth')
75 | model.load_state_dict(state_dict, state_dict=False)
76 |
77 | # 损失函数
78 | criterion1 = nn.CrossEntropyLoss().to(device)
79 | criterion2 = DiceLoss(6).to(device)
80 |
81 | # 优化器选择
82 | optimizer = close_optimizer(args, model).to(device)
83 |
84 | # 将相应的参数进行打印
85 | Print_data(args.dataset_name, train_dataset.class_names,
86 | train_dataset, args.optimizer_name, args.model, args.total_epochs)
87 |
88 | # 训练及验证
89 | traincd_Data = []
90 | for epoch in range(args.start_epoch, args.total_epochs):
91 | ACC = training(args, 6, model, optimizer, train_dataset, train_loader, criterion1, criterion2, device,
92 | epoch) # 对模型进行训练zzzz
93 | validating(args, 6, model, optimizer, train_dataset, val_loader, device, epoch) # 对模型进行验证
94 | traincd_Data.append(ACC)
95 | print(" ")
96 | Creat_LineGraph(traincd_Data) # 绘制相应曲线图
97 |
98 |
99 | if __name__ == "__main__":
100 | main()
101 |
--------------------------------------------------------------------------------
/model/ST_Unet/deform_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | import os
5 |
6 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" # 设置采用的GPU序号
7 |
8 | class DeformConv2d(nn.Module):
9 | def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=False):
10 | """
11 | Args:
12 | modulation (bool, optional): If True, Modulated Defomable Convolution (Deformable ConvNets v2).
13 | """
14 | super(DeformConv2d, self).__init__()
15 | self.kernel_size = kernel_size
16 | self.padding = padding
17 | self.stride = stride
18 | self.zero_padding = nn.ZeroPad2d(padding)
19 | self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias)
20 |
21 | self.p_conv = nn.Conv2d(inc, 2*kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
22 | nn.init.constant_(self.p_conv.weight, 0)
23 | self.p_conv.register_backward_hook(self._set_lr)
24 |
25 | self.modulation = modulation
26 | if modulation:
27 | self.m_conv = nn.Conv2d(inc, kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
28 | nn.init.constant_(self.m_conv.weight, 0)
29 | self.m_conv.register_backward_hook(self._set_lr)
30 |
31 | @staticmethod
32 | def _set_lr(module, grad_input, grad_output):
33 | grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input)))
34 | grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output)))
35 |
36 | def forward(self, x):
37 | offset = self.p_conv(x)
38 | if self.modulation:
39 | m = torch.sigmoid(self.m_conv(x))
40 |
41 | dtype = offset.data.type()
42 | ks = self.kernel_size
43 | N = offset.size(1) // 2
44 |
45 | if self.padding:
46 | x = self.zero_padding(x)
47 |
48 | # (b, 2N, h, w)
49 | p = self._get_p(offset, dtype)
50 |
51 | # (b, h, w, 2N)
52 | p = p.contiguous().permute(0, 2, 3, 1)
53 | q_lt = p.detach().floor()
54 | q_rb = q_lt + 1
55 |
56 | q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long()
57 | q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long()
58 | q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)
59 | q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)
60 |
61 | # clip p
62 | p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1)
63 |
64 | # bilinear kernel (b, h, w, N)
65 | g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
66 | g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
67 | g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
68 | g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))
69 |
70 | # (b, c, h, w, N)
71 | x_q_lt = self._get_x_q(x, q_lt, N)
72 | x_q_rb = self._get_x_q(x, q_rb, N)
73 | x_q_lb = self._get_x_q(x, q_lb, N)
74 | x_q_rt = self._get_x_q(x, q_rt, N)
75 |
76 | # (b, c, h, w, N)
77 | x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
78 | g_rb.unsqueeze(dim=1) * x_q_rb + \
79 | g_lb.unsqueeze(dim=1) * x_q_lb + \
80 | g_rt.unsqueeze(dim=1) * x_q_rt
81 |
82 | # modulation
83 | if self.modulation:
84 | m = m.contiguous().permute(0, 2, 3, 1)
85 | m = m.unsqueeze(dim=1)
86 | m = torch.cat([m for _ in range(x_offset.size(1))], dim=1)
87 | x_offset *= m
88 |
89 | x_offset = self._reshape_x_offset(x_offset, ks)
90 | out = self.conv(x_offset)
91 |
92 | return out
93 |
94 | def _get_p_n(self, N, dtype):
95 | p_n_x, p_n_y = torch.meshgrid(
96 | torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1),
97 | torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1))
98 | # (2N, 1)
99 | p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)
100 | p_n = p_n.view(1, 2*N, 1, 1).type(dtype)
101 |
102 | return p_n
103 |
104 | def _get_p_0(self, h, w, N, dtype):
105 | p_0_x, p_0_y = torch.meshgrid(
106 | torch.arange(1, h*self.stride+1, self.stride),
107 | torch.arange(1, w*self.stride+1, self.stride))
108 | p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
109 | p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
110 | p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)
111 |
112 | return p_0
113 |
114 | def _get_p(self, offset, dtype):
115 | N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)
116 |
117 | # (1, 2N, 1, 1)
118 | p_n = self._get_p_n(N, dtype)
119 | # (1, 2N, h, w)
120 | p_0 = self._get_p_0(h, w, N, dtype)
121 | p = p_0 + p_n + offset
122 | return p
123 |
124 | def _get_x_q(self, x, q, N):
125 | b, h, w, _ = q.size()
126 | padded_w = x.size(3)
127 | c = x.size(1)
128 | # (b, c, h*w)
129 | x = x.contiguous().view(b, c, -1)
130 |
131 | # (b, h, w, N)
132 | index = q[..., :N]*padded_w + q[..., N:] # offset_x*w + offset_y
133 | # (b, c, h*w*N)
134 | index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)
135 |
136 | x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)
137 |
138 | return x_offset
139 |
140 | @staticmethod
141 | def _reshape_x_offset(x_offset, ks):
142 | b, c, h, w, N = x_offset.size()
143 | x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1)
144 | x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks)
145 |
146 | return x_offset
--------------------------------------------------------------------------------
/model/ST_Unet/model_resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 | from torch.nn import init
6 |
7 |
8 | def conv3x3(in_planes, out_planes, stride=1):
9 | "3x3 convolution with padding"
10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
11 | padding=1, bias=False)
12 |
13 |
14 | class BasicBlock(nn.Module):
15 | expansion = 1
16 |
17 | def __init__(self, inplanes, planes, stride=1, downsample=None, use_cbam=False):
18 | super(BasicBlock, self).__init__()
19 | self.conv1 = conv3x3(inplanes, planes, stride)
20 | self.bn1 = nn.BatchNorm2d(planes)
21 | self.relu = nn.ReLU(inplace=True)
22 | self.conv2 = conv3x3(planes, planes)
23 | self.bn2 = nn.BatchNorm2d(planes)
24 | self.downsample = downsample
25 | self.stride = stride
26 |
27 | if use_cbam:
28 | self.cbam = CBAM(planes, 16)
29 | else:
30 | self.cbam = None
31 |
32 | def forward(self, x):
33 | residual = x
34 |
35 | out = self.conv1(x)
36 | out = self.bn1(out)
37 | out = self.relu(out)
38 |
39 | out = self.conv2(out)
40 | out = self.bn2(out)
41 |
42 | if self.downsample is not None:
43 | residual = self.downsample(x)
44 |
45 | if not self.cbam is None:
46 | out = self.cbam(out)
47 |
48 | out += residual
49 | out = self.relu(out)
50 |
51 | return out
52 |
53 |
54 | class Bottleneck(nn.Module):
55 | expansion = 4
56 |
57 | def __init__(self, inplanes, planes, stride=1, downsample=None, use_cbam=False):
58 | super(Bottleneck, self).__init__()
59 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
60 | self.bn1 = nn.BatchNorm2d(planes)
61 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
62 | padding=1, bias=False)
63 | self.bn2 = nn.BatchNorm2d(planes)
64 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
65 | self.bn3 = nn.BatchNorm2d(planes * 4)
66 | self.relu = nn.ReLU(inplace=True)
67 | self.downsample = downsample
68 | self.stride = stride
69 |
70 | if use_cbam:
71 | self.cbam = CBAM(planes * 4, 16)
72 | else:
73 | self.cbam = None
74 |
75 | def forward(self, x):
76 | residual = x
77 |
78 | out = self.conv1(x)
79 | out = self.bn1(out)
80 | out = self.relu(out)
81 |
82 | out = self.conv2(out)
83 | out = self.bn2(out)
84 | out = self.relu(out)
85 |
86 | out = self.conv3(out)
87 | out = self.bn3(out)
88 |
89 | if self.downsample is not None:
90 | residual = self.downsample(x)
91 | out += residual
92 | out = self.relu(out)
93 |
94 | return out
95 |
96 |
97 | class ResNet(nn.Module):
98 | def __init__(self, block, layers, att_type=None):
99 | self.inplanes = 64
100 | super(ResNet, self).__init__()
101 |
102 | # different model config between ImageNet and CIFAR
103 |
104 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
106 | self.avgpool = nn.AvgPool2d(7)
107 |
108 | self.bn1 = nn.BatchNorm2d(64)
109 | self.relu = nn.ReLU(inplace=True)
110 |
111 | self.layer1 = self._make_layer(block, 64, layers[0], att_type=att_type)
112 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, att_type=att_type)
113 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, att_type=att_type)
114 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, att_type=att_type)
115 |
116 | init.kaiming_normal(self.fc.weight)
117 | for key in self.state_dict():
118 | if key.split('.')[-1] == "weight":
119 | if "conv" in key:
120 | init.kaiming_normal(self.state_dict()[key], mode='fan_out')
121 | if "bn" in key:
122 | if "SpatialGate" in key:
123 | self.state_dict()[key][...] = 0
124 | else:
125 | self.state_dict()[key][...] = 1
126 | elif key.split(".")[-1] == 'bias':
127 | self.state_dict()[key][...] = 0
128 |
129 | def _make_layer(self, block, planes, blocks, stride=1, att_type=None):
130 | downsample = None
131 | if stride != 1 or self.inplanes != planes * block.expansion:
132 | downsample = nn.Sequential(
133 | nn.Conv2d(self.inplanes, planes * block.expansion,
134 | kernel_size=1, stride=stride, bias=False),
135 | nn.BatchNorm2d(planes * block.expansion),
136 | )
137 | layers = []
138 | layers.append(block(self.inplanes, planes, stride, downsample, use_cbam=att_type == 'CBAM'))
139 | self.inplanes = planes * block.expansion
140 | for i in range(1, blocks):
141 | layers.append(block(self.inplanes, planes, use_cbam=att_type == 'CBAM'))
142 |
143 | return nn.Sequential(*layers)
144 |
145 | def forward(self, x):
146 | x = self.conv1(x)
147 | x = self.bn1(x)
148 | x = self.relu(x)
149 | x = self.maxpool(x)
150 |
151 | x = self.layer1(x)
152 | x = self.layer2(x)
153 | x = self.layer3(x)
154 | x = self.layer4(x)
155 | x = self.avgpool(x)
156 |
157 | x = x.view(x.size(0), -1)
158 | x = self.fc(x)
159 | return x
160 |
161 |
162 | def ResidualNet(depth, att_type):
163 | assert depth in [18, 34, 50, 101], 'network depth should be 18, 34, 50 or 101'
164 |
165 | if depth == 18:
166 | model = ResNet(BasicBlock, [2, 2, 2, 2], att_type)
167 |
168 | elif depth == 34:
169 | model = ResNet(BasicBlock, [3, 4, 6, 3], att_type)
170 |
171 | elif depth == 50:
172 | model = ResNet(Bottleneck, [3, 4, 6, 3], att_type)
173 |
174 | elif depth == 101:
175 | model = ResNet(Bottleneck, [3, 4, 23, 3], att_type)
176 |
177 | return model
178 |
--------------------------------------------------------------------------------
/model/ST_Unet/vit_seg_configs.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 |
3 |
4 | def get_b16_config():
5 | """Returns the ViT-B/16 configuration."""
6 | config = ml_collections.ConfigDict()
7 | config.patches = ml_collections.ConfigDict({'size': (16, 16)})
8 | config.hidden_size = 768
9 | config.transformer = ml_collections.ConfigDict()
10 | config.transformer.mlp_dim = 3072
11 | config.transformer.num_heads = 12
12 | config.transformer.num_layers = 1
13 | config.transformer.attention_dropout_rate = 0.0
14 | config.transformer.dropout_rate = 0.1
15 | # config.resnet.att_type = 'CBAM'
16 | config.classifier = 'seg'
17 | config.representation_size = None
18 | config.resnet_pretrained_path = None
19 | # config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz'
20 | config.patch_size = 16
21 |
22 | config.decoder_channels = (256, 128, 64, 16)
23 | config.n_classes = 2
24 | config.activation = 'softmax'
25 | return config
26 |
27 |
28 | def get_testing():
29 | """Returns a minimal configuration for testing."""
30 | config = ml_collections.ConfigDict()
31 | config.patches = ml_collections.ConfigDict({'size': (16, 16)})
32 | config.hidden_size = 1
33 | config.transformer = ml_collections.ConfigDict()
34 | config.transformer.mlp_dim = 1
35 | config.transformer.num_heads = 1
36 | config.transformer.num_layers = 1
37 | config.transformer.attention_dropout_rate = 0.0
38 | config.transformer.dropout_rate = 0.1
39 | config.classifier = 'token'
40 | config.representation_size = None
41 | return config
42 |
43 |
44 | def get_r50_b16_config():
45 | """Returns the Resnet50 + ViT-B/16 configuration.-------------------------wo yong de """
46 | config = get_b16_config()
47 |
48 | # 构建config.data容器,将不同的类型给放入进去
49 | config.data = ml_collections.ConfigDict()
50 | config.data.img_size = 256 # 6144
51 | config.data.in_chans = 3
52 |
53 | # 放入种类数目和相应的patch,就是256*256的图片划分成为4*4的patch结构,共256/4的数量
54 | config.n_classes = 6
55 | config.patches.grid = (4, 4)
56 |
57 | # 构建config.resnet容器,将不同的类型给放入进去
58 | config.resnet = ml_collections.ConfigDict()
59 | config.resnet.num_layers = (3, 4, 6, 3) # resnet的层数结构
60 | config.resnet.width_factor = 0.5
61 |
62 | config.classifier = 'seg' # 种类名称
63 |
64 | # 构建 config.trans容器,也就是辅助encoder(swin transformer)中的各个必要参数
65 | config.trans = ml_collections.ConfigDict()
66 | config.trans.num_heads = [3, 6, 12, 24] # 注意力的头的数目
67 | config.trans.depths = [2, 2, 6, 2] # swin transformer的网络结构深度
68 | config.trans.embed_dim = 96
69 | config.trans.window_size = 8
70 |
71 | # config.pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' #yuxunlian
72 |
73 | # (256,128,64,16)#
74 | # #1024,512,256,128,64)
75 | # #(2048,1024,512,256,128)
76 | # #(256, 128, 64, 16)
77 | # 解码的通道数
78 | config.decoder_channels = (512, 256, 128, 64)
79 |
80 | # 链接的通道数量
81 | # [256,128,64,16]#[512,256,128,64,16]#[512,256,128,64,32]#[1024,512,256,128,64]#[512, 256, 64, 16]
82 | config.skip_channels = [512, 256, 128, 64]
83 |
84 | config.n_classes = 6 # 分类的个数
85 | config.n_skip = 4 # 链接的次数,或者直接理解成阶段数
86 | config.activation = 'softmax'
87 |
88 | return config
89 |
90 |
91 | def get_b32_config():
92 | """Returns the ViT-B/32 configuration."""
93 | config = get_b16_config()
94 | config.patches.size = (32, 32)
95 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz'
96 | return config
97 |
98 |
99 | def get_l16_config():
100 | """Returns the ViT-L/16 configuration."""
101 | config = ml_collections.ConfigDict()
102 | config.patches = ml_collections.ConfigDict({'size': (16, 16)})
103 | config.hidden_size = 1024
104 | config.transformer = ml_collections.ConfigDict()
105 | config.transformer.mlp_dim = 4096
106 | config.transformer.num_heads = 16
107 | config.transformer.num_layers = 24
108 | config.transformer.attention_dropout_rate = 0.0
109 | config.transformer.dropout_rate = 0.1
110 | config.representation_size = None
111 |
112 | # custom
113 | config.classifier = 'seg'
114 | config.resnet_pretrained_path = None
115 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz'
116 | config.decoder_channels = (256, 128, 64, 16)
117 | config.n_classes = 2
118 | config.activation = 'softmax'
119 | return config
120 |
121 |
122 | def get_r50_l16_config():
123 | """Returns the Resnet50 + ViT-L/16 configuration. customized """
124 | config = get_l16_config()
125 | config.patches.grid = (16, 16)
126 | config.resnet = ml_collections.ConfigDict()
127 | config.resnet.num_layers = (3, 4, 9)
128 | config.resnet.width_factor = 1
129 |
130 | config.classifier = 'seg'
131 | config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz'
132 | config.decoder_channels = (256, 128, 64, 16)
133 | config.skip_channels = [512, 256, 64, 16]
134 | config.n_classes = 2
135 | config.activation = 'softmax'
136 | return config
137 |
138 |
139 | def get_l32_config():
140 | """Returns the ViT-L/32 configuration."""
141 | config = get_l16_config()
142 | config.patches.size = (32, 32)
143 | return config
144 |
145 |
146 | def get_h14_config():
147 | """Returns the ViT-L/16 configuration."""
148 | config = ml_collections.ConfigDict()
149 | config.patches = ml_collections.ConfigDict({'size': (14, 14)})
150 | config.hidden_size = 1280
151 | config.transformer = ml_collections.ConfigDict()
152 | config.transformer.mlp_dim = 5120
153 | config.transformer.num_heads = 16
154 | config.transformer.num_layers = 32
155 | config.transformer.attention_dropout_rate = 0.0
156 | config.transformer.dropout_rate = 0.1
157 | config.classifier = 'token'
158 | config.representation_size = None
159 |
160 | return config
161 |
--------------------------------------------------------------------------------
/model/ST_Unet/vit_seg_modeling.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import copy
7 | import logging
8 | import math
9 |
10 | from os.path import join as pjoin
11 |
12 | import torch
13 | import torch.nn as nn
14 | import numpy as np
15 |
16 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
17 | from torch.nn.modules.utils import _pair
18 | from scipy import ndimage
19 |
20 | # import vit_seg_configs as configs
21 | # from vit_seg_modeling_resnet_skip import TransResNetV2
22 | # from model_resnet import *
23 |
24 | from model.ST_Unet import vit_seg_configs as configs
25 | from model.ST_Unet.vit_seg_modeling_resnet_skip import TransResNetV2
26 | from model.ST_Unet.model_resnet import *
27 |
28 | # 忽略相应的警告
29 | import warnings
30 |
31 | warnings.filterwarnings("ignore")
32 |
33 | logger = logging.getLogger(__name__)
34 |
35 | ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
36 | ATTENTION_K = "MultiHeadDotProductAttention_1/key"
37 | ATTENTION_V = "MultiHeadDotProductAttention_1/value"
38 | ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
39 | FC_0 = "MlpBlock_3/Dense_0"
40 | FC_1 = "MlpBlock_3/Dense_1"
41 | ATTENTION_NORM = "LayerNorm_0"
42 | MLP_NORM = "LayerNorm_2"
43 |
44 |
45 | def np2th(weights, conv=False):
46 | """Possibly convert HWIO to OIHW."""
47 | if conv:
48 | weights = weights.transpose([3, 2, 0, 1])
49 | return torch.from_numpy(weights)
50 |
51 |
52 | def swish(x):
53 | return x * torch.sigmoid(x)
54 |
55 |
56 | ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
57 |
58 |
59 | class Attention(nn.Module):
60 | def __init__(self, config, vis):
61 | super(Attention, self).__init__()
62 | self.vis = vis
63 | self.num_attention_heads = config.transformer["num_heads"] # 12
64 | self.attention_head_size = int(config.hidden_size / self.num_attention_heads) # 768/12
65 | self.all_head_size = self.num_attention_heads * self.attention_head_size
66 |
67 | self.query = Linear(config.hidden_size, self.all_head_size)
68 | self.key = Linear(config.hidden_size, self.all_head_size)
69 | self.value = Linear(config.hidden_size, self.all_head_size)
70 |
71 | self.out = Linear(config.hidden_size, config.hidden_size)
72 | self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
73 | self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
74 |
75 | self.softmax = Softmax(dim=-1)
76 |
77 | def transpose_for_scores(self, x):
78 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
79 | x = x.view(*new_x_shape)
80 | return x.permute(0, 2, 1, 3)
81 |
82 | def forward(self, hidden_states):
83 | mixed_query_layer = self.query(hidden_states)
84 | mixed_key_layer = self.key(hidden_states)
85 | mixed_value_layer = self.value(hidden_states)
86 |
87 | query_layer = self.transpose_for_scores(mixed_query_layer)
88 | key_layer = self.transpose_for_scores(mixed_key_layer)
89 | value_layer = self.transpose_for_scores(mixed_value_layer)
90 |
91 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
92 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
93 | attention_probs = self.softmax(attention_scores)
94 | weights = attention_probs if self.vis else None
95 | attention_probs = self.attn_dropout(attention_probs)
96 |
97 | context_layer = torch.matmul(attention_probs, value_layer)
98 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
99 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
100 | context_layer = context_layer.view(*new_context_layer_shape)
101 | attention_output = self.out(context_layer)
102 | attention_output = self.proj_dropout(attention_output)
103 | return attention_output, weights
104 |
105 |
106 | class Mlp(nn.Module):
107 | def __init__(self, config):
108 | super(Mlp, self).__init__()
109 | self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
110 | self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
111 | self.act_fn = ACT2FN["gelu"]
112 | self.dropout = Dropout(config.transformer["dropout_rate"])
113 |
114 | self._init_weights()
115 |
116 | def _init_weights(self):
117 | nn.init.xavier_uniform_(self.fc1.weight)
118 | nn.init.xavier_uniform_(self.fc2.weight)
119 | nn.init.normal_(self.fc1.bias, std=1e-6)
120 | nn.init.normal_(self.fc2.bias, std=1e-6)
121 |
122 | def forward(self, x):
123 | x = self.fc1(x)
124 | x = self.act_fn(x)
125 | x = self.dropout(x)
126 | x = self.fc2(x)
127 | x = self.dropout(x)
128 | return x
129 |
130 |
131 | class Block(nn.Module):
132 | def __init__(self, config, vis):
133 | super(Block, self).__init__()
134 | self.hidden_size = config.hidden_size
135 | self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
136 | self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
137 | self.ffn = Mlp(config)
138 | self.attn = Attention(config, vis)
139 |
140 | def forward(self, x):
141 | h = x
142 | x = self.attention_norm(x)
143 | x, weights = self.attn(x)
144 | x = x + h
145 |
146 | h = x
147 | x = self.ffn_norm(x)
148 | x = self.ffn(x)
149 | x = x + h
150 | return x, weights
151 |
152 | def load_from(self, weights, n_block):
153 | ROOT = f"Transformer/encoderblock_{n_block}"
154 | with torch.no_grad():
155 | query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size,
156 | self.hidden_size).t()
157 | key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
158 | value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size,
159 | self.hidden_size).t()
160 | out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size,
161 | self.hidden_size).t()
162 |
163 | query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
164 | key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
165 | value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
166 | out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)
167 |
168 | self.attn.query.weight.copy_(query_weight)
169 | self.attn.key.weight.copy_(key_weight)
170 | self.attn.value.weight.copy_(value_weight)
171 | self.attn.out.weight.copy_(out_weight)
172 | self.attn.query.bias.copy_(query_bias)
173 | self.attn.key.bias.copy_(key_bias)
174 | self.attn.value.bias.copy_(value_bias)
175 | self.attn.out.bias.copy_(out_bias)
176 |
177 | mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
178 | mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
179 | mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
180 | mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()
181 |
182 | self.ffn.fc1.weight.copy_(mlp_weight_0)
183 | self.ffn.fc2.weight.copy_(mlp_weight_1)
184 | self.ffn.fc1.bias.copy_(mlp_bias_0)
185 | self.ffn.fc2.bias.copy_(mlp_bias_1)
186 |
187 | self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
188 | self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
189 | self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
190 | self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))
191 |
192 |
193 | class Encoder(nn.Module):
194 | def __init__(self, config, vis):
195 | super(Encoder, self).__init__()
196 | self.vis = vis
197 | self.layer = nn.ModuleList()
198 | self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
199 | for _ in range(config.transformer["num_layers"]):
200 | layer = Block(config, vis)
201 | self.layer.append(copy.deepcopy(layer))
202 |
203 | def forward(self, hidden_states):
204 | attn_weights = []
205 | for layer_block in self.layer:
206 | hidden_states, weights = layer_block(hidden_states)
207 | if self.vis:
208 | attn_weights.append(weights)
209 | # torch.Size([12, 256, 768])
210 | encoded = self.encoder_norm(hidden_states)
211 | return encoded, attn_weights
212 |
213 |
214 | class Transformer(nn.Module):
215 | def __init__(self, config, img_size):
216 | super(Transformer, self).__init__()
217 | img_size = _pair(img_size) # 构建(256,256)的数组
218 | # print(img_size)
219 | self.hybrid_model = TransResNetV2(config, block_units=config.resnet.num_layers,
220 | width_factor=config.resnet.width_factor) # 建立混合transformer和cnn的模型
221 |
222 | def forward(self, input):
223 | x, features = self.hybrid_model(input)
224 | return x, features
225 |
226 |
227 | class Conv2dReLU(nn.Sequential):
228 | def __init__(
229 | self,
230 | in_channels,
231 | out_channels,
232 | kernel_size,
233 | padding=0,
234 | stride=1,
235 | use_batchnorm=True,
236 | ):
237 | conv = nn.Conv2d(
238 | in_channels,
239 | out_channels,
240 | kernel_size,
241 | stride=stride,
242 | padding=padding,
243 | bias=not (use_batchnorm),
244 | )
245 | relu = nn.ReLU(inplace=True)
246 |
247 | bn = nn.BatchNorm2d(out_channels)
248 |
249 | super(Conv2dReLU, self).__init__(conv, bn, relu)
250 |
251 |
252 | class DecoderBlock(nn.Module):
253 | def __init__(
254 | self,
255 | in_channels,
256 | out_channels,
257 | skip_channels=0,
258 | use_batchnorm=True,
259 | ):
260 | super().__init__()
261 |
262 | self.conv1 = Conv2dReLU(
263 | in_channels // 2 + skip_channels,
264 | out_channels,
265 | kernel_size=3,
266 | padding=1,
267 | use_batchnorm=use_batchnorm,
268 | )
269 | self.conv2 = Conv2dReLU(
270 | out_channels,
271 | out_channels,
272 | kernel_size=3,
273 | padding=1,
274 | use_batchnorm=use_batchnorm,
275 | )
276 | self.conv3 = Conv2dReLU(
277 | in_channels // 2,
278 | in_channels // 2,
279 | kernel_size=3,
280 | padding=1,
281 | use_batchnorm=use_batchnorm,
282 | )
283 | self.conv4 = Conv2dReLU(
284 | 64,
285 | 64,
286 | kernel_size=3,
287 | padding=1,
288 | use_batchnorm=use_batchnorm,
289 | )
290 | self.up = nn.UpsamplingBilinear2d(scale_factor=2)
291 |
292 | # 逆卷积操作ConvTranspose2d
293 | self.conT = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
294 |
295 | self.conT1 = nn.ConvTranspose2d(64, 16, kernel_size=2, stride=2)
296 |
297 | def forward(self, x, skip=None):
298 | x = self.conT(x)
299 |
300 | if skip is not None:
301 | # skip = self.cbam(skip) #sptial attention
302 |
303 | x = torch.cat([x, skip], dim=1)
304 |
305 | x = self.conv1(x)
306 |
307 | x = self.conv2(x)
308 | else:
309 | x = self.conv3(x)
310 | x = self.conv4(x)
311 | x = self.conT1(x)
312 |
313 | return x
314 |
315 |
316 | class SegmentationHead(nn.Sequential):
317 |
318 | def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
319 | conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
320 | upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
321 | super().__init__(conv2d, upsampling)
322 |
323 |
324 | class DecoderCup(nn.Module):
325 | def __init__(self, config):
326 | super().__init__()
327 | self.config = config
328 | head_channels = 1024
329 | decoder_channels = config.decoder_channels
330 | in_channels = [head_channels] + list(decoder_channels[:-1])
331 | # print('-: ', decoder_channels)
332 | # in_channels=[512,512,256,128,64]
333 | out_channels = decoder_channels
334 | # (512, 256, 128, 64)
335 |
336 | if self.config.n_skip != 0:
337 | # print('self.config.n_skip',self.config.n_skip) 3
338 | skip_channels = self.config.skip_channels
339 | # print('self.config.n_skip', self.config.skip_channels)[512, 256, 64, 16]
340 | for i in range(4 - self.config.n_skip): # re-select the skip channels according to n_skip 4
341 | skip_channels[3 - i] = 0 # skip_channels[3]=0 i=0 3#[512,256,128,64,16])
342 | else:
343 | skip_channels = [0, 0, 0, 0]
344 |
345 | # print(in_channels,out_channels, skip_channels) #[512, 256, 128, 64] (256, 128, 64, 16) [512, 256, 64, 0]
346 | blocks = [
347 | DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
348 | ]
349 | self.blocks = nn.ModuleList(blocks)
350 |
351 | self.conv_more = Conv2dReLU(1024, 1024, kernel_size=3, padding=1, use_batchnorm=True)
352 |
353 | def forward(self, x, features=None):
354 | # B, n_patch, hidden = hidden_states.size() # [12, 256, 768] reshape from (B, n_patch, hidden) to (B, h, w, hidden)
355 |
356 | x = self.conv_more(x)
357 | for i, decoder_block in enumerate(self.blocks):
358 |
359 | if i == 3:
360 | continue # 将初始化的32通道给排除
361 |
362 | if features is not None:
363 | skip = features[i] if (i < self.config.n_skip) else None # config.n_skip = 3
364 | # print('ss:', skip.shape)
365 | # print('x:', x.shape)
366 | else:
367 | skip = None
368 | x = decoder_block(x, skip=skip)
369 |
370 | x = decoder_block(x, skip=None)
371 | return x
372 |
373 |
374 | class VisionTransformer(nn.Module):
375 | def __init__(self, config, img_size=256, num_classes=6, zero_head=False):
376 | super(VisionTransformer, self).__init__()
377 | self.num_classes = num_classes # 分类数量
378 | self.zero_head = zero_head # 头的数量
379 |
380 | self.classifier = config.classifier
381 | self.transformer = Transformer(config, img_size) # 构造transformer和cnn的混合函数
382 | self.decoder = DecoderCup(config)
383 | self.segmentation_head = SegmentationHead(
384 | in_channels=16,
385 | out_channels=config['n_classes'],
386 | kernel_size=3,
387 | )
388 | self.config = config
389 |
390 | def forward(self, x):
391 | # print('111', x.shape)
392 | x, features = self.transformer(x) # (B, n_patch, hidden)
393 |
394 | # print(x.shape, features.shape)
395 | x = self.decoder(x, features)
396 | # print(x.shape)
397 | logits = self.segmentation_head(x)
398 | return logits
399 |
400 |
401 | CONFIGS = {
402 | 'ViT-B_16': configs.get_b16_config(),
403 | 'ViT-B_32': configs.get_b32_config(),
404 | 'ViT-L_16': configs.get_l16_config(),
405 | 'ViT-L_32': configs.get_l32_config(),
406 | 'ViT-H_14': configs.get_h14_config(),
407 | 'R50-ViT-B_16': configs.get_r50_b16_config(),
408 | 'R50-ViT-L_16': configs.get_r50_l16_config(),
409 | 'testing': configs.get_testing(),
410 | }
411 |
412 | if __name__ == '__main__':
413 | config_vit = configs.get_r50_b16_config()
414 | model = VisionTransformer(config_vit, img_size=256, num_classes=6)
415 |
416 | image = torch.randn(32, 3, 256, 256)
417 |
418 | output = model(image)
419 | print("input:", image.shape)
420 | print("output:", output.shape)
421 |
--------------------------------------------------------------------------------
/model/SwinUnet/swin_transformer_unet_skip_expand_decoder_sys.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.utils.checkpoint as checkpoint
4 | from einops import rearrange
5 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
6 |
7 |
8 | class Mlp(nn.Module):
9 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
10 | super().__init__()
11 | out_features = out_features or in_features
12 | hidden_features = hidden_features or in_features
13 | self.fc1 = nn.Linear(in_features, hidden_features)
14 | self.act = act_layer()
15 | self.fc2 = nn.Linear(hidden_features, out_features)
16 | self.drop = nn.Dropout(drop)
17 |
18 | def forward(self, x):
19 | x = self.fc1(x)
20 | x = self.act(x)
21 | x = self.drop(x)
22 | x = self.fc2(x)
23 | x = self.drop(x)
24 | return x
25 |
26 |
27 | def window_partition(x, window_size):
28 | """
29 | Args:
30 | x: (B, H, W, C)
31 | window_size (int): window size
32 |
33 | Returns:
34 | windows: (num_windows*B, window_size, window_size, C)
35 | """
36 | B, H, W, C = x.shape
37 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
38 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
39 | return windows
40 |
41 |
42 | def window_reverse(windows, window_size, H, W):
43 | """
44 | Args:
45 | windows: (num_windows*B, window_size, window_size, C)
46 | window_size (int): Window size
47 | H (int): Height of image
48 | W (int): Width of image
49 |
50 | Returns:
51 | x: (B, H, W, C)
52 | """
53 | B = int(windows.shape[0] / (H * W / window_size / window_size))
54 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
55 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
56 | return x
57 |
58 |
59 | class WindowAttention(nn.Module):
60 | r""" Window based multi-head self attention (W-MSA) module with relative position bias.
61 | It supports both of shifted and non-shifted window.
62 |
63 | Args:
64 | dim (int): Number of input channels.
65 | window_size (tuple[int]): The height and width of the window.
66 | num_heads (int): Number of attention heads.
67 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
68 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
69 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
70 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0
71 | """
72 |
73 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
74 |
75 | super().__init__()
76 | self.dim = dim
77 | self.window_size = window_size # Wh, Ww
78 | self.num_heads = num_heads
79 | head_dim = dim // num_heads
80 | self.scale = qk_scale or head_dim ** -0.5
81 |
82 | # define a parameter table of relative position bias
83 | self.relative_position_bias_table = nn.Parameter(
84 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
85 |
86 | # get pair-wise relative position index for each token inside the window
87 | coords_h = torch.arange(self.window_size[0])
88 | coords_w = torch.arange(self.window_size[1])
89 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
90 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
91 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
92 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
93 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
94 | relative_coords[:, :, 1] += self.window_size[1] - 1
95 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
96 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
97 | self.register_buffer("relative_position_index", relative_position_index)
98 |
99 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
100 | self.attn_drop = nn.Dropout(attn_drop)
101 | self.proj = nn.Linear(dim, dim)
102 | self.proj_drop = nn.Dropout(proj_drop)
103 |
104 | trunc_normal_(self.relative_position_bias_table, std=.02)
105 | self.softmax = nn.Softmax(dim=-1)
106 |
107 | def forward(self, x, mask=None):
108 | """
109 | Args:
110 | x: input features with shape of (num_windows*B, N, C)
111 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
112 | """
113 | B_, N, C = x.shape
114 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
115 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
116 |
117 | q = q * self.scale
118 | attn = (q @ k.transpose(-2, -1))
119 |
120 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
121 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
122 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
123 | attn = attn + relative_position_bias.unsqueeze(0)
124 |
125 | if mask is not None:
126 | nW = mask.shape[0]
127 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
128 | attn = attn.view(-1, self.num_heads, N, N)
129 | attn = self.softmax(attn)
130 | else:
131 | attn = self.softmax(attn)
132 |
133 | attn = self.attn_drop(attn)
134 |
135 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
136 | x = self.proj(x)
137 | x = self.proj_drop(x)
138 | return x
139 |
140 | def extra_repr(self) -> str:
141 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
142 |
143 | def flops(self, N):
144 | # calculate flops for 1 window with token length of N
145 | flops = 0
146 | # qkv = self.qkv(x)
147 | flops += N * self.dim * 3 * self.dim
148 | # attn = (q @ k.transpose(-2, -1))
149 | flops += self.num_heads * N * (self.dim // self.num_heads) * N
150 | # x = (attn @ v)
151 | flops += self.num_heads * N * N * (self.dim // self.num_heads)
152 | # x = self.proj(x)
153 | flops += N * self.dim * self.dim
154 | return flops
155 |
156 |
157 | class SwinTransformerBlock(nn.Module):
158 | r""" Swin Transformer Block.
159 |
160 | Args:
161 | dim (int): Number of input channels.
162 | input_resolution (tuple[int]): Input resulotion.
163 | num_heads (int): Number of attention heads.
164 | window_size (int): Window size.
165 | shift_size (int): Shift size for SW-MSA.
166 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
167 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
168 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
169 | drop (float, optional): Dropout rate. Default: 0.0
170 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
171 | drop_path (float, optional): Stochastic depth rate. Default: 0.0
172 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
173 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
174 | """
175 |
176 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
177 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
178 | act_layer=nn.GELU, norm_layer=nn.LayerNorm):
179 | super().__init__()
180 | self.dim = dim
181 | self.input_resolution = input_resolution
182 | self.num_heads = num_heads
183 | self.window_size = window_size
184 | self.shift_size = shift_size
185 | self.mlp_ratio = mlp_ratio
186 | if min(self.input_resolution) <= self.window_size:
187 | # if window size is larger than input resolution, we don't partition windows
188 | self.shift_size = 0
189 | self.window_size = min(self.input_resolution)
190 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
191 |
192 | self.norm1 = norm_layer(dim)
193 | self.attn = WindowAttention(
194 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
195 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
196 |
197 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
198 | self.norm2 = norm_layer(dim)
199 | mlp_hidden_dim = int(dim * mlp_ratio)
200 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
201 |
202 | if self.shift_size > 0:
203 | # calculate attention mask for SW-MSA
204 | H, W = self.input_resolution
205 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
206 | h_slices = (slice(0, -self.window_size),
207 | slice(-self.window_size, -self.shift_size),
208 | slice(-self.shift_size, None))
209 | w_slices = (slice(0, -self.window_size),
210 | slice(-self.window_size, -self.shift_size),
211 | slice(-self.shift_size, None))
212 | cnt = 0
213 | for h in h_slices:
214 | for w in w_slices:
215 | img_mask[:, h, w, :] = cnt
216 | cnt += 1
217 |
218 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
219 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
220 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
221 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
222 | else:
223 | attn_mask = None
224 |
225 | self.register_buffer("attn_mask", attn_mask)
226 |
227 | def forward(self, x):
228 | H, W = self.input_resolution
229 | B, L, C = x.shape
230 | assert L == H * W, "input feature has wrong size"
231 |
232 | shortcut = x
233 | x = self.norm1(x)
234 | x = x.view(B, H, W, C)
235 |
236 | # cyclic shift
237 | if self.shift_size > 0:
238 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
239 | else:
240 | shifted_x = x
241 |
242 | # partition windows
243 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
244 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
245 |
246 | # W-MSA/SW-MSA
247 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
248 |
249 | # merge windows
250 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
251 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
252 |
253 | # reverse cyclic shift
254 | if self.shift_size > 0:
255 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
256 | else:
257 | x = shifted_x
258 | x = x.view(B, H * W, C)
259 |
260 | # FFN
261 | x = shortcut + self.drop_path(x)
262 | x = x + self.drop_path(self.mlp(self.norm2(x)))
263 |
264 | return x
265 |
266 | def extra_repr(self) -> str:
267 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
268 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
269 |
270 | def flops(self):
271 | flops = 0
272 | H, W = self.input_resolution
273 | # norm1
274 | flops += self.dim * H * W
275 | # W-MSA/SW-MSA
276 | nW = H * W / self.window_size / self.window_size
277 | flops += nW * self.attn.flops(self.window_size * self.window_size)
278 | # mlp
279 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
280 | # norm2
281 | flops += self.dim * H * W
282 | return flops
283 |
284 |
285 | class PatchMerging(nn.Module):
286 | r""" Patch Merging Layer.
287 |
288 | Args:
289 | input_resolution (tuple[int]): Resolution of input feature.
290 | dim (int): Number of input channels.
291 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
292 | """
293 |
294 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
295 | super().__init__()
296 | self.input_resolution = input_resolution
297 | self.dim = dim
298 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
299 | self.norm = norm_layer(4 * dim)
300 |
301 | def forward(self, x):
302 | """
303 | x: B, H*W, C
304 | """
305 | H, W = self.input_resolution
306 | B, L, C = x.shape
307 | assert L == H * W, "input feature has wrong size"
308 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
309 |
310 | x = x.view(B, H, W, C)
311 |
312 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
313 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
314 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
315 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
316 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
317 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
318 |
319 | x = self.norm(x)
320 | x = self.reduction(x)
321 |
322 | return x
323 |
324 | # repr() 函数,得到的字符串通常可以用来重新获得该对象,将对象转化为供解释器读取的形式。
325 | def extra_repr(self) -> str:
326 | return f"input_resolution={self.input_resolution}, dim={self.dim}"
327 |
328 | def flops(self):
329 | H, W = self.input_resolution
330 | flops = H * W * self.dim
331 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
332 | return flops
333 |
334 |
335 | class PatchExpand(nn.Module):
336 | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
337 | super().__init__()
338 | self.input_resolution = input_resolution
339 | self.dim = dim
340 | self.expand = nn.Linear(dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity()
341 | self.norm = norm_layer(dim // dim_scale)
342 |
343 | def forward(self, x):
344 | """
345 | x: B, H*W, C
346 | """
347 | H, W = self.input_resolution
348 | x = self.expand(x)
349 | B, L, C = x.shape
350 | assert L == H * W, "input feature has wrong size"
351 |
352 | x = x.view(B, H, W, C)
353 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4)
354 | x = x.view(B, -1, C // 4)
355 | x = self.norm(x)
356 |
357 | return x
358 |
359 |
360 | class FinalPatchExpand_X4(nn.Module):
361 | def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm):
362 | super().__init__()
363 | self.input_resolution = input_resolution
364 | self.dim = dim
365 | self.dim_scale = dim_scale
366 | self.expand = nn.Linear(dim, 16 * dim, bias=False)
367 | self.output_dim = dim
368 | self.norm = norm_layer(self.output_dim)
369 |
370 | def forward(self, x):
371 | """
372 | x: B, H*W, C
373 | """
374 | H, W = self.input_resolution
375 | x = self.expand(x)
376 | B, L, C = x.shape
377 | assert L == H * W, "input feature has wrong size"
378 |
379 | x = x.view(B, H, W, C)
380 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale,
381 | c=C // (self.dim_scale ** 2))
382 | x = x.view(B, -1, self.output_dim)
383 | x = self.norm(x)
384 |
385 | return x
386 |
387 |
388 | class BasicLayer(nn.Module):
389 | """ A basic Swin Transformer layer for one stage.
390 |
391 | Args:
392 | dim (int): Number of input channels.
393 | input_resolution (tuple[int]): Input resolution.
394 | depth (int): Number of blocks.
395 | num_heads (int): Number of attention heads.
396 | window_size (int): Local window size.
397 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
398 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
399 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
400 | drop (float, optional): Dropout rate. Default: 0.0
401 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
402 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
403 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
404 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
405 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
406 | """
407 |
408 | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
409 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
410 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
411 |
412 | super().__init__()
413 | self.dim = dim
414 | self.input_resolution = input_resolution
415 | self.depth = depth
416 | self.use_checkpoint = use_checkpoint
417 |
418 | # build blocks
419 | self.blocks = nn.ModuleList([
420 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
421 | num_heads=num_heads, window_size=window_size,
422 | shift_size=0 if (i % 2 == 0) else window_size // 2,
423 | mlp_ratio=mlp_ratio,
424 | qkv_bias=qkv_bias, qk_scale=qk_scale,
425 | drop=drop, attn_drop=attn_drop,
426 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
427 | norm_layer=norm_layer)
428 | for i in range(depth)])
429 |
430 | # patch merging layer
431 | if downsample is not None:
432 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
433 | else:
434 | self.downsample = None
435 |
436 | def forward(self, x):
437 | for blk in self.blocks:
438 | if self.use_checkpoint:
439 | x = checkpoint.checkpoint(blk, x)
440 | else:
441 | x = blk(x)
442 | if self.downsample is not None:
443 | x = self.downsample(x)
444 | return x
445 |
446 | def extra_repr(self) -> str:
447 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
448 |
449 | def flops(self):
450 | flops = 0
451 | for blk in self.blocks:
452 | flops += blk.flops()
453 | if self.downsample is not None:
454 | flops += self.downsample.flops()
455 | return flops
456 |
457 |
458 | class BasicLayer_up(nn.Module):
459 | """ A basic Swin Transformer layer for one stage.
460 |
461 | Args:
462 | dim (int): Number of input channels.
463 | input_resolution (tuple[int]): Input resolution.
464 | depth (int): Number of blocks.
465 | num_heads (int): Number of attention heads.
466 | window_size (int): Local window size.
467 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
468 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
469 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
470 | drop (float, optional): Dropout rate. Default: 0.0
471 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
472 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
473 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
474 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
475 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
476 | """
477 |
478 | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
479 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
480 | drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False):
481 |
482 | super().__init__()
483 | self.dim = dim
484 | self.input_resolution = input_resolution
485 | self.depth = depth
486 | self.use_checkpoint = use_checkpoint
487 |
488 | # build blocks
489 | self.blocks = nn.ModuleList([
490 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
491 | num_heads=num_heads, window_size=window_size,
492 | shift_size=0 if (i % 2 == 0) else window_size // 2,
493 | mlp_ratio=mlp_ratio,
494 | qkv_bias=qkv_bias, qk_scale=qk_scale,
495 | drop=drop, attn_drop=attn_drop,
496 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
497 | norm_layer=norm_layer)
498 | for i in range(depth)])
499 |
500 | # patch merging layer
501 | if upsample is not None:
502 | self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer)
503 | else:
504 | self.upsample = None
505 |
506 | def forward(self, x):
507 | for blk in self.blocks:
508 | if self.use_checkpoint:
509 | x = checkpoint.checkpoint(blk, x)
510 | else:
511 | x = blk(x)
512 | if self.upsample is not None:
513 | x = self.upsample(x)
514 | return x
515 |
516 |
517 | class PatchEmbed(nn.Module):
518 |
519 | def __init__(self, img_size=256, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
520 | super().__init__()
521 | img_size = to_2tuple(img_size)
522 | patch_size = to_2tuple(patch_size)
523 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
524 | self.img_size = img_size
525 | self.patch_size = patch_size
526 | self.patches_resolution = patches_resolution
527 | self.num_patches = patches_resolution[0] * patches_resolution[1]
528 |
529 | self.in_chans = in_chans
530 | self.embed_dim = embed_dim
531 |
532 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
533 | if norm_layer is not None:
534 | self.norm = norm_layer(embed_dim)
535 | else:
536 | self.norm = None
537 |
538 | def forward(self, x):
539 | B, C, H, W = x.shape
540 | # FIXME look at relaxing size constraints
541 | assert H == self.img_size[0] and W == self.img_size[1], \
542 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
543 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
544 | if self.norm is not None:
545 | x = self.norm(x)
546 | return x
547 |
548 | # 计算模型的复杂度
549 | def flops(self):
550 | Ho, Wo = self.patches_resolution
551 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
552 | if self.norm is not None:
553 | flops += Ho * Wo * self.embed_dim
554 | return flops
555 |
556 |
557 | class SwinTransformerSys(nn.Module):
558 | def __init__(self, img_size=256, patch_size=4, in_chans=3, num_classes=6,
559 | embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24],
560 | window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
561 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
562 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
563 | use_checkpoint=False, final_upsample="expand_first", **kwargs):
564 | super().__init__()
565 |
566 | self.num_classes = num_classes
567 | self.num_layers = len(depths)
568 | self.embed_dim = embed_dim
569 | self.ape = ape
570 | self.patch_norm = patch_norm
571 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
572 | self.num_features_up = int(embed_dim * 2)
573 | self.mlp_ratio = mlp_ratio
574 | self.final_upsample = final_upsample
575 |
576 | # split image into non-overlapping patches
577 | self.patch_embed = PatchEmbed(
578 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
579 | norm_layer=norm_layer if self.patch_norm else None)
580 | num_patches = self.patch_embed.num_patches
581 | patches_resolution = self.patch_embed.patches_resolution
582 | self.patches_resolution = patches_resolution
583 |
584 | # absolute position embedding
585 | if self.ape:
586 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
587 | trunc_normal_(self.absolute_pos_embed, std=.02)
588 |
589 | self.pos_drop = nn.Dropout(p=drop_rate)
590 |
591 | # stochastic depth
592 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
593 |
594 | # build encoder and bottleneck layers
595 | self.layers = nn.ModuleList()
596 | for i_layer in range(self.num_layers):
597 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
598 | input_resolution=(patches_resolution[0] // (2 ** i_layer),
599 | patches_resolution[1] // (2 ** i_layer)),
600 | depth=depths[i_layer],
601 | num_heads=num_heads[i_layer],
602 | window_size=window_size,
603 | mlp_ratio=self.mlp_ratio,
604 | qkv_bias=qkv_bias, qk_scale=qk_scale,
605 | drop=drop_rate, attn_drop=attn_drop_rate,
606 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
607 | norm_layer=norm_layer,
608 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
609 | use_checkpoint=use_checkpoint)
610 | self.layers.append(layer)
611 |
612 | # build decoder layers
613 | self.layers_up = nn.ModuleList()
614 | self.concat_back_dim = nn.ModuleList()
615 | for i_layer in range(self.num_layers):
616 | concat_linear = nn.Linear(2 * int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)),
617 | int(embed_dim * 2 ** (
618 | self.num_layers - 1 - i_layer))) if i_layer > 0 else nn.Identity()
619 | if i_layer == 0:
620 | layer_up = PatchExpand(
621 | input_resolution=(patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)),
622 | patches_resolution[1] // (2 ** (self.num_layers - 1 - i_layer))),
623 | dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)), dim_scale=2, norm_layer=norm_layer)
624 | else:
625 | layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)),
626 | input_resolution=(
627 | patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)),
628 | patches_resolution[1] // (2 ** (self.num_layers - 1 - i_layer))),
629 | depth=depths[(self.num_layers - 1 - i_layer)],
630 | num_heads=num_heads[(self.num_layers - 1 - i_layer)],
631 | window_size=window_size,
632 | mlp_ratio=self.mlp_ratio,
633 | qkv_bias=qkv_bias, qk_scale=qk_scale,
634 | drop=drop_rate, attn_drop=attn_drop_rate,
635 | drop_path=dpr[sum(depths[:(self.num_layers - 1 - i_layer)]):sum(
636 | depths[:(self.num_layers - 1 - i_layer) + 1])],
637 | norm_layer=norm_layer,
638 | upsample=PatchExpand if (i_layer < self.num_layers - 1) else None,
639 | use_checkpoint=use_checkpoint)
640 | self.layers_up.append(layer_up)
641 | self.concat_back_dim.append(concat_linear)
642 |
643 | self.norm = norm_layer(self.num_features)
644 | self.norm_up = norm_layer(self.embed_dim)
645 |
646 | if self.final_upsample == "expand_first":
647 | self.up = FinalPatchExpand_X4(input_resolution=(img_size // patch_size, img_size // patch_size),
648 | dim_scale=4, dim=embed_dim)
649 | self.output = nn.Conv2d(in_channels=embed_dim, out_channels=self.num_classes, kernel_size=1, bias=False)
650 |
651 | self.apply(self._init_weights)
652 |
653 | def _init_weights(self, m):
654 | if isinstance(m, nn.Linear):
655 | trunc_normal_(m.weight, std=.02)
656 | if isinstance(m, nn.Linear) and m.bias is not None:
657 | nn.init.constant_(m.bias, 0)
658 | elif isinstance(m, nn.LayerNorm):
659 | nn.init.constant_(m.bias, 0)
660 | nn.init.constant_(m.weight, 1.0)
661 |
662 | @torch.jit.ignore
663 | def no_weight_decay(self):
664 | return {'absolute_pos_embed'}
665 |
666 | @torch.jit.ignore
667 | def no_weight_decay_keywords(self):
668 | return {'relative_position_bias_table'}
669 |
670 | # Encoder and Bottleneck
671 | def forward_features(self, x):
672 | x = self.patch_embed(x)
673 | if self.ape:
674 | x = x + self.absolute_pos_embed
675 | x = self.pos_drop(x)
676 | x_downsample = []
677 |
678 | for layer in self.layers:
679 | x_downsample.append(x)
680 | x = layer(x)
681 |
682 | x = self.norm(x) # B L C
683 |
684 | return x, x_downsample
685 |
686 | # Dencoder and Skip connection
687 | def forward_up_features(self, x, x_downsample):
688 | for inx, layer_up in enumerate(self.layers_up):
689 | if inx == 0:
690 | x = layer_up(x)
691 | else:
692 | x = torch.cat([x, x_downsample[3 - inx]], -1)
693 | x = self.concat_back_dim[inx](x)
694 | x = layer_up(x)
695 |
696 | x = self.norm_up(x) # B L C
697 |
698 | return x
699 |
700 | def up_x4(self, x):
701 | H, W = self.patches_resolution
702 | B, L, C = x.shape
703 | assert L == H * W, "input features has wrong size"
704 |
705 | if self.final_upsample == "expand_first":
706 | x = self.up(x)
707 | x = x.view(B, 4 * H, 4 * W, -1)
708 | x = x.permute(0, 3, 1, 2) # B,C,H,W
709 | x = self.output(x)
710 |
711 | return x
712 |
713 | def forward(self, x):
714 | x, x_downsample = self.forward_features(x)
715 | x = self.forward_up_features(x, x_downsample)
716 | x = self.up_x4(x)
717 |
718 | return x
719 |
720 | def flops(self):
721 | flops = 0
722 | flops += self.patch_embed.flops()
723 | for i, layer in enumerate(self.layers):
724 | flops += layer.flops()
725 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
726 | flops += self.num_features * self.num_classes
727 | return flops
728 |
--------------------------------------------------------------------------------
/model/SwinUnet/vision_transformer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | from __future__ import absolute_import # 引入相对引入的概念
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import copy
7 | import logging
8 | import math
9 |
10 | from os.path import join as pjoin
11 |
12 | import torch
13 | import torch.nn as nn
14 | import numpy as np
15 |
16 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
17 | from torch.nn.modules.utils import _pair
18 | from scipy import ndimage
19 | from model.SwinUnet.swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSys
20 |
21 | # 忽略相应的警告
22 | import warnings
23 | warnings.filterwarnings("ignore")
24 |
25 |
26 | logger = logging.getLogger(__name__)
27 |
28 |
29 | class SwinUnet(nn.Module):
30 | def __init__(self, img_size=256, num_classes=6, zero_head=False, vis=False):
31 | super(SwinUnet, self).__init__()
32 | self.num_classes = num_classes
33 | self.zero_head = zero_head
34 |
35 | self.swin_unet = SwinTransformerSys(img_size=256,
36 | patch_size=4,
37 | in_chans=3,
38 | num_classes=self.num_classes,
39 | embed_dim=96,
40 | depths=[2, 2, 2, 2],
41 | num_heads=[3, 6, 12, 24],
42 | window_size=8,
43 | mlp_ratio=0.2,
44 | qkv_bias=True,
45 | qk_scale=0.,
46 | drop_rate=0.,
47 | drop_path_rate=0.1,
48 | ape=False,
49 | # patch_norm=[0, 0, 0, 0],
50 | use_checkpoint=False)
51 |
52 | def forward(self, x):
53 | if x.size()[1] == 1:
54 | x = x.repeat(1, 3, 1, 1)
55 | logits = self.swin_unet(x)
56 | return logits
57 |
58 | # # 加载相应的预训练模型
59 | # def load_from(self, config):
60 | # pretrained_path = config.MODEL.PRETRAIN_CKPT
61 | # if pretrained_path is not None:
62 | # print("pretrained_path:{}".format(pretrained_path))
63 | # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
64 | # pretrained_dict = torch.load(pretrained_path, map_location=device)
65 | # if "model" not in pretrained_dict:
66 | # print("---start load pretrained modle by splitting---")
67 | # pretrained_dict = {k[17:]: v for k, v in pretrained_dict.items()}
68 | # for k in list(pretrained_dict.keys()):
69 | # if "output" in k:
70 | # print("delete key:{}".format(k))
71 | # del pretrained_dict[k]
72 | # msg = self.swin_unet.load_state_dict(pretrained_dict, strict=False)
73 | # # print(msg)
74 | # return
75 | # pretrained_dict = pretrained_dict['model']
76 | # print("---start load pretrained modle of swin encoder---")
77 | #
78 | # model_dict = self.swin_unet.state_dict()
79 | # full_dict = copy.deepcopy(pretrained_dict)
80 | # for k, v in pretrained_dict.items():
81 | # if "layers." in k:
82 | # current_layer_num = 3 - int(k[7:8])
83 | # current_k = "layers_up." + str(current_layer_num) + k[8:]
84 | # full_dict.update({current_k: v})
85 | # for k in list(full_dict.keys()):
86 | # if k in model_dict:
87 | # if full_dict[k].shape != model_dict[k].shape:
88 | # print("delete:{};shape pretrain:{};shape model:{}".format(k, v.shape, model_dict[k].shape))
89 | # del full_dict[k]
90 | #
91 | # msg = self.swin_unet.load_state_dict(full_dict, strict=False)
92 | # # print(msg)
93 | # else:
94 | # print("none pretrain")
95 |
96 |
97 | if __name__ == '__main__':
98 | model = SwinUnet()
99 | model.eval()
100 | image = torch.randn(32, 3, 256, 256)
101 |
102 | output = model(image)
103 | print("input:", image.shape)
104 | print("output:", output.shape)
105 |
--------------------------------------------------------------------------------
/model/Swin_Transformer/SwinT.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.utils.checkpoint as checkpoint
5 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
6 | import numpy as np
7 | from einops import rearrange
8 |
9 | import warnings
10 |
11 | warnings.filterwarnings("ignore")
12 |
13 |
14 | class Mlp(nn.Module):
15 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
16 | super().__init__()
17 | out_features = out_features or in_features
18 | hidden_features = hidden_features or in_features
19 | self.fc1 = nn.Linear(in_features, hidden_features)
20 | self.act = act_layer()
21 | self.fc2 = nn.Linear(hidden_features, out_features)
22 | self.drop = nn.Dropout(drop)
23 |
24 | def forward(self, x):
25 | x = self.fc1(x)
26 | x = self.act(x)
27 | x = self.drop(x)
28 | x = self.fc2(x)
29 | x = self.drop(x)
30 | return x
31 |
32 |
33 | def window_partition(x, window_size):
34 | """
35 | Args:
36 | x: (B, H, W, C)
37 | window_size (int): window size
38 |
39 | Returns:
40 | windows: (num_windows*B, window_size, window_size, C)
41 | """
42 | B, H, W, C = x.shape
43 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
44 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
45 | return windows
46 |
47 |
48 | def window_reverse(windows, window_size, H, W):
49 | """
50 | Args:
51 | windows: (num_windows*B, window_size, window_size, C)
52 | window_size (int): Window size
53 | H (int): Height of image
54 | W (int): Width of image
55 |
56 | Returns:
57 | x: (B, H, W, C)
58 | """
59 | B = int(windows.shape[0] / (H * W / window_size / window_size))
60 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
61 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
62 | return x
63 |
64 |
65 | class WindowAttention(nn.Module):
66 | r""" Window based multi-head self attention (W-MSA) module with relative position bias.
67 | It supports both of shifted and non-shifted window.
68 |
69 | Args:
70 | dim (int): Number of input channels.
71 | window_size (tuple[int]): The height and width of the window.
72 | num_heads (int): Number of attention heads.
73 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
74 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
75 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0
76 | pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
77 | """
78 |
79 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
80 | pretrained_window_size=[0, 0]):
81 |
82 | super().__init__()
83 | self.dim = dim
84 | self.window_size = window_size # Wh, Ww
85 | self.pretrained_window_size = pretrained_window_size
86 | self.num_heads = num_heads
87 |
88 | # 先生成tabel的值
89 | self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
90 |
91 | # mlp to generate continuous relative position bias
92 | self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
93 | nn.ReLU(inplace=True),
94 | nn.Linear(512, num_heads, bias=False))
95 |
96 | # 在进行相应的相对tabel位置的索引
97 | relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
98 | relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
99 | relative_coords_table = torch.stack(
100 | torch.meshgrid([relative_coords_h,
101 | relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
102 | if pretrained_window_size[0] > 0:
103 | relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
104 | relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
105 | else:
106 | relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
107 | relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
108 | relative_coords_table *= 8 # normalize to -8, 8
109 | relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
110 | torch.abs(relative_coords_table) + 1.0) / np.log2(8)
111 |
112 | self.register_buffer("relative_coords_table", relative_coords_table)
113 |
114 | # get pair-wise relative position index for each token inside the window
115 | coords_h = torch.arange(self.window_size[0])
116 | coords_w = torch.arange(self.window_size[1])
117 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
118 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
119 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
120 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
121 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
122 | relative_coords[:, :, 1] += self.window_size[1] - 1
123 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
124 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
125 | self.register_buffer("relative_position_index", relative_position_index)
126 |
127 | self.qkv = nn.Linear(dim, dim * 3, bias=False)
128 | if qkv_bias:
129 | self.q_bias = nn.Parameter(torch.zeros(dim))
130 | self.v_bias = nn.Parameter(torch.zeros(dim))
131 | else:
132 | self.q_bias = None
133 | self.v_bias = None
134 | self.attn_drop = nn.Dropout(attn_drop)
135 | self.proj = nn.Linear(dim, dim)
136 | self.proj_drop = nn.Dropout(proj_drop)
137 | self.softmax = nn.Softmax(dim=-1)
138 |
139 | def forward(self, x, mask=None):
140 | """
141 | Args:
142 | x: input features with shape of (num_windows*B, N, C)
143 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
144 | """
145 | B_, N, C = x.shape
146 | qkv_bias = None
147 | if self.q_bias is not None:
148 | qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
149 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
150 | qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
151 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
152 |
153 | # cosine attention
154 | attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
155 | logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp()
156 | attn = attn * logit_scale
157 |
158 | # 根据索引去取相应的值
159 | relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
160 | relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
161 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
162 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
163 | relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
164 | attn = attn + relative_position_bias.unsqueeze(0)
165 |
166 | # mask不为空,则进行的是SW-MSA操作,将mask加入
167 | if mask is not None:
168 | nW = mask.shape[0]
169 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
170 | attn = attn.view(-1, self.num_heads, N, N)
171 | attn = self.softmax(attn)
172 | else:
173 | attn = self.softmax(attn)
174 |
175 | attn = self.attn_drop(attn)
176 |
177 | # 这边是单纯的注意力机制,将相应的V
178 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
179 | x = self.proj(x)
180 | x = self.proj_drop(x)
181 | return x
182 |
183 | def extra_repr(self) -> str:
184 | return f'dim={self.dim}, window_size={self.window_size}, ' \
185 | f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
186 |
187 | def flops(self, N):
188 | # calculate flops for 1 window with token length of N
189 | flops = 0
190 | # qkv = self.qkv(x)
191 | flops += N * self.dim * 3 * self.dim
192 | # attn = (q @ k.transpose(-2, -1))
193 | flops += self.num_heads * N * (self.dim // self.num_heads) * N
194 | # x = (attn @ v)
195 | flops += self.num_heads * N * N * (self.dim // self.num_heads)
196 | # x = self.proj(x)
197 | flops += N * self.dim * self.dim
198 | return flops
199 |
200 |
201 | # 代码块的组成
202 | class SwinTransformerBlock(nn.Module):
203 | r""" Swin Transformer Block.
204 |
205 | Args:
206 | dim (int): Number of input channels.
207 | input_resolution (tuple[int]): Input resulotion.
208 | num_heads (int): Number of attention heads.
209 | window_size (int): Window size.
210 | shift_size (int): Shift size for SW-MSA.
211 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
212 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
213 | drop (float, optional): Dropout rate. Default: 0.0
214 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
215 | drop_path (float, optional): Stochastic depth rate. Default: 0.0
216 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
217 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
218 | pretrained_window_size (int): Window size in pre-training.
219 | """
220 |
221 | def __init__(self,
222 | dim,
223 | input_resolution,
224 | num_heads,
225 | window_size=7,
226 | shift_size=0,
227 | mlp_ratio=4.,
228 | qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
229 | act_layer=nn.GELU,
230 | norm_layer=nn.LayerNorm,
231 | pretrained_window_size=0):
232 | super().__init__()
233 | self.dim = dim
234 | self.input_resolution = input_resolution
235 | self.num_heads = num_heads
236 | self.window_size = window_size
237 | self.shift_size = shift_size
238 | self.mlp_ratio = mlp_ratio
239 | if min(self.input_resolution) <= self.window_size:
240 | # if window size is larger than input resolution, we don't partition windows
241 | self.shift_size = 0
242 | self.window_size = min(self.input_resolution)
243 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
244 |
245 | self.norm1 = norm_layer(dim)
246 |
247 | # 窗口注意力函数的实现
248 | self.attn = WindowAttention(
249 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
250 | qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
251 | pretrained_window_size=to_2tuple(pretrained_window_size))
252 |
253 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
254 | self.norm2 = norm_layer(dim)
255 | mlp_hidden_dim = int(dim * mlp_ratio)
256 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
257 |
258 | # 生成mask矩阵在需要的时候传进去
259 | if self.shift_size > 0:
260 | # calculate attention mask for SW-MSA
261 | H, W = self.input_resolution
262 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
263 | h_slices = (slice(0, -self.window_size),
264 | slice(-self.window_size, -self.shift_size),
265 | slice(-self.shift_size, None))
266 | w_slices = (slice(0, -self.window_size),
267 | slice(-self.window_size, -self.shift_size),
268 | slice(-self.shift_size, None))
269 | cnt = 0
270 | for h in h_slices:
271 | for w in w_slices:
272 | img_mask[:, h, w, :] = cnt
273 | cnt += 1
274 |
275 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
276 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
277 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
278 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
279 | else:
280 | attn_mask = None
281 |
282 | self.register_buffer("attn_mask", attn_mask)
283 |
284 | def forward(self, x):
285 | H, W = self.input_resolution
286 | B, L, C = x.shape
287 | assert L == H * W, "input feature has wrong size"
288 |
289 | shortcut = x
290 | x = x.view(B, H, W, C)
291 |
292 | # 窗口移动,shift_size大于0时成立,进行相应操作
293 | if self.shift_size > 0:
294 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
295 | else:
296 | shifted_x = x
297 |
298 | # partition windows,将一个patch(56*56)化为windows
299 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
300 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
301 |
302 | # W-MSA/SW-MSA进行选择
303 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
304 |
305 | # merge windows
306 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
307 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
308 |
309 | # reverse cyclic shift
310 | if self.shift_size > 0:
311 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
312 | else:
313 | x = shifted_x
314 | x = x.view(B, H * W, C)
315 | x = shortcut + self.drop_path(self.norm1(x))
316 |
317 | # FFN
318 | x = x + self.drop_path(self.norm2(self.mlp(x)))
319 |
320 | return x
321 |
322 | def extra_repr(self) -> str:
323 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
324 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
325 |
326 | def flops(self):
327 | flops = 0
328 | H, W = self.input_resolution
329 | # norm1
330 | flops += self.dim * H * W
331 | # W-MSA/SW-MSA
332 | nW = H * W / self.window_size / self.window_size
333 | flops += nW * self.attn.flops(self.window_size * self.window_size)
334 | # mlp
335 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
336 | # norm2
337 | flops += self.dim * H * W
338 | return flops
339 |
340 |
341 | # 扩大全局视野的作用,每隔2,提取相应的元素,将4组合并,C变为4C,再进行卷积操作变为2C
342 | class PatchMerging(nn.Module):
343 | r""" Patch Merging Layer.
344 |
345 | Args:
346 | input_resolution (tuple[int]): Resolution of input feature.
347 | dim (int): Number of input channels.
348 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
349 | """
350 |
351 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
352 | super().__init__()
353 | self.input_resolution = input_resolution
354 | self.dim = dim
355 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
356 | self.norm = norm_layer(2 * dim)
357 |
358 | def forward(self, x):
359 | """
360 | x: B, H*W, C
361 | """
362 | H, W = self.input_resolution
363 | B, L, C = x.shape
364 | assert L == H * W, "input feature has wrong size"
365 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
366 |
367 | x = x.view(B, H, W, C)
368 |
369 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
370 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
371 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
372 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
373 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
374 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
375 |
376 | # 进行相应的映射,将通道数减半
377 | x = self.reduction(x)
378 | x = self.norm(x)
379 |
380 | return x
381 |
382 | def extra_repr(self) -> str:
383 | return f"input_resolution={self.input_resolution}, dim={self.dim}"
384 |
385 | def flops(self):
386 | H, W = self.input_resolution
387 | flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
388 | flops += H * W * self.dim // 2
389 | return flops
390 |
391 |
392 | # 设置transformer的代码块
393 | class BasicLayer(nn.Module):
394 | """ A basic Swin Transformer layer for one stage.
395 |
396 | Args:
397 | dim (int): Number of input channels.
398 | input_resolution (tuple[int]): Input resolution.
399 | depth (int): Number of blocks.
400 | num_heads (int): Number of attention heads.
401 | window_size (int): Local window size.
402 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
403 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
404 | drop (float, optional): Dropout rate. Default: 0.0
405 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
406 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
407 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
408 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
409 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
410 | pretrained_window_size (int): Local window size in pre-training.
411 | """
412 |
413 | def __init__(self,
414 | dim,
415 | input_resolution,
416 | depth, num_heads,
417 | window_size,
418 | mlp_ratio=4.,
419 | qkv_bias=True,
420 | drop=0., attn_drop=0., drop_path=0.,
421 | norm_layer=nn.LayerNorm,
422 | downsample=None, use_checkpoint=False,
423 | pretrained_window_size=0):
424 |
425 | super().__init__()
426 | self.dim = dim
427 | self.input_resolution = input_resolution
428 | self.depth = depth
429 | self.use_checkpoint = use_checkpoint
430 |
431 | # 建立SwinTransformerBlock,也就是将W-MSA和SW-MSA进行相应的编码块
432 | self.blocks = nn.ModuleList([
433 | SwinTransformerBlock(dim=dim,
434 | input_resolution=input_resolution,
435 | num_heads=num_heads,
436 | window_size=window_size,
437 | shift_size=0 if (i % 2 == 0) else window_size // 2,
438 | mlp_ratio=mlp_ratio,
439 | qkv_bias=qkv_bias,
440 | drop=drop, attn_drop=attn_drop,
441 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
442 | norm_layer=norm_layer,
443 | pretrained_window_size=pretrained_window_size)
444 | for i in range(depth)])
445 |
446 | # patch merging layer
447 | if downsample is not None:
448 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
449 | else:
450 | self.downsample = None
451 |
452 | def forward(self, x):
453 | for blk in self.blocks:
454 | if self.use_checkpoint:
455 | x = checkpoint.checkpoint(blk, x)
456 | else:
457 | x = blk(x)
458 | if self.downsample is not None:
459 | x = self.downsample(x)
460 | return x
461 |
462 | def extra_repr(self) -> str:
463 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
464 |
465 | def flops(self):
466 | flops = 0
467 | for blk in self.blocks:
468 | flops += blk.flops()
469 | if self.downsample is not None:
470 | flops += self.downsample.flops()
471 | return flops
472 |
473 | def _init_respostnorm(self):
474 | for blk in self.blocks:
475 | nn.init.constant_(blk.norm1.bias, 0)
476 | nn.init.constant_(blk.norm1.weight, 0)
477 | nn.init.constant_(blk.norm2.bias, 0)
478 | nn.init.constant_(blk.norm2.weight, 0)
479 |
480 |
481 | class PatchEmbed(nn.Module):
482 | r""" Image to Patch Embedding
483 |
484 | Args:
485 | img_size (int): Image size. Default: 224.
486 | patch_size (int): Patch token size. Default: 4.
487 | in_chans (int): Number of input image channels. Default: 3.
488 | embed_dim (int): Number of linear projection output channels. Default: 96.
489 | norm_layer (nn.Module, optional): Normalization layer. Default: None
490 | """
491 |
492 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
493 | super().__init__()
494 | img_size = to_2tuple(img_size)
495 | patch_size = to_2tuple(patch_size)
496 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
497 | self.img_size = img_size
498 | self.patch_size = patch_size
499 | self.patches_resolution = patches_resolution
500 | self.num_patches = patches_resolution[0] * patches_resolution[1]
501 |
502 | self.in_chans = in_chans
503 | self.embed_dim = embed_dim
504 |
505 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
506 | if norm_layer is not None:
507 | self.norm = norm_layer(embed_dim)
508 | else:
509 | self.norm = None
510 |
511 | def forward(self, x):
512 | B, C, H, W = x.shape
513 | # FIXME look at relaxing size constraints
514 | assert H == self.img_size[0] and W == self.img_size[1], \
515 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
516 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
517 | if self.norm is not None:
518 | x = self.norm(x)
519 | return x
520 |
521 | def flops(self):
522 | Ho, Wo = self.patches_resolution
523 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
524 | if self.norm is not None:
525 | flops += Ho * Wo * self.embed_dim
526 | return flops
527 |
528 |
529 | class Decoder_block(nn.Module):
530 | def __init__(self):
531 | super().__init__()
532 | self.upsample_4 = nn.Sequential(
533 | nn.ConvTranspose2d(in_channels=768, out_channels=384, kernel_size=4, stride=2, padding=1)
534 | )
535 | self.stage_up_4 = nn.Sequential(
536 | nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1),
537 | nn.BatchNorm2d(384),
538 | nn.ReLU(),
539 | nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1),
540 | nn.BatchNorm2d(384),
541 | nn.ReLU()
542 | )
543 |
544 | self.upsample_3 = nn.Sequential(
545 | nn.ConvTranspose2d(in_channels=384, out_channels=192, kernel_size=4, stride=2, padding=1)
546 | )
547 | self.stage_up_3 = nn.Sequential(
548 | nn.Conv2d(in_channels=192, out_channels=192, kernel_size=3, padding=1),
549 | nn.BatchNorm2d(192),
550 | nn.ReLU(),
551 | nn.Conv2d(in_channels=192, out_channels=192, kernel_size=3, padding=1),
552 | nn.BatchNorm2d(192),
553 | nn.ReLU()
554 | )
555 |
556 | self.upsample_2 = nn.Sequential(
557 | nn.ConvTranspose2d(in_channels=192, out_channels=96, kernel_size=4, stride=2, padding=1)
558 | )
559 | self.stage_up_2 = nn.Sequential(
560 | nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, padding=1),
561 | nn.BatchNorm2d(96),
562 | nn.ReLU(),
563 | nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, padding=1),
564 | nn.BatchNorm2d(96),
565 | nn.ReLU()
566 | )
567 |
568 | self.upsample_1 = nn.Sequential(
569 | nn.ConvTranspose2d(in_channels=96, out_channels=48, kernel_size=4, stride=2, padding=1)
570 | )
571 | self.stage_up_1 = nn.Sequential(
572 | nn.Conv2d(in_channels=48, out_channels=48, kernel_size=3, padding=1),
573 | nn.BatchNorm2d(48),
574 | nn.ReLU(),
575 | nn.Conv2d(in_channels=48, out_channels=48, kernel_size=3, padding=1),
576 | nn.BatchNorm2d(48),
577 | nn.ReLU()
578 | )
579 |
580 | self.conT1 = nn.ConvTranspose2d(48, 16, kernel_size=2, stride=2)
581 |
582 | def forward(self, x):
583 | x = self.upsample_4(x)
584 | x = self.stage_up_4(x)
585 | x = self.upsample_3(x)
586 | x = self.stage_up_3(x)
587 | x = self.upsample_2(x)
588 | x = self.stage_up_2(x)
589 | x = self.upsample_1(x)
590 | x = self.stage_up_1(x)
591 |
592 | x = self.conT1(x)
593 | return x
594 |
595 |
596 | class SegmentationHead(nn.Sequential):
597 |
598 | def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
599 | conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
600 | upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
601 | super().__init__(conv2d, upsampling)
602 |
603 |
604 | class SwinTransformerV2(nn.Module):
605 | r""" Swin Transformer
606 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
607 | https://arxiv.org/pdf/2103.14030
608 |
609 | Args:
610 | img_size (int | tuple(int)): Input image size. Default 224
611 | patch_size (int | tuple(int)): Patch size. Default: 4
612 | in_chans (int): Number of input image channels. Default: 3
613 | num_classes (int): Number of classes for classification head. Default: 1000
614 | embed_dim (int): Patch embedding dimension. Default: 96
615 | depths (tuple(int)): Depth of each Swin Transformer layer.
616 | num_heads (tuple(int)): Number of attention heads in different layers.
617 | window_size (int): Window size. Default: 7
618 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
619 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
620 | drop_rate (float): Dropout rate. Default: 0
621 | attn_drop_rate (float): Attention dropout rate. Default: 0
622 | drop_path_rate (float): Stochastic depth rate. Default: 0.1
623 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
624 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
625 | patch_norm (bool): If True, add normalization after patch embedding. Default: True
626 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
627 | pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer.
628 | """
629 |
630 | def __init__(self, img_size=256,
631 | patch_size=4,
632 | in_chans=3,
633 | num_classes=12,
634 | embed_dim=96,
635 | depths=[2, 2, 6, 2],
636 | num_heads=[3, 6, 12, 24],
637 | window_size=8,
638 | mlp_ratio=4.,
639 | qkv_bias=True,
640 | drop_rate=0.,
641 | attn_drop_rate=0.,
642 | drop_path_rate=0.1,
643 | norm_layer=nn.LayerNorm,
644 | ape=False,
645 | patch_norm=True,
646 | use_checkpoint=False,
647 | pretrained_window_sizes=[0, 0, 0, 0],
648 | **kwargs):
649 | super().__init__()
650 |
651 | self.num_classes = num_classes
652 | self.num_layers = len(depths)
653 | self.embed_dim = embed_dim
654 | self.ape = ape
655 | self.patch_norm = patch_norm
656 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
657 | self.mlp_ratio = mlp_ratio
658 |
659 | # 将224*224的图像4*4的patch大小进行相应的划分成56*56个
660 | self.patch_embed = PatchEmbed(
661 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
662 | norm_layer=norm_layer if self.patch_norm else None)
663 |
664 | # patch的总数量
665 | num_patches = self.patch_embed.num_patches
666 |
667 | # patch的相应方阵数量矩阵
668 | patches_resolution = self.patch_embed.patches_resolution
669 | self.patches_resolution = patches_resolution
670 |
671 | # 是否加入相对位置信息的嵌入
672 | if self.ape:
673 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
674 | trunc_normal_(self.absolute_pos_embed, std=.02)
675 |
676 | # 防止模型过拟合和欠拟合,p根据相应情况进行调整,一般属于0.4-0.6之间
677 | self.pos_drop = nn.Dropout(p=drop_rate)
678 |
679 | # 随机深度
680 | # 返回0到0.1之间的均匀间隔的参数值,数量和sum(depths)有关
681 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
682 |
683 | # 构建相应的w_MSA和SW_MSA的块
684 | # nn.ModuleList()以列表的形式保存多个子模块
685 | self.layers = nn.ModuleList()
686 |
687 | # 一共为4个阶段,整体采用len(depths),来进行确定
688 | for i_layer in range(self.num_layers):
689 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), # 不同阶段的维度数
690 | input_resolution=(patches_resolution[0] // (2 ** i_layer),
691 | patches_resolution[1] // (2 ** i_layer)),
692 | depth=depths[i_layer], # transformer在这个阶段的个数
693 | num_heads=num_heads[i_layer], # 注意力机制的head的数目
694 | window_size=window_size, # 窗口的大小
695 | mlp_ratio=self.mlp_ratio, # mlp的比例大小
696 | qkv_bias=qkv_bias, # 偏差值
697 | drop=drop_rate, attn_drop=attn_drop_rate, # 下降的比列
698 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
699 | norm_layer=norm_layer,
700 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
701 | # 进行相应的PatchMerging操作,把通道数减半
702 | use_checkpoint=use_checkpoint,
703 | pretrained_window_size=pretrained_window_sizes[i_layer])
704 | self.layers.append(layer)
705 |
706 | self.norm = norm_layer(self.num_features)
707 | self.avgpool = nn.AdaptiveAvgPool1d(1) # 进行池化操作,将相应的
708 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
709 |
710 | self.Decoder_block = Decoder_block()
711 | self.segmentation_head = SegmentationHead(
712 | in_channels=16,
713 | out_channels=6,
714 | kernel_size=3,
715 | )
716 |
717 | self.apply(self._init_weights)
718 | for bly in self.layers:
719 | bly._init_respostnorm()
720 |
721 | def _init_weights(self, m):
722 | if isinstance(m, nn.Linear):
723 | trunc_normal_(m.weight, std=.02)
724 | if isinstance(m, nn.Linear) and m.bias is not None:
725 | nn.init.constant_(m.bias, 0)
726 | elif isinstance(m, nn.LayerNorm):
727 | nn.init.constant_(m.bias, 0)
728 | nn.init.constant_(m.weight, 1.0)
729 |
730 | @torch.jit.ignore
731 | def no_weight_decay(self):
732 | return {'absolute_pos_embed'}
733 |
734 | @torch.jit.ignore
735 | def no_weight_decay_keywords(self):
736 | return {"cpb_mlp", "logit_scale", 'relative_position_bias_table'}
737 |
738 | def forward_features(self, x):
739 | B = x.shape[0]
740 | x = self.patch_embed(x)
741 | if self.ape:
742 | x = x + self.absolute_pos_embed
743 | x = self.pos_drop(x)
744 |
745 | for layer in self.layers:
746 | x = layer(x)
747 |
748 | x = self.norm(x)
749 | # torch.Size([32, 49, 768])
750 | x = x.contiguous().view(B, 8, 8, 768)
751 | x = rearrange(x, 'b h w c-> b c h w')
752 |
753 | x = self.Decoder_block(x)
754 |
755 | # x = self.norm(x) # B L C
756 | # x = self.avgpool(x.transpose(1, 2)) # B C 1
757 | # x = torch.flatten(x, 1)
758 |
759 | return x
760 |
761 | def forward(self, x):
762 | x = self.forward_features(x)
763 | logits = self.segmentation_head(x)
764 | return logits
765 |
766 | def flops(self):
767 | flops = 0
768 | flops += self.patch_embed.flops()
769 | for i, layer in enumerate(self.layers):
770 | flops += layer.flops()
771 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
772 | flops += self.num_features * self.num_classes
773 | return flops
774 |
775 |
776 | if __name__ == '__main__':
777 | model = SwinTransformerV2()
778 | model.eval()
779 | image = torch.randn(32, 3, 256, 256)
780 |
781 | output = model(image)
782 | print("input:", image.shape)
783 | print("output:", output.shape)
784 |
--------------------------------------------------------------------------------
/model/TransUnet/vit_seg_configs.py:
--------------------------------------------------------------------------------
1 | import ml_collections
2 |
3 | def get_b16_config():
4 | """Returns the ViT-B/16 configuration."""
5 | config = ml_collections.ConfigDict()
6 | config.patches = ml_collections.ConfigDict({'size': (16, 16)})
7 | config.hidden_size = 768
8 | config.transformer = ml_collections.ConfigDict()
9 | config.transformer.mlp_dim = 3072
10 | config.transformer.num_heads = 12
11 | config.transformer.num_layers = 12
12 | config.transformer.attention_dropout_rate = 0.0
13 | config.transformer.dropout_rate = 0.1
14 |
15 | config.classifier = 'seg'
16 | config.representation_size = None
17 | config.resnet_pretrained_path = None
18 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz'
19 | config.patch_size = 16
20 |
21 | config.decoder_channels = (256, 128, 64, 16)
22 | config.n_classes = 2
23 | config.activation = 'softmax'
24 | return config
25 |
26 |
27 | def get_testing():
28 | """Returns a minimal configuration for testing."""
29 | config = ml_collections.ConfigDict()
30 | config.patches = ml_collections.ConfigDict({'size': (16, 16)})
31 | config.hidden_size = 1
32 | config.transformer = ml_collections.ConfigDict()
33 | config.transformer.mlp_dim = 1
34 | config.transformer.num_heads = 1
35 | config.transformer.num_layers = 1
36 | config.transformer.attention_dropout_rate = 0.0
37 | config.transformer.dropout_rate = 0.1
38 | config.classifier = 'token'
39 | config.representation_size = None
40 | return config
41 |
42 | def get_r50_b16_config():
43 | """Returns the Resnet50 + ViT-B/16 configuration."""
44 | config = get_b16_config()
45 | config.patches.grid = (16, 16)
46 | config.resnet = ml_collections.ConfigDict()
47 | config.resnet.num_layers = (3, 4, 9)
48 | config.resnet.width_factor = 1
49 |
50 | config.classifier = 'seg'
51 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz'
52 | config.decoder_channels = (256, 128, 64, 16)
53 | config.skip_channels = [512, 256, 64, 16]
54 | config.n_classes = 6
55 | config.n_skip = 3
56 | config.activation = 'softmax'
57 |
58 | return config
59 |
60 |
61 | def get_b32_config():
62 | """Returns the ViT-B/32 configuration."""
63 | config = get_b16_config()
64 | config.patches.size = (32, 32)
65 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz'
66 | return config
67 |
68 |
69 | def get_l16_config():
70 | """Returns the ViT-L/16 configuration."""
71 | config = ml_collections.ConfigDict()
72 | config.patches = ml_collections.ConfigDict({'size': (16, 16)})
73 | config.hidden_size = 1024
74 | config.transformer = ml_collections.ConfigDict()
75 | config.transformer.mlp_dim = 4096
76 | config.transformer.num_heads = 16
77 | config.transformer.num_layers = 24
78 | config.transformer.attention_dropout_rate = 0.0
79 | config.transformer.dropout_rate = 0.1
80 | config.representation_size = None
81 |
82 | # custom
83 | config.classifier = 'seg'
84 | config.resnet_pretrained_path = None
85 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz'
86 | config.decoder_channels = (256, 128, 64, 16)
87 | config.n_classes = 2
88 | config.activation = 'softmax'
89 | return config
90 |
91 |
92 | def get_r50_l16_config():
93 | """Returns the Resnet50 + ViT-L/16 configuration. customized """
94 | config = get_l16_config()
95 | config.patches.grid = (16, 16)
96 | config.resnet = ml_collections.ConfigDict()
97 | config.resnet.num_layers = (3, 4, 9)
98 | config.resnet.width_factor = 1
99 |
100 | config.classifier = 'seg'
101 | config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz'
102 | config.decoder_channels = (256, 128, 64, 16)
103 | config.skip_channels = [512, 256, 64, 16]
104 | config.n_classes = 2
105 | config.activation = 'softmax'
106 | return config
107 |
108 |
109 | def get_l32_config():
110 | """Returns the ViT-L/32 configuration."""
111 | config = get_l16_config()
112 | config.patches.size = (32, 32)
113 | return config
114 |
115 |
116 | def get_h14_config():
117 | """Returns the ViT-L/16 configuration."""
118 | config = ml_collections.ConfigDict()
119 | config.patches = ml_collections.ConfigDict({'size': (14, 14)})
120 | config.hidden_size = 1280
121 | config.transformer = ml_collections.ConfigDict()
122 | config.transformer.mlp_dim = 5120
123 | config.transformer.num_heads = 16
124 | config.transformer.num_layers = 32
125 | config.transformer.attention_dropout_rate = 0.0
126 | config.transformer.dropout_rate = 0.1
127 | config.classifier = 'token'
128 | config.representation_size = None
129 |
130 | return config
131 |
--------------------------------------------------------------------------------
/model/TransUnet/vit_seg_modeling.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import copy
7 | import logging
8 | import math
9 |
10 | from os.path import join as pjoin
11 |
12 | import torch
13 | import torch.nn as nn
14 | import numpy as np
15 |
16 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
17 | from torch.nn.modules.utils import _pair
18 | from scipy import ndimage
19 | from model.TransUnet import vit_seg_configs as configs
20 | from model.TransUnet.vit_seg_modeling_resnet_skip import ResNetV2
21 |
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 |
26 | ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
27 | ATTENTION_K = "MultiHeadDotProductAttention_1/key"
28 | ATTENTION_V = "MultiHeadDotProductAttention_1/value"
29 | ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
30 | FC_0 = "MlpBlock_3/Dense_0"
31 | FC_1 = "MlpBlock_3/Dense_1"
32 | ATTENTION_NORM = "LayerNorm_0"
33 | MLP_NORM = "LayerNorm_2"
34 |
35 |
36 | def np2th(weights, conv=False):
37 | """Possibly convert HWIO to OIHW."""
38 | if conv:
39 | weights = weights.transpose([3, 2, 0, 1])
40 | return torch.from_numpy(weights)
41 |
42 |
43 | def swish(x):
44 | return x * torch.sigmoid(x)
45 |
46 |
47 | ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
48 |
49 |
50 | class Attention(nn.Module):
51 | def __init__(self, config, vis):
52 | super(Attention, self).__init__()
53 | self.vis = vis
54 | self.num_attention_heads = config.transformer["num_heads"]
55 | self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
56 | self.all_head_size = self.num_attention_heads * self.attention_head_size
57 |
58 | self.query = Linear(config.hidden_size, self.all_head_size)
59 | self.key = Linear(config.hidden_size, self.all_head_size)
60 | self.value = Linear(config.hidden_size, self.all_head_size)
61 |
62 | self.out = Linear(config.hidden_size, config.hidden_size)
63 | self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
64 | self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
65 |
66 | self.softmax = Softmax(dim=-1)
67 |
68 | def transpose_for_scores(self, x):
69 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
70 | x = x.view(*new_x_shape)
71 | return x.permute(0, 2, 1, 3)
72 |
73 | def forward(self, hidden_states):
74 | mixed_query_layer = self.query(hidden_states)
75 | mixed_key_layer = self.key(hidden_states)
76 | mixed_value_layer = self.value(hidden_states)
77 |
78 | query_layer = self.transpose_for_scores(mixed_query_layer)
79 | key_layer = self.transpose_for_scores(mixed_key_layer)
80 | value_layer = self.transpose_for_scores(mixed_value_layer)
81 |
82 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
83 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
84 | attention_probs = self.softmax(attention_scores)
85 | weights = attention_probs if self.vis else None
86 | attention_probs = self.attn_dropout(attention_probs)
87 |
88 | context_layer = torch.matmul(attention_probs, value_layer)
89 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
90 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
91 | context_layer = context_layer.view(*new_context_layer_shape)
92 | attention_output = self.out(context_layer)
93 | attention_output = self.proj_dropout(attention_output)
94 | return attention_output, weights
95 |
96 |
97 | class Mlp(nn.Module):
98 | def __init__(self, config):
99 | super(Mlp, self).__init__()
100 | self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
101 | self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
102 | self.act_fn = ACT2FN["gelu"]
103 | self.dropout = Dropout(config.transformer["dropout_rate"])
104 |
105 | self._init_weights()
106 |
107 | def _init_weights(self):
108 | nn.init.xavier_uniform_(self.fc1.weight)
109 | nn.init.xavier_uniform_(self.fc2.weight)
110 | nn.init.normal_(self.fc1.bias, std=1e-6)
111 | nn.init.normal_(self.fc2.bias, std=1e-6)
112 |
113 | def forward(self, x):
114 | x = self.fc1(x)
115 | x = self.act_fn(x)
116 | x = self.dropout(x)
117 | x = self.fc2(x)
118 | x = self.dropout(x)
119 | return x
120 |
121 |
122 | class Embeddings(nn.Module):
123 | """Construct the embeddings from patch, position embeddings.
124 | """
125 | def __init__(self, config, img_size, in_channels=3):
126 | super(Embeddings, self).__init__()
127 | self.hybrid = None
128 | self.config = config
129 | img_size = _pair(img_size)
130 |
131 | if config.patches.get("grid") is not None: # ResNet
132 | grid_size = config.patches["grid"]
133 | patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
134 | patch_size_real = (patch_size[0] * 16, patch_size[1] * 16)
135 | n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1])
136 | self.hybrid = True
137 | else:
138 | patch_size = _pair(config.patches["size"])
139 | n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
140 | self.hybrid = False
141 |
142 | if self.hybrid:
143 | self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor)
144 | in_channels = self.hybrid_model.width * 16
145 | self.patch_embeddings = Conv2d(in_channels=in_channels,
146 | out_channels=config.hidden_size,
147 | kernel_size=patch_size,
148 | stride=patch_size)
149 | self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size))
150 |
151 | self.dropout = Dropout(config.transformer["dropout_rate"])
152 |
153 |
154 | def forward(self, x):
155 | if self.hybrid:
156 | x, features = self.hybrid_model(x)
157 | else:
158 | features = None
159 | x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2))
160 | x = x.flatten(2)
161 | x = x.transpose(-1, -2) # (B, n_patches, hidden)
162 |
163 | embeddings = x + self.position_embeddings
164 | embeddings = self.dropout(embeddings)
165 | return embeddings, features
166 |
167 |
168 | class Block(nn.Module):
169 | def __init__(self, config, vis):
170 | super(Block, self).__init__()
171 | self.hidden_size = config.hidden_size
172 | self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
173 | self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
174 | self.ffn = Mlp(config)
175 | self.attn = Attention(config, vis)
176 |
177 | def forward(self, x):
178 | h = x
179 | x = self.attention_norm(x)
180 | x, weights = self.attn(x)
181 | x = x + h
182 |
183 | h = x
184 | x = self.ffn_norm(x)
185 | x = self.ffn(x)
186 | x = x + h
187 | return x, weights
188 |
189 | def load_from(self, weights, n_block):
190 | ROOT = f"Transformer/encoderblock_{n_block}"
191 | with torch.no_grad():
192 | query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
193 | key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
194 | value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
195 | out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()
196 |
197 | query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
198 | key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
199 | value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
200 | out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)
201 |
202 | self.attn.query.weight.copy_(query_weight)
203 | self.attn.key.weight.copy_(key_weight)
204 | self.attn.value.weight.copy_(value_weight)
205 | self.attn.out.weight.copy_(out_weight)
206 | self.attn.query.bias.copy_(query_bias)
207 | self.attn.key.bias.copy_(key_bias)
208 | self.attn.value.bias.copy_(value_bias)
209 | self.attn.out.bias.copy_(out_bias)
210 |
211 | mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
212 | mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
213 | mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
214 | mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()
215 |
216 | self.ffn.fc1.weight.copy_(mlp_weight_0)
217 | self.ffn.fc2.weight.copy_(mlp_weight_1)
218 | self.ffn.fc1.bias.copy_(mlp_bias_0)
219 | self.ffn.fc2.bias.copy_(mlp_bias_1)
220 |
221 | self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
222 | self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
223 | self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
224 | self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))
225 |
226 |
227 | class Encoder(nn.Module):
228 | def __init__(self, config, vis):
229 | super(Encoder, self).__init__()
230 | self.vis = vis
231 | self.layer = nn.ModuleList()
232 | self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
233 | for _ in range(config.transformer["num_layers"]):
234 | layer = Block(config, vis)
235 | self.layer.append(copy.deepcopy(layer))
236 |
237 | def forward(self, hidden_states):
238 | attn_weights = []
239 | for layer_block in self.layer:
240 | hidden_states, weights = layer_block(hidden_states)
241 | if self.vis:
242 | attn_weights.append(weights)
243 | encoded = self.encoder_norm(hidden_states)
244 | return encoded, attn_weights
245 |
246 |
247 | class Transformer(nn.Module):
248 | def __init__(self, config, img_size, vis):
249 | super(Transformer, self).__init__()
250 | self.embeddings = Embeddings(config, img_size=img_size)
251 | self.encoder = Encoder(config, vis)
252 |
253 | def forward(self, input_ids):
254 | embedding_output, features = self.embeddings(input_ids)
255 | encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden)
256 | return encoded, attn_weights, features
257 |
258 |
259 | class Conv2dReLU(nn.Sequential):
260 | def __init__(
261 | self,
262 | in_channels,
263 | out_channels,
264 | kernel_size,
265 | padding=0,
266 | stride=1,
267 | use_batchnorm=True,
268 | ):
269 | conv = nn.Conv2d(
270 | in_channels,
271 | out_channels,
272 | kernel_size,
273 | stride=stride,
274 | padding=padding,
275 | bias=not (use_batchnorm),
276 | )
277 | relu = nn.ReLU(inplace=True)
278 |
279 | bn = nn.BatchNorm2d(out_channels)
280 |
281 | super(Conv2dReLU, self).__init__(conv, bn, relu)
282 |
283 |
284 | class DecoderBlock(nn.Module):
285 | def __init__(
286 | self,
287 | in_channels,
288 | out_channels,
289 | skip_channels=0,
290 | use_batchnorm=True,
291 | ):
292 | super().__init__()
293 | self.conv1 = Conv2dReLU(
294 | in_channels + skip_channels,
295 | out_channels,
296 | kernel_size=3,
297 | padding=1,
298 | use_batchnorm=use_batchnorm,
299 | )
300 | self.conv2 = Conv2dReLU(
301 | out_channels,
302 | out_channels,
303 | kernel_size=3,
304 | padding=1,
305 | use_batchnorm=use_batchnorm,
306 | )
307 | self.up = nn.UpsamplingBilinear2d(scale_factor=2)
308 |
309 | def forward(self, x, skip=None):
310 | x = self.up(x)
311 | if skip is not None:
312 | x = torch.cat([x, skip], dim=1)
313 | x = self.conv1(x)
314 | x = self.conv2(x)
315 | return x
316 |
317 |
318 | class SegmentationHead(nn.Sequential):
319 |
320 | def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
321 | conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
322 | upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
323 | super().__init__(conv2d, upsampling)
324 |
325 |
326 | class DecoderCup(nn.Module):
327 | def __init__(self, config):
328 | super().__init__()
329 | self.config = config
330 | head_channels = 512
331 | self.conv_more = Conv2dReLU(
332 | config.hidden_size,
333 | head_channels,
334 | kernel_size=3,
335 | padding=1,
336 | use_batchnorm=True,
337 | )
338 | decoder_channels = config.decoder_channels
339 | in_channels = [head_channels] + list(decoder_channels[:-1])
340 | out_channels = decoder_channels
341 |
342 | if self.config.n_skip != 0:
343 | skip_channels = self.config.skip_channels
344 | for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip
345 | skip_channels[3-i]=0
346 |
347 | else:
348 | skip_channels=[0,0,0,0]
349 |
350 | blocks = [
351 | DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
352 | ]
353 | self.blocks = nn.ModuleList(blocks)
354 |
355 | def forward(self, hidden_states, features=None):
356 | B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
357 | h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
358 | x = hidden_states.permute(0, 2, 1)
359 | x = x.contiguous().view(B, hidden, h, w)
360 | x = self.conv_more(x)
361 | for i, decoder_block in enumerate(self.blocks):
362 | if features is not None:
363 | skip = features[i] if (i < self.config.n_skip) else None
364 | else:
365 | skip = None
366 | x = decoder_block(x, skip=skip)
367 | return x
368 |
369 |
370 | class VisionTransformer(nn.Module):
371 | def __init__(self, config, img_size=256, num_classes=6, zero_head=False, vis=False):
372 | super(VisionTransformer, self).__init__()
373 | self.num_classes = num_classes
374 | self.zero_head = zero_head
375 | self.classifier = config.classifier
376 | self.transformer = Transformer(config, img_size, vis)
377 | self.decoder = DecoderCup(config)
378 | self.segmentation_head = SegmentationHead(
379 | in_channels=config['decoder_channels'][-1],
380 | out_channels=config['n_classes'],
381 | kernel_size=3,
382 | )
383 | self.config = config
384 |
385 | def forward(self, x):
386 | if x.size()[1] == 1:
387 | x = x.repeat(1,3,1,1)
388 | x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden)
389 | x = self.decoder(x, features)
390 | logits = self.segmentation_head(x)
391 | return logits
392 |
393 | def load_from(self, weights):
394 | with torch.no_grad():
395 |
396 | res_weight = weights
397 | self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
398 | self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
399 |
400 | self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
401 | self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))
402 |
403 | posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
404 |
405 | posemb_new = self.transformer.embeddings.position_embeddings
406 | if posemb.size() == posemb_new.size():
407 | self.transformer.embeddings.position_embeddings.copy_(posemb)
408 | elif posemb.size()[1]-1 == posemb_new.size()[1]:
409 | posemb = posemb[:, 1:]
410 | self.transformer.embeddings.position_embeddings.copy_(posemb)
411 | else:
412 | logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
413 | ntok_new = posemb_new.size(1)
414 | if self.classifier == "seg":
415 | _, posemb_grid = posemb[:, :1], posemb[0, 1:]
416 | gs_old = int(np.sqrt(len(posemb_grid)))
417 | gs_new = int(np.sqrt(ntok_new))
418 | print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
419 | posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
420 | zoom = (gs_new / gs_old, gs_new / gs_old, 1)
421 | posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np
422 | posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
423 | posemb = posemb_grid
424 | self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))
425 |
426 | # Encoder whole
427 | for bname, block in self.transformer.encoder.named_children():
428 | for uname, unit in block.named_children():
429 | unit.load_from(weights, n_block=uname)
430 |
431 | if self.transformer.embeddings.hybrid:
432 | self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True))
433 | gn_weight = np2th(res_weight["gn_root/scale"]).view(-1)
434 | gn_bias = np2th(res_weight["gn_root/bias"]).view(-1)
435 | self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
436 | self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)
437 |
438 | for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
439 | for uname, unit in block.named_children():
440 | unit.load_from(res_weight, n_block=bname, n_unit=uname)
441 |
442 | CONFIGS = {
443 | 'ViT-B_16': configs.get_b16_config(),
444 | 'ViT-B_32': configs.get_b32_config(),
445 | 'ViT-L_16': configs.get_l16_config(),
446 | 'ViT-L_32': configs.get_l32_config(),
447 | 'ViT-H_14': configs.get_h14_config(),
448 | 'R50-ViT-B_16': configs.get_r50_b16_config(),
449 | 'R50-ViT-L_16': configs.get_r50_l16_config(),
450 | 'testing': configs.get_testing(),
451 | }
452 |
453 |
454 | if __name__ == '__main__':
455 | config_vit = configs.get_r50_b16_config()
456 | model = VisionTransformer(config_vit, img_size=256, num_classes=6)
457 |
458 | image = torch.randn(32, 3, 256, 256)
459 |
460 | output = model(image)
461 | print("input:", image.shape)
462 | print("output:", output.shape)
--------------------------------------------------------------------------------
/model/TransUnet/vit_seg_modeling_resnet_skip.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | from os.path import join as pjoin
4 | from collections import OrderedDict
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | def np2th(weights, conv=False):
12 | """Possibly convert HWIO to OIHW."""
13 | if conv:
14 | weights = weights.transpose([3, 2, 0, 1])
15 | return torch.from_numpy(weights)
16 |
17 |
18 | class StdConv2d(nn.Conv2d):
19 |
20 | def forward(self, x):
21 | w = self.weight
22 | v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
23 | w = (w - m) / torch.sqrt(v + 1e-5)
24 | return F.conv2d(x, w, self.bias, self.stride, self.padding,
25 | self.dilation, self.groups)
26 |
27 |
28 | def conv3x3(cin, cout, stride=1, groups=1, bias=False):
29 | return StdConv2d(cin, cout, kernel_size=3, stride=stride,
30 | padding=1, bias=bias, groups=groups)
31 |
32 |
33 | def conv1x1(cin, cout, stride=1, bias=False):
34 | return StdConv2d(cin, cout, kernel_size=1, stride=stride,
35 | padding=0, bias=bias)
36 |
37 |
38 | class PreActBottleneck(nn.Module):
39 | """Pre-activation (v2) bottleneck block.
40 | """
41 |
42 | def __init__(self, cin, cout=None, cmid=None, stride=1):
43 | super().__init__()
44 | cout = cout or cin
45 | cmid = cmid or cout//4
46 |
47 | self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6)
48 | self.conv1 = conv1x1(cin, cmid, bias=False)
49 | self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6)
50 | self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!!
51 | self.gn3 = nn.GroupNorm(32, cout, eps=1e-6)
52 | self.conv3 = conv1x1(cmid, cout, bias=False)
53 | self.relu = nn.ReLU(inplace=True)
54 |
55 | if (stride != 1 or cin != cout):
56 | # Projection also with pre-activation according to paper.
57 | self.downsample = conv1x1(cin, cout, stride, bias=False)
58 | self.gn_proj = nn.GroupNorm(cout, cout)
59 |
60 | def forward(self, x):
61 |
62 | # Residual branch
63 | residual = x
64 | if hasattr(self, 'downsample'):
65 | residual = self.downsample(x)
66 | residual = self.gn_proj(residual)
67 |
68 | # Unit's branch
69 | y = self.relu(self.gn1(self.conv1(x)))
70 | y = self.relu(self.gn2(self.conv2(y)))
71 | y = self.gn3(self.conv3(y))
72 |
73 | y = self.relu(residual + y)
74 | return y
75 |
76 | def load_from(self, weights, n_block, n_unit):
77 | conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True)
78 | conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True)
79 | conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True)
80 |
81 | gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")])
82 | gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")])
83 |
84 | gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")])
85 | gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")])
86 |
87 | gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")])
88 | gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")])
89 |
90 | self.conv1.weight.copy_(conv1_weight)
91 | self.conv2.weight.copy_(conv2_weight)
92 | self.conv3.weight.copy_(conv3_weight)
93 |
94 | self.gn1.weight.copy_(gn1_weight.view(-1))
95 | self.gn1.bias.copy_(gn1_bias.view(-1))
96 |
97 | self.gn2.weight.copy_(gn2_weight.view(-1))
98 | self.gn2.bias.copy_(gn2_bias.view(-1))
99 |
100 | self.gn3.weight.copy_(gn3_weight.view(-1))
101 | self.gn3.bias.copy_(gn3_bias.view(-1))
102 |
103 | if hasattr(self, 'downsample'):
104 | proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True)
105 | proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")])
106 | proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")])
107 |
108 | self.downsample.weight.copy_(proj_conv_weight)
109 | self.gn_proj.weight.copy_(proj_gn_weight.view(-1))
110 | self.gn_proj.bias.copy_(proj_gn_bias.view(-1))
111 |
112 | class ResNetV2(nn.Module):
113 | """Implementation of Pre-activation (v2) ResNet mode."""
114 |
115 | def __init__(self, block_units, width_factor):
116 | super().__init__()
117 | width = int(64 * width_factor)
118 | self.width = width
119 |
120 | self.root = nn.Sequential(OrderedDict([
121 | ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)),
122 | ('gn', nn.GroupNorm(32, width, eps=1e-6)),
123 | ('relu', nn.ReLU(inplace=True)),
124 | # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0))
125 | ]))
126 |
127 | self.body = nn.Sequential(OrderedDict([
128 | ('block1', nn.Sequential(OrderedDict(
129 | [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] +
130 | [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)],
131 | ))),
132 | ('block2', nn.Sequential(OrderedDict(
133 | [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] +
134 | [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)],
135 | ))),
136 | ('block3', nn.Sequential(OrderedDict(
137 | [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] +
138 | [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)],
139 | ))),
140 | ]))
141 |
142 | def forward(self, x):
143 | features = []
144 | b, c, in_size, _ = x.size()
145 | x = self.root(x)
146 | features.append(x)
147 | x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x)
148 | for i in range(len(self.body)-1):
149 | x = self.body[i](x)
150 | right_size = int(in_size / 4 / (i+1))
151 | if x.size()[2] != right_size:
152 | pad = right_size - x.size()[2]
153 | assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size)
154 | feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device)
155 | feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:]
156 | else:
157 | feat = x
158 | features.append(feat)
159 | x = self.body[-1](x)
160 | return x, features[::-1]
161 |
--------------------------------------------------------------------------------
/model/Unet/Unet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch
4 | import numpy as np
5 |
6 |
7 | class Unet(nn.Module):
8 | def __init__(self, num_classes):
9 | super(Unet, self).__init__()
10 | self.stage_1 = nn.Sequential(
11 | nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
12 | nn.BatchNorm2d(32),
13 | nn.ReLU(),
14 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
15 | nn.BatchNorm2d(64),
16 | nn.ReLU(),
17 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
18 | nn.BatchNorm2d(64),
19 | nn.ReLU(),
20 | )
21 |
22 | self.stage_2 = nn.Sequential(
23 | nn.MaxPool2d(kernel_size=2),
24 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
25 | nn.BatchNorm2d(128),
26 | nn.ReLU(),
27 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
28 | nn.BatchNorm2d(128),
29 | nn.ReLU(),
30 | )
31 |
32 | self.stage_3 = nn.Sequential(
33 | nn.MaxPool2d(kernel_size=2),
34 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
35 | nn.BatchNorm2d(256),
36 | nn.ReLU(),
37 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
38 | nn.BatchNorm2d(256),
39 | nn.ReLU(),
40 | )
41 |
42 | self.stage_4 = nn.Sequential(
43 | nn.MaxPool2d(kernel_size=2),
44 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
45 | nn.BatchNorm2d(512),
46 | nn.ReLU(),
47 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
48 | nn.BatchNorm2d(512),
49 | nn.ReLU(),
50 | )
51 |
52 | self.stage_5 = nn.Sequential(
53 | nn.MaxPool2d(kernel_size=2),
54 | nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=1),
55 | nn.BatchNorm2d(1024),
56 | nn.ReLU(),
57 | nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, padding=1),
58 | nn.BatchNorm2d(1024),
59 | nn.ReLU(),
60 | )
61 |
62 | self.upsample_4 = nn.Sequential(
63 | nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1)
64 | )
65 | self.upsample_3 = nn.Sequential(
66 | nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1)
67 | )
68 | self.upsample_2 = nn.Sequential(
69 | nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1)
70 | )
71 | self.upsample_1 = nn.Sequential(
72 | nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1)
73 | )
74 |
75 | self.stage_up_4 = nn.Sequential(
76 | nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1),
77 | nn.BatchNorm2d(512),
78 | nn.ReLU(),
79 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
80 | nn.BatchNorm2d(512),
81 | nn.ReLU()
82 | )
83 |
84 | self.stage_up_3 = nn.Sequential(
85 | nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1),
86 | nn.BatchNorm2d(256),
87 | nn.ReLU(),
88 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
89 | nn.BatchNorm2d(256),
90 | nn.ReLU()
91 | )
92 |
93 | self.stage_up_2 = nn.Sequential(
94 | nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1),
95 | nn.BatchNorm2d(128),
96 | nn.ReLU(),
97 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
98 | nn.BatchNorm2d(128),
99 | nn.ReLU()
100 | )
101 | self.stage_up_1 = nn.Sequential(
102 | nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1),
103 | nn.BatchNorm2d(64),
104 | nn.ReLU(),
105 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
106 | nn.BatchNorm2d(64),
107 | nn.ReLU()
108 | )
109 |
110 | self.final = nn.Sequential(
111 | nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=3, padding=1),
112 | )
113 |
114 | def forward(self, x):
115 | x = x.float()
116 | # 下采样过程
117 | stage_1 = self.stage_1(x)
118 | stage_2 = self.stage_2(stage_1)
119 | stage_3 = self.stage_3(stage_2)
120 | stage_4 = self.stage_4(stage_3)
121 | stage_5 = self.stage_5(stage_4)
122 |
123 |
124 | # 上采样和合并stage_4
125 | up_4 = self.upsample_4(stage_5)
126 |
127 | up_4_conv = torch.cat([up_4, stage_4], dim=1)
128 | up_4_conv = self.stage_up_4(up_4_conv)
129 |
130 | # 上采样和合并stage_3
131 | up_3 = self.upsample_3(up_4_conv)
132 |
133 | up_3_conv = torch.cat([up_3, stage_3], dim=1)
134 | up_3_conv = self.stage_up_3(up_3_conv)
135 |
136 | # 上采样和合并stage_2
137 | up_2 = self.upsample_2(up_3_conv)
138 |
139 | up_2_conv = torch.cat([up_2, stage_2], dim=1)
140 | up_2_conv = self.stage_up_2(up_2_conv)
141 |
142 | # 上采样和合并stage_1
143 | up_1 = self.upsample_1(up_2_conv)
144 |
145 | up_1_conv = torch.cat([up_1, stage_1], dim=1)
146 | up_1_conv = self.stage_up_1(up_1_conv)
147 |
148 | output = self.final(up_1_conv)
149 |
150 | return output
151 |
152 |
153 | if __name__ == '__main__':
154 | model = Unet(6)
155 |
156 | image = torch.randn(32, 3, 256, 256)
157 |
158 | output = model(image)
159 |
160 | print("input:", image.shape)
161 | print("output:", output.shape)
162 |
--------------------------------------------------------------------------------
/model/Unet/_init_.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wzysaber/ST_Unet_pytorch_Semantic-segmentation/b27f4d79ba85f81f793e17e686d6a7a1cd8b41ec/model/Unet/_init_.py
--------------------------------------------------------------------------------
/model/deeplabv3_version_1/aspp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ASPP_Bottleneck(nn.Module):
7 | def __init__(self, num_classes):
8 | super(ASPP_Bottleneck, self).__init__()
9 |
10 | self.conv_1x1_1 = nn.Conv2d(4*512, 256, kernel_size=1)
11 | self.bn_conv_1x1_1 = nn.BatchNorm2d(256)
12 |
13 | self.conv_3x3_1 = nn.Conv2d(4*512, 256, kernel_size=3, stride=1, padding=6, dilation=6)#6
14 | self.bn_conv_3x3_1 = nn.BatchNorm2d(256)
15 |
16 | self.conv_3x3_2 = nn.Conv2d(4*512, 256, kernel_size=3, stride=1, padding=12, dilation=12)#12
17 | self.bn_conv_3x3_2 = nn.BatchNorm2d(256)
18 |
19 | self.conv_3x3_3 = nn.Conv2d(4*512, 256, kernel_size=3, stride=1, padding=18, dilation=18)#18
20 | self.bn_conv_3x3_3 = nn.BatchNorm2d(256)
21 |
22 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
23 |
24 | self.conv_1x1_2 = nn.Conv2d(4*512, 256, kernel_size=1)
25 | self.bn_conv_1x1_2 = nn.BatchNorm2d(256)
26 |
27 | self.conv_1x1_3 = nn.Conv2d(1280, 256, kernel_size=1) # (1280 = 5*256)
28 | self.bn_conv_1x1_3 = nn.BatchNorm2d(256)
29 |
30 | self.conv_1x1_4 = nn.Conv2d(256, num_classes, kernel_size=1)
31 |
32 | def forward(self, feature_map):
33 | # (feature_map has shape (batch_size, 4*512, h/16, w/16))
34 |
35 | feature_map_h = feature_map.size()[2] # (== h/16)
36 | feature_map_w = feature_map.size()[3] # (== w/16)
37 |
38 | out_1x1 = F.relu(self.bn_conv_1x1_1(self.conv_1x1_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16))
39 | out_3x3_1 = F.relu(self.bn_conv_3x3_1(self.conv_3x3_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16))
40 | out_3x3_2 = F.relu(self.bn_conv_3x3_2(self.conv_3x3_2(feature_map))) # (shape: (batch_size, 256, h/16, w/16))
41 | out_3x3_3 = F.relu(self.bn_conv_3x3_3(self.conv_3x3_3(feature_map))) # (shape: (batch_size, 256, h/16, w/16))
42 |
43 | out_img = self.avg_pool(feature_map) # (shape: (batch_size, 512, 1, 1))
44 | out_img = F.relu(self.bn_conv_1x1_2(self.conv_1x1_2(out_img))) # (shape: (batch_size, 256, 1, 1))
45 | out_img = F.interpolate(out_img, size=(feature_map_h, feature_map_w), mode="bilinear", align_corners=False) # (shape: (batch_size, 256, h/16, w/16))
46 |
47 | out = torch.cat([out_1x1, out_3x3_1, out_3x3_2, out_3x3_3, out_img], 1) # (shape: (batch_size, 1280, h/16, w/16))
48 | out = F.relu(self.bn_conv_1x1_3(self.conv_1x1_3(out))) # (shape: (batch_size, 256, h/16, w/16))
49 | out = self.conv_1x1_4(out) # (shape: (batch_size, num_classes, h/16, w/16))
50 |
51 | return out
--------------------------------------------------------------------------------
/model/deeplabv3_version_1/deeplabv3.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | from resnet import ResNet50
4 | from aspp import ASPP_Bottleneck
5 | import torch
6 |
7 | class DeepLabV3(nn.Module):
8 | def __init__(self, num_classes=6):
9 | super(DeepLabV3, self).__init__()
10 | self.num_classes = num_classes
11 | self.resnet = ResNet50()
12 | self.aspp = ASPP_Bottleneck(num_classes=self.num_classes)
13 | self.sigmoid = nn.Sigmoid()
14 |
15 | def forward(self, x):
16 | h = x.size()[2]
17 | w = x.size()[3]
18 | feature_map = self.resnet(x)
19 | output = self.aspp(feature_map)
20 | output = F.interpolate(output, size=(h, w), mode="bilinear", align_corners=False)
21 | # output = self.sigmoid(output)
22 | return output
23 |
24 |
25 | if __name__ == '__main__':
26 | model = DeepLabV3()
27 | model.eval()
28 | image = torch.randn(32, 3, 256, 256)
29 | print(model)
30 | output = model(image)
31 | print("input:", image.shape)
32 | print("output:", output.shape)
33 |
--------------------------------------------------------------------------------
/model/deeplabv3_version_1/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torchvision.models as models
5 |
6 |
7 | def make_layer(block, in_channels, channels, num_blocks, stride=1, dilation=1):
8 | strides = [stride] + [1]*(num_blocks - 1) # (stride == 2, num_blocks == 4 --> strides == [2, 1, 1, 1])
9 |
10 | blocks = []
11 | for stride in strides:
12 | blocks.append(block(in_channels=in_channels, channels=channels, stride=stride, dilation=dilation))
13 | in_channels = block.expansion*channels
14 |
15 | layer = nn.Sequential(*blocks) # (*blocks: call with unpacked list entires as arguments)
16 |
17 | return layer
18 |
19 |
20 | class BasicBlock(nn.Module):
21 | expansion = 1
22 |
23 | def __init__(self, in_channels, channels, stride=1, dilation=1):
24 | super(BasicBlock, self).__init__()
25 |
26 | out_channels = self.expansion*channels
27 |
28 | self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False)
29 | self.bn1 = nn.BatchNorm2d(channels)
30 |
31 | self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=dilation, dilation=dilation, bias=False)
32 | self.bn2 = nn.BatchNorm2d(channels)
33 |
34 | if (stride != 1) or (in_channels != out_channels):
35 | conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
36 | bn = nn.BatchNorm2d(out_channels)
37 | self.downsample = nn.Sequential(conv, bn)
38 | else:
39 | self.downsample = nn.Sequential()
40 |
41 | def forward(self, x):
42 | # (x has shape: (batch_size, in_channels, h, w))
43 |
44 | out = F.relu(self.bn1(self.conv1(x))) # (shape: (batch_size, channels, h, w) if stride == 1, (batch_size, channels, h/2, w/2) if stride == 2)
45 | out = self.bn2(self.conv2(out)) # (shape: (batch_size, channels, h, w) if stride == 1, (batch_size, channels, h/2, w/2) if stride == 2)
46 |
47 | out = out + self.downsample(x) # (shape: (batch_size, channels, h, w) if stride == 1, (batch_size, channels, h/2, w/2) if stride == 2)
48 |
49 | out = F.relu(out) # (shape: (batch_size, channels, h, w) if stride == 1, (batch_size, channels, h/2, w/2) if stride == 2)
50 |
51 | return out
52 |
53 |
54 | class Bottleneck(nn.Module):
55 | expansion = 4
56 |
57 | def __init__(self, in_channels, channels, stride=1, dilation=1):
58 | super(Bottleneck, self).__init__()
59 |
60 | out_channels = self.expansion*channels
61 |
62 | self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=1, bias=False)
63 | self.bn1 = nn.BatchNorm2d(channels)
64 |
65 | self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False)
66 | self.bn2 = nn.BatchNorm2d(channels)
67 |
68 | self.conv3 = nn.Conv2d(channels, out_channels, kernel_size=1, bias=False)
69 | self.bn3 = nn.BatchNorm2d(out_channels)
70 |
71 | if (stride != 1) or (in_channels != out_channels):
72 | conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
73 | bn = nn.BatchNorm2d(out_channels)
74 | self.downsample = nn.Sequential(conv, bn)
75 | else:
76 | self.downsample = nn.Sequential()
77 |
78 | def forward(self, x):
79 | # (x has shape: (batch_size, in_channels, h, w))
80 |
81 | out = F.relu(self.bn1(self.conv1(x))) # (shape: (batch_size, channels, h, w))
82 | out = F.relu(self.bn2(self.conv2(out))) # (shape: (batch_size, channels, h, w) if stride == 1, (batch_size, channels, h/2, w/2) if stride == 2)
83 | out = self.bn3(self.conv3(out)) # (shape: (batch_size, out_channels, h, w) if stride == 1, (batch_size, out_channels, h/2, w/2) if stride == 2)
84 |
85 | out = out + self.downsample(x) # (shape: (batch_size, out_channels, h, w) if stride == 1, (batch_size, out_channels, h/2, w/2) if stride == 2)
86 |
87 | out = F.relu(out) # (shape: (batch_size, out_channels, h, w) if stride == 1, (batch_size, out_channels, h/2, w/2) if stride == 2)
88 |
89 | return out
90 |
91 | class ResNet50(nn.Module):
92 | def __init__(self):
93 | super(ResNet50, self).__init__()
94 |
95 | resnet = models.resnet50()
96 | #resnet.load_state_dict((torch.load("/root/data/others/yaoganbisai/code_6_7/models/pretrained_model/resnet50-19c8e357.pth")))
97 | self.resnet = nn.Sequential(*list(resnet.children())[:-3])
98 | self.layer5 = make_layer(Bottleneck, in_channels=4*256, channels=512, num_blocks=3, stride=1, dilation=2)
99 |
100 | def forward(self, x):
101 | c4 = self.resnet(x)
102 | output = self.layer5(c4)
103 | return output
104 |
105 | def get_resnet50():
106 | return ResNet50()
--------------------------------------------------------------------------------
/tool/Save_predict.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import numpy as np
4 | from torchvision import transforms
5 |
6 | from PIL import Image
7 | from torch.autograd import Variable
8 |
9 | from torch.utils.data import Dataset
10 | from torch.utils.data import DataLoader
11 |
12 | from model.Unet.Unet import Unet
13 |
14 | img_transform = transforms.Compose([
15 | transforms.ToTensor(),
16 | transforms.Normalize([.485, .456, .406], [.229, .224, .225])])
17 |
18 | # ====================================================================================================
19 | import cv2 as cv
20 |
21 |
22 | def GetPadImNRowColLi(image_path, cutsize_h=256, cutsize_w=256, stride=256):
23 | image = cv.imread(image_path)
24 | h, w = image.shape[0], image.shape[1]
25 | h_pad_cutsize = h if (h // cutsize_h == 0) else (h // cutsize_h + 1) * cutsize_h
26 | w_pad_cutsize = w if (w // cutsize_w == 0) else (w // cutsize_w + 1) * cutsize_w
27 | image = cv.copyMakeBorder(image,
28 | 0,
29 | h_pad_cutsize - h,
30 | 0,
31 | w_pad_cutsize - w,
32 | cv.BORDER_CONSTANT, 0)
33 | N = image.shape[0] - cutsize_h + 1
34 | M = image.shape[1] - cutsize_w + 1
35 | from numpy import arange
36 | row = arange(0, N, stride)
37 | col = arange(0, M, stride)
38 | row_col_li = []
39 | for c in col:
40 | for r in row:
41 | row_col_li.append([c, r, c + cutsize_w, r + cutsize_h])
42 | return image, row_col_li
43 |
44 |
45 | # ====================================================================================================
46 |
47 |
48 | def snapshot_forward(model, dataloader, model_list, png, shape):
49 | model.eval()
50 | for (index, (image, pos_list)) in enumerate(dataloader):
51 | image = Variable(image).cuda()
52 | # print(image)
53 | # print(pos_list)
54 |
55 | predict_list = 0
56 | for model in model_list:
57 | predict_1 = model(image)
58 | predict_list = predict_1
59 | predict_2 = model(torch.flip(image, [-1]))
60 | predict_2 = torch.flip(predict_2, [-1])
61 |
62 | predict_3 = model(torch.flip(image, [-2]))
63 | predict_3 = torch.flip(predict_3, [-2])
64 |
65 | predict_4 = model(torch.flip(image, [-1, -2]))
66 | predict_4 = torch.flip(predict_4, [-1, -2])
67 |
68 | predict_list += (predict_1 + predict_2 + predict_3 + predict_4)
69 | predict_list = torch.argmax(predict_list.cpu(), 1).byte().numpy() # n x h x w
70 |
71 | batch_size = predict_list.shape[0] # batch大小
72 | for i in range(batch_size):
73 | predict = predict_list[i]
74 | pos = pos_list[i, :]
75 | [topleft_x, topleft_y, buttomright_x, buttomright_y] = pos
76 |
77 | if (buttomright_x - topleft_x) == 256 and (buttomright_y - topleft_y) == 256:
78 | # png[topleft_y + 128:buttomright_y - 128, topleft_x + 128:buttomright_x - 128] = predict[128:384,128:384]
79 | png[topleft_y:buttomright_y, topleft_x:buttomright_x] = predict
80 | else:
81 | raise ValueError(
82 | "target_size!=512, Got {},{}".format(buttomright_x - topleft_x, buttomright_y - topleft_y))
83 |
84 | h, w = png.shape
85 | # png = png[128:h - 128, 128:w - 128] # 去除整体外边界
86 | # zeros = (6800, 7200) # 去除补全512整数倍时的右下边界
87 | zeros = shape
88 | png = png[:zeros[0], :zeros[1]]
89 |
90 | return png
91 |
92 |
93 | def parse_args():
94 | parser = argparse.ArgumentParser(description="膨胀预测")
95 | parser.add_argument('--test-data-root', type=str,
96 | default="/home/students/master/2022/wangzy/dataset/Vaihingen/Train/image/top_mosaic_09cm_area13.tif")
97 | parser.add_argument('--test-batch-size', type=int, default=4, metavar='N',
98 | help='batch size for testing (default:16)')
99 | parser.add_argument('--num_workers', type=int, default=0)
100 |
101 | parser.add_argument("--model-path", type=str,
102 | default="/home/students/master/2022/wangzy/PyCharm-Remote/ST_Unet_test/weight/Vaihingen/Unet/03-21-22:41:16/epoch_84_miou_0.70_F1_0.82.pth")
103 | parser.add_argument("--pred-path", type=str, default="")
104 | args = parser.parse_args()
105 | return args
106 |
107 |
108 | def create_png(shape):
109 | # zeros = (6800, 7200)
110 | zeros = shape
111 | h, w = zeros[0], zeros[1]
112 | new_h = h if (h // 256 == 0) else (h // 256 + 1) * 256
113 | new_w = w if (w // 256 == 0) else (w // 256 + 1) * 256
114 | # new_h, new_w = (h//512+1)*512, (w//512+1)*512 # 填充下边界和右边界得到滑窗的整数倍
115 | # zeros = (new_h+128, new_w+128) # 填充空白边界,考虑到边缘数据
116 | zeros = (new_h, new_w)
117 | zeros = np.zeros(zeros, np.uint8)
118 | return zeros
119 |
120 |
121 | # ====================================================================================================
122 | class Inference_Dataset(Dataset):
123 | def __init__(self, root_dir, transforms):
124 | self.root_dir = root_dir
125 | # self.csv_file = pd.read_csv(csv_file, header=None)
126 | self.pad_image, self.row_col_li = GetPadImNRowColLi(root_dir)
127 | self.transforms = transforms
128 |
129 | def __len__(self):
130 | # return len(self.csv_file)
131 | return len(self.row_col_li)
132 |
133 | def __getitem__(self, idx):
134 | c, r, c_end, r_end = self.row_col_li[idx]
135 | image = Image.fromarray(self.pad_image[r:r_end, c:c_end])
136 | image = self.transforms(image)
137 | pos_list = np.array(self.row_col_li[idx])
138 | return image, pos_list
139 |
140 |
141 | # ====================================================================================================
142 |
143 |
144 | def reference():
145 | args = parse_args()
146 |
147 | dataset = Inference_Dataset(root_dir=args.test_data_root,
148 | transforms=img_transform)
149 | dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=False, num_workers=0)
150 |
151 | model = Unet(num_classes=6)
152 | state_dict = torch.load(args.model_path)
153 | # new_state_dict = OrderedDict()
154 | # for k, v in state_dict.items():
155 | # print(k)
156 | # name = k[7:]
157 | # new_state_dict[name] = v
158 | model.load_state_dict(state_dict, strict=False)
159 | model = model.cuda()
160 |
161 | # model = nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
162 |
163 | model_list = []
164 | model_list.append(model)
165 |
166 | # ==================================================================
167 | shape = cv.imread(args.test_data_root).shape
168 | zeros = create_png((shape[0], shape[1]))
169 | image = snapshot_forward(model, dataloader, model_list, zeros, (shape[0], shape[1]))
170 | # ==================================================================
171 |
172 | from utils.palette import colorize_mask
173 | overlap = colorize_mask(image)
174 |
175 | import matplotlib.pyplot as plt
176 | plt.title("predict")
177 | plt.imshow(overlap)
178 | plt.show()
179 |
180 |
181 | if __name__ == '__main__':
182 | reference()
183 |
--------------------------------------------------------------------------------
/tool/predict.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms as T
3 | import numpy as np
4 | import cv2
5 | from PIL import Image
6 | from model.Unet.Unet import Unet
7 | import os
8 |
9 | from utils.palette import colorize_mask
10 | from Parameter import metric
11 | from prettytable import PrettyTable
12 |
13 |
14 | import matplotlib.pyplot as plt
15 |
16 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" # 设置采用的GPU序号
17 |
18 |
19 | # 定义预测函数
20 | def predict(model, image_path, Gray_label_path):
21 | """
22 | 对输入图像进行预测,返回预测结果。
23 |
24 | Args:
25 | model (nn.Module): PyTorch模型实例
26 | image_path (str): 输入图像路径
27 |
28 | Returns:
29 | 预测结果的(N, H, W)的numpy数组
30 | """
31 | # 加载图像并做相应预处理
32 | img = Image.open(image_path).convert('RGB')
33 |
34 | transform = T.Compose([
35 | T.ToTensor(),
36 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
37 | ])
38 | img = transform(img).to(device)
39 | img = img.unsqueeze(0) # 增加batch维
40 |
41 | Gray_label = Image.open(Gray_label_path).convert('L')
42 | mask = torch.from_numpy(np.array(Gray_label, dtype=np.int8)).long().numpy()
43 |
44 | # 对输入图像进行预测
45 |
46 | output = model(img)
47 | # pred = output.argmax(dim=1) # 取最大值的索引
48 | _, pred = torch.max(output, 1) # 加_,则返回一行中最大数的位置。
49 |
50 | # 转为numpy数组并去掉batch维
51 | pred = pred.data.cpu().numpy().squeeze().astype(np.uint8) # 将数据提取出来
52 |
53 | return pred, mask
54 |
55 |
56 | def Check(pred, mask):
57 | conf_mat = np.zeros((5, 5)).astype(np.int64)
58 | conf_mat += metric.confusion_matrix(pred=pred.flatten(),
59 | label=mask.flatten(),
60 | num_classes=6)
61 |
62 | acc, acc_per_class, pre, IoU, mean_IoU, kappa, F1_score, val_recall = metric.evaluate(conf_mat)
63 |
64 | print("Mean_IoU:", mean_IoU)
65 | print("OA:", acc)
66 |
67 |
68 | if __name__ == '__main__':
69 | # 加载相应参数
70 | device = torch.device("cuda:3")
71 |
72 | image_path = "/home/students/master/2022/wangzy/dataset/Vaihingen/predict/Cut/rgb/top_mosaic_09cm_area1_0017.jpg"
73 | RGB_label_path = "/home/students/master/2022/wangzy/dataset/Vaihingen/predict/Cut/label/top_mosaic_09cm_area1_label_0017.jpg"
74 | Gray_label_path = "/home/students/master/2022/wangzy/dataset/Vaihingen/predict/Cut/Gray_label/top_mosaic_09cm_area1_label_0017.jpg"
75 |
76 | model_path = "/home/students/master/2022/wangzy/PyCharm-Remote/ST_Unet_test/weight/Vaihingen/Unet/03-21-22:41:16/epoch_84_miou_0.70_F1_0.82.pth" # 导入网络的参数
77 |
78 | # 加载原始标签
79 | image = cv2.imread(RGB_label_path)
80 | B, G, R = cv2.split(image)
81 | image = cv2.merge((R, G, B))
82 |
83 | # 加载模型
84 | model = Unet(num_classes=6)
85 |
86 | state_dict = torch.load(model_path)
87 | # new_state_dict = OrderedDict()
88 | # for k, v in state_dict.items():
89 | # print(k)
90 | # name = k[7:]
91 | # new_state_dict[name] = v
92 | model.load_state_dict(state_dict, strict=False)
93 | model = model.to(device)
94 |
95 | # 预测图像
96 | pred, mask = predict(model, image_path, Gray_label_path)
97 | overlap = colorize_mask(pred)
98 |
99 | # 查看评价指标
100 | Check(pred, mask)
101 |
102 | # 可视化预测结果
103 | plt.title("predict")
104 | plt.imshow(overlap)
105 | plt.show()
106 |
107 | plt.title("label")
108 | plt.imshow(image)
109 | plt.show()
110 |
--------------------------------------------------------------------------------
/tool/train.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 | import torch.nn as nn
3 | import torch
4 | import numpy as np
5 | import data.sync_transforms
6 |
7 | from tqdm import tqdm
8 | from data.dataset import RSDataset
9 | from torch.autograd import Variable
10 | from prettytable import PrettyTable
11 | from Parameter import average_meter, metric
12 |
13 |
14 | def close_optimizer(args, model):
15 | # 使用相应的优化器
16 | if args.optimizer_name == 'Adadelta':
17 | optimizer = torch.optim.Adadelta(model.parameters(),
18 | lr=args.base_lr,
19 | weight_decay=args.weight_decay)
20 | if args.optimizer_name == 'Adam':
21 | optimizer = torch.optim.Adam(model.parameters(),
22 | lr=args.base_lr)
23 |
24 | if args.optimizer_name == 'SGD':
25 | optimizer = torch.optim.SGD(params=model.parameters(),
26 | lr=args.base_lr,
27 | momentum=args.momentum,
28 | weight_decay=args.weight_decay)
29 |
30 | return optimizer
31 |
32 |
33 | def data_set(args):
34 | # 对载入图像进行数据增强
35 | resize_scale_range = [float(scale) for scale in args.resize_scale_range.split(',')] # 0.5 2.0
36 |
37 | sync_transform = data.sync_transforms.Compose([
38 | data.sync_transforms.RandomScale(args.base_size, args.crop_size, resize_scale_range),
39 | data.sync_transforms.RandomFlip(args.flip_ratio)
40 | ])
41 |
42 | # 数据集的载入和相应参数
43 | train_dataset = RSDataset(root=args.train_data_root, mode='src', sync_transforms=sync_transform) # 加载数据集
44 |
45 | train_loader = DataLoader(dataset=train_dataset,
46 | batch_size=args.train_batch_size,
47 | num_workers=args.num_workers,
48 | shuffle=True,
49 | drop_last=True)
50 |
51 | # print('class names {}.'.format(train_loader.class_names))
52 | # print('Number samples {}.'.format(len(train_loader))) # 将模型的种类数和名称进行打印
53 |
54 | # 实现相应验证集
55 | if not args.no_val:
56 | val_dataset = RSDataset(root=args.val_data_root, mode='src', sync_transforms=None)
57 | val_loader = DataLoader(dataset=val_dataset,
58 | batch_size=args.val_batch_size,
59 | num_workers=args.num_workers,
60 | shuffle=True,
61 | drop_last=True)
62 |
63 | return train_loader, train_dataset, val_loader, val_dataset
64 |
65 |
66 | def training(args, num_classes, model, optimizer, train_dataset, train_loader, criterion1, criterion2, device, epoch):
67 | model.train() # 把module设成训练模式,对Dropout和BatchNorm有影响
68 |
69 | train_loss = average_meter.AverageMeter()
70 |
71 | # “Poly”衰减策略
72 | max_iter = args.total_epochs * len(train_loader)
73 | curr_iter = epoch * len(train_loader) # 训练的数量
74 | lr = args.base_lr * (1 - float(curr_iter) / max_iter) ** 0.9 # 自己定义的学习率
75 |
76 | # 建立比较的矩阵16X16的格式,
77 | conf_mat = np.zeros((5, 5)).astype(np.int64)
78 |
79 | tbar = tqdm(train_loader) # 可视化显示数据的迭代
80 |
81 | # 将训练集里面的数据进行相应的遍历
82 | for index, data in enumerate(tbar):
83 | # assert data[0].size()[2:] == data[1].size()[1:]
84 | # data = self.mixup_transform(data, epoch)
85 |
86 | # 加载dataload中的图片
87 | imgs = Variable(data[0]).to(device)
88 | masks = Variable(data[1]).to(device)
89 |
90 | # 引入参数
91 | outputs = model(imgs)
92 | # torch.max(tensor, dim):指定维度上最大的数,返回tensor和下标
93 | _, preds = torch.max(outputs, 1) # 加_,则返回一行中最大数的位置。
94 | preds = preds.data.cpu().numpy().squeeze().astype(np.uint8) # 将数据提取出来
95 |
96 | loss1 = criterion1(outputs, masks)
97 | loss2 = criterion2(outputs, masks, softmax=True)
98 |
99 | loss = 0.5 * loss1 + 0.5 * loss2
100 |
101 | train_loss.update(loss, args.train_batch_size)
102 | # writer.add_scalar('train_loss', train_loss.avg, curr_iter)
103 |
104 | optimizer.zero_grad() # zero_grad()梯度清0
105 | loss.backward()
106 | optimizer.step()
107 |
108 | # 将相应的数据进行打印
109 | tbar.set_description('epoch {}, training loss {}, with learning rate {}.'.format(
110 | epoch, train_loss.val, lr
111 | ))
112 |
113 | masks = masks.data.cpu().numpy().squeeze().astype(np.uint8) # 将数据提取出来
114 |
115 | # 将相应的数据存储在矩阵方阵中
116 | conf_mat += metric.confusion_matrix(pred=preds.flatten(),
117 | label=masks.flatten(),
118 | num_classes=num_classes)
119 |
120 | # 评价参数
121 | train_acc, train_acc_per_class, train_pre, train_IoU, train_mean_IoU, train_kappa, train_F1_score, train_recall = metric.evaluate(
122 | conf_mat)
123 |
124 | table = PrettyTable(["序号", "名称", "acc", "IOu"])
125 |
126 | # 打印参数
127 | for i in range(5):
128 | table.add_row([i, train_dataset.class_names[i], train_acc_per_class[i], train_IoU[i]])
129 | print(table)
130 |
131 | print("F1_score:", train_F1_score)
132 | print("train_mean_IoU:", train_mean_IoU)
133 |
134 | print("\ntrain_acc(OA):", train_acc)
135 | print("kappa:", train_kappa)
136 | print(" ")
137 |
138 | return train_acc
139 |
--------------------------------------------------------------------------------
/tool/val.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import os
4 |
5 | from tqdm import tqdm
6 | from torch.autograd import Variable
7 | from Parameter import metric
8 | from prettytable import PrettyTable
9 |
10 |
11 | def validating(args, num_classes, model, optimizer, train_dataset, val_loader, device, epoch):
12 | model.eval() # 把module 设成预测模式,对Dropout和BatchNorm有影响
13 |
14 | # 构建矩阵方阵
15 | conf_mat = np.zeros((5, 5)).astype(np.int64)
16 | # 加载相应的数据集
17 | tbar = tqdm(val_loader)
18 |
19 | # 对数据进行遍历
20 | for index, data in enumerate(tbar):
21 | # assert data[0].size()[2:] == data[1].size()[1:]
22 |
23 | # 将相应的数据提取出来
24 | imgs = Variable(data[0]).to(device)
25 | masks = Variable(data[1]).to(device)
26 |
27 | optimizer.zero_grad() # 梯度清0
28 | outputs = model(imgs)
29 | _, preds = torch.max(outputs, 1) # 返回最大值的值,不是像素的值则为1
30 |
31 | # 将相应的参数进行提取
32 | preds = preds.data.cpu().numpy().squeeze().astype(np.uint8)
33 | masks = masks.data.cpu().numpy().squeeze().astype(np.uint8)
34 |
35 | conf_mat += metric.confusion_matrix(pred=preds.flatten(),
36 | label=masks.flatten(),
37 | num_classes=num_classes)
38 |
39 | # 打印相应的数据
40 | val_acc, val_acc_per_class, val_pre, val_IoU, val_mean_IoU, val_kappa, val_F1_score, val_recall = metric.evaluate(
41 | conf_mat)
42 |
43 | model_name = 'epoch_%d_miou_%.2f_F1_%.2f' % (epoch, val_mean_IoU, val_F1_score)
44 |
45 | # 保存相应训练中最好的模型
46 | if val_mean_IoU > args.best_miou:
47 | if args.save_file:
48 | torch.save(model.state_dict(), os.path.join(args.directory, model_name + '.pth'))
49 | args.best_miou = val_mean_IoU
50 |
51 | table = PrettyTable(["序号", "名称", "acc", "IoU"])
52 |
53 | for i in range(5):
54 | table.add_row([i, train_dataset.class_names[i], val_acc_per_class[i], val_IoU[i]])
55 | print(table)
56 | print("val_F1_score:", val_F1_score)
57 | print("val_mean_IoU:", val_mean_IoU)
58 | print("val_acc:", val_acc)
59 | print("best_miou:", args.best_miou)
60 |
--------------------------------------------------------------------------------
/utils/Data_process.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import matplotlib.pyplot as plt
4 |
5 | # 自定义类别
6 | def fifteen_classes():
7 | return ['其他类别',
8 | '水田',
9 | '水浇地',
10 | '旱耕地',
11 | '园地',
12 | '乔木林地',
13 | '灌木林地',
14 | '天然草地',
15 | '人工草地',
16 | '工业用地',
17 | '城市住宅',
18 | '村镇住宅',
19 | '交通运输',
20 | '河流',
21 | '湖泊',
22 | '坑塘']
23 |
24 |
25 | def five_classes():
26 | return [
27 | '不透明表面',
28 | '建筑',
29 | '灌木',
30 | '树',
31 | '车',
32 | ]
33 |
34 |
35 | def Print_data(dataset_name, class_name, train_dataset_len, optimizer_name, model, total_epochs):
36 | print('\ndataset:', dataset_name)
37 | print('classification:', class_name)
38 | print('Number samples {}.'.format(len(train_dataset_len))) # 将模型的种类数和名称进行打印
39 | print('\noptimizer:', optimizer_name)
40 | print('model:', model)
41 | print('epoch:', total_epochs)
42 | print("\nOK!,everything is fine,let's start training!\n")
43 |
44 |
45 | def Creat_LineGraph(traincd_line):
46 | x = range(len(traincd_line))
47 | y = traincd_line
48 | plt.plot(x, y, color="g", label="train cd H_acc", linewidth=0.3, marker=',')
49 | plt.xlabel('Epoch')
50 | plt.ylabel('Acc Value')
51 | plt.show()
52 |
--------------------------------------------------------------------------------
/utils/Loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import torch.nn as nn
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | class DiceLoss(nn.Module):
11 | def __init__(self, n_classes):
12 | super(DiceLoss, self).__init__()
13 | self.n_classes = n_classes # 物体的输入数量
14 |
15 | #没有问题,但是需要的是进行一个one_hot_的解码,来满足6个特征图
16 | def _one_hot_encoder(self, input_tensor):
17 | tensor_list = []
18 | for i in range(self.n_classes):
19 | temp_prob = input_tensor == i # * torch.ones_like(input_tensor)
20 | tensor_list.append(temp_prob.unsqueeze(1))
21 | output_tensor = torch.cat(tensor_list, dim=1)
22 | return output_tensor.float()
23 |
24 | def _dice_loss(self, score, target):
25 | target = target.float()
26 | smooth = 1e-5
27 | intersect = torch.sum(score * target)
28 | y_sum = torch.sum(target * target)
29 | z_sum = torch.sum(score * score)
30 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
31 | loss = 1 - loss
32 | return loss
33 |
34 | def forward(self, inputs, target, weight=None, softmax=False):
35 | if softmax:
36 | inputs = torch.softmax(inputs, dim=1)#12, 6, 256, 256
37 | target = self._one_hot_encoder(target)#[12, 6, 256, 256]
38 | if weight is None:
39 | weight = [1] * self.n_classes
40 | assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size())
41 | class_wise_dice = []
42 | loss = 0.0
43 | for i in range(0, self.n_classes):
44 | dice = self._dice_loss(inputs[:, i], target[:, i])
45 | class_wise_dice.append(1.0 - dice.item())
46 | loss += dice * weight[i]
47 | return loss / self.n_classes
--------------------------------------------------------------------------------
/utils/palette.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 |
4 | # 染色板将图片进行染色
5 | palette = [
6 | 255, 255, 255, # 0 #surface
7 | 0, 0, 255, # 1 #building
8 | 0, 255, 255, # 2 #low vegetation
9 | 0, 255, 0, # 3 #tree
10 | 255, 255, 0, # 4 #car
11 | 255, 0, 0, # 5 #clutter/background red
12 | ]
13 |
14 | zero_pad = 256 * 3 - len(palette)
15 | for i in range(zero_pad):
16 | palette.append(0)
17 |
18 |
19 | # 将grey mask转化为彩色mask
20 |
21 | # putpalette
22 | # 为“P”或者“L”图像增加一个调色板。对于“L”图像,它的模式将变化为“P”。
23 | # 调色板序列需要包含768项整数,每组三个值表示对应像素的红,绿和蓝。用户可以使用768个byte的字符串代替这个整数序列。
24 |
25 | def colorize_mask(mask):
26 | mask_color = Image.fromarray(mask.astype(np.uint8)).convert('P')
27 | mask_color.putpalette(palette)
28 | return mask_color
29 |
--------------------------------------------------------------------------------