├── utils ├── __init__.py ├── __pycache__ │ ├── loss.cpython-36.pyc │ ├── loss.cpython-37.pyc │ ├── config.cpython-36.pyc │ ├── config.cpython-37.pyc │ ├── utils.cpython-36.pyc │ ├── utils.cpython-37.pyc │ ├── __init__.cpython-36.pyc │ └── __init__.cpython-37.pyc ├── config.py ├── loss.py └── utils.py ├── dataset ├── Dataset_init.py ├── __pycache__ │ ├── PiFu.cpython-36.pyc │ ├── Linear_lesion.cpython-36.pyc │ └── Linear_lesion.cpython-37.pyc └── Linear_lesion.py ├── model ├── Synchronized │ ├── Temp │ ├── sync_batchnorm │ │ ├── Temp │ │ ├── __pycache__ │ │ │ ├── CC.cpython-37.pyc │ │ │ ├── comm.cpython-36.pyc │ │ │ ├── comm.cpython-37.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── batchnorm.cpython-36.pyc │ │ │ ├── batchnorm.cpython-37.pyc │ │ │ ├── replicate.cpython-36.pyc │ │ │ └── replicate.cpython-37.pyc │ │ ├── __init__.py │ │ ├── unittest.py │ │ ├── batchnorm_reimpl.py │ │ ├── replicate.py │ │ └── comm.py │ ├── __pycache__ │ │ ├── CC.cpython-36.pyc │ │ ├── CC.cpython-37.pyc │ │ ├── carafe.cpython-36.pyc │ │ └── carafe.cpython-37.pyc │ └── CC.py ├── __pycache__ │ ├── FRRN.cpython-36.pyc │ ├── diao.cpython-36.pyc │ ├── unet.cpython-36.pyc │ ├── unet.cpython-37.pyc │ ├── SE_unet.cpython-36.pyc │ ├── mildnet.cpython-36.pyc │ ├── rsunet.cpython-36.pyc │ ├── rsunet.cpython-37.pyc │ ├── DualUnet.cpython-36.pyc │ ├── unet_deep.cpython-36.pyc │ ├── unet_deep.cpython-37.pyc │ ├── CCunet_deep.cpython-36.pyc │ ├── GCunet_deep.cpython-36.pyc │ ├── diao_improve.cpython-36.pyc │ ├── diao_strip.cpython-36.pyc │ ├── lipunet_deep.cpython-36.pyc │ ├── lipunet_deep.cpython-37.pyc │ ├── unet_deep_3.cpython-36.pyc │ ├── unetmin_deep.cpython-37.pyc │ ├── CBAMunet_deep.cpython-36.pyc │ ├── CBAMunet_deep.cpython-37.pyc │ ├── GCunet_deep_2.cpython-36.pyc │ ├── Residualunet_deep.cpython-37.pyc │ ├── SCconv_unet_deep.cpython-36.pyc │ ├── unet_carafe_deep.cpython-36.pyc │ ├── unet_deep_improve.cpython-36.pyc │ ├── unet_deepsup_stip.cpython-36.pyc │ ├── GCunet_deep_decode.cpython-36.pyc │ ├── GCunet_deep_encode.cpython-36.pyc │ └── unet_deep_usecarafe.cpython-36.pyc ├── unet_deep.py ├── SE_unet.py ├── Residualunet_deep.py ├── resnet.py ├── lipunet_deep.py ├── rsunet.py ├── CBAMunet_deep.py ├── unet_carafe_deep.py ├── unet_deep_improve.py ├── unet_deepsup_stip.py ├── SCconv_unet_deep.py ├── unet_deep_usecarafe.py ├── DualUnet.py ├── mildnet.py └── unet_deep_Asymmetric_Non-local.py ├── README.md ├── metric.py └── demo.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/Dataset_init.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/Synchronized/Temp: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /model/Synchronized/sync_batchnorm/Temp: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /model/__pycache__/FRRN.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/FRRN.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/diao.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/diao.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/unet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/unet.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/unet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/unet.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/utils/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/utils/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/PiFu.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/dataset/__pycache__/PiFu.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/SE_unet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/SE_unet.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/mildnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/mildnet.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/rsunet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/rsunet.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/rsunet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/rsunet.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/utils/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/utils/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/utils/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/DualUnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/DualUnet.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/unet_deep.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/unet_deep.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/unet_deep.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/unet_deep.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/CCunet_deep.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/CCunet_deep.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/GCunet_deep.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/GCunet_deep.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/diao_improve.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/diao_improve.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/diao_strip.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/diao_strip.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/lipunet_deep.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/lipunet_deep.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/lipunet_deep.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/lipunet_deep.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/unet_deep_3.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/unet_deep_3.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/unetmin_deep.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/unetmin_deep.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/CBAMunet_deep.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/CBAMunet_deep.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/CBAMunet_deep.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/CBAMunet_deep.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/GCunet_deep_2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/GCunet_deep_2.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/Linear_lesion.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/dataset/__pycache__/Linear_lesion.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/Linear_lesion.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/dataset/__pycache__/Linear_lesion.cpython-37.pyc -------------------------------------------------------------------------------- /model/Synchronized/__pycache__/CC.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/Synchronized/__pycache__/CC.cpython-36.pyc -------------------------------------------------------------------------------- /model/Synchronized/__pycache__/CC.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/Synchronized/__pycache__/CC.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/Residualunet_deep.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/Residualunet_deep.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/SCconv_unet_deep.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/SCconv_unet_deep.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/unet_carafe_deep.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/unet_carafe_deep.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/unet_deep_improve.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/unet_deep_improve.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/unet_deepsup_stip.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/unet_deepsup_stip.cpython-36.pyc -------------------------------------------------------------------------------- /model/Synchronized/__pycache__/carafe.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/Synchronized/__pycache__/carafe.cpython-36.pyc -------------------------------------------------------------------------------- /model/Synchronized/__pycache__/carafe.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/Synchronized/__pycache__/carafe.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/GCunet_deep_decode.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/GCunet_deep_decode.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/GCunet_deep_encode.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/GCunet_deep_encode.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/unet_deep_usecarafe.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/__pycache__/unet_deep_usecarafe.cpython-36.pyc -------------------------------------------------------------------------------- /model/Synchronized/sync_batchnorm/__pycache__/CC.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/Synchronized/sync_batchnorm/__pycache__/CC.cpython-37.pyc -------------------------------------------------------------------------------- /model/Synchronized/sync_batchnorm/__pycache__/comm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/Synchronized/sync_batchnorm/__pycache__/comm.cpython-36.pyc -------------------------------------------------------------------------------- /model/Synchronized/sync_batchnorm/__pycache__/comm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/Synchronized/sync_batchnorm/__pycache__/comm.cpython-37.pyc -------------------------------------------------------------------------------- /model/Synchronized/sync_batchnorm/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/Synchronized/sync_batchnorm/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/Synchronized/sync_batchnorm/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/Synchronized/sync_batchnorm/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/Synchronized/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/Synchronized/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc -------------------------------------------------------------------------------- /model/Synchronized/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/Synchronized/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc -------------------------------------------------------------------------------- /model/Synchronized/sync_batchnorm/__pycache__/replicate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/Synchronized/sync_batchnorm/__pycache__/replicate.cpython-36.pyc -------------------------------------------------------------------------------- /model/Synchronized/sync_batchnorm/__pycache__/replicate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VaingloryD/pytorch_medical_image_seg_collection/HEAD/model/Synchronized/sync_batchnorm/__pycache__/replicate.cpython-37.pyc -------------------------------------------------------------------------------- /model/Synchronized/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .batchnorm import patch_sync_batchnorm, convert_model 13 | from .replicate import DataParallelWithCallback, patch_replication_callback 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | This project is a medical image segmentation template based on Pytorch implementation, which implements the basic and even most of the functions you need in medical image segmentation experiments. Such as data processing,data augmentation the design of loss, tool files, save and visualization of log, model files, training ,validation, test and project configuration.This project is inherited and improved by [Pytorch_Medical_Segmention_Template](https://github.com/FENGShuanglang/Pytorch_Medical_Segmention_Template).We added some commonly used segmentation networks and the latest modules including attention module, downsampling, upsampling, etc. 3 | -------------------------------------------------------------------------------- /model/Synchronized/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = 'NaN' 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ( 24 | 'Tensor close check failed\n' 25 | 'adiff={}\n' 26 | 'rdiff={}\n' 27 | ).format(adiff, rdiff) 28 | self.assertTrue(torch.allclose(x, y), message) 29 | 30 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Mar 12 14:03:52 2019 4 | 5 | @author: Administrator 6 | """ 7 | 8 | class DefaultConfig(object): 9 | num_epochs=100 #设置epoch * 10 | epoch_start_i=0 11 | checkpoint_step=5 #无用 12 | validation_step=1 #每训练几个epoch进行一次验证 13 | crop_height=256#无用 14 | crop_width=256 #无用 15 | batch_size=2 #无用 * 16 | input_channel=1 #输入的图像通道 * 17 | 18 | data=r'C:\Users\Administrator\Desktop\model file\Pytorch_Medical_Segmention-multi-deep_spie\Dataset'#数据存放的根目录 * 19 | dataset="Linear_lesion"#数据库名字(需修改成自己的数据名字) * 20 | log_dirs=r'C:\Users\Administrator\Desktop\model file\Pytorch_Medical_Segmention-multi-deep_spie\Linear_lesion_Code\UNet'#存放tensorboard log的文件夹() * 21 | 22 | lr=1e-3 #sgd学习率 * 23 | # lr=0.0001 #adam学习率 * 24 | lr_mode= 'poly' # poly优化策略 25 | net_work= 'UNet_deepsupusecarafe' # 可选网络 * 26 | momentum = 0.9# 优化器动量 27 | weight_decay = 1e-4# L2正则化系数 28 | 29 | mode='train' # 训练模式 * 30 | k_fold=3 #交叉验证折数 * 31 | test_fold=3 #测试时候需要选择的第几个文件夹 32 | num_workers=0 33 | num_classes=3 #分割类别数,二类分割设置为1,多类分割设置成 类别数+加背景 ** 34 | cuda='0' #GPU id选择 ** 35 | use_gpu=True 36 | pretrained_model_path=r'C:\Users\Administrator\Desktop\model file\Pytorch_Medical_Segmention-multi-deep_spie\Linear_lesion_Code\UNet\diao_deepsup_origin_conv7——crosscheckpoints\3/model_117_0.6867.pth.tar' #test的时候模型文件的选择(当mode='test'的时候用) 37 | save_model_path='./UNet_deepsupusecarafe——crosscheckpoints'#保存模型的文件夹 38 | -------------------------------------------------------------------------------- /model/Synchronized/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ['BatchNorm2dReimpl'] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 28 | super().__init__() 29 | 30 | self.num_features = num_features 31 | self.eps = eps 32 | self.momentum = momentum 33 | self.weight = nn.Parameter(torch.empty(num_features)) 34 | self.bias = nn.Parameter(torch.empty(num_features)) 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.reset_parameters() 38 | 39 | def reset_running_stats(self): 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | 43 | def reset_parameters(self): 44 | self.reset_running_stats() 45 | init.uniform_(self.weight) 46 | init.zeros_(self.bias) 47 | 48 | def forward(self, input_): 49 | batchsize, channels, height, width = input_.size() 50 | numel = batchsize * height * width 51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 52 | sum_ = input_.sum(1) 53 | sum_of_square = input_.pow(2).sum(1) 54 | mean = sum_ / numel 55 | sumvar = sum_of_square - sum_ * mean 56 | 57 | self.running_mean = ( 58 | (1 - self.momentum) * self.running_mean 59 | + self.momentum * mean.detach() 60 | ) 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | (1 - self.momentum) * self.running_var 64 | + self.momentum * unbias_var.detach() 65 | ) 66 | 67 | bias_var = sumvar / numel 68 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 69 | output = ( 70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 72 | 73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 74 | 75 | -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Sep 19 10:59:53 2018 4 | 5 | """ 6 | 7 | import numpy as np 8 | import os 9 | from PIL import Image 10 | #path_true=r'D:\task\projects\cabunet\keras_cabunet\aug\data\testset\label\4' 11 | #path_predict=r'D:\task\projects\cabunet\keras_cabunet\aug\result3\2\4' 12 | path_true=r'G:\KeTi\JBHI_pytorch\BaseNet\dataset\path\to\PiFu\mask\f1' 13 | path_predict=r'G:\KeTi\JBHI_pytorch\BaseNet_Full_channel_impore\dataset\test\f1' 14 | TP=FPN=0 15 | Jaccard=[] 16 | for roots,dirs,files in os.walk(path_predict): 17 | if files: 18 | # dice=[] 19 | # num=0 20 | for file in files: 21 | # num=num+1 22 | pre_file_path=os.path.join(roots,file) 23 | true_file_path=os.path.join(path_true,file) 24 | img_pre = np.array(Image.open(pre_file_path).convert("L")) 25 | img_pre[img_pre==255]=1 26 | img_true = np.array(Image.open(true_file_path).convert("L")) 27 | img_true[img_true==255]=1 28 | # print(img_pre.shape) 29 | # print(img_true.shape) 30 | # TP = TP+np.sum(np.array(img_pre,dtype=np.int32)&np.array(img_true,dtype=np.int32)) 31 | # FPN = FPN +np.sum(np.array(img_pre,dtype=np.int32)|np.array(img_true,dtype=np.int32)) 32 | TP = TP+np.sum(img_pre*img_true) 33 | FPN = FPN +np.sum(img_pre)+np.sum(img_true) 34 | single_I=np.sum(img_pre*img_true) 35 | single_U=np.sum(img_pre)+np.sum(img_true)-single_I 36 | Jaccard.append(single_I/single_U) 37 | 38 | 39 | 40 | 41 | dice = 2*TP/FPN 42 | print('TP:',TP) 43 | print('FPN:',FPN) 44 | print("DICE",dice) 45 | print('glob_Jaccard',TP/(FPN-TP)) 46 | print('single_Jaccard',sum(Jaccard)/len(Jaccard)) 47 | 48 | 49 | # 50 | ## pre_npy=np.load(pre_file_path) 51 | ## true_npy=np.load(true_file_path) 52 | # for i in range(1,4): 53 | # pre_npy_s=np.zeros(pre_npy.shape) 54 | # true_npy_s=np.zeros(true_npy.shape) 55 | # 56 | # pre_npy_s[pre_npy==i]=1 57 | # true_npy_s[true_npy==i]=1 58 | # TP=np.sum(np.array(pre_npy_s,dtype=np.int32)&np.array(true_npy_s,dtype=np.int32)) 59 | # FPN=np.sum(np.array(pre_npy_s,dtype=np.int32)|np.array(true_npy_s,dtype=np.int32)) 60 | # print('%d_TP:' %num,TP) 61 | # print('%d_FPN:' %num,FPN) 62 | ## if FPN!=0: 63 | # 64 | # if np.sum(true_npy_s)!=0: 65 | # dice_cof=2*TP/(TP+FPN) 66 | # dice.append(dice_cof) 67 | #dice=np.array(dice) 68 | #print(dice) 69 | #dice_mean=np.mean(dice) 70 | #print(dice_mean) -------------------------------------------------------------------------------- /model/Synchronized/CC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Softmax 5 | 6 | def INF(B,H,W): 7 | return -torch.diag(torch.tensor(float("inf")).cuda().repeat(H),0).unsqueeze(0).repeat(B*W,1,1) 8 | 9 | #def INF(B,H,W): 10 | # return -torch.diag(torch.tensor(float("inf")).repeat(H),0).unsqueeze(0).repeat(B*W,1,1) 11 | 12 | class CC_module(nn.Module): 13 | def __init__(self,in_dim): 14 | super(CC_module, self).__init__() 15 | self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)#特征图Q 16 | self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)#特征图k 17 | self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)#特征图V 18 | self.softmax = Softmax(dim=3) 19 | self.INF = INF 20 | self.gamma = nn.Parameter(torch.zeros(1))#可学习参数 21 | def forward(self, x): 22 | m_batchsize, _, height, width = x.size()#得到特征图的batchsize,通道数,高H,宽W 23 | proj_query = self.query_conv(x)#得到特征图Q 24 | proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1)#得到特征图Q的高 25 | proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1)#得到特征图Q的宽 26 | proj_key = self.key_conv(x)#得到特征图K 27 | proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)#得到特征图K的高H 28 | proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)#得到特征图k的宽W 29 | proj_value = self.value_conv(x)##得到特征图V 30 | proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)#得到特征图V的高H 31 | proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)#得到特征图V的宽W 32 | energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3) 33 | energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width)#计算两个tensor的矩阵乘法,torch.bmm(a,b),tensor a 的size为(b,h,w),tensor b的size为(b,w,h) 34 | concate = self.softmax(torch.cat([energy_H, energy_W], 3)) 35 | #concate = concate * (concate>torch.mean(concate,dim=3,keepdim=True)).float() 36 | att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)#计算得到A 37 | #print(concate) 38 | #print(att_H) 39 | att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width) 40 | out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1) 41 | out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3) 42 | #print(out_H.size(),out_W.size()) 43 | return self.gamma*(out_H + out_W) + x 44 | 45 | 46 | 47 | if __name__ == '__main__': 48 | model = CC_module(64) 49 | x = torch.randn(2, 64, 5, 6) 50 | out = model(x) 51 | print(out.shape) 52 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | #import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | import numpy as np 7 | # 8 | #def loss_builder(loss_type): 9 | # 10 | # if loss_type == 'cross_entropy': 11 | # weight_1 = torch.Tensor([1,5,10,20]) 12 | # criterion = nn.NLLLoss(weight=weight_1,ignore_index=255) 13 | # criterion_2 = DiceLoss() 14 | # criterion_3 = nn.BCELoss() 15 | # return 16 | # elif loss_type == 'dice_loss': 17 | # weight_1 = torch.Tensor([1,5,10,20]) 18 | # criterion_1 = nn.NLLLoss(weight=weight_1,ignore_index=255) 19 | # criterion_2 = EL_DiceLoss() 20 | # criterion_3 = nn.BCELoss() 21 | # 22 | # if loss_type in ['mix_3','mix_33']: 23 | # criterion_1.cuda() 24 | # criterion_2.cuda() 25 | # criterion_3.cuda() 26 | # criterion = [criterion_1,criterion_2,criterion_3] 27 | # 28 | # return criterion 29 | 30 | class DiceLoss(nn.Module): 31 | def __init__(self,smooth=0.01): 32 | super(DiceLoss, self).__init__() 33 | self.smooth = smooth 34 | 35 | def forward(self,input, target): 36 | input = torch.sigmoid(input) 37 | Dice = Variable(torch.Tensor([0]).float()).cuda() 38 | intersect=(input*target).sum()#两者相交 39 | union = torch.sum(input) + torch.sum(target)#各自的元素相加 40 | Dice=(2*intersect+self.smooth)/(union+self.smooth) 41 | dice_loss=1-Dice 42 | return dice_loss 43 | 44 | class Multi_DiceLoss(nn.Module): 45 | def __init__(self, class_num=3,smooth=0.1): 46 | super(Multi_DiceLoss, self).__init__() 47 | self.smooth = smooth 48 | self.class_num = class_num 49 | 50 | def forward(self,input, target): 51 | input = torch.exp(input)#返回一个新张量,包含输入input张量每个元素的指数。log(2)输出就是2 52 | 53 | Dice = Variable(torch.Tensor([0]).float()).cuda() 54 | for i in range(0,self.class_num): 55 | input_i = input[:,i,:,:]#将预测图的几类的标签分别拿出 56 | target_i = (target == i).float()#将金标准的三类标签拿出,对应的位置为1,其余的为0 57 | intersect = (input_i*target_i).sum() 58 | union = torch.sum(input_i) + torch.sum(target_i) 59 | dice = (2 * intersect + self.smooth) / (union + self.smooth) 60 | Dice += dice 61 | dice_loss = 1 - Dice/(self.class_num) 62 | return dice_loss 63 | 64 | class EL_DiceLoss(nn.Module): 65 | def __init__(self, class_num=4,smooth=1,gamma=0.5): 66 | super(EL_DiceLoss, self).__init__() 67 | self.smooth = smooth 68 | self.class_num = class_num 69 | self.gamma = gamma 70 | 71 | def forward(self,input, target): 72 | input = torch.exp(input) 73 | self.smooth = 0. 74 | Dice = Variable(torch.Tensor([0]).float()).cuda() 75 | for i in range(1,self.class_num): 76 | input_i = input[:,i,:,:] 77 | target_i = (target == i).float() 78 | intersect = (input_i*target_i).sum() 79 | union = torch.sum(input_i) + torch.sum(target_i) 80 | if target_i.sum() == 0: 81 | dice = Variable(torch.Tensor([1]).float()).cuda() 82 | else: 83 | dice = (2 * intersect + self.smooth) / (union + self.smooth) 84 | Dice += (-torch.log(dice))**self.gamma 85 | dice_loss = Dice/(self.class_num - 1) 86 | return dice_loss 87 | 88 | -------------------------------------------------------------------------------- /model/Synchronized/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import argparse 3 | from model.build_BiSeNet import BiSeNet 4 | import os 5 | import torch 6 | import cv2 7 | from imgaug import augmenters as iaa 8 | from PIL import Image 9 | from torchvision import transforms 10 | import numpy as np 11 | from utils import reverse_one_hot, get_label_info, colour_code_segmentation 12 | 13 | def predict_on_image(model, args): 14 | # pre-processing on image 15 | image = cv2.imread(args.data, -1) 16 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 17 | resize = iaa.Scale({'height': args.crop_height, 'width': args.crop_width}) 18 | resize_det = resize.to_deterministic() 19 | image = resize_det.augment_image(image) 20 | image = Image.fromarray(image).convert('RGB') 21 | image = transforms.ToTensor()(image).unsqueeze(0) 22 | 23 | # read csv label path 24 | label_info = get_label_info(args.csv_path) 25 | # predict 26 | model.eval() 27 | predict = model(image).squeeze() 28 | predict = reverse_one_hot(predict) 29 | predict = colour_code_segmentation(np.array(predict), label_info) 30 | predict = cv2.resize(np.uint8(predict), (960, 720)) 31 | cv2.imwrite(args.save_path, cv2.cvtColor(np.uint8(predict), cv2.COLOR_RGB2BGR)) 32 | 33 | def main(params): 34 | # basic parameters 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('--image', action='store_true', default=False, help='predict on image') 37 | parser.add_argument('--video', action='store_true', default=False, help='predict on video') 38 | parser.add_argument('--checkpoint_path', type=str, default=None, help='The path to the pretrained weights of model') 39 | parser.add_argument('--context_path', type=str, default="resnet101", help='The context path model you are using.') 40 | parser.add_argument('--num_classes', type=int, default=32, help='num of object classes (with void)') 41 | parser.add_argument('--data', type=str, default=None, help='Path to image or video for prediction') 42 | parser.add_argument('--crop_height', type=int, default=640, help='Height of cropped/resized input image to network') 43 | parser.add_argument('--crop_width', type=int, default=640, help='Width of cropped/resized input image to network') 44 | parser.add_argument('--cuda', type=str, default='0', help='GPU ids used for training') 45 | parser.add_argument('--use_gpu', type=bool, default=True, help='Whether to user gpu for training') 46 | parser.add_argument('--csv_path', type=str, default=None, required=True, help='Path to label info csv file') 47 | parser.add_argument('--save_path', type=str, default=None, required=True, help='Path to save predict image') 48 | 49 | 50 | args = parser.parse_args(params) 51 | 52 | # build model 53 | os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda 54 | model = BiSeNet(args.num_classes, args.context_path) 55 | if torch.cuda.is_available() and args.use_gpu: 56 | model = torch.nn.DataParallel(model).cuda() 57 | 58 | # load pretrained model if exists 59 | print('load model from %s ...' % args.checkpoint_path) 60 | model.module.load_state_dict(torch.load(args.checkpoint_path)) 61 | print('Done!') 62 | 63 | # predict on image 64 | if args.image: 65 | predict_on_image(model, args) 66 | 67 | # predict on video 68 | if args.video: 69 | pass 70 | 71 | if __name__ == '__main__': 72 | params = [ 73 | '--image', 74 | '--data', '0016E5_06210.png', 75 | '--checkpoint_path', './checkpoints/epoch_295.pth', 76 | '--cuda', '4', 77 | '--csv_path', '/path/to/CamVid/class_dict.csv', 78 | '--save_path', 'demo.png' 79 | ] 80 | main(params) -------------------------------------------------------------------------------- /model/Synchronized/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /dataset/Linear_lesion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import glob 3 | import os 4 | from torchvision import transforms 5 | from torchvision.transforms import functional as F 6 | #import cv2 7 | from PIL import Image 8 | # import pandas as pd 9 | import numpy as np 10 | from imgaug import augmenters as iaa 11 | import imgaug as ia 12 | #from utils import get_label_info, one_hot_it 13 | import random 14 | import matplotlib.pyplot as plt 15 | 16 | def augmentation(): 17 | # augment images with spatial transformation: Flip, Affine, Rotation, etc... 18 | # see https://github.com/aleju/imgaug for more details 19 | pass 20 | 21 | def augmentation_pixel(): 22 | # augment images with pixel intensity transformation: GaussianBlur, Multiply, etc... 23 | pass 24 | 25 | class LinearLesion(torch.utils.data.Dataset): 26 | def __init__(self, dataset_path,scale,k_fold_test=1, mode='train'): 27 | super().__init__() 28 | self.mode = mode 29 | self.img_path=dataset_path+'\\img'#训练集路径 30 | self.mask_path=dataset_path+'\\mask'#标签路径 31 | self.image_lists,self.label_lists=self.read_list(self.img_path,k_fold_test=k_fold_test) 32 | self.flip =iaa.SomeOf((2,4),[ 33 | iaa.Fliplr(0.5), 34 | iaa.Flipud(0.5), 35 | iaa.Affine(rotate=(-30, 30)), 36 | iaa.AdditiveGaussianNoise(scale=(0.0,0.08*255))], random_order=True) 37 | # resize 38 | self.resize_label = transforms.Resize(scale, Image.NEAREST)#重置标签图像分辨率,插值方法选择,Image.NEAREST为低质量插值 39 | self.resize_img = transforms.Resize(scale, Image.BILINEAR)#重置原图图像分辨率,插值方法选择,Image.BILINEAR为双线性插值 40 | # normalization 41 | self.to_tensor = transforms.ToTensor()#将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1],归一化至[0-1]是直接除以255,若自己的ndarray数据尺度有变化,则需要自行修改。 42 | 43 | def __getitem__(self, index): 44 | # load image and crop 45 | img = Image.open(self.image_lists[index]) 46 | 47 | 48 | img = np.array(img)#转化为numpy格式 49 | labels=self.label_lists[index] 50 | 51 | #load label 52 | if self.mode !='test': 53 | label = Image.open(self.label_lists[index]) 54 | label = np.array(label) 55 | 56 | label[label==255]=1 57 | label[label==190]=2 58 | # label[label==105]=3 59 | 60 | # label=np.argmax(label,axis=-1) 61 | # label[label!=1]=0 62 | # augment image and label 63 | 64 | if self.mode == 'train': 65 | 66 | 67 | seq_det = self.flip.to_deterministic()#确定一个数据增强的序列 68 | segmap = ia.SegmentationMapOnImage(label, shape=label.shape, nb_classes=3) 69 | img = seq_det.augment_image(img)#将方法应用在原图像上 70 | 71 | # plt.imshow(img.astype(np.float32))#显示原图 72 | # plt.show() 73 | 74 | label = seq_det.augment_segmentation_maps([segmap])[0].get_arr_int().astype(np.uint8)# 将方法应用在分割标签上,并且转换成np类型,这里尺度(256,512) 75 | # plt.imshow(label.astype(np.float32)) 76 | # plt.show() 77 | 78 | 79 | # label=np.reshape(label,(1,)+label.shape) 80 | # 81 | # plt.imshow(label.astype(np.float32))#显示label图片 82 | # plt.show() 83 | # label=torch.from_numpy(label.copy()).float()#二分类用float 84 | 85 | labels = torch.from_numpy(label.copy()).long()#多分类label用long 86 | 87 | img=np.reshape(img,img.shape+(1,)) # 如果输入是1通道需打开此注释 ****** 88 | 89 | img = self.to_tensor(img.copy()).float() 90 | 91 | 92 | 93 | 94 | return img, labels 95 | 96 | def __len__(self): 97 | return len(self.image_lists) 98 | def read_list(self,image_path,k_fold_test=1): 99 | fold=sorted(os.listdir(image_path))#对列表进行排序 100 | # print(fold) 101 | os.listdir()#指定的文件夹包含的文件或文件夹的名字的列表。 102 | img_list=[] 103 | if self.mode=='train': 104 | fold_r=fold 105 | fold_r.remove('f'+str(k_fold_test))# 移除测试数据,因为命名方式为f加数字 106 | for item in fold_r: 107 | img_list+=glob.glob(os.path.join(image_path,item)+'\\*.png')#这里是原图片的图片列表,将每个文件夹的图片加入列表 108 | 109 | label_list=[x.replace('img','mask').split('.')[0]+'.png' for x in img_list]#标签列表,其中对应的名称与原图一致 110 | 111 | 112 | elif self.mode=='val' or self.mode=='test': 113 | fold_s=fold[k_fold_test-1] 114 | img_list=glob.glob(os.path.join(image_path,fold_s)+'\\*.png') 115 | label_list=[x.replace('img','mask').split('.')[0]+'.png' for x in img_list] 116 | 117 | return img_list,label_list 118 | 119 | 120 | if __name__ == '__main__': 121 | data = LinearLesion(r'C:\Users\Administrator\Desktop\model file\Pytorch_Medical_Segmention-multi-deep_spie\Dataset\Linear_lesion', (256, 256),mode='train') 122 | 123 | from torch.utils.data import DataLoader 124 | dataloader_test = DataLoader( 125 | data, 126 | # this has to be 1 127 | batch_size=1, 128 | shuffle=True, 129 | num_workers=0, 130 | pin_memory=True, 131 | drop_last=False 132 | ) 133 | for i, (img, label) in enumerate(dataloader_test): 134 | 135 | label_colors = torch.unique(label.view(label.size(0), -1) .type(torch.LongTensor)) 136 | image_colors = torch.unique(img.view(img.size(0), -1), dim=1) 137 | 138 | # print(label.shape) 139 | # print(img.shape) 140 | 141 | label_arr = np.squeeze(label.numpy())#去除维度中是1的维度,例如(1, 3, 256, 512)变(3, 256, 512),(1, 1, 256, 512)变(256, 512) 142 | image_arr = np.squeeze(img.numpy()) 143 | # print(label_arr.shape) 144 | # print(image_arr.shape) 145 | 146 | #显示图片 147 | plt.imshow(label_arr.astype(np.float32))#这里由于torch中tensor的格式问题,转化为np后(3, 256, 512)是不能显示的,一定要(256, 512,3) 148 | plt.show()#这里标签只有0,1,2 149 | plt.imshow(image_arr.astype(np.float64)) 150 | plt.show() 151 | 152 | #显示图片存在的灰度 153 | # print(image_colors) 154 | # print(label_colors) 155 | # print(list(label)) 156 | # print(list(img.size())) 157 | break 158 | if i>3: 159 | break 160 | -------------------------------------------------------------------------------- /model/unet_deep.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Jan 6 19:27:45 2020 4 | 5 | @author: Administrator 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | #from layers import unetConv2, unetUp 11 | #from utils import init_weights, count_param 12 | import torchsummary 13 | from torch.nn import functional as F 14 | from torch.nn import init 15 | class UNet_deepsup(nn.Module): 16 | 17 | def __init__(self, in_channels=1, n_classes=3, feature_scale=2, is_deconv=True, is_batchnorm=True): 18 | super(UNet_deepsup, self).__init__() 19 | self.in_channels = in_channels 20 | self.feature_scale = feature_scale 21 | self.is_deconv = is_deconv 22 | self.is_batchnorm = is_batchnorm 23 | 24 | 25 | filters = [64, 128, 256, 512, 1024] 26 | filters = [int(x / self.feature_scale) for x in filters] 27 | 28 | # downsampling 29 | self.maxpool = nn.MaxPool2d(kernel_size=2) 30 | self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm) 31 | self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm) 32 | self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm) 33 | self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm) 34 | self.center = unetConv2(filters[3], filters[4], self.is_batchnorm) 35 | # upsampling 36 | self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv) 37 | self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv) 38 | self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv) 39 | self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv) 40 | # final conv (without any concat) 41 | self.final = nn.Conv2d(filters[0], n_classes, 1) 42 | 43 | #deep Supervision 44 | 45 | self.deepsup_3 = nn.Conv2d(filters[3], n_classes, kernel_size=1, stride=1, padding=0) 46 | self.output_3_up = nn.Upsample(scale_factor=8, mode='bilinear') 47 | self.deepsup_2 = nn.Conv2d(filters[2], n_classes, kernel_size=1, stride=1, padding=0) 48 | self.output_2_up = nn.Upsample(scale_factor=4, mode='bilinear') 49 | self.deepsup_1 = nn.Conv2d(filters[1], n_classes, kernel_size=1, stride=1, padding=0) 50 | self.output_1_up = nn.Upsample(scale_factor=2, mode='bilinear') 51 | 52 | 53 | def forward(self, inputs): 54 | conv1 = self.conv1(inputs) # 16*512*512 55 | maxpool1 = self.maxpool(conv1) # 16*256*256 56 | 57 | conv2 = self.conv2(maxpool1) # 32*256*256 58 | maxpool2 = self.maxpool(conv2) # 32*128*128 59 | 60 | conv3 = self.conv3(maxpool2) # 64*128*128 61 | maxpool3 = self.maxpool(conv3) # 64*64*64 62 | 63 | conv4 = self.conv4(maxpool3) # 128*64*64 64 | maxpool4 = self.maxpool(conv4) # 128*32*32 65 | 66 | center = self.center(maxpool4) # 256*32*32 67 | 68 | up4 = self.up_concat4(center,conv4) # 128*64*64 69 | up4_deep = self.deepsup_3(up4) 70 | up4_deep = self.output_3_up(up4_deep) 71 | 72 | up3 = self.up_concat3(up4,conv3) # 64*128*128 73 | up3_deep = self.deepsup_2(up3) 74 | up3_deep = self.output_2_up(up3_deep) 75 | 76 | up2 = self.up_concat2(up3,conv2) # 32*256*256 77 | up2_deep = self.deepsup_1(up2) 78 | up2_deep = self.output_1_up(up2_deep) 79 | 80 | up1 = self.up_concat1(up2,conv1) # 16*512*512 81 | 82 | 83 | final = self.final(up1) 84 | # final=F.softmax(final,dim=1)#对每一行使用Softmax 85 | up4_deep = F.log_softmax(final,dim=1) 86 | up3_deep = F.log_softmax(final,dim=1) 87 | up2_deep = F.log_softmax(final,dim=1) 88 | final=F.log_softmax(final,dim=1) 89 | 90 | return up4_deep,up3_deep,up2_deep,final 91 | 92 | class unetConv2(nn.Module): 93 | def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1): 94 | super(unetConv2, self).__init__() 95 | self.n = n 96 | self.ks = ks 97 | self.stride = stride 98 | self.padding = padding 99 | s = stride 100 | p = padding 101 | if is_batchnorm: 102 | for i in range(1, n+1): 103 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 104 | nn.BatchNorm2d(out_size), 105 | nn.ReLU(inplace=True),) 106 | setattr(self, 'conv%d'%i, conv) 107 | in_size = out_size 108 | 109 | else: 110 | for i in range(1, n+1): 111 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 112 | nn.ReLU(inplace=True),) 113 | setattr(self, 'conv%d'%i, conv) 114 | in_size = out_size 115 | 116 | # initialise the blocks 117 | for m in self.children(): 118 | init_weights(m, init_type='kaiming') 119 | 120 | def forward(self, inputs): 121 | x = inputs 122 | for i in range(1, self.n+1): 123 | conv = getattr(self, 'conv%d'%i) 124 | x = conv(x) 125 | 126 | return x 127 | 128 | class unetUp(nn.Module): 129 | def __init__(self, in_size, out_size, is_deconv, n_concat=2): 130 | super(unetUp, self).__init__() 131 | self.conv = unetConv2(in_size+(n_concat-2)*out_size, out_size, False) 132 | if is_deconv: 133 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, padding=0) 134 | else: 135 | self.up = nn.Sequential( 136 | nn.UpsamplingBilinear2d(scale_factor=2), 137 | nn.Conv2d(in_size, out_size, 1)) 138 | 139 | # initialise the blocks 140 | for m in self.children(): 141 | if m.__class__.__name__.find('unetConv2') != -1: continue#continue语句可用于循环中,用于跳过当前循环的剩余代码,然后继续进行下一轮的循环。 142 | init_weights(m, init_type='kaiming') 143 | 144 | def forward(self, high_feature, *low_feature): 145 | outputs0 = self.up(high_feature) 146 | for feature in low_feature: 147 | outputs0 = torch.cat([outputs0, feature], 1) 148 | 149 | return self.conv(outputs0) 150 | 151 | def init_weights(net, init_type='normal'): 152 | #print('initialization method [%s]' % init_type) 153 | if init_type == 'kaiming': 154 | net.apply(weights_init_kaiming) 155 | else: 156 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 157 | def weights_init_kaiming(m): 158 | classname = m.__class__.__name__ 159 | if classname.find('Conv') != -1: 160 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 161 | elif classname.find('Linear') != -1: 162 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 163 | elif classname.find('BatchNorm') != -1: 164 | init.normal_(m.weight.data, 1.0, 0.02) 165 | init.constant_(m.bias.data, 0.0) 166 | 167 | 168 | if __name__ == '__main__': 169 | inputs = torch.rand((2, 1, 512, 512)).cuda() 170 | 171 | unet_plus_plus = UNet_deepsup(in_channels=1, n_classes=2).cuda() 172 | a,b,c,output = unet_plus_plus(inputs) 173 | print('# parameters:', sum(param.numel() for param in unet_plus_plus.parameters())) 174 | def get_parameter_number(net): 175 | total_num = sum(p.numel() for p in net.parameters()) 176 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) 177 | return {'Total': total_num, 'Trainable': trainable_num} 178 | print('# parameters:', get_parameter_number(unet_plus_plus)) -------------------------------------------------------------------------------- /model/SE_unet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Apr 17 12:59:10 2019 4 | 5 | @author: Fsl 6 | """ 7 | #import _init_paths 8 | import torch 9 | import torch.nn as nn 10 | #from layers import unetConv2, unetUp 11 | #from utils import init_weights, count_param 12 | import torchsummary 13 | from torch.nn import functional as F 14 | from torch.nn import init 15 | class SE_unet(nn.Module): 16 | 17 | def __init__(self, in_channels=1, n_classes=4, feature_scale=2, is_deconv=True, is_batchnorm=True): 18 | super(SE_unet, self).__init__() 19 | self.in_channels = in_channels 20 | self.feature_scale = feature_scale 21 | self.is_deconv = is_deconv 22 | self.is_batchnorm = is_batchnorm 23 | 24 | filters = [64, 128, 256, 512, 1024] 25 | filters = [int(x / self.feature_scale) for x in filters] 26 | 27 | # downsampling 28 | self.maxpool = nn.MaxPool2d(kernel_size=2) 29 | self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm) 30 | self.SE_Block_1 = SE_Block(planes=filters[0],r = 16) 31 | self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm) 32 | self.SE_Block_2 = SE_Block(planes=filters[1],r = 16) 33 | self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm) 34 | self.SE_Block_3 = SE_Block(planes=filters[2],r = 16) 35 | self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm) 36 | self.SE_Block_4 = SE_Block(planes=filters[3],r = 16) 37 | self.center = unetConv2(filters[3], filters[4], self.is_batchnorm) 38 | # upsampling 39 | self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv) 40 | self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv) 41 | self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv) 42 | self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv) 43 | # final conv (without any concat) 44 | self.final = nn.Conv2d(filters[0], n_classes, 1) 45 | 46 | # initialise weights 47 | for m in self.modules(): 48 | if isinstance(m, nn.Conv2d): 49 | init_weights(m, init_type='kaiming') 50 | elif isinstance(m, nn.BatchNorm2d): 51 | init_weights(m, init_type='kaiming') 52 | 53 | 54 | def forward(self, inputs): 55 | conv1 = self.conv1(inputs) # 16*512*512 56 | maxpool1 = self.SE_Block_1(conv1) 57 | maxpool1 = self.maxpool(maxpool1) # 16*256*256 58 | 59 | conv2 = self.conv2(maxpool1) # 32*256*256 60 | maxpool2 = self.SE_Block_2(conv2) 61 | maxpool2 = self.maxpool(maxpool2) # 32*128*128 62 | 63 | conv3 = self.conv3(maxpool2) # 64*128*128 64 | maxpool3 = self.SE_Block_3(conv3) 65 | maxpool3 = self.maxpool(maxpool3) # 64*64*64 66 | 67 | conv4 = self.conv4(maxpool3) # 128*64*64 68 | maxpool4 = self.SE_Block_4(conv4) 69 | maxpool4 = self.maxpool(maxpool4) # 128*32*32 70 | 71 | center = self.center(maxpool4) # 256*32*32 72 | up4 = self.up_concat4(center,conv4) # 128*64*64 73 | up3 = self.up_concat3(up4,conv3) # 64*128*128 74 | up2 = self.up_concat2(up3,conv2) # 32*256*256 75 | up1 = self.up_concat1(up2,conv1) # 16*512*512 76 | 77 | final = self.final(up1) 78 | # final=F.sigmoid(final) 79 | final=F.log_softmax(final,dim=1) 80 | 81 | return final 82 | 83 | class unetConv2(nn.Module): 84 | def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1): 85 | super(unetConv2, self).__init__() 86 | self.n = n 87 | self.ks = ks 88 | self.stride = stride 89 | self.padding = padding 90 | s = stride 91 | p = padding 92 | if is_batchnorm: 93 | for i in range(1, n+1): 94 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 95 | nn.BatchNorm2d(out_size), 96 | nn.ReLU(inplace=True),) 97 | setattr(self, 'conv%d'%i, conv) 98 | in_size = out_size 99 | 100 | else: 101 | for i in range(1, n+1): 102 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 103 | nn.ReLU(inplace=True),) 104 | setattr(self, 'conv%d'%i, conv) 105 | in_size = out_size 106 | 107 | # initialise the blocks 108 | for m in self.children(): 109 | init_weights(m, init_type='kaiming') 110 | 111 | def forward(self, inputs): 112 | x = inputs 113 | for i in range(1, self.n+1): 114 | conv = getattr(self, 'conv%d'%i) 115 | x = conv(x) 116 | 117 | return x 118 | 119 | class unetUp(nn.Module): 120 | def __init__(self, in_size, out_size, is_deconv, n_concat=2): 121 | super(unetUp, self).__init__() 122 | self.conv = unetConv2(in_size+(n_concat-2)*out_size, out_size, False) 123 | if is_deconv: 124 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, padding=0)#上采样,反卷积 125 | else: 126 | self.up = nn.Sequential( 127 | nn.UpsamplingBilinear2d(scale_factor=2), 128 | nn.Conv2d(in_size, out_size, 1)) 129 | 130 | # initialise the blocks 131 | for m in self.children(): 132 | if m.__class__.__name__.find('unetConv2') != -1: continue 133 | init_weights(m, init_type='kaiming') 134 | 135 | def forward(self, high_feature, *low_feature): 136 | outputs0 = self.up(high_feature) 137 | for feature in low_feature: 138 | outputs0 = torch.cat([outputs0, feature], 1) 139 | return self.conv(outputs0) 140 | 141 | def init_weights(net, init_type='normal'): 142 | #print('initialization method [%s]' % init_type) 143 | if init_type == 'kaiming': 144 | net.apply(weights_init_kaiming) 145 | else: 146 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 147 | def weights_init_kaiming(m): 148 | classname = m.__class__.__name__ 149 | #print(classname) 150 | if classname.find('Conv') != -1: 151 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 152 | elif classname.find('Linear') != -1: 153 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 154 | elif classname.find('BatchNorm') != -1: 155 | init.normal_(m.weight.data, 1.0, 0.02) 156 | init.constant_(m.bias.data, 0.0) 157 | 158 | 159 | class SE_Block(nn.Module):#自写的 160 | def __init__(self, planes,r = 16, stride=1, downsample=None): 161 | super(SE_Block, self).__init__() 162 | #中间是se模块的部分,其他是resnet正常部分 163 | self.relu = nn.ReLU(inplace=True) 164 | self.global_pool = nn.AdaptiveAvgPool2d(1) 165 | self.conv_down = nn.Conv2d( 166 | planes , planes // r, kernel_size=1, bias=False) 167 | self.conv_up = nn.Conv2d( 168 | planes // r, planes , kernel_size=1, bias=False) 169 | self.sig = nn.Sigmoid() 170 | # 171 | def forward(self, x): 172 | input = x 173 | out1 = self.global_pool(x) 174 | out1 = self.conv_down(out1) 175 | out1 = self.relu(out1) 176 | out1 = self.conv_up(out1) 177 | out1 = self.sig(out1) 178 | res = out1 * input 179 | 180 | return res 181 | #model = UNet() 182 | #torchsummary.summary(model, (1, 512, 512)) 183 | if __name__ == '__main__': 184 | inputs = torch.rand((2, 1, 512, 512)).cuda() 185 | 186 | unet_plus_plus = SE_unet(in_channels=1, n_classes=2).cuda() 187 | output = unet_plus_plus(inputs) 188 | print('# parameters:', sum(param.numel() for param in unet_plus_plus.parameters())) 189 | def get_parameter_number(net): 190 | total_num = sum(p.numel() for p in net.parameters()) 191 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) 192 | return {'Total': total_num, 'Trainable': trainable_num} 193 | print('# parameters:', get_parameter_number(unet_plus_plus)) -------------------------------------------------------------------------------- /model/Residualunet_deep.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Apr 5 16:28:48 2020 4 | 5 | @author: 45780 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | #from layers import unetConv2, unetUp 11 | #from utils import init_weights, count_param 12 | import torchsummary 13 | from torch.nn import functional as F 14 | from torch.nn import init 15 | class Residual_UNet_deepsup(nn.Module): 16 | 17 | def __init__(self, in_channels=1, n_classes=4, feature_scale=2, is_deconv=True, is_batchnorm=True): 18 | super(Residual_UNet_deepsup, self).__init__() 19 | self.in_channels = in_channels 20 | self.feature_scale = feature_scale 21 | self.is_deconv = is_deconv 22 | self.is_batchnorm = is_batchnorm 23 | 24 | filters = [64, 128, 256, 512, 1024] 25 | filters = [int(x / self.feature_scale) for x in filters] 26 | 27 | # downsampling 28 | self.maxpool = nn.MaxPool2d(kernel_size=2) 29 | self.conv1 = ResidualBlock(self.in_channels, filters[0], self.is_batchnorm) 30 | self.conv2 = ResidualBlock(filters[0], filters[1], self.is_batchnorm) 31 | self.conv3 = ResidualBlock(filters[1], filters[2], self.is_batchnorm) 32 | self.conv4 = ResidualBlock(filters[2], filters[3], self.is_batchnorm) 33 | self.center = ResidualBlock(filters[3], filters[4], self.is_batchnorm) 34 | # upsampling 35 | self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv) 36 | self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv) 37 | self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv) 38 | self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv) 39 | # final conv (without any concat) 40 | self.final = nn.Conv2d(filters[0], n_classes, 1) 41 | 42 | #deep Supervision 43 | 44 | self.deepsup_3 = nn.Conv2d(filters[3], n_classes, kernel_size=1, stride=1, padding=0) 45 | self.output_3_up = nn.Upsample(scale_factor=8, mode='bilinear') 46 | self.deepsup_2 = nn.Conv2d(filters[2], n_classes, kernel_size=1, stride=1, padding=0) 47 | self.output_2_up = nn.Upsample(scale_factor=4, mode='bilinear') 48 | self.deepsup_1 = nn.Conv2d(filters[1], n_classes, kernel_size=1, stride=1, padding=0) 49 | self.output_1_up = nn.Upsample(scale_factor=2, mode='bilinear') 50 | 51 | 52 | def forward(self, inputs): 53 | conv1 = self.conv1(inputs) # 16*512*512 54 | maxpool1 = self.maxpool(conv1) # 16*256*256 55 | 56 | conv2 = self.conv2(maxpool1) # 32*256*256 57 | maxpool2 = self.maxpool(conv2) # 32*128*128 58 | 59 | conv3 = self.conv3(maxpool2) # 64*128*128 60 | maxpool3 = self.maxpool(conv3) # 64*64*64 61 | 62 | conv4 = self.conv4(maxpool3) # 128*64*64 63 | maxpool4 = self.maxpool(conv4) # 128*32*32 64 | 65 | center = self.center(maxpool4) # 256*32*32 66 | 67 | up4 = self.up_concat4(center,conv4) # 128*64*64 68 | up4_deep = self.deepsup_3(up4) 69 | up4_deep = self.output_3_up(up4_deep) 70 | 71 | up3 = self.up_concat3(up4,conv3) # 64*128*128 72 | up3_deep = self.deepsup_2(up3) 73 | up3_deep = self.output_2_up(up3_deep) 74 | 75 | up2 = self.up_concat2(up3,conv2) # 32*256*256 76 | up2_deep = self.deepsup_1(up2) 77 | up2_deep = self.output_1_up(up2_deep) 78 | 79 | up1 = self.up_concat1(up2,conv1) # 16*512*512 80 | 81 | 82 | final = self.final(up1) 83 | # final=F.softmax(final,dim=1)#对每一行使用Softmax 84 | up4_deep = F.log_softmax(final,dim=1) 85 | up3_deep = F.log_softmax(final,dim=1) 86 | up2_deep = F.log_softmax(final,dim=1) 87 | final=F.log_softmax(final,dim=1) 88 | 89 | return up4_deep,up3_deep,up2_deep,final 90 | 91 | class unetConv2(nn.Module): 92 | def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1): 93 | super(unetConv2, self).__init__() 94 | self.n = n 95 | self.ks = ks 96 | self.stride = stride 97 | self.padding = padding 98 | s = stride 99 | p = padding 100 | if is_batchnorm: 101 | for i in range(1, n+1): 102 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 103 | nn.BatchNorm2d(out_size), 104 | nn.ReLU(inplace=True),) 105 | setattr(self, 'conv%d'%i, conv) 106 | in_size = out_size 107 | 108 | else: 109 | for i in range(1, n+1): 110 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 111 | nn.ReLU(inplace=True),) 112 | setattr(self, 'conv%d'%i, conv) 113 | in_size = out_size 114 | 115 | # initialise the blocks 116 | for m in self.children(): 117 | init_weights(m, init_type='kaiming') 118 | 119 | def forward(self, inputs): 120 | x = inputs 121 | for i in range(1, self.n+1): 122 | conv = getattr(self, 'conv%d'%i) 123 | x = conv(x) 124 | 125 | return x 126 | 127 | class unetUp(nn.Module): 128 | def __init__(self, in_size, out_size, is_deconv, n_concat=2): 129 | super(unetUp, self).__init__() 130 | self.conv = unetConv2(in_size+(n_concat-2)*out_size, out_size, False) 131 | if is_deconv: 132 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, padding=0) 133 | else: 134 | self.up = nn.Sequential( 135 | nn.UpsamplingBilinear2d(scale_factor=2), 136 | nn.Conv2d(in_size, out_size, 1)) 137 | 138 | # initialise the blocks 139 | for m in self.children(): 140 | if m.__class__.__name__.find('unetConv2') != -1: continue#continue语句可用于循环中,用于跳过当前循环的剩余代码,然后继续进行下一轮的循环。 141 | init_weights(m, init_type='kaiming') 142 | 143 | def forward(self, high_feature, *low_feature): 144 | outputs0 = self.up(high_feature) 145 | for feature in low_feature: 146 | outputs0 = torch.cat([outputs0, feature], 1) 147 | 148 | return self.conv(outputs0) 149 | 150 | def init_weights(net, init_type='normal'): 151 | #print('initialization method [%s]' % init_type) 152 | if init_type == 'kaiming': 153 | net.apply(weights_init_kaiming) 154 | else: 155 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 156 | def weights_init_kaiming(m): 157 | classname = m.__class__.__name__ 158 | if classname.find('Conv') != -1: 159 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 160 | elif classname.find('Linear') != -1: 161 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 162 | elif classname.find('BatchNorm') != -1: 163 | init.normal_(m.weight.data, 1.0, 0.02) 164 | init.constant_(m.bias.data, 0.0) 165 | 166 | class ResidualBlock(nn.Module): 167 | def __init__(self, inplanes, planes, stride=1): 168 | super(ResidualBlock, self).__init__() 169 | self.stride = stride 170 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, 171 | padding=1, bias=False) 172 | self.bn1 = nn.BatchNorm2d(planes) 173 | self.relu = nn.ReLU(inplace=True) 174 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 175 | padding=1, bias=False) 176 | self.bn2 = nn.BatchNorm2d(planes) 177 | self.conv1_1 = nn.Conv2d(inplanes, planes, 1, stride) 178 | 179 | for m in self.children(): 180 | init_weights(m, init_type='kaiming') 181 | 182 | def forward(self, x): 183 | residual = self.conv1_1(x) 184 | 185 | out = self.conv1(x) 186 | out = self.bn1(out) 187 | out = self.relu(out) 188 | 189 | out = self.conv2(out) 190 | out = self.bn2(out) 191 | 192 | out += residual 193 | out = self.relu(out) 194 | 195 | return out 196 | 197 | if __name__ == '__main__': 198 | inputs = torch.rand((2, 1, 512, 512)).cuda() 199 | 200 | unet_plus_plus = Residual_UNet_deepsup(in_channels=1, n_classes=2).cuda() 201 | a,b,c,output = unet_plus_plus(inputs) 202 | print('# parameters:', sum(param.numel() for param in unet_plus_plus.parameters())) 203 | def get_parameter_number(net): 204 | total_num = sum(p.numel() for p in net.parameters()) 205 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) 206 | return {'Total': total_num, 'Trainable': trainable_num} 207 | print('# parameters:', get_parameter_number(unet_plus_plus)) -------------------------------------------------------------------------------- /model/resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Apr 10 09:57:49 2019 4 | 5 | @author: Fsl 6 | """ 7 | 8 | import torch.nn as nn 9 | import math 10 | import torch.utils.model_zoo as model_zoo 11 | import torchsummary 12 | 13 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 14 | 'resnet152'] 15 | 16 | 17 | model_urls = { 18 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 19 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 20 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 21 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 22 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 23 | } 24 | 25 | 26 | def conv3x3(in_planes, out_planes, stride=1): 27 | """3x3 convolution with padding""" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 29 | padding=1, bias=False) 30 | 31 | 32 | class BasicBlock(nn.Module): 33 | expansion = 1 34 | 35 | def __init__(self, inplanes, planes, stride=1, downsample=None): 36 | super(BasicBlock, self).__init__() 37 | self.conv1 = conv3x3(inplanes, planes, stride) 38 | self.bn1 = nn.BatchNorm2d(planes) 39 | self.relu = nn.ReLU(inplace=True) 40 | self.conv2 = conv3x3(planes, planes) 41 | self.bn2 = nn.BatchNorm2d(planes) 42 | self.downsample = downsample 43 | self.stride = stride 44 | 45 | def forward(self, x): 46 | residual = x 47 | 48 | out = self.conv1(x) 49 | out = self.bn1(out) 50 | out = self.relu(out) 51 | 52 | out = self.conv2(out) 53 | out = self.bn2(out) 54 | 55 | if self.downsample is not None: 56 | residual = self.downsample(x) 57 | 58 | out += residual 59 | out = self.relu(out) 60 | 61 | return out 62 | 63 | 64 | class Bottleneck(nn.Module): 65 | expansion = 4 66 | 67 | def __init__(self, inplanes, planes, stride=1, downsample=None): 68 | super(Bottleneck, self).__init__() 69 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 70 | self.bn1 = nn.BatchNorm2d(planes) 71 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 72 | padding=1, bias=False) 73 | self.bn2 = nn.BatchNorm2d(planes) 74 | self.conv3 = nn.Conv2d(planes, planes * Bottleneck.expansion, kernel_size=1, bias=False) 75 | self.bn3 = nn.BatchNorm2d(planes * Bottleneck.expansion) 76 | self.relu = nn.ReLU(inplace=True) 77 | self.downsample = downsample 78 | self.stride = stride 79 | 80 | def forward(self, x): 81 | residual = x 82 | 83 | out = self.conv1(x) 84 | out = self.bn1(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv2(out) 88 | out = self.bn2(out) 89 | out = self.relu(out) 90 | 91 | out = self.conv3(out) 92 | out = self.bn3(out) 93 | 94 | if self.downsample is not None: 95 | residual = self.downsample(x) 96 | 97 | out += residual 98 | out = self.relu(out) 99 | 100 | return out 101 | 102 | 103 | class ResNet(nn.Module): 104 | 105 | def __init__(self, block, layers, num_classes=1000,deep_base=False,stem_width=32): 106 | self.inplanes = stem_width*2 if deep_base else 64 107 | 108 | super(ResNet, self).__init__() 109 | if deep_base: 110 | self.conv1 = nn.Sequential( 111 | nn.Conv2d(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False), 112 | nn.BatchNorm2d(stem_width), 113 | nn.ReLU(inplace=True), 114 | nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False), 115 | nn.BatchNorm2d(stem_width), 116 | nn.ReLU(inplace=True), 117 | nn.Conv2d(stem_width, stem_width*2, kernel_size=3, stride=1, padding=1, bias=False), 118 | ) 119 | else: 120 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 121 | bias=False) 122 | 123 | self.bn1 = nn.BatchNorm2d(self.inplanes) 124 | self.relu = nn.ReLU(inplace=True) 125 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 126 | self.layer1 = self._make_layer(block, 64, layers[0]) 127 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 128 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 129 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 130 | self.avgpool = nn.AvgPool2d(7, stride=1) 131 | self.fc = nn.Linear(512 * block.expansion, num_classes) 132 | 133 | for m in self.modules(): 134 | if isinstance(m, nn.Conv2d): 135 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 136 | m.weight.data.normal_(0, math.sqrt(2. / n)) 137 | elif isinstance(m, nn.BatchNorm2d): 138 | m.weight.data.fill_(1) 139 | m.bias.data.zero_() 140 | 141 | def _make_layer(self, block, planes, blocks, stride=1): 142 | downsample = None 143 | if stride != 1 or self.inplanes != planes * block.expansion: 144 | downsample = nn.Sequential( 145 | nn.Conv2d(self.inplanes, planes * block.expansion, 146 | kernel_size=1, stride=stride, bias=False), 147 | nn.BatchNorm2d(planes * block.expansion), 148 | ) 149 | 150 | layers = [] 151 | layers.append(block(self.inplanes, planes, stride, downsample)) 152 | self.inplanes = planes * block.expansion 153 | for i in range(1, blocks): 154 | layers.append(block(self.inplanes, planes)) 155 | 156 | return nn.Sequential(*layers) 157 | 158 | def forward(self, x): 159 | x = self.conv1_1(x) 160 | x = self.bn1(x) 161 | x = self.relu(x) 162 | x = self.maxpool(x) 163 | 164 | x = self.layer1(x) 165 | x = self.layer2(x) 166 | x = self.layer3(x) 167 | t = self.layer4(x) 168 | 169 | # x = self.avgpool(x) 170 | # x = x.view(x.size(0), -1) 171 | # x = self.fc(x) 172 | 173 | return t 174 | 175 | 176 | def resnet18(pretrained=False, **kwargs): 177 | """Constructs a ResNet-18 model. 178 | 179 | Args: 180 | pretrained (bool): If True, returns a model pre-trained on ImageNet 181 | """ 182 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 183 | if pretrained: 184 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 185 | return model 186 | 187 | 188 | def resnet34(pretrained=False, **kwargs): 189 | """Constructs a ResNet-34 model. 190 | 191 | Args: 192 | pretrained (bool): If True, returns a model pre-trained on ImageNet 193 | """ 194 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 195 | # model_dict = model.state_dict() 196 | 197 | 198 | if pretrained: 199 | # pretrained_dict=model_zoo.load_url(model_urls['resnet34'],model_dir='/home/FENGsl/JBHI/Pretrain_model') 200 | # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 201 | # model_dict.update(pretrained_dict) 202 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'],model_dir='/home/FENGsl/JBHI/Pretrain_model')) 203 | print('===> Pretrain Model Have Been Loaded, Please fasten your seat belt and get ready to take off!') 204 | return model 205 | 206 | 207 | def resnet50(pretrained=False, **kwargs): 208 | """Constructs a ResNet-50 model. 209 | 210 | Args: 211 | pretrained (bool): If True, returns a model pre-trained on ImageNet 212 | """ 213 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 214 | if pretrained: 215 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 216 | return model 217 | 218 | 219 | def resnet101(pretrained=False, **kwargs): 220 | """Constructs a ResNet-101 model. 221 | 222 | Args: 223 | pretrained (bool): If True, returns a model pre-trained on ImageNet 224 | """ 225 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 226 | if pretrained: 227 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 228 | return model 229 | 230 | 231 | def resnet152(pretrained=False, **kwargs): 232 | """Constructs a ResNet-152 model. 233 | 234 | Args: 235 | pretrained (bool): If True, returns a model pre-trained on ImageNet 236 | """ 237 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 238 | if pretrained: 239 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 240 | return model 241 | # net = resnet34(pretrained=False) 242 | # torchsummary.summary(net, (3, 512, 512)) -------------------------------------------------------------------------------- /model/lipunet_deep.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Apr 15 14:37:08 2020 4 | 5 | @author: 45780 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | #from layers import unetConv2, unetUp 11 | #from utils import init_weights, count_param 12 | import torchsummary 13 | from torch.nn import functional as F 14 | from torch.nn import init 15 | from collections import OrderedDict 16 | 17 | COEFF = 12.0 18 | 19 | class lipUNet_deepsup(nn.Module): 20 | 21 | def __init__(self, in_channels=1, n_classes=3, feature_scale=2, is_deconv=True, is_batchnorm=True): 22 | super(lipUNet_deepsup, self).__init__() 23 | self.in_channels = in_channels 24 | self.feature_scale = feature_scale 25 | self.is_deconv = is_deconv 26 | self.is_batchnorm = is_batchnorm 27 | 28 | filters = [64, 128, 256, 512, 1024] 29 | filters = [int(x / self.feature_scale) for x in filters] 30 | 31 | # downsampling 32 | self.maxpool_1 = SimplifiedLIP(filters[0]) 33 | self.maxpool_2 = SimplifiedLIP(filters[1]) 34 | self.maxpool_3 = SimplifiedLIP(filters[2]) 35 | self.maxpool_4 = SimplifiedLIP(filters[3]) 36 | 37 | self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm) 38 | self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm) 39 | self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm) 40 | self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm) 41 | self.center = unetConv2(filters[3], filters[4], self.is_batchnorm) 42 | # upsampling 43 | self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv) 44 | self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv) 45 | self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv) 46 | self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv) 47 | # final conv (without any concat) 48 | self.final = nn.Conv2d(filters[0], n_classes, 1) 49 | 50 | #deep Supervision 51 | 52 | self.deepsup_3 = nn.Conv2d(filters[3], n_classes, kernel_size=1, stride=1, padding=0) 53 | self.output_3_up = nn.Upsample(scale_factor=8, mode='bilinear') 54 | self.deepsup_2 = nn.Conv2d(filters[2], n_classes, kernel_size=1, stride=1, padding=0) 55 | self.output_2_up = nn.Upsample(scale_factor=4, mode='bilinear') 56 | self.deepsup_1 = nn.Conv2d(filters[1], n_classes, kernel_size=1, stride=1, padding=0) 57 | self.output_1_up = nn.Upsample(scale_factor=2, mode='bilinear') 58 | 59 | 60 | def forward(self, inputs): 61 | conv1 = self.conv1(inputs) # 16*512*512 62 | maxpool1 = self.maxpool_1(conv1) # 16*256*256 63 | 64 | conv2 = self.conv2(maxpool1) # 32*256*256 65 | maxpool2 = self.maxpool_2(conv2) # 32*128*128 66 | 67 | conv3 = self.conv3(maxpool2) # 64*128*128 68 | maxpool3 = self.maxpool_3(conv3) # 64*64*64 69 | 70 | conv4 = self.conv4(maxpool3) # 128*64*64 71 | maxpool4 = self.maxpool_4(conv4) # 128*32*32 72 | 73 | center = self.center(maxpool4) # 256*32*32 74 | 75 | up4 = self.up_concat4(center,conv4) # 128*64*64 76 | up4_deep = self.deepsup_3(up4) 77 | up4_deep = self.output_3_up(up4_deep) 78 | 79 | up3 = self.up_concat3(up4,conv3) # 64*128*128 80 | up3_deep = self.deepsup_2(up3) 81 | up3_deep = self.output_2_up(up3_deep) 82 | 83 | up2 = self.up_concat2(up3,conv2) # 32*256*256 84 | up2_deep = self.deepsup_1(up2) 85 | up2_deep = self.output_1_up(up2_deep) 86 | 87 | up1 = self.up_concat1(up2,conv1) # 16*512*512 88 | 89 | 90 | final = self.final(up1) 91 | # final=F.softmax(final,dim=1)#对每一行使用Softmax 92 | up4_deep = F.log_softmax(final,dim=1) 93 | up3_deep = F.log_softmax(final,dim=1) 94 | up2_deep = F.log_softmax(final,dim=1) 95 | final=F.log_softmax(final,dim=1) 96 | 97 | return up4_deep,up3_deep,up2_deep,final 98 | 99 | class unetConv2(nn.Module): 100 | def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1): 101 | super(unetConv2, self).__init__() 102 | self.n = n 103 | self.ks = ks 104 | self.stride = stride 105 | self.padding = padding 106 | s = stride 107 | p = padding 108 | if is_batchnorm: 109 | for i in range(1, n+1): 110 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 111 | nn.BatchNorm2d(out_size), 112 | nn.ReLU(inplace=True),) 113 | setattr(self, 'conv%d'%i, conv) 114 | in_size = out_size 115 | 116 | else: 117 | for i in range(1, n+1): 118 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 119 | nn.ReLU(inplace=True),) 120 | setattr(self, 'conv%d'%i, conv) 121 | in_size = out_size 122 | 123 | # initialise the blocks 124 | for m in self.children(): 125 | init_weights(m, init_type='kaiming') 126 | 127 | def forward(self, inputs): 128 | x = inputs 129 | for i in range(1, self.n+1): 130 | conv = getattr(self, 'conv%d'%i) 131 | x = conv(x) 132 | 133 | return x 134 | 135 | class unetUp(nn.Module): 136 | def __init__(self, in_size, out_size, is_deconv, n_concat=2): 137 | super(unetUp, self).__init__() 138 | self.conv = unetConv2(in_size+(n_concat-2)*out_size, out_size, False) 139 | if is_deconv: 140 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, padding=0) 141 | else: 142 | self.up = nn.Sequential( 143 | nn.UpsamplingBilinear2d(scale_factor=2), 144 | nn.Conv2d(in_size, out_size, 1)) 145 | 146 | # initialise the blocks 147 | for m in self.children(): 148 | if m.__class__.__name__.find('unetConv2') != -1: continue#continue语句可用于循环中,用于跳过当前循环的剩余代码,然后继续进行下一轮的循环。 149 | init_weights(m, init_type='kaiming') 150 | 151 | def forward(self, high_feature, *low_feature): 152 | outputs0 = self.up(high_feature) 153 | for feature in low_feature: 154 | outputs0 = torch.cat([outputs0, feature], 1) 155 | 156 | return self.conv(outputs0) 157 | 158 | def init_weights(net, init_type='normal'): 159 | #print('initialization method [%s]' % init_type) 160 | if init_type == 'kaiming': 161 | net.apply(weights_init_kaiming) 162 | else: 163 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 164 | def weights_init_kaiming(m): 165 | classname = m.__class__.__name__ 166 | if classname.find('Conv') != -1: 167 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 168 | elif classname.find('Linear') != -1: 169 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 170 | elif classname.find('BatchNorm') != -1: 171 | init.normal_(m.weight.data, 1.0, 0.02) 172 | init.constant_(m.bias.data, 0.0) 173 | 174 | def lip2d(x, logit, kernel=3, stride=2, padding=1): 175 | weight = logit.exp() 176 | return F.avg_pool2d(x*weight, kernel, stride, padding)/F.avg_pool2d(weight, kernel, stride, padding) 177 | 178 | class SoftGate(nn.Module): 179 | def __init__(self): 180 | super(SoftGate, self).__init__() 181 | 182 | def forward(self, x): 183 | return torch.sigmoid(x).mul(COEFF) 184 | 185 | class SimplifiedLIP(nn.Module): 186 | def __init__(self, channels): 187 | super(SimplifiedLIP, self).__init__() 188 | 189 | self.logit = nn.Sequential( 190 | OrderedDict(( 191 | ('conv', nn.Conv2d(channels, channels, 3, padding=1, bias=False)), 192 | ('bn', nn.InstanceNorm2d(channels, affine=True)), 193 | ('gate', SoftGate()), 194 | )) 195 | )#相当于lip里面的g 196 | 197 | def init_layer(self): 198 | self.logit[0].weight.data.fill_(0.0) 199 | 200 | def forward(self, x): 201 | frac = lip2d(x, self.logit(x)) 202 | return frac 203 | 204 | if __name__ == '__main__': 205 | inputs = torch.rand((2, 1, 512, 512)).cuda() 206 | 207 | unet_plus_plus = lipUNet_deepsup(in_channels=1, n_classes=2).cuda() 208 | a,b,c,output = unet_plus_plus(inputs) 209 | print('# parameters:', sum(param.numel() for param in unet_plus_plus.parameters())) 210 | def get_parameter_number(net): 211 | total_num = sum(p.numel() for p in net.parameters()) 212 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) 213 | return {'Total': total_num, 'Trainable': trainable_num} 214 | print('# parameters:', get_parameter_number(unet_plus_plus)) -------------------------------------------------------------------------------- /model/rsunet.py: -------------------------------------------------------------------------------- 1 | #import _init_paths 2 | import torch 3 | import torch.nn as nn 4 | #from layers import unetConv2, unetUp 5 | #from utils import init_weights, count_param 6 | import torchsummary 7 | from torch.nn import functional as F 8 | from torch.nn import init 9 | 10 | class RSUNet(nn.Module): 11 | def __init__(self, in_channels=1, n_classes=4, feature_scale=2, is_deconv=True, is_batchnorm=True): 12 | super(RSUNet, self).__init__() #子类继承父类,子类的构造方法的第一行,系统会默认编写super(),在调用子类的构造方法时,先调用父类的无参数构造方法 13 | self.in_channels = in_channels 14 | self.feature_scale = feature_scale 15 | self.is_deconv = is_deconv 16 | self.is_batchnorm = is_batchnorm 17 | 18 | 19 | filters = [64, 128, 256, 512, 1024] 20 | filters = [int(x / self.feature_scale) for x in filters] 21 | 22 | # downsampling 23 | self.maxpool = nn.MaxPool2d(kernel_size=2) 24 | self.unetDown1 = unetDown(self.in_channels, filters[0], self.is_batchnorm) 25 | self.unetDown2 = unetDown(filters[0], filters[1], self.is_batchnorm) 26 | self.unetDown3 = unetDown(filters[1], filters[2], self.is_batchnorm) 27 | self.unetDown4 = unetDown(filters[2], filters[3], self.is_batchnorm) 28 | self.center = unetConv2(filters[3], filters[4], self.is_batchnorm) 29 | # upsampling 30 | self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv) 31 | self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv) 32 | self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv) 33 | self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv) 34 | # final conv (without any concat) 35 | self.final = nn.Conv2d(filters[0], n_classes, 1) 36 | 37 | #deep Supervision 38 | self.deepsup_3 = nn.Conv2d(filters[3], n_classes, kernel_size=1, stride=1, padding=0) 39 | self.output_3_up = nn.Upsample(scale_factor=8, mode='bilinear') 40 | self.deepsup_2 = nn.Conv2d(filters[2], n_classes, kernel_size=1, stride=1, padding=0) 41 | self.output_2_up = nn.Upsample(scale_factor=4, mode='bilinear') 42 | self.deepsup_1 = nn.Conv2d(filters[1], n_classes, kernel_size=1, stride=1, padding=0) 43 | self.output_1_up = nn.Upsample(scale_factor=2, mode='bilinear') 44 | 45 | 46 | # initialise weights 47 | for m in self.modules(): 48 | if isinstance(m, nn.Conv2d): #判别目标类型,为了防止该类里有单独的Conv2d也需要初始化 49 | init_weights(m, init_type='kaiming') 50 | elif isinstance(m, nn.BatchNorm2d): 51 | init_weights(m, init_type='kaiming') 52 | 53 | def forward(self, inputs): 54 | conv1 = self.unetDown1(inputs) # 16*512*512 55 | maxpool1 = self.maxpool(conv1) # 16*256*256 56 | conv2 = self.unetDown2(maxpool1) # 32*256*256 57 | maxpool2 = self.maxpool(conv2) # 32*128*128 58 | conv3 = self.unetDown3(maxpool2) # 64*128*128 59 | maxpool3 = self.maxpool(conv3) # 64*64*64 60 | conv4 = self.unetDown4(maxpool3) # 128*64*64 61 | maxpool4 = self.maxpool(conv4) # 128*32*32 62 | 63 | center = self.center(maxpool4) # 256*32*32 64 | up4 = self.up_concat4(center,conv4) # 128*64*64 65 | up4_deep = self.deepsup_3(up4) 66 | up4_deep = self.output_3_up(up4_deep) 67 | up3 = self.up_concat3(up4,conv3) # 64*128*128 68 | up3_deep = self.deepsup_2(up3) 69 | up3_deep = self.output_2_up(up3_deep) 70 | up2 = self.up_concat2(up3,conv2) # 32*256*256 71 | up2_deep = self.deepsup_1(up2) 72 | up2_deep = self.output_1_up(up2_deep) 73 | up1 = self.up_concat1(up2,conv1) # 16*512*512 74 | 75 | final = self.final(up1) 76 | final=F.log_softmax(final,dim=1) 77 | up4_deep = F.log_softmax(final,dim=1) 78 | up3_deep = F.log_softmax(final,dim=1) 79 | up2_deep = F.log_softmax(final,dim=1) 80 | 81 | return up4_deep,up3_deep,up2_deep,final 82 | 83 | class unetConv2(nn.Module): 84 | def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1): 85 | super(unetConv2, self).__init__() 86 | self.n = n 87 | self.ks = ks 88 | self.stride = stride 89 | self.padding = padding 90 | s = stride 91 | p = padding 92 | if is_batchnorm: 93 | for i in range(1, n+1): 94 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 95 | nn.BatchNorm2d(out_size), 96 | nn.ReLU(inplace=True),) 97 | setattr(self, 'conv%d'%i, conv) 98 | in_size = out_size 99 | 100 | else: 101 | for i in range(1, n+1): 102 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 103 | nn.ReLU(inplace=True),) 104 | setattr(self, 'conv%d'%i, conv) 105 | in_size = out_size 106 | 107 | # initialise the blocks 108 | for m in self.children(): #两个conv子模块 109 | init_weights(m, init_type='kaiming') 110 | 111 | def forward(self, inputs): 112 | x = inputs 113 | for i in range(1, self.n+1): 114 | conv = getattr(self, 'conv%d'%i) 115 | x = conv(x) 116 | 117 | return x 118 | 119 | class unetDown(nn.Module): 120 | def __init__(self, in_size, out_size, is_batchnorm): 121 | super(unetDown, self).__init__() 122 | self.conv = unetConv2(in_size, out_size, is_batchnorm) 123 | self.sb = SBU_Block(out_size) 124 | 125 | def forward(self, inputs): 126 | x = self.conv(inputs) 127 | shortcut = x.clone() 128 | sbx = self.sb(x) 129 | outputs = shortcut+sbx 130 | 131 | return outputs 132 | 133 | class unetUp(nn.Module): 134 | def __init__(self, in_size, out_size, is_deconv, n_concat=2): 135 | super(unetUp, self).__init__() 136 | self.conv = unetConv2(in_size+(n_concat-2)*out_size, out_size, True) 137 | if is_deconv: 138 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, padding=0) 139 | else: 140 | self.up = nn.Sequential( 141 | nn.UpsamplingBilinear2d(scale_factor=2), 142 | nn.Conv2d(in_size, out_size, 1)) 143 | 144 | # initialise the blocks 145 | for m in self.children(): 146 | if m.__class__.__name__.find('unetConv2') != -1: continue #unetConv2已经是一个初始化好的类,不需要再初始化 147 | init_weights(m, init_type='kaiming') 148 | 149 | def forward(self, high_feature, *low_feature): 150 | outputs0 = self.up(high_feature) 151 | for feature in low_feature: 152 | outputs0 = torch.cat([outputs0, feature], 1) 153 | return self.conv(outputs0) 154 | 155 | class SBU_Block(nn.Module): 156 | def __init__(self, channel): 157 | super(SBU_Block, self).__init__() 158 | self.conv = nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1) 159 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 160 | self.fc = nn.Sequential( 161 | nn.Conv2d(channel, channel, kernel_size=1, stride=1, padding=0, bias=False), 162 | nn.BatchNorm2d(channel), 163 | nn.ReLU(inplace=True), 164 | nn.Conv2d(channel, channel, kernel_size=1, stride=1, padding=0, bias=False), 165 | nn.Sigmoid() 166 | ) 167 | 168 | def forward(self, x): 169 | x = self.conv(x) 170 | y = x.clone() 171 | agap = self.avg_pool(torch.abs(x))#torch.abs参数的绝对值作为输出 172 | alpha = self.fc(agap) 173 | sigma = agap * alpha 174 | soft_threshold = torch.abs(sigma.expand_as(y)) 175 | 176 | y[torch.abs(y) < soft_threshold]=0 177 | return y 178 | 179 | def init_weights(net, init_type='normal'): 180 | #print('initialization method [%s]' % init_type) 181 | if init_type == 'kaiming': 182 | net.apply(weights_init_kaiming) #apply函数会递归地搜索网络内的所有module并把参数表示的函数应用到所有的module上 183 | else: 184 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 185 | 186 | def weights_init_kaiming(m): 187 | classname = m.__class__.__name__ 188 | # print(classname) 189 | if classname.find('Conv') != -1: 190 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 191 | elif classname.find('Linear') != -1: #全连接层 192 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 193 | elif classname.find('BatchNorm') != -1: 194 | init.normal_(m.weight.data, 1.0, 0.02) 195 | init.constant_(m.bias.data, 0.0) 196 | #model = RSUNet().cuda() 197 | #torchsummary.summary(model, (1, 512, 512)) -------------------------------------------------------------------------------- /model/CBAMunet_deep.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Feb 27 12:49:13 2020 4 | 5 | @author: 45780 6 | """ 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | #from layers import unetConv2, unetUp 12 | #from utils import init_weights, count_param 13 | import torchsummary 14 | from torch.nn import functional as F 15 | from torch.nn import init 16 | class CBAMUNet_deepsup(nn.Module): 17 | 18 | def __init__(self, in_channels=1, n_classes=4, feature_scale=2, is_deconv=True, is_batchnorm=True): 19 | super(CBAMUNet_deepsup, self).__init__() 20 | self.in_channels = in_channels 21 | self.feature_scale = feature_scale 22 | self.is_deconv = is_deconv 23 | self.is_batchnorm = is_batchnorm 24 | 25 | 26 | filters = [64, 128, 256, 512, 1024] 27 | filters = [int(x / self.feature_scale) for x in filters] 28 | 29 | # downsampling 30 | self.maxpool = nn.MaxPool2d(kernel_size=2) 31 | self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm) 32 | self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm) 33 | self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm) 34 | self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm) 35 | self.center = unetConv2(filters[3], filters[4], self.is_batchnorm) 36 | # upsampling 37 | self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv) 38 | self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv) 39 | self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv) 40 | self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv) 41 | # final conv (without any concat) 42 | self.final = nn.Conv2d(filters[0], n_classes, 1) 43 | 44 | #deep Supervision 45 | 46 | self.deepsup_3 = nn.Conv2d(filters[3], n_classes, kernel_size=1, stride=1, padding=0) 47 | self.output_3_up = nn.Upsample(scale_factor=8, mode='bilinear') 48 | self.deepsup_2 = nn.Conv2d(filters[2], n_classes, kernel_size=1, stride=1, padding=0) 49 | self.output_2_up = nn.Upsample(scale_factor=4, mode='bilinear') 50 | self.deepsup_1 = nn.Conv2d(filters[1], n_classes, kernel_size=1, stride=1, padding=0) 51 | self.output_1_up = nn.Upsample(scale_factor=2, mode='bilinear') 52 | 53 | 54 | def forward(self, inputs): 55 | conv1 = self.conv1(inputs) # 16*512*512 56 | maxpool1 = self.maxpool(conv1) # 16*256*256 57 | 58 | conv2 = self.conv2(maxpool1) # 32*256*256 59 | maxpool2 = self.maxpool(conv2) # 32*128*128 60 | 61 | conv3 = self.conv3(maxpool2) # 64*128*128 62 | maxpool3 = self.maxpool(conv3) # 64*64*64 63 | 64 | conv4 = self.conv4(maxpool3) # 128*64*64 65 | maxpool4 = self.maxpool(conv4) # 128*32*32 66 | 67 | center = self.center(maxpool4) # 256*32*32 68 | 69 | up4 = self.up_concat4(center,conv4) # 128*64*64 70 | up4_deep = self.deepsup_3(up4) 71 | up4_deep = self.output_3_up(up4_deep) 72 | 73 | up3 = self.up_concat3(up4,conv3) # 64*128*128 74 | up3_deep = self.deepsup_2(up3) 75 | up3_deep = self.output_2_up(up3_deep) 76 | 77 | up2 = self.up_concat2(up3,conv2) # 32*256*256 78 | up2_deep = self.deepsup_1(up2) 79 | up2_deep = self.output_1_up(up2_deep) 80 | 81 | up1 = self.up_concat1(up2,conv1) # 16*512*512 82 | 83 | 84 | final = self.final(up1) 85 | # final=F.softmax(final,dim=1)#对每一行使用Softmax 86 | up4_deep = F.log_softmax(final,dim=1) 87 | up3_deep = F.log_softmax(final,dim=1) 88 | up2_deep = F.log_softmax(final,dim=1) 89 | final=F.log_softmax(final,dim=1) 90 | 91 | return up4_deep,up3_deep,up2_deep,final 92 | 93 | class unetConv2(nn.Module): 94 | def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1): 95 | super(unetConv2, self).__init__() 96 | self.n = n 97 | self.ks = ks 98 | self.stride = stride 99 | self.padding = padding 100 | s = stride 101 | p = padding 102 | if is_batchnorm: 103 | for i in range(1, n+1): 104 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 105 | nn.BatchNorm2d(out_size), 106 | nn.ReLU(inplace=True),) 107 | setattr(self, 'conv%d'%i, conv) 108 | in_size = out_size 109 | 110 | else: 111 | for i in range(1, n+1): 112 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 113 | nn.ReLU(inplace=True),) 114 | setattr(self, 'conv%d'%i, conv) 115 | in_size = out_size 116 | 117 | # initialise the blocks 118 | for m in self.children(): 119 | init_weights(m, init_type='kaiming') 120 | 121 | def forward(self, inputs): 122 | x = inputs 123 | for i in range(1, self.n+1): 124 | conv = getattr(self, 'conv%d'%i) 125 | x = conv(x) 126 | 127 | return x 128 | 129 | class unetUp(nn.Module): 130 | def __init__(self, in_size, out_size, is_deconv, n_concat=2): 131 | super(unetUp, self).__init__() 132 | self.conv = unetConv2(in_size+(n_concat-2)*out_size, out_size, False) 133 | if is_deconv: 134 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, padding=0) 135 | else: 136 | self.up = nn.Sequential( 137 | nn.UpsamplingBilinear2d(scale_factor=2), 138 | nn.Conv2d(in_size, out_size, 1)) 139 | 140 | #attantion 141 | self.ca = ChannelAttention(2*out_size) 142 | self.sa = SpatialAttention() 143 | 144 | # initialise the blocks 145 | for m in self.children(): 146 | if m.__class__.__name__.find('unetConv2') != -1: continue 147 | init_weights(m, init_type='kaiming') 148 | 149 | def forward(self, high_feature, *low_feature): 150 | outputs0 = self.up(high_feature) 151 | for feature in low_feature: 152 | outputs0 = torch.cat([outputs0, feature], 1) 153 | outputs0 = self.ca(outputs0) * outputs0 154 | outputs0 = self.sa(outputs0) * outputs0 155 | return self.conv(outputs0) 156 | 157 | def init_weights(net, init_type='normal'): 158 | #print('initialization method [%s]' % init_type) 159 | if init_type == 'kaiming': 160 | net.apply(weights_init_kaiming) 161 | else: 162 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 163 | def weights_init_kaiming(m): 164 | classname = m.__class__.__name__ 165 | if classname.find('Conv') != -1: 166 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 167 | elif classname.find('Linear') != -1: 168 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 169 | elif classname.find('BatchNorm') != -1: 170 | init.normal_(m.weight.data, 1.0, 0.02) 171 | init.constant_(m.bias.data, 0.0) 172 | 173 | 174 | class ChannelAttention(nn.Module): 175 | def __init__(self, in_planes, ratio=16): 176 | super(ChannelAttention, self).__init__() 177 | self.avg_pool = nn.AdaptiveAvgPool2d(1)#空间维度上做平均池化,输出维度为(batch,channel,1,1) 178 | self.max_pool = nn.AdaptiveMaxPool2d(1)#空间维度上做最大池化,输出维度为(batch,channel,1,1) 179 | 180 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 181 | self.relu1 = nn.ReLU() 182 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 183 | 184 | self.sigmoid = nn.Sigmoid() 185 | 186 | def forward(self, x): 187 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 188 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 189 | out = avg_out + max_out 190 | return self.sigmoid(out) 191 | 192 | class SpatialAttention(nn.Module): 193 | def __init__(self, kernel_size=7): 194 | super(SpatialAttention, self).__init__() 195 | 196 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 197 | padding = 3 if kernel_size == 7 else 1 198 | 199 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 200 | self.sigmoid = nn.Sigmoid() 201 | 202 | def forward(self, x): 203 | avg_out = torch.mean(x, dim=1, keepdim=True)#在维度上做平均池化,输出维度为(batch,1,行,列) 204 | max_out, _ = torch.max(x, dim=1, keepdim=True)#在维度上做最大池化,输出维度为(batch,1,行,列) 205 | x = torch.cat([avg_out, max_out], dim=1) 206 | x = self.conv1(x) 207 | return self.sigmoid(x) 208 | 209 | 210 | if __name__ == '__main__': 211 | inputs = torch.rand((2, 1, 512, 512)).cuda() 212 | 213 | unet_plus_plus = CBAMUNet_deepsup(in_channels=1, n_classes=2).cuda() 214 | a,b,c,output = unet_plus_plus(inputs) 215 | print('# parameters:', sum(param.numel() for param in unet_plus_plus.parameters())) 216 | def get_parameter_number(net): 217 | total_num = sum(p.numel() for p in net.parameters()) 218 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) 219 | return {'Total': total_num, 'Trainable': trainable_num} 220 | print('# parameters:', get_parameter_number(unet_plus_plus)) -------------------------------------------------------------------------------- /model/unet_carafe_deep.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Mar 17 19:42:29 2020 4 | 5 | @author: 45780 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | #from layers import unetConv2, unetUp 11 | #from utils import init_weights, count_param 12 | import torchsummary 13 | from torch.nn import functional as F 14 | from torch.nn import init 15 | 16 | class UNet_carafe_deepsup(nn.Module): 17 | 18 | def __init__(self, in_channels=1, n_classes=3, feature_scale=2, is_deconv=False, is_batchnorm=True): 19 | super(UNet_carafe_deepsup, self).__init__() 20 | self.in_channels = in_channels 21 | self.feature_scale = feature_scale 22 | self.is_deconv = is_deconv 23 | self.is_batchnorm = is_batchnorm 24 | 25 | 26 | filters = [64, 128, 256, 512, 1024] 27 | filters = [int(x / self.feature_scale) for x in filters] 28 | 29 | # downsampling 30 | self.maxpool = nn.MaxPool2d(kernel_size=2) 31 | self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm) 32 | self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm) 33 | self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm) 34 | self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm) 35 | self.center = unetConv2(filters[3], filters[4], self.is_batchnorm) 36 | # upsampling 37 | self.up_concat4 = unetUp_2(filters[4], filters[3], self.is_deconv) 38 | self.up_concat3 = unetUp_2(filters[3], filters[2], self.is_deconv) 39 | self.up_concat2 = unetUp_2(filters[2], filters[1], self.is_deconv) 40 | self.up_concat1 = unetUp_2(filters[1], filters[0], self.is_deconv) 41 | # final conv (without any concat) 42 | self.final = nn.Conv2d(filters[0], n_classes, 1) 43 | 44 | #deep Supervision 45 | 46 | self.deepsup_3 = nn.Conv2d(filters[3], n_classes, kernel_size=1, stride=1, padding=0) 47 | self.output_3_up = nn.Upsample(scale_factor=8, mode='bilinear') 48 | self.deepsup_2 = nn.Conv2d(filters[2], n_classes, kernel_size=1, stride=1, padding=0) 49 | self.output_2_up = nn.Upsample(scale_factor=4, mode='bilinear') 50 | self.deepsup_1 = nn.Conv2d(filters[1], n_classes, kernel_size=1, stride=1, padding=0) 51 | self.output_1_up = nn.Upsample(scale_factor=2, mode='bilinear') 52 | 53 | 54 | def forward(self, inputs): 55 | conv1 = self.conv1(inputs) # 16*512*512 56 | maxpool1 = self.maxpool(conv1) # 16*256*256 57 | 58 | conv2 = self.conv2(maxpool1) # 32*256*256 59 | maxpool2 = self.maxpool(conv2) # 32*128*128 60 | 61 | conv3 = self.conv3(maxpool2) # 64*128*128 62 | maxpool3 = self.maxpool(conv3) # 64*64*64 63 | 64 | conv4 = self.conv4(maxpool3) # 128*64*64 65 | maxpool4 = self.maxpool(conv4) # 128*32*32 66 | 67 | center = self.center(maxpool4) # 256*32*32 68 | 69 | up4 = self.up_concat4(center,conv4) # 128*64*64 70 | up4_deep = self.deepsup_3(up4) 71 | up4_deep = self.output_3_up(up4_deep) 72 | 73 | up3 = self.up_concat3(up4,conv3) # 64*128*128 74 | up3_deep = self.deepsup_2(up3) 75 | up3_deep = self.output_2_up(up3_deep) 76 | 77 | up2 = self.up_concat2(up3,conv2) # 32*256*256 78 | up2_deep = self.deepsup_1(up2) 79 | up2_deep = self.output_1_up(up2_deep) 80 | 81 | up1 = self.up_concat1(up2,conv1) # 16*512*512 82 | 83 | 84 | final = self.final(up1) 85 | # final=F.softmax(final,dim=1)#对每一行使用Softmax 86 | up4_deep = F.log_softmax(final,dim=1) 87 | up3_deep = F.log_softmax(final,dim=1) 88 | up2_deep = F.log_softmax(final,dim=1) 89 | final=F.log_softmax(final,dim=1) 90 | 91 | return up4_deep,up3_deep,up2_deep,final 92 | 93 | class unetConv2(nn.Module): 94 | def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1): 95 | super(unetConv2, self).__init__() 96 | self.n = n 97 | self.ks = ks 98 | self.stride = stride 99 | self.padding = padding 100 | s = stride 101 | p = padding 102 | if is_batchnorm: 103 | for i in range(1, n+1): 104 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 105 | nn.BatchNorm2d(out_size), 106 | nn.ReLU(inplace=True),) 107 | setattr(self, 'conv%d'%i, conv) 108 | in_size = out_size 109 | 110 | else: 111 | for i in range(1, n+1): 112 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 113 | nn.ReLU(inplace=True),) 114 | setattr(self, 'conv%d'%i, conv) 115 | in_size = out_size 116 | 117 | # initialise the blocks 118 | for m in self.children(): 119 | init_weights(m, init_type='kaiming') 120 | 121 | def forward(self, inputs): 122 | x = inputs 123 | for i in range(1, self.n+1): 124 | conv = getattr(self, 'conv%d'%i) 125 | x = conv(x) 126 | 127 | return x 128 | 129 | class unetUp_2(nn.Module): 130 | def __init__(self, in_size, out_size, is_deconv, n_concat=2): 131 | super(unetUp_2, self).__init__() 132 | self.conv = unetConv2(in_size+(n_concat-2)*out_size, out_size, False) 133 | 134 | self.up = CARAFE_3(in_size) 135 | self.conv1 = nn.Conv2d(in_size, out_size, 1) 136 | 137 | # self.up = Carafe(in_size) 138 | 139 | # initialise the blocks 140 | for m in self.children(): 141 | if m.__class__.__name__.find('unetConv2') != -1: continue#continue语句可用于循环中,用于跳过当前循环的剩余代码,然后继续进行下一轮的循环。 142 | init_weights(m, init_type='kaiming') 143 | 144 | def forward(self, high_feature, *low_feature): 145 | outputs0 = self.up(high_feature) 146 | outputs0 = self.conv1(outputs0) 147 | 148 | for feature in low_feature: 149 | outputs0 = torch.cat([outputs0, feature], 1) 150 | 151 | return self.conv(outputs0) 152 | 153 | class unetUp(nn.Module): 154 | def __init__(self, in_size, out_size, is_deconv, n_concat=2): 155 | super(unetUp, self).__init__() 156 | self.conv = unetConv2(in_size+(n_concat-2)*out_size, out_size, False) 157 | if is_deconv: 158 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, padding=0) 159 | else: 160 | self.up = nn.Sequential( 161 | nn.UpsamplingBilinear2d(scale_factor=2), 162 | nn.Conv2d(in_size, out_size, 1)) 163 | 164 | # initialise the blocks 165 | for m in self.children(): 166 | if m.__class__.__name__.find('unetConv2') != -1: continue#continue语句可用于循环中,用于跳过当前循环的剩余代码,然后继续进行下一轮的循环。 167 | init_weights(m, init_type='kaiming') 168 | 169 | def forward(self, high_feature, *low_feature): 170 | outputs0 = self.up(high_feature) 171 | for feature in low_feature: 172 | outputs0 = torch.cat([outputs0, feature], 1) 173 | 174 | return self.conv(outputs0) 175 | 176 | def init_weights(net, init_type='normal'): 177 | #print('initialization method [%s]' % init_type) 178 | if init_type == 'kaiming': 179 | net.apply(weights_init_kaiming) 180 | else: 181 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 182 | def weights_init_kaiming(m): 183 | classname = m.__class__.__name__ 184 | if classname.find('Conv') != -1: 185 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 186 | elif classname.find('Linear') != -1: 187 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 188 | elif classname.find('BatchNorm') != -1: 189 | init.normal_(m.weight.data, 1.0, 0.02) 190 | init.constant_(m.bias.data, 0.0) 191 | 192 | 193 | class CARAFE_3(nn.Module): 194 | def __init__(self, c, c_mid=64, scale=2, k_up=5, k_enc=3): 195 | super(CARAFE_3, self).__init__() 196 | self.scale = scale 197 | 198 | self.comp = nn.Conv2d(c, c_mid,kernel_size=1, stride=1, 199 | padding=0, dilation=1)#降低通道数量 200 | self.bn1 = nn.BatchNorm2d(c_mid) 201 | self.relu = nn.ReLU(inplace=True) 202 | 203 | self.enc = nn.Conv2d(c_mid, (scale * k_up) ** 2, kernel_size=k_enc, 204 | stride=1, padding=k_enc // 2, dilation=1) 205 | self.bn2 = nn.BatchNorm2d((scale * k_up) ** 2) 206 | 207 | self.pix_shf = nn.PixelShuffle(scale) 208 | 209 | self.upsmp = nn.Upsample(scale_factor=scale, mode='nearest') 210 | self.unfold = nn.Unfold(kernel_size=k_up, dilation=scale, 211 | padding=k_up // 2 * scale) 212 | 213 | def forward(self, X): 214 | b, c, h, w = X.size() 215 | h_, w_ = h * self.scale, w * self.scale 216 | 217 | W = self.comp(X) # b * m * h * w 218 | W = self.bn1(W) 219 | W = self.relu(W) 220 | 221 | W = self.enc(W) # b * 100 * h * w 222 | W = self.bn2(W) 223 | W = self.pix_shf(W) # b * 25 * h_ * w_ 224 | W = F.softmax(W, dim=1) # b * 25 * h_ * w_ 225 | 226 | X = self.upsmp(X) # b * c * h_ * w_ 227 | X = self.unfold(X) # b * 25c * h_ * w_ 228 | X = X.view(b, c, -1, h_, w_) # b * 25 * c * h_ * w_ 229 | 230 | X = torch.einsum('bkhw,bckhw->bchw', [W, X]) # b * c * h_ * w_ 231 | return X 232 | 233 | if __name__ == '__main__': 234 | inputs = torch.rand((2, 1, 256, 512)).cuda() 235 | 236 | unet_plus_plus = UNet_carafe_deepsup(in_channels=1, n_classes=2).cuda() 237 | a,b,c,output = unet_plus_plus(inputs) 238 | print('# parameters:', sum(param.numel() for param in unet_plus_plus.parameters())) 239 | def get_parameter_number(net): 240 | total_num = sum(p.numel() for p in net.parameters()) 241 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) 242 | return {'Total': total_num, 'Trainable': trainable_num} 243 | print('# parameters:', get_parameter_number(unet_plus_plus)) -------------------------------------------------------------------------------- /model/unet_deep_improve.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue May 26 16:05:35 2020 4 | 5 | @author: Administrator 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | #from layers import unetConv2, unetUp 11 | #from utils import init_weights, count_param 12 | import torchsummary 13 | from torch.nn import functional as F 14 | from torch.nn import init 15 | from collections import OrderedDict 16 | from torch.nn import Module, Sequential, Conv2d, ReLU,AdaptiveMaxPool2d, AdaptiveAvgPool2d, \ 17 | NLLLoss, BCELoss, CrossEntropyLoss, AvgPool2d, MaxPool2d, Parameter, Linear, Sigmoid, Softmax, Dropout, Embedding 18 | 19 | COEFF = 12.0 20 | 21 | class unet_deepsup_improve(nn.Module): 22 | def __init__(self, in_channels=1, n_classes=3, feature_scale=2, is_deconv=True, is_batchnorm=True): 23 | super(unet_deepsup_improve, self).__init__() 24 | self.in_channels = in_channels 25 | self.feature_scale = feature_scale 26 | self.is_deconv = is_deconv 27 | self.is_batchnorm = is_batchnorm 28 | 29 | filters = [64, 128, 256, 512, 1024] 30 | filters = [int(x / self.feature_scale) for x in filters] 31 | 32 | # downsampling 33 | self.maxpool_1 = StripPooling(filters[0],nn.BatchNorm2d) 34 | self.maxpool_2 = StripPooling(filters[1],nn.BatchNorm2d) 35 | self.maxpool_3 = StripPooling(filters[2],nn.BatchNorm2d) 36 | self.maxpool_4 = StripPooling(filters[3],nn.BatchNorm2d) 37 | 38 | self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm) 39 | self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm) 40 | self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm) 41 | self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm) 42 | self.center = unetConv2(filters[3], filters[4], self.is_batchnorm) 43 | # upsampling 44 | self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv) 45 | self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv) 46 | self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv) 47 | self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv) 48 | # final conv (without any concat) 49 | self.final = nn.Conv2d(filters[0], n_classes, 1) 50 | 51 | #deep Supervision 52 | 53 | self.deepsup_3 = nn.Conv2d(filters[3], n_classes, kernel_size=1, stride=1, padding=0) 54 | self.output_3_up = nn.Upsample(scale_factor=8, mode='bilinear') 55 | self.deepsup_2 = nn.Conv2d(filters[2], n_classes, kernel_size=1, stride=1, padding=0) 56 | self.output_2_up = nn.Upsample(scale_factor=4, mode='bilinear') 57 | self.deepsup_1 = nn.Conv2d(filters[1], n_classes, kernel_size=1, stride=1, padding=0) 58 | self.output_1_up = nn.Upsample(scale_factor=2, mode='bilinear') 59 | 60 | 61 | def forward(self, inputs): 62 | conv1 = self.conv1(inputs) # 16*512*512 63 | maxpool1 = self.maxpool_1(conv1) # 16*256*256 64 | 65 | conv2 = self.conv2(maxpool1) # 32*256*256 66 | maxpool2 = self.maxpool_2(conv2) # 32*128*128 67 | 68 | conv3 = self.conv3(maxpool2) # 64*128*128 69 | maxpool3 = self.maxpool_3(conv3) # 64*64*64 70 | 71 | conv4 = self.conv4(maxpool3) # 128*64*64 72 | maxpool4 = self.maxpool_4(conv4) # 128*32*32 73 | 74 | center = self.center(maxpool4) # 256*32*32 75 | 76 | up4 = self.up_concat4(center,conv4) # 128*64*64 77 | up4_deep = self.deepsup_3(up4) 78 | up4_deep = self.output_3_up(up4_deep) 79 | 80 | up3 = self.up_concat3(up4,conv3) # 64*128*128 81 | up3_deep = self.deepsup_2(up3) 82 | up3_deep = self.output_2_up(up3_deep) 83 | 84 | up2 = self.up_concat2(up3,conv2) # 32*256*256 85 | up2_deep = self.deepsup_1(up2) 86 | up2_deep = self.output_1_up(up2_deep) 87 | 88 | up1 = self.up_concat1(up2,conv1) # 16*512*512 89 | 90 | 91 | final = self.final(up1) 92 | # final=F.softmax(final,dim=1)#对每一行使用Softmax 93 | up4_deep = F.log_softmax(final,dim=1) 94 | up3_deep = F.log_softmax(final,dim=1) 95 | up2_deep = F.log_softmax(final,dim=1) 96 | final=F.log_softmax(final,dim=1) 97 | 98 | return up4_deep,up3_deep,up2_deep,final 99 | 100 | class unetConv2(nn.Module): 101 | def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1): 102 | super(unetConv2, self).__init__() 103 | self.n = n 104 | self.ks = ks 105 | self.stride = stride 106 | self.padding = padding 107 | s = stride 108 | p = padding 109 | if is_batchnorm: 110 | for i in range(1, n+1): 111 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 112 | nn.BatchNorm2d(out_size), 113 | nn.ReLU(inplace=True),) 114 | setattr(self, 'conv%d'%i, conv) 115 | in_size = out_size 116 | 117 | else: 118 | for i in range(1, n+1): 119 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 120 | nn.ReLU(inplace=True),) 121 | setattr(self, 'conv%d'%i, conv) 122 | in_size = out_size 123 | 124 | # initialise the blocks 125 | for m in self.children(): 126 | init_weights(m, init_type='kaiming') 127 | 128 | def forward(self, inputs): 129 | x = inputs 130 | for i in range(1, self.n+1): 131 | conv = getattr(self, 'conv%d'%i) 132 | x = conv(x) 133 | 134 | return x 135 | 136 | class unetUp(nn.Module): 137 | def __init__(self, in_size, out_size, is_deconv, n_concat=2): 138 | super(unetUp, self).__init__() 139 | self.conv = unetConv2(in_size+(n_concat-2)*out_size, out_size, False) 140 | if is_deconv: 141 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, padding=0) 142 | else: 143 | self.up = nn.Sequential( 144 | nn.UpsamplingBilinear2d(scale_factor=2), 145 | nn.Conv2d(in_size, out_size, 1)) 146 | 147 | # initialise the blocks 148 | for m in self.children(): 149 | if m.__class__.__name__.find('unetConv2') != -1: continue#continue语句可用于循环中,用于跳过当前循环的剩余代码,然后继续进行下一轮的循环。 150 | init_weights(m, init_type='kaiming') 151 | 152 | def forward(self, high_feature, *low_feature): 153 | outputs0 = self.up(high_feature) 154 | for feature in low_feature: 155 | outputs0 = torch.cat([outputs0, feature], 1) 156 | 157 | return self.conv(outputs0) 158 | 159 | 160 | def init_weights(net, init_type='normal'): 161 | #print('initialization method [%s]' % init_type) 162 | if init_type == 'kaiming': 163 | net.apply(weights_init_kaiming) 164 | else: 165 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 166 | def weights_init_kaiming(m): 167 | classname = m.__class__.__name__ 168 | if classname.find('Conv') != -1: 169 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 170 | elif classname.find('Linear') != -1: 171 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 172 | elif classname.find('BatchNorm') != -1: 173 | init.normal_(m.weight.data, 1.0, 0.02) 174 | init.constant_(m.bias.data, 0.0) 175 | 176 | 177 | class StripPooling(nn.Module): 178 | """ 179 | Reference: 180 | """ 181 | def __init__(self, in_channels, norm_layer): 182 | super(StripPooling, self).__init__() 183 | ### 通过AdaptiveAvgPool2d实现strip pooling 184 | self.pool3 = nn.AdaptiveAvgPool2d((1, None)) 185 | self.pool4 = nn.AdaptiveAvgPool2d((None, 1)) 186 | 187 | inter_channels = int(in_channels) 188 | 189 | self.conv1_2 = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 1, bias=False), 190 | norm_layer(inter_channels), 191 | nn.ReLU(True)) 192 | self.conv2_3 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, (1, 3), 1, (0, 1), bias=False), 193 | norm_layer(inter_channels)) 194 | self.conv2_4 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, (3, 1), 1, (1, 0), bias=False), 195 | norm_layer(inter_channels)) 196 | self.conv2_6 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, 1, 1, bias=False), 197 | norm_layer(inter_channels), 198 | nn.ReLU(True)) 199 | self.conv3 = nn.Sequential(nn.Conv2d(inter_channels, in_channels, 3,1,1,bias=False), 200 | norm_layer(in_channels) 201 | ) 202 | self.lip = SimplifiedLIP(inter_channels) 203 | ## STPM模块 204 | def forward(self, x): 205 | _, _, h, w = x.size() 206 | # x1 = self.conv1_2(x) 207 | x1 = self.lip(x) 208 | 209 | x2 = self.conv1_2(x) 210 | x2_4 = F.interpolate(self.conv2_3(self.pool3(x2)), (h//2, w//2), mode='bilinear') 211 | x2_5 = F.interpolate(self.conv2_4(self.pool4(x2)), (h//2, w//2), mode='bilinear') 212 | 213 | x3 =self.conv2_6( x2_4 + x2_5) 214 | out = self.conv3(x1+x3) 215 | 216 | return F.relu_(out) 217 | 218 | def lip2d(x, logit, kernel=3, stride=2, padding=1): 219 | weight = logit.exp() 220 | return F.avg_pool2d(x*weight, kernel, stride, padding)/F.avg_pool2d(weight, kernel, stride, padding) 221 | 222 | class SoftGate(nn.Module): 223 | def __init__(self): 224 | super(SoftGate, self).__init__() 225 | 226 | def forward(self, x): 227 | return torch.sigmoid(x).mul(COEFF) 228 | 229 | class SimplifiedLIP(nn.Module): 230 | def __init__(self, channels): 231 | super(SimplifiedLIP, self).__init__() 232 | 233 | self.logit = nn.Sequential( 234 | OrderedDict(( 235 | ('conv', nn.Conv2d(channels, channels, 3, padding=1, bias=False)), 236 | ('bn', nn.InstanceNorm2d(channels, affine=True)), 237 | ('gate', SoftGate()), 238 | )) 239 | )#相当于lip里面的g 240 | 241 | def init_layer(self): 242 | self.logit[0].weight.data.fill_(0.0) 243 | 244 | def forward(self, x): 245 | frac = lip2d(x, self.logit(x)) 246 | return frac 247 | 248 | 249 | 250 | def conv3x3(in_planes, out_planes, stride=1): 251 | """3x3 convolution with padding""" 252 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 253 | padding=1, bias=False) 254 | 255 | 256 | 257 | if __name__ == '__main__': 258 | inputs = torch.rand((2, 1, 512, 512)).cuda() 259 | 260 | unet_plus_plus = unet_deepsup_improve(in_channels=1, n_classes=2).cuda() 261 | a,b,c,output = unet_plus_plus(inputs) 262 | print('# parameters:', sum(param.numel() for param in unet_plus_plus.parameters())) 263 | def get_parameter_number(net): 264 | total_num = sum(p.numel() for p in net.parameters()) 265 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) 266 | return {'Total': total_num, 'Trainable': trainable_num} 267 | print('# parameters:', get_parameter_number(unet_plus_plus)) 268 | print('# parameters:', get_parameter_number(unet_plus_plus)) -------------------------------------------------------------------------------- /model/unet_deepsup_stip.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Jun 4 16:11:08 2020 4 | 5 | @author: Administrator 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | #from layers import unetConv2, unetUp 11 | #from utils import init_weights, count_param 12 | import torchsummary 13 | from torch.nn import functional as F 14 | from torch.nn import init 15 | from collections import OrderedDict 16 | from torch.nn import Module, Sequential, Conv2d, ReLU,AdaptiveMaxPool2d, AdaptiveAvgPool2d, \ 17 | NLLLoss, BCELoss, CrossEntropyLoss, AvgPool2d, MaxPool2d, Parameter, Linear, Sigmoid, Softmax, Dropout, Embedding 18 | 19 | COEFF = 12.0 20 | 21 | class unet_deepsup_strip(nn.Module): 22 | def __init__(self, in_channels=1, n_classes=3, feature_scale=2, is_deconv=True, is_batchnorm=True): 23 | super(unet_deepsup_strip, self).__init__() 24 | self.in_channels = in_channels 25 | self.feature_scale = feature_scale 26 | self.is_deconv = is_deconv 27 | self.is_batchnorm = is_batchnorm 28 | 29 | filters = [64, 128, 256, 512, 1024] 30 | filters = [int(x / self.feature_scale) for x in filters] 31 | 32 | # downsampling 33 | self.maxpool_1 = StripPooling(filters[0],(128, 256),nn.BatchNorm2d) 34 | self.maxpool_2 = StripPooling(filters[1],(64, 128),nn.BatchNorm2d) 35 | self.maxpool_3 = StripPooling(filters[2],(32, 64),nn.BatchNorm2d) 36 | self.maxpool_4 = StripPooling(filters[3],(16, 32),nn.BatchNorm2d) 37 | 38 | self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm) 39 | self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm) 40 | self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm) 41 | self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm) 42 | self.center = unetConv2(filters[3], filters[4], self.is_batchnorm) 43 | # upsampling 44 | self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv) 45 | self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv) 46 | self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv) 47 | self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv) 48 | # final conv (without any concat) 49 | self.final = nn.Conv2d(filters[0], n_classes, 1) 50 | 51 | #deep Supervision 52 | 53 | self.deepsup_3 = nn.Conv2d(filters[3], n_classes, kernel_size=1, stride=1, padding=0) 54 | self.output_3_up = nn.Upsample(scale_factor=8, mode='bilinear') 55 | self.deepsup_2 = nn.Conv2d(filters[2], n_classes, kernel_size=1, stride=1, padding=0) 56 | self.output_2_up = nn.Upsample(scale_factor=4, mode='bilinear') 57 | self.deepsup_1 = nn.Conv2d(filters[1], n_classes, kernel_size=1, stride=1, padding=0) 58 | self.output_1_up = nn.Upsample(scale_factor=2, mode='bilinear') 59 | 60 | 61 | def forward(self, inputs): 62 | conv1 = self.conv1(inputs) # 16*512*512 63 | maxpool1 = self.maxpool_1(conv1) # 16*256*256 64 | 65 | conv2 = self.conv2(maxpool1) # 32*256*256 66 | maxpool2 = self.maxpool_2(conv2) # 32*128*128 67 | 68 | conv3 = self.conv3(maxpool2) # 64*128*128 69 | maxpool3 = self.maxpool_3(conv3) # 64*64*64 70 | 71 | conv4 = self.conv4(maxpool3) # 128*64*64 72 | maxpool4 = self.maxpool_4(conv4) # 128*32*32 73 | 74 | center = self.center(maxpool4) # 256*32*32 75 | 76 | up4 = self.up_concat4(center,conv4) # 128*64*64 77 | up4_deep = self.deepsup_3(up4) 78 | up4_deep = self.output_3_up(up4_deep) 79 | 80 | up3 = self.up_concat3(up4,conv3) # 64*128*128 81 | up3_deep = self.deepsup_2(up3) 82 | up3_deep = self.output_2_up(up3_deep) 83 | 84 | up2 = self.up_concat2(up3,conv2) # 32*256*256 85 | up2_deep = self.deepsup_1(up2) 86 | up2_deep = self.output_1_up(up2_deep) 87 | 88 | up1 = self.up_concat1(up2,conv1) # 16*512*512 89 | 90 | 91 | final = self.final(up1) 92 | # final=F.softmax(final,dim=1)#对每一行使用Softmax 93 | up4_deep = F.log_softmax(final,dim=1) 94 | up3_deep = F.log_softmax(final,dim=1) 95 | up2_deep = F.log_softmax(final,dim=1) 96 | final=F.log_softmax(final,dim=1) 97 | 98 | return up4_deep,up3_deep,up2_deep,final 99 | 100 | class unetConv2(nn.Module): 101 | def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1): 102 | super(unetConv2, self).__init__() 103 | self.n = n 104 | self.ks = ks 105 | self.stride = stride 106 | self.padding = padding 107 | s = stride 108 | p = padding 109 | if is_batchnorm: 110 | for i in range(1, n+1): 111 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 112 | nn.BatchNorm2d(out_size), 113 | nn.ReLU(inplace=True),) 114 | setattr(self, 'conv%d'%i, conv) 115 | in_size = out_size 116 | 117 | else: 118 | for i in range(1, n+1): 119 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 120 | nn.ReLU(inplace=True),) 121 | setattr(self, 'conv%d'%i, conv) 122 | in_size = out_size 123 | 124 | # initialise the blocks 125 | for m in self.children(): 126 | init_weights(m, init_type='kaiming') 127 | 128 | def forward(self, inputs): 129 | x = inputs 130 | for i in range(1, self.n+1): 131 | conv = getattr(self, 'conv%d'%i) 132 | x = conv(x) 133 | 134 | return x 135 | 136 | class unetUp(nn.Module): 137 | def __init__(self, in_size, out_size, is_deconv, n_concat=2): 138 | super(unetUp, self).__init__() 139 | self.conv = unetConv2(in_size+(n_concat-2)*out_size, out_size, False) 140 | if is_deconv: 141 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, padding=0) 142 | else: 143 | self.up = nn.Sequential( 144 | nn.UpsamplingBilinear2d(scale_factor=2), 145 | nn.Conv2d(in_size, out_size, 1)) 146 | 147 | # initialise the blocks 148 | for m in self.children(): 149 | if m.__class__.__name__.find('unetConv2') != -1: continue#continue语句可用于循环中,用于跳过当前循环的剩余代码,然后继续进行下一轮的循环。 150 | init_weights(m, init_type='kaiming') 151 | 152 | def forward(self, high_feature, *low_feature): 153 | outputs0 = self.up(high_feature) 154 | for feature in low_feature: 155 | outputs0 = torch.cat([outputs0, feature], 1) 156 | 157 | return self.conv(outputs0) 158 | 159 | def init_weights(net, init_type='normal'): 160 | #print('initialization method [%s]' % init_type) 161 | if init_type == 'kaiming': 162 | net.apply(weights_init_kaiming) 163 | else: 164 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 165 | def weights_init_kaiming(m): 166 | classname = m.__class__.__name__ 167 | if classname.find('Conv') != -1: 168 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 169 | elif classname.find('Linear') != -1: 170 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 171 | elif classname.find('BatchNorm') != -1: 172 | init.normal_(m.weight.data, 1.0, 0.02) 173 | init.constant_(m.bias.data, 0.0) 174 | 175 | 176 | def conv3x3(in_planes, out_planes, stride=1): 177 | """3x3 convolution with padding""" 178 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 179 | padding=1, bias=False) 180 | 181 | 182 | 183 | class StripPooling(nn.Module): 184 | """ 185 | Reference: 186 | """ 187 | def __init__(self, in_channels, pool_size, norm_layer): 188 | super(StripPooling, self).__init__() 189 | ### 通过AdaptiveAvgPool2d实现strip pooling 190 | self.pool1 = nn.AdaptiveAvgPool2d(pool_size[0]) 191 | self.pool2 = nn.AdaptiveAvgPool2d(pool_size[1]) 192 | self.pool3 = nn.AdaptiveAvgPool2d((1, None)) 193 | self.pool4 = nn.AdaptiveAvgPool2d((None, 1)) 194 | 195 | inter_channels = int(in_channels/2) 196 | self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 1, bias=False), 197 | norm_layer(inter_channels), 198 | nn.ReLU(True)) 199 | self.conv1_2 = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 1, bias=False), 200 | norm_layer(inter_channels), 201 | nn.ReLU(True)) 202 | self.conv2_0 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, 2, 1, bias=False), 203 | norm_layer(inter_channels)) 204 | self.conv2_1 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, 1, 1, bias=False), 205 | norm_layer(inter_channels)) 206 | self.conv2_2 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, 1, 1, bias=False), 207 | norm_layer(inter_channels)) 208 | self.conv2_3 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, (1, 3), 1, (0, 1), bias=False), 209 | norm_layer(inter_channels)) 210 | self.conv2_4 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, (3, 1), 1, (1, 0), bias=False), 211 | norm_layer(inter_channels)) 212 | self.conv2_5 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, 1, 1, bias=False), 213 | norm_layer(inter_channels), 214 | nn.ReLU(True)) 215 | self.conv2_6 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, 1, 1, bias=False), 216 | norm_layer(inter_channels), 217 | nn.ReLU(True)) 218 | self.conv3 = nn.Sequential(nn.Conv2d(inter_channels*2, in_channels, 1, bias=False), 219 | norm_layer(in_channels)) 220 | # bilinear interpolate options 221 | 222 | ## SPM模块 223 | def forward(self, x): 224 | _, _, h, w = x.size() 225 | x1 = self.conv1_1(x) 226 | x2 = self.conv1_2(x) 227 | x2_1 = self.conv2_0(x1) 228 | x2_2 = F.interpolate(self.conv2_1(self.pool1(x1)), (h//2, w//2), mode='bilinear') 229 | x2_3 = F.interpolate(self.conv2_2(self.pool2(x1)), (h//2, w//2), mode='bilinear') 230 | x2_4 = F.interpolate(self.conv2_3(self.pool3(x2)), (h//2, w//2), mode='bilinear') 231 | x2_5 = F.interpolate(self.conv2_4(self.pool4(x2)), (h//2, w//2), mode='bilinear') 232 | x1 = self.conv2_5(F.relu_(x2_1 + x2_2 + x2_3)) 233 | x2 = self.conv2_6(F.relu_(x2_5 + x2_4)) 234 | out = self.conv3(torch.cat([x1, x2], dim=1)) 235 | return F.relu_(out) 236 | 237 | 238 | if __name__ == '__main__': 239 | inputs = torch.rand((2, 1, 256, 512)).cuda() 240 | 241 | unet_plus_plus = unet_deepsup_stip(in_channels=1, n_classes=2).cuda() 242 | a,b,c,output = unet_plus_plus(inputs) 243 | print('# parameters:', sum(param.numel() for param in unet_plus_plus.parameters())) 244 | def get_parameter_number(net): 245 | total_num = sum(p.numel() for p in net.parameters()) 246 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) 247 | return {'Total': total_num, 'Trainable': trainable_num} 248 | print('# parameters:', get_parameter_number(unet_plus_plus)) 249 | print('# parameters:', get_parameter_number(unet_plus_plus)) -------------------------------------------------------------------------------- /model/SCconv_unet_deep.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon May 4 15:35:39 2020 4 | 5 | @author: 45780 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | #from layers import unetConv2, unetUp 11 | #from utils import init_weights, count_param 12 | import torchsummary 13 | from torch.nn import functional as F 14 | from torch.nn import init 15 | 16 | class SCconv_UNet_deepsup(nn.Module): 17 | 18 | def __init__(self, in_channels=1, n_classes=4, feature_scale=2, is_deconv=True, is_batchnorm=True): 19 | super(SCconv_UNet_deepsup, self).__init__() 20 | self.in_channels = in_channels 21 | self.feature_scale = feature_scale 22 | self.is_deconv = is_deconv 23 | self.is_batchnorm = is_batchnorm 24 | 25 | 26 | filters = [64, 128, 256, 512, 1024] 27 | filters = [int(x / self.feature_scale) for x in filters] 28 | 29 | # downsampling 30 | self.maxpool = nn.MaxPool2d(kernel_size=2) 31 | self.conv1 = unetConv1(self.in_channels, filters[0], self.is_batchnorm) 32 | self.conv2 = unetConv1(filters[0], filters[1], self.is_batchnorm) 33 | self.conv3 = unetConv1(filters[1], filters[2], self.is_batchnorm) 34 | self.conv4 = unetConv1(filters[2], filters[3], self.is_batchnorm) 35 | self.center = unetConv1(filters[3], filters[4], self.is_batchnorm) 36 | # upsampling 37 | self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv) 38 | self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv) 39 | self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv) 40 | self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv) 41 | # final conv (without any concat) 42 | self.final = nn.Conv2d(filters[0], n_classes, 1) 43 | 44 | #deep Supervision 45 | 46 | self.deepsup_3 = nn.Conv2d(filters[3], n_classes, kernel_size=1, stride=1, padding=0) 47 | self.output_3_up = nn.Upsample(scale_factor=8, mode='bilinear') 48 | self.deepsup_2 = nn.Conv2d(filters[2], n_classes, kernel_size=1, stride=1, padding=0) 49 | self.output_2_up = nn.Upsample(scale_factor=4, mode='bilinear') 50 | self.deepsup_1 = nn.Conv2d(filters[1], n_classes, kernel_size=1, stride=1, padding=0) 51 | self.output_1_up = nn.Upsample(scale_factor=2, mode='bilinear') 52 | 53 | 54 | def forward(self, inputs): 55 | conv1 = self.conv1(inputs) # 16*512*512 56 | maxpool1 = self.maxpool(conv1) # 16*256*256 57 | 58 | conv2 = self.conv2(maxpool1) # 32*256*256 59 | maxpool2 = self.maxpool(conv2) # 32*128*128 60 | 61 | conv3 = self.conv3(maxpool2) # 64*128*128 62 | maxpool3 = self.maxpool(conv3) # 64*64*64 63 | 64 | conv4 = self.conv4(maxpool3) # 128*64*64 65 | maxpool4 = self.maxpool(conv4) # 128*32*32 66 | 67 | center = self.center(maxpool4) # 256*32*32 68 | 69 | up4 = self.up_concat4(center,conv4) # 128*64*64 70 | up4_deep = self.deepsup_3(up4) 71 | up4_deep = self.output_3_up(up4_deep) 72 | 73 | up3 = self.up_concat3(up4,conv3) # 64*128*128 74 | up3_deep = self.deepsup_2(up3) 75 | up3_deep = self.output_2_up(up3_deep) 76 | 77 | up2 = self.up_concat2(up3,conv2) # 32*256*256 78 | up2_deep = self.deepsup_1(up2) 79 | up2_deep = self.output_1_up(up2_deep) 80 | 81 | up1 = self.up_concat1(up2,conv1) # 16*512*512 82 | 83 | 84 | final = self.final(up1) 85 | # final=F.softmax(final,dim=1)#对每一行使用Softmax 86 | up4_deep = F.log_softmax(final,dim=1) 87 | up3_deep = F.log_softmax(final,dim=1) 88 | up2_deep = F.log_softmax(final,dim=1) 89 | final=F.log_softmax(final,dim=1) 90 | 91 | return up4_deep,up3_deep,up2_deep,final 92 | 93 | class unetConv1(nn.Module):#编码部分用SCnet,解码不用 94 | def __init__(self, in_size, out_size, n=2, ks=3, stride=1, padding=1): 95 | super(unetConv1, self).__init__() 96 | self.n = n 97 | self.ks = ks 98 | self.stride = stride 99 | self.padding = padding 100 | 101 | for i in range(1, n+1): 102 | conv = SCnet(in_size, out_size) 103 | setattr(self, 'conv%d'%i, conv) 104 | in_size = out_size 105 | 106 | ## initialise the blocks 107 | # for m in self.children(): 108 | # init_weights(m, init_type='kaiming') 109 | 110 | def forward(self, inputs): 111 | x = inputs 112 | for i in range(1, self.n+1): 113 | conv = getattr(self, 'conv%d'%i) 114 | x = conv(x) 115 | 116 | return x 117 | 118 | 119 | class unetConv2(nn.Module): 120 | def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1): 121 | super(unetConv2, self).__init__() 122 | self.n = n 123 | self.ks = ks 124 | self.stride = stride 125 | self.padding = padding 126 | s = stride 127 | p = padding 128 | if is_batchnorm: 129 | for i in range(1, n+1): 130 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 131 | nn.BatchNorm2d(out_size), 132 | nn.ReLU(inplace=True),) 133 | setattr(self, 'conv%d'%i, conv) 134 | in_size = out_size 135 | 136 | else: 137 | for i in range(1, n+1): 138 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 139 | nn.ReLU(inplace=True),) 140 | setattr(self, 'conv%d'%i, conv) 141 | in_size = out_size 142 | 143 | # initialise the blocks 144 | for m in self.children(): 145 | init_weights(m, init_type='kaiming') 146 | 147 | def forward(self, inputs): 148 | x = inputs 149 | for i in range(1, self.n+1): 150 | conv = getattr(self, 'conv%d'%i) 151 | x = conv(x) 152 | 153 | return x 154 | 155 | class unetUp(nn.Module): 156 | def __init__(self, in_size, out_size, is_deconv, n_concat=2): 157 | super(unetUp, self).__init__() 158 | self.conv = unetConv2(in_size+(n_concat-2)*out_size, out_size, False) 159 | if is_deconv: 160 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, padding=0) 161 | else: 162 | self.up = nn.Sequential( 163 | nn.UpsamplingBilinear2d(scale_factor=2), 164 | nn.Conv2d(in_size, out_size, 1)) 165 | 166 | # initialise the blocks 167 | for m in self.children(): 168 | if m.__class__.__name__.find('unetConv2') != -1: continue#continue语句可用于循环中,用于跳过当前循环的剩余代码,然后继续进行下一轮的循环。 169 | init_weights(m, init_type='kaiming') 170 | 171 | def forward(self, high_feature, *low_feature): 172 | outputs0 = self.up(high_feature) 173 | for feature in low_feature: 174 | outputs0 = torch.cat([outputs0, feature], 1) 175 | 176 | return self.conv(outputs0) 177 | 178 | def init_weights(net, init_type='normal'): 179 | #print('initialization method [%s]' % init_type) 180 | if init_type == 'kaiming': 181 | net.apply(weights_init_kaiming) 182 | else: 183 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 184 | def weights_init_kaiming(m): 185 | classname = m.__class__.__name__ 186 | if classname.find('Conv') != -1: 187 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 188 | elif classname.find('Linear') != -1: 189 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 190 | elif classname.find('BatchNorm') != -1: 191 | init.normal_(m.weight.data, 1.0, 0.02) 192 | init.constant_(m.bias.data, 0.0) 193 | 194 | def conv3x3(in_planes, out_planes, stride=1, groups=1): 195 | """3x3 convolution with padding""" 196 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 197 | padding=1, groups=groups, bias=False) 198 | 199 | 200 | def conv1x1(in_planes, out_planes, stride=1): 201 | """1x1 convolution""" 202 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 203 | 204 | 205 | class SCConv(nn.Module):#文章中虚线框部分 206 | def __init__(self, planes, stride, pooling_r): 207 | super(SCConv, self).__init__() 208 | self.k2 = nn.Sequential( 209 | nn.AvgPool2d(kernel_size=pooling_r, stride=pooling_r), 210 | conv3x3(planes, planes), 211 | nn.BatchNorm2d(planes), 212 | ) 213 | self.k3 = nn.Sequential( 214 | conv3x3(planes, planes), 215 | nn.BatchNorm2d(planes), 216 | ) 217 | self.k4 = nn.Sequential( 218 | conv3x3(planes, planes, stride), 219 | nn.BatchNorm2d(planes), 220 | nn.ReLU(inplace=True), 221 | ) 222 | # initialise the blocks 223 | for m in self.children(): 224 | init_weights(m, init_type='kaiming') 225 | 226 | def forward(self, x): 227 | identity = x 228 | 229 | out = torch.sigmoid(torch.add(identity, F.interpolate(self.k2(x), identity.size()[2:]))) # sigmoid(identity + k2) 230 | out = torch.mul(self.k3(x), out) # k3 * sigmoid(identity + k2) 231 | out = self.k4(out) # k4 232 | 233 | return out 234 | 235 | 236 | class SCnet(nn.Module): 237 | expansion = 4 238 | pooling_r = 4 # down-sampling rate of the avg pooling layer in the K3 path of SC-Conv. 239 | 240 | def __init__(self, inplanes, planes, stride=1): 241 | super(SCnet, self).__init__() 242 | planes = int(planes / 2) 243 | 244 | self.conv1_a = conv1x1(inplanes, planes) 245 | self.bn1_a = nn.BatchNorm2d(planes) 246 | 247 | self.k1 = nn.Sequential( 248 | conv3x3(planes, planes, stride), 249 | nn.BatchNorm2d(planes), 250 | nn.ReLU(inplace=True), 251 | ) 252 | 253 | self.conv1_b = conv1x1(inplanes, planes) 254 | self.bn1_b = nn.BatchNorm2d(planes) 255 | 256 | self.scconv = SCConv(planes, stride, self.pooling_r) 257 | 258 | self.bn3 = nn.BatchNorm2d(planes * 2 ) 259 | self.relu = nn.ReLU(inplace=True) 260 | self.stride = stride 261 | # initialise the blocks 262 | for m in self.children(): 263 | if m.__class__.__name__.find('SCConv') != -1: continue#continue语句可用于循环中,用于跳过当前循环的剩余代码,然后继续进行下一轮的循环。 264 | init_weights(m, init_type='kaiming') 265 | 266 | def forward(self, x): 267 | 268 | out_a= self.conv1_a(x) 269 | out_a = self.bn1_a(out_a) 270 | out_a = self.relu(out_a) 271 | 272 | out_a = self.k1(out_a) 273 | 274 | 275 | out_b = self.conv1_b(x) 276 | out_b = self.bn1_b(out_b) 277 | out_b = self.relu(out_b) 278 | 279 | out_b = self.scconv(out_b) 280 | 281 | out = torch.cat([out_a, out_b], dim=1) 282 | out = self.bn3(out) 283 | 284 | out = self.relu(out) 285 | 286 | return out 287 | 288 | if __name__ == '__main__': 289 | inputs = torch.rand((2, 1, 512, 512)).cuda() 290 | 291 | unet_plus_plus = SCconv_UNet_deepsup(in_channels=1, n_classes=2).cuda() 292 | a,b,c,output = unet_plus_plus(inputs) 293 | print('# parameters:', sum(param.numel() for param in unet_plus_plus.parameters())) 294 | def get_parameter_number(net): 295 | total_num = sum(p.numel() for p in net.parameters()) 296 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) 297 | return {'Total': total_num, 'Trainable': trainable_num} 298 | print('# parameters:', get_parameter_number(unet_plus_plus)) -------------------------------------------------------------------------------- /model/unet_deep_usecarafe.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Jul 12 13:56:19 2020 4 | 5 | @author: Administrator 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | #from layers import unetConv2, unetUp 11 | #from utils import init_weights, count_param 12 | import torchsummary 13 | from torch.nn import functional as F 14 | from torch.nn import init 15 | 16 | class UNet_deepsupusecarafe(nn.Module): 17 | 18 | def __init__(self, in_channels=1, n_classes=3, feature_scale=2, is_deconv=True, is_batchnorm=True): 19 | super(UNet_deepsupusecarafe, self).__init__() 20 | self.in_channels = in_channels 21 | self.feature_scale = feature_scale 22 | self.is_deconv = is_deconv 23 | self.is_batchnorm = is_batchnorm 24 | 25 | 26 | filters = [64, 128, 256, 512, 1024] 27 | filters = [int(x / self.feature_scale) for x in filters] 28 | 29 | # downsampling 30 | self.maxpool = nn.MaxPool2d(kernel_size=2) 31 | self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm) 32 | self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm) 33 | self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm) 34 | self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm) 35 | self.center = unetConv2(filters[3], filters[4], self.is_batchnorm) 36 | # upsampling 37 | self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv) 38 | self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv) 39 | self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv) 40 | self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv) 41 | # final conv (without any concat) 42 | self.final = nn.Conv2d(filters[0], n_classes, 1) 43 | 44 | #deep Supervision 45 | self.up_3 = CARAFE_3(filters[3]) 46 | self.deepsup_3 = nn.Conv2d(filters[3], n_classes, kernel_size=1, stride=1, padding=0) 47 | self.up_2 = CARAFE_3(filters[2]) 48 | self.deepsup_2 = nn.Conv2d(filters[2], n_classes, kernel_size=1, stride=1, padding=0) 49 | self.up_1 = CARAFE_2(filters[1]) 50 | self.deepsup_1 = nn.Conv2d(filters[1], n_classes, kernel_size=1, stride=1, padding=0) 51 | 52 | 53 | def forward(self, inputs): 54 | conv1 = self.conv1(inputs) # 16*512*512 55 | maxpool1 = self.maxpool(conv1) # 16*256*256 56 | 57 | conv2 = self.conv2(maxpool1) # 32*256*256 58 | maxpool2 = self.maxpool(conv2) # 32*128*128 59 | 60 | conv3 = self.conv3(maxpool2) # 64*128*128 61 | maxpool3 = self.maxpool(conv3) # 64*64*64 62 | 63 | conv4 = self.conv4(maxpool3) # 128*64*64 64 | maxpool4 = self.maxpool(conv4) # 128*32*32 65 | 66 | center = self.center(maxpool4) # 256*32*32 67 | 68 | up4 = self.up_concat4(center,conv4) # 128*64*64 69 | up4_1 = self.up_3(up4) 70 | up4_deep = self.deepsup_3(up4_1) 71 | 72 | up3 = self.up_concat3(up4,conv3) # 64*128*128 73 | up3_1 = self.up_2(up3) 74 | up3_deep = self.deepsup_2(up3_1) 75 | 76 | up2 = self.up_concat2(up3,conv2) # 32*256*256 77 | up2_1 = self.up_1(up2) 78 | up2_deep = self.deepsup_1(up2_1) 79 | 80 | up1 = self.up_concat1(up2,conv1) # 16*512*512 81 | 82 | 83 | final = self.final(up1) 84 | # final=F.softmax(final,dim=1)#对每一行使用Softmax 85 | up4_deep = F.log_softmax(final,dim=1) 86 | up3_deep = F.log_softmax(final,dim=1) 87 | up2_deep = F.log_softmax(final,dim=1) 88 | final=F.log_softmax(final,dim=1) 89 | 90 | return up4_deep,up3_deep,up2_deep,final 91 | 92 | class unetConv2(nn.Module): 93 | def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1): 94 | super(unetConv2, self).__init__() 95 | self.n = n 96 | self.ks = ks 97 | self.stride = stride 98 | self.padding = padding 99 | s = stride 100 | p = padding 101 | if is_batchnorm: 102 | for i in range(1, n+1): 103 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 104 | nn.BatchNorm2d(out_size), 105 | nn.ReLU(inplace=True),) 106 | setattr(self, 'conv%d'%i, conv) 107 | in_size = out_size 108 | 109 | else: 110 | for i in range(1, n+1): 111 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 112 | nn.ReLU(inplace=True),) 113 | setattr(self, 'conv%d'%i, conv) 114 | in_size = out_size 115 | 116 | # initialise the blocks 117 | for m in self.children(): 118 | init_weights(m, init_type='kaiming') 119 | 120 | def forward(self, inputs): 121 | x = inputs 122 | for i in range(1, self.n+1): 123 | conv = getattr(self, 'conv%d'%i) 124 | x = conv(x) 125 | 126 | return x 127 | 128 | class unetUp_2(nn.Module): 129 | def __init__(self, in_size, out_size, is_deconv, n_concat=2): 130 | super(unetUp_2, self).__init__() 131 | self.conv = unetConv2(in_size+(n_concat-2)*out_size, out_size, False) 132 | 133 | self.up = CARAFE_3(in_size) 134 | self.conv1 = nn.Conv2d(in_size, out_size, 1) 135 | 136 | # self.up = Carafe(in_size) 137 | 138 | # initialise the blocks 139 | for m in self.children(): 140 | if m.__class__.__name__.find('unetConv2') != -1: continue#continue语句可用于循环中,用于跳过当前循环的剩余代码,然后继续进行下一轮的循环。 141 | init_weights(m, init_type='kaiming') 142 | 143 | def forward(self, high_feature, *low_feature): 144 | outputs0 = self.up(high_feature) 145 | outputs0 = self.conv1(outputs0) 146 | 147 | for feature in low_feature: 148 | outputs0 = torch.cat([outputs0, feature], 1) 149 | 150 | return self.conv(outputs0) 151 | 152 | class unetUp(nn.Module): 153 | def __init__(self, in_size, out_size, is_deconv, n_concat=2): 154 | super(unetUp, self).__init__() 155 | self.conv = unetConv2(in_size+(n_concat-2)*out_size, out_size, False) 156 | if is_deconv: 157 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, padding=0) 158 | else: 159 | self.up = nn.Sequential( 160 | nn.UpsamplingBilinear2d(scale_factor=2), 161 | nn.Conv2d(in_size, out_size, 1)) 162 | 163 | # initialise the blocks 164 | for m in self.children(): 165 | if m.__class__.__name__.find('unetConv2') != -1: continue#continue语句可用于循环中,用于跳过当前循环的剩余代码,然后继续进行下一轮的循环。 166 | init_weights(m, init_type='kaiming') 167 | 168 | def forward(self, high_feature, *low_feature): 169 | outputs0 = self.up(high_feature) 170 | for feature in low_feature: 171 | outputs0 = torch.cat([outputs0, feature], 1) 172 | 173 | return self.conv(outputs0) 174 | 175 | def init_weights(net, init_type='normal'): 176 | #print('initialization method [%s]' % init_type) 177 | if init_type == 'kaiming': 178 | net.apply(weights_init_kaiming) 179 | else: 180 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 181 | def weights_init_kaiming(m): 182 | classname = m.__class__.__name__ 183 | if classname.find('Conv') != -1: 184 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 185 | elif classname.find('Linear') != -1: 186 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 187 | elif classname.find('BatchNorm') != -1: 188 | init.normal_(m.weight.data, 1.0, 0.02) 189 | init.constant_(m.bias.data, 0.0) 190 | 191 | 192 | class CARAFE_3(nn.Module): 193 | def __init__(self, c, c_mid=64, scale=2, k_up=5, k_enc=3): 194 | super(CARAFE_3, self).__init__() 195 | self.scale = scale 196 | 197 | self.comp = nn.Conv2d(c, c_mid,kernel_size=1, stride=1, 198 | padding=0, dilation=1)#降低通道数量 199 | self.bn1 = nn.BatchNorm2d(c_mid) 200 | self.relu = nn.ReLU(inplace=True) 201 | 202 | self.enc = nn.Conv2d(c_mid, (scale * k_up) ** 2, kernel_size=k_enc, 203 | stride=1, padding=k_enc // 2, dilation=1) 204 | self.bn2 = nn.BatchNorm2d((scale * k_up) ** 2) 205 | 206 | self.pix_shf = nn.PixelShuffle(scale) 207 | 208 | self.upsmp = nn.Upsample(scale_factor=scale, mode='bilinear') 209 | self.unfold = nn.Unfold(kernel_size=k_up, dilation=scale, 210 | padding=k_up // 2 * scale) 211 | 212 | # initialise the blocks 213 | for m in self.children(): 214 | init_weights(m, init_type='kaiming') 215 | 216 | 217 | def forward(self, X): 218 | b, c, h, w = X.size() 219 | h_, w_ = h * self.scale, w * self.scale 220 | 221 | W = self.comp(X) # b * m * h * w 222 | W = self.bn1(W) 223 | W = self.relu(W) 224 | 225 | W = self.enc(W) # b * 100 * h * w 226 | W = self.bn2(W) 227 | W = self.pix_shf(W) # b * 25 * h_ * w_ 228 | W = F.softmax(W, dim=1) # b * 25 * h_ * w_ 229 | 230 | X = self.upsmp(X) # b * c * h_ * w_ 231 | X = self.unfold(X) # b * 25c * h_ * w_ 232 | X = X.view(b, c, -1, h_, w_) # b * 25 * c * h_ * w_ 233 | 234 | X = torch.einsum('bkhw,bckhw->bchw', [W, X]) # b * c * h_ * w_ 235 | return X 236 | 237 | class CARAFE_2(nn.Module): 238 | def __init__(self, c, c_mid=32, scale=2, k_up=5, k_enc=3): 239 | super(CARAFE_2, self).__init__() 240 | self.scale = scale 241 | 242 | self.comp = nn.Conv2d(c, c_mid,kernel_size=1, stride=1, 243 | padding=0, dilation=1)#降低通道数量 244 | self.bn1 = nn.BatchNorm2d(c_mid) 245 | self.relu = nn.ReLU(inplace=True) 246 | 247 | self.enc = nn.Conv2d(c_mid, (scale * k_up) ** 2, kernel_size=k_enc, 248 | stride=1, padding=k_enc // 2, dilation=1) 249 | self.bn2 = nn.BatchNorm2d((scale * k_up) ** 2) 250 | 251 | self.pix_shf = nn.PixelShuffle(scale) 252 | 253 | self.upsmp = nn.Upsample(scale_factor=scale, mode='bilinear') 254 | self.unfold = nn.Unfold(kernel_size=k_up, dilation=scale, 255 | padding=k_up // 2 * scale) 256 | 257 | # initialise the blocks 258 | for m in self.children(): 259 | init_weights(m, init_type='kaiming') 260 | 261 | 262 | def forward(self, X): 263 | b, c, h, w = X.size() 264 | h_, w_ = h * self.scale, w * self.scale 265 | 266 | W = self.comp(X) # b * m * h * w 267 | W = self.bn1(W) 268 | W = self.relu(W) 269 | 270 | W = self.enc(W) # b * 100 * h * w 271 | W = self.bn2(W) 272 | W = self.pix_shf(W) # b * 25 * h_ * w_ 273 | W = F.softmax(W, dim=1) # b * 25 * h_ * w_ 274 | 275 | X = self.upsmp(X) # b * c * h_ * w_ 276 | X = self.unfold(X) # b * 25c * h_ * w_ 277 | X = X.view(b, c, -1, h_, w_) # b * 25 * c * h_ * w_ 278 | 279 | X = torch.einsum('bkhw,bckhw->bchw', [W, X]) # b * c * h_ * w_ 280 | return X 281 | 282 | if __name__ == '__main__': 283 | inputs = torch.rand((2, 1, 256, 512)).cuda() 284 | 285 | unet_plus_plus = UNet_deepsupusecarafe(in_channels=1, n_classes=2).cuda() 286 | a,b,c,output = unet_plus_plus(inputs) 287 | print('# parameters:', sum(param.numel() for param in unet_plus_plus.parameters())) 288 | def get_parameter_number(net): 289 | total_num = sum(p.numel() for p in net.parameters()) 290 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) 291 | return {'Total': total_num, 'Trainable': trainable_num} 292 | print('# parameters:', get_parameter_number(unet_plus_plus)) -------------------------------------------------------------------------------- /model/DualUnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Dec 26 19:16:06 2019 4 | 5 | @author: Administrator 6 | """ 7 | 8 | #import _init_paths 9 | import torch 10 | import torch.nn as nn 11 | #from layers import unetConv2, unetUp 12 | #from utils import init_weights, count_param 13 | import torchsummary 14 | from torch.nn import functional as F 15 | from torch.nn import init 16 | class Dual_Unet(nn.Module): 17 | 18 | def __init__(self, in_channels=1, n_classes=4, feature_scale=2, is_deconv=True, is_batchnorm=True): 19 | super(Dual_Unet, self).__init__() 20 | self.in_channels = in_channels 21 | self.feature_scale = feature_scale 22 | self.is_deconv = is_deconv 23 | self.is_batchnorm = is_batchnorm 24 | 25 | 26 | filters = [64, 128, 256, 512, 1024, 2048] 27 | filters = [int(x / self.feature_scale) for x in filters] 28 | 29 | 30 | # downsampling 31 | self.maxpool = nn.MaxPool2d(kernel_size=2) 32 | self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm) 33 | self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm) 34 | self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm) 35 | self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm) 36 | 37 | self.Context_1 = Context_Path(self.in_channels, filters[0]) 38 | self.Context_2 = Context_Path(filters[0], filters[1]) 39 | self.Context_3 = Context_Path(filters[1], filters[2]) 40 | self.Context_4 = Context_Path(filters[2], filters[3]) 41 | 42 | self.Attention_1 = Attention_Skip(filters[0]) 43 | self.Attention_2 = Attention_Skip(filters[1]) 44 | self.Attention_3 = Attention_Skip(filters[2]) 45 | self.Attention_4 = Attention_Skip(filters[3]) 46 | 47 | self.Feature_Fusion = Feature_Fusion(filters[4],filters[4]) 48 | 49 | # upsampling 50 | self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv) 51 | self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv) 52 | self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv) 53 | 54 | self.final = nn.Conv2d(filters[1], n_classes, 1) 55 | 56 | self.conv1_1 = nn.Conv2d(filters[5], filters[4], 1, 1) 57 | 58 | # 1*1 conv (without any concat) 59 | # self.final = Multiscale_Predict(filters[2], n_classes, 1) 60 | 61 | # initialise weights 62 | for m in self.modules(): 63 | if isinstance(m, nn.Conv2d): 64 | init_weights(m, init_type='kaiming') 65 | elif isinstance(m, nn.BatchNorm2d): 66 | init_weights(m, init_type='kaiming') 67 | 68 | def forward(self, inputs): 69 | conv1 = self.conv1(inputs) # 64*512*512 70 | maxpool1 = self.maxpool(conv1) # 64*256*256 71 | conv2 = self.conv2(maxpool1) # 128*256*256 72 | maxpool2 = self.maxpool(conv2) # 128*128*128 73 | conv3 = self.conv3(maxpool2) # 256*128*128 74 | maxpool3 = self.maxpool(conv3) # 256*64*64 75 | conv4 = self.conv4(maxpool3) # 512*64*64 76 | 77 | conx1 = self.Context_1(inputs)#64 78 | conx1_1 = self.maxpool(conx1) 79 | conx2 = self.Context_2(conx1_1)#128 80 | conx2_1 = self.maxpool(conx2) 81 | conx3 = self.Context_3(conx2_1)#256 82 | conx3_1 = self.maxpool(conx3) 83 | conx4 = self.Context_4(conx3_1)#512 84 | 85 | Attention1 = self.Attention_1(conv1,conx1) #128 86 | Attention2 = self.Attention_2(conv2,conx2) #256 87 | Attention3 = self.Attention_3(conv3,conx3) #512 88 | Attention4 = self.Attention_4(conv4,conx4) #1024 89 | 90 | Feature_Fusion = self.Feature_Fusion(conv4,conx4)#1024 91 | 92 | lay1 = torch.cat([Feature_Fusion, Attention4], dim=1)#2048 93 | lay1 = self.conv1_1(lay1)#1024 94 | 95 | lay2 = self.up_concat4(lay1,Attention3) # 512 96 | lay3 = self.up_concat3(lay2,Attention2) # 256 97 | lay4 = self.up_concat2(lay3,Attention1) # 128 98 | final = self.final(lay4) 99 | 100 | # final = self.final(lay3,lay4) 101 | # final=F.softmax(final,dim=1)#对每一行使用Softmax 102 | final=F.log_softmax(final,dim=1) 103 | 104 | return final 105 | 106 | 107 | 108 | 109 | class unetConv2(nn.Module): 110 | def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1): 111 | super(unetConv2, self).__init__() 112 | self.n = n 113 | self.ks = ks 114 | self.stride = stride 115 | self.padding = padding 116 | s = stride 117 | p = padding 118 | if is_batchnorm: 119 | for i in range(1, n+1): 120 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 121 | nn.BatchNorm2d(out_size), 122 | nn.ReLU(inplace=True),) 123 | setattr(self, 'conv%d'%i, conv) 124 | in_size = out_size 125 | 126 | else: 127 | for i in range(1, n+1): 128 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 129 | nn.ReLU(inplace=True),) 130 | setattr(self, 'conv%d'%i, conv) 131 | in_size = out_size 132 | 133 | # initialise the blocks 134 | for m in self.children(): 135 | init_weights(m, init_type='kaiming') 136 | 137 | def forward(self, inputs): 138 | x = inputs 139 | for i in range(1, self.n+1): 140 | conv = getattr(self, 'conv%d'%i) 141 | x = conv(x) 142 | 143 | return x 144 | 145 | class unetUp(nn.Module): 146 | def __init__(self, in_size, out_size, is_deconv, n_concat=2): 147 | super(unetUp, self).__init__() 148 | self.conv = unetConv2(in_size+(n_concat-2)*out_size, out_size, False) 149 | if is_deconv: 150 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, padding=0) 151 | else: 152 | self.up = nn.Sequential( 153 | nn.UpsamplingBilinear2d(scale_factor=2), 154 | nn.Conv2d(in_size, out_size, 1)) 155 | 156 | # initialise the blocks 157 | for m in self.children(): 158 | if m.__class__.__name__.find('unetConv2') != -1: continue 159 | init_weights(m, init_type='kaiming') 160 | 161 | def forward(self, high_feature, *low_feature): 162 | outputs0 = self.up(high_feature) 163 | for feature in low_feature: 164 | outputs0 = torch.cat([outputs0, feature], 1) 165 | 166 | return self.conv(outputs0) 167 | 168 | 169 | 170 | def init_weights(net, init_type='normal'): 171 | #print('initialization method [%s]' % init_type) 172 | if init_type == 'kaiming': 173 | net.apply(weights_init_kaiming) 174 | else: 175 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 176 | def weights_init_kaiming(m): 177 | classname = m.__class__.__name__ 178 | #print(classname) 179 | if classname.find('Conv') != -1: 180 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 181 | elif classname.find('Linear') != -1: 182 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 183 | elif classname.find('BatchNorm') != -1: 184 | init.normal_(m.weight.data, 1.0, 0.02) 185 | init.constant_(m.bias.data, 0.0) 186 | 187 | 188 | class Feature_Fusion(nn.Module):#自写的 189 | def __init__(self, inplanes, planes,r = 16, stride=1, downsample=None): 190 | super(Feature_Fusion, self).__init__() 191 | #中间是se模块的部分,其他是resnet正常部分 192 | self.conv1 = conv3x3(inplanes, planes, stride) 193 | self.bn1 = nn.BatchNorm2d(planes) 194 | self.relu = nn.ReLU(inplace=True) 195 | self.global_pool = nn.AdaptiveAvgPool2d(1) 196 | self.conv_down = nn.Conv2d( 197 | planes , planes // r, kernel_size=1, bias=False) 198 | self.conv_up = nn.Conv2d( 199 | planes // r, planes , kernel_size=1, bias=False) 200 | self.sig = nn.Sigmoid() 201 | self.bn1 = nn.BatchNorm2d(planes) 202 | 203 | def forward(self, x, y): 204 | 205 | out = torch.cat([x,y], dim=1) 206 | out = self.conv1(out) 207 | out = self.bn1(out) 208 | out = self.relu(out) 209 | input = out 210 | 211 | out1 = self.global_pool(out) 212 | out1 = self.conv_down(out1) 213 | out1 = self.relu(out1) 214 | out1 = self.conv_up(out1) 215 | out1 = self.sig(out1) 216 | 217 | res_1 = out1 * input 218 | res = res_1 + input 219 | 220 | return res 221 | 222 | 223 | class Attention_Skip(nn.Module):#自写的 224 | def __init__(self, planes,r = 16, stride=1, downsample=None): 225 | super(Attention_Skip, self).__init__() 226 | #中间是se模块的部分,其他是resnet正常部分 227 | self.relu = nn.ReLU(inplace=True) 228 | self.global_pool = nn.AdaptiveAvgPool2d(1) 229 | self.conv_down = nn.Conv2d( 230 | 2*planes , 2*planes // r, kernel_size=1, bias=False) 231 | self.conv_up = nn.Conv2d( 232 | 2*planes // r, 2*planes , kernel_size=1, bias=False) 233 | self.sig = nn.Sigmoid() 234 | self.bn1 = nn.BatchNorm2d(2*planes) 235 | 236 | def forward(self, x, y): 237 | 238 | input = torch.cat([x,y], dim=1) 239 | 240 | out1 = self.global_pool(input) 241 | out1 = self.conv_down(out1) 242 | out1 = self.relu(out1) 243 | out1 = self.conv_up(out1) 244 | out1 = self.sig(out1) 245 | out1 = self.bn1(out1) 246 | 247 | res = out1 * input 248 | 249 | return res 250 | 251 | 252 | class Context_Path(nn.Module): 253 | expansion = 1 254 | 255 | def __init__(self, inplanes, planes, stride=1): 256 | super(Context_Path, self).__init__() 257 | self.conv1 = conv3x3(inplanes, planes, stride) 258 | self.bn1 = nn.BatchNorm2d(planes) 259 | self.relu = nn.ReLU(inplace=True) 260 | self.conv2 = conv3x3(planes, planes, stride) 261 | self.bn2 = nn.BatchNorm2d(inplanes) 262 | self.stride = stride 263 | self.conv1_1 = nn.Conv2d(inplanes, planes, 1, stride) 264 | self.conv1_2 = nn.Conv2d(inplanes, inplanes, 1, stride) 265 | self.conv1_3 = nn.Conv2d(3*planes, planes, 1, stride) 266 | 267 | def forward(self, x): 268 | 269 | out1 = self.conv1_1(x)#d = 64 270 | 271 | out2 = self.conv1_1(x) 272 | out2 = self.conv2(out2)#将原图通道变为与特征图输入通道一致,d = 32 273 | out2 = self.bn1(out2) 274 | out2 = self.relu(out2) 275 | 276 | out3 = self.conv1_2(x) 277 | out3 = self.conv1(out3) 278 | out3 = self.bn1(out3) 279 | out3 = self.relu(out3)#特征图与原图下采样后concat经过3*3卷积 280 | out3 = self.conv2(out3) 281 | out3 = self.bn1(out3) 282 | out3 = self.relu(out3)#特征图与原图下采样后concat经过3*3卷积 283 | 284 | out = torch.cat([out1, out2, out3], dim=1)#通道数变plane的3倍 285 | out = self.conv1_3(out) 286 | 287 | return out 288 | 289 | 290 | 291 | 292 | class Multiscale_Predict(nn.Module): 293 | expansion = 1 294 | 295 | def __init__(self, inplanes, planes, stride=1): 296 | super(Multiscale_Predict, self).__init__() 297 | self.conv1 = conv3x3(inplanes, planes, stride) 298 | self.bn1 = nn.BatchNorm2d(planes) 299 | self.relu = nn.ReLU(inplace=True) 300 | self.conv2 = conv3x3(planes, planes, stride) 301 | self.bn2 = nn.BatchNorm2d(inplanes) 302 | self.pixelshuffle = nn.PixelShuffle(2) 303 | self.stride = stride 304 | 305 | def forward(self, x ,y): 306 | 307 | out1 = self.conv1(x)#layer3 308 | out1 = self.pixelshuffle(out1) 309 | 310 | out2 = torch.cat([out1, y], dim=1) 311 | out2 = self.conv2(out2) 312 | out2 = self.bn1(out2) 313 | out2 = self.relu(out2) 314 | 315 | return out2 316 | 317 | 318 | 319 | def conv3x3(in_planes, out_planes, stride=1): 320 | """3x3 convolution with padding""" 321 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 322 | padding=1, bias=False) 323 | 324 | 325 | if __name__ == '__main__': 326 | inputs = torch.rand((2, 1, 512, 512)).cuda() 327 | 328 | unet_plus_plus = Dual_Unet(in_channels=1, n_classes=2).cuda() 329 | output = unet_plus_plus(inputs) 330 | print('# parameters:', sum(param.numel() for param in unet_plus_plus.parameters())) 331 | def get_parameter_number(net): 332 | total_num = sum(p.numel() for p in net.parameters()) 333 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) 334 | return {'Total': total_num, 'Trainable': trainable_num} 335 | print('# parameters:', get_parameter_number(unet_plus_plus)) -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | #import torch.nn as nn 2 | import torch 3 | from torch.nn import functional as F 4 | #from PIL import Image 5 | import numpy as np 6 | import pandas as pd 7 | #import os 8 | import os.path as osp 9 | import shutil 10 | #import math 11 | 12 | def save_checkpoint(state,best_pred, epoch,is_best,checkpoint_path,filename='./checkpoint/checkpoint.pth.tar'): 13 | torch.save(state, filename) 14 | if is_best: 15 | shutil.copyfile(filename, osp.join(checkpoint_path,'model_{:03d}_{:.4f}.pth.tar'.format((epoch + 1),best_pred))) 16 | 17 | 18 | 19 | def adjust_learning_rate(opt, optimizer, epoch): 20 | """ 21 | Sets the learning rate to the initial LR decayed by 10 every 30 epochs(step = 30) 22 | """ 23 | if opt.lr_mode == 'step': 24 | lr = opt.lr * (0.1 ** (epoch // opt.step)) 25 | elif opt.lr_mode == 'poly': 26 | lr = opt.lr * (1 - epoch / opt.num_epochs) ** 0.9 27 | else: 28 | raise ValueError('Unknown lr mode {}'.format(opt.lr_mode)) 29 | 30 | for param_group in optimizer.param_groups: 31 | param_group['lr'] = lr 32 | return lr 33 | 34 | 35 | 36 | 37 | def one_hot_it(label, label_info): 38 | # return semantic_map -> [H, W, num_classes] 39 | semantic_map = [] 40 | for info in label_info: 41 | color = label_info[info] 42 | # colour_map = np.full((label.shape[0], label.shape[1], label.shape[2]), colour, dtype=int) 43 | equality = np.equal(label, color) 44 | class_map = np.all(equality, axis=-1) 45 | semantic_map.append(class_map) 46 | semantic_map = np.stack(semantic_map, axis=-1) 47 | return semantic_map 48 | 49 | def compute_score_multi(predict, target, smooth=1): 50 | 51 | assert(predict.shape == target.shape) 52 | #第一类 53 | overlap_1 = ((predict == 1)*(target == 1)).sum() #TP 54 | union_1 = (predict == 1).sum() + (target == 1).sum()-overlap_1 #FP+FN+TP 55 | FP_1=(predict == 1).sum()-overlap_1 #FP 56 | FN_1=(target == 1).sum()-overlap_1 #FN 57 | TN_1= target.shape[0]*target.shape[1]-union_1 #TN 58 | 59 | #第二类 60 | overlap_2 = ((predict == 2)*(target == 2)).sum() #TP 61 | union_2 = (predict == 2).sum() + (target == 2).sum()-overlap_2 #FP+FN+TP 62 | FP_2=(predict == 2).sum()-overlap_2 #FP 63 | FN_2=(target == 2).sum()-overlap_2 #FN 64 | TN_2= target.shape[0]*target.shape[1]-union_2 #TN 65 | 66 | 67 | #计算指标 68 | dice_1=(2*overlap_1 +smooth)/ (union_1 + overlap_1 +smooth) 69 | Acc_1=(overlap_1+TN_1+smooth)/(target.shape[0]*target.shape[1]+smooth) 70 | jaccard_1=(overlap_1 + smooth) / (union_1 + smooth) 71 | # Sensitivity_1=(overlap_1 +smooth) / ((target == 1).sum()+smooth) 72 | Sensitivity_1=(overlap_1 +smooth) / (overlap_1+FN_1 +smooth) 73 | Specificity_1=(TN_1 + smooth) / (FP_1 + TN_1 + smooth) 74 | 75 | dice_2=(2*overlap_2 +smooth)/ (union_2 + overlap_2 +smooth) 76 | Acc_2=(overlap_2+TN_2+smooth)/(target.shape[0]*target.shape[1]+smooth) 77 | jaccard_2=(overlap_2 + smooth) / (union_2 + smooth) 78 | Sensitivity_2=(overlap_2 +smooth) / (overlap_2+FN_2 +smooth) 79 | Specificity_2=(TN_2 + smooth) / (FP_2 + TN_2 + smooth) 80 | 81 | 82 | return dice_1,dice_2,Acc_1,Acc_2,jaccard_1,jaccard_2,Sensitivity_1,Sensitivity_2,Specificity_1,Specificity_2 83 | 84 | 85 | 86 | def eval_multi_seg(predict, target): 87 | pred_seg=torch.argmax(torch.exp(predict),dim=1).int() 88 | pred_seg = pred_seg.data.cpu().numpy() 89 | label_seg = target.data.cpu().numpy().astype(dtype=np.int) 90 | assert(pred_seg.shape == label_seg.shape) 91 | 92 | Dice_1 = [] 93 | Precsion_1 = [] 94 | Jaccard_1 = [] 95 | Sensitivity_1=[] 96 | Specificity_1=[] 97 | 98 | Dice_2 = [] 99 | Precsion_2 = [] 100 | Jaccard_2 = [] 101 | Sensitivity_2=[] 102 | Specificity_2=[] 103 | 104 | n = pred_seg.shape[0] 105 | 106 | for i in range(n): 107 | dice_1,dice_2,precsion_1,precsion_2,jaccard_1,jaccard_2,sensitivity_1,sensitivity_2,specificity_1,specificity_2= compute_score_multi(pred_seg[i],label_seg[i]) 108 | 109 | Dice_1.append(dice_1) 110 | Dice_2.append(dice_2) 111 | 112 | Precsion_1.append(precsion_1) 113 | Precsion_2.append(precsion_2) 114 | 115 | Jaccard_1.append(jaccard_1) 116 | Jaccard_2.append(jaccard_2) 117 | 118 | Sensitivity_1.append(sensitivity_1) 119 | Sensitivity_2.append(sensitivity_2) 120 | 121 | Specificity_1.append(specificity_1) 122 | Specificity_2.append(specificity_2) 123 | 124 | return Dice_1,Dice_2,Precsion_1,Precsion_2,Jaccard_1,Jaccard_2,Sensitivity_1,Sensitivity_2,Specificity_1,Specificity_2 125 | #def compute_score_multi(predict, target, forground = 1,smooth=0.001): 126 | # score = 0 127 | # count = 0 128 | # target[target!=forground]=0 129 | # predict[predict!=forground]=0 130 | # assert(predict.shape == target.shape) 131 | # overlap = ((predict == forground)*(target == forground)).sum() #TP 132 | # union=(predict == forground).sum() + (target == forground).sum()-overlap #FP+FN+TP 133 | # FP=(predict == forground).sum()-overlap #FP 134 | # FN=(target == forground).sum()-overlap #FN 135 | # TN= target.shape[0]*target.shape[1]-union #TN 136 | # 137 | # 138 | # #print('overlap:',overlap) 139 | # dice=(2*overlap +smooth)/ (union+overlap+smooth) 140 | # 141 | # precsion=((predict == target).sum()+smooth) / (target.shape[0]*target.shape[1]+smooth) 142 | # 143 | # jaccard=(overlap+smooth) / (union+smooth) 144 | # 145 | # Sensitivity=(overlap+smooth) / ((target == forground).sum()+smooth) 146 | # 147 | # Specificity=(TN+smooth) / (FP+TN+smooth) 148 | # 149 | # 150 | # return dice,precsion,jaccard,Sensitivity,Specificity 151 | # 152 | # 153 | # 154 | #def eval_multi_seg(predict, target, forground = 1): 155 | # pred_seg=torch.argmax(torch.exp(predict),dim=1).int() 156 | # pred_seg = pred_seg.data.cpu().numpy() 157 | # label_seg = target.data.cpu().numpy().astype(dtype=np.int) 158 | # assert(pred_seg.shape == label_seg.shape) 159 | # 160 | # Dice = [] 161 | # Precsion = [] 162 | # Jaccard = [] 163 | # Sensitivity=[] 164 | # Specificity=[] 165 | # 166 | # n = pred_seg.shape[0] 167 | # 168 | # for i in range(n): 169 | # dice,precsion,jaccard,sensitivity,specificity= compute_score_multi(pred_seg[i],label_seg[i]) 170 | # Dice.append(dice) 171 | # Precsion .append(precsion) 172 | # Jaccard.append(jaccard) 173 | # Sensitivity.append(sensitivity) 174 | # Specificity.append(specificity) 175 | # 176 | # return Dice,Precsion,Jaccard,Sensitivity,Specificity 177 | 178 | def compute_score_single(predict, target, forground = 1,smooth=1): 179 | score = 0 180 | count = 0 181 | target[target!=forground]=0 182 | predict[predict!=forground]=0 183 | assert(predict.shape == target.shape) 184 | overlap = ((predict == forground)*(target == forground)).sum() #TP,真阳性:预测为正,实际也为正,overlap为相交 185 | union=(predict == forground).sum() + (target == forground).sum()-overlap #FP+FN+TP,union为并 186 | FP=(predict == forground).sum()-overlap #FP,False Positive 假阳性:预测为正,实际为负 187 | FN=(target == forground).sum()-overlap #FN,False Negative 假阴性:预测与负、实际为正 188 | TN= target.shape[0]*target.shape[1]*target.shape[2]-union #TN,True Negative 真阴性:预测为负、实际也为负。 189 | 190 | 191 | #print('overlap:',overlap) 192 | dice=(2*overlap +smooth)/ (union+overlap+smooth) 193 | 194 | precsion=((predict == target).sum()+smooth) / (target.shape[0]*target.shape[1]*target.shape[2]+smooth) 195 | 196 | jaccard=(overlap+smooth) / (union+smooth) 197 | 198 | Sensitivity=(overlap+smooth) / ((target == forground).sum()+smooth) 199 | 200 | Specificity=(TN+smooth) / (FP+TN+smooth) 201 | 202 | 203 | return dice,precsion,jaccard,Sensitivity,Specificity 204 | 205 | def eval_single_seg(predict, target, forground = 1): 206 | pred_seg=torch.round(torch.sigmoid(predict)).int() 207 | pred_seg = pred_seg.data.cpu().numpy() 208 | label_seg = target.data.cpu().numpy().astype(dtype=np.int) 209 | assert(pred_seg.shape == label_seg.shape) 210 | 211 | Dice = [] 212 | Precsion = [] 213 | Jaccard = [] 214 | Sensitivity=[] 215 | Specificity=[] 216 | 217 | n = pred_seg.shape[0] 218 | 219 | for i in range(n): 220 | dice,precsion,jaccard,sensitivity,specificity= compute_score_single(pred_seg[i],label_seg[i]) 221 | Dice.append(dice) 222 | Precsion .append(precsion) 223 | Jaccard.append(jaccard) 224 | Sensitivity.append(sensitivity) 225 | Specificity.append(specificity) 226 | 227 | return Dice,Precsion,Jaccard,Sensitivity,Specificity 228 | 229 | 230 | def batch_pix_accuracy(pred,label,nclass=1): 231 | if nclass==1: 232 | pred=torch.round(torch.sigmoid(pred)).int() 233 | pred=pred.cpu().numpy() 234 | else: 235 | pred=torch.max(pred,dim=1) 236 | pred=pred.cpu().numpy() 237 | label=label.cpu().numpy() 238 | pixel_labeled = np.sum(label >=0) 239 | pixel_correct=np.sum(pred==label) 240 | 241 | assert pixel_correct <= pixel_labeled, \ 242 | "Correct area should be smaller than Labeled" 243 | 244 | return pixel_correct,pixel_labeled 245 | 246 | def batch_intersection_union(predict, target, nclass): 247 | 248 | """Batch Intersection of Union 249 | Args: 250 | predict: input 4D tensor 251 | target: label 3D tensor 252 | nclass: number of categories (int),note: not include background 253 | """ 254 | if nclass==1: 255 | pred=torch.round(torch.sigmoid(predict)).int() 256 | pred=pred.cpu().numpy() 257 | target = target.cpu().numpy() 258 | area_inter=np.sum(pred*target) 259 | area_union=np.sum(pred)+np.sum(target)-area_inter 260 | 261 | return area_inter,area_union 262 | 263 | 264 | 265 | 266 | if nclass>1: 267 | _, predict = torch.max(predict, 1) 268 | mini = 1 269 | maxi = nclass 270 | nbins = nclass 271 | predict = predict.cpu().numpy() + 1 272 | target = target.cpu().numpy() + 1 273 | # target = target + 1 274 | 275 | predict = predict * (target > 0).astype(predict.dtype) 276 | intersection = predict * (predict == target) 277 | # areas of intersection and union 278 | area_inter, _ = np.histogram(intersection, bins=nbins-1, range=(mini+1, maxi)) 279 | area_pred, _ = np.histogram(predict, bins=nbins-1, range=(mini+1, maxi)) 280 | area_lab, _ = np.histogram(target, bins=nbins-1, range=(mini+1, maxi)) 281 | area_union = area_pred + area_lab - area_inter 282 | assert (area_inter <= area_union).all(), \ 283 | "Intersection area should be smaller than Union area" 284 | return area_inter, area_union 285 | 286 | 287 | def pixel_accuracy(im_pred, im_lab): 288 | im_pred = np.asarray(im_pred) 289 | im_lab = np.asarray(im_lab) 290 | 291 | # Remove classes from unlabeled pixels in gt image. 292 | # We should not penalize detections in unlabeled portions of the image. 293 | pixel_labeled = np.sum(im_lab > 0) 294 | pixel_correct = np.sum((im_pred == im_lab) * (im_lab > 0)) 295 | #pixel_accuracy = 1.0 * pixel_correct / pixel_labeled 296 | return pixel_correct, pixel_labeled 297 | 298 | def reverse_one_hot(image): 299 | """ 300 | Transform a 2D array in one-hot format (depth is num_classes), 301 | to a 2D array with only 1 channel, where each pixel value is 302 | the classified class key. 303 | 304 | # Arguments 305 | image: The one-hot format image 306 | 307 | # Returns 308 | A 2D array with the same width and height as the input, but 309 | with a depth size of 1, where each pixel value is the classified 310 | class key. 311 | """ 312 | # w = image.shape[0] 313 | # h = image.shape[1] 314 | # x = np.zeros([w,h,1]) 315 | 316 | # for i in range(0, w): 317 | # for j in range(0, h): 318 | # index, value = max(enumerate(image[i, j, :]), key=operator.itemgetter(1)) 319 | # x[i, j] = index 320 | image = image.permute(1, 2, 0) 321 | x = torch.argmax(image, dim=-1) 322 | return x 323 | 324 | 325 | def colour_code_segmentation(image, label_values): 326 | """ 327 | Given a 1-channel array of class keys, colour code the segmentation results. 328 | 329 | # Arguments 330 | image: single channel array where each value represents the class key. 331 | label_values 332 | 333 | # Returns 334 | Colour coded image for segmentation visualization 335 | """ 336 | 337 | # w = image.shape[0] 338 | # h = image.shape[1] 339 | # x = np.zeros([w,h,3]) 340 | # colour_codes = label_values 341 | # for i in range(0, w): 342 | # for j in range(0, h): 343 | # x[i, j, :] = colour_codes[int(image[i, j])] 344 | label_values = [label_values[key] for key in label_values] 345 | colour_codes = np.array(label_values) 346 | x = colour_codes[image.astype(int)] 347 | 348 | return x 349 | 350 | #def compute_global_accuracy(pred, label): 351 | # pred = pred.flatten() 352 | # label = label.flatten() 353 | # total = len(label) 354 | # count = 0.0 355 | # for i in range(total): 356 | # if pred[i] == label[i]: 357 | # count = count + 1.0 358 | # return float(count) / float(total) -------------------------------------------------------------------------------- /model/mildnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Apr 17 12:59:10 2019 4 | 5 | @author: Fsl 6 | """ 7 | #import _init_paths 8 | import torch 9 | import torch.nn as nn 10 | #from layers import unetConv2, unetUp 11 | #from utils import init_weights, count_param 12 | import torchsummary 13 | from torch.nn import functional as F 14 | from torch.nn import init 15 | class Mild_net(nn.Module): 16 | 17 | def __init__(self, in_channels=1, n_classes=4, feature_scale=2, is_deconv=False,is_batchnorm=True): 18 | super(Mild_net, self).__init__() 19 | self.in_channels = in_channels 20 | self.feature_scale = feature_scale 21 | self.is_deconv = is_deconv 22 | self.is_batchnorm = is_batchnorm 23 | 24 | filters = [64, 128, 256, 512, 1024, 640] 25 | filters = [int(x / self.feature_scale) for x in filters] 26 | 27 | # downsampling 28 | self.maxpool = nn.MaxPool2d(kernel_size=2) 29 | self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm) 30 | self.Mil_unit1 = Mil_unit(filters[0], filters[1], self.is_batchnorm) 31 | self.Residual_unit1 = Residual_unit(filters[1], filters[1], self.is_batchnorm) 32 | 33 | self.Mil_unit2 = Mil_unit(filters[1], filters[2], self.is_batchnorm) 34 | self.Residual_unit2 = Residual_unit(filters[2], filters[2], self.is_batchnorm) 35 | 36 | self.Mil_unit3 = Mil_unit(filters[2], filters[3], self.is_batchnorm) 37 | self.Residual_unit3 = Residual_unit(filters[3], filters[3], self.is_batchnorm) 38 | 39 | self.Mil_unit4 = Mil_unit(filters[3], filters[4], self.is_batchnorm) 40 | self.Residual_unit4 = Residual_unit(filters[4], filters[4], self.is_batchnorm) 41 | 42 | 43 | # ASPP 44 | self.aspp = ASPP(filters[4], filters[5]) 45 | self.conv10 = nn.Conv2d(filters[5], filters[4], 1, stride = 1)#将原始图通过1*1卷积变换通道,最左边的1是因为灰度图 46 | 47 | # upsampling 48 | self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv) 49 | self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv) 50 | self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv) 51 | self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv) 52 | 53 | # final conv (without any concat) 54 | # self.drop = nn.Dropout2d(p=0.5)#dropout 55 | self.final = nn.Conv2d(filters[0], n_classes, 1) 56 | 57 | # initialise weights 58 | for m in self.modules(): 59 | if isinstance(m, nn.Conv2d): 60 | init_weights(m, init_type='kaiming') 61 | elif isinstance(m, nn.BatchNorm2d): 62 | init_weights(m, init_type='kaiming') 63 | 64 | 65 | def forward(self, inputs): 66 | conv1 = self.conv1(inputs) # 前两个3*3,64*256*512 67 | maxpool1 = self.maxpool(conv1) # 64*128*256 68 | mil1 = self.Mil_unit1(maxpool1, inputs)#128*128*256 69 | Residual1 = self.Residual_unit1(mil1)#128*128*256 70 | maxpool2 = self.maxpool(Residual1)#128*64*128 71 | mil2 = self.Mil_unit2(maxpool2, inputs)#256*64*128 72 | Residual2 = self.Residual_unit2(mil2)#256*64*128 73 | maxpool3 = self.maxpool(Residual2)#256*32*64 74 | mil3 = self.Mil_unit3(maxpool3, inputs)#512*32*64 75 | Residual3 = self.Residual_unit3(mil3)#512*32*64 76 | maxpool4 = self.maxpool(Residual3)#512*16*32 77 | mil4 = self.Mil_unit4(maxpool4, inputs)#1024*16*32 78 | Residual4 = self.Residual_unit4(mil4)#1024*16*32 79 | 80 | aspp = self.aspp(Residual4) # 640*32*64 81 | aspp = self.conv10(aspp) 82 | 83 | up4 = self.up_concat4(aspp ,Residual3 ) 84 | up3 = self.up_concat3(up4 ,Residual2 ) # 128*128*256 85 | up2 = self.up_concat2(up3,Residual1) # 64*256*512 86 | up1 = self.up_concat1(up2,conv1) # 64*256*512 87 | final = self.final(up1) 88 | 89 | # final=F.sigmoid(final) 90 | final=F.log_softmax(final,dim=1) 91 | 92 | return up1,final 93 | 94 | class unetConv2(nn.Module): 95 | def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1): 96 | super(unetConv2, self).__init__() 97 | self.n = n 98 | self.ks = ks 99 | self.stride = stride 100 | self.padding = padding 101 | s = stride 102 | p = padding 103 | if is_batchnorm: 104 | for i in range(1, n+1): 105 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 106 | nn.BatchNorm2d(out_size), 107 | nn.ReLU(inplace=True),) 108 | setattr(self, 'conv%d'%i, conv) 109 | in_size = out_size 110 | 111 | else: 112 | for i in range(1, n+1): 113 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 114 | nn.ReLU(inplace=True),) 115 | setattr(self, 'conv%d'%i, conv) 116 | in_size = out_size 117 | 118 | # initialise the blocks 119 | for m in self.children(): 120 | init_weights(m, init_type='kaiming') 121 | 122 | def forward(self, inputs): 123 | x = inputs 124 | for i in range(1, self.n+1): 125 | conv = getattr(self, 'conv%d'%i) 126 | x = conv(x) 127 | 128 | return x 129 | 130 | 131 | 132 | class unetUp(nn.Module): 133 | def __init__(self, in_size, out_size, is_deconv, n_concat=2): 134 | super(unetUp, self).__init__() 135 | self.conv = unetConv2(in_size+(n_concat-2)*out_size, out_size, True) 136 | if is_deconv: 137 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, padding=0) 138 | else: 139 | self.up = nn.Sequential( 140 | nn.UpsamplingBilinear2d(scale_factor=2), 141 | nn.Conv2d(in_size, out_size, 1)) 142 | 143 | # initialise the blocks 144 | for m in self.children(): 145 | if m.__class__.__name__.find('unetConv2') != -1: continue #unetConv2已经是一个初始化好的类,不需要再初始化 146 | init_weights(m, init_type='kaiming') 147 | 148 | def forward(self, high_feature, *low_feature): 149 | outputs0 = self.up(high_feature) 150 | for feature in low_feature: 151 | outputs0 = torch.cat([outputs0, feature], 1) 152 | return self.conv(outputs0) 153 | 154 | def init_weights(net, init_type='normal'): 155 | #print('initialization method [%s]' % init_type) 156 | if init_type == 'kaiming': 157 | net.apply(weights_init_kaiming) 158 | else: 159 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 160 | def weights_init_kaiming(m): 161 | classname = m.__class__.__name__ 162 | #print(classname) 163 | if classname.find('Conv') != -1: 164 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 165 | elif classname.find('Linear') != -1: 166 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 167 | elif classname.find('BatchNorm') != -1: 168 | init.normal_(m.weight.data, 1.0, 0.02) 169 | init.constant_(m.bias.data, 0.0) 170 | 171 | 172 | class ASPP(nn.Module): 173 | def __init__(self, in_channel, depth): 174 | super(ASPP,self).__init__() 175 | self.in_channel = in_channel 176 | self.depth = depth 177 | # global average pooling : init nn.AdaptiveAvgPool2d ;also forward torch.mean(,,keep_dim=True) 178 | self.mean = nn.AdaptiveAvgPool2d((1, 1))#平均池化 179 | self.conv = nn.Conv2d(in_channel, depth, 1, 1) 180 | # k=1 s=1 no pad 181 | self.atrous_block1 = nn.Conv2d(in_channel, depth, 1, 1)#1*1卷积 182 | self.atrous_block6 = nn.Conv2d(in_channel, depth, 3, 1, padding=6, dilation=6)#3*3卷积,膨胀率6 183 | self.atrous_block12 = nn.Conv2d(in_channel, depth, 3, 1, padding=12, dilation=12)#3*3卷积,膨胀率12 184 | self.atrous_block18 = nn.Conv2d(in_channel, depth, 3, 1, padding=18, dilation=18)#3*3卷积,膨胀率18 185 | 186 | self.conv_1x1_output = nn.Conv2d(depth * 5, depth, 1, 1) 187 | 188 | def forward(self, x): 189 | size = x.shape[2:] 190 | 191 | image_features = self.mean(x) 192 | image_features = self.conv(image_features) 193 | image_features = F.interpolate(image_features, size=size, mode='bilinear') 194 | 195 | atrous_block1 = self.atrous_block1(x) 196 | 197 | atrous_block6 = self.atrous_block6(x) 198 | 199 | atrous_block12 = self.atrous_block12(x) 200 | 201 | atrous_block18 = self.atrous_block18(x) 202 | 203 | net = self.conv_1x1_output(torch.cat([image_features, atrous_block1, atrous_block6, 204 | atrous_block12, atrous_block18], dim=1)) 205 | return net 206 | 207 | 208 | class Residual_unit(nn.Module): 209 | expansion = 1 210 | 211 | def __init__(self, inplanes, planes, stride=1, downsample=None): 212 | super(Residual_unit, self).__init__() 213 | self.conv1 = conv3x3(inplanes, planes, stride) 214 | self.bn1 = nn.BatchNorm2d(planes) 215 | self.relu = nn.ReLU(inplace=True) 216 | self.conv2 = conv3x3(planes, planes) 217 | self.bn2 = nn.BatchNorm2d(planes) 218 | self.downsample = downsample 219 | self.stride = stride 220 | 221 | def forward(self, x): 222 | residual = x 223 | 224 | out = self.conv1(x) 225 | out = self.bn1(out) 226 | out = self.relu(out) 227 | 228 | out = self.conv2(out) 229 | out = self.bn2(out) 230 | 231 | if self.downsample is not None: 232 | residual = self.downsample(x) 233 | 234 | out += residual 235 | out = self.relu(out) 236 | 237 | return out 238 | 239 | class Dilated_Residual_unit_1(nn.Module): 240 | expansion = 1 241 | 242 | def __init__(self, inplanes, planes, kernal=3, stride=1): 243 | super(Dilated_Residual_unit_1, self).__init__() 244 | self.kernal = kernal 245 | self.stride = stride 246 | self.conv1 = nn.Conv2d(inplanes, planes, 3, 1, padding=2, dilation=2) 247 | self.bn1 = nn.BatchNorm2d(planes) 248 | self.relu = nn.ReLU(inplace=True) 249 | self.conv2 = nn.Conv2d(planes, planes, 3, 1, padding=2, dilation=2) 250 | self.bn2 = nn.BatchNorm2d(planes) 251 | 252 | 253 | def forward(self, x): 254 | residual = x 255 | 256 | out = self.conv1(x) 257 | out = self.bn1(out) 258 | out = self.relu(out) 259 | 260 | out = self.conv2(out) 261 | out = self.bn2(out) 262 | 263 | out += residual 264 | out = self.relu(out) 265 | 266 | return out 267 | 268 | class Dilated_Residual_unit_2(nn.Module): 269 | def __init__(self, inplanes, planes, kernal=3, stride=1): 270 | super(Dilated_Residual_unit_2, self).__init__() 271 | self.conv1 = nn.Conv2d(inplanes, planes, 3, 1, padding=4, dilation=4) 272 | self.bn1 = nn.BatchNorm2d(planes) 273 | self.relu = nn.ReLU(inplace=True) 274 | self.conv2 = nn.Conv2d(planes, planes, 3, 1, padding=4, dilation=4) 275 | self.bn2 = nn.BatchNorm2d(planes) 276 | self.kernal = kernal 277 | self.stride = stride 278 | 279 | def forward(self, x): 280 | residual = x 281 | 282 | out = self.conv1(x) 283 | out = self.bn1(out) 284 | out = self.relu(out) 285 | 286 | out = self.conv2(out) 287 | out = self.bn2(out) 288 | 289 | out += residual 290 | out = self.relu(out) 291 | 292 | return out 293 | 294 | class Mil_unit(nn.Module): 295 | expansion = 1 296 | 297 | def __init__(self, inplanes, planes, stride=1): 298 | super(Mil_unit, self).__init__() 299 | self.conv1 = conv3x3(inplanes, inplanes, stride) 300 | self.conv1x = conv3x3(inplanes, inplanes, stride) 301 | self.bn1 = nn.BatchNorm2d(inplanes) 302 | self.relu = nn.ReLU(inplace=True) 303 | self.conv2 = conv3x3(inplanes, planes) 304 | self.bn2 = nn.BatchNorm2d(planes) 305 | self.conv3 = conv3x3( planes, planes, stride) 306 | self.stride = stride 307 | self.conv1_1 = nn.Conv2d(1, inplanes, 1, stride)#将原始图通过1*1卷积变换通道,最左边的1是因为灰度图 308 | 309 | def forward(self, x, origin): 310 | 311 | size = x.shape[2:]#****** 312 | 313 | out1 = self.conv1(x) 314 | out1 = self.bn1(out1) 315 | out1 = self.relu(out1) 316 | out1 = self.conv2(out1) 317 | out1 = self.bn2(out1)#特征图经过两个3*3卷积 318 | 319 | out2 = F.interpolate(origin, size, mode="bilinear") 320 | out2 = self.conv1_1(out2)#将原图通道变为与特征图输入通道一致 321 | 322 | out2 = self.conv1x(out2) 323 | out2 = self.bn1(out2)#输入通道 324 | out2 = self.relu(out2)#原图下采样,经过3*3卷积 325 | out3 = torch.cat([x, out2], 1)#将上一层的特征图和原图concat,通道翻倍 326 | out3 = self.conv3(out3) 327 | out3 = self.bn2(out3) 328 | out3 = self.relu(out3)#特征图与原图下采样后concat经过3*3卷积 329 | 330 | out = out3 + out1#两条线相加 331 | out = self.relu(out) 332 | 333 | return out 334 | 335 | def conv3x3(in_planes, out_planes, stride=1): 336 | """3x3 convolution with padding""" 337 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 338 | padding=1, bias=False) 339 | 340 | 341 | 342 | 343 | 344 | #model = Mild_net().cuda() 345 | #torchsummary.summary(model, (1, 512, 512)) 346 | if __name__ == '__main__': 347 | inputs = torch.rand((2, 1, 512, 512)).cuda() 348 | 349 | unet_plus_plus = Mild_net(in_channels=1, n_classes=2).cuda() 350 | output = unet_plus_plus(inputs) 351 | print('# parameters:', sum(param.numel() for param in unet_plus_plus.parameters())) 352 | def get_parameter_number(net): 353 | total_num = sum(p.numel() for p in net.parameters()) 354 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) 355 | return {'Total': total_num, 'Trainable': trainable_num} 356 | print('# parameters:', get_parameter_number(unet_plus_plus)) -------------------------------------------------------------------------------- /model/unet_deep_Asymmetric_Non-local.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Jun 24 15:46:32 2020 4 | 5 | @author: Administrator 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | #from layers import unetConv2, unetUp 11 | #from utils import init_weights, count_param 12 | import torchsummary 13 | from torch.nn import functional as F 14 | from torch.nn import init 15 | class UNet_deepsup(nn.Module): 16 | 17 | def __init__(self, in_channels=1, n_classes=3, feature_scale=2, is_deconv=True, is_batchnorm=True): 18 | super(UNet_deepsup, self).__init__() 19 | self.in_channels = in_channels 20 | self.feature_scale = feature_scale 21 | self.is_deconv = is_deconv 22 | self.is_batchnorm = is_batchnorm 23 | 24 | 25 | filters = [64, 128, 256, 512, 1024] 26 | filters = [int(x / self.feature_scale) for x in filters] 27 | 28 | # downsampling 29 | self.maxpool = nn.MaxPool2d(kernel_size=2) 30 | self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm) 31 | self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm) 32 | self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm) 33 | self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm) 34 | self.center = unetConv2(filters[3], filters[4], self.is_batchnorm) 35 | # upsampling 36 | self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv) 37 | self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv) 38 | self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv) 39 | self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv) 40 | # final conv (without any concat) 41 | self.final = nn.Conv2d(filters[0], n_classes, 1) 42 | 43 | #deep Supervision 44 | 45 | self.deepsup_3 = nn.Conv2d(filters[3], n_classes, kernel_size=1, stride=1, padding=0) 46 | self.output_3_up = nn.Upsample(scale_factor=8, mode='bilinear') 47 | self.deepsup_2 = nn.Conv2d(filters[2], n_classes, kernel_size=1, stride=1, padding=0) 48 | self.output_2_up = nn.Upsample(scale_factor=4, mode='bilinear') 49 | self.deepsup_1 = nn.Conv2d(filters[1], n_classes, kernel_size=1, stride=1, padding=0) 50 | self.output_1_up = nn.Upsample(scale_factor=2, mode='bilinear') 51 | 52 | 53 | def forward(self, inputs): 54 | conv1 = self.conv1(inputs) # 16*512*512 55 | maxpool1 = self.maxpool(conv1) # 16*256*256 56 | 57 | conv2 = self.conv2(maxpool1) # 32*256*256 58 | maxpool2 = self.maxpool(conv2) # 32*128*128 59 | 60 | conv3 = self.conv3(maxpool2) # 64*128*128 61 | maxpool3 = self.maxpool(conv3) # 64*64*64 62 | 63 | conv4 = self.conv4(maxpool3) # 128*64*64 64 | maxpool4 = self.maxpool(conv4) # 128*32*32 65 | 66 | center = self.center(maxpool4) # 256*32*32 67 | 68 | up4 = self.up_concat4(center,conv4) # 128*64*64 69 | up4_deep = self.deepsup_3(up4) 70 | up4_deep = self.output_3_up(up4_deep) 71 | 72 | up3 = self.up_concat3(up4,conv3) # 64*128*128 73 | up3_deep = self.deepsup_2(up3) 74 | up3_deep = self.output_2_up(up3_deep) 75 | 76 | up2 = self.up_concat2(up3,conv2) # 32*256*256 77 | up2_deep = self.deepsup_1(up2) 78 | up2_deep = self.output_1_up(up2_deep) 79 | 80 | up1 = self.up_concat1(up2,conv1) # 16*512*512 81 | 82 | 83 | final = self.final(up1) 84 | # final=F.softmax(final,dim=1)#对每一行使用Softmax 85 | up4_deep = F.log_softmax(final,dim=1) 86 | up3_deep = F.log_softmax(final,dim=1) 87 | up2_deep = F.log_softmax(final,dim=1) 88 | final=F.log_softmax(final,dim=1) 89 | 90 | return up4_deep,up3_deep,up2_deep,final 91 | 92 | class unetConv2(nn.Module): 93 | def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1): 94 | super(unetConv2, self).__init__() 95 | self.n = n 96 | self.ks = ks 97 | self.stride = stride 98 | self.padding = padding 99 | s = stride 100 | p = padding 101 | if is_batchnorm: 102 | for i in range(1, n+1): 103 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 104 | nn.BatchNorm2d(out_size), 105 | nn.ReLU(inplace=True),) 106 | setattr(self, 'conv%d'%i, conv) 107 | in_size = out_size 108 | 109 | else: 110 | for i in range(1, n+1): 111 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 112 | nn.ReLU(inplace=True),) 113 | setattr(self, 'conv%d'%i, conv) 114 | in_size = out_size 115 | 116 | # initialise the blocks 117 | for m in self.children(): 118 | init_weights(m, init_type='kaiming') 119 | 120 | def forward(self, inputs): 121 | x = inputs 122 | for i in range(1, self.n+1): 123 | conv = getattr(self, 'conv%d'%i) 124 | x = conv(x) 125 | 126 | return x 127 | 128 | class unetUp(nn.Module): 129 | def __init__(self, in_size, out_size, is_deconv, n_concat=2): 130 | super(unetUp, self).__init__() 131 | self.conv = unetConv2(in_size+(n_concat-2)*out_size, out_size, False) 132 | if is_deconv: 133 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, padding=0) 134 | else: 135 | self.up = nn.Sequential( 136 | nn.UpsamplingBilinear2d(scale_factor=2), 137 | nn.Conv2d(in_size, out_size, 1)) 138 | 139 | # initialise the blocks 140 | for m in self.children(): 141 | if m.__class__.__name__.find('unetConv2') != -1: continue#continue语句可用于循环中,用于跳过当前循环的剩余代码,然后继续进行下一轮的循环。 142 | init_weights(m, init_type='kaiming') 143 | 144 | def forward(self, high_feature, *low_feature): 145 | outputs0 = self.up(high_feature) 146 | for feature in low_feature: 147 | outputs0 = torch.cat([outputs0, feature], 1) 148 | 149 | return self.conv(outputs0) 150 | 151 | def init_weights(net, init_type='normal'): 152 | #print('initialization method [%s]' % init_type) 153 | if init_type == 'kaiming': 154 | net.apply(weights_init_kaiming) 155 | else: 156 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 157 | def weights_init_kaiming(m): 158 | classname = m.__class__.__name__ 159 | if classname.find('Conv') != -1: 160 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 161 | elif classname.find('Linear') != -1: 162 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 163 | elif classname.find('BatchNorm') != -1: 164 | init.normal_(m.weight.data, 1.0, 0.02) 165 | init.constant_(m.bias.data, 0.0) 166 | 167 | class SelfAttentionBlock2D(_SelfAttentionBlock): 168 | def __init__(self, in_channels, key_channels, value_channels, out_channels=None, scale=1, norm_type=None,psp_size=(1,3,6,8)): 169 | super(SelfAttentionBlock2D, self).__init__(in_channels, 170 | key_channels, 171 | value_channels, 172 | out_channels, 173 | scale, 174 | norm_type, 175 | psp_size=psp_size) 176 | 177 | class APNB(nn.Module): 178 | """ 179 | Parameters: 180 | in_features / out_features: the channels of the input / output feature maps. 181 | dropout: we choose 0.05 as the default value. 182 | size: you can apply multiple sizes. Here we only use one size. 183 | Return: 184 | features fused with Object context information. 185 | """ 186 | 187 | def __init__(self, in_channels, out_channels, key_channels, value_channels, dropout, sizes=([1]), norm_type=None,psp_size=(1,3,6,8)): 188 | super(APNB, self).__init__() 189 | self.stages = [] 190 | self.norm_type = norm_type 191 | self.psp_size=psp_size 192 | self.stages = nn.ModuleList( 193 | [self._make_stage(in_channels, out_channels, key_channels, value_channels, size) for size in sizes]) 194 | self.conv_bn_dropout = nn.Sequential( 195 | nn.Conv2d(2 * in_channels, out_channels, kernel_size=1, padding=0), 196 | ModuleHelper.BNReLU(out_channels, norm_type=norm_type), 197 | nn.Dropout2d(dropout) 198 | ) 199 | 200 | def _make_stage(self, in_channels, output_channels, key_channels, value_channels, size): 201 | return SelfAttentionBlock2D(in_channels, 202 | key_channels, 203 | value_channels, 204 | output_channels, 205 | size, 206 | self.norm_type, 207 | self.psp_size) 208 | 209 | def forward(self, feats): 210 | priors = [stage(feats) for stage in self.stages] 211 | context = priors[0] 212 | for i in range(1, len(priors)): 213 | context += priors[i] 214 | output = self.conv_bn_dropout(torch.cat([context, feats], 1)) 215 | return output 216 | 217 | class AFNB(nn.Module): 218 | """ 219 | Parameters: 220 | in_features / out_features: the channels of the input / output feature maps. 221 | dropout: we choose 0.05 as the default value. 222 | size: you can apply multiple sizes. Here we only use one size. 223 | Return: 224 | features fused with Object context information. 225 | """ 226 | 227 | def __init__(self, low_in_channels, high_in_channels, out_channels, key_channels, value_channels, dropout, 228 | sizes=([1]), norm_type=None,psp_size=(1,3,6,8)): 229 | super(AFNB, self).__init__() 230 | self.stages = [] 231 | self.norm_type = norm_type 232 | self.psp_size=psp_size 233 | self.stages = nn.ModuleList( 234 | [self._make_stage([low_in_channels, high_in_channels], out_channels, key_channels, value_channels, size) for 235 | size in sizes]) 236 | self.conv_bn_dropout = nn.Sequential( 237 | nn.Conv2d(out_channels + high_in_channels, out_channels, kernel_size=1, padding=0), 238 | ModuleHelper.BatchNorm2d(norm_type=self.norm_type)(out_channels), 239 | nn.Dropout2d(dropout) 240 | ) 241 | 242 | def _make_stage(self, in_channels, output_channels, key_channels, value_channels, size): 243 | return SelfAttentionBlock2D(in_channels[0], 244 | in_channels[1], 245 | key_channels, 246 | value_channels, 247 | output_channels, 248 | size, 249 | self.norm_type, 250 | psp_size=self.psp_size) 251 | 252 | def forward(self, low_feats, high_feats): 253 | priors = [stage(low_feats, high_feats) for stage in self.stages] 254 | context = priors[0] 255 | for i in range(1, len(priors)): 256 | context += priors[i] 257 | output = self.conv_bn_dropout(torch.cat([context, high_feats], 1)) 258 | return output 259 | 260 | class asymmetric_non_local_network(nn.Sequential): 261 | def __init__(self, configer): 262 | super(asymmetric_non_local_network, self).__init__() 263 | self.configer = configer 264 | self.num_classes = self.configer.get('data', 'num_classes') 265 | self.backbone = BackboneSelector(configer).get_backbone() 266 | # low_in_channels, high_in_channels, out_channels, key_channels, value_channels, dropout 267 | self.fusion = AFNB(1024, 2048, 2048, 256, 256, dropout=0.05, sizes=([1]), norm_type=self.configer.get('network', 'norm_type')) 268 | # extra added layers 269 | self.context = nn.Sequential( 270 | nn.Conv2d(2048, 512, kernel_size=3, stride=1, padding=1), 271 | ModuleHelper.BNReLU(512, norm_type=self.configer.get('network', 'norm_type')), 272 | APNB(in_channels=512, out_channels=512, key_channels=256, value_channels=256, 273 | dropout=0.05, sizes=([1]), norm_type=self.configer.get('network', 'norm_type')) 274 | ) 275 | self.cls = nn.Conv2d(512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True) 276 | self.dsn = nn.Sequential( 277 | nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), 278 | ModuleHelper.BNReLU(512, norm_type=self.configer.get('network', 'norm_type')), 279 | nn.Dropout2d(0.05), 280 | nn.Conv2d(512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True) 281 | ) 282 | 283 | def forward(self, x_): 284 | x = self.backbone(x_) 285 | aux_x = self.dsn(x[-2]) 286 | x = self.fusion(x[-2], x[-1]) 287 | x = self.context(x) 288 | x = self.cls(x) 289 | aux_x = F.interpolate(aux_x, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True) 290 | x = F.interpolate(x, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True) 291 | return aux_x, x 292 | if __name__ == '__main__': 293 | inputs = torch.rand((2, 1, 256, 512)).cuda() 294 | 295 | unet_plus_plus = UNet_deepsup(in_channels=1, n_classes=2).cuda() 296 | a,b,c,output = unet_plus_plus(inputs) 297 | print('# parameters:', sum(param.numel() for param in unet_plus_plus.parameters())) 298 | def get_parameter_number(net): 299 | total_num = sum(p.numel() for p in net.parameters()) 300 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) 301 | return {'Total': total_num, 'Trainable': trainable_num} 302 | print('# parameters:', get_parameter_number(unet_plus_plus)) --------------------------------------------------------------------------------