├── .idea ├── .gitignore ├── Template.iml ├── deployment.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── remote-mappings.xml ├── Parameter ├── __init__.py ├── average_meter.py ├── lr_scheduler.py └── metric.py ├── README.md ├── configs └── configs.py ├── data ├── __init__.py ├── dataset.py ├── save.py └── sync_transforms.py ├── main.py ├── model ├── ST_Unet │ ├── deform_conv.py │ ├── model_resnet.py │ ├── vit_seg_configs.py │ ├── vit_seg_modeling.py │ └── vit_seg_modeling_resnet_skip.py ├── SwinUnet │ ├── swin_transformer_unet_skip_expand_decoder_sys.py │ └── vision_transformer.py ├── Swin_Transformer │ └── SwinT.py ├── TransUnet │ ├── vit_seg_configs.py │ ├── vit_seg_modeling.py │ └── vit_seg_modeling_resnet_skip.py ├── Unet │ ├── Unet.py │ └── _init_.py └── deeplabv3_version_1 │ ├── aspp.py │ ├── deeplabv3.py │ └── resnet.py ├── tool ├── Save_predict.py ├── predict.py ├── train.py └── val.py └── utils ├── Data_process.py ├── Loss.py └── palette.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/Template.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/remote-mappings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /Parameter/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wzysaber/ST_Unet_pytorch_Semantic-segmentation/b27f4d79ba85f81f793e17e686d6a7a1cd8b41ec/Parameter/__init__.py -------------------------------------------------------------------------------- /Parameter/average_meter.py: -------------------------------------------------------------------------------- 1 | # 对相应的参数进行定义 2 | class AverageMeter(object): 3 | def __init__(self): 4 | self.reset() 5 | 6 | def reset(self): 7 | self.val = 0 8 | self.avg = 0 9 | self.sum = 0 10 | self.count = 0 11 | 12 | def update(self, val, n=1): 13 | self.val = val #当前值 14 | self.sum += val * n 15 | self.count += n 16 | self.avg = self.sum / self.count #平均值 -------------------------------------------------------------------------------- /Parameter/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | 3 | 4 | class PolynomialLR(_LRScheduler): 5 | def __init__(self, optimizer, step_size, iter_max, power, last_epoch=-1): 6 | self.step_size = step_size 7 | self.iter_max = iter_max 8 | self.power = power 9 | self.last_epoch = last_epoch 10 | super(PolynomialLR, self).__init__(optimizer, last_epoch) 11 | 12 | def polynomial_decay(self, lr): 13 | return lr * (1 - float(self.last_epoch) / self.iter_max) ** self.power 14 | 15 | def get_lr(self): 16 | if((self.last_epoch == 0) or (self.last_epoch % self.step_size != 0) or (self.last_epoch > self.iter_max)): 17 | return [group['lr'] for group in self.optimizer.param_groups] 18 | return [self.polynomial_decay(lr) for lr in self.base_lrs] 19 | -------------------------------------------------------------------------------- /Parameter/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def matrix_change(conf_mat, num_classes=5): 5 | Matrix_data = conf_mat[:num_classes, :num_classes] 6 | return Matrix_data 7 | 8 | 9 | def confusion_matrix(pred, label, num_classes): 10 | mask = (label >= 0) & (label < num_classes) 11 | conf_mat = np.bincount(num_classes * label[mask].astype(int) + pred[mask], minlength=num_classes ** 2).reshape( 12 | num_classes, num_classes) 13 | Matrix_data = matrix_change(conf_mat) 14 | return Matrix_data 15 | 16 | 17 | def evaluate(Matrix_data): 18 | matrix = Matrix_data 19 | acc = np.diag(matrix).sum() / matrix.sum() 20 | acc_per_class = np.diag(matrix) / matrix.sum(axis=1) 21 | pre = np.nanmean(acc_per_class) 22 | 23 | recall_class = np.diag(matrix) / matrix.sum(axis=0) 24 | recall = np.nanmean(recall_class) 25 | 26 | F1_score = (2 * pre * recall) / (pre + recall) 27 | 28 | IoU = np.diag(matrix) / (matrix.sum(axis=1) + matrix.sum(axis=0) - np.diag(matrix)) 29 | mean_IoU = np.nanmean(IoU) 30 | 31 | # 求kappa 32 | pe = np.dot(np.sum(matrix, axis=0), np.sum(matrix, axis=1)) / (matrix.sum() ** 2) 33 | kappa = (acc - pe) / (1 - pe) 34 | return acc, acc_per_class, pre, IoU, mean_IoU, kappa, F1_score, recall 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 【论文阅读】Swin Transformer Embedding UNet用于遥感图像语义分割 2 | 3 | [TOC] 4 | 5 | 6 | 7 | Swin Transformer Embedding UNet for Remote Sensing Image Semantic Segmentation 8 | 9 | 全局上下文信息是遥感图像语义分割的关键 10 | 11 | 具有强大全局建模能力的Swin transformer 12 | 13 | 提出了一种新的RS图像语义分割框架ST-UNet型网络(UNet) 14 | 15 | 解决方案:将Swin transformer嵌入到经典的基于cnn的UNet中 16 | 17 | ST-UNet由Swin变压器和CNN并联构成了一种新型的双编码器结构 18 | 19 | 相应结构: 20 | 21 | - 建立像素级相关性来编码Swin变压器块中的空间信息 22 | - 构造了特征压缩模块(FCM) 23 | - 作为双编码器之间的桥梁,设计了一个关系聚合模块(RAM) 24 | 25 | 数据集的使用: 26 | 27 | - Vaihingen 28 | - Potsdam 29 | 30 | 31 | 32 | ## 一、相应介绍 33 | 34 | 35 | 36 | **具体作用:** 37 | 38 | - 编码器用于提取特征 39 | - 解码器在融合高级语义和低级空间信息的同时,尽可能精细地恢复图像分辨率 40 | 41 | u型网络(UNet)[14]利用解码器通过跳过连接来学习相应编码阶段的空间相关性 42 | 43 | 44 | 45 | 利用变压器的编码器-解码器结构来模拟序列中元素之间的相互作用。 46 | 47 | 本文针对CNN在全局建模方面的不足,提出了一种新的RS图像语义分割网络框架ST-UNet 48 | 49 | 50 | 51 | **相应结构层次:** 52 | 53 | - 以UNet中的编码器为主编码器,Swin变压器为辅助编码器,形成一个并行的双编码器结构 54 | - 设计良好的关系聚合模块(RAM)构建了从辅助编码器到主编码器的单向信息流 55 | - RAM是ST-UNet的关键组件 56 | - 将SIM卡附加到Swin变压器上,以探索全局特征的空间相关性 57 | - 使用FCM提高小尺度目标的分割精度 58 | 59 | 60 | 61 | **相应贡献:** 62 | 63 | - 构建了空间交互模块(SIM),重点关注空间维度上的像素级特征相关性,SIM还弥补了Swin变压器窗口机制所限制的全局建模能力 64 | - 提出了特征压缩模块(FCM),以缓解patch token下采样过程中小尺度特征的遗漏 65 | - 设计了一个随机存储器,从辅助编码器中提取与chanel相关的信息作为全局线索来指导主编码器 66 | 67 | 68 | 69 | ## 二、相关工作 70 | 71 | ### 2.1 基于CNN的遥感图像语义分割 72 | 73 | 存在数据集: 74 | 75 | - IEEE地球科学与遥感学会(IGARSS)数据融合大赛 76 | - SpaceNet比赛 77 | - DeepGlobe比赛 78 | 79 | 80 | 81 | **在检测方面的发展过程** 82 | 83 | (1)在最开始的发展中,多分支并行卷积结构生成多尺度特征图,并设计自适应空间池化模块聚合更多局部上下文 84 | 85 | (2)引入了多层感知器(MLP),以产生更好的分割结果,最早是在自然语言中使用的。 86 | 87 | (3)关注了小尺度特征的特征提取 88 | 89 | (4)结合了基于patch的像素分类和像素到像素分割,引入了不确定映射,以实现对小尺度物体的高性能 90 | 91 | (5) 通过密集融合策略实现小尺度特征的聚合 92 | 93 | (6)明确引入边缘检测模块[43]来监督边界特征学习 94 | 95 | (7)提出了两个简单的边缘损失增强模块来增强物体边界的保存 96 | 97 | 98 | 99 | ### 2.2 Self-Attention机制 100 | 101 | 最早的注意力在计算机视觉领域 102 | 103 | (1)Zhao et al[45]和Li et al[46]分别给出了视频字幕的区域级注意和帧级注意 104 | 105 | (2)SENet[48]通过全局平均池化层表示通道之间的关系,自动了解不同通道的重要性 106 | 107 | (3)CBAM[49]将通道级注意和空间级注意应用于自适应特征细化 108 | 109 | (4)Ding等[19]提出了patch attention module来突出feature map的重点区域 110 | 111 | (5)在GCN[51]框架的每个阶段引入通道注意块,对特征图进行分层优化 112 | 113 | (6)[52] 关注小批量图像中的相似对象,并通过自注意机制对它们之间的交互信息进行编码 114 | 115 | 116 | 117 | ### 2.3 Vision Transformer 118 | 119 | 首次提出用于机器翻译任务[53],超越了以往基于复杂递归或cnn的序列转导模型 120 | 121 | 标准transformer由多头自注意(MSA)、多层感知器(MLP)和层归一化(LN)组成 122 | 123 | 通过分割和扁平化将图像数据转化为一系列tokens 124 | 125 | 密集的预测任务,ViT仍然有巨大的训练成本,只能输出一个不能匹配预测目标(与输入图像分辨率相同)的低分辨率特征 126 | 127 | **在现在过程中的发展:** 128 | 129 | 1. SETR[58]将转换器视为编码器,结合简单的解码器对每一层的全局上下文进行建模,形成语义分割网络 130 | 2. PVT[59]模仿CNN主干的特点,在ViT中引入金字塔结构,获得多比例尺特征图 131 | 3. 基于移位窗口策略的Swin变压器,将MSA的计算限制在不重叠的窗口 132 | 4. 以Swin转换器为骨干,Cao等[31]和Lin等[32]开发了医学图像语义分割的u型编码器-解码器框架 133 | 5. TransUNet[20]和TransFuse[60]指出,纯transformer细分网络的效果并不理想,因为transformer只关注全局建模,缺乏定位能力 134 | 6. 创建了CNN和transformer的混合结构。TransUNet将CNN和transformer依次堆叠 135 | 136 | 137 | 138 | 在本文中采用Swin变压器块组成的辅助编码器为基于cnn的主编码器提供全局上下文信息,提出的ST-UNet首次将Swin变压器应用到RS图像分割任务中,弥补了纯cnn的不足,提高了分割精度 139 | 140 | 141 | 142 | ## 三、方法 143 | 144 | **ST-UNet中的三个重要模块:** 145 | 146 | - **RAM** 147 | - **SIM** 148 | - **FCM** 149 | 150 | 151 | 152 | ### 3.1 网络结构 153 | 154 | ST-UNet的整体架构 155 | 156 | ![image-20230203231637260](C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230203231637260.png) 157 | 158 | 159 | 160 | **相应组成部分:** 161 | 162 | - ST-UNet是Swin transformer和UNet的混合体,它继承了UNet的优良结构,采用跳跃式连接层连接编码器和解码器 163 | - ST-UNet构造了由基于cnn的残差网络和Swin变压器组成的双编码器 164 | - 通过RAM传输信息,充分获取RS图像的判别特征 165 | - 设计了SIM和FCM,进一步提高了Swin transformer的性能。 166 | 167 | 168 | 169 | **辅助 encoder部分** 170 | 171 | ------ 172 | 173 | **输入部分:** 174 | 175 | - RS图像X∈R^H×W×3^ 176 | - 数据划分为不重叠的patch,以模拟序列数据的“token” 177 | - 通过卷积从每张图像中获取重叠的patch token 178 | - patch尺寸为8 × 8,重叠率为50%。然后将线性嵌入层压平 179 | - patch投影到C1维 180 | - patch token被放入Swin变压器块堆叠的辅助编码器 181 | 182 | 183 | 184 | 辅助编码器有四个特征提取阶段,每个阶段的输出定义为Sn, n = 1,2,3,4。标准的Swin变压器块包括两种类型,即基于窗口的变压器(W-Trans)和移位的W-Trans (SW-Trans)。 185 | 186 | 187 | 188 | **提出在SIM卡上建立像素级的信息交换,加在Swin transformer块上** 189 | 190 | SIM可以有效地弥补基于窗口的自我注意的局限性,缓解遮挡引起的语义模糊问题 191 | 192 | 193 | 194 | **通过缩短patch令牌长度构建FCM** 195 | 196 | 为了在与主编码器的特征分辨率匹配的同时获得多尺度特征,FCM的提出可以减少小尺度物体特征的遗漏 197 | 198 | 阶段n的输出分辨率为(H/(2^n+1^) × (W/(2^n+1^),维度为(2^n−1^)*C1 199 | 200 | 201 | 202 | 203 | 204 | **主要encode部分** 205 | 206 | ------ 207 | 208 | **输入部分:** 209 | 210 | - 原始RS图像X先在通道上压缩一半后馈送到ResNet50 211 | - 第n个残差块的输出特征图可表示为An∈R(H/(2^n+1^))×(W/(2^n+1^))×2^n−1^C2 212 | - 将An和辅助编码器对应级的输出Sn送入RAM,融合结果返回主编码器。 213 | - RAM模块作为主辅编码器之间的桥梁,通过可变形卷积和通道注意机制建立连接。 214 | 215 | 216 | 217 | **解码部分** 218 | 219 | ------ 220 | 221 | **具体操作:** 222 | 223 | - 特征F∈R(H/32)×(W/32)×1024,经过卷积层后送入解码器。然后,我们将其输入到2 × 2反卷积层以扩大分辨率 224 | - UNet之后,ST-UNet利用跳过连接层来连接编码器和解码器特性 225 | - 3×3卷积层的减少通道数量 226 | - 每个卷积层都伴随着一个批处理归一层和一个ReLU层 227 | - 最后,对特征F进行3 × 3卷积层和线性插值上采样,得到最终的预测掩码。 228 | 229 | 230 | 231 | ### 3.2 Swin Transformer BlocK 232 | 233 | 为了高效建模,Swin变压器提出了具有两种分区配置的W-MSA来替代普通MSA 234 | 235 | **MSA变化:** 236 | 237 | - 常规窗口配置(W-MSA) 238 | - 移位窗口配置(SW-MSA) 239 | 240 | 每个窗口只覆盖D × D补丁,将D设为8,**将两个Swin变压器块重命名为W-Trans块和SW-Trans块** 241 | 242 | ![image-20230204093250772](C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230204093250772.png) 243 | 244 | **相应的结构图** 245 | 246 | ![image-20230204093425033](C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230204093425033.png) 247 | 248 | 249 | 250 | ### 3.3 空间交互模块 251 | 252 | Swin transformer块在有限的窗口内建立patch token关系,有效地减少内存开销 253 | 254 | **具体操作:** 255 | 256 | - 采用了规则窗口和移位窗口的交替执行策略 257 | - 提出了跨W-Trans和SW-Trans区块的SIM,以进一步增强信息交换 258 | - SIM在两个空间维度上引入注意力,考虑像素之间的关系,而不仅仅是patch token 259 | - 在输入阶段将输入数据转化为一维 260 | 261 | 262 | 263 | **SIM结构框图** 264 | 265 | ![image-20230204094200372](C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230204094200372.png) 266 | 267 | **SIM操作:** 268 | 269 | - 通过一个大的接受场,将特征向量以2的扩张速率 270 | - 在一个3 × 3的扩张卷积层上进行卷积 271 | - 将通道数缩减为c1/2。然后,采用全局平均池化操作 272 | 273 | 在竖直方向和水平方向上的总张量分别记为 h×1×(c1/2)和1×w×(c1/2),因此我们将两者相乘得到与位置相关的注意力图M, 张量h×w×(c1/2),最后,将M与SW-Trans块的输出sl+1相加。 274 | 275 | M的维数需要通过卷积层增加,以匹配特征sl+1的维数(所以进行了1X1卷积来改变通道数) 276 | 277 | 278 | 279 | ### 3.4 特征压缩模块 280 | 281 | 在transformer的前期工作中,通过将图像补丁[27]、[59]平化投影或合并2个×2相邻补丁的特征,并对[30]进行线性处理,形成了一个层次网络。 282 | 283 | 在Swin变压器的patch token下采样中设计了FCM 284 | 285 | FCM避免了大量细节和结构信息的丢失,物体密集、小尺度的RS图像的语义分割,提高了小尺度对象的分割效果。 286 | 287 | 288 | 289 | **FCM结构框图** 290 | 291 | ![image-20230204100536737](C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230204100536737.png) 292 | 293 | **一种分支是扩大卷积的块:** 294 | 295 | - 扩大卷积的接受场,广泛地收集小尺度物体的特征和结构信息 296 | - 采用前1 × 1卷积层增维 297 | - 中间3 × 3扩张卷积层获取广泛的结构信息 298 | - 后1 × 1卷积层降低特征尺度 299 | - 输出结构(h/2)×(w/2)×2c1 300 | 301 | 302 | 303 | **另一个分支:** 304 | 305 | - 引入了软池[61]操作,以获得更精细的下采样 306 | - 软池可以以指数加权的方式激活池化内核中的像素 307 | - 将软池后的特征输入到卷积层(增维) 308 | - 输出结构(h/2)×(w/2)×2c1 309 | 310 | 311 | 312 | **两个分支按等比例合并为FCM的输出L** 313 | 314 | 315 | 316 | ### 3.5 关系聚合模块 317 | 318 | 基于cnn的主编码器在空间维度上提取了受卷积核限制的局部信息,缺乏对channel维度[48]之间关系的显式建模 319 | 320 | 提出了RAM,为了从整个特征图中强调重要且更具代表性的channel,从辅助编码器的全局特征中提取channel依赖关系,然后将其嵌入到从主编码器获得的局部特征中。 321 | 322 | 323 | 324 | **RAM结构特征图** 325 | 326 | ![image-20230204102128379](C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230204102128379.png) 327 | 328 | **RAM引入了可变形卷积[63]以适应不同形状的目标区域** 329 | 330 | 331 | 332 | **具体操作:** 333 | 334 | - An和Sn分别表示第n阶段主编码器和辅助编码器的输出 335 | - An输入到可变形卷积中,An = δ(An)。这里δ是一个3 × 3的可变形卷积 336 | - Sn被发送到卷积层以改变维数,由于特征图的每个通道都可以看作是一个特征检测器 337 | - 我们应用average-和max-pool层来计算通道上特征映射的统计特征, 338 | - 发送到共享的全连接层,PA&M结构数为 1×1×(c1/2) 339 | - σ代表ReLu函数,$1被设置为一个大小减半的全连接层 340 | - PA&M与PS相乘来优化每个通道 341 | 342 | ![image-20230204103003050](C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230204103003050.png) 343 | 344 | 345 | 346 | δ代表sigmoid函数,$2是一个大小增加的完全连接层,并表示元素级乘法。 347 | 348 | 我们将Channel依赖P作为权值与变形卷积运算的结果An相乘,得到了细化的特征。最后,将细化后的特征与残差结构相连接,形成RAM的输出特征Tn 349 | 350 | ![image-20230204103550475](C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230204103550475.png) 351 | 352 | 353 | 354 | ## 四、实验结果 355 | 356 | ### 4.1 数据集 357 | 358 | Vaihingen Dataset 359 | 360 | 包含33张由先进机载传感器采集的真正射影像(TOP)图像,每个TOP图像都有红外(IR)、红色(R)和绿色(G)通道。 361 | 362 | **相应参数:** 363 | 364 | - 图像被标记为sic类别 365 | - 11张图像用于训练(图像id: 1、3、5、7、13、17、21、23、26、32和37) 366 | - 5张图像用于测试(图像id: 11、15、28、30和34), 367 | - 裁剪为256 × 256 368 | 369 | 370 | 371 | Potsdam Dataset 372 | 373 | 有38个相同大小的patch (6000 × 6000),都是从高分辨率TOP提取 374 | 375 | **相应参数:** 376 | 377 | - 数据集进行了六个类别的标注,用于语义分割研究 378 | - 每张图像都有三种通道组合,即IR-R-G、R-G-B和R-G-B- ir 379 | - 使用14张带有R-G-B的图像进行测试 380 | - (图像id: 2_13, 2_14, 3_13, 3_14, 4_13, 4_14, 4_15, 5_13, 5_14, 5_15, 6_14, 6_15, 7_13) 381 | - 其余24张带有R-G-B的图像进行训练 382 | - 我们将这些原始图像切割为256 × 256 383 | 384 | 385 | 386 | ### 4.2 具体参数 387 | 388 | **实验具体参数:** 389 | 390 | - 动量项为0.9,权重衰减为1e−4 391 | - SGD优化器 392 | - 初始学习率设置为0.01 393 | - 批处理大小设置为8 394 | - 最大epoch为100 395 | 396 | 397 | 398 | 采用联合损失 dice loss [71] LDice与骰子损失cross-entropy loss LCE 399 | 400 | ![image-20230204111154155](C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230204111154155.png) 401 | 402 | **评价指标:** 403 | 404 | - 平均交叉over联合(MIoU) 405 | - 平均F1 (Ave.F1) 406 | 407 | 408 | 409 | ### 4.3 消融实验 410 | 411 | 为了评估所提出的网络结构和三个重要模块的性能,我们将UNet作为基线网络 412 | 413 | 采用Vaihingen数据集 414 | 415 | **步骤:** 416 | 417 | - 在我们的ST-UNet中,主编码器采用半压缩的ResNet50 418 | - 辅助编码器采用“Tiny”配置的Swin变压器 419 | 420 | 421 | 422 | 主要分为两种: 423 | 424 | Add~LS~,即在编码的最后阶段才合并辅助编码器和主编码器的特征 425 | 426 | Add~ES~,辅助编码器和主编码器在每个编码阶段的特征,通过元素相加。 -------------------------------------------------------------------------------- /configs/configs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import os 4 | 5 | 6 | # 函数参数定义 7 | def parse_args(): 8 | parser = argparse.ArgumentParser(description="RemoteSensingSegmentation by PyTorch") 9 | 10 | # dataset 11 | parser.add_argument('--dataset-name', type=str, default='Vaihingen') 12 | parser.add_argument('--train-data-root', type=str, 13 | default='/home/students/master/2022/wangzy/PyCharm-Remote/ST_Unet_test/Vaihingen_Img/Train/') 14 | parser.add_argument('--val-data-root', type=str, 15 | default='/home/students/master/2022/wangzy/PyCharm-Remote/ST_Unet_test/Vaihingen_Img/Test/') 16 | parser.add_argument('--train-batch-size', type=int, default=8, metavar='N', 17 | help='batch size for training (default:16)') 18 | parser.add_argument('--val-batch-size', type=int, default=8, metavar='N', 19 | help='batch size for testing (default:16)') 20 | 21 | # output_save_path 22 | # strftime格式化时间,显示当前的时间 23 | parser.add_argument('--experiment-start-time', type=str, 24 | default=time.strftime('%m-%d-%H:%M:%S', time.localtime(time.time()))) 25 | parser.add_argument('--save-pseudo-data-path', type=str, 26 | default='/home/students/master/2022/wangzy/PyCharm-Remote/ST_Unet_test/pseudo_data') 27 | parser.add_argument('--save-file', default=False) 28 | 29 | # augmentation 30 | parser.add_argument('--base-size', type=int, default=256, help='base image size') 31 | parser.add_argument('--crop-size', type=int, default=256, help='crop image size') 32 | parser.add_argument('--flip-ratio', type=float, default=0.5) 33 | parser.add_argument('--resize-scale-range', type=str, default='0.5, 2.0') 34 | 35 | # model 36 | parser.add_argument('--model', type=str, default='Swin_Transformer', help='model name') 37 | parser.add_argument('--pretrained', action='store_true', default=True) 38 | 39 | # criterion 40 | # 损失的权重值 41 | parser.add_argument('--class-loss-weight', type=list, default= 42 | # [0.007814952234152803, 0.055862295151291756, 0.029094606950899726, 0.03104357983254851, 0.22757710412943985, 0.19666243636646102, 0.6088052968747066, 0.15683966777104494, 0.5288489922602664, 0.21668940382940433, 0.04310240828376457, 0.18284053575941367, 0.571096349549462, 0.32601488184885147, 0.45384359272537766, 1.0]) 43 | # [0.007956167959807792, 0.05664417300631733, 0.029857031694750392, 0.03198534634969046, 0.2309102255169529, 44 | # 0.19627322641039702, 0.6074939752850792, 0.16196525436190998, 0.5396602408824741, 0.22346488456565283, 45 | # 0.04453628275090391, 0.18672995330033487, 0.5990724459491834, 0.33183887346397484, 0.47737597643193597, 1.0] 46 | [0.008728536232175135, 0.05870821984204281, 0.030766985878693004, 0.03295408432939304, 0.2399409412190348, 47 | 0.20305583055639448, 0.6344888568739531, 0.16440413437125656, 0.5372260524694122, 0.22310945250778813, 48 | 0.04659596810284655, 0.19246378709444723, 0.6087430986295436, 0.34431415558778183, 0.4718853977371564, 1.0]) 49 | 50 | # loss 51 | parser.add_argument('--loss-names', type=str, default='cross_entropy') 52 | parser.add_argument('--classes-weight', type=str, default=None) 53 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='momentum (default:0.9)') 54 | parser.add_argument('--weight-decay', type=float, default=0.0001, metavar='M', help='weight-decay (default:1e-4)') 55 | 56 | # optimizer 57 | parser.add_argument('--optimizer-name', type=str, default='SGD') 58 | 59 | # learning_rate 60 | parser.add_argument('--base-lr', type=float, default=0.01, metavar='M', help='') 61 | 62 | # environment 63 | parser.add_argument('--use-cuda', action='store_true', default=True, help='using CUDA training') 64 | parser.add_argument('--num-GPUs', type=int, default=1, help='numbers of GPUs') 65 | parser.add_argument('--num_workers', type=int, default=32) 66 | 67 | # validation 68 | parser.add_argument('--eval', action='store_true', default=False, help='evaluation only') 69 | parser.add_argument('--no-val', action='store_true', default=False) 70 | 71 | parser.add_argument('--best-miou', type=float, default=0) 72 | 73 | parser.add_argument('--total-epochs', type=int, default=100, metavar='N', 74 | help='number of epochs to train (default: 120)') 75 | parser.add_argument('--start-epoch', type=int, default=0, metavar='N', help='start epoch (default:0)') 76 | 77 | parser.add_argument('--resume-path', type=str, default=None) 78 | 79 | args = parser.parse_args() 80 | 81 | directory = "weight/%s/%s/%s/" % (args.dataset_name, args.model, args.experiment_start_time) 82 | args.directory = directory 83 | 84 | if args.save_file: 85 | if not os.path.exists(directory): 86 | os.makedirs(directory) 87 | print("Creat and Save model.pth!") 88 | 89 | return args 90 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wzysaber/ST_Unet_pytorch_Semantic-segmentation/b27f4d79ba85f81f793e17e686d6a7a1cd8b41ec/data/__init__.py -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import os 3 | from PIL import Image 4 | from torchvision import transforms 5 | import numpy as np 6 | import torch 7 | from utils.Data_process import five_classes 8 | 9 | # 将图像数据转化为numpy型 10 | class MaskToTensor(object): # 将MaskToTensor定义为可以调用的类 11 | def __call__(self, img): 12 | return torch.from_numpy(np.array(img, dtype=np.int32)).long() 13 | 14 | 15 | # 对图像进行归一化的操作 16 | img_transform = transforms.Compose([ 17 | transforms.ToTensor(), 18 | transforms.Normalize([.485, .456, .406], [.229, .224, .225]) 19 | ]) 20 | mask_transform = MaskToTensor() 21 | 22 | 23 | class RSDataset(Dataset): 24 | def __init__(self, root=None, mode=None, img_transform=img_transform, mask_transform=mask_transform, 25 | sync_transforms=None): 26 | # 数据相关 27 | self.class_names = five_classes() # 图像中所包含的种类 28 | self.mode = mode 29 | self.img_transform = img_transform 30 | self.mask_transform = mask_transform 31 | self.sync_transform = sync_transforms 32 | self.sync_img_mask = [] 33 | 34 | if mode == "train": 35 | key_word = 'train_data' 36 | elif mode == "val": 37 | key_word = 'val_data' 38 | else: 39 | key_word = 'test_data' 40 | 41 | if mode == "src": 42 | img_dir = os.path.join(root, 'rgb') 43 | mask_dir = os.path.join(root, 'label') 44 | else: 45 | for dirname in os.listdir(root): 46 | # 进入选定的文件夹 47 | if dirname == key_word in dirname: 48 | break 49 | 50 | # 读取其中的图像数据 51 | 52 | img_dir = os.path.join(root, dirname, 'rgb') 53 | mask_dir = os.path.join(root, dirname, 'label') 54 | 55 | # 将相应的图像数据进行保存 56 | for img_filename in os.listdir(img_dir): 57 | img_mask_pair = (os.path.join(img_dir, img_filename), 58 | os.path.join(mask_dir, 59 | img_filename.replace(img_filename[-8:], "label_" + img_filename[-8:]))) 60 | 61 | self.sync_img_mask.append(img_mask_pair) 62 | 63 | # print(self.sync_img_mask) 64 | if (len(self.sync_img_mask)) == 0: 65 | print("Found 0 data, please check your dataset!") 66 | 67 | def __getitem__(self, index): 68 | num_class = 6 69 | ignore_label = 5 70 | 71 | img_path, mask_path = self.sync_img_mask[index] 72 | img = Image.open(img_path).convert('RGB') 73 | mask = Image.open(mask_path).convert('L') # 将图像转化为灰度值 74 | 75 | # 将图像进行相应的裁剪,变换等操作 76 | if self.sync_transform is not None: 77 | img, mask = self.sync_transform(img, mask) 78 | 79 | # 将原始图像进行归一化操作 80 | if self.img_transform is not None: 81 | img = self.img_transform(img) 82 | 83 | # 将标签图转化为可以操作的形式 84 | if self.mask_transform is not None: 85 | mask = self.mask_transform(mask) 86 | 87 | mask[mask >= num_class] = ignore_label 88 | mask[mask < 0] = ignore_label 89 | 90 | return img, mask 91 | 92 | def __len__(self): 93 | return len(self.sync_img_mask) 94 | 95 | def classes(self): 96 | return self.class_names 97 | 98 | 99 | if __name__ == "__main__": 100 | pass 101 | # RSDataset(class_name, root=args.train_data_root, mode='train', sync_transforms=None) 102 | -------------------------------------------------------------------------------- /data/save.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | from configs.configs import parse_args 5 | from PIL import Image 6 | from utils.palette import colorize_mask 7 | from torchvision import transforms 8 | 9 | args = parse_args() 10 | 11 | 12 | # 保存相应的工作参数 13 | def save_work(): 14 | directory = "work_dirs/%s/%s/%s/%s/" % (args.dataset_name, args.model, args.backbone, args.experiment_start_time) 15 | args.directory = directory 16 | if not os.path.exists(directory): 17 | os.makedirs(directory) 18 | 19 | config_file = os.path.join(directory, 'config.json') 20 | 21 | # 将相应参数转换为json格式,进行文本保存 22 | with open(config_file, 'w') as file: 23 | json.dump(vars(args), file, indent=4) 24 | 25 | if args.use_cuda: 26 | print('Numbers of GPUs:', args.num_GPUs) 27 | else: 28 | print("Using CPU") 29 | 30 | 31 | # 归一化操作 32 | # zip的作用将元素打包成为元组 33 | class DeNormalize(object): 34 | def __init__(self, mean, std): 35 | self.mean = mean 36 | self.std = std 37 | 38 | def __call__(self, tensor): 39 | for t, m, s in zip(tensor, self.mean, self.std): 40 | t.mul_(s).add_(m) 41 | return tensor 42 | 43 | 44 | resore_transform = transforms.Compose([ 45 | DeNormalize([.485, .456, .406], [.229, .224, .225]), # 对相应的数据进行归一化操作 46 | transforms.ToPILImage() # 将图片变化为可以查看的形式 47 | ]) 48 | 49 | 50 | def save_pic(score, data, preds, save_path, epoch, index): 51 | val_visual = [] 52 | # 将相应的图片进行保存到文件夹 53 | for i in range(score.shape[0]): 54 | 55 | num_score = np.sum(score[i] > 0.9) 56 | 57 | if num_score > 0.9 * (512 * 512): 58 | # 将图片进行归一化操作 59 | # 提取原始图像后进行操作 60 | img_pil = resore_transform(data[0][i]) 61 | 62 | # 将图片转化为灰度图片 63 | # 这个是我的预测图像 64 | preds_pil = Image.fromarray(preds[i].astype(np.uint8)).convert('L') 65 | 66 | # 将预测图片转化为RGB 67 | pred_vis_pil = colorize_mask(preds[i]) 68 | 69 | # 将图片转化为RGB 70 | gt_vis_pil = colorize_mask(data[1][i].numpy()) 71 | 72 | # 将相应的数据包装起来 73 | dir_list = ['rgb', 'label', 'vis_label', 'gt'] 74 | rgb_save_path = os.path.join(save_path, dir_list[0], str(epoch)) 75 | label_save_path = os.path.join(save_path, dir_list[1], str(epoch)) 76 | vis_save_path = os.path.join(save_path, dir_list[2], str(epoch)) 77 | gt_save_path = os.path.join(save_path, dir_list[3], str(epoch)) 78 | 79 | path_list = [rgb_save_path, label_save_path, vis_save_path, gt_save_path] 80 | 81 | # 创建相应的地址位置 82 | for path in range(4): 83 | if not os.path.exists(path_list[path]): 84 | os.makedirs(path_list[path]) 85 | 86 | # 将相应的地址位进行保存 87 | img_pil.save(os.path.join(path_list[0], 'img_batch_%d_%d.jpg' % (index, i))) 88 | preds_pil.save(os.path.join(path_list[1], 'label_%d_%d.png' % (index, i))) 89 | pred_vis_pil.save(os.path.join(path_list[2], 'vis_%d_%d.png' % (index, i))) 90 | gt_vis_pil.save(os.path.join(path_list[3], 'gt_%d_%d.png' % (index, i))) 91 | -------------------------------------------------------------------------------- /data/sync_transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | from PIL import Image, ImageOps, ImageFilter 3 | import numpy as np 4 | 5 | 6 | # 将transforms中的img和mask提取出来 7 | class Compose(object): 8 | def __init__(self, transforms): 9 | self.transforms = transforms 10 | 11 | def __call__(self, img, mask): 12 | assert img.size == mask.size 13 | for t in self.transforms: 14 | img, mask = t(img, mask) 15 | return img, mask 16 | 17 | 18 | class RandomScale(object): 19 | def __init__(self, base_size, crop_size, resize_scale_range): 20 | self.base_size = base_size 21 | self.crop_size = crop_size 22 | self.resize_scale_range = resize_scale_range 23 | 24 | def __call__(self, img, mask): 25 | w, h = img.size 26 | 27 | # print("img.size:", img.size) 28 | # randon.randint返回指定范围内的整数 29 | 30 | short_size = random.randint(int(self.base_size * self.resize_scale_range[0]), 31 | int(self.base_size * self.resize_scale_range[1])) 32 | # print("short_size:", short_size) 33 | # if h > w: 34 | # ow = short_size 35 | # oh = int(1.0 * h * ow / w) 36 | # else: 37 | # oh = short_size 38 | # ow = int(1.0 * w * oh / h) 39 | ow, oh = short_size, short_size 40 | # print("ow, oh = ", ow, oh) 41 | img, mask = img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST) # 对相应的图片进行缩放 42 | 43 | # 当相应的尺寸不够的时候做扩展 44 | if short_size < self.crop_size: 45 | padh = self.crop_size - oh if oh < self.crop_size else 0 46 | padw = self.crop_size - ow if ow < self.crop_size else 0 47 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 48 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0) 49 | 50 | w, h = img.size 51 | img = np.array(img) 52 | mask = np.array(mask) 53 | num_crop = 0 54 | while num_crop < 5: 55 | x = random.randint(0, w - self.crop_size) 56 | y = random.randint(0, h - self.crop_size) 57 | endx = x + self.crop_size 58 | endy = y + self.crop_size 59 | patch = img[y:endy, x:endx] 60 | if (patch == 0).all(): 61 | continue 62 | else: 63 | break 64 | img = img[y:endy, x:endx] 65 | mask = mask[y:endy, x:endx] 66 | img, mask = Image.fromarray(img), Image.fromarray(mask) 67 | return img, mask 68 | 69 | 70 | class RandomFlip(object): 71 | def __init__(self, flip_ratio=0.5): 72 | self.flip_ratio = flip_ratio 73 | 74 | def __call__(self, img, mask): 75 | if random.random() < self.flip_ratio: 76 | img, mask = img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT) 77 | else: 78 | img, mask = img.transpose(Image.FLIP_TOP_BOTTOM), mask.transpose(Image.FLIP_TOP_BOTTOM) 79 | return img, mask 80 | 81 | 82 | class RandomGaussianBlur(object): 83 | def __init__(self, prop): 84 | self.prop = prop 85 | 86 | def __call__(self, img, mask, prop): 87 | if random.random() < self.prop: 88 | img = img.filter(ImageFilter.GaussianBlur)(radius=random.random()) 89 | return img, mask 90 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from configs.configs import parse_args 5 | from model.deeplabv3_version_1.deeplabv3 import DeepLabV3 6 | from model.Unet.Unet import Unet 7 | # from model.ST_Unet.vit_seg_modeling import VisionTransformer 8 | # from model.ST_Unet.vit_seg_configs import get_r50_b16_config 9 | from model.SwinUnet.vision_transformer import SwinUnet 10 | from model.TransUnet.vit_seg_configs import get_r50_b16_config 11 | from model.TransUnet.vit_seg_modeling import VisionTransformer 12 | from model.Swin_Transformer.SwinT import SwinTransformerV2 13 | 14 | from tool.train import close_optimizer 15 | from tool.train import data_set 16 | from tool.train import training 17 | from tool.val import validating 18 | 19 | from utils.Loss import DiceLoss 20 | from utils.Data_process import Print_data 21 | from utils.Data_process import Creat_LineGraph 22 | 23 | # 忽略相应的警告 24 | import warnings 25 | 26 | warnings.filterwarnings("ignore") 27 | 28 | # 清除pytorch无用缓存 29 | import gc 30 | 31 | gc.collect() 32 | torch.cuda.empty_cache() 33 | 34 | # # 设置GPU的序列号 35 | import os 36 | 37 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" # 设置采用的GPU序号 38 | 39 | 40 | def main(): 41 | # # 所以,当这个参数设置为True时,启动算法的前期会比较慢,但算法跑起来以后会非常快 42 | torch.backends.cudnn.benchmark = True 43 | 44 | # 导入配置 45 | args = parse_args() 46 | 47 | # 加载训练和验证数据集 48 | train_loader = data_set(args)[0] 49 | train_dataset = data_set(args)[1] 50 | 51 | val_loader = data_set(args)[2] 52 | 53 | # 训练的相关配置 54 | device = torch.device("cuda:2") 55 | 56 | # 加载模型 57 | if args.model == "Unet": 58 | model = Unet(num_classes=6).to(device) 59 | elif args.model == "ST-Unet": 60 | config_vit = get_r50_b16_config() 61 | model = VisionTransformer(config_vit, img_size=256, num_classes=6).to(device) 62 | elif args.model == "deeplabv3+": 63 | model = DeepLabV3(num_classes=6).to(device) 64 | elif args.model == "SwinUnet": 65 | model = SwinUnet(num_classes=6).to(device) 66 | elif args.model == "TransUnet": 67 | config_vit = get_r50_b16_config() 68 | model = VisionTransformer(config_vit, img_size=256, num_classes=6).to(device) 69 | elif args.model == "Swin_Transformer": 70 | model = SwinTransformerV2().to(device) 71 | 72 | # 判断是否有训练好的模型 73 | if args.resume_path: 74 | state_dict = torch.load('.pth') 75 | model.load_state_dict(state_dict, state_dict=False) 76 | 77 | # 损失函数 78 | criterion1 = nn.CrossEntropyLoss().to(device) 79 | criterion2 = DiceLoss(6).to(device) 80 | 81 | # 优化器选择 82 | optimizer = close_optimizer(args, model).to(device) 83 | 84 | # 将相应的参数进行打印 85 | Print_data(args.dataset_name, train_dataset.class_names, 86 | train_dataset, args.optimizer_name, args.model, args.total_epochs) 87 | 88 | # 训练及验证 89 | traincd_Data = [] 90 | for epoch in range(args.start_epoch, args.total_epochs): 91 | ACC = training(args, 6, model, optimizer, train_dataset, train_loader, criterion1, criterion2, device, 92 | epoch) # 对模型进行训练zzzz 93 | validating(args, 6, model, optimizer, train_dataset, val_loader, device, epoch) # 对模型进行验证 94 | traincd_Data.append(ACC) 95 | print(" ") 96 | Creat_LineGraph(traincd_Data) # 绘制相应曲线图 97 | 98 | 99 | if __name__ == "__main__": 100 | main() 101 | -------------------------------------------------------------------------------- /model/ST_Unet/deform_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | import os 5 | 6 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" # 设置采用的GPU序号 7 | 8 | class DeformConv2d(nn.Module): 9 | def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=False): 10 | """ 11 | Args: 12 | modulation (bool, optional): If True, Modulated Defomable Convolution (Deformable ConvNets v2). 13 | """ 14 | super(DeformConv2d, self).__init__() 15 | self.kernel_size = kernel_size 16 | self.padding = padding 17 | self.stride = stride 18 | self.zero_padding = nn.ZeroPad2d(padding) 19 | self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias) 20 | 21 | self.p_conv = nn.Conv2d(inc, 2*kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride) 22 | nn.init.constant_(self.p_conv.weight, 0) 23 | self.p_conv.register_backward_hook(self._set_lr) 24 | 25 | self.modulation = modulation 26 | if modulation: 27 | self.m_conv = nn.Conv2d(inc, kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride) 28 | nn.init.constant_(self.m_conv.weight, 0) 29 | self.m_conv.register_backward_hook(self._set_lr) 30 | 31 | @staticmethod 32 | def _set_lr(module, grad_input, grad_output): 33 | grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input))) 34 | grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output))) 35 | 36 | def forward(self, x): 37 | offset = self.p_conv(x) 38 | if self.modulation: 39 | m = torch.sigmoid(self.m_conv(x)) 40 | 41 | dtype = offset.data.type() 42 | ks = self.kernel_size 43 | N = offset.size(1) // 2 44 | 45 | if self.padding: 46 | x = self.zero_padding(x) 47 | 48 | # (b, 2N, h, w) 49 | p = self._get_p(offset, dtype) 50 | 51 | # (b, h, w, 2N) 52 | p = p.contiguous().permute(0, 2, 3, 1) 53 | q_lt = p.detach().floor() 54 | q_rb = q_lt + 1 55 | 56 | q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long() 57 | q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long() 58 | q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1) 59 | q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1) 60 | 61 | # clip p 62 | p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1) 63 | 64 | # bilinear kernel (b, h, w, N) 65 | g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:])) 66 | g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:])) 67 | g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:])) 68 | g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:])) 69 | 70 | # (b, c, h, w, N) 71 | x_q_lt = self._get_x_q(x, q_lt, N) 72 | x_q_rb = self._get_x_q(x, q_rb, N) 73 | x_q_lb = self._get_x_q(x, q_lb, N) 74 | x_q_rt = self._get_x_q(x, q_rt, N) 75 | 76 | # (b, c, h, w, N) 77 | x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \ 78 | g_rb.unsqueeze(dim=1) * x_q_rb + \ 79 | g_lb.unsqueeze(dim=1) * x_q_lb + \ 80 | g_rt.unsqueeze(dim=1) * x_q_rt 81 | 82 | # modulation 83 | if self.modulation: 84 | m = m.contiguous().permute(0, 2, 3, 1) 85 | m = m.unsqueeze(dim=1) 86 | m = torch.cat([m for _ in range(x_offset.size(1))], dim=1) 87 | x_offset *= m 88 | 89 | x_offset = self._reshape_x_offset(x_offset, ks) 90 | out = self.conv(x_offset) 91 | 92 | return out 93 | 94 | def _get_p_n(self, N, dtype): 95 | p_n_x, p_n_y = torch.meshgrid( 96 | torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1), 97 | torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1)) 98 | # (2N, 1) 99 | p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0) 100 | p_n = p_n.view(1, 2*N, 1, 1).type(dtype) 101 | 102 | return p_n 103 | 104 | def _get_p_0(self, h, w, N, dtype): 105 | p_0_x, p_0_y = torch.meshgrid( 106 | torch.arange(1, h*self.stride+1, self.stride), 107 | torch.arange(1, w*self.stride+1, self.stride)) 108 | p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1) 109 | p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1) 110 | p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype) 111 | 112 | return p_0 113 | 114 | def _get_p(self, offset, dtype): 115 | N, h, w = offset.size(1)//2, offset.size(2), offset.size(3) 116 | 117 | # (1, 2N, 1, 1) 118 | p_n = self._get_p_n(N, dtype) 119 | # (1, 2N, h, w) 120 | p_0 = self._get_p_0(h, w, N, dtype) 121 | p = p_0 + p_n + offset 122 | return p 123 | 124 | def _get_x_q(self, x, q, N): 125 | b, h, w, _ = q.size() 126 | padded_w = x.size(3) 127 | c = x.size(1) 128 | # (b, c, h*w) 129 | x = x.contiguous().view(b, c, -1) 130 | 131 | # (b, h, w, N) 132 | index = q[..., :N]*padded_w + q[..., N:] # offset_x*w + offset_y 133 | # (b, c, h*w*N) 134 | index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1) 135 | 136 | x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N) 137 | 138 | return x_offset 139 | 140 | @staticmethod 141 | def _reshape_x_offset(x_offset, ks): 142 | b, c, h, w, N = x_offset.size() 143 | x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1) 144 | x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks) 145 | 146 | return x_offset -------------------------------------------------------------------------------- /model/ST_Unet/model_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from torch.nn import init 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | "3x3 convolution with padding" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, bias=False) 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, inplanes, planes, stride=1, downsample=None, use_cbam=False): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = conv3x3(inplanes, planes, stride) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.relu = nn.ReLU(inplace=True) 22 | self.conv2 = conv3x3(planes, planes) 23 | self.bn2 = nn.BatchNorm2d(planes) 24 | self.downsample = downsample 25 | self.stride = stride 26 | 27 | if use_cbam: 28 | self.cbam = CBAM(planes, 16) 29 | else: 30 | self.cbam = None 31 | 32 | def forward(self, x): 33 | residual = x 34 | 35 | out = self.conv1(x) 36 | out = self.bn1(out) 37 | out = self.relu(out) 38 | 39 | out = self.conv2(out) 40 | out = self.bn2(out) 41 | 42 | if self.downsample is not None: 43 | residual = self.downsample(x) 44 | 45 | if not self.cbam is None: 46 | out = self.cbam(out) 47 | 48 | out += residual 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | 54 | class Bottleneck(nn.Module): 55 | expansion = 4 56 | 57 | def __init__(self, inplanes, planes, stride=1, downsample=None, use_cbam=False): 58 | super(Bottleneck, self).__init__() 59 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 60 | self.bn1 = nn.BatchNorm2d(planes) 61 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 62 | padding=1, bias=False) 63 | self.bn2 = nn.BatchNorm2d(planes) 64 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 65 | self.bn3 = nn.BatchNorm2d(planes * 4) 66 | self.relu = nn.ReLU(inplace=True) 67 | self.downsample = downsample 68 | self.stride = stride 69 | 70 | if use_cbam: 71 | self.cbam = CBAM(planes * 4, 16) 72 | else: 73 | self.cbam = None 74 | 75 | def forward(self, x): 76 | residual = x 77 | 78 | out = self.conv1(x) 79 | out = self.bn1(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv3(out) 87 | out = self.bn3(out) 88 | 89 | if self.downsample is not None: 90 | residual = self.downsample(x) 91 | out += residual 92 | out = self.relu(out) 93 | 94 | return out 95 | 96 | 97 | class ResNet(nn.Module): 98 | def __init__(self, block, layers, att_type=None): 99 | self.inplanes = 64 100 | super(ResNet, self).__init__() 101 | 102 | # different model config between ImageNet and CIFAR 103 | 104 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | self.avgpool = nn.AvgPool2d(7) 107 | 108 | self.bn1 = nn.BatchNorm2d(64) 109 | self.relu = nn.ReLU(inplace=True) 110 | 111 | self.layer1 = self._make_layer(block, 64, layers[0], att_type=att_type) 112 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, att_type=att_type) 113 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, att_type=att_type) 114 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, att_type=att_type) 115 | 116 | init.kaiming_normal(self.fc.weight) 117 | for key in self.state_dict(): 118 | if key.split('.')[-1] == "weight": 119 | if "conv" in key: 120 | init.kaiming_normal(self.state_dict()[key], mode='fan_out') 121 | if "bn" in key: 122 | if "SpatialGate" in key: 123 | self.state_dict()[key][...] = 0 124 | else: 125 | self.state_dict()[key][...] = 1 126 | elif key.split(".")[-1] == 'bias': 127 | self.state_dict()[key][...] = 0 128 | 129 | def _make_layer(self, block, planes, blocks, stride=1, att_type=None): 130 | downsample = None 131 | if stride != 1 or self.inplanes != planes * block.expansion: 132 | downsample = nn.Sequential( 133 | nn.Conv2d(self.inplanes, planes * block.expansion, 134 | kernel_size=1, stride=stride, bias=False), 135 | nn.BatchNorm2d(planes * block.expansion), 136 | ) 137 | layers = [] 138 | layers.append(block(self.inplanes, planes, stride, downsample, use_cbam=att_type == 'CBAM')) 139 | self.inplanes = planes * block.expansion 140 | for i in range(1, blocks): 141 | layers.append(block(self.inplanes, planes, use_cbam=att_type == 'CBAM')) 142 | 143 | return nn.Sequential(*layers) 144 | 145 | def forward(self, x): 146 | x = self.conv1(x) 147 | x = self.bn1(x) 148 | x = self.relu(x) 149 | x = self.maxpool(x) 150 | 151 | x = self.layer1(x) 152 | x = self.layer2(x) 153 | x = self.layer3(x) 154 | x = self.layer4(x) 155 | x = self.avgpool(x) 156 | 157 | x = x.view(x.size(0), -1) 158 | x = self.fc(x) 159 | return x 160 | 161 | 162 | def ResidualNet(depth, att_type): 163 | assert depth in [18, 34, 50, 101], 'network depth should be 18, 34, 50 or 101' 164 | 165 | if depth == 18: 166 | model = ResNet(BasicBlock, [2, 2, 2, 2], att_type) 167 | 168 | elif depth == 34: 169 | model = ResNet(BasicBlock, [3, 4, 6, 3], att_type) 170 | 171 | elif depth == 50: 172 | model = ResNet(Bottleneck, [3, 4, 6, 3], att_type) 173 | 174 | elif depth == 101: 175 | model = ResNet(Bottleneck, [3, 4, 23, 3], att_type) 176 | 177 | return model 178 | -------------------------------------------------------------------------------- /model/ST_Unet/vit_seg_configs.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def get_b16_config(): 5 | """Returns the ViT-B/16 configuration.""" 6 | config = ml_collections.ConfigDict() 7 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 8 | config.hidden_size = 768 9 | config.transformer = ml_collections.ConfigDict() 10 | config.transformer.mlp_dim = 3072 11 | config.transformer.num_heads = 12 12 | config.transformer.num_layers = 1 13 | config.transformer.attention_dropout_rate = 0.0 14 | config.transformer.dropout_rate = 0.1 15 | # config.resnet.att_type = 'CBAM' 16 | config.classifier = 'seg' 17 | config.representation_size = None 18 | config.resnet_pretrained_path = None 19 | # config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz' 20 | config.patch_size = 16 21 | 22 | config.decoder_channels = (256, 128, 64, 16) 23 | config.n_classes = 2 24 | config.activation = 'softmax' 25 | return config 26 | 27 | 28 | def get_testing(): 29 | """Returns a minimal configuration for testing.""" 30 | config = ml_collections.ConfigDict() 31 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 32 | config.hidden_size = 1 33 | config.transformer = ml_collections.ConfigDict() 34 | config.transformer.mlp_dim = 1 35 | config.transformer.num_heads = 1 36 | config.transformer.num_layers = 1 37 | config.transformer.attention_dropout_rate = 0.0 38 | config.transformer.dropout_rate = 0.1 39 | config.classifier = 'token' 40 | config.representation_size = None 41 | return config 42 | 43 | 44 | def get_r50_b16_config(): 45 | """Returns the Resnet50 + ViT-B/16 configuration.-------------------------wo yong de """ 46 | config = get_b16_config() 47 | 48 | # 构建config.data容器,将不同的类型给放入进去 49 | config.data = ml_collections.ConfigDict() 50 | config.data.img_size = 256 # 6144 51 | config.data.in_chans = 3 52 | 53 | # 放入种类数目和相应的patch,就是256*256的图片划分成为4*4的patch结构,共256/4的数量 54 | config.n_classes = 6 55 | config.patches.grid = (4, 4) 56 | 57 | # 构建config.resnet容器,将不同的类型给放入进去 58 | config.resnet = ml_collections.ConfigDict() 59 | config.resnet.num_layers = (3, 4, 6, 3) # resnet的层数结构 60 | config.resnet.width_factor = 0.5 61 | 62 | config.classifier = 'seg' # 种类名称 63 | 64 | # 构建 config.trans容器,也就是辅助encoder(swin transformer)中的各个必要参数 65 | config.trans = ml_collections.ConfigDict() 66 | config.trans.num_heads = [3, 6, 12, 24] # 注意力的头的数目 67 | config.trans.depths = [2, 2, 6, 2] # swin transformer的网络结构深度 68 | config.trans.embed_dim = 96 69 | config.trans.window_size = 8 70 | 71 | # config.pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' #yuxunlian 72 | 73 | # (256,128,64,16)# 74 | # #1024,512,256,128,64) 75 | # #(2048,1024,512,256,128) 76 | # #(256, 128, 64, 16) 77 | # 解码的通道数 78 | config.decoder_channels = (512, 256, 128, 64) 79 | 80 | # 链接的通道数量 81 | # [256,128,64,16]#[512,256,128,64,16]#[512,256,128,64,32]#[1024,512,256,128,64]#[512, 256, 64, 16] 82 | config.skip_channels = [512, 256, 128, 64] 83 | 84 | config.n_classes = 6 # 分类的个数 85 | config.n_skip = 4 # 链接的次数,或者直接理解成阶段数 86 | config.activation = 'softmax' 87 | 88 | return config 89 | 90 | 91 | def get_b32_config(): 92 | """Returns the ViT-B/32 configuration.""" 93 | config = get_b16_config() 94 | config.patches.size = (32, 32) 95 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz' 96 | return config 97 | 98 | 99 | def get_l16_config(): 100 | """Returns the ViT-L/16 configuration.""" 101 | config = ml_collections.ConfigDict() 102 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 103 | config.hidden_size = 1024 104 | config.transformer = ml_collections.ConfigDict() 105 | config.transformer.mlp_dim = 4096 106 | config.transformer.num_heads = 16 107 | config.transformer.num_layers = 24 108 | config.transformer.attention_dropout_rate = 0.0 109 | config.transformer.dropout_rate = 0.1 110 | config.representation_size = None 111 | 112 | # custom 113 | config.classifier = 'seg' 114 | config.resnet_pretrained_path = None 115 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz' 116 | config.decoder_channels = (256, 128, 64, 16) 117 | config.n_classes = 2 118 | config.activation = 'softmax' 119 | return config 120 | 121 | 122 | def get_r50_l16_config(): 123 | """Returns the Resnet50 + ViT-L/16 configuration. customized """ 124 | config = get_l16_config() 125 | config.patches.grid = (16, 16) 126 | config.resnet = ml_collections.ConfigDict() 127 | config.resnet.num_layers = (3, 4, 9) 128 | config.resnet.width_factor = 1 129 | 130 | config.classifier = 'seg' 131 | config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' 132 | config.decoder_channels = (256, 128, 64, 16) 133 | config.skip_channels = [512, 256, 64, 16] 134 | config.n_classes = 2 135 | config.activation = 'softmax' 136 | return config 137 | 138 | 139 | def get_l32_config(): 140 | """Returns the ViT-L/32 configuration.""" 141 | config = get_l16_config() 142 | config.patches.size = (32, 32) 143 | return config 144 | 145 | 146 | def get_h14_config(): 147 | """Returns the ViT-L/16 configuration.""" 148 | config = ml_collections.ConfigDict() 149 | config.patches = ml_collections.ConfigDict({'size': (14, 14)}) 150 | config.hidden_size = 1280 151 | config.transformer = ml_collections.ConfigDict() 152 | config.transformer.mlp_dim = 5120 153 | config.transformer.num_heads = 16 154 | config.transformer.num_layers = 32 155 | config.transformer.attention_dropout_rate = 0.0 156 | config.transformer.dropout_rate = 0.1 157 | config.classifier = 'token' 158 | config.representation_size = None 159 | 160 | return config 161 | -------------------------------------------------------------------------------- /model/ST_Unet/vit_seg_modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import copy 7 | import logging 8 | import math 9 | 10 | from os.path import join as pjoin 11 | 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | 16 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm 17 | from torch.nn.modules.utils import _pair 18 | from scipy import ndimage 19 | 20 | # import vit_seg_configs as configs 21 | # from vit_seg_modeling_resnet_skip import TransResNetV2 22 | # from model_resnet import * 23 | 24 | from model.ST_Unet import vit_seg_configs as configs 25 | from model.ST_Unet.vit_seg_modeling_resnet_skip import TransResNetV2 26 | from model.ST_Unet.model_resnet import * 27 | 28 | # 忽略相应的警告 29 | import warnings 30 | 31 | warnings.filterwarnings("ignore") 32 | 33 | logger = logging.getLogger(__name__) 34 | 35 | ATTENTION_Q = "MultiHeadDotProductAttention_1/query" 36 | ATTENTION_K = "MultiHeadDotProductAttention_1/key" 37 | ATTENTION_V = "MultiHeadDotProductAttention_1/value" 38 | ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" 39 | FC_0 = "MlpBlock_3/Dense_0" 40 | FC_1 = "MlpBlock_3/Dense_1" 41 | ATTENTION_NORM = "LayerNorm_0" 42 | MLP_NORM = "LayerNorm_2" 43 | 44 | 45 | def np2th(weights, conv=False): 46 | """Possibly convert HWIO to OIHW.""" 47 | if conv: 48 | weights = weights.transpose([3, 2, 0, 1]) 49 | return torch.from_numpy(weights) 50 | 51 | 52 | def swish(x): 53 | return x * torch.sigmoid(x) 54 | 55 | 56 | ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} 57 | 58 | 59 | class Attention(nn.Module): 60 | def __init__(self, config, vis): 61 | super(Attention, self).__init__() 62 | self.vis = vis 63 | self.num_attention_heads = config.transformer["num_heads"] # 12 64 | self.attention_head_size = int(config.hidden_size / self.num_attention_heads) # 768/12 65 | self.all_head_size = self.num_attention_heads * self.attention_head_size 66 | 67 | self.query = Linear(config.hidden_size, self.all_head_size) 68 | self.key = Linear(config.hidden_size, self.all_head_size) 69 | self.value = Linear(config.hidden_size, self.all_head_size) 70 | 71 | self.out = Linear(config.hidden_size, config.hidden_size) 72 | self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) 73 | self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) 74 | 75 | self.softmax = Softmax(dim=-1) 76 | 77 | def transpose_for_scores(self, x): 78 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 79 | x = x.view(*new_x_shape) 80 | return x.permute(0, 2, 1, 3) 81 | 82 | def forward(self, hidden_states): 83 | mixed_query_layer = self.query(hidden_states) 84 | mixed_key_layer = self.key(hidden_states) 85 | mixed_value_layer = self.value(hidden_states) 86 | 87 | query_layer = self.transpose_for_scores(mixed_query_layer) 88 | key_layer = self.transpose_for_scores(mixed_key_layer) 89 | value_layer = self.transpose_for_scores(mixed_value_layer) 90 | 91 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 92 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 93 | attention_probs = self.softmax(attention_scores) 94 | weights = attention_probs if self.vis else None 95 | attention_probs = self.attn_dropout(attention_probs) 96 | 97 | context_layer = torch.matmul(attention_probs, value_layer) 98 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 99 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 100 | context_layer = context_layer.view(*new_context_layer_shape) 101 | attention_output = self.out(context_layer) 102 | attention_output = self.proj_dropout(attention_output) 103 | return attention_output, weights 104 | 105 | 106 | class Mlp(nn.Module): 107 | def __init__(self, config): 108 | super(Mlp, self).__init__() 109 | self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) 110 | self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) 111 | self.act_fn = ACT2FN["gelu"] 112 | self.dropout = Dropout(config.transformer["dropout_rate"]) 113 | 114 | self._init_weights() 115 | 116 | def _init_weights(self): 117 | nn.init.xavier_uniform_(self.fc1.weight) 118 | nn.init.xavier_uniform_(self.fc2.weight) 119 | nn.init.normal_(self.fc1.bias, std=1e-6) 120 | nn.init.normal_(self.fc2.bias, std=1e-6) 121 | 122 | def forward(self, x): 123 | x = self.fc1(x) 124 | x = self.act_fn(x) 125 | x = self.dropout(x) 126 | x = self.fc2(x) 127 | x = self.dropout(x) 128 | return x 129 | 130 | 131 | class Block(nn.Module): 132 | def __init__(self, config, vis): 133 | super(Block, self).__init__() 134 | self.hidden_size = config.hidden_size 135 | self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) 136 | self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) 137 | self.ffn = Mlp(config) 138 | self.attn = Attention(config, vis) 139 | 140 | def forward(self, x): 141 | h = x 142 | x = self.attention_norm(x) 143 | x, weights = self.attn(x) 144 | x = x + h 145 | 146 | h = x 147 | x = self.ffn_norm(x) 148 | x = self.ffn(x) 149 | x = x + h 150 | return x, weights 151 | 152 | def load_from(self, weights, n_block): 153 | ROOT = f"Transformer/encoderblock_{n_block}" 154 | with torch.no_grad(): 155 | query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, 156 | self.hidden_size).t() 157 | key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() 158 | value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, 159 | self.hidden_size).t() 160 | out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, 161 | self.hidden_size).t() 162 | 163 | query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) 164 | key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) 165 | value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) 166 | out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) 167 | 168 | self.attn.query.weight.copy_(query_weight) 169 | self.attn.key.weight.copy_(key_weight) 170 | self.attn.value.weight.copy_(value_weight) 171 | self.attn.out.weight.copy_(out_weight) 172 | self.attn.query.bias.copy_(query_bias) 173 | self.attn.key.bias.copy_(key_bias) 174 | self.attn.value.bias.copy_(value_bias) 175 | self.attn.out.bias.copy_(out_bias) 176 | 177 | mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() 178 | mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() 179 | mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() 180 | mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() 181 | 182 | self.ffn.fc1.weight.copy_(mlp_weight_0) 183 | self.ffn.fc2.weight.copy_(mlp_weight_1) 184 | self.ffn.fc1.bias.copy_(mlp_bias_0) 185 | self.ffn.fc2.bias.copy_(mlp_bias_1) 186 | 187 | self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) 188 | self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) 189 | self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) 190 | self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) 191 | 192 | 193 | class Encoder(nn.Module): 194 | def __init__(self, config, vis): 195 | super(Encoder, self).__init__() 196 | self.vis = vis 197 | self.layer = nn.ModuleList() 198 | self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) 199 | for _ in range(config.transformer["num_layers"]): 200 | layer = Block(config, vis) 201 | self.layer.append(copy.deepcopy(layer)) 202 | 203 | def forward(self, hidden_states): 204 | attn_weights = [] 205 | for layer_block in self.layer: 206 | hidden_states, weights = layer_block(hidden_states) 207 | if self.vis: 208 | attn_weights.append(weights) 209 | # torch.Size([12, 256, 768]) 210 | encoded = self.encoder_norm(hidden_states) 211 | return encoded, attn_weights 212 | 213 | 214 | class Transformer(nn.Module): 215 | def __init__(self, config, img_size): 216 | super(Transformer, self).__init__() 217 | img_size = _pair(img_size) # 构建(256,256)的数组 218 | # print(img_size) 219 | self.hybrid_model = TransResNetV2(config, block_units=config.resnet.num_layers, 220 | width_factor=config.resnet.width_factor) # 建立混合transformer和cnn的模型 221 | 222 | def forward(self, input): 223 | x, features = self.hybrid_model(input) 224 | return x, features 225 | 226 | 227 | class Conv2dReLU(nn.Sequential): 228 | def __init__( 229 | self, 230 | in_channels, 231 | out_channels, 232 | kernel_size, 233 | padding=0, 234 | stride=1, 235 | use_batchnorm=True, 236 | ): 237 | conv = nn.Conv2d( 238 | in_channels, 239 | out_channels, 240 | kernel_size, 241 | stride=stride, 242 | padding=padding, 243 | bias=not (use_batchnorm), 244 | ) 245 | relu = nn.ReLU(inplace=True) 246 | 247 | bn = nn.BatchNorm2d(out_channels) 248 | 249 | super(Conv2dReLU, self).__init__(conv, bn, relu) 250 | 251 | 252 | class DecoderBlock(nn.Module): 253 | def __init__( 254 | self, 255 | in_channels, 256 | out_channels, 257 | skip_channels=0, 258 | use_batchnorm=True, 259 | ): 260 | super().__init__() 261 | 262 | self.conv1 = Conv2dReLU( 263 | in_channels // 2 + skip_channels, 264 | out_channels, 265 | kernel_size=3, 266 | padding=1, 267 | use_batchnorm=use_batchnorm, 268 | ) 269 | self.conv2 = Conv2dReLU( 270 | out_channels, 271 | out_channels, 272 | kernel_size=3, 273 | padding=1, 274 | use_batchnorm=use_batchnorm, 275 | ) 276 | self.conv3 = Conv2dReLU( 277 | in_channels // 2, 278 | in_channels // 2, 279 | kernel_size=3, 280 | padding=1, 281 | use_batchnorm=use_batchnorm, 282 | ) 283 | self.conv4 = Conv2dReLU( 284 | 64, 285 | 64, 286 | kernel_size=3, 287 | padding=1, 288 | use_batchnorm=use_batchnorm, 289 | ) 290 | self.up = nn.UpsamplingBilinear2d(scale_factor=2) 291 | 292 | # 逆卷积操作ConvTranspose2d 293 | self.conT = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 294 | 295 | self.conT1 = nn.ConvTranspose2d(64, 16, kernel_size=2, stride=2) 296 | 297 | def forward(self, x, skip=None): 298 | x = self.conT(x) 299 | 300 | if skip is not None: 301 | # skip = self.cbam(skip) #sptial attention 302 | 303 | x = torch.cat([x, skip], dim=1) 304 | 305 | x = self.conv1(x) 306 | 307 | x = self.conv2(x) 308 | else: 309 | x = self.conv3(x) 310 | x = self.conv4(x) 311 | x = self.conT1(x) 312 | 313 | return x 314 | 315 | 316 | class SegmentationHead(nn.Sequential): 317 | 318 | def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): 319 | conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) 320 | upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() 321 | super().__init__(conv2d, upsampling) 322 | 323 | 324 | class DecoderCup(nn.Module): 325 | def __init__(self, config): 326 | super().__init__() 327 | self.config = config 328 | head_channels = 1024 329 | decoder_channels = config.decoder_channels 330 | in_channels = [head_channels] + list(decoder_channels[:-1]) 331 | # print('-: ', decoder_channels) 332 | # in_channels=[512,512,256,128,64] 333 | out_channels = decoder_channels 334 | # (512, 256, 128, 64) 335 | 336 | if self.config.n_skip != 0: 337 | # print('self.config.n_skip',self.config.n_skip) 3 338 | skip_channels = self.config.skip_channels 339 | # print('self.config.n_skip', self.config.skip_channels)[512, 256, 64, 16] 340 | for i in range(4 - self.config.n_skip): # re-select the skip channels according to n_skip 4 341 | skip_channels[3 - i] = 0 # skip_channels[3]=0 i=0 3#[512,256,128,64,16]) 342 | else: 343 | skip_channels = [0, 0, 0, 0] 344 | 345 | # print(in_channels,out_channels, skip_channels) #[512, 256, 128, 64] (256, 128, 64, 16) [512, 256, 64, 0] 346 | blocks = [ 347 | DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) 348 | ] 349 | self.blocks = nn.ModuleList(blocks) 350 | 351 | self.conv_more = Conv2dReLU(1024, 1024, kernel_size=3, padding=1, use_batchnorm=True) 352 | 353 | def forward(self, x, features=None): 354 | # B, n_patch, hidden = hidden_states.size() # [12, 256, 768] reshape from (B, n_patch, hidden) to (B, h, w, hidden) 355 | 356 | x = self.conv_more(x) 357 | for i, decoder_block in enumerate(self.blocks): 358 | 359 | if i == 3: 360 | continue # 将初始化的32通道给排除 361 | 362 | if features is not None: 363 | skip = features[i] if (i < self.config.n_skip) else None # config.n_skip = 3 364 | # print('ss:', skip.shape) 365 | # print('x:', x.shape) 366 | else: 367 | skip = None 368 | x = decoder_block(x, skip=skip) 369 | 370 | x = decoder_block(x, skip=None) 371 | return x 372 | 373 | 374 | class VisionTransformer(nn.Module): 375 | def __init__(self, config, img_size=256, num_classes=6, zero_head=False): 376 | super(VisionTransformer, self).__init__() 377 | self.num_classes = num_classes # 分类数量 378 | self.zero_head = zero_head # 头的数量 379 | 380 | self.classifier = config.classifier 381 | self.transformer = Transformer(config, img_size) # 构造transformer和cnn的混合函数 382 | self.decoder = DecoderCup(config) 383 | self.segmentation_head = SegmentationHead( 384 | in_channels=16, 385 | out_channels=config['n_classes'], 386 | kernel_size=3, 387 | ) 388 | self.config = config 389 | 390 | def forward(self, x): 391 | # print('111', x.shape) 392 | x, features = self.transformer(x) # (B, n_patch, hidden) 393 | 394 | # print(x.shape, features.shape) 395 | x = self.decoder(x, features) 396 | # print(x.shape) 397 | logits = self.segmentation_head(x) 398 | return logits 399 | 400 | 401 | CONFIGS = { 402 | 'ViT-B_16': configs.get_b16_config(), 403 | 'ViT-B_32': configs.get_b32_config(), 404 | 'ViT-L_16': configs.get_l16_config(), 405 | 'ViT-L_32': configs.get_l32_config(), 406 | 'ViT-H_14': configs.get_h14_config(), 407 | 'R50-ViT-B_16': configs.get_r50_b16_config(), 408 | 'R50-ViT-L_16': configs.get_r50_l16_config(), 409 | 'testing': configs.get_testing(), 410 | } 411 | 412 | if __name__ == '__main__': 413 | config_vit = configs.get_r50_b16_config() 414 | model = VisionTransformer(config_vit, img_size=256, num_classes=6) 415 | 416 | image = torch.randn(32, 3, 256, 256) 417 | 418 | output = model(image) 419 | print("input:", image.shape) 420 | print("output:", output.shape) 421 | -------------------------------------------------------------------------------- /model/SwinUnet/swin_transformer_unet_skip_expand_decoder_sys.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.checkpoint as checkpoint 4 | from einops import rearrange 5 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 6 | 7 | 8 | class Mlp(nn.Module): 9 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 10 | super().__init__() 11 | out_features = out_features or in_features 12 | hidden_features = hidden_features or in_features 13 | self.fc1 = nn.Linear(in_features, hidden_features) 14 | self.act = act_layer() 15 | self.fc2 = nn.Linear(hidden_features, out_features) 16 | self.drop = nn.Dropout(drop) 17 | 18 | def forward(self, x): 19 | x = self.fc1(x) 20 | x = self.act(x) 21 | x = self.drop(x) 22 | x = self.fc2(x) 23 | x = self.drop(x) 24 | return x 25 | 26 | 27 | def window_partition(x, window_size): 28 | """ 29 | Args: 30 | x: (B, H, W, C) 31 | window_size (int): window size 32 | 33 | Returns: 34 | windows: (num_windows*B, window_size, window_size, C) 35 | """ 36 | B, H, W, C = x.shape 37 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 38 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 39 | return windows 40 | 41 | 42 | def window_reverse(windows, window_size, H, W): 43 | """ 44 | Args: 45 | windows: (num_windows*B, window_size, window_size, C) 46 | window_size (int): Window size 47 | H (int): Height of image 48 | W (int): Width of image 49 | 50 | Returns: 51 | x: (B, H, W, C) 52 | """ 53 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 54 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 55 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 56 | return x 57 | 58 | 59 | class WindowAttention(nn.Module): 60 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 61 | It supports both of shifted and non-shifted window. 62 | 63 | Args: 64 | dim (int): Number of input channels. 65 | window_size (tuple[int]): The height and width of the window. 66 | num_heads (int): Number of attention heads. 67 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 68 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 69 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 70 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 71 | """ 72 | 73 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 74 | 75 | super().__init__() 76 | self.dim = dim 77 | self.window_size = window_size # Wh, Ww 78 | self.num_heads = num_heads 79 | head_dim = dim // num_heads 80 | self.scale = qk_scale or head_dim ** -0.5 81 | 82 | # define a parameter table of relative position bias 83 | self.relative_position_bias_table = nn.Parameter( 84 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 85 | 86 | # get pair-wise relative position index for each token inside the window 87 | coords_h = torch.arange(self.window_size[0]) 88 | coords_w = torch.arange(self.window_size[1]) 89 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 90 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 91 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 92 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 93 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 94 | relative_coords[:, :, 1] += self.window_size[1] - 1 95 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 96 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 97 | self.register_buffer("relative_position_index", relative_position_index) 98 | 99 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 100 | self.attn_drop = nn.Dropout(attn_drop) 101 | self.proj = nn.Linear(dim, dim) 102 | self.proj_drop = nn.Dropout(proj_drop) 103 | 104 | trunc_normal_(self.relative_position_bias_table, std=.02) 105 | self.softmax = nn.Softmax(dim=-1) 106 | 107 | def forward(self, x, mask=None): 108 | """ 109 | Args: 110 | x: input features with shape of (num_windows*B, N, C) 111 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 112 | """ 113 | B_, N, C = x.shape 114 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 115 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 116 | 117 | q = q * self.scale 118 | attn = (q @ k.transpose(-2, -1)) 119 | 120 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 121 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 122 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 123 | attn = attn + relative_position_bias.unsqueeze(0) 124 | 125 | if mask is not None: 126 | nW = mask.shape[0] 127 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 128 | attn = attn.view(-1, self.num_heads, N, N) 129 | attn = self.softmax(attn) 130 | else: 131 | attn = self.softmax(attn) 132 | 133 | attn = self.attn_drop(attn) 134 | 135 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 136 | x = self.proj(x) 137 | x = self.proj_drop(x) 138 | return x 139 | 140 | def extra_repr(self) -> str: 141 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 142 | 143 | def flops(self, N): 144 | # calculate flops for 1 window with token length of N 145 | flops = 0 146 | # qkv = self.qkv(x) 147 | flops += N * self.dim * 3 * self.dim 148 | # attn = (q @ k.transpose(-2, -1)) 149 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 150 | # x = (attn @ v) 151 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 152 | # x = self.proj(x) 153 | flops += N * self.dim * self.dim 154 | return flops 155 | 156 | 157 | class SwinTransformerBlock(nn.Module): 158 | r""" Swin Transformer Block. 159 | 160 | Args: 161 | dim (int): Number of input channels. 162 | input_resolution (tuple[int]): Input resulotion. 163 | num_heads (int): Number of attention heads. 164 | window_size (int): Window size. 165 | shift_size (int): Shift size for SW-MSA. 166 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 167 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 168 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 169 | drop (float, optional): Dropout rate. Default: 0.0 170 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 171 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 172 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 173 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 174 | """ 175 | 176 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 177 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 178 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 179 | super().__init__() 180 | self.dim = dim 181 | self.input_resolution = input_resolution 182 | self.num_heads = num_heads 183 | self.window_size = window_size 184 | self.shift_size = shift_size 185 | self.mlp_ratio = mlp_ratio 186 | if min(self.input_resolution) <= self.window_size: 187 | # if window size is larger than input resolution, we don't partition windows 188 | self.shift_size = 0 189 | self.window_size = min(self.input_resolution) 190 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 191 | 192 | self.norm1 = norm_layer(dim) 193 | self.attn = WindowAttention( 194 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 195 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 196 | 197 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 198 | self.norm2 = norm_layer(dim) 199 | mlp_hidden_dim = int(dim * mlp_ratio) 200 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 201 | 202 | if self.shift_size > 0: 203 | # calculate attention mask for SW-MSA 204 | H, W = self.input_resolution 205 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 206 | h_slices = (slice(0, -self.window_size), 207 | slice(-self.window_size, -self.shift_size), 208 | slice(-self.shift_size, None)) 209 | w_slices = (slice(0, -self.window_size), 210 | slice(-self.window_size, -self.shift_size), 211 | slice(-self.shift_size, None)) 212 | cnt = 0 213 | for h in h_slices: 214 | for w in w_slices: 215 | img_mask[:, h, w, :] = cnt 216 | cnt += 1 217 | 218 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 219 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 220 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 221 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 222 | else: 223 | attn_mask = None 224 | 225 | self.register_buffer("attn_mask", attn_mask) 226 | 227 | def forward(self, x): 228 | H, W = self.input_resolution 229 | B, L, C = x.shape 230 | assert L == H * W, "input feature has wrong size" 231 | 232 | shortcut = x 233 | x = self.norm1(x) 234 | x = x.view(B, H, W, C) 235 | 236 | # cyclic shift 237 | if self.shift_size > 0: 238 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 239 | else: 240 | shifted_x = x 241 | 242 | # partition windows 243 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 244 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 245 | 246 | # W-MSA/SW-MSA 247 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 248 | 249 | # merge windows 250 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 251 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 252 | 253 | # reverse cyclic shift 254 | if self.shift_size > 0: 255 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 256 | else: 257 | x = shifted_x 258 | x = x.view(B, H * W, C) 259 | 260 | # FFN 261 | x = shortcut + self.drop_path(x) 262 | x = x + self.drop_path(self.mlp(self.norm2(x))) 263 | 264 | return x 265 | 266 | def extra_repr(self) -> str: 267 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 268 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 269 | 270 | def flops(self): 271 | flops = 0 272 | H, W = self.input_resolution 273 | # norm1 274 | flops += self.dim * H * W 275 | # W-MSA/SW-MSA 276 | nW = H * W / self.window_size / self.window_size 277 | flops += nW * self.attn.flops(self.window_size * self.window_size) 278 | # mlp 279 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 280 | # norm2 281 | flops += self.dim * H * W 282 | return flops 283 | 284 | 285 | class PatchMerging(nn.Module): 286 | r""" Patch Merging Layer. 287 | 288 | Args: 289 | input_resolution (tuple[int]): Resolution of input feature. 290 | dim (int): Number of input channels. 291 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 292 | """ 293 | 294 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 295 | super().__init__() 296 | self.input_resolution = input_resolution 297 | self.dim = dim 298 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 299 | self.norm = norm_layer(4 * dim) 300 | 301 | def forward(self, x): 302 | """ 303 | x: B, H*W, C 304 | """ 305 | H, W = self.input_resolution 306 | B, L, C = x.shape 307 | assert L == H * W, "input feature has wrong size" 308 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 309 | 310 | x = x.view(B, H, W, C) 311 | 312 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 313 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 314 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 315 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 316 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 317 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 318 | 319 | x = self.norm(x) 320 | x = self.reduction(x) 321 | 322 | return x 323 | 324 | # repr() 函数,得到的字符串通常可以用来重新获得该对象,将对象转化为供解释器读取的形式。 325 | def extra_repr(self) -> str: 326 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 327 | 328 | def flops(self): 329 | H, W = self.input_resolution 330 | flops = H * W * self.dim 331 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 332 | return flops 333 | 334 | 335 | class PatchExpand(nn.Module): 336 | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): 337 | super().__init__() 338 | self.input_resolution = input_resolution 339 | self.dim = dim 340 | self.expand = nn.Linear(dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity() 341 | self.norm = norm_layer(dim // dim_scale) 342 | 343 | def forward(self, x): 344 | """ 345 | x: B, H*W, C 346 | """ 347 | H, W = self.input_resolution 348 | x = self.expand(x) 349 | B, L, C = x.shape 350 | assert L == H * W, "input feature has wrong size" 351 | 352 | x = x.view(B, H, W, C) 353 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4) 354 | x = x.view(B, -1, C // 4) 355 | x = self.norm(x) 356 | 357 | return x 358 | 359 | 360 | class FinalPatchExpand_X4(nn.Module): 361 | def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm): 362 | super().__init__() 363 | self.input_resolution = input_resolution 364 | self.dim = dim 365 | self.dim_scale = dim_scale 366 | self.expand = nn.Linear(dim, 16 * dim, bias=False) 367 | self.output_dim = dim 368 | self.norm = norm_layer(self.output_dim) 369 | 370 | def forward(self, x): 371 | """ 372 | x: B, H*W, C 373 | """ 374 | H, W = self.input_resolution 375 | x = self.expand(x) 376 | B, L, C = x.shape 377 | assert L == H * W, "input feature has wrong size" 378 | 379 | x = x.view(B, H, W, C) 380 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, 381 | c=C // (self.dim_scale ** 2)) 382 | x = x.view(B, -1, self.output_dim) 383 | x = self.norm(x) 384 | 385 | return x 386 | 387 | 388 | class BasicLayer(nn.Module): 389 | """ A basic Swin Transformer layer for one stage. 390 | 391 | Args: 392 | dim (int): Number of input channels. 393 | input_resolution (tuple[int]): Input resolution. 394 | depth (int): Number of blocks. 395 | num_heads (int): Number of attention heads. 396 | window_size (int): Local window size. 397 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 398 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 399 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 400 | drop (float, optional): Dropout rate. Default: 0.0 401 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 402 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 403 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 404 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 405 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 406 | """ 407 | 408 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 409 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 410 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 411 | 412 | super().__init__() 413 | self.dim = dim 414 | self.input_resolution = input_resolution 415 | self.depth = depth 416 | self.use_checkpoint = use_checkpoint 417 | 418 | # build blocks 419 | self.blocks = nn.ModuleList([ 420 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 421 | num_heads=num_heads, window_size=window_size, 422 | shift_size=0 if (i % 2 == 0) else window_size // 2, 423 | mlp_ratio=mlp_ratio, 424 | qkv_bias=qkv_bias, qk_scale=qk_scale, 425 | drop=drop, attn_drop=attn_drop, 426 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 427 | norm_layer=norm_layer) 428 | for i in range(depth)]) 429 | 430 | # patch merging layer 431 | if downsample is not None: 432 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 433 | else: 434 | self.downsample = None 435 | 436 | def forward(self, x): 437 | for blk in self.blocks: 438 | if self.use_checkpoint: 439 | x = checkpoint.checkpoint(blk, x) 440 | else: 441 | x = blk(x) 442 | if self.downsample is not None: 443 | x = self.downsample(x) 444 | return x 445 | 446 | def extra_repr(self) -> str: 447 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 448 | 449 | def flops(self): 450 | flops = 0 451 | for blk in self.blocks: 452 | flops += blk.flops() 453 | if self.downsample is not None: 454 | flops += self.downsample.flops() 455 | return flops 456 | 457 | 458 | class BasicLayer_up(nn.Module): 459 | """ A basic Swin Transformer layer for one stage. 460 | 461 | Args: 462 | dim (int): Number of input channels. 463 | input_resolution (tuple[int]): Input resolution. 464 | depth (int): Number of blocks. 465 | num_heads (int): Number of attention heads. 466 | window_size (int): Local window size. 467 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 468 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 469 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 470 | drop (float, optional): Dropout rate. Default: 0.0 471 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 472 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 473 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 474 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 475 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 476 | """ 477 | 478 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 479 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 480 | drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False): 481 | 482 | super().__init__() 483 | self.dim = dim 484 | self.input_resolution = input_resolution 485 | self.depth = depth 486 | self.use_checkpoint = use_checkpoint 487 | 488 | # build blocks 489 | self.blocks = nn.ModuleList([ 490 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 491 | num_heads=num_heads, window_size=window_size, 492 | shift_size=0 if (i % 2 == 0) else window_size // 2, 493 | mlp_ratio=mlp_ratio, 494 | qkv_bias=qkv_bias, qk_scale=qk_scale, 495 | drop=drop, attn_drop=attn_drop, 496 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 497 | norm_layer=norm_layer) 498 | for i in range(depth)]) 499 | 500 | # patch merging layer 501 | if upsample is not None: 502 | self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer) 503 | else: 504 | self.upsample = None 505 | 506 | def forward(self, x): 507 | for blk in self.blocks: 508 | if self.use_checkpoint: 509 | x = checkpoint.checkpoint(blk, x) 510 | else: 511 | x = blk(x) 512 | if self.upsample is not None: 513 | x = self.upsample(x) 514 | return x 515 | 516 | 517 | class PatchEmbed(nn.Module): 518 | 519 | def __init__(self, img_size=256, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 520 | super().__init__() 521 | img_size = to_2tuple(img_size) 522 | patch_size = to_2tuple(patch_size) 523 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 524 | self.img_size = img_size 525 | self.patch_size = patch_size 526 | self.patches_resolution = patches_resolution 527 | self.num_patches = patches_resolution[0] * patches_resolution[1] 528 | 529 | self.in_chans = in_chans 530 | self.embed_dim = embed_dim 531 | 532 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 533 | if norm_layer is not None: 534 | self.norm = norm_layer(embed_dim) 535 | else: 536 | self.norm = None 537 | 538 | def forward(self, x): 539 | B, C, H, W = x.shape 540 | # FIXME look at relaxing size constraints 541 | assert H == self.img_size[0] and W == self.img_size[1], \ 542 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 543 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 544 | if self.norm is not None: 545 | x = self.norm(x) 546 | return x 547 | 548 | # 计算模型的复杂度 549 | def flops(self): 550 | Ho, Wo = self.patches_resolution 551 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 552 | if self.norm is not None: 553 | flops += Ho * Wo * self.embed_dim 554 | return flops 555 | 556 | 557 | class SwinTransformerSys(nn.Module): 558 | def __init__(self, img_size=256, patch_size=4, in_chans=3, num_classes=6, 559 | embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24], 560 | window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None, 561 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 562 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 563 | use_checkpoint=False, final_upsample="expand_first", **kwargs): 564 | super().__init__() 565 | 566 | self.num_classes = num_classes 567 | self.num_layers = len(depths) 568 | self.embed_dim = embed_dim 569 | self.ape = ape 570 | self.patch_norm = patch_norm 571 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 572 | self.num_features_up = int(embed_dim * 2) 573 | self.mlp_ratio = mlp_ratio 574 | self.final_upsample = final_upsample 575 | 576 | # split image into non-overlapping patches 577 | self.patch_embed = PatchEmbed( 578 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 579 | norm_layer=norm_layer if self.patch_norm else None) 580 | num_patches = self.patch_embed.num_patches 581 | patches_resolution = self.patch_embed.patches_resolution 582 | self.patches_resolution = patches_resolution 583 | 584 | # absolute position embedding 585 | if self.ape: 586 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 587 | trunc_normal_(self.absolute_pos_embed, std=.02) 588 | 589 | self.pos_drop = nn.Dropout(p=drop_rate) 590 | 591 | # stochastic depth 592 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 593 | 594 | # build encoder and bottleneck layers 595 | self.layers = nn.ModuleList() 596 | for i_layer in range(self.num_layers): 597 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 598 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 599 | patches_resolution[1] // (2 ** i_layer)), 600 | depth=depths[i_layer], 601 | num_heads=num_heads[i_layer], 602 | window_size=window_size, 603 | mlp_ratio=self.mlp_ratio, 604 | qkv_bias=qkv_bias, qk_scale=qk_scale, 605 | drop=drop_rate, attn_drop=attn_drop_rate, 606 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 607 | norm_layer=norm_layer, 608 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 609 | use_checkpoint=use_checkpoint) 610 | self.layers.append(layer) 611 | 612 | # build decoder layers 613 | self.layers_up = nn.ModuleList() 614 | self.concat_back_dim = nn.ModuleList() 615 | for i_layer in range(self.num_layers): 616 | concat_linear = nn.Linear(2 * int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)), 617 | int(embed_dim * 2 ** ( 618 | self.num_layers - 1 - i_layer))) if i_layer > 0 else nn.Identity() 619 | if i_layer == 0: 620 | layer_up = PatchExpand( 621 | input_resolution=(patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)), 622 | patches_resolution[1] // (2 ** (self.num_layers - 1 - i_layer))), 623 | dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)), dim_scale=2, norm_layer=norm_layer) 624 | else: 625 | layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)), 626 | input_resolution=( 627 | patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)), 628 | patches_resolution[1] // (2 ** (self.num_layers - 1 - i_layer))), 629 | depth=depths[(self.num_layers - 1 - i_layer)], 630 | num_heads=num_heads[(self.num_layers - 1 - i_layer)], 631 | window_size=window_size, 632 | mlp_ratio=self.mlp_ratio, 633 | qkv_bias=qkv_bias, qk_scale=qk_scale, 634 | drop=drop_rate, attn_drop=attn_drop_rate, 635 | drop_path=dpr[sum(depths[:(self.num_layers - 1 - i_layer)]):sum( 636 | depths[:(self.num_layers - 1 - i_layer) + 1])], 637 | norm_layer=norm_layer, 638 | upsample=PatchExpand if (i_layer < self.num_layers - 1) else None, 639 | use_checkpoint=use_checkpoint) 640 | self.layers_up.append(layer_up) 641 | self.concat_back_dim.append(concat_linear) 642 | 643 | self.norm = norm_layer(self.num_features) 644 | self.norm_up = norm_layer(self.embed_dim) 645 | 646 | if self.final_upsample == "expand_first": 647 | self.up = FinalPatchExpand_X4(input_resolution=(img_size // patch_size, img_size // patch_size), 648 | dim_scale=4, dim=embed_dim) 649 | self.output = nn.Conv2d(in_channels=embed_dim, out_channels=self.num_classes, kernel_size=1, bias=False) 650 | 651 | self.apply(self._init_weights) 652 | 653 | def _init_weights(self, m): 654 | if isinstance(m, nn.Linear): 655 | trunc_normal_(m.weight, std=.02) 656 | if isinstance(m, nn.Linear) and m.bias is not None: 657 | nn.init.constant_(m.bias, 0) 658 | elif isinstance(m, nn.LayerNorm): 659 | nn.init.constant_(m.bias, 0) 660 | nn.init.constant_(m.weight, 1.0) 661 | 662 | @torch.jit.ignore 663 | def no_weight_decay(self): 664 | return {'absolute_pos_embed'} 665 | 666 | @torch.jit.ignore 667 | def no_weight_decay_keywords(self): 668 | return {'relative_position_bias_table'} 669 | 670 | # Encoder and Bottleneck 671 | def forward_features(self, x): 672 | x = self.patch_embed(x) 673 | if self.ape: 674 | x = x + self.absolute_pos_embed 675 | x = self.pos_drop(x) 676 | x_downsample = [] 677 | 678 | for layer in self.layers: 679 | x_downsample.append(x) 680 | x = layer(x) 681 | 682 | x = self.norm(x) # B L C 683 | 684 | return x, x_downsample 685 | 686 | # Dencoder and Skip connection 687 | def forward_up_features(self, x, x_downsample): 688 | for inx, layer_up in enumerate(self.layers_up): 689 | if inx == 0: 690 | x = layer_up(x) 691 | else: 692 | x = torch.cat([x, x_downsample[3 - inx]], -1) 693 | x = self.concat_back_dim[inx](x) 694 | x = layer_up(x) 695 | 696 | x = self.norm_up(x) # B L C 697 | 698 | return x 699 | 700 | def up_x4(self, x): 701 | H, W = self.patches_resolution 702 | B, L, C = x.shape 703 | assert L == H * W, "input features has wrong size" 704 | 705 | if self.final_upsample == "expand_first": 706 | x = self.up(x) 707 | x = x.view(B, 4 * H, 4 * W, -1) 708 | x = x.permute(0, 3, 1, 2) # B,C,H,W 709 | x = self.output(x) 710 | 711 | return x 712 | 713 | def forward(self, x): 714 | x, x_downsample = self.forward_features(x) 715 | x = self.forward_up_features(x, x_downsample) 716 | x = self.up_x4(x) 717 | 718 | return x 719 | 720 | def flops(self): 721 | flops = 0 722 | flops += self.patch_embed.flops() 723 | for i, layer in enumerate(self.layers): 724 | flops += layer.flops() 725 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) 726 | flops += self.num_features * self.num_classes 727 | return flops 728 | -------------------------------------------------------------------------------- /model/SwinUnet/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import # 引入相对引入的概念 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import copy 7 | import logging 8 | import math 9 | 10 | from os.path import join as pjoin 11 | 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | 16 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm 17 | from torch.nn.modules.utils import _pair 18 | from scipy import ndimage 19 | from model.SwinUnet.swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSys 20 | 21 | # 忽略相应的警告 22 | import warnings 23 | warnings.filterwarnings("ignore") 24 | 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class SwinUnet(nn.Module): 30 | def __init__(self, img_size=256, num_classes=6, zero_head=False, vis=False): 31 | super(SwinUnet, self).__init__() 32 | self.num_classes = num_classes 33 | self.zero_head = zero_head 34 | 35 | self.swin_unet = SwinTransformerSys(img_size=256, 36 | patch_size=4, 37 | in_chans=3, 38 | num_classes=self.num_classes, 39 | embed_dim=96, 40 | depths=[2, 2, 2, 2], 41 | num_heads=[3, 6, 12, 24], 42 | window_size=8, 43 | mlp_ratio=0.2, 44 | qkv_bias=True, 45 | qk_scale=0., 46 | drop_rate=0., 47 | drop_path_rate=0.1, 48 | ape=False, 49 | # patch_norm=[0, 0, 0, 0], 50 | use_checkpoint=False) 51 | 52 | def forward(self, x): 53 | if x.size()[1] == 1: 54 | x = x.repeat(1, 3, 1, 1) 55 | logits = self.swin_unet(x) 56 | return logits 57 | 58 | # # 加载相应的预训练模型 59 | # def load_from(self, config): 60 | # pretrained_path = config.MODEL.PRETRAIN_CKPT 61 | # if pretrained_path is not None: 62 | # print("pretrained_path:{}".format(pretrained_path)) 63 | # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 64 | # pretrained_dict = torch.load(pretrained_path, map_location=device) 65 | # if "model" not in pretrained_dict: 66 | # print("---start load pretrained modle by splitting---") 67 | # pretrained_dict = {k[17:]: v for k, v in pretrained_dict.items()} 68 | # for k in list(pretrained_dict.keys()): 69 | # if "output" in k: 70 | # print("delete key:{}".format(k)) 71 | # del pretrained_dict[k] 72 | # msg = self.swin_unet.load_state_dict(pretrained_dict, strict=False) 73 | # # print(msg) 74 | # return 75 | # pretrained_dict = pretrained_dict['model'] 76 | # print("---start load pretrained modle of swin encoder---") 77 | # 78 | # model_dict = self.swin_unet.state_dict() 79 | # full_dict = copy.deepcopy(pretrained_dict) 80 | # for k, v in pretrained_dict.items(): 81 | # if "layers." in k: 82 | # current_layer_num = 3 - int(k[7:8]) 83 | # current_k = "layers_up." + str(current_layer_num) + k[8:] 84 | # full_dict.update({current_k: v}) 85 | # for k in list(full_dict.keys()): 86 | # if k in model_dict: 87 | # if full_dict[k].shape != model_dict[k].shape: 88 | # print("delete:{};shape pretrain:{};shape model:{}".format(k, v.shape, model_dict[k].shape)) 89 | # del full_dict[k] 90 | # 91 | # msg = self.swin_unet.load_state_dict(full_dict, strict=False) 92 | # # print(msg) 93 | # else: 94 | # print("none pretrain") 95 | 96 | 97 | if __name__ == '__main__': 98 | model = SwinUnet() 99 | model.eval() 100 | image = torch.randn(32, 3, 256, 256) 101 | 102 | output = model(image) 103 | print("input:", image.shape) 104 | print("output:", output.shape) 105 | -------------------------------------------------------------------------------- /model/Swin_Transformer/SwinT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.checkpoint as checkpoint 5 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 6 | import numpy as np 7 | from einops import rearrange 8 | 9 | import warnings 10 | 11 | warnings.filterwarnings("ignore") 12 | 13 | 14 | class Mlp(nn.Module): 15 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 16 | super().__init__() 17 | out_features = out_features or in_features 18 | hidden_features = hidden_features or in_features 19 | self.fc1 = nn.Linear(in_features, hidden_features) 20 | self.act = act_layer() 21 | self.fc2 = nn.Linear(hidden_features, out_features) 22 | self.drop = nn.Dropout(drop) 23 | 24 | def forward(self, x): 25 | x = self.fc1(x) 26 | x = self.act(x) 27 | x = self.drop(x) 28 | x = self.fc2(x) 29 | x = self.drop(x) 30 | return x 31 | 32 | 33 | def window_partition(x, window_size): 34 | """ 35 | Args: 36 | x: (B, H, W, C) 37 | window_size (int): window size 38 | 39 | Returns: 40 | windows: (num_windows*B, window_size, window_size, C) 41 | """ 42 | B, H, W, C = x.shape 43 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 44 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 45 | return windows 46 | 47 | 48 | def window_reverse(windows, window_size, H, W): 49 | """ 50 | Args: 51 | windows: (num_windows*B, window_size, window_size, C) 52 | window_size (int): Window size 53 | H (int): Height of image 54 | W (int): Width of image 55 | 56 | Returns: 57 | x: (B, H, W, C) 58 | """ 59 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 60 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 61 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 62 | return x 63 | 64 | 65 | class WindowAttention(nn.Module): 66 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 67 | It supports both of shifted and non-shifted window. 68 | 69 | Args: 70 | dim (int): Number of input channels. 71 | window_size (tuple[int]): The height and width of the window. 72 | num_heads (int): Number of attention heads. 73 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 74 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 75 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 76 | pretrained_window_size (tuple[int]): The height and width of the window in pre-training. 77 | """ 78 | 79 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., 80 | pretrained_window_size=[0, 0]): 81 | 82 | super().__init__() 83 | self.dim = dim 84 | self.window_size = window_size # Wh, Ww 85 | self.pretrained_window_size = pretrained_window_size 86 | self.num_heads = num_heads 87 | 88 | # 先生成tabel的值 89 | self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) 90 | 91 | # mlp to generate continuous relative position bias 92 | self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), 93 | nn.ReLU(inplace=True), 94 | nn.Linear(512, num_heads, bias=False)) 95 | 96 | # 在进行相应的相对tabel位置的索引 97 | relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) 98 | relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) 99 | relative_coords_table = torch.stack( 100 | torch.meshgrid([relative_coords_h, 101 | relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 102 | if pretrained_window_size[0] > 0: 103 | relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) 104 | relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) 105 | else: 106 | relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) 107 | relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) 108 | relative_coords_table *= 8 # normalize to -8, 8 109 | relative_coords_table = torch.sign(relative_coords_table) * torch.log2( 110 | torch.abs(relative_coords_table) + 1.0) / np.log2(8) 111 | 112 | self.register_buffer("relative_coords_table", relative_coords_table) 113 | 114 | # get pair-wise relative position index for each token inside the window 115 | coords_h = torch.arange(self.window_size[0]) 116 | coords_w = torch.arange(self.window_size[1]) 117 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 118 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 119 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 120 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 121 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 122 | relative_coords[:, :, 1] += self.window_size[1] - 1 123 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 124 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 125 | self.register_buffer("relative_position_index", relative_position_index) 126 | 127 | self.qkv = nn.Linear(dim, dim * 3, bias=False) 128 | if qkv_bias: 129 | self.q_bias = nn.Parameter(torch.zeros(dim)) 130 | self.v_bias = nn.Parameter(torch.zeros(dim)) 131 | else: 132 | self.q_bias = None 133 | self.v_bias = None 134 | self.attn_drop = nn.Dropout(attn_drop) 135 | self.proj = nn.Linear(dim, dim) 136 | self.proj_drop = nn.Dropout(proj_drop) 137 | self.softmax = nn.Softmax(dim=-1) 138 | 139 | def forward(self, x, mask=None): 140 | """ 141 | Args: 142 | x: input features with shape of (num_windows*B, N, C) 143 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 144 | """ 145 | B_, N, C = x.shape 146 | qkv_bias = None 147 | if self.q_bias is not None: 148 | qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) 149 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) 150 | qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 151 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 152 | 153 | # cosine attention 154 | attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) 155 | logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp() 156 | attn = attn * logit_scale 157 | 158 | # 根据索引去取相应的值 159 | relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) 160 | relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( 161 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 162 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 163 | relative_position_bias = 16 * torch.sigmoid(relative_position_bias) 164 | attn = attn + relative_position_bias.unsqueeze(0) 165 | 166 | # mask不为空,则进行的是SW-MSA操作,将mask加入 167 | if mask is not None: 168 | nW = mask.shape[0] 169 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 170 | attn = attn.view(-1, self.num_heads, N, N) 171 | attn = self.softmax(attn) 172 | else: 173 | attn = self.softmax(attn) 174 | 175 | attn = self.attn_drop(attn) 176 | 177 | # 这边是单纯的注意力机制,将相应的V 178 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 179 | x = self.proj(x) 180 | x = self.proj_drop(x) 181 | return x 182 | 183 | def extra_repr(self) -> str: 184 | return f'dim={self.dim}, window_size={self.window_size}, ' \ 185 | f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}' 186 | 187 | def flops(self, N): 188 | # calculate flops for 1 window with token length of N 189 | flops = 0 190 | # qkv = self.qkv(x) 191 | flops += N * self.dim * 3 * self.dim 192 | # attn = (q @ k.transpose(-2, -1)) 193 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 194 | # x = (attn @ v) 195 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 196 | # x = self.proj(x) 197 | flops += N * self.dim * self.dim 198 | return flops 199 | 200 | 201 | # 代码块的组成 202 | class SwinTransformerBlock(nn.Module): 203 | r""" Swin Transformer Block. 204 | 205 | Args: 206 | dim (int): Number of input channels. 207 | input_resolution (tuple[int]): Input resulotion. 208 | num_heads (int): Number of attention heads. 209 | window_size (int): Window size. 210 | shift_size (int): Shift size for SW-MSA. 211 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 212 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 213 | drop (float, optional): Dropout rate. Default: 0.0 214 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 215 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 216 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 217 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 218 | pretrained_window_size (int): Window size in pre-training. 219 | """ 220 | 221 | def __init__(self, 222 | dim, 223 | input_resolution, 224 | num_heads, 225 | window_size=7, 226 | shift_size=0, 227 | mlp_ratio=4., 228 | qkv_bias=True, drop=0., attn_drop=0., drop_path=0., 229 | act_layer=nn.GELU, 230 | norm_layer=nn.LayerNorm, 231 | pretrained_window_size=0): 232 | super().__init__() 233 | self.dim = dim 234 | self.input_resolution = input_resolution 235 | self.num_heads = num_heads 236 | self.window_size = window_size 237 | self.shift_size = shift_size 238 | self.mlp_ratio = mlp_ratio 239 | if min(self.input_resolution) <= self.window_size: 240 | # if window size is larger than input resolution, we don't partition windows 241 | self.shift_size = 0 242 | self.window_size = min(self.input_resolution) 243 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 244 | 245 | self.norm1 = norm_layer(dim) 246 | 247 | # 窗口注意力函数的实现 248 | self.attn = WindowAttention( 249 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 250 | qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, 251 | pretrained_window_size=to_2tuple(pretrained_window_size)) 252 | 253 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 254 | self.norm2 = norm_layer(dim) 255 | mlp_hidden_dim = int(dim * mlp_ratio) 256 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 257 | 258 | # 生成mask矩阵在需要的时候传进去 259 | if self.shift_size > 0: 260 | # calculate attention mask for SW-MSA 261 | H, W = self.input_resolution 262 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 263 | h_slices = (slice(0, -self.window_size), 264 | slice(-self.window_size, -self.shift_size), 265 | slice(-self.shift_size, None)) 266 | w_slices = (slice(0, -self.window_size), 267 | slice(-self.window_size, -self.shift_size), 268 | slice(-self.shift_size, None)) 269 | cnt = 0 270 | for h in h_slices: 271 | for w in w_slices: 272 | img_mask[:, h, w, :] = cnt 273 | cnt += 1 274 | 275 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 276 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 277 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 278 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 279 | else: 280 | attn_mask = None 281 | 282 | self.register_buffer("attn_mask", attn_mask) 283 | 284 | def forward(self, x): 285 | H, W = self.input_resolution 286 | B, L, C = x.shape 287 | assert L == H * W, "input feature has wrong size" 288 | 289 | shortcut = x 290 | x = x.view(B, H, W, C) 291 | 292 | # 窗口移动,shift_size大于0时成立,进行相应操作 293 | if self.shift_size > 0: 294 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 295 | else: 296 | shifted_x = x 297 | 298 | # partition windows,将一个patch(56*56)化为windows 299 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 300 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 301 | 302 | # W-MSA/SW-MSA进行选择 303 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 304 | 305 | # merge windows 306 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 307 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 308 | 309 | # reverse cyclic shift 310 | if self.shift_size > 0: 311 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 312 | else: 313 | x = shifted_x 314 | x = x.view(B, H * W, C) 315 | x = shortcut + self.drop_path(self.norm1(x)) 316 | 317 | # FFN 318 | x = x + self.drop_path(self.norm2(self.mlp(x))) 319 | 320 | return x 321 | 322 | def extra_repr(self) -> str: 323 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 324 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 325 | 326 | def flops(self): 327 | flops = 0 328 | H, W = self.input_resolution 329 | # norm1 330 | flops += self.dim * H * W 331 | # W-MSA/SW-MSA 332 | nW = H * W / self.window_size / self.window_size 333 | flops += nW * self.attn.flops(self.window_size * self.window_size) 334 | # mlp 335 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 336 | # norm2 337 | flops += self.dim * H * W 338 | return flops 339 | 340 | 341 | # 扩大全局视野的作用,每隔2,提取相应的元素,将4组合并,C变为4C,再进行卷积操作变为2C 342 | class PatchMerging(nn.Module): 343 | r""" Patch Merging Layer. 344 | 345 | Args: 346 | input_resolution (tuple[int]): Resolution of input feature. 347 | dim (int): Number of input channels. 348 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 349 | """ 350 | 351 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 352 | super().__init__() 353 | self.input_resolution = input_resolution 354 | self.dim = dim 355 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 356 | self.norm = norm_layer(2 * dim) 357 | 358 | def forward(self, x): 359 | """ 360 | x: B, H*W, C 361 | """ 362 | H, W = self.input_resolution 363 | B, L, C = x.shape 364 | assert L == H * W, "input feature has wrong size" 365 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 366 | 367 | x = x.view(B, H, W, C) 368 | 369 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 370 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 371 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 372 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 373 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 374 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 375 | 376 | # 进行相应的映射,将通道数减半 377 | x = self.reduction(x) 378 | x = self.norm(x) 379 | 380 | return x 381 | 382 | def extra_repr(self) -> str: 383 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 384 | 385 | def flops(self): 386 | H, W = self.input_resolution 387 | flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 388 | flops += H * W * self.dim // 2 389 | return flops 390 | 391 | 392 | # 设置transformer的代码块 393 | class BasicLayer(nn.Module): 394 | """ A basic Swin Transformer layer for one stage. 395 | 396 | Args: 397 | dim (int): Number of input channels. 398 | input_resolution (tuple[int]): Input resolution. 399 | depth (int): Number of blocks. 400 | num_heads (int): Number of attention heads. 401 | window_size (int): Local window size. 402 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 403 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 404 | drop (float, optional): Dropout rate. Default: 0.0 405 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 406 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 407 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 408 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 409 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 410 | pretrained_window_size (int): Local window size in pre-training. 411 | """ 412 | 413 | def __init__(self, 414 | dim, 415 | input_resolution, 416 | depth, num_heads, 417 | window_size, 418 | mlp_ratio=4., 419 | qkv_bias=True, 420 | drop=0., attn_drop=0., drop_path=0., 421 | norm_layer=nn.LayerNorm, 422 | downsample=None, use_checkpoint=False, 423 | pretrained_window_size=0): 424 | 425 | super().__init__() 426 | self.dim = dim 427 | self.input_resolution = input_resolution 428 | self.depth = depth 429 | self.use_checkpoint = use_checkpoint 430 | 431 | # 建立SwinTransformerBlock,也就是将W-MSA和SW-MSA进行相应的编码块 432 | self.blocks = nn.ModuleList([ 433 | SwinTransformerBlock(dim=dim, 434 | input_resolution=input_resolution, 435 | num_heads=num_heads, 436 | window_size=window_size, 437 | shift_size=0 if (i % 2 == 0) else window_size // 2, 438 | mlp_ratio=mlp_ratio, 439 | qkv_bias=qkv_bias, 440 | drop=drop, attn_drop=attn_drop, 441 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 442 | norm_layer=norm_layer, 443 | pretrained_window_size=pretrained_window_size) 444 | for i in range(depth)]) 445 | 446 | # patch merging layer 447 | if downsample is not None: 448 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 449 | else: 450 | self.downsample = None 451 | 452 | def forward(self, x): 453 | for blk in self.blocks: 454 | if self.use_checkpoint: 455 | x = checkpoint.checkpoint(blk, x) 456 | else: 457 | x = blk(x) 458 | if self.downsample is not None: 459 | x = self.downsample(x) 460 | return x 461 | 462 | def extra_repr(self) -> str: 463 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 464 | 465 | def flops(self): 466 | flops = 0 467 | for blk in self.blocks: 468 | flops += blk.flops() 469 | if self.downsample is not None: 470 | flops += self.downsample.flops() 471 | return flops 472 | 473 | def _init_respostnorm(self): 474 | for blk in self.blocks: 475 | nn.init.constant_(blk.norm1.bias, 0) 476 | nn.init.constant_(blk.norm1.weight, 0) 477 | nn.init.constant_(blk.norm2.bias, 0) 478 | nn.init.constant_(blk.norm2.weight, 0) 479 | 480 | 481 | class PatchEmbed(nn.Module): 482 | r""" Image to Patch Embedding 483 | 484 | Args: 485 | img_size (int): Image size. Default: 224. 486 | patch_size (int): Patch token size. Default: 4. 487 | in_chans (int): Number of input image channels. Default: 3. 488 | embed_dim (int): Number of linear projection output channels. Default: 96. 489 | norm_layer (nn.Module, optional): Normalization layer. Default: None 490 | """ 491 | 492 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 493 | super().__init__() 494 | img_size = to_2tuple(img_size) 495 | patch_size = to_2tuple(patch_size) 496 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 497 | self.img_size = img_size 498 | self.patch_size = patch_size 499 | self.patches_resolution = patches_resolution 500 | self.num_patches = patches_resolution[0] * patches_resolution[1] 501 | 502 | self.in_chans = in_chans 503 | self.embed_dim = embed_dim 504 | 505 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 506 | if norm_layer is not None: 507 | self.norm = norm_layer(embed_dim) 508 | else: 509 | self.norm = None 510 | 511 | def forward(self, x): 512 | B, C, H, W = x.shape 513 | # FIXME look at relaxing size constraints 514 | assert H == self.img_size[0] and W == self.img_size[1], \ 515 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 516 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 517 | if self.norm is not None: 518 | x = self.norm(x) 519 | return x 520 | 521 | def flops(self): 522 | Ho, Wo = self.patches_resolution 523 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 524 | if self.norm is not None: 525 | flops += Ho * Wo * self.embed_dim 526 | return flops 527 | 528 | 529 | class Decoder_block(nn.Module): 530 | def __init__(self): 531 | super().__init__() 532 | self.upsample_4 = nn.Sequential( 533 | nn.ConvTranspose2d(in_channels=768, out_channels=384, kernel_size=4, stride=2, padding=1) 534 | ) 535 | self.stage_up_4 = nn.Sequential( 536 | nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1), 537 | nn.BatchNorm2d(384), 538 | nn.ReLU(), 539 | nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1), 540 | nn.BatchNorm2d(384), 541 | nn.ReLU() 542 | ) 543 | 544 | self.upsample_3 = nn.Sequential( 545 | nn.ConvTranspose2d(in_channels=384, out_channels=192, kernel_size=4, stride=2, padding=1) 546 | ) 547 | self.stage_up_3 = nn.Sequential( 548 | nn.Conv2d(in_channels=192, out_channels=192, kernel_size=3, padding=1), 549 | nn.BatchNorm2d(192), 550 | nn.ReLU(), 551 | nn.Conv2d(in_channels=192, out_channels=192, kernel_size=3, padding=1), 552 | nn.BatchNorm2d(192), 553 | nn.ReLU() 554 | ) 555 | 556 | self.upsample_2 = nn.Sequential( 557 | nn.ConvTranspose2d(in_channels=192, out_channels=96, kernel_size=4, stride=2, padding=1) 558 | ) 559 | self.stage_up_2 = nn.Sequential( 560 | nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, padding=1), 561 | nn.BatchNorm2d(96), 562 | nn.ReLU(), 563 | nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, padding=1), 564 | nn.BatchNorm2d(96), 565 | nn.ReLU() 566 | ) 567 | 568 | self.upsample_1 = nn.Sequential( 569 | nn.ConvTranspose2d(in_channels=96, out_channels=48, kernel_size=4, stride=2, padding=1) 570 | ) 571 | self.stage_up_1 = nn.Sequential( 572 | nn.Conv2d(in_channels=48, out_channels=48, kernel_size=3, padding=1), 573 | nn.BatchNorm2d(48), 574 | nn.ReLU(), 575 | nn.Conv2d(in_channels=48, out_channels=48, kernel_size=3, padding=1), 576 | nn.BatchNorm2d(48), 577 | nn.ReLU() 578 | ) 579 | 580 | self.conT1 = nn.ConvTranspose2d(48, 16, kernel_size=2, stride=2) 581 | 582 | def forward(self, x): 583 | x = self.upsample_4(x) 584 | x = self.stage_up_4(x) 585 | x = self.upsample_3(x) 586 | x = self.stage_up_3(x) 587 | x = self.upsample_2(x) 588 | x = self.stage_up_2(x) 589 | x = self.upsample_1(x) 590 | x = self.stage_up_1(x) 591 | 592 | x = self.conT1(x) 593 | return x 594 | 595 | 596 | class SegmentationHead(nn.Sequential): 597 | 598 | def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): 599 | conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) 600 | upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() 601 | super().__init__(conv2d, upsampling) 602 | 603 | 604 | class SwinTransformerV2(nn.Module): 605 | r""" Swin Transformer 606 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 607 | https://arxiv.org/pdf/2103.14030 608 | 609 | Args: 610 | img_size (int | tuple(int)): Input image size. Default 224 611 | patch_size (int | tuple(int)): Patch size. Default: 4 612 | in_chans (int): Number of input image channels. Default: 3 613 | num_classes (int): Number of classes for classification head. Default: 1000 614 | embed_dim (int): Patch embedding dimension. Default: 96 615 | depths (tuple(int)): Depth of each Swin Transformer layer. 616 | num_heads (tuple(int)): Number of attention heads in different layers. 617 | window_size (int): Window size. Default: 7 618 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 619 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 620 | drop_rate (float): Dropout rate. Default: 0 621 | attn_drop_rate (float): Attention dropout rate. Default: 0 622 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 623 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 624 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 625 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 626 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 627 | pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer. 628 | """ 629 | 630 | def __init__(self, img_size=256, 631 | patch_size=4, 632 | in_chans=3, 633 | num_classes=12, 634 | embed_dim=96, 635 | depths=[2, 2, 6, 2], 636 | num_heads=[3, 6, 12, 24], 637 | window_size=8, 638 | mlp_ratio=4., 639 | qkv_bias=True, 640 | drop_rate=0., 641 | attn_drop_rate=0., 642 | drop_path_rate=0.1, 643 | norm_layer=nn.LayerNorm, 644 | ape=False, 645 | patch_norm=True, 646 | use_checkpoint=False, 647 | pretrained_window_sizes=[0, 0, 0, 0], 648 | **kwargs): 649 | super().__init__() 650 | 651 | self.num_classes = num_classes 652 | self.num_layers = len(depths) 653 | self.embed_dim = embed_dim 654 | self.ape = ape 655 | self.patch_norm = patch_norm 656 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 657 | self.mlp_ratio = mlp_ratio 658 | 659 | # 将224*224的图像4*4的patch大小进行相应的划分成56*56个 660 | self.patch_embed = PatchEmbed( 661 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 662 | norm_layer=norm_layer if self.patch_norm else None) 663 | 664 | # patch的总数量 665 | num_patches = self.patch_embed.num_patches 666 | 667 | # patch的相应方阵数量矩阵 668 | patches_resolution = self.patch_embed.patches_resolution 669 | self.patches_resolution = patches_resolution 670 | 671 | # 是否加入相对位置信息的嵌入 672 | if self.ape: 673 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 674 | trunc_normal_(self.absolute_pos_embed, std=.02) 675 | 676 | # 防止模型过拟合和欠拟合,p根据相应情况进行调整,一般属于0.4-0.6之间 677 | self.pos_drop = nn.Dropout(p=drop_rate) 678 | 679 | # 随机深度 680 | # 返回0到0.1之间的均匀间隔的参数值,数量和sum(depths)有关 681 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 682 | 683 | # 构建相应的w_MSA和SW_MSA的块 684 | # nn.ModuleList()以列表的形式保存多个子模块 685 | self.layers = nn.ModuleList() 686 | 687 | # 一共为4个阶段,整体采用len(depths),来进行确定 688 | for i_layer in range(self.num_layers): 689 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), # 不同阶段的维度数 690 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 691 | patches_resolution[1] // (2 ** i_layer)), 692 | depth=depths[i_layer], # transformer在这个阶段的个数 693 | num_heads=num_heads[i_layer], # 注意力机制的head的数目 694 | window_size=window_size, # 窗口的大小 695 | mlp_ratio=self.mlp_ratio, # mlp的比例大小 696 | qkv_bias=qkv_bias, # 偏差值 697 | drop=drop_rate, attn_drop=attn_drop_rate, # 下降的比列 698 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 699 | norm_layer=norm_layer, 700 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 701 | # 进行相应的PatchMerging操作,把通道数减半 702 | use_checkpoint=use_checkpoint, 703 | pretrained_window_size=pretrained_window_sizes[i_layer]) 704 | self.layers.append(layer) 705 | 706 | self.norm = norm_layer(self.num_features) 707 | self.avgpool = nn.AdaptiveAvgPool1d(1) # 进行池化操作,将相应的 708 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 709 | 710 | self.Decoder_block = Decoder_block() 711 | self.segmentation_head = SegmentationHead( 712 | in_channels=16, 713 | out_channels=6, 714 | kernel_size=3, 715 | ) 716 | 717 | self.apply(self._init_weights) 718 | for bly in self.layers: 719 | bly._init_respostnorm() 720 | 721 | def _init_weights(self, m): 722 | if isinstance(m, nn.Linear): 723 | trunc_normal_(m.weight, std=.02) 724 | if isinstance(m, nn.Linear) and m.bias is not None: 725 | nn.init.constant_(m.bias, 0) 726 | elif isinstance(m, nn.LayerNorm): 727 | nn.init.constant_(m.bias, 0) 728 | nn.init.constant_(m.weight, 1.0) 729 | 730 | @torch.jit.ignore 731 | def no_weight_decay(self): 732 | return {'absolute_pos_embed'} 733 | 734 | @torch.jit.ignore 735 | def no_weight_decay_keywords(self): 736 | return {"cpb_mlp", "logit_scale", 'relative_position_bias_table'} 737 | 738 | def forward_features(self, x): 739 | B = x.shape[0] 740 | x = self.patch_embed(x) 741 | if self.ape: 742 | x = x + self.absolute_pos_embed 743 | x = self.pos_drop(x) 744 | 745 | for layer in self.layers: 746 | x = layer(x) 747 | 748 | x = self.norm(x) 749 | # torch.Size([32, 49, 768]) 750 | x = x.contiguous().view(B, 8, 8, 768) 751 | x = rearrange(x, 'b h w c-> b c h w') 752 | 753 | x = self.Decoder_block(x) 754 | 755 | # x = self.norm(x) # B L C 756 | # x = self.avgpool(x.transpose(1, 2)) # B C 1 757 | # x = torch.flatten(x, 1) 758 | 759 | return x 760 | 761 | def forward(self, x): 762 | x = self.forward_features(x) 763 | logits = self.segmentation_head(x) 764 | return logits 765 | 766 | def flops(self): 767 | flops = 0 768 | flops += self.patch_embed.flops() 769 | for i, layer in enumerate(self.layers): 770 | flops += layer.flops() 771 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) 772 | flops += self.num_features * self.num_classes 773 | return flops 774 | 775 | 776 | if __name__ == '__main__': 777 | model = SwinTransformerV2() 778 | model.eval() 779 | image = torch.randn(32, 3, 256, 256) 780 | 781 | output = model(image) 782 | print("input:", image.shape) 783 | print("output:", output.shape) 784 | -------------------------------------------------------------------------------- /model/TransUnet/vit_seg_configs.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | def get_b16_config(): 4 | """Returns the ViT-B/16 configuration.""" 5 | config = ml_collections.ConfigDict() 6 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 7 | config.hidden_size = 768 8 | config.transformer = ml_collections.ConfigDict() 9 | config.transformer.mlp_dim = 3072 10 | config.transformer.num_heads = 12 11 | config.transformer.num_layers = 12 12 | config.transformer.attention_dropout_rate = 0.0 13 | config.transformer.dropout_rate = 0.1 14 | 15 | config.classifier = 'seg' 16 | config.representation_size = None 17 | config.resnet_pretrained_path = None 18 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz' 19 | config.patch_size = 16 20 | 21 | config.decoder_channels = (256, 128, 64, 16) 22 | config.n_classes = 2 23 | config.activation = 'softmax' 24 | return config 25 | 26 | 27 | def get_testing(): 28 | """Returns a minimal configuration for testing.""" 29 | config = ml_collections.ConfigDict() 30 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 31 | config.hidden_size = 1 32 | config.transformer = ml_collections.ConfigDict() 33 | config.transformer.mlp_dim = 1 34 | config.transformer.num_heads = 1 35 | config.transformer.num_layers = 1 36 | config.transformer.attention_dropout_rate = 0.0 37 | config.transformer.dropout_rate = 0.1 38 | config.classifier = 'token' 39 | config.representation_size = None 40 | return config 41 | 42 | def get_r50_b16_config(): 43 | """Returns the Resnet50 + ViT-B/16 configuration.""" 44 | config = get_b16_config() 45 | config.patches.grid = (16, 16) 46 | config.resnet = ml_collections.ConfigDict() 47 | config.resnet.num_layers = (3, 4, 9) 48 | config.resnet.width_factor = 1 49 | 50 | config.classifier = 'seg' 51 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' 52 | config.decoder_channels = (256, 128, 64, 16) 53 | config.skip_channels = [512, 256, 64, 16] 54 | config.n_classes = 6 55 | config.n_skip = 3 56 | config.activation = 'softmax' 57 | 58 | return config 59 | 60 | 61 | def get_b32_config(): 62 | """Returns the ViT-B/32 configuration.""" 63 | config = get_b16_config() 64 | config.patches.size = (32, 32) 65 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz' 66 | return config 67 | 68 | 69 | def get_l16_config(): 70 | """Returns the ViT-L/16 configuration.""" 71 | config = ml_collections.ConfigDict() 72 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 73 | config.hidden_size = 1024 74 | config.transformer = ml_collections.ConfigDict() 75 | config.transformer.mlp_dim = 4096 76 | config.transformer.num_heads = 16 77 | config.transformer.num_layers = 24 78 | config.transformer.attention_dropout_rate = 0.0 79 | config.transformer.dropout_rate = 0.1 80 | config.representation_size = None 81 | 82 | # custom 83 | config.classifier = 'seg' 84 | config.resnet_pretrained_path = None 85 | config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz' 86 | config.decoder_channels = (256, 128, 64, 16) 87 | config.n_classes = 2 88 | config.activation = 'softmax' 89 | return config 90 | 91 | 92 | def get_r50_l16_config(): 93 | """Returns the Resnet50 + ViT-L/16 configuration. customized """ 94 | config = get_l16_config() 95 | config.patches.grid = (16, 16) 96 | config.resnet = ml_collections.ConfigDict() 97 | config.resnet.num_layers = (3, 4, 9) 98 | config.resnet.width_factor = 1 99 | 100 | config.classifier = 'seg' 101 | config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' 102 | config.decoder_channels = (256, 128, 64, 16) 103 | config.skip_channels = [512, 256, 64, 16] 104 | config.n_classes = 2 105 | config.activation = 'softmax' 106 | return config 107 | 108 | 109 | def get_l32_config(): 110 | """Returns the ViT-L/32 configuration.""" 111 | config = get_l16_config() 112 | config.patches.size = (32, 32) 113 | return config 114 | 115 | 116 | def get_h14_config(): 117 | """Returns the ViT-L/16 configuration.""" 118 | config = ml_collections.ConfigDict() 119 | config.patches = ml_collections.ConfigDict({'size': (14, 14)}) 120 | config.hidden_size = 1280 121 | config.transformer = ml_collections.ConfigDict() 122 | config.transformer.mlp_dim = 5120 123 | config.transformer.num_heads = 16 124 | config.transformer.num_layers = 32 125 | config.transformer.attention_dropout_rate = 0.0 126 | config.transformer.dropout_rate = 0.1 127 | config.classifier = 'token' 128 | config.representation_size = None 129 | 130 | return config 131 | -------------------------------------------------------------------------------- /model/TransUnet/vit_seg_modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import copy 7 | import logging 8 | import math 9 | 10 | from os.path import join as pjoin 11 | 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | 16 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm 17 | from torch.nn.modules.utils import _pair 18 | from scipy import ndimage 19 | from model.TransUnet import vit_seg_configs as configs 20 | from model.TransUnet.vit_seg_modeling_resnet_skip import ResNetV2 21 | 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | ATTENTION_Q = "MultiHeadDotProductAttention_1/query" 27 | ATTENTION_K = "MultiHeadDotProductAttention_1/key" 28 | ATTENTION_V = "MultiHeadDotProductAttention_1/value" 29 | ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" 30 | FC_0 = "MlpBlock_3/Dense_0" 31 | FC_1 = "MlpBlock_3/Dense_1" 32 | ATTENTION_NORM = "LayerNorm_0" 33 | MLP_NORM = "LayerNorm_2" 34 | 35 | 36 | def np2th(weights, conv=False): 37 | """Possibly convert HWIO to OIHW.""" 38 | if conv: 39 | weights = weights.transpose([3, 2, 0, 1]) 40 | return torch.from_numpy(weights) 41 | 42 | 43 | def swish(x): 44 | return x * torch.sigmoid(x) 45 | 46 | 47 | ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} 48 | 49 | 50 | class Attention(nn.Module): 51 | def __init__(self, config, vis): 52 | super(Attention, self).__init__() 53 | self.vis = vis 54 | self.num_attention_heads = config.transformer["num_heads"] 55 | self.attention_head_size = int(config.hidden_size / self.num_attention_heads) 56 | self.all_head_size = self.num_attention_heads * self.attention_head_size 57 | 58 | self.query = Linear(config.hidden_size, self.all_head_size) 59 | self.key = Linear(config.hidden_size, self.all_head_size) 60 | self.value = Linear(config.hidden_size, self.all_head_size) 61 | 62 | self.out = Linear(config.hidden_size, config.hidden_size) 63 | self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) 64 | self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) 65 | 66 | self.softmax = Softmax(dim=-1) 67 | 68 | def transpose_for_scores(self, x): 69 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 70 | x = x.view(*new_x_shape) 71 | return x.permute(0, 2, 1, 3) 72 | 73 | def forward(self, hidden_states): 74 | mixed_query_layer = self.query(hidden_states) 75 | mixed_key_layer = self.key(hidden_states) 76 | mixed_value_layer = self.value(hidden_states) 77 | 78 | query_layer = self.transpose_for_scores(mixed_query_layer) 79 | key_layer = self.transpose_for_scores(mixed_key_layer) 80 | value_layer = self.transpose_for_scores(mixed_value_layer) 81 | 82 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 83 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 84 | attention_probs = self.softmax(attention_scores) 85 | weights = attention_probs if self.vis else None 86 | attention_probs = self.attn_dropout(attention_probs) 87 | 88 | context_layer = torch.matmul(attention_probs, value_layer) 89 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 90 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 91 | context_layer = context_layer.view(*new_context_layer_shape) 92 | attention_output = self.out(context_layer) 93 | attention_output = self.proj_dropout(attention_output) 94 | return attention_output, weights 95 | 96 | 97 | class Mlp(nn.Module): 98 | def __init__(self, config): 99 | super(Mlp, self).__init__() 100 | self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) 101 | self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) 102 | self.act_fn = ACT2FN["gelu"] 103 | self.dropout = Dropout(config.transformer["dropout_rate"]) 104 | 105 | self._init_weights() 106 | 107 | def _init_weights(self): 108 | nn.init.xavier_uniform_(self.fc1.weight) 109 | nn.init.xavier_uniform_(self.fc2.weight) 110 | nn.init.normal_(self.fc1.bias, std=1e-6) 111 | nn.init.normal_(self.fc2.bias, std=1e-6) 112 | 113 | def forward(self, x): 114 | x = self.fc1(x) 115 | x = self.act_fn(x) 116 | x = self.dropout(x) 117 | x = self.fc2(x) 118 | x = self.dropout(x) 119 | return x 120 | 121 | 122 | class Embeddings(nn.Module): 123 | """Construct the embeddings from patch, position embeddings. 124 | """ 125 | def __init__(self, config, img_size, in_channels=3): 126 | super(Embeddings, self).__init__() 127 | self.hybrid = None 128 | self.config = config 129 | img_size = _pair(img_size) 130 | 131 | if config.patches.get("grid") is not None: # ResNet 132 | grid_size = config.patches["grid"] 133 | patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1]) 134 | patch_size_real = (patch_size[0] * 16, patch_size[1] * 16) 135 | n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1]) 136 | self.hybrid = True 137 | else: 138 | patch_size = _pair(config.patches["size"]) 139 | n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) 140 | self.hybrid = False 141 | 142 | if self.hybrid: 143 | self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor) 144 | in_channels = self.hybrid_model.width * 16 145 | self.patch_embeddings = Conv2d(in_channels=in_channels, 146 | out_channels=config.hidden_size, 147 | kernel_size=patch_size, 148 | stride=patch_size) 149 | self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size)) 150 | 151 | self.dropout = Dropout(config.transformer["dropout_rate"]) 152 | 153 | 154 | def forward(self, x): 155 | if self.hybrid: 156 | x, features = self.hybrid_model(x) 157 | else: 158 | features = None 159 | x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2)) 160 | x = x.flatten(2) 161 | x = x.transpose(-1, -2) # (B, n_patches, hidden) 162 | 163 | embeddings = x + self.position_embeddings 164 | embeddings = self.dropout(embeddings) 165 | return embeddings, features 166 | 167 | 168 | class Block(nn.Module): 169 | def __init__(self, config, vis): 170 | super(Block, self).__init__() 171 | self.hidden_size = config.hidden_size 172 | self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) 173 | self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) 174 | self.ffn = Mlp(config) 175 | self.attn = Attention(config, vis) 176 | 177 | def forward(self, x): 178 | h = x 179 | x = self.attention_norm(x) 180 | x, weights = self.attn(x) 181 | x = x + h 182 | 183 | h = x 184 | x = self.ffn_norm(x) 185 | x = self.ffn(x) 186 | x = x + h 187 | return x, weights 188 | 189 | def load_from(self, weights, n_block): 190 | ROOT = f"Transformer/encoderblock_{n_block}" 191 | with torch.no_grad(): 192 | query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t() 193 | key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() 194 | value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t() 195 | out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t() 196 | 197 | query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) 198 | key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) 199 | value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) 200 | out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) 201 | 202 | self.attn.query.weight.copy_(query_weight) 203 | self.attn.key.weight.copy_(key_weight) 204 | self.attn.value.weight.copy_(value_weight) 205 | self.attn.out.weight.copy_(out_weight) 206 | self.attn.query.bias.copy_(query_bias) 207 | self.attn.key.bias.copy_(key_bias) 208 | self.attn.value.bias.copy_(value_bias) 209 | self.attn.out.bias.copy_(out_bias) 210 | 211 | mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() 212 | mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() 213 | mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() 214 | mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() 215 | 216 | self.ffn.fc1.weight.copy_(mlp_weight_0) 217 | self.ffn.fc2.weight.copy_(mlp_weight_1) 218 | self.ffn.fc1.bias.copy_(mlp_bias_0) 219 | self.ffn.fc2.bias.copy_(mlp_bias_1) 220 | 221 | self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) 222 | self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) 223 | self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) 224 | self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) 225 | 226 | 227 | class Encoder(nn.Module): 228 | def __init__(self, config, vis): 229 | super(Encoder, self).__init__() 230 | self.vis = vis 231 | self.layer = nn.ModuleList() 232 | self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) 233 | for _ in range(config.transformer["num_layers"]): 234 | layer = Block(config, vis) 235 | self.layer.append(copy.deepcopy(layer)) 236 | 237 | def forward(self, hidden_states): 238 | attn_weights = [] 239 | for layer_block in self.layer: 240 | hidden_states, weights = layer_block(hidden_states) 241 | if self.vis: 242 | attn_weights.append(weights) 243 | encoded = self.encoder_norm(hidden_states) 244 | return encoded, attn_weights 245 | 246 | 247 | class Transformer(nn.Module): 248 | def __init__(self, config, img_size, vis): 249 | super(Transformer, self).__init__() 250 | self.embeddings = Embeddings(config, img_size=img_size) 251 | self.encoder = Encoder(config, vis) 252 | 253 | def forward(self, input_ids): 254 | embedding_output, features = self.embeddings(input_ids) 255 | encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden) 256 | return encoded, attn_weights, features 257 | 258 | 259 | class Conv2dReLU(nn.Sequential): 260 | def __init__( 261 | self, 262 | in_channels, 263 | out_channels, 264 | kernel_size, 265 | padding=0, 266 | stride=1, 267 | use_batchnorm=True, 268 | ): 269 | conv = nn.Conv2d( 270 | in_channels, 271 | out_channels, 272 | kernel_size, 273 | stride=stride, 274 | padding=padding, 275 | bias=not (use_batchnorm), 276 | ) 277 | relu = nn.ReLU(inplace=True) 278 | 279 | bn = nn.BatchNorm2d(out_channels) 280 | 281 | super(Conv2dReLU, self).__init__(conv, bn, relu) 282 | 283 | 284 | class DecoderBlock(nn.Module): 285 | def __init__( 286 | self, 287 | in_channels, 288 | out_channels, 289 | skip_channels=0, 290 | use_batchnorm=True, 291 | ): 292 | super().__init__() 293 | self.conv1 = Conv2dReLU( 294 | in_channels + skip_channels, 295 | out_channels, 296 | kernel_size=3, 297 | padding=1, 298 | use_batchnorm=use_batchnorm, 299 | ) 300 | self.conv2 = Conv2dReLU( 301 | out_channels, 302 | out_channels, 303 | kernel_size=3, 304 | padding=1, 305 | use_batchnorm=use_batchnorm, 306 | ) 307 | self.up = nn.UpsamplingBilinear2d(scale_factor=2) 308 | 309 | def forward(self, x, skip=None): 310 | x = self.up(x) 311 | if skip is not None: 312 | x = torch.cat([x, skip], dim=1) 313 | x = self.conv1(x) 314 | x = self.conv2(x) 315 | return x 316 | 317 | 318 | class SegmentationHead(nn.Sequential): 319 | 320 | def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): 321 | conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) 322 | upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() 323 | super().__init__(conv2d, upsampling) 324 | 325 | 326 | class DecoderCup(nn.Module): 327 | def __init__(self, config): 328 | super().__init__() 329 | self.config = config 330 | head_channels = 512 331 | self.conv_more = Conv2dReLU( 332 | config.hidden_size, 333 | head_channels, 334 | kernel_size=3, 335 | padding=1, 336 | use_batchnorm=True, 337 | ) 338 | decoder_channels = config.decoder_channels 339 | in_channels = [head_channels] + list(decoder_channels[:-1]) 340 | out_channels = decoder_channels 341 | 342 | if self.config.n_skip != 0: 343 | skip_channels = self.config.skip_channels 344 | for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip 345 | skip_channels[3-i]=0 346 | 347 | else: 348 | skip_channels=[0,0,0,0] 349 | 350 | blocks = [ 351 | DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) 352 | ] 353 | self.blocks = nn.ModuleList(blocks) 354 | 355 | def forward(self, hidden_states, features=None): 356 | B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden) 357 | h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch)) 358 | x = hidden_states.permute(0, 2, 1) 359 | x = x.contiguous().view(B, hidden, h, w) 360 | x = self.conv_more(x) 361 | for i, decoder_block in enumerate(self.blocks): 362 | if features is not None: 363 | skip = features[i] if (i < self.config.n_skip) else None 364 | else: 365 | skip = None 366 | x = decoder_block(x, skip=skip) 367 | return x 368 | 369 | 370 | class VisionTransformer(nn.Module): 371 | def __init__(self, config, img_size=256, num_classes=6, zero_head=False, vis=False): 372 | super(VisionTransformer, self).__init__() 373 | self.num_classes = num_classes 374 | self.zero_head = zero_head 375 | self.classifier = config.classifier 376 | self.transformer = Transformer(config, img_size, vis) 377 | self.decoder = DecoderCup(config) 378 | self.segmentation_head = SegmentationHead( 379 | in_channels=config['decoder_channels'][-1], 380 | out_channels=config['n_classes'], 381 | kernel_size=3, 382 | ) 383 | self.config = config 384 | 385 | def forward(self, x): 386 | if x.size()[1] == 1: 387 | x = x.repeat(1,3,1,1) 388 | x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden) 389 | x = self.decoder(x, features) 390 | logits = self.segmentation_head(x) 391 | return logits 392 | 393 | def load_from(self, weights): 394 | with torch.no_grad(): 395 | 396 | res_weight = weights 397 | self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) 398 | self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) 399 | 400 | self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) 401 | self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) 402 | 403 | posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) 404 | 405 | posemb_new = self.transformer.embeddings.position_embeddings 406 | if posemb.size() == posemb_new.size(): 407 | self.transformer.embeddings.position_embeddings.copy_(posemb) 408 | elif posemb.size()[1]-1 == posemb_new.size()[1]: 409 | posemb = posemb[:, 1:] 410 | self.transformer.embeddings.position_embeddings.copy_(posemb) 411 | else: 412 | logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size())) 413 | ntok_new = posemb_new.size(1) 414 | if self.classifier == "seg": 415 | _, posemb_grid = posemb[:, :1], posemb[0, 1:] 416 | gs_old = int(np.sqrt(len(posemb_grid))) 417 | gs_new = int(np.sqrt(ntok_new)) 418 | print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) 419 | posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) 420 | zoom = (gs_new / gs_old, gs_new / gs_old, 1) 421 | posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np 422 | posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) 423 | posemb = posemb_grid 424 | self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) 425 | 426 | # Encoder whole 427 | for bname, block in self.transformer.encoder.named_children(): 428 | for uname, unit in block.named_children(): 429 | unit.load_from(weights, n_block=uname) 430 | 431 | if self.transformer.embeddings.hybrid: 432 | self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True)) 433 | gn_weight = np2th(res_weight["gn_root/scale"]).view(-1) 434 | gn_bias = np2th(res_weight["gn_root/bias"]).view(-1) 435 | self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) 436 | self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) 437 | 438 | for bname, block in self.transformer.embeddings.hybrid_model.body.named_children(): 439 | for uname, unit in block.named_children(): 440 | unit.load_from(res_weight, n_block=bname, n_unit=uname) 441 | 442 | CONFIGS = { 443 | 'ViT-B_16': configs.get_b16_config(), 444 | 'ViT-B_32': configs.get_b32_config(), 445 | 'ViT-L_16': configs.get_l16_config(), 446 | 'ViT-L_32': configs.get_l32_config(), 447 | 'ViT-H_14': configs.get_h14_config(), 448 | 'R50-ViT-B_16': configs.get_r50_b16_config(), 449 | 'R50-ViT-L_16': configs.get_r50_l16_config(), 450 | 'testing': configs.get_testing(), 451 | } 452 | 453 | 454 | if __name__ == '__main__': 455 | config_vit = configs.get_r50_b16_config() 456 | model = VisionTransformer(config_vit, img_size=256, num_classes=6) 457 | 458 | image = torch.randn(32, 3, 256, 256) 459 | 460 | output = model(image) 461 | print("input:", image.shape) 462 | print("output:", output.shape) -------------------------------------------------------------------------------- /model/TransUnet/vit_seg_modeling_resnet_skip.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from os.path import join as pjoin 4 | from collections import OrderedDict 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | def np2th(weights, conv=False): 12 | """Possibly convert HWIO to OIHW.""" 13 | if conv: 14 | weights = weights.transpose([3, 2, 0, 1]) 15 | return torch.from_numpy(weights) 16 | 17 | 18 | class StdConv2d(nn.Conv2d): 19 | 20 | def forward(self, x): 21 | w = self.weight 22 | v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) 23 | w = (w - m) / torch.sqrt(v + 1e-5) 24 | return F.conv2d(x, w, self.bias, self.stride, self.padding, 25 | self.dilation, self.groups) 26 | 27 | 28 | def conv3x3(cin, cout, stride=1, groups=1, bias=False): 29 | return StdConv2d(cin, cout, kernel_size=3, stride=stride, 30 | padding=1, bias=bias, groups=groups) 31 | 32 | 33 | def conv1x1(cin, cout, stride=1, bias=False): 34 | return StdConv2d(cin, cout, kernel_size=1, stride=stride, 35 | padding=0, bias=bias) 36 | 37 | 38 | class PreActBottleneck(nn.Module): 39 | """Pre-activation (v2) bottleneck block. 40 | """ 41 | 42 | def __init__(self, cin, cout=None, cmid=None, stride=1): 43 | super().__init__() 44 | cout = cout or cin 45 | cmid = cmid or cout//4 46 | 47 | self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6) 48 | self.conv1 = conv1x1(cin, cmid, bias=False) 49 | self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6) 50 | self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!! 51 | self.gn3 = nn.GroupNorm(32, cout, eps=1e-6) 52 | self.conv3 = conv1x1(cmid, cout, bias=False) 53 | self.relu = nn.ReLU(inplace=True) 54 | 55 | if (stride != 1 or cin != cout): 56 | # Projection also with pre-activation according to paper. 57 | self.downsample = conv1x1(cin, cout, stride, bias=False) 58 | self.gn_proj = nn.GroupNorm(cout, cout) 59 | 60 | def forward(self, x): 61 | 62 | # Residual branch 63 | residual = x 64 | if hasattr(self, 'downsample'): 65 | residual = self.downsample(x) 66 | residual = self.gn_proj(residual) 67 | 68 | # Unit's branch 69 | y = self.relu(self.gn1(self.conv1(x))) 70 | y = self.relu(self.gn2(self.conv2(y))) 71 | y = self.gn3(self.conv3(y)) 72 | 73 | y = self.relu(residual + y) 74 | return y 75 | 76 | def load_from(self, weights, n_block, n_unit): 77 | conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True) 78 | conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True) 79 | conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True) 80 | 81 | gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")]) 82 | gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")]) 83 | 84 | gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")]) 85 | gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")]) 86 | 87 | gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")]) 88 | gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")]) 89 | 90 | self.conv1.weight.copy_(conv1_weight) 91 | self.conv2.weight.copy_(conv2_weight) 92 | self.conv3.weight.copy_(conv3_weight) 93 | 94 | self.gn1.weight.copy_(gn1_weight.view(-1)) 95 | self.gn1.bias.copy_(gn1_bias.view(-1)) 96 | 97 | self.gn2.weight.copy_(gn2_weight.view(-1)) 98 | self.gn2.bias.copy_(gn2_bias.view(-1)) 99 | 100 | self.gn3.weight.copy_(gn3_weight.view(-1)) 101 | self.gn3.bias.copy_(gn3_bias.view(-1)) 102 | 103 | if hasattr(self, 'downsample'): 104 | proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True) 105 | proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")]) 106 | proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")]) 107 | 108 | self.downsample.weight.copy_(proj_conv_weight) 109 | self.gn_proj.weight.copy_(proj_gn_weight.view(-1)) 110 | self.gn_proj.bias.copy_(proj_gn_bias.view(-1)) 111 | 112 | class ResNetV2(nn.Module): 113 | """Implementation of Pre-activation (v2) ResNet mode.""" 114 | 115 | def __init__(self, block_units, width_factor): 116 | super().__init__() 117 | width = int(64 * width_factor) 118 | self.width = width 119 | 120 | self.root = nn.Sequential(OrderedDict([ 121 | ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)), 122 | ('gn', nn.GroupNorm(32, width, eps=1e-6)), 123 | ('relu', nn.ReLU(inplace=True)), 124 | # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)) 125 | ])) 126 | 127 | self.body = nn.Sequential(OrderedDict([ 128 | ('block1', nn.Sequential(OrderedDict( 129 | [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] + 130 | [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)], 131 | ))), 132 | ('block2', nn.Sequential(OrderedDict( 133 | [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] + 134 | [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)], 135 | ))), 136 | ('block3', nn.Sequential(OrderedDict( 137 | [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] + 138 | [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)], 139 | ))), 140 | ])) 141 | 142 | def forward(self, x): 143 | features = [] 144 | b, c, in_size, _ = x.size() 145 | x = self.root(x) 146 | features.append(x) 147 | x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x) 148 | for i in range(len(self.body)-1): 149 | x = self.body[i](x) 150 | right_size = int(in_size / 4 / (i+1)) 151 | if x.size()[2] != right_size: 152 | pad = right_size - x.size()[2] 153 | assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size) 154 | feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device) 155 | feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:] 156 | else: 157 | feat = x 158 | features.append(feat) 159 | x = self.body[-1](x) 160 | return x, features[::-1] 161 | -------------------------------------------------------------------------------- /model/Unet/Unet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | import numpy as np 5 | 6 | 7 | class Unet(nn.Module): 8 | def __init__(self, num_classes): 9 | super(Unet, self).__init__() 10 | self.stage_1 = nn.Sequential( 11 | nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1), 12 | nn.BatchNorm2d(32), 13 | nn.ReLU(), 14 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1), 15 | nn.BatchNorm2d(64), 16 | nn.ReLU(), 17 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1), 18 | nn.BatchNorm2d(64), 19 | nn.ReLU(), 20 | ) 21 | 22 | self.stage_2 = nn.Sequential( 23 | nn.MaxPool2d(kernel_size=2), 24 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1), 25 | nn.BatchNorm2d(128), 26 | nn.ReLU(), 27 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1), 28 | nn.BatchNorm2d(128), 29 | nn.ReLU(), 30 | ) 31 | 32 | self.stage_3 = nn.Sequential( 33 | nn.MaxPool2d(kernel_size=2), 34 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1), 35 | nn.BatchNorm2d(256), 36 | nn.ReLU(), 37 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1), 38 | nn.BatchNorm2d(256), 39 | nn.ReLU(), 40 | ) 41 | 42 | self.stage_4 = nn.Sequential( 43 | nn.MaxPool2d(kernel_size=2), 44 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1), 45 | nn.BatchNorm2d(512), 46 | nn.ReLU(), 47 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1), 48 | nn.BatchNorm2d(512), 49 | nn.ReLU(), 50 | ) 51 | 52 | self.stage_5 = nn.Sequential( 53 | nn.MaxPool2d(kernel_size=2), 54 | nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=1), 55 | nn.BatchNorm2d(1024), 56 | nn.ReLU(), 57 | nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, padding=1), 58 | nn.BatchNorm2d(1024), 59 | nn.ReLU(), 60 | ) 61 | 62 | self.upsample_4 = nn.Sequential( 63 | nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1) 64 | ) 65 | self.upsample_3 = nn.Sequential( 66 | nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1) 67 | ) 68 | self.upsample_2 = nn.Sequential( 69 | nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1) 70 | ) 71 | self.upsample_1 = nn.Sequential( 72 | nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1) 73 | ) 74 | 75 | self.stage_up_4 = nn.Sequential( 76 | nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1), 77 | nn.BatchNorm2d(512), 78 | nn.ReLU(), 79 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1), 80 | nn.BatchNorm2d(512), 81 | nn.ReLU() 82 | ) 83 | 84 | self.stage_up_3 = nn.Sequential( 85 | nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1), 86 | nn.BatchNorm2d(256), 87 | nn.ReLU(), 88 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1), 89 | nn.BatchNorm2d(256), 90 | nn.ReLU() 91 | ) 92 | 93 | self.stage_up_2 = nn.Sequential( 94 | nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1), 95 | nn.BatchNorm2d(128), 96 | nn.ReLU(), 97 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1), 98 | nn.BatchNorm2d(128), 99 | nn.ReLU() 100 | ) 101 | self.stage_up_1 = nn.Sequential( 102 | nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1), 103 | nn.BatchNorm2d(64), 104 | nn.ReLU(), 105 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1), 106 | nn.BatchNorm2d(64), 107 | nn.ReLU() 108 | ) 109 | 110 | self.final = nn.Sequential( 111 | nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=3, padding=1), 112 | ) 113 | 114 | def forward(self, x): 115 | x = x.float() 116 | # 下采样过程 117 | stage_1 = self.stage_1(x) 118 | stage_2 = self.stage_2(stage_1) 119 | stage_3 = self.stage_3(stage_2) 120 | stage_4 = self.stage_4(stage_3) 121 | stage_5 = self.stage_5(stage_4) 122 | 123 | 124 | # 上采样和合并stage_4 125 | up_4 = self.upsample_4(stage_5) 126 | 127 | up_4_conv = torch.cat([up_4, stage_4], dim=1) 128 | up_4_conv = self.stage_up_4(up_4_conv) 129 | 130 | # 上采样和合并stage_3 131 | up_3 = self.upsample_3(up_4_conv) 132 | 133 | up_3_conv = torch.cat([up_3, stage_3], dim=1) 134 | up_3_conv = self.stage_up_3(up_3_conv) 135 | 136 | # 上采样和合并stage_2 137 | up_2 = self.upsample_2(up_3_conv) 138 | 139 | up_2_conv = torch.cat([up_2, stage_2], dim=1) 140 | up_2_conv = self.stage_up_2(up_2_conv) 141 | 142 | # 上采样和合并stage_1 143 | up_1 = self.upsample_1(up_2_conv) 144 | 145 | up_1_conv = torch.cat([up_1, stage_1], dim=1) 146 | up_1_conv = self.stage_up_1(up_1_conv) 147 | 148 | output = self.final(up_1_conv) 149 | 150 | return output 151 | 152 | 153 | if __name__ == '__main__': 154 | model = Unet(6) 155 | 156 | image = torch.randn(32, 3, 256, 256) 157 | 158 | output = model(image) 159 | 160 | print("input:", image.shape) 161 | print("output:", output.shape) 162 | -------------------------------------------------------------------------------- /model/Unet/_init_.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wzysaber/ST_Unet_pytorch_Semantic-segmentation/b27f4d79ba85f81f793e17e686d6a7a1cd8b41ec/model/Unet/_init_.py -------------------------------------------------------------------------------- /model/deeplabv3_version_1/aspp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ASPP_Bottleneck(nn.Module): 7 | def __init__(self, num_classes): 8 | super(ASPP_Bottleneck, self).__init__() 9 | 10 | self.conv_1x1_1 = nn.Conv2d(4*512, 256, kernel_size=1) 11 | self.bn_conv_1x1_1 = nn.BatchNorm2d(256) 12 | 13 | self.conv_3x3_1 = nn.Conv2d(4*512, 256, kernel_size=3, stride=1, padding=6, dilation=6)#6 14 | self.bn_conv_3x3_1 = nn.BatchNorm2d(256) 15 | 16 | self.conv_3x3_2 = nn.Conv2d(4*512, 256, kernel_size=3, stride=1, padding=12, dilation=12)#12 17 | self.bn_conv_3x3_2 = nn.BatchNorm2d(256) 18 | 19 | self.conv_3x3_3 = nn.Conv2d(4*512, 256, kernel_size=3, stride=1, padding=18, dilation=18)#18 20 | self.bn_conv_3x3_3 = nn.BatchNorm2d(256) 21 | 22 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 23 | 24 | self.conv_1x1_2 = nn.Conv2d(4*512, 256, kernel_size=1) 25 | self.bn_conv_1x1_2 = nn.BatchNorm2d(256) 26 | 27 | self.conv_1x1_3 = nn.Conv2d(1280, 256, kernel_size=1) # (1280 = 5*256) 28 | self.bn_conv_1x1_3 = nn.BatchNorm2d(256) 29 | 30 | self.conv_1x1_4 = nn.Conv2d(256, num_classes, kernel_size=1) 31 | 32 | def forward(self, feature_map): 33 | # (feature_map has shape (batch_size, 4*512, h/16, w/16)) 34 | 35 | feature_map_h = feature_map.size()[2] # (== h/16) 36 | feature_map_w = feature_map.size()[3] # (== w/16) 37 | 38 | out_1x1 = F.relu(self.bn_conv_1x1_1(self.conv_1x1_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) 39 | out_3x3_1 = F.relu(self.bn_conv_3x3_1(self.conv_3x3_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) 40 | out_3x3_2 = F.relu(self.bn_conv_3x3_2(self.conv_3x3_2(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) 41 | out_3x3_3 = F.relu(self.bn_conv_3x3_3(self.conv_3x3_3(feature_map))) # (shape: (batch_size, 256, h/16, w/16)) 42 | 43 | out_img = self.avg_pool(feature_map) # (shape: (batch_size, 512, 1, 1)) 44 | out_img = F.relu(self.bn_conv_1x1_2(self.conv_1x1_2(out_img))) # (shape: (batch_size, 256, 1, 1)) 45 | out_img = F.interpolate(out_img, size=(feature_map_h, feature_map_w), mode="bilinear", align_corners=False) # (shape: (batch_size, 256, h/16, w/16)) 46 | 47 | out = torch.cat([out_1x1, out_3x3_1, out_3x3_2, out_3x3_3, out_img], 1) # (shape: (batch_size, 1280, h/16, w/16)) 48 | out = F.relu(self.bn_conv_1x1_3(self.conv_1x1_3(out))) # (shape: (batch_size, 256, h/16, w/16)) 49 | out = self.conv_1x1_4(out) # (shape: (batch_size, num_classes, h/16, w/16)) 50 | 51 | return out -------------------------------------------------------------------------------- /model/deeplabv3_version_1/deeplabv3.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from resnet import ResNet50 4 | from aspp import ASPP_Bottleneck 5 | import torch 6 | 7 | class DeepLabV3(nn.Module): 8 | def __init__(self, num_classes=6): 9 | super(DeepLabV3, self).__init__() 10 | self.num_classes = num_classes 11 | self.resnet = ResNet50() 12 | self.aspp = ASPP_Bottleneck(num_classes=self.num_classes) 13 | self.sigmoid = nn.Sigmoid() 14 | 15 | def forward(self, x): 16 | h = x.size()[2] 17 | w = x.size()[3] 18 | feature_map = self.resnet(x) 19 | output = self.aspp(feature_map) 20 | output = F.interpolate(output, size=(h, w), mode="bilinear", align_corners=False) 21 | # output = self.sigmoid(output) 22 | return output 23 | 24 | 25 | if __name__ == '__main__': 26 | model = DeepLabV3() 27 | model.eval() 28 | image = torch.randn(32, 3, 256, 256) 29 | print(model) 30 | output = model(image) 31 | print("input:", image.shape) 32 | print("output:", output.shape) 33 | -------------------------------------------------------------------------------- /model/deeplabv3_version_1/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.models as models 5 | 6 | 7 | def make_layer(block, in_channels, channels, num_blocks, stride=1, dilation=1): 8 | strides = [stride] + [1]*(num_blocks - 1) # (stride == 2, num_blocks == 4 --> strides == [2, 1, 1, 1]) 9 | 10 | blocks = [] 11 | for stride in strides: 12 | blocks.append(block(in_channels=in_channels, channels=channels, stride=stride, dilation=dilation)) 13 | in_channels = block.expansion*channels 14 | 15 | layer = nn.Sequential(*blocks) # (*blocks: call with unpacked list entires as arguments) 16 | 17 | return layer 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | expansion = 1 22 | 23 | def __init__(self, in_channels, channels, stride=1, dilation=1): 24 | super(BasicBlock, self).__init__() 25 | 26 | out_channels = self.expansion*channels 27 | 28 | self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False) 29 | self.bn1 = nn.BatchNorm2d(channels) 30 | 31 | self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=dilation, dilation=dilation, bias=False) 32 | self.bn2 = nn.BatchNorm2d(channels) 33 | 34 | if (stride != 1) or (in_channels != out_channels): 35 | conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) 36 | bn = nn.BatchNorm2d(out_channels) 37 | self.downsample = nn.Sequential(conv, bn) 38 | else: 39 | self.downsample = nn.Sequential() 40 | 41 | def forward(self, x): 42 | # (x has shape: (batch_size, in_channels, h, w)) 43 | 44 | out = F.relu(self.bn1(self.conv1(x))) # (shape: (batch_size, channels, h, w) if stride == 1, (batch_size, channels, h/2, w/2) if stride == 2) 45 | out = self.bn2(self.conv2(out)) # (shape: (batch_size, channels, h, w) if stride == 1, (batch_size, channels, h/2, w/2) if stride == 2) 46 | 47 | out = out + self.downsample(x) # (shape: (batch_size, channels, h, w) if stride == 1, (batch_size, channels, h/2, w/2) if stride == 2) 48 | 49 | out = F.relu(out) # (shape: (batch_size, channels, h, w) if stride == 1, (batch_size, channels, h/2, w/2) if stride == 2) 50 | 51 | return out 52 | 53 | 54 | class Bottleneck(nn.Module): 55 | expansion = 4 56 | 57 | def __init__(self, in_channels, channels, stride=1, dilation=1): 58 | super(Bottleneck, self).__init__() 59 | 60 | out_channels = self.expansion*channels 61 | 62 | self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=1, bias=False) 63 | self.bn1 = nn.BatchNorm2d(channels) 64 | 65 | self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False) 66 | self.bn2 = nn.BatchNorm2d(channels) 67 | 68 | self.conv3 = nn.Conv2d(channels, out_channels, kernel_size=1, bias=False) 69 | self.bn3 = nn.BatchNorm2d(out_channels) 70 | 71 | if (stride != 1) or (in_channels != out_channels): 72 | conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) 73 | bn = nn.BatchNorm2d(out_channels) 74 | self.downsample = nn.Sequential(conv, bn) 75 | else: 76 | self.downsample = nn.Sequential() 77 | 78 | def forward(self, x): 79 | # (x has shape: (batch_size, in_channels, h, w)) 80 | 81 | out = F.relu(self.bn1(self.conv1(x))) # (shape: (batch_size, channels, h, w)) 82 | out = F.relu(self.bn2(self.conv2(out))) # (shape: (batch_size, channels, h, w) if stride == 1, (batch_size, channels, h/2, w/2) if stride == 2) 83 | out = self.bn3(self.conv3(out)) # (shape: (batch_size, out_channels, h, w) if stride == 1, (batch_size, out_channels, h/2, w/2) if stride == 2) 84 | 85 | out = out + self.downsample(x) # (shape: (batch_size, out_channels, h, w) if stride == 1, (batch_size, out_channels, h/2, w/2) if stride == 2) 86 | 87 | out = F.relu(out) # (shape: (batch_size, out_channels, h, w) if stride == 1, (batch_size, out_channels, h/2, w/2) if stride == 2) 88 | 89 | return out 90 | 91 | class ResNet50(nn.Module): 92 | def __init__(self): 93 | super(ResNet50, self).__init__() 94 | 95 | resnet = models.resnet50() 96 | #resnet.load_state_dict((torch.load("/root/data/others/yaoganbisai/code_6_7/models/pretrained_model/resnet50-19c8e357.pth"))) 97 | self.resnet = nn.Sequential(*list(resnet.children())[:-3]) 98 | self.layer5 = make_layer(Bottleneck, in_channels=4*256, channels=512, num_blocks=3, stride=1, dilation=2) 99 | 100 | def forward(self, x): 101 | c4 = self.resnet(x) 102 | output = self.layer5(c4) 103 | return output 104 | 105 | def get_resnet50(): 106 | return ResNet50() -------------------------------------------------------------------------------- /tool/Save_predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import numpy as np 4 | from torchvision import transforms 5 | 6 | from PIL import Image 7 | from torch.autograd import Variable 8 | 9 | from torch.utils.data import Dataset 10 | from torch.utils.data import DataLoader 11 | 12 | from model.Unet.Unet import Unet 13 | 14 | img_transform = transforms.Compose([ 15 | transforms.ToTensor(), 16 | transforms.Normalize([.485, .456, .406], [.229, .224, .225])]) 17 | 18 | # ==================================================================================================== 19 | import cv2 as cv 20 | 21 | 22 | def GetPadImNRowColLi(image_path, cutsize_h=256, cutsize_w=256, stride=256): 23 | image = cv.imread(image_path) 24 | h, w = image.shape[0], image.shape[1] 25 | h_pad_cutsize = h if (h // cutsize_h == 0) else (h // cutsize_h + 1) * cutsize_h 26 | w_pad_cutsize = w if (w // cutsize_w == 0) else (w // cutsize_w + 1) * cutsize_w 27 | image = cv.copyMakeBorder(image, 28 | 0, 29 | h_pad_cutsize - h, 30 | 0, 31 | w_pad_cutsize - w, 32 | cv.BORDER_CONSTANT, 0) 33 | N = image.shape[0] - cutsize_h + 1 34 | M = image.shape[1] - cutsize_w + 1 35 | from numpy import arange 36 | row = arange(0, N, stride) 37 | col = arange(0, M, stride) 38 | row_col_li = [] 39 | for c in col: 40 | for r in row: 41 | row_col_li.append([c, r, c + cutsize_w, r + cutsize_h]) 42 | return image, row_col_li 43 | 44 | 45 | # ==================================================================================================== 46 | 47 | 48 | def snapshot_forward(model, dataloader, model_list, png, shape): 49 | model.eval() 50 | for (index, (image, pos_list)) in enumerate(dataloader): 51 | image = Variable(image).cuda() 52 | # print(image) 53 | # print(pos_list) 54 | 55 | predict_list = 0 56 | for model in model_list: 57 | predict_1 = model(image) 58 | predict_list = predict_1 59 | predict_2 = model(torch.flip(image, [-1])) 60 | predict_2 = torch.flip(predict_2, [-1]) 61 | 62 | predict_3 = model(torch.flip(image, [-2])) 63 | predict_3 = torch.flip(predict_3, [-2]) 64 | 65 | predict_4 = model(torch.flip(image, [-1, -2])) 66 | predict_4 = torch.flip(predict_4, [-1, -2]) 67 | 68 | predict_list += (predict_1 + predict_2 + predict_3 + predict_4) 69 | predict_list = torch.argmax(predict_list.cpu(), 1).byte().numpy() # n x h x w 70 | 71 | batch_size = predict_list.shape[0] # batch大小 72 | for i in range(batch_size): 73 | predict = predict_list[i] 74 | pos = pos_list[i, :] 75 | [topleft_x, topleft_y, buttomright_x, buttomright_y] = pos 76 | 77 | if (buttomright_x - topleft_x) == 256 and (buttomright_y - topleft_y) == 256: 78 | # png[topleft_y + 128:buttomright_y - 128, topleft_x + 128:buttomright_x - 128] = predict[128:384,128:384] 79 | png[topleft_y:buttomright_y, topleft_x:buttomright_x] = predict 80 | else: 81 | raise ValueError( 82 | "target_size!=512, Got {},{}".format(buttomright_x - topleft_x, buttomright_y - topleft_y)) 83 | 84 | h, w = png.shape 85 | # png = png[128:h - 128, 128:w - 128] # 去除整体外边界 86 | # zeros = (6800, 7200) # 去除补全512整数倍时的右下边界 87 | zeros = shape 88 | png = png[:zeros[0], :zeros[1]] 89 | 90 | return png 91 | 92 | 93 | def parse_args(): 94 | parser = argparse.ArgumentParser(description="膨胀预测") 95 | parser.add_argument('--test-data-root', type=str, 96 | default="/home/students/master/2022/wangzy/dataset/Vaihingen/Train/image/top_mosaic_09cm_area13.tif") 97 | parser.add_argument('--test-batch-size', type=int, default=4, metavar='N', 98 | help='batch size for testing (default:16)') 99 | parser.add_argument('--num_workers', type=int, default=0) 100 | 101 | parser.add_argument("--model-path", type=str, 102 | default="/home/students/master/2022/wangzy/PyCharm-Remote/ST_Unet_test/weight/Vaihingen/Unet/03-21-22:41:16/epoch_84_miou_0.70_F1_0.82.pth") 103 | parser.add_argument("--pred-path", type=str, default="") 104 | args = parser.parse_args() 105 | return args 106 | 107 | 108 | def create_png(shape): 109 | # zeros = (6800, 7200) 110 | zeros = shape 111 | h, w = zeros[0], zeros[1] 112 | new_h = h if (h // 256 == 0) else (h // 256 + 1) * 256 113 | new_w = w if (w // 256 == 0) else (w // 256 + 1) * 256 114 | # new_h, new_w = (h//512+1)*512, (w//512+1)*512 # 填充下边界和右边界得到滑窗的整数倍 115 | # zeros = (new_h+128, new_w+128) # 填充空白边界,考虑到边缘数据 116 | zeros = (new_h, new_w) 117 | zeros = np.zeros(zeros, np.uint8) 118 | return zeros 119 | 120 | 121 | # ==================================================================================================== 122 | class Inference_Dataset(Dataset): 123 | def __init__(self, root_dir, transforms): 124 | self.root_dir = root_dir 125 | # self.csv_file = pd.read_csv(csv_file, header=None) 126 | self.pad_image, self.row_col_li = GetPadImNRowColLi(root_dir) 127 | self.transforms = transforms 128 | 129 | def __len__(self): 130 | # return len(self.csv_file) 131 | return len(self.row_col_li) 132 | 133 | def __getitem__(self, idx): 134 | c, r, c_end, r_end = self.row_col_li[idx] 135 | image = Image.fromarray(self.pad_image[r:r_end, c:c_end]) 136 | image = self.transforms(image) 137 | pos_list = np.array(self.row_col_li[idx]) 138 | return image, pos_list 139 | 140 | 141 | # ==================================================================================================== 142 | 143 | 144 | def reference(): 145 | args = parse_args() 146 | 147 | dataset = Inference_Dataset(root_dir=args.test_data_root, 148 | transforms=img_transform) 149 | dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=False, num_workers=0) 150 | 151 | model = Unet(num_classes=6) 152 | state_dict = torch.load(args.model_path) 153 | # new_state_dict = OrderedDict() 154 | # for k, v in state_dict.items(): 155 | # print(k) 156 | # name = k[7:] 157 | # new_state_dict[name] = v 158 | model.load_state_dict(state_dict, strict=False) 159 | model = model.cuda() 160 | 161 | # model = nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) 162 | 163 | model_list = [] 164 | model_list.append(model) 165 | 166 | # ================================================================== 167 | shape = cv.imread(args.test_data_root).shape 168 | zeros = create_png((shape[0], shape[1])) 169 | image = snapshot_forward(model, dataloader, model_list, zeros, (shape[0], shape[1])) 170 | # ================================================================== 171 | 172 | from utils.palette import colorize_mask 173 | overlap = colorize_mask(image) 174 | 175 | import matplotlib.pyplot as plt 176 | plt.title("predict") 177 | plt.imshow(overlap) 178 | plt.show() 179 | 180 | 181 | if __name__ == '__main__': 182 | reference() 183 | -------------------------------------------------------------------------------- /tool/predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as T 3 | import numpy as np 4 | import cv2 5 | from PIL import Image 6 | from model.Unet.Unet import Unet 7 | import os 8 | 9 | from utils.palette import colorize_mask 10 | from Parameter import metric 11 | from prettytable import PrettyTable 12 | 13 | 14 | import matplotlib.pyplot as plt 15 | 16 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" # 设置采用的GPU序号 17 | 18 | 19 | # 定义预测函数 20 | def predict(model, image_path, Gray_label_path): 21 | """ 22 | 对输入图像进行预测,返回预测结果。 23 | 24 | Args: 25 | model (nn.Module): PyTorch模型实例 26 | image_path (str): 输入图像路径 27 | 28 | Returns: 29 | 预测结果的(N, H, W)的numpy数组 30 | """ 31 | # 加载图像并做相应预处理 32 | img = Image.open(image_path).convert('RGB') 33 | 34 | transform = T.Compose([ 35 | T.ToTensor(), 36 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 37 | ]) 38 | img = transform(img).to(device) 39 | img = img.unsqueeze(0) # 增加batch维 40 | 41 | Gray_label = Image.open(Gray_label_path).convert('L') 42 | mask = torch.from_numpy(np.array(Gray_label, dtype=np.int8)).long().numpy() 43 | 44 | # 对输入图像进行预测 45 | 46 | output = model(img) 47 | # pred = output.argmax(dim=1) # 取最大值的索引 48 | _, pred = torch.max(output, 1) # 加_,则返回一行中最大数的位置。 49 | 50 | # 转为numpy数组并去掉batch维 51 | pred = pred.data.cpu().numpy().squeeze().astype(np.uint8) # 将数据提取出来 52 | 53 | return pred, mask 54 | 55 | 56 | def Check(pred, mask): 57 | conf_mat = np.zeros((5, 5)).astype(np.int64) 58 | conf_mat += metric.confusion_matrix(pred=pred.flatten(), 59 | label=mask.flatten(), 60 | num_classes=6) 61 | 62 | acc, acc_per_class, pre, IoU, mean_IoU, kappa, F1_score, val_recall = metric.evaluate(conf_mat) 63 | 64 | print("Mean_IoU:", mean_IoU) 65 | print("OA:", acc) 66 | 67 | 68 | if __name__ == '__main__': 69 | # 加载相应参数 70 | device = torch.device("cuda:3") 71 | 72 | image_path = "/home/students/master/2022/wangzy/dataset/Vaihingen/predict/Cut/rgb/top_mosaic_09cm_area1_0017.jpg" 73 | RGB_label_path = "/home/students/master/2022/wangzy/dataset/Vaihingen/predict/Cut/label/top_mosaic_09cm_area1_label_0017.jpg" 74 | Gray_label_path = "/home/students/master/2022/wangzy/dataset/Vaihingen/predict/Cut/Gray_label/top_mosaic_09cm_area1_label_0017.jpg" 75 | 76 | model_path = "/home/students/master/2022/wangzy/PyCharm-Remote/ST_Unet_test/weight/Vaihingen/Unet/03-21-22:41:16/epoch_84_miou_0.70_F1_0.82.pth" # 导入网络的参数 77 | 78 | # 加载原始标签 79 | image = cv2.imread(RGB_label_path) 80 | B, G, R = cv2.split(image) 81 | image = cv2.merge((R, G, B)) 82 | 83 | # 加载模型 84 | model = Unet(num_classes=6) 85 | 86 | state_dict = torch.load(model_path) 87 | # new_state_dict = OrderedDict() 88 | # for k, v in state_dict.items(): 89 | # print(k) 90 | # name = k[7:] 91 | # new_state_dict[name] = v 92 | model.load_state_dict(state_dict, strict=False) 93 | model = model.to(device) 94 | 95 | # 预测图像 96 | pred, mask = predict(model, image_path, Gray_label_path) 97 | overlap = colorize_mask(pred) 98 | 99 | # 查看评价指标 100 | Check(pred, mask) 101 | 102 | # 可视化预测结果 103 | plt.title("predict") 104 | plt.imshow(overlap) 105 | plt.show() 106 | 107 | plt.title("label") 108 | plt.imshow(image) 109 | plt.show() 110 | -------------------------------------------------------------------------------- /tool/train.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | import torch.nn as nn 3 | import torch 4 | import numpy as np 5 | import data.sync_transforms 6 | 7 | from tqdm import tqdm 8 | from data.dataset import RSDataset 9 | from torch.autograd import Variable 10 | from prettytable import PrettyTable 11 | from Parameter import average_meter, metric 12 | 13 | 14 | def close_optimizer(args, model): 15 | # 使用相应的优化器 16 | if args.optimizer_name == 'Adadelta': 17 | optimizer = torch.optim.Adadelta(model.parameters(), 18 | lr=args.base_lr, 19 | weight_decay=args.weight_decay) 20 | if args.optimizer_name == 'Adam': 21 | optimizer = torch.optim.Adam(model.parameters(), 22 | lr=args.base_lr) 23 | 24 | if args.optimizer_name == 'SGD': 25 | optimizer = torch.optim.SGD(params=model.parameters(), 26 | lr=args.base_lr, 27 | momentum=args.momentum, 28 | weight_decay=args.weight_decay) 29 | 30 | return optimizer 31 | 32 | 33 | def data_set(args): 34 | # 对载入图像进行数据增强 35 | resize_scale_range = [float(scale) for scale in args.resize_scale_range.split(',')] # 0.5 2.0 36 | 37 | sync_transform = data.sync_transforms.Compose([ 38 | data.sync_transforms.RandomScale(args.base_size, args.crop_size, resize_scale_range), 39 | data.sync_transforms.RandomFlip(args.flip_ratio) 40 | ]) 41 | 42 | # 数据集的载入和相应参数 43 | train_dataset = RSDataset(root=args.train_data_root, mode='src', sync_transforms=sync_transform) # 加载数据集 44 | 45 | train_loader = DataLoader(dataset=train_dataset, 46 | batch_size=args.train_batch_size, 47 | num_workers=args.num_workers, 48 | shuffle=True, 49 | drop_last=True) 50 | 51 | # print('class names {}.'.format(train_loader.class_names)) 52 | # print('Number samples {}.'.format(len(train_loader))) # 将模型的种类数和名称进行打印 53 | 54 | # 实现相应验证集 55 | if not args.no_val: 56 | val_dataset = RSDataset(root=args.val_data_root, mode='src', sync_transforms=None) 57 | val_loader = DataLoader(dataset=val_dataset, 58 | batch_size=args.val_batch_size, 59 | num_workers=args.num_workers, 60 | shuffle=True, 61 | drop_last=True) 62 | 63 | return train_loader, train_dataset, val_loader, val_dataset 64 | 65 | 66 | def training(args, num_classes, model, optimizer, train_dataset, train_loader, criterion1, criterion2, device, epoch): 67 | model.train() # 把module设成训练模式,对Dropout和BatchNorm有影响 68 | 69 | train_loss = average_meter.AverageMeter() 70 | 71 | # “Poly”衰减策略 72 | max_iter = args.total_epochs * len(train_loader) 73 | curr_iter = epoch * len(train_loader) # 训练的数量 74 | lr = args.base_lr * (1 - float(curr_iter) / max_iter) ** 0.9 # 自己定义的学习率 75 | 76 | # 建立比较的矩阵16X16的格式, 77 | conf_mat = np.zeros((5, 5)).astype(np.int64) 78 | 79 | tbar = tqdm(train_loader) # 可视化显示数据的迭代 80 | 81 | # 将训练集里面的数据进行相应的遍历 82 | for index, data in enumerate(tbar): 83 | # assert data[0].size()[2:] == data[1].size()[1:] 84 | # data = self.mixup_transform(data, epoch) 85 | 86 | # 加载dataload中的图片 87 | imgs = Variable(data[0]).to(device) 88 | masks = Variable(data[1]).to(device) 89 | 90 | # 引入参数 91 | outputs = model(imgs) 92 | # torch.max(tensor, dim):指定维度上最大的数,返回tensor和下标 93 | _, preds = torch.max(outputs, 1) # 加_,则返回一行中最大数的位置。 94 | preds = preds.data.cpu().numpy().squeeze().astype(np.uint8) # 将数据提取出来 95 | 96 | loss1 = criterion1(outputs, masks) 97 | loss2 = criterion2(outputs, masks, softmax=True) 98 | 99 | loss = 0.5 * loss1 + 0.5 * loss2 100 | 101 | train_loss.update(loss, args.train_batch_size) 102 | # writer.add_scalar('train_loss', train_loss.avg, curr_iter) 103 | 104 | optimizer.zero_grad() # zero_grad()梯度清0 105 | loss.backward() 106 | optimizer.step() 107 | 108 | # 将相应的数据进行打印 109 | tbar.set_description('epoch {}, training loss {}, with learning rate {}.'.format( 110 | epoch, train_loss.val, lr 111 | )) 112 | 113 | masks = masks.data.cpu().numpy().squeeze().astype(np.uint8) # 将数据提取出来 114 | 115 | # 将相应的数据存储在矩阵方阵中 116 | conf_mat += metric.confusion_matrix(pred=preds.flatten(), 117 | label=masks.flatten(), 118 | num_classes=num_classes) 119 | 120 | # 评价参数 121 | train_acc, train_acc_per_class, train_pre, train_IoU, train_mean_IoU, train_kappa, train_F1_score, train_recall = metric.evaluate( 122 | conf_mat) 123 | 124 | table = PrettyTable(["序号", "名称", "acc", "IOu"]) 125 | 126 | # 打印参数 127 | for i in range(5): 128 | table.add_row([i, train_dataset.class_names[i], train_acc_per_class[i], train_IoU[i]]) 129 | print(table) 130 | 131 | print("F1_score:", train_F1_score) 132 | print("train_mean_IoU:", train_mean_IoU) 133 | 134 | print("\ntrain_acc(OA):", train_acc) 135 | print("kappa:", train_kappa) 136 | print(" ") 137 | 138 | return train_acc 139 | -------------------------------------------------------------------------------- /tool/val.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | 5 | from tqdm import tqdm 6 | from torch.autograd import Variable 7 | from Parameter import metric 8 | from prettytable import PrettyTable 9 | 10 | 11 | def validating(args, num_classes, model, optimizer, train_dataset, val_loader, device, epoch): 12 | model.eval() # 把module 设成预测模式,对Dropout和BatchNorm有影响 13 | 14 | # 构建矩阵方阵 15 | conf_mat = np.zeros((5, 5)).astype(np.int64) 16 | # 加载相应的数据集 17 | tbar = tqdm(val_loader) 18 | 19 | # 对数据进行遍历 20 | for index, data in enumerate(tbar): 21 | # assert data[0].size()[2:] == data[1].size()[1:] 22 | 23 | # 将相应的数据提取出来 24 | imgs = Variable(data[0]).to(device) 25 | masks = Variable(data[1]).to(device) 26 | 27 | optimizer.zero_grad() # 梯度清0 28 | outputs = model(imgs) 29 | _, preds = torch.max(outputs, 1) # 返回最大值的值,不是像素的值则为1 30 | 31 | # 将相应的参数进行提取 32 | preds = preds.data.cpu().numpy().squeeze().astype(np.uint8) 33 | masks = masks.data.cpu().numpy().squeeze().astype(np.uint8) 34 | 35 | conf_mat += metric.confusion_matrix(pred=preds.flatten(), 36 | label=masks.flatten(), 37 | num_classes=num_classes) 38 | 39 | # 打印相应的数据 40 | val_acc, val_acc_per_class, val_pre, val_IoU, val_mean_IoU, val_kappa, val_F1_score, val_recall = metric.evaluate( 41 | conf_mat) 42 | 43 | model_name = 'epoch_%d_miou_%.2f_F1_%.2f' % (epoch, val_mean_IoU, val_F1_score) 44 | 45 | # 保存相应训练中最好的模型 46 | if val_mean_IoU > args.best_miou: 47 | if args.save_file: 48 | torch.save(model.state_dict(), os.path.join(args.directory, model_name + '.pth')) 49 | args.best_miou = val_mean_IoU 50 | 51 | table = PrettyTable(["序号", "名称", "acc", "IoU"]) 52 | 53 | for i in range(5): 54 | table.add_row([i, train_dataset.class_names[i], val_acc_per_class[i], val_IoU[i]]) 55 | print(table) 56 | print("val_F1_score:", val_F1_score) 57 | print("val_mean_IoU:", val_mean_IoU) 58 | print("val_acc:", val_acc) 59 | print("best_miou:", args.best_miou) 60 | -------------------------------------------------------------------------------- /utils/Data_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import matplotlib.pyplot as plt 4 | 5 | # 自定义类别 6 | def fifteen_classes(): 7 | return ['其他类别', 8 | '水田', 9 | '水浇地', 10 | '旱耕地', 11 | '园地', 12 | '乔木林地', 13 | '灌木林地', 14 | '天然草地', 15 | '人工草地', 16 | '工业用地', 17 | '城市住宅', 18 | '村镇住宅', 19 | '交通运输', 20 | '河流', 21 | '湖泊', 22 | '坑塘'] 23 | 24 | 25 | def five_classes(): 26 | return [ 27 | '不透明表面', 28 | '建筑', 29 | '灌木', 30 | '树', 31 | '车', 32 | ] 33 | 34 | 35 | def Print_data(dataset_name, class_name, train_dataset_len, optimizer_name, model, total_epochs): 36 | print('\ndataset:', dataset_name) 37 | print('classification:', class_name) 38 | print('Number samples {}.'.format(len(train_dataset_len))) # 将模型的种类数和名称进行打印 39 | print('\noptimizer:', optimizer_name) 40 | print('model:', model) 41 | print('epoch:', total_epochs) 42 | print("\nOK!,everything is fine,let's start training!\n") 43 | 44 | 45 | def Creat_LineGraph(traincd_line): 46 | x = range(len(traincd_line)) 47 | y = traincd_line 48 | plt.plot(x, y, color="g", label="train cd H_acc", linewidth=0.3, marker=',') 49 | plt.xlabel('Epoch') 50 | plt.ylabel('Acc Value') 51 | plt.show() 52 | -------------------------------------------------------------------------------- /utils/Loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import torch.nn as nn 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class DiceLoss(nn.Module): 11 | def __init__(self, n_classes): 12 | super(DiceLoss, self).__init__() 13 | self.n_classes = n_classes # 物体的输入数量 14 | 15 | #没有问题,但是需要的是进行一个one_hot_的解码,来满足6个特征图 16 | def _one_hot_encoder(self, input_tensor): 17 | tensor_list = [] 18 | for i in range(self.n_classes): 19 | temp_prob = input_tensor == i # * torch.ones_like(input_tensor) 20 | tensor_list.append(temp_prob.unsqueeze(1)) 21 | output_tensor = torch.cat(tensor_list, dim=1) 22 | return output_tensor.float() 23 | 24 | def _dice_loss(self, score, target): 25 | target = target.float() 26 | smooth = 1e-5 27 | intersect = torch.sum(score * target) 28 | y_sum = torch.sum(target * target) 29 | z_sum = torch.sum(score * score) 30 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 31 | loss = 1 - loss 32 | return loss 33 | 34 | def forward(self, inputs, target, weight=None, softmax=False): 35 | if softmax: 36 | inputs = torch.softmax(inputs, dim=1)#12, 6, 256, 256 37 | target = self._one_hot_encoder(target)#[12, 6, 256, 256] 38 | if weight is None: 39 | weight = [1] * self.n_classes 40 | assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size()) 41 | class_wise_dice = [] 42 | loss = 0.0 43 | for i in range(0, self.n_classes): 44 | dice = self._dice_loss(inputs[:, i], target[:, i]) 45 | class_wise_dice.append(1.0 - dice.item()) 46 | loss += dice * weight[i] 47 | return loss / self.n_classes -------------------------------------------------------------------------------- /utils/palette.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | 4 | # 染色板将图片进行染色 5 | palette = [ 6 | 255, 255, 255, # 0 #surface 7 | 0, 0, 255, # 1 #building 8 | 0, 255, 255, # 2 #low vegetation 9 | 0, 255, 0, # 3 #tree 10 | 255, 255, 0, # 4 #car 11 | 255, 0, 0, # 5 #clutter/background red 12 | ] 13 | 14 | zero_pad = 256 * 3 - len(palette) 15 | for i in range(zero_pad): 16 | palette.append(0) 17 | 18 | 19 | # 将grey mask转化为彩色mask 20 | 21 | # putpalette 22 | # 为“P”或者“L”图像增加一个调色板。对于“L”图像,它的模式将变化为“P”。 23 | # 调色板序列需要包含768项整数,每组三个值表示对应像素的红,绿和蓝。用户可以使用768个byte的字符串代替这个整数序列。 24 | 25 | def colorize_mask(mask): 26 | mask_color = Image.fromarray(mask.astype(np.uint8)).convert('P') 27 | mask_color.putpalette(palette) 28 | return mask_color 29 | --------------------------------------------------------------------------------