├── requirements.txt ├── model ├── __init__.py ├── ACM │ ├── __init__.py │ ├── fusion.py │ └── acm.py ├── AGPCNet │ ├── __init__.py │ ├── fusion.py │ ├── agpc.py │ ├── resnet.py │ └── context.py ├── DNANet │ ├── __init__.py │ └── dna_net.py ├── URANet │ ├── __init__.py │ └── uranet.py ├── ABC │ ├── __init__.py │ ├── Module.py │ └── ABCNet.py ├── MTUet │ ├── __init__.py │ ├── vit.py │ └── mtu_uet.py ├── RDIAN │ ├── __init__.py │ ├── cbam.py │ ├── direction.py │ └── rdian.py ├── UIUNet │ ├── __init__.py │ └── fusion.py └── build_segmentor.py ├── utils ├── __init__.py ├── drawing.py ├── loss.py ├── save_model.py ├── scheduler.py ├── data.py ├── visual.py ├── logs.py ├── metric.py └── tools.py ├── configs ├── acm │ ├── acm_res20_unet_256x256_500e_nudt.py │ ├── acm_res20_unet_512x512_800e_nuaa.py │ ├── acm_res20_unet_512x512_500e_irstd1k.py │ ├── acm_res20_unet_256x256_300e_sirstaug.py │ ├── acm_res20_fpn_512x512_500e_irstd1k.py │ ├── acm_res20_fpn_256x256_500e_nudt.py │ ├── acm_res20_fpn_256x256_300e_sirstaug.py │ └── acm_res20_fpn_512x512_800e_nuaa.py ├── dnanet │ ├── dnanet_res34_512x512_800e_nuaa.py │ ├── dnanet_res18_512x512_800e_nuaa.py │ ├── dnanet_vgg10_512x512_800e_nuaa.py │ ├── dnanet_res18_256x256_300e_sirstaug.py │ ├── dnanet_res18_256x256_800e_nudt.py │ ├── dnanet_res18_512x512_500e_irstd1k.py │ └── dnanet_res10_512x512_800e_nuaa.py ├── abcnet │ ├── abcnet_clft-b_256x256_1500e_nudt.py │ ├── abcnet_clft-b_512x512_1500e_nuaa.py │ ├── abcnet_clft-l_512x512_1500e_nuaa.py │ ├── abcnet_clft-b_512x512_500e_irstd1k.py │ ├── abcnet_clft-l_512x512_500e_irstd1k.py │ ├── abcnet_clft-b_256x256_500e_sirstaug.py │ ├── abcnet_clft-l_256x256_1500e_nudt.py │ ├── abcnet_clft-l_256x256_500e_sirstaug.py │ ├── abcnet_clft-s_256x256_500e_sirstaug.py │ ├── abcnet_clft-s_512x512_500e_irstd1k.py │ ├── abcnet_clft-s_256x256_1500e_nudt.py │ └── abcnet_clft-s_512x512_1500e_nuaa.py ├── agpcnet │ ├── agpcnet_res34_512x512_800e_nuaa.py │ ├── agpcnet_res18_256x256_800e_nudt.py │ ├── agpcnet_res18_512x512_800e_nuaa.py │ ├── agpcnet_res34_512x512_500e_irstd1k.py │ ├── agpcnet_res34_256x256_500e_nudt.py │ └── agpcnet_res34_256x256_300e_sirstaug.py ├── _base_ │ ├── datasets │ │ ├── nuaa.py │ │ ├── nudt.py │ │ ├── irstd1k.py │ │ └── sirstaug.py │ ├── models │ │ ├── acm.py │ │ ├── abcnet.py │ │ ├── uranet.py │ │ ├── agpc.py │ │ ├── dnanet.py │ │ └── mscan.py │ ├── default_runtime.py │ └── schedules │ │ └── schedule_500e.py ├── uranet │ ├── uranet_base_512x512_500e_irstd1k.py │ ├── uranet_base_256x256_500e_nudt.py │ ├── uranet_base_512x512_800e_nuaa.py │ └── uranet_base_256x256_300e_sirstaug.py ├── rdian │ ├── rdian_base_512x512_800e_nuaa.py │ ├── rdian_base_256x256_1500e_nudt.py │ ├── rdian_base_256x256_300e_sirstaug.py │ └── rdian_base_512x512_500e_irstd1k.py ├── mtuet │ ├── mtuuet_base_512x512_1500e_nudt.py │ ├── mtuuet_base_512x512_800e_nuaa.py │ ├── mtuuet_base_512x512_300e_srstaug.py │ └── mtuuet_base_512x512_500e_istd1k.py └── uiunet │ ├── uiunet_base_256x256_1500e_nudt.py │ ├── uiunet_base_512x512_800e_nuaa.py │ ├── uiunet_base_256x256_300e_sirstaug.py │ └── uiunet_base_512x512_500e_irst1k.py ├── LICENSE ├── docs ├── get_started.md ├── add_dataset.md ├── add_loss.md ├── add_model.md └── add_optimizer.md ├── README.md ├── test.py └── train.py /requirements.txt: -------------------------------------------------------------------------------- 1 | mmcv-full==1.4.0 2 | mmdet==2.25.0 3 | mmsegmentation==0.28.0 -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/4/5 15:18 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : __init__.py.py 5 | # @Software: PyCharm 6 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/4/6 14:41 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : __init__.py.py 5 | # @Software: PyCharm 6 | -------------------------------------------------------------------------------- /model/ACM/__init__.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/9/20 15:45 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : __init__.py.py 5 | # @Software: PyCharm 6 | -------------------------------------------------------------------------------- /model/AGPCNet/__init__.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/5/18 17:23 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : __init__.py.py 5 | # @Software: PyCharm 6 | -------------------------------------------------------------------------------- /model/DNANet/__init__.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/4/6 16:17 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : __init__.py.py 5 | # @Software: PyCharm 6 | -------------------------------------------------------------------------------- /model/URANet/__init__.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/6/10 16:13 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : __init__.py.py 5 | # @Software: PyCharm 6 | -------------------------------------------------------------------------------- /configs/acm/acm_res20_unet_256x256_500e_nudt.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'acm_res20_fpn_256x256_500e_nudt.py', 3 | ] 4 | # model settings 5 | model = dict( 6 | decode_head=dict( 7 | name='ASKCResUNet') 8 | ) 9 | -------------------------------------------------------------------------------- /configs/acm/acm_res20_unet_512x512_800e_nuaa.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'acm_res20_fpn_512x512_800e_nuaa.py', 3 | ] 4 | # model settings 5 | model = dict( 6 | decode_head=dict( 7 | name='ASKCResUNet') 8 | ) 9 | -------------------------------------------------------------------------------- /model/ABC/__init__.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/10/3 16:13 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : __init__.py.py 5 | # @Software: PyCharm 6 | from model.ABC.ABCNet import ABCNet -------------------------------------------------------------------------------- /configs/acm/acm_res20_unet_512x512_500e_irstd1k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'acm_res20_fpn_512x512_500e_irstd1k.py', 3 | ] 4 | # model settings 5 | model = dict( 6 | decode_head=dict( 7 | name='ASKCResUNet') 8 | ) 9 | -------------------------------------------------------------------------------- /model/MTUet/__init__.py: -------------------------------------------------------------------------------- 1 | # @Time : 2024/2/6 13:33 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : __init__.py.py 5 | # @Software: PyCharm 6 | from model.MTUet.mtu_uet import MTUNet -------------------------------------------------------------------------------- /model/RDIAN/__init__.py: -------------------------------------------------------------------------------- 1 | # @Time : 2024/2/24 15:30 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : __init__.py.py 5 | # @Software: PyCharm 6 | from model.RDIAN.rdian import RDIAN -------------------------------------------------------------------------------- /model/UIUNet/__init__.py: -------------------------------------------------------------------------------- 1 | # @Time : 2024/2/25 01:17 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : __init__.py.py 5 | # @Software: PyCharm 6 | from model.UIUNet.uiunet import UIUNet -------------------------------------------------------------------------------- /configs/acm/acm_res20_unet_256x256_300e_sirstaug.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'acm_res20_fpn_256x256_300e_sirstaug.py', 3 | ] 4 | # model settings 5 | model = dict( 6 | decode_head=dict( 7 | name='ASKCResUNet') 8 | ) 9 | -------------------------------------------------------------------------------- /configs/dnanet/dnanet_res34_512x512_800e_nuaa.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'dnanet_res10_512x512_800e_nuaa.py' 3 | ] 4 | # model settings 5 | model = dict( 6 | decode_head=dict( 7 | num_blocks=[3, 4, 6, 3] 8 | ) 9 | ) -------------------------------------------------------------------------------- /configs/abcnet/abcnet_clft-b_256x256_1500e_nudt.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'abcnet_clft-s_256x256_500e_nudt.py' 3 | ] 4 | 5 | model = dict( 6 | decode_head=dict( 7 | dim=32 8 | ) 9 | ) 10 | data = dict(train_batch=16) -------------------------------------------------------------------------------- /configs/abcnet/abcnet_clft-b_512x512_1500e_nuaa.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'abcnet_clft-s_512x512_1500e_nuaa.py' 3 | ] 4 | 5 | model = dict( 6 | decode_head=dict( 7 | dim=32 8 | ) 9 | ) 10 | data = dict(train_batch=16) -------------------------------------------------------------------------------- /configs/abcnet/abcnet_clft-l_512x512_1500e_nuaa.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'abcnet_clft-s_512x512_1500e_nuaa.py' 3 | ] 4 | 5 | model = dict( 6 | decode_head=dict( 7 | dim=64 8 | ) 9 | ) 10 | data = dict(train_batch=4) -------------------------------------------------------------------------------- /configs/abcnet/abcnet_clft-b_512x512_500e_irstd1k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'abcnet_clft-s_512x512_500e_irstd1k.py' 3 | ] 4 | 5 | model = dict( 6 | decode_head=dict( 7 | dim=32 8 | ) 9 | ) 10 | data = dict(train_batch=16) -------------------------------------------------------------------------------- /configs/abcnet/abcnet_clft-l_512x512_500e_irstd1k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'abcnet_clft-s_512x512_500e_irstd1k.py' 3 | ] 4 | 5 | model = dict( 6 | decode_head=dict( 7 | dim=64 8 | ) 9 | ) 10 | data = dict(train_batch=4) -------------------------------------------------------------------------------- /configs/abcnet/abcnet_clft-b_256x256_500e_sirstaug.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'abcnet_clft-s_256x256_500e_sirstaug.py' 3 | ] 4 | 5 | model = dict( 6 | decode_head=dict( 7 | dim=32 8 | ) 9 | ) 10 | data = dict(train_batch=16) -------------------------------------------------------------------------------- /configs/abcnet/abcnet_clft-l_256x256_1500e_nudt.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'abcnet_clft-s_256x256_1500e_nudt.py' 3 | ] 4 | 5 | model = dict( 6 | decode_head=dict( 7 | dim=64 8 | ) 9 | ) 10 | data = dict(train_batch=16) 11 | -------------------------------------------------------------------------------- /configs/abcnet/abcnet_clft-l_256x256_500e_sirstaug.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'abcnet_clft-s_256x256_500e_sirstaug.py' 3 | ] 4 | 5 | model = dict( 6 | decode_head=dict( 7 | dim=64 8 | ) 9 | ) 10 | data = dict(train_batch=16) -------------------------------------------------------------------------------- /configs/agpcnet/agpcnet_res34_512x512_800e_nuaa.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'agpcnet_res18_512x512_800e_nuaa.py', 3 | ] 4 | # model settings 5 | model = dict( 6 | decode_head=dict( 7 | backbone='resnet34') 8 | ) 9 | data = dict(train_batch=4) 10 | -------------------------------------------------------------------------------- /configs/dnanet/dnanet_res18_512x512_800e_nuaa.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'dnanet_res10_512x512_800e_nuaa.py' 3 | ] 4 | # model settings 5 | model = dict( 6 | decode_head=dict( 7 | num_blocks=[2, 2, 2, 2] 8 | ) 9 | ) 10 | data = dict(train_batch=8) 11 | -------------------------------------------------------------------------------- /configs/dnanet/dnanet_vgg10_512x512_800e_nuaa.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'dnanet_res10_512x512_800e_nuaa.py' 3 | ] 4 | # model settings 5 | model = dict( 6 | backbone=dict( 7 | type=None, 8 | type_info='vgg', 9 | ), 10 | decode_head=dict( 11 | block='vgg', 12 | num_blocks=[1, 1, 1, 1] 13 | ) 14 | ) -------------------------------------------------------------------------------- /configs/dnanet/dnanet_res18_256x256_300e_sirstaug.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/sirstaug.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py', 5 | '../_base_/models/dnanet.py' 6 | ] 7 | optimizer = dict( 8 | type='SGD', 9 | setting=dict(lr=0.05, momentum=0.9, weight_decay=0.0005) 10 | ) 11 | -------------------------------------------------------------------------------- /configs/acm/acm_res20_fpn_512x512_500e_irstd1k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/irstd1k.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py', 5 | '../_base_/models/acm.py' 6 | ] 7 | optimizer = dict( 8 | type='Adagrad', 9 | setting=dict(lr=0.05, weight_decay=1e-4) 10 | ) 11 | data = dict(train_batch=8) 12 | -------------------------------------------------------------------------------- /configs/_base_/datasets/nuaa.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | data = dict( 3 | dataset_type='NUAA', 4 | data_root='/data1/ppw/works/All_ISTD/datasets/NUAA', 5 | base_size=512, 6 | crop_size=512, 7 | data_aug=True, 8 | suffix='png', 9 | num_workers=8, 10 | train_batch=8, 11 | test_batch=8, 12 | train_dir='trainval', 13 | test_dir='test' 14 | ) 15 | -------------------------------------------------------------------------------- /configs/_base_/datasets/nudt.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | data = dict( 3 | dataset_type='NUDT', 4 | data_root='/data1/ppw/works/All_ISTD/datasets/NUDT', 5 | base_size=256, 6 | crop_size=256, 7 | data_aug=True, 8 | suffix='png', 9 | num_workers=8, 10 | train_batch=32, 11 | test_batch=32, 12 | train_dir='trainval', 13 | test_dir='test' 14 | ) 15 | -------------------------------------------------------------------------------- /configs/acm/acm_res20_fpn_256x256_500e_nudt.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/nudt.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py', 5 | '../_base_/models/acm.py' 6 | ] 7 | optimizer = dict( 8 | type='Adagrad', 9 | setting=dict(lr=0.05, weight_decay=1e-4) 10 | ) 11 | runner = dict(type='EpochBasedRunner', max_epochs=1500) 12 | -------------------------------------------------------------------------------- /configs/_base_/datasets/irstd1k.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | data = dict( 3 | dataset_type='IRSTD1k', 4 | data_root='/data1/ppw/works/All_ISTD/datasets/IRSTD-1k', 5 | base_size=512, 6 | crop_size=512, 7 | data_aug=True, 8 | suffix='png', 9 | num_workers=8, 10 | train_batch=8, 11 | test_batch=8, 12 | train_dir='trainval', 13 | test_dir='test' 14 | ) 15 | -------------------------------------------------------------------------------- /configs/acm/acm_res20_fpn_256x256_300e_sirstaug.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/sirstaug.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py', 5 | '../_base_/models/acm.py' 6 | ] 7 | optimizer = dict( 8 | type='Adagrad', 9 | setting=dict(lr=0.05, weight_decay=1e-4) 10 | ) 11 | runner = dict(type='EpochBasedRunner', max_epochs=300) 12 | -------------------------------------------------------------------------------- /configs/uranet/uranet_base_512x512_500e_irstd1k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/irstd1k.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py', 5 | '../_base_/models/uranet.py' 6 | ] 7 | 8 | optimizer = dict( 9 | type='AdamW', 10 | setting=dict(lr=0.001, weight_decay=1e-4, betas=(0.9, 0.999)) 11 | ) 12 | data = dict(train_batch=8) 13 | -------------------------------------------------------------------------------- /configs/_base_/datasets/sirstaug.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | data = dict( 3 | dataset_type='SIRSTAUG', 4 | data_root='/data1/ppw/works/All_ISTD/datasets/SIRST_AUG', 5 | base_size=256, 6 | crop_size=256, 7 | data_aug=True, 8 | suffix='png', 9 | num_workers=8, 10 | train_batch=32, 11 | test_batch=32, 12 | train_dir='trainval', 13 | test_dir='test' 14 | ) 15 | -------------------------------------------------------------------------------- /configs/dnanet/dnanet_res18_256x256_800e_nudt.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/nudt.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py', 5 | '../_base_/models/dnanet.py' 6 | ] 7 | optimizer = dict( 8 | type='SGD', 9 | setting=dict(lr=0.05, momentum=0.9, weight_decay=0.0005) 10 | ) 11 | runner = dict(type='EpochBasedRunner', max_epochs=1500) 12 | -------------------------------------------------------------------------------- /configs/agpcnet/agpcnet_res18_256x256_800e_nudt.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/nudt.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py', 5 | '../_base_/models/agpc.py' 6 | ] 7 | 8 | optimizer = dict( 9 | type='SGD', 10 | setting=dict(lr=0.05, momentum=0.9, weight_decay=0.0005) 11 | ) 12 | runner = dict(type='EpochBasedRunner', max_epochs=1500) 13 | -------------------------------------------------------------------------------- /configs/agpcnet/agpcnet_res18_512x512_800e_nuaa.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/nuaa.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py', 5 | '../_base_/models/agpc.py' 6 | ] 7 | 8 | optimizer = dict( 9 | type='SGD', 10 | setting=dict(lr=0.05, momentum=0.9, weight_decay=0.0005) 11 | ) 12 | runner = dict(type='EpochBasedRunner', max_epochs=800) 13 | -------------------------------------------------------------------------------- /configs/acm/acm_res20_fpn_512x512_800e_nuaa.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/nuaa.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py', 5 | '../_base_/models/acm.py' 6 | ] 7 | optimizer = dict( 8 | type='Adagrad', 9 | setting=dict(lr=0.05, weight_decay=1e-4) 10 | ) 11 | runner = dict(type='EpochBasedRunner', max_epochs=800) 12 | data = dict(train_batch=8) 13 | -------------------------------------------------------------------------------- /configs/uranet/uranet_base_256x256_500e_nudt.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/nudt.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py', 5 | '../_base_/models/uranet.py' 6 | ] 7 | 8 | optimizer = dict( 9 | type='AdamW', 10 | setting=dict(lr=0.0001, weight_decay=0.01, betas=(0.9, 0.999)) 11 | ) 12 | runner = dict(type='EpochBasedRunner', max_epochs=1500) 13 | -------------------------------------------------------------------------------- /configs/_base_/models/acm.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | name='ACM', 3 | type='EncoderDecoder', 4 | pretrained=None, 5 | backbone=dict( 6 | type=None, 7 | type_info='resnet' 8 | ), 9 | decode_head=dict( 10 | type='ASKCResNetFPN', 11 | layer_blocks=[4, 4, 4], 12 | channels=[8, 16, 32, 64], 13 | fuse_model='AsymBi' 14 | ), 15 | loss=dict(type='SoftIoULoss') 16 | ) 17 | -------------------------------------------------------------------------------- /configs/_base_/models/abcnet.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | name='ABCNet', 3 | type='EncoderDecoder', 4 | pretrained=None, 5 | backbone=dict( 6 | type=None 7 | ), 8 | decode_head=dict( 9 | type='ABCNet', 10 | in_ch=3, 11 | out_ch=1, 12 | dim=64, # in dim 13 | ori_h=256, # image height == width 14 | deep_supervision=True 15 | ), 16 | loss=dict(type='SoftIoULoss') 17 | ) 18 | -------------------------------------------------------------------------------- /configs/uranet/uranet_base_512x512_800e_nuaa.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/nuaa.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py', 5 | '../_base_/models/uranet.py' 6 | ] 7 | 8 | optimizer = dict( 9 | type='AdamW', 10 | setting=dict(lr=0.001, weight_decay=0.01, betas=(0.9, 0.999)) 11 | ) 12 | runner = dict(type='EpochBasedRunner', max_epochs=800) 13 | data = dict(train_batch=8) 14 | -------------------------------------------------------------------------------- /configs/uranet/uranet_base_256x256_300e_sirstaug.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/sirstaug.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py', 5 | '../_base_/models/uranet.py' 6 | ] 7 | 8 | optimizer = dict( 9 | type='AdamW', 10 | setting=dict(lr=0.001, weight_decay=0.01, betas=(0.9, 0.999)) 11 | ) 12 | runner = dict(type='EpochBasedRunner', max_epochs=300) 13 | data = dict(train_batch=64) 14 | -------------------------------------------------------------------------------- /configs/dnanet/dnanet_res18_512x512_500e_irstd1k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/irstd1k.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py', 5 | '../_base_/models/dnanet.py' 6 | ] 7 | optimizer = dict( 8 | type='SGD', 9 | setting=dict(lr=0.05, momentum=0.9, weight_decay=0.0005) 10 | ) 11 | data = dict(train_batch=8) 12 | runner = dict(type='EpochBasedRunner', max_epochs=1000) 13 | random_seed = 64 14 | -------------------------------------------------------------------------------- /configs/agpcnet/agpcnet_res34_512x512_500e_irstd1k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/irstd1k.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py', 5 | '../_base_/models/agpc.py' 6 | ] 7 | model = dict( 8 | decode_head=dict( 9 | backbone='resnet34') 10 | ) 11 | 12 | optimizer = dict( 13 | type='SGD', 14 | setting=dict(lr=0.05, momentum=0.9, weight_decay=0.0005) 15 | ) 16 | data = dict(train_batch=4) 17 | -------------------------------------------------------------------------------- /configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | # yapf:disable 2 | log_config = dict( 3 | interval=10, 4 | hooks=[ 5 | dict(type='TextLoggerHook', by_epoch=True), 6 | dict(type='TensorboardLoggerHook') 7 | ]) 8 | # yapf:enable 9 | dist_params = dict(backend='nccl') 10 | log_level = 'INFO' 11 | load_from = None 12 | resume_from = None 13 | workflow = [('train', 1)] 14 | cudnn_benchmark = True 15 | find_unused_parameters = False 16 | random_seed = 42 17 | gpus = 1 18 | -------------------------------------------------------------------------------- /configs/abcnet/abcnet_clft-s_256x256_500e_sirstaug.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/sirstaug.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py', 5 | '../_base_/models/abcnet.py' 6 | ] 7 | model = dict( 8 | decode_head=dict( 9 | dim=16, 10 | ) 11 | ) 12 | 13 | optimizer = dict( 14 | type='AdamW', 15 | setting=dict(lr=0.0001, weight_decay=0.01, betas=(0.9, 0.999)) 16 | ) 17 | data = dict(train_batch=32) 18 | -------------------------------------------------------------------------------- /configs/abcnet/abcnet_clft-s_512x512_500e_irstd1k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/irstd1k.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py', 5 | '../_base_/models/abcnet.py' 6 | ] 7 | model = dict( 8 | decode_head=dict( 9 | dim=16, 10 | ori_h=512 11 | ) 12 | ) 13 | 14 | optimizer = dict( 15 | type='AdamW', 16 | setting=dict(lr=0.0001, weight_decay=0.01, betas=(0.9, 0.999)) 17 | ) 18 | data = dict(train_batch=32) 19 | -------------------------------------------------------------------------------- /configs/_base_/models/uranet.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | name='UraNet_ffc', 4 | type='EncoderDecoder', 5 | pretrained=None, 6 | backbone=dict( 7 | type=None, 8 | type_info='resnet', 9 | ), 10 | decode_head=dict( 11 | type='URANet', 12 | in_channel=3, 13 | base_dim=32, 14 | class_num=1, 15 | bilinear=True, 16 | use_da=True, 17 | theta=0.7 18 | ), 19 | loss=dict(type='SoftIoULoss') 20 | ) 21 | -------------------------------------------------------------------------------- /configs/_base_/models/agpc.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | name='AGPCNet', 4 | type='EncoderDecoder', 5 | pretrained=None, 6 | backbone=dict( 7 | type=None, 8 | type_info='resnet', 9 | ), 10 | decode_head=dict( 11 | type='AGPCNet', 12 | backbone='resnet18', 13 | scalse=[10, 6, 5, 3], 14 | reduce_ratios=[16, 4], 15 | gca_type='patch', 16 | gca_att='post', 17 | drop=0.1), 18 | loss=dict(type='SoftIoULoss') 19 | ) 20 | -------------------------------------------------------------------------------- /configs/_base_/models/dnanet.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | name='DNANet', 4 | type='EncoderDecoder', 5 | pretrained=None, 6 | backbone=dict( 7 | type=None, 8 | type_info='resnet', 9 | ), 10 | decode_head=dict( 11 | type='DNANet', 12 | num_classes=1, 13 | input_channels=3, 14 | block_name='resnet', 15 | num_blocks=[2, 2, 2, 2], 16 | nb_filter=[16, 32, 64, 128, 256] 17 | ), 18 | loss=dict(type='SoftIoULoss') 19 | ) 20 | -------------------------------------------------------------------------------- /configs/agpcnet/agpcnet_res34_256x256_500e_nudt.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/nudt.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py', 5 | '../_base_/models/agpc.py' 6 | ] 7 | model = dict( 8 | decode_head=dict( 9 | backbone='resnet34') 10 | ) 11 | 12 | optimizer = dict( 13 | type='SGD', 14 | setting=dict(lr=0.05, momentum=0.9, weight_decay=0.0005) 15 | ) 16 | runner = dict(type='EpochBasedRunner', max_epochs=3000) 17 | data = dict(train_batch=16) 18 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_500e.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict( 3 | type='SGD', 4 | setting=dict(lr=0.01, weight_decay=0.0005) 5 | ) 6 | optimizer_config = dict() 7 | # learning policy 8 | # TODO warmup only 'linear' 9 | lr_config = dict(policy='PolyLR', warmup='linear', power=0.9, min_lr=1e-4, warmup_epochs=5) 10 | # runtime settings 11 | runner = dict(type='EpochBasedRunner', max_epochs=500) 12 | checkpoint_config = dict(by_epoch=True, interval=1) 13 | evaluation = dict(epochval=1, metric='mIoU', pre_eval=True) 14 | -------------------------------------------------------------------------------- /configs/abcnet/abcnet_clft-s_256x256_1500e_nudt.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/nudt.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py', 5 | '../_base_/models/abcnet.py' 6 | ] 7 | model = dict( 8 | decode_head=dict( 9 | dim=16, 10 | ) 11 | ) 12 | 13 | optimizer = dict( 14 | type='AdamW', 15 | setting=dict(lr=0.0001, weight_decay=0.01, betas=(0.9, 0.999)) 16 | ) 17 | data = dict(train_batch=32) 18 | runner = dict(type='EpochBasedRunner', max_epochs=1500) 19 | -------------------------------------------------------------------------------- /configs/agpcnet/agpcnet_res34_256x256_300e_sirstaug.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/sirstaug.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py', 5 | '../_base_/models/agpc.py' 6 | ] 7 | model = dict( 8 | decode_head=dict( 9 | backbone='resnet34') 10 | ) 11 | 12 | optimizer = dict( 13 | type='SGD', 14 | setting=dict(lr=0.05, momentum=0.9, weight_decay=0.0005) 15 | ) 16 | runner = dict(type='EpochBasedRunner', max_epochs=300) 17 | data = dict(train_batch=16) 18 | -------------------------------------------------------------------------------- /configs/abcnet/abcnet_clft-s_512x512_1500e_nuaa.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/nuaa.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py', 5 | '../_base_/models/abcnet.py' 6 | ] 7 | 8 | model = dict( 9 | decode_head=dict( 10 | dim=16, 11 | ori_h=512 12 | ) 13 | ) 14 | 15 | optimizer = dict( 16 | type='AdamW', 17 | setting=dict(lr=0.0003, weight_decay=0.01, betas=(0.9, 0.999)) 18 | ) 19 | runner = dict(type='EpochBasedRunner', max_epochs=1500) 20 | data = dict(train_batch=32) 21 | -------------------------------------------------------------------------------- /configs/rdian/rdian_base_512x512_800e_nuaa.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/nuaa.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py' 5 | ] 6 | 7 | model = dict( 8 | name='Segformer', 9 | type='EncoderDecoder', 10 | pretrained=None, 11 | backbone=dict( 12 | type=None, 13 | ), 14 | decode_head=dict( 15 | type='RDIAN' 16 | ), 17 | loss=dict(type='SoftIoULoss') 18 | ) 19 | 20 | optimizer = dict( 21 | type='Adagrad', 22 | setting=dict(lr=0.005, weight_decay=1e-4) 23 | ) 24 | runner = dict(type='EpochBasedRunner', max_epochs=800) 25 | data = dict( 26 | train_batch=8, 27 | test_batch=8, 28 | rgb=False) -------------------------------------------------------------------------------- /configs/mtuet/mtuuet_base_512x512_1500e_nudt.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/nudt.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py' 5 | ] 6 | 7 | model = dict( 8 | name='Segformer', 9 | type='EncoderDecoder', 10 | pretrained=None, 11 | backbone=dict( 12 | type=None, 13 | ), 14 | decode_head=dict( 15 | type='res_UNet' 16 | ), 17 | loss=dict(type='SoftIoULoss') 18 | ) 19 | 20 | optimizer = dict( 21 | type='AdamW', 22 | setting=dict(lr=0.0001, weight_decay=0.01, betas=(0.9, 0.999)) 23 | ) 24 | runner = dict(type='EpochBasedRunner', max_epochs=1500) 25 | data = dict( 26 | train_batch=32, 27 | test_batch=32) -------------------------------------------------------------------------------- /configs/mtuet/mtuuet_base_512x512_800e_nuaa.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/nuaa.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py' 5 | ] 6 | 7 | model = dict( 8 | name='Segformer', 9 | type='EncoderDecoder', 10 | pretrained=None, 11 | backbone=dict( 12 | type=None, 13 | ), 14 | decode_head=dict( 15 | type='res_UNet' 16 | ), 17 | loss=dict(type='SoftIoULoss') 18 | ) 19 | 20 | optimizer = dict( 21 | type='AdamW', 22 | setting=dict(lr=0.0001, weight_decay=0.01, betas=(0.9, 0.999)) 23 | ) 24 | runner = dict(type='EpochBasedRunner', max_epochs=800) 25 | data = dict( 26 | train_batch=16, 27 | test_batch=16) -------------------------------------------------------------------------------- /configs/mtuet/mtuuet_base_512x512_300e_srstaug.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/sirstaug.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py' 5 | ] 6 | 7 | model = dict( 8 | name='Segformer', 9 | type='EncoderDecoder', 10 | pretrained=None, 11 | backbone=dict( 12 | type=None, 13 | ), 14 | decode_head=dict( 15 | type='res_UNet' 16 | ), 17 | loss=dict(type='SoftIoULoss') 18 | ) 19 | 20 | optimizer = dict( 21 | type='AdamW', 22 | setting=dict(lr=0.0001, weight_decay=0.01, betas=(0.9, 0.999)) 23 | ) 24 | runner = dict(type='EpochBasedRunner', max_epochs=300) 25 | data = dict( 26 | train_batch=32, 27 | test_batch=32) -------------------------------------------------------------------------------- /configs/mtuet/mtuuet_base_512x512_500e_istd1k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/irstd1k.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py' 5 | ] 6 | 7 | model = dict( 8 | name='Segformer', 9 | type='EncoderDecoder', 10 | pretrained=None, 11 | backbone=dict( 12 | type=None, 13 | ), 14 | decode_head=dict( 15 | type='res_UNet' 16 | ), 17 | loss=dict(type='SoftIoULoss') 18 | ) 19 | 20 | optimizer = dict( 21 | type='AdamW', 22 | setting=dict(lr=0.0001, weight_decay=0.01, betas=(0.9, 0.999)) 23 | ) 24 | runner = dict(type='EpochBasedRunner', max_epochs=800) 25 | data = dict( 26 | train_batch=16, 27 | test_batch=16) -------------------------------------------------------------------------------- /configs/rdian/rdian_base_256x256_1500e_nudt.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/nudt.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py' 5 | ] 6 | 7 | 8 | model = dict( 9 | name='Segformer', 10 | type='EncoderDecoder', 11 | pretrained=None, 12 | backbone=dict( 13 | type=None, 14 | ), 15 | decode_head=dict( 16 | type='RDIAN' 17 | ), 18 | loss=dict(type='SoftIoULoss') 19 | ) 20 | 21 | optimizer = dict( 22 | type='Adagrad', 23 | setting=dict(lr=0.005, weight_decay=1e-4) 24 | ) 25 | runner = dict(type='EpochBasedRunner', max_epochs=1500) 26 | data = dict( 27 | train_batch=16, 28 | test_batch=16, 29 | rgb=False) -------------------------------------------------------------------------------- /configs/rdian/rdian_base_256x256_300e_sirstaug.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/sirstaug.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py' 5 | ] 6 | 7 | model = dict( 8 | name='Segformer', 9 | type='EncoderDecoder', 10 | pretrained=None, 11 | backbone=dict( 12 | type=None, 13 | ), 14 | decode_head=dict( 15 | type='RDIAN' 16 | ), 17 | loss=dict(type='SoftIoULoss') 18 | ) 19 | 20 | optimizer = dict( 21 | type='Adagrad', 22 | setting=dict(lr=0.005, weight_decay=1e-4) 23 | ) 24 | runner = dict(type='EpochBasedRunner', max_epochs=300) 25 | data = dict( 26 | train_batch=16, 27 | test_batch=16, 28 | rgb=False) -------------------------------------------------------------------------------- /configs/rdian/rdian_base_512x512_500e_irstd1k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/irstd1k.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py' 5 | ] 6 | 7 | model = dict( 8 | name='Segformer', 9 | type='EncoderDecoder', 10 | pretrained=None, 11 | backbone=dict( 12 | type=None, 13 | ), 14 | decode_head=dict( 15 | type='RDIAN' 16 | ), 17 | loss=dict(type='SoftIoULoss') 18 | ) 19 | 20 | optimizer = dict( 21 | type='Adagrad', 22 | setting=dict(lr=0.005, weight_decay=1e-4) 23 | ) 24 | runner = dict(type='EpochBasedRunner', max_epochs=500) 25 | data = dict( 26 | train_batch=8, 27 | test_batch=8, 28 | rgb=False) 29 | -------------------------------------------------------------------------------- /configs/uiunet/uiunet_base_256x256_1500e_nudt.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/nudt.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py' 5 | ] 6 | 7 | model = dict( 8 | name='Segformer', 9 | type='EncoderDecoder', 10 | pretrained=None, 11 | backbone=dict( 12 | type=None, 13 | ), 14 | decode_head=dict( 15 | type='UIUNet', 16 | deep_supervision=True 17 | ), 18 | loss=dict(type='SoftIoULoss') 19 | ) 20 | 21 | optimizer = dict( 22 | type='Adam', 23 | setting=dict(lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) 24 | ) 25 | runner = dict(type='EpochBasedRunner', max_epochs=1500) 26 | data = dict( 27 | train_batch=8, 28 | test_batch=8) 29 | -------------------------------------------------------------------------------- /configs/uiunet/uiunet_base_512x512_800e_nuaa.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/nuaa.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py' 5 | ] 6 | 7 | model = dict( 8 | name='Segformer', 9 | type='EncoderDecoder', 10 | pretrained=None, 11 | backbone=dict( 12 | type=None, 13 | ), 14 | decode_head=dict( 15 | type='UIUNet', 16 | deep_supervision=True 17 | ), 18 | loss=dict(type='SoftIoULoss') 19 | ) 20 | 21 | optimizer = dict( 22 | type='Adam', 23 | setting=dict(lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) 24 | ) 25 | runner = dict(type='EpochBasedRunner', max_epochs=800) 26 | data = dict( 27 | train_batch=4, 28 | test_batch=4) 29 | -------------------------------------------------------------------------------- /configs/uiunet/uiunet_base_256x256_300e_sirstaug.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/sirstaug.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py' 5 | ] 6 | 7 | model = dict( 8 | name='Segformer', 9 | type='EncoderDecoder', 10 | pretrained=None, 11 | backbone=dict( 12 | type=None, 13 | ), 14 | decode_head=dict( 15 | type='UIUNet', 16 | deep_supervision=True 17 | ), 18 | loss=dict(type='SoftIoULoss') 19 | ) 20 | 21 | optimizer = dict( 22 | type='Adam', 23 | setting=dict(lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) 24 | ) 25 | runner = dict(type='EpochBasedRunner', max_epochs=300) 26 | data = dict( 27 | train_batch=8, 28 | test_batch=8) 29 | -------------------------------------------------------------------------------- /configs/uiunet/uiunet_base_512x512_500e_irst1k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/irstd1k.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py' 5 | ] 6 | 7 | model = dict( 8 | name='Segformer', 9 | type='EncoderDecoder', 10 | pretrained=None, 11 | backbone=dict( 12 | type=None, 13 | ), 14 | decode_head=dict( 15 | type='UIUNet', 16 | deep_supervision=True 17 | ), 18 | loss=dict(type='SoftIoULoss') 19 | ) 20 | 21 | optimizer = dict( 22 | type='Adam', 23 | setting=dict(lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) 24 | ) 25 | runner = dict(type='EpochBasedRunner', max_epochs=500) 26 | data = dict( 27 | train_batch=4, 28 | test_batch=4) 29 | -------------------------------------------------------------------------------- /configs/dnanet/dnanet_res10_512x512_800e_nuaa.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/nuaa.py', 3 | '../_base_/default_runtime.py', 4 | '../_base_/schedules/schedule_500e.py' 5 | ] 6 | # model settings 7 | model = dict( 8 | name='DNANet', 9 | type='EncoderDecoder', 10 | pretrained=None, 11 | backbone=dict( 12 | type=None, 13 | type_info='resnet', 14 | ), 15 | decode_head=dict( 16 | type='DNANet', 17 | num_classes=1, 18 | input_channels=3, 19 | block_name='resnet', 20 | num_blocks=[1, 1, 1, 1], 21 | nb_filter=[16, 32, 64, 128, 256] 22 | ), 23 | loss=dict(type='SoftIoULoss') 24 | ) 25 | optimizer = dict( 26 | type='Adagrad', 27 | setting=dict(lr=0.05, weight_decay=1e-4) 28 | ) 29 | runner = dict(type='EpochBasedRunner', max_epochs=800) 30 | -------------------------------------------------------------------------------- /configs/_base_/models/mscan.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True) 4 | model = dict( 5 | type='EncoderDecoder', 6 | pretrained=None, 7 | backbone=dict( 8 | type='MSCAN', 9 | embed_dims=[32, 64, 160, 256], 10 | mlp_ratios=[8, 8, 4, 4], 11 | drop_rate=0.0, 12 | drop_path_rate=0.1, 13 | depths=[3, 3, 5, 2], 14 | norm_cfg=dict(type='SyncBN', requires_grad=True)), 15 | decode_head=dict( 16 | type='LightHamHead', 17 | in_channels=[64, 160, 256], 18 | in_index=[1, 2, 3], 19 | channels=256, 20 | ham_channels=256, 21 | dropout_ratio=0.1, 22 | num_classes=19, 23 | norm_cfg=ham_norm_cfg, 24 | align_corners=False), 25 | loss=dict(type='SoftIoULoss'), 26 | ) 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 PANPEIWEN 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/get_started.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | Our config mechanism is based on mmcv, so we need to install the mmcv series package. 3 | ``` 4 | # Python == 3.8 5 | # Pytorch == 1.10 6 | # Cuda == 11.1 7 | 8 | conda create -n Infrared python=3.8 9 | conda install pytorch==1.10.0 torchvision==0.11.1 torchaudio==0.10.0 cudatoolkit=11.1 -c pytorch -c nvidia 10 | 11 | pip install -U openmim 12 | mim install mmcv-full==1.4.0 13 | mim install mmdet==2.25.0 14 | mim install mmsegmentation==0.28.0 15 | ``` 16 | 17 | After the installation is complete, if other packages are missing during the running process, you can install them directly with pip. 18 | ## Dataset Preparation 19 | 20 | ### File Structure 21 | ``` 22 | |-datasets 23 | |-NUAA 24 | |-trainval 25 | |-images 26 | |-Misc_1.png 27 | ... 28 | |-masks 29 | |-Misc_1.png 30 | ... 31 | |-test 32 | |-images 33 | |-Misc_1.png 34 | ... 35 | |-masks 36 | |-Misc_1.png 37 | ... 38 | |-IRSTD 39 | ... 40 | ``` 41 | 42 | ### Datasets Link 43 | 44 | https://drive.google.com/drive/folders/1RGpVHccGb8B4_spX_RZPEMW9pyeXOIaC?usp=sharing 45 | -------------------------------------------------------------------------------- /utils/drawing.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/9/14 19:57 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : drawing.py 5 | # @Software: PyCharm 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | def drawing_loss(num_epoch, train_loss, test_loss, save_dir, curve_file): 10 | plt.figure() 11 | plt.plot(num_epoch, train_loss, label='train_loss') 12 | plt.plot(num_epoch, test_loss, label='test_loss') 13 | plt.legend() 14 | plt.ylabel('Loss') 15 | plt.xlabel('Epoch') 16 | plt.savefig("work_dirs/" + save_dir + '/' + curve_file + "/fig_loss.png") 17 | 18 | 19 | def drawing_iou(num_epoch, mIoU, nIoU, save_dir, curve_file): 20 | plt.figure() 21 | plt.plot(num_epoch, mIoU, label='mIoU') 22 | plt.plot(num_epoch, nIoU, label='nIoU') 23 | plt.legend() 24 | plt.ylabel('IoU') 25 | plt.xlabel('Epoch') 26 | plt.savefig("work_dirs/" + save_dir + '/' + curve_file + "/fig_IoU.png") 27 | 28 | 29 | def drawing_f1(num_epoch, f1, save_dir, curve_file): 30 | plt.figure() 31 | plt.plot(num_epoch, f1, label='F1-score') 32 | plt.legend() 33 | plt.ylabel('F1-score') 34 | plt.xlabel('Epoch') 35 | plt.savefig("work_dirs/" + save_dir + '/' + curve_file + "/fig_F1-score.png") 36 | -------------------------------------------------------------------------------- /model/build_segmentor.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/9/22 17:02 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : build_segmentor.py 5 | # @Software: PyCharm 6 | import torch.nn as nn 7 | from model.AGPCNet.agpc import AGPCNet, AGPCNet_Pro 8 | from model.ACM.acm import ASKCResNetFPN, ASKCResUNet 9 | from model.DNANet.dna_net import DNANet 10 | from model.URANet.uranet import URANet 11 | from model.ABC.ABCNet import ABCNet 12 | from model.RDIAN.rdian import RDIAN 13 | from model.MTUet.mtu_uet import MTUNet 14 | from model.UIUNet.uiunet import UIUNet 15 | 16 | 17 | __all__ = ['Model', 'AGPCNet', 'AGPCNet_Pro', 'ASKCResUNet', 'ASKCResNetFPN', 'DNANet', 'URANet', 'ABCNet', 'RDIAN', 'MTUNet', 'UIUNet'] 18 | 19 | 20 | class Model(nn.Module): 21 | def __init__(self, cfg): 22 | super(Model, self).__init__() 23 | backbone_name = cfg.model['backbone']['type'] if cfg.model['backbone']['type'] else None 24 | decode_name = cfg.model['decode_head']['type'] 25 | backbone_class = globals()[backbone_name] if backbone_name else None 26 | decode_class = globals()[decode_name] 27 | self.backbone = backbone_class(**cfg.model['backbone']) if backbone_name else None 28 | self.decode_head = decode_class(**cfg.model['decode_head']) 29 | 30 | def forward(self, x): 31 | if self.backbone: 32 | x = self.backbone(x) 33 | out = self.decode_head(x) 34 | return out 35 | -------------------------------------------------------------------------------- /docs/add_dataset.md: -------------------------------------------------------------------------------- 1 | ## Add Custom Dataset 2 | 3 | You need to follow the process below to add custom dataset. 4 | 5 | ### Dataset Preparation 6 | 7 | Please refer 8 | to [get_started.md](https://github.com/PANPEIWEN/Infrared-Small-Target-Segmentation-Framework/blob/main/docs/get_started.md) 9 | for dataset preparation. 10 | 11 | ### Add Dataset Config File 12 | 13 | 1. Create config file named _your_dataset_name.py_ in 14 | the [configs/\_base\_/datasets](https://github.com/PANPEIWEN/Infrared-Small-Target-Segmentation-Framework/blob/main/configs/_base_/datasets) 15 | folder. 16 | 2. Config code specification: 17 | 18 | ```python 19 | data = dict( 20 | # For identification, no practical use 21 | dataset_type='NUAA', 22 | # You dataset path 23 | data_root='/data1/ppw/works/All_ISTD/datasets/NUAA', 24 | # You want to resize the image size, this is the size of the image for training and testing if data_aug=False 25 | base_size=512, 26 | # You want to crop the image size, this is the size of the image for training and testing if data_aug=True 27 | crop_size=512, 28 | # Whether to use data augmentation, a variety of data augmentation will be added later for selection 29 | data_aug=True, 30 | # Suffix of the data image 31 | suffix='png', 32 | # DataLoader num_workers 33 | num_workers=8, 34 | # Train batch size 35 | train_batch=16, 36 | # Test batch size 37 | test_batch=8, 38 | # The filename where the training set is stored 39 | train_dir='trainval', 40 | # The filename where the testing set is stored 41 | test_dir='test' 42 | ) 43 | ``` -------------------------------------------------------------------------------- /model/AGPCNet/fusion.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/5/18 17:25 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : fusion.py 5 | # @Software: PyCharm 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | __all__ = ['AsymFusionModule'] 12 | 13 | 14 | class AsymFusionModule(nn.Module): 15 | def __init__(self, planes_high, planes_low, planes_out): 16 | super(AsymFusionModule, self).__init__() 17 | self.pa = nn.Sequential( 18 | nn.Conv2d(planes_low, planes_low//4, kernel_size=1), 19 | nn.BatchNorm2d(planes_low//4), 20 | nn.ReLU(True), 21 | 22 | nn.Conv2d(planes_low//4, planes_low, kernel_size=1), 23 | nn.BatchNorm2d(planes_low), 24 | nn.Sigmoid(), 25 | ) 26 | self.plus_conv = nn.Sequential( 27 | nn.Conv2d(planes_high, planes_low, kernel_size=1), 28 | nn.BatchNorm2d(planes_low), 29 | nn.ReLU(True) 30 | ) 31 | self.ca = nn.Sequential( 32 | nn.AdaptiveAvgPool2d(1), 33 | nn.Conv2d(planes_low, planes_low//4, kernel_size=1), 34 | nn.BatchNorm2d(planes_low//4), 35 | nn.ReLU(True), 36 | 37 | nn.Conv2d(planes_low//4, planes_low, kernel_size=1), 38 | nn.BatchNorm2d(planes_low), 39 | nn.Sigmoid(), 40 | ) 41 | self.end_conv = nn.Sequential( 42 | nn.Conv2d(planes_low, planes_out, 3, 1, 1), 43 | nn.BatchNorm2d(planes_out), 44 | nn.ReLU(True), 45 | ) 46 | 47 | def forward(self, x_high, x_low): 48 | x_high = self.plus_conv(x_high) 49 | pa = self.pa(x_low) 50 | ca = self.ca(x_high) 51 | 52 | feat = x_low + x_high 53 | feat = self.end_conv(feat) 54 | feat = feat * ca 55 | feat = feat * pa 56 | return feat -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/4/6 14:58 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : loss.py 5 | # @Software: PyCharm 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class SoftIoULoss(nn.Module): 11 | def __init__(self, **kwargs): 12 | super(SoftIoULoss, self).__init__() 13 | 14 | def forward(self, pred, target): 15 | # Old One 16 | pred = torch.sigmoid(pred) 17 | smooth = 1 18 | 19 | # print("pred.shape: ", pred.shape) 20 | # print("target.shape: ", target.shape) 21 | 22 | intersection = pred * target 23 | loss = (intersection.sum() + smooth) / (pred.sum() + target.sum() - intersection.sum() + smooth) 24 | 25 | # loss = (intersection.sum(axis=(1, 2, 3)) + smooth) / \ 26 | # (pred.sum(axis=(1, 2, 3)) + target.sum(axis=(1, 2, 3)) 27 | # - intersection.sum(axis=(1, 2, 3)) + smooth) 28 | 29 | loss = 1 - loss.mean() 30 | # loss = (1 - loss).mean() 31 | 32 | return loss 33 | 34 | 35 | class CrossEntropy(nn.Module): 36 | def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', 37 | label_smoothing=0.0, **kwargs): 38 | super(CrossEntropy, self).__init__() 39 | self.crit = nn.CrossEntropyLoss(weight, size_average, ignore_index, reduce, reduction, label_smoothing) 40 | 41 | def forward(self, pred, target): 42 | target.squeeze(dim=1) 43 | loss = self.crit(pred, target) 44 | return loss 45 | 46 | 47 | class BCEWithLogits(nn.Module): 48 | def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None, **kwargs): 49 | super(BCEWithLogits, self).__init__() 50 | self.crit = nn.BCEWithLogitsLoss(weight, size_average, reduce, reduction, pos_weight) 51 | 52 | def forward(self, pred, target): 53 | loss = self.crit(pred, target) 54 | return loss 55 | -------------------------------------------------------------------------------- /utils/save_model.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/4/6 20:22 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : save_model.py 5 | # @Software: PyCharm 6 | from datetime import datetime 7 | 8 | import numpy as np 9 | import os 10 | import torch.nn as nn 11 | import torch 12 | from skimage import measure 13 | import numpy 14 | 15 | 16 | def make_dir(dataset, model): 17 | now = datetime.now() 18 | dt_string = now.strftime("%Y_%m_%d_%H_%M_%S") 19 | save_dir = "%s_%s_%s" % (dataset, model, dt_string) 20 | os.makedirs('work_dirs/%s' % save_dir, exist_ok=True) 21 | return save_dir 22 | 23 | 24 | def save_ckpt(state, save_path, filename): 25 | torch.save(state, os.path.join(save_path, filename)) 26 | 27 | 28 | def save_model_and_result(dt_string, epoch, train_loss, test_loss, best_iou, recall, precision, save_mIoU_dir, 29 | save_other_metric_dir): 30 | with open(save_mIoU_dir, 'a') as f: 31 | f.write('{} - {:04d}:\t - train_loss: {:04f}:\t - test_loss: {:04f}:\t mIoU {:.4f}\n'.format(dt_string, epoch, 32 | train_loss, 33 | test_loss, 34 | best_iou)) 35 | with open(save_other_metric_dir, 'a') as f: 36 | f.write(dt_string) 37 | f.write('-') 38 | f.write(str(epoch)) 39 | f.write('\n') 40 | f.write('Recall-----:') 41 | for i in range(len(recall)): 42 | f.write(' ') 43 | f.write(str(round(recall[i], 8))) 44 | f.write(' ') 45 | f.write('\n') 46 | f.write('Precision--:') 47 | for i in range(len(precision)): 48 | f.write(' ') 49 | f.write(str(round(precision[i], 8))) 50 | f.write(' ') 51 | f.write('\n') 52 | -------------------------------------------------------------------------------- /docs/add_loss.md: -------------------------------------------------------------------------------- 1 | ## Add Loss Function 2 | 3 | You need to follow the process below to add loss function. 4 | 5 | ### Add or rewrite loss function 6 | 7 | You need to perform these operations 8 | in [utils/loss.py](https://github.com/PANPEIWEN/Infrared-Small-Target-Segmentation-Framework/blob/main/utils/loss.py). 9 | 10 | _Notice: Do not repeat the sigmoid or softmax operation in the model output layer and loss function, and generally 11 | perform this operation in loss function._ 12 | 13 | 1. Add custom loss function 14 | 15 | ```python 16 | class YourLossName(nn.Module): 17 | def __init__(self, args1, args2, ..., **kwargs): 18 | super(YourLossName, self).__init__() 19 | pass 20 | 21 | def forward(self, pred, target): 22 | pred = torch.sigmoid(pred) 23 | loss = ... 24 | pass 25 | return loss 26 | ``` 27 | 28 | 2. Add loss function in pytorch 29 | 30 | ```python 31 | """ 32 | We need to rewrite the loss function in pytorch here. 33 | For example, we rewrite nn.BCEWithLogitsLoss. 34 | """ 35 | # 'BCEWithLogits' is the new class name, you can also use the original name 'BCEWithLogitsLoss' 36 | class BCEWithLogits(nn.Module): 37 | # The parameters here need to be consistent with the parameters required by nn.BCEWithLogitsLoss, and must have **kwargs. 38 | def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None, **kwargs): 39 | super(BCEWithLogits, self).__init__() 40 | # Pass the parameters in __init__ to nn.BCEWithLogitsLoss 41 | self.crit = nn.BCEWithLogitsLoss(weight, size_average, reduce, reduction, pos_weight) 42 | 43 | def forward(self, pred, target): 44 | # Maybe a softmax or sigmoid operation is required. 45 | # If the data dimension is not correct, you need to perform the lifting and lowering operation here. 46 | # Calculate loss 47 | loss = self.crit(pred, target) 48 | return loss 49 | 50 | ``` 51 | 52 | 3. Add the loss function class name to \_\_all__ in 53 | the [build/build_criterion.py](https://github.com/PANPEIWEN/Infrared-Small-Target-Segmentation-Framework/blob/main/build/build_criterion.py) 54 | file. 55 | 56 | _How to modify the config file to use loss function, please refer to [docs/add_model.md](https://github.com/PANPEIWEN/Infrared-Small-Target-Segmentation-Framework/blob/main/docs/add_model.md)._ -------------------------------------------------------------------------------- /model/ABC/Module.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/3/17 15:56 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : Module.py 5 | # @Software: PyCharm 6 | from __future__ import print_function, division 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.data 10 | import torch 11 | 12 | 13 | class conv_block(nn.Module): 14 | """ 15 | Convolution Block 16 | """ 17 | 18 | def __init__(self, in_ch, out_ch): 19 | super(conv_block, self).__init__() 20 | 21 | self.conv = nn.Sequential( 22 | nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True), 23 | nn.BatchNorm2d(out_ch), 24 | nn.ReLU(inplace=True), 25 | nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True), 26 | nn.BatchNorm2d(out_ch), 27 | nn.ReLU(inplace=True)) 28 | 29 | def forward(self, x): 30 | x = self.conv(x) 31 | return x 32 | 33 | 34 | class up_conv(nn.Module): 35 | """ 36 | Up Convolution Block 37 | """ 38 | 39 | def __init__(self, in_ch, out_ch): 40 | super(up_conv, self).__init__() 41 | self.up = nn.Sequential( 42 | nn.Upsample(scale_factor=2, mode='bilinear'), 43 | nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True), 44 | nn.BatchNorm2d(out_ch), 45 | nn.ReLU(inplace=True) 46 | ) 47 | 48 | def forward(self, x): 49 | x = self.up(x) 50 | return x 51 | 52 | 53 | # self.active = torch.nn.Sigmoid() 54 | def _upsample_like(src, tar): 55 | src = F.upsample(src, size=tar.shape[2:], mode='bilinear') 56 | return src 57 | 58 | 59 | def conv_relu_bn(in_channel, out_channel, dirate): 60 | return nn.Sequential( 61 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=dirate, 62 | dilation=dirate), 63 | nn.BatchNorm2d(out_channel), 64 | nn.ReLU(inplace=True) 65 | ) 66 | 67 | 68 | class dconv_block(nn.Module): 69 | """ 70 | Convolution Block 71 | """ 72 | 73 | def __init__(self, in_ch, out_ch): 74 | super(dconv_block, self).__init__() 75 | self.conv1 = conv_relu_bn(in_ch, out_ch, 1) 76 | self.dconv1 = conv_relu_bn(out_ch, out_ch // 2, 2) 77 | self.dconv2 = conv_relu_bn(out_ch // 2, out_ch // 2, 4) 78 | self.dconv3 = conv_relu_bn(out_ch, out_ch, 2) 79 | self.conv2 = conv_relu_bn(out_ch * 2, out_ch, 1) 80 | 81 | def forward(self, x): 82 | x1 = self.conv1(x) 83 | dx1 = self.dconv1(x1) 84 | dx2 = self.dconv2(dx1) 85 | dx3 = self.dconv3(torch.cat((dx1, dx2), dim=1)) 86 | out = self.conv2(torch.cat((x1, dx3), dim=1)) 87 | return out 88 | -------------------------------------------------------------------------------- /utils/scheduler.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/5/31 17:19 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : scheduler.py 5 | # @Software: PyCharm 6 | import math 7 | 8 | 9 | def linear(optimizer, epoch, base_lr, warmup_epoch=5): 10 | if epoch == 0: 11 | lr = base_lr / warmup_epoch 12 | else: 13 | lr = epoch * (base_lr / warmup_epoch) 14 | for param_group in optimizer.param_groups: 15 | param_group['lr'] = lr 16 | 17 | 18 | class PolyLR(object): 19 | def __init__(self, optimizer, num_epochs, base_lr, warmup, power=0.9, warmup_epochs=5, **kwargs): 20 | super(PolyLR, self).__init__() 21 | self.optimizer = optimizer 22 | self.num_epochs = num_epochs 23 | self.base_lr = base_lr 24 | self.warmup = warmup 25 | self.warmup_epoch = warmup_epochs if self.warmup else 0 26 | self.power = power 27 | 28 | def step(self, epoch): 29 | if self.warmup and epoch <= self.warmup_epoch: 30 | globals()[self.warmup](self.optimizer, epoch, self.base_lr, self.warmup_epoch) 31 | else: 32 | lr = self.base_lr * (1 - (epoch - self.warmup_epoch) / self.num_epochs) ** self.power 33 | for param_group in self.optimizer.param_groups: 34 | param_group['lr'] = lr 35 | 36 | 37 | class CosineAnnealingLR(object): 38 | def __init__(self, optimizer, num_epochs, base_lr, warmup, min_lr=1e-4, warmup_epochs=5, **kwargs): 39 | super(CosineAnnealingLR, self).__init__() 40 | self.optimizer = optimizer 41 | self.num_epochs = num_epochs 42 | self.base_lr = base_lr 43 | self.warmup = warmup 44 | self.warmup_epoch = warmup_epochs if self.warmup else 0 45 | self.min_lr = min_lr 46 | 47 | def step(self, epoch): 48 | if self.warmup and epoch <= self.warmup_epoch: 49 | globals()[self.warmup](self.optimizer, epoch, self.base_lr, self.warmup_epoch) 50 | else: 51 | lr = self.min_lr + ((self.base_lr - self.min_lr) / 2) * ( 52 | 1 + math.cos((epoch - self.warmup_epoch) / self.num_epochs * math.pi)) 53 | for param_group in self.optimizer.param_groups: 54 | param_group['lr'] = lr 55 | 56 | 57 | class StepLR(object): 58 | def __init__(self, optimizer, step, base_lr, warmup, gamma=0.1, warmup_epochs=5, **kwargs): 59 | super(StepLR, self).__init__() 60 | self.optimizer = optimizer 61 | self.step = step 62 | self.gamma = gamma 63 | self.base_lr = base_lr 64 | self.warmup = warmup 65 | self.warmup_epoch = warmup_epochs if self.warmup else 0 66 | 67 | def step(self, epoch): 68 | if self.warmup and epoch <= self.warmup_epoch: 69 | globals()[self.warmup](self.optimizer, epoch, self.base_lr, self.warmup_epoch) 70 | else: 71 | if epoch in self.step: 72 | for param_group in self.optimizer.param_groups: 73 | param_group['lr'] *= self.gamma 74 | -------------------------------------------------------------------------------- /model/UIUNet/fusion.py: -------------------------------------------------------------------------------- 1 | # @Time : 2024/2/25 01:17 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : fusion.py 5 | # @Software: PyCharm 6 | import torch 7 | import torch.nn as nn 8 | 9 | class AsymBiChaFuseReduce(nn.Module): 10 | def __init__(self, in_high_channels, in_low_channels, out_channels=64, r=4): 11 | super(AsymBiChaFuseReduce, self).__init__() 12 | assert in_low_channels == out_channels 13 | self.high_channels = in_high_channels 14 | self.low_channels = in_low_channels 15 | self.out_channels = out_channels 16 | self.bottleneck_channels = int(out_channels // r) 17 | 18 | self.feature_high = nn.Sequential( 19 | nn.Conv2d(self.high_channels, self.out_channels, 1, 1, 0), 20 | nn.BatchNorm2d(out_channels), 21 | nn.ReLU(True), 22 | )##512 23 | 24 | self.topdown = nn.Sequential( 25 | nn.AdaptiveAvgPool2d((1, 1)), 26 | nn.Conv2d(self.out_channels, self.bottleneck_channels, 1, 1, 0), 27 | nn.BatchNorm2d(self.bottleneck_channels), 28 | nn.ReLU(True), 29 | 30 | nn.Conv2d(self.bottleneck_channels, self.out_channels, 1, 1, 0), 31 | nn.BatchNorm2d(self.out_channels), 32 | nn.Sigmoid(), 33 | )#512 34 | 35 | ##############add spatial attention ###Cross UtU############ 36 | self.bottomup = nn.Sequential( 37 | nn.Conv2d(self.low_channels, self.bottleneck_channels, 1, 1, 0), 38 | nn.BatchNorm2d(self.bottleneck_channels), 39 | nn.ReLU(True), 40 | # nn.Sigmoid(), 41 | 42 | SpatialAttention(kernel_size=3), 43 | # nn.Conv2d(self.bottleneck_channels, 2, 3, 1, 0), 44 | # nn.Conv2d(2, 1, 1, 1, 0), 45 | #nn.BatchNorm2d(self.out_channels), 46 | nn.Sigmoid() 47 | ) 48 | 49 | self.post = nn.Sequential( 50 | nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), 51 | nn.BatchNorm2d(self.out_channels), 52 | nn.ReLU(True), 53 | )#512 54 | 55 | def forward(self, xh, xl): 56 | xh = self.feature_high(xh) 57 | 58 | topdown_wei = self.topdown(xh) 59 | bottomup_wei = self.bottomup(xl * topdown_wei) 60 | xs1 = 2 * xl * topdown_wei #1 61 | out1 = self.post(xs1) 62 | 63 | xs2 = 2 * xh * bottomup_wei #1 64 | out2 = self.post(xs2) 65 | return out1,out2 66 | 67 | ############################## 68 | class SpatialAttention(nn.Module): 69 | def __init__(self, kernel_size=3): 70 | super(SpatialAttention, self).__init__() 71 | 72 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 73 | padding = 3 if kernel_size == 7 else 1 74 | 75 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 76 | 77 | def forward(self, x): 78 | avg_out = torch.mean(x, dim=1, keepdim=True) 79 | max_out, _ = torch.max(x, dim=1, keepdim=True) 80 | x = torch.cat([avg_out, max_out], dim=1) 81 | x = self.conv1(x) 82 | return x -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # The current code is too bloated, and we are refactoring the code, which will be released soon, and more models will be integrated. 💦💦💦🍺🍺🍺🚀🚀🚀 2 | 3 | # Infrared-Small-Target-Segmentation-Framework 4 | 5 | A general framework for infrared small target detection and segmentation. By modifying or add the config file, you can 6 | adjust various parameters, switch models and datasets and so on, and you can easily add your own models and datasets and so on. 7 | It is recommended to spend a little time reading the [tutorial](https://github.com/PANPEIWEN/Infrared-Small-Target-Segmentation-Framework#framework-usage-tutorial) before use, which can make you master the use of the framework faster. 8 | 9 | ## The tutorial and code are being improved... 10 | 11 | ## Installation 12 | 13 | Please refer 14 | to [get_started.md](https://github.com/PANPEIWEN/Infrared-Small-Target-Segmentation-Framework/blob/main/docs/get_started.md) 15 | for installation and dataset preparation. 16 | 17 | ## Training 18 | 19 | ### Single GPU Training 20 | 21 | ``` 22 | python train.py 23 | ``` 24 | 25 | For example, train ACM model with fpn in single gpu, run: 26 | 27 | ``` 28 | python train.py configs/acm/acm_res20_fpn_512x512_800e_nuaa.py 29 | ``` 30 | 31 | ### Multi GPU Training 32 | 33 | ```nproc_per_node``` is the number of gpus you are using. 34 | 35 | ``` 36 | python -m torch.distributed.launch --nproc_per_node= train.py 37 | ``` 38 | 39 | For example, train ACM model with fpn and 2 gpus, run: 40 | 41 | ``` 42 | python -m torch.distributed.launch --nproc_per_node=2 train.py configs/acm/acm_res20_fpn_512x512_800e_nuaa.py 43 | ``` 44 | 45 | ### Notes 46 | 47 | * You can specify the GPU at the second line of ```os.environ['CUDA_VISIBLE_DEVICES']``` in train.py. 48 | * Be sure to set args.local_rank to 0 if using Multi-GPU training. 49 | 50 | ## Test 51 | 52 | ``` 53 | python test.py 54 | ``` 55 | 56 | For example, test ACM model with fpn, run: 57 | 58 | ``` 59 | python test.py configs/acm/acm_res20_fpn_512x512_800e_nuaa.py work_dirs/acm_res20_fpn_512x512_800e_nuaa/20221009_231431/best.pth.tar 60 | ``` 61 | 62 | If you want to visualize the result, you only add ```--show``` at the end of the above command. 63 | 64 | ## Framework usage tutorial 65 | This part explains the config file in detail, which can make you understand the content of the config file more effectively and quickly, and master the overall framework. 66 | ### Add custom model 67 | Please refer 68 | to [add_model.md](https://github.com/PANPEIWEN/Infrared-Small-Target-Segmentation-Framework/blob/main/docs/add_model.md) 69 | for add custom model. 70 | ### Add custom dataset 71 | Please refer 72 | to [add_dataset.md](https://github.com/PANPEIWEN/Infrared-Small-Target-Segmentation-Framework/blob/main/docs/add_dataset.md) 73 | for add custom dataset. 74 | ### Add loss function 75 | Please refer 76 | to [add_loss.md](https://github.com/PANPEIWEN/Infrared-Small-Target-Segmentation-Framework/blob/main/docs/add_loss.md) 77 | for add loss function. 78 | 79 | _Notice: Although you need to use the loss function already in pytorch, you still need to do this._ 80 | ### Add optimizer and scheduler 81 | Please refer 82 | to [add_optimizer.md](https://github.com/PANPEIWEN/Infrared-Small-Target-Segmentation-Framework/blob/main/docs/add_optimizer.md) 83 | for add optimizer and scheduler. 84 | 85 | _Notice: Although you need to use the optimizer and scheduler already in pytorch, you still need to do this._ 86 | -------------------------------------------------------------------------------- /docs/add_model.md: -------------------------------------------------------------------------------- 1 | ## Add Custom Model 2 | 3 | You need to follow the process below to add custom model. 4 | 5 | ### Add Model File 6 | 7 | 1. Create a new **Python Package** named _YourModel_ in 8 | the [model](https://github.com/PANPEIWEN/Infrared-Small-Target-Segmentation-Framework/blob/main/model) folder. 9 | 2. Create model file named _yourmodel.py_ in the **YourModel** folder. 10 | 3. Model code specification: 11 | 12 | ```python 13 | class YourModelName(nn.Module): 14 | """ 15 | 1. You must add **kwargs. 16 | 2. If you want to use deep_supervision, you can add deep_supervision in __init__. 17 | Note: That parameter names can only be 'deep_supervision'. 18 | If you use deep_supervision, the final output must be at the end 19 | of the output list, for example: 20 | >>> out1 = self.conv1(x) 21 | >>> out2 = self.conv2(x) 22 | >>> out = self.final_conv(x) 23 | >>> return [out1, out2, out] if self.deep_supervision else out 24 | 25 | """ 26 | def __init__(self, args1, args2, ..., deep_supervision=True, **kwargs): 27 | super(YourModelName, self).__init__() 28 | pass 29 | 30 | def forward(self, x): 31 | pass 32 | return outs 33 | 34 | ``` 35 | 36 | 4. Modify [build_segmentor.py](https://github.com/PANPEIWEN/Infrared-Small-Target-Segmentation-Framework/blob/main/model/build_segmentor.py) 37 | file: 38 | 39 | ```python 40 | # add 41 | from model.YourModel.yourmodel import YourModelName 42 | 43 | __all__ = [..., 'YourModelName'] 44 | ``` 45 | 46 | ### Add Model Config File 47 | 48 | 1. Create a new directory named _yourmodel_ in 49 | the [configs](https://github.com/PANPEIWEN/Infrared-Small-Target-Segmentation-Framework/blob/main/configs) folder. 50 | 2. Create config file named _yourmodel_base_512x512_800e_nuaa.py_ in the **yourmodel** folder. 51 | 52 | _Tips: Recommended config file name naming rules:_ 53 | ```[model_name]_[model_scale]_[data_size]_[train_epoch]_[dataset_name]``` 54 | 3. Config code specification: 55 | 56 | ```python 57 | """ 58 | In the _base_ is the config you inherited, you can modify the places you need to modify after inheritance. 59 | For example, you want to modify train and test batch, you can write like this: 60 | >>> data = dict( 61 | >>> train_batch=32, 62 | >>> test_batch=32) 63 | You can use this method flexibly to make the config file more concise. 64 | It is recommended to use this method to add or modify the places you need to set, instead of modifying the 65 | config file in the _base_ folder. 66 | """ 67 | _base_ = [ 68 | # dataset config file 69 | '../_base_/datasets/nuaa.py', 70 | # run config file 71 | '../_base_/default_runtime.py', 72 | # optimizer and schedule config file 73 | '../_base_/schedules/schedule_500e.py' 74 | ] 75 | 76 | # model settings 77 | model = dict( 78 | # YourModelName 79 | name='YourModelName', 80 | type='EncoderDecoder', 81 | 82 | # If your model has a separate backbone, that: 83 | # >>> type=ClassName 84 | # type_info only represent information, no practical use. 85 | # This code cannot be deleted. 86 | backbone=dict( 87 | type=None, 88 | type_info='resnet', 89 | ), 90 | 91 | # The type must be the same as your model class name. 92 | # The parameters are the parameters inside your model __init__ 93 | decode_head=dict( 94 | type='YourModelName', 95 | args1=..., 96 | args2=..., 97 | deep_supervision=True, 98 | ... 99 | ), 100 | 101 | # The type must in __all__ of build_criterion.py 102 | # If you want to set some parameters, you just need to add a key-value pair after type. 103 | # For example: 104 | # >>> loss=dict(type='BCEWithLogits', reduction='mean') 105 | loss=dict(type='SoftIoULoss') 106 | ) 107 | 108 | # The type must in __all__ of build_optimizer.py, you need set parameters in setting. 109 | # There cannot be key-value pairs that are not in the optimizer parameter list here. 110 | optimizer = dict( 111 | type='SGD', 112 | setting=dict(lr=0.01, momentum=0.9, weight_decay=0.0005) 113 | ) 114 | ``` -------------------------------------------------------------------------------- /model/RDIAN/cbam.py: -------------------------------------------------------------------------------- 1 | # @Time : 2024/2/24 15:31 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : cbam.py 5 | # @Software: PyCharm 6 | import torch 7 | import math 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class BasicConv(nn.Module): 13 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, 14 | bn=True, bias=False): 15 | super(BasicConv, self).__init__() 16 | self.out_channels = out_planes 17 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, 18 | dilation=dilation, groups=groups, bias=bias) 19 | self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None 20 | self.relu = nn.ReLU() if relu else None 21 | 22 | def forward(self, x): 23 | x = self.conv(x) 24 | if self.bn is not None: 25 | x = self.bn(x) 26 | if self.relu is not None: 27 | x = self.relu(x) 28 | return x 29 | 30 | 31 | class Flatten(nn.Module): 32 | def forward(self, x): 33 | return x.view(x.size(0), -1) 34 | 35 | 36 | class ChannelGate(nn.Module): 37 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): 38 | super(ChannelGate, self).__init__() 39 | self.gate_channels = gate_channels 40 | self.mlp = nn.Sequential( 41 | Flatten(), 42 | nn.Linear(gate_channels, gate_channels // reduction_ratio), 43 | nn.ReLU(), 44 | nn.Linear(gate_channels // reduction_ratio, gate_channels) 45 | ) 46 | self.pool_types = pool_types 47 | 48 | def forward(self, x): 49 | channel_att_sum = None 50 | for pool_type in self.pool_types: 51 | if pool_type == 'avg': 52 | avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 53 | channel_att_raw = self.mlp(avg_pool) 54 | elif pool_type == 'max': 55 | max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 56 | channel_att_raw = self.mlp(max_pool) 57 | elif pool_type == 'lp': 58 | lp_pool = F.lp_pool2d(x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 59 | channel_att_raw = self.mlp(lp_pool) 60 | elif pool_type == 'lse': 61 | # LSE pool only 62 | lse_pool = logsumexp_2d(x) 63 | channel_att_raw = self.mlp(lse_pool) 64 | 65 | if channel_att_sum is None: 66 | channel_att_sum = channel_att_raw 67 | else: 68 | channel_att_sum = channel_att_sum + channel_att_raw 69 | 70 | scale = F.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x) 71 | return x * scale 72 | 73 | 74 | def logsumexp_2d(tensor): 75 | tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) 76 | s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) 77 | outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() 78 | return outputs 79 | 80 | 81 | class ChannelPool(nn.Module): 82 | def forward(self, x): 83 | return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1) 84 | 85 | 86 | class SpatialGate(nn.Module): 87 | def __init__(self): 88 | super(SpatialGate, self).__init__() 89 | kernel_size = 7 90 | self.compress = ChannelPool() 91 | self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False) 92 | 93 | def forward(self, x): 94 | x_compress = self.compress(x) 95 | x_out = self.spatial(x_compress) 96 | scale = F.sigmoid(x_out) 97 | return x * scale 98 | 99 | 100 | class CBAM(nn.Module): 101 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): 102 | super(CBAM, self).__init__() 103 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) 104 | self.SpatialGate = SpatialGate() 105 | 106 | def forward(self, x): 107 | x_out = self.ChannelGate(x) 108 | x_out = self.SpatialGate(x_out) 109 | return x_out 110 | -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/4/6 14:41 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : data.py 5 | # @Software: PyCharm 6 | import random 7 | import sys 8 | import os.path as osp 9 | import os 10 | from PIL import Image, ImageOps, ImageFilter 11 | import torchvision.transforms as transforms 12 | import torch.utils.data as Data 13 | import torch 14 | import numpy as np 15 | import math 16 | 17 | 18 | class DatasetLoad(Data.Dataset): 19 | def __init__(self, data_root, base_size, crop_size, mode, train_dir, test_dir, data_aug=True, suffix='png', 20 | rgb=True, **kwargs): 21 | self.base_size = base_size 22 | self.crop_size = crop_size 23 | self.mode = mode 24 | self.data_aug = data_aug 25 | self.rgb = rgb 26 | assert mode in ['train', 'test'], 'The mode should be train or test' 27 | if mode == 'train': 28 | self.data_dir = osp.join(data_root, train_dir) 29 | else: 30 | self.data_dir = osp.join(data_root, test_dir) 31 | 32 | self.img_names = [] 33 | for img in os.listdir(osp.join(self.data_dir, 'images')): 34 | if img.endswith(suffix): 35 | self.img_names.append(img) 36 | 37 | self.rgb_transform = transforms.Compose([ 38 | transforms.ToTensor(), 39 | transforms.Normalize([.485, .456, .406], [.229, .224, .225]), 40 | ]) 41 | 42 | self.gray_transform = transforms.Compose([ 43 | transforms.ToTensor(), 44 | transforms.Normalize([-0.1246], [1.0923]) 45 | ]) 46 | 47 | def _sync_transform(self, img, mask): 48 | if self.mode == 'train' and self.data_aug: 49 | if random.random() < 0.5: 50 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 51 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 52 | crop_size = self.crop_size 53 | long_size = random.randint( 54 | int(self.base_size * 0.5), int(self.base_size * 2.0)) 55 | # int(self.base_size * 0.8), int(self.base_size * 1.2)) 56 | w, h = img.size 57 | if h > w: 58 | oh = long_size 59 | ow = int(1.0 * w * long_size / h + 0.5) 60 | short_size = ow 61 | else: 62 | ow = long_size 63 | oh = int(1.0 * h * long_size / w + 0.5) 64 | short_size = oh 65 | img = img.resize((ow, oh), Image.BILINEAR) 66 | mask = mask.resize((ow, oh), Image.NEAREST) 67 | if short_size < crop_size: 68 | padh = crop_size - oh if oh < crop_size else 0 69 | padw = crop_size - ow if ow < crop_size else 0 70 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 71 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0) 72 | w, h = img.size 73 | x1 = random.randint(0, w - crop_size) 74 | y1 = random.randint(0, h - crop_size) 75 | img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size)) 76 | mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size)) 77 | if random.random() < 0.5: 78 | img = img.filter(ImageFilter.GaussianBlur(radius=random.random())) 79 | img, mask = np.array(img), np.array(mask) 80 | img = self.rgb_transform(img) if self.rgb else self.gray_transform(img) 81 | mask = transforms.ToTensor()(mask) 82 | else: 83 | img = img.resize((self.base_size, self.base_size), Image.BILINEAR) 84 | mask = mask.resize((self.base_size, self.base_size), Image.NEAREST) 85 | img, mask = np.array(img), np.array(mask) 86 | img = self.rgb_transform(img) if self.rgb else self.gray_transform(img) 87 | mask = transforms.ToTensor()(mask) 88 | return img, mask 89 | 90 | def __getitem__(self, item): 91 | img_name = self.img_names[item] 92 | img_path = osp.join(self.data_dir, 'images', img_name) 93 | label_path = osp.join(self.data_dir, 'masks', img_name) 94 | img = Image.open(img_path).convert('RGB') if self.rgb else Image.open(img_path).convert('L') 95 | mask = Image.open(label_path).convert('L') 96 | img, mask = self._sync_transform(img, mask) 97 | return img, mask 98 | 99 | def __len__(self): 100 | return len(self.img_names) 101 | -------------------------------------------------------------------------------- /utils/visual.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/4/7 17:01 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : visual.py 5 | # @Software: PyCharm 6 | import os 7 | import shutil 8 | 9 | import numpy as np 10 | import torch 11 | from PIL import Image 12 | from matplotlib import pyplot as plt 13 | 14 | 15 | def make_show_dir(show_dir): 16 | if not os.path.exists(show_dir): 17 | os.mkdir(show_dir) 18 | 19 | if os.path.exists(os.path.join(show_dir, 'result')): 20 | shutil.rmtree(os.path.join(show_dir, 'result')) # 删除目录,包括目录下的所有文件 21 | os.mkdir(os.path.join(show_dir, 'result')) 22 | 23 | if os.path.exists(os.path.join(show_dir, 'fuse')): 24 | shutil.rmtree(os.path.join(show_dir, 'fuse')) # 删除目录,包括目录下的所有文件 25 | os.mkdir(os.path.join(show_dir, 'fuse')) 26 | 27 | 28 | def save_Pred_GT(preds, labels, show_dir, num, cfg): 29 | img_name = os.listdir(os.path.join(cfg.data['data_root'], cfg.data['test_dir'], 'images')) 30 | val_img_ids = [] 31 | for img in img_name: 32 | val_img_ids.append(img.split('.')[0]) 33 | # predsss = ((torch.sigmoid((pred)).cpu().numpy()) * 255).astype('int64') 34 | batch = preds.size()[0] 35 | for b in range(batch): 36 | predsss = np.array((preds[b, :, :, :] > 0).cpu()).astype('int64') * 255 37 | predsss = np.uint8(predsss) 38 | labelsss = labels[b, :, :, :] * 255 39 | labelsss = np.uint8(labelsss.cpu()) 40 | 41 | img = Image.fromarray(predsss.reshape(cfg.data['crop_size'], cfg.data['crop_size'])) 42 | img.save(show_dir + '/result/' + '%s_Pred' % (val_img_ids[num + b]) + '.' + cfg.data['suffix']) 43 | img = Image.fromarray(labelsss.reshape(cfg.data['crop_size'], cfg.data['crop_size'])) 44 | img.save(show_dir + '/result/' + '%s_GT' % (val_img_ids[num + b]) + '.' + cfg.data['suffix']) 45 | 46 | 47 | def save_Pred_GT_visulize(pred, img_demo_dir, img_demo_index, suffix, cfg): 48 | predsss = np.array((pred > 0).cpu()).astype('int64') * 255 49 | predsss = np.uint8(predsss) 50 | 51 | img = Image.fromarray(predsss.reshape(cfg.data['crop_size'], cfg.data['crop_size'])) 52 | img.save(img_demo_dir + '/' + '%s_Pred' % (img_demo_index) + suffix) 53 | 54 | plt.figure(figsize=(10, 6)) 55 | plt.subplot(1, 2, 1) 56 | img = plt.imread(img_demo_dir + '/' + img_demo_index + suffix) 57 | plt.imshow(img, cmap='gray') 58 | plt.xlabel("Raw Imamge", size=11) 59 | 60 | plt.subplot(1, 2, 2) 61 | img = plt.imread(img_demo_dir + '/' + '%s_Pred' % (img_demo_index) + suffix) 62 | plt.imshow(img, cmap='gray') 63 | plt.xlabel("Predicts", size=11) 64 | 65 | plt.savefig(img_demo_dir + '/' + img_demo_index + "_fuse" + suffix, facecolor='w', edgecolor='red') 66 | plt.show() 67 | 68 | 69 | def total_show_generation(show_dir, cfg): 70 | source_image_path = os.path.join(cfg.data['data_root'], cfg.data['test_dir'], 'images') 71 | ids = [] 72 | img_name = os.listdir(source_image_path) 73 | for img in img_name: 74 | ids.append(img.split('.')[0]) 75 | for i in range(len(ids)): 76 | source_image = source_image_path + '/' + ids[i] + '.' + cfg.data['suffix'] 77 | target_image = show_dir + '/result/' + ids[i] + '.' + cfg.data['suffix'] 78 | shutil.copy(source_image, target_image) 79 | for i in range(len(ids)): 80 | source_image = show_dir + '/result/' + ids[i] + '.' + cfg.data['suffix'] 81 | img = Image.open(source_image) 82 | img = img.resize((cfg.data['crop_size'], cfg.data['crop_size']), Image.ANTIALIAS) 83 | img.save(source_image) 84 | for m in range(len(ids)): 85 | print('Processing the %d image' % (m + 1)) 86 | plt.figure(figsize=(10, 6)) 87 | plt.subplot(1, 3, 1) 88 | img = plt.imread(show_dir + '/result/' + ids[m] + '.' + cfg.data['suffix']) 89 | plt.imshow(img, cmap='gray') 90 | plt.xlabel("Raw Image", size=11) 91 | 92 | plt.subplot(1, 3, 2) 93 | img = plt.imread(show_dir + '/result/' + ids[m] + '_GT' + '.' + cfg.data['suffix']) 94 | plt.imshow(img, cmap='gray') 95 | plt.xlabel("Ground Truth", size=11) 96 | 97 | plt.subplot(1, 3, 3) 98 | img = plt.imread(show_dir + '/result/' + ids[m] + '_Pred' + '.' + cfg.data['suffix']) 99 | plt.imshow(img, cmap='gray') 100 | plt.xlabel("Predicts", size=11) 101 | plt.savefig(show_dir + '/fuse/' + ids[m].split('.')[0] + "_fuse" + '.' + cfg.data['suffix'], 102 | facecolor='w', edgecolor='red') 103 | -------------------------------------------------------------------------------- /model/RDIAN/direction.py: -------------------------------------------------------------------------------- 1 | # @Time : 2024/2/24 15:31 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : direction.py 5 | # @Software: PyCharm 6 | import torch 7 | import math 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class Conv_d11(nn.Module): 13 | def __init__(self): 14 | super(Conv_d11, self).__init__() 15 | kernel = [[-1, 0, 0, 0, 0], 16 | [0, 0, 0, 0, 0], 17 | [0, 0, 1, 0, 0], 18 | [0, 0, 0, 0, 0], 19 | [0, 0, 0, 0, 0]] 20 | 21 | kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0) 22 | self.weight = nn.Parameter(data=kernel, requires_grad=False) 23 | 24 | def forward(self, input): 25 | return F.conv2d(input, self.weight, padding=2) 26 | 27 | 28 | class Conv_d12(nn.Module): 29 | def __init__(self): 30 | super(Conv_d12, self).__init__() 31 | kernel = [[0, 0, -1, 0, 0], 32 | [0, 0, 0, 0, 0], 33 | [0, 0, 1, 0, 0], 34 | [0, 0, 0, 0, 0], 35 | [0, 0, 0, 0, 0]] 36 | 37 | kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0) 38 | self.weight = nn.Parameter(data=kernel, requires_grad=False) 39 | 40 | def forward(self, input): 41 | return F.conv2d(input, self.weight, padding=2) 42 | 43 | 44 | class Conv_d13(nn.Module): 45 | def __init__(self): 46 | super(Conv_d13, self).__init__() 47 | kernel = [[0, 0, 0, 0, -1], 48 | [0, 0, 0, 0, 0], 49 | [0, 0, 1, 0, 0], 50 | [0, 0, 0, 0, 0], 51 | [0, 0, 0, 0, 0]] 52 | 53 | kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0) 54 | self.weight = nn.Parameter(data=kernel, requires_grad=False) 55 | 56 | def forward(self, input): 57 | return F.conv2d(input, self.weight, padding=2) 58 | 59 | 60 | class Conv_d14(nn.Module): 61 | def __init__(self): 62 | super(Conv_d14, self).__init__() 63 | kernel = [[0, 0, 0, 0, 0], 64 | [0, 0, 0, 0, 0], 65 | [0, 0, 1, 0, -1], 66 | [0, 0, 0, 0, 0], 67 | [0, 0, 0, 0, 0]] 68 | 69 | kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0) 70 | self.weight = nn.Parameter(data=kernel, requires_grad=False) 71 | 72 | def forward(self, input): 73 | return F.conv2d(input, self.weight, padding=2) 74 | 75 | 76 | class Conv_d15(nn.Module): 77 | def __init__(self): 78 | super(Conv_d15, self).__init__() 79 | kernel = [[0, 0, 0, 0, 0], 80 | [0, 0, 0, 0, 0], 81 | [0, 0, 1, 0, 0], 82 | [0, 0, 0, 0, 0], 83 | [0, 0, 0, 0, -1]] 84 | 85 | kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0) 86 | self.weight = nn.Parameter(data=kernel, requires_grad=False) 87 | 88 | def forward(self, input): 89 | return F.conv2d(input, self.weight, padding=2) 90 | 91 | 92 | class Conv_d16(nn.Module): 93 | def __init__(self): 94 | super(Conv_d16, self).__init__() 95 | kernel = [[0, 0, 0, 0, 0], 96 | [0, 0, 0, 0, 0], 97 | [0, 0, 1, 0, 0], 98 | [0, 0, 0, 0, 0], 99 | [0, 0, -1, 0, 0]] 100 | 101 | kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0) 102 | self.weight = nn.Parameter(data=kernel, requires_grad=False) 103 | 104 | def forward(self, input): 105 | return F.conv2d(input, self.weight, padding=2) 106 | 107 | 108 | class Conv_d17(nn.Module): 109 | def __init__(self): 110 | super(Conv_d17, self).__init__() 111 | kernel = [[0, 0, 0, 0, 0], 112 | [0, 0, 0, 0, 0], 113 | [0, 0, 1, 0, 0], 114 | [0, 0, 0, 0, 0], 115 | [-1, 0, 0, 0, 0]] 116 | 117 | kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0) 118 | self.weight = nn.Parameter(data=kernel, requires_grad=False) 119 | 120 | def forward(self, input): 121 | return F.conv2d(input, self.weight, padding=2) 122 | 123 | 124 | class Conv_d18(nn.Module): 125 | def __init__(self): 126 | super(Conv_d18, self).__init__() 127 | kernel = [[0, 0, 0, 0, 0], 128 | [0, 0, 0, 0, 0], 129 | [-1, 0, 1, 0, 0], 130 | [0, 0, 0, 0, 0], 131 | [0, 0, 0, 0, 0]] 132 | 133 | kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0) 134 | self.weight = nn.Parameter(data=kernel, requires_grad=False) 135 | 136 | def forward(self, input): 137 | return F.conv2d(input, self.weight, padding=2) 138 | -------------------------------------------------------------------------------- /model/RDIAN/rdian.py: -------------------------------------------------------------------------------- 1 | # @Time : 2024/2/24 15:32 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : rdian.py 5 | # @Software: PyCharm 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from model.RDIAN.cbam import * 11 | from model.RDIAN.direction import * 12 | 13 | 14 | class _FCNHead(nn.Module): 15 | def __init__(self, in_channels, out_channels): 16 | super(_FCNHead, self).__init__() 17 | inter_channels = in_channels // 4 18 | self.block = nn.Sequential( 19 | nn.Conv2d(in_channels, inter_channels, 3, 1, 1, bias=False), 20 | nn.BatchNorm2d(inter_channels), 21 | nn.ReLU(True), 22 | nn.Dropout(0.1), 23 | nn.Conv2d(inter_channels, out_channels, 1, 1, 0) 24 | ) 25 | 26 | def forward(self, x): 27 | return self.block(x) 28 | 29 | 30 | def conv_batch(in_num, out_num, kernel_size=3, padding=1, stride=1): 31 | return nn.Sequential( 32 | nn.Conv2d(in_num, out_num, kernel_size=kernel_size, stride=stride, padding=padding, bias=False), 33 | nn.BatchNorm2d(out_num), 34 | nn.LeakyReLU()) 35 | 36 | 37 | class NewBlock(nn.Module): 38 | def __init__(self, in_channels, stride, kernel_size, padding): 39 | super(NewBlock, self).__init__() 40 | reduced_channels = int(in_channels / 2) 41 | self.layer1 = conv_batch(in_channels, reduced_channels, kernel_size=kernel_size, padding=padding, stride=stride) 42 | self.layer2 = conv_batch(reduced_channels, in_channels, kernel_size=kernel_size, padding=padding, stride=stride) 43 | 44 | def forward(self, x): 45 | residual = x 46 | out = self.layer1(x) 47 | out = self.layer2(out) 48 | out += residual 49 | return out 50 | 51 | 52 | class RDIAN(nn.Module): 53 | def __init__(self, **kwargs): 54 | super(RDIAN, self).__init__() 55 | accumulate_params = "none" 56 | self.conv1 = conv_batch(1, 16) 57 | self.conv2 = conv_batch(16, 32, stride=2) 58 | self.residual_block0 = self.make_layer(NewBlock, in_channels=32, num_blocks=1, kernel_size=1, padding=0, 59 | stride=1) 60 | self.residual_block1 = self.make_layer(NewBlock, in_channels=32, num_blocks=2, kernel_size=3, padding=1, 61 | stride=1) 62 | self.residual_block2 = self.make_layer(NewBlock, in_channels=32, num_blocks=2, kernel_size=5, padding=2, 63 | stride=1) 64 | self.residual_block3 = self.make_layer(NewBlock, in_channels=32, num_blocks=2, kernel_size=7, padding=3, 65 | stride=1) 66 | self.cbam = CBAM(32, 32) 67 | self.conv_cat = conv_batch(4 * 32, 32, 3, padding=1) 68 | self.conv_res = conv_batch(16, 32, 1, padding=0) 69 | self.relu = nn.ReLU(True) 70 | 71 | self.d11 = Conv_d11() 72 | self.d12 = Conv_d12() 73 | self.d13 = Conv_d13() 74 | self.d14 = Conv_d14() 75 | self.d15 = Conv_d15() 76 | self.d16 = Conv_d16() 77 | self.d17 = Conv_d17() 78 | self.d18 = Conv_d18() 79 | 80 | self.head = _FCNHead(32, 1) 81 | 82 | def forward(self, x): 83 | _, _, hei, wid = x.shape 84 | d11 = self.d11(x) 85 | d12 = self.d12(x) 86 | d13 = self.d13(x) 87 | d14 = self.d14(x) 88 | d15 = self.d15(x) 89 | d16 = self.d16(x) 90 | d17 = self.d17(x) 91 | d18 = self.d18(x) 92 | md = d11.mul(d15) + d12.mul(d16) + d13.mul(d17) + d14.mul(d18) 93 | md = F.sigmoid(md) 94 | 95 | out1 = self.conv1(x) 96 | out2 = out1.mul(md) 97 | out = self.conv2(out1 + out2) 98 | 99 | c0 = self.residual_block0(out) 100 | c1 = self.residual_block1(out) 101 | c2 = self.residual_block2(out) 102 | c3 = self.residual_block3(out) 103 | 104 | x_cat = self.conv_cat(torch.cat((c0, c1, c2, c3), dim=1)) # [16,32,240,240] 105 | x_a = self.cbam(x_cat) 106 | 107 | temp = F.interpolate(x_a, size=[hei, wid], mode='bilinear') 108 | temp2 = self.conv_res(out1) 109 | x_new = self.relu(temp + temp2) 110 | self.x_new = x_new 111 | pred = self.head(x_new) 112 | 113 | return pred 114 | 115 | def make_layer(self, block, in_channels, num_blocks, stride, kernel_size, padding): 116 | layers = [] 117 | for i in range(0, num_blocks): 118 | layers.append(block(in_channels, stride, kernel_size, padding)) 119 | return nn.Sequential(*layers) 120 | 121 | 122 | if __name__ == '__main__': 123 | x = torch.rand(8, 1, 256, 256) 124 | model = RDIAN() 125 | out = model(x) 126 | print(out.size()) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/10/1 20:01 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : rebuild_test.py 5 | # @Software: PyCharm 6 | import argparse 7 | import os 8 | from mmcv import Config 9 | from tqdm import tqdm 10 | from build.build_model import build_model 11 | from build.build_criterion import build_criterion 12 | from build.build_dataset import build_dataset 13 | 14 | from utils.visual import * 15 | from utils.tools import * 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser( 20 | description='mmseg test (and eval) a model') 21 | parser.add_argument('config', help='test config file path') 22 | parser.add_argument('checkpoint', help='checkpoint file') 23 | parser.add_argument( 24 | '--work-dir', 25 | help=('if specified, the evaluation metric results will be dumped' 26 | 'into the directory as txt')) 27 | parser.add_argument('--show', action='store_true', help='show results') 28 | parser.add_argument( 29 | '--show-dir', help='directory where painted images will be saved') 30 | parser.add_argument( 31 | '--gpu-id', 32 | type=int, 33 | default=0, 34 | help='id of gpu to use ' 35 | '(only applicable to non-distributed testing)') 36 | parser.add_argument('--local_rank', type=int, default=-1) 37 | args = parser.parse_args() 38 | if 'LOCAL_RANK' not in os.environ: 39 | os.environ['LOCAL_RANK'] = str(args.local_rank) 40 | return args 41 | 42 | 43 | class Test(object): 44 | def __init__(self, args, cfg): 45 | super(Test, self).__init__() 46 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s: %(message)s', datefmt='%F %T') 47 | self.cfg = cfg 48 | self.deep_supervision = 'deep_supervision' in self.cfg.model['decode_head'] 49 | self.save_dir = args.work_dir if args.work_dir else os.path.dirname(os.path.abspath(args.checkpoint)) 50 | self.show_dir = args.show_dir if args.show_dir else os.path.join(self.save_dir, 'show') 51 | make_show_dir(self.show_dir) if args.show else empty_function() 52 | _, self.test_data, _, self.img_num = build_dataset(args, self.cfg) 53 | self.criterion = build_criterion(self.cfg) 54 | self.model = build_model(self.cfg) 55 | self.mIoU_metric = SigmoidMetric() 56 | self.nIoU_metric = SamplewiseSigmoidMetric(1, score_thresh=0.5) 57 | self.ROC = ROCMetric(1, 10) 58 | self.PD_FA = PD_FA(1, 10, cfg) 59 | self.best_recall = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 60 | self.best_precision = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 61 | self.mIoU_metric.reset() 62 | self.nIoU_metric.reset() 63 | self.PD_FA.reset() 64 | 65 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 66 | checkpoint = torch.load(args.checkpoint) 67 | self.model.load_state_dict(checkpoint['state_dict']) 68 | logging.info("Model Initializing") 69 | self.model = self.model.to(self.device) 70 | self.model.eval() 71 | tbar = tqdm(self.test_data) 72 | losses = [] 73 | 74 | with torch.no_grad(): 75 | for i, (img, mask) in enumerate(tbar): 76 | img, mask = data2device(args, (img, mask), self.device) 77 | preds = self.model(img) 78 | loss, preds = compute_loss(preds, mask, self.deep_supervision, cfg, self.criterion) 79 | losses.append(loss.item()) 80 | 81 | self.ROC.update(preds, mask) 82 | self.mIoU_metric.update(preds, mask) 83 | self.nIoU_metric.update(preds, mask) 84 | self.PD_FA.update(preds, mask) 85 | _, mIoU = self.mIoU_metric.get() 86 | _, nIoU = self.nIoU_metric.get() 87 | ture_positive_rate, false_positive_rate, recall, precision, F1_score = self.ROC.get() 88 | tbar.set_description( 89 | 'Loss %.4f, mIoU %.4f, nIoU %.4f, F1-score %.4f' % (np.mean(losses), mIoU, nIoU, F1_score)) 90 | if args.show: 91 | save_Pred_GT(preds, mask, self.show_dir, cfg.data['test_batch'] * i, cfg) 92 | FA, PD = self.PD_FA.get(self.img_num) 93 | save_test_config(cfg, self.save_dir) 94 | save_result_for_test(self.save_dir, mIoU, nIoU, recall, precision, FA, PD, F1_score, ture_positive_rate, 95 | false_positive_rate) 96 | if args.show: 97 | total_show_generation(self.show_dir, cfg) 98 | logging.info('Finishing') 99 | logging.info('mIoU: %.4f, nIoU: %.4f, F1-score: %.4f' % (mIoU, nIoU, F1_score)) 100 | 101 | 102 | def main(args): 103 | cfg = Config.fromfile(args.config) 104 | tester = Test(args, cfg) 105 | 106 | 107 | if __name__ == "__main__": 108 | args = parse_args() 109 | main(args) 110 | -------------------------------------------------------------------------------- /model/MTUet/vit.py: -------------------------------------------------------------------------------- 1 | # @Time : 2024/2/6 13:33 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : vit.py 5 | # @Software: PyCharm 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | from einops import rearrange, repeat 10 | 11 | 12 | class MultiHeadAttention(nn.Module): 13 | def __init__(self, embedding_dim, head_num): 14 | super().__init__() 15 | 16 | self.head_num = head_num 17 | self.dk = (embedding_dim // head_num) ** 1 / 2 18 | 19 | self.qkv_layer = nn.Linear(embedding_dim, embedding_dim * 3, bias=False) 20 | self.out_attention = nn.Linear(embedding_dim, embedding_dim, bias=False) 21 | 22 | def forward(self, x, mask=None): 23 | qkv = self.qkv_layer(x) 24 | 25 | query, key, value = tuple(rearrange(qkv, 'b t (d k h ) -> k b h t d ', k=3, h=self.head_num)) 26 | energy = torch.einsum("... i d , ... j d -> ... i j", query, key) * self.dk 27 | 28 | if mask is not None: 29 | energy = energy.masked_fill(mask, -np.inf) 30 | 31 | attention = torch.softmax(energy, dim=-1) 32 | 33 | x = torch.einsum("... i j , ... j d -> ... i d", attention, value) 34 | 35 | x = rearrange(x, "b h t d -> b t (h d)") 36 | x = self.out_attention(x) 37 | 38 | return x 39 | 40 | 41 | class MLP(nn.Module): 42 | def __init__(self, embedding_dim, mlp_dim): 43 | super().__init__() 44 | 45 | self.mlp_layers = nn.Sequential( 46 | nn.Linear(embedding_dim, mlp_dim), 47 | nn.GELU(), 48 | nn.Dropout(0.1), 49 | nn.Linear(mlp_dim, embedding_dim), 50 | nn.Dropout(0.1) 51 | ) 52 | 53 | def forward(self, x): 54 | x = self.mlp_layers(x) 55 | 56 | return x 57 | 58 | 59 | class TransformerEncoderBlock(nn.Module): 60 | def __init__(self, embedding_dim, head_num, mlp_dim): 61 | super().__init__() 62 | 63 | self.multi_head_attention = MultiHeadAttention(embedding_dim, head_num) 64 | self.mlp = MLP(embedding_dim, mlp_dim) 65 | 66 | self.layer_norm1 = nn.LayerNorm(embedding_dim) 67 | self.layer_norm2 = nn.LayerNorm(embedding_dim) 68 | 69 | self.dropout = nn.Dropout(0.1) 70 | 71 | def forward(self, x): 72 | _x = self.multi_head_attention(x) 73 | _x = self.dropout(_x) 74 | x = x + _x 75 | x = self.layer_norm1(x) 76 | 77 | _x = self.mlp(x) 78 | x = x + _x 79 | x = self.layer_norm2(x) 80 | 81 | return x 82 | 83 | 84 | class TransformerEncoder(nn.Module): 85 | def __init__(self, embedding_dim, head_num, mlp_dim, block_num=12): 86 | super().__init__() 87 | 88 | self.layer_blocks = nn.ModuleList( 89 | [TransformerEncoderBlock(embedding_dim, head_num, mlp_dim) for _ in range(block_num)]) 90 | 91 | def forward(self, x): 92 | for layer_block in self.layer_blocks: 93 | x = layer_block(x) 94 | 95 | return x 96 | 97 | 98 | class ViT(nn.Module): 99 | def __init__(self, img_dim, in_channels, embedding_dim, head_num, mlp_dim, 100 | block_num, patch_dim, classification=True, num_classes=1): 101 | super().__init__() 102 | 103 | self.patch_dim = patch_dim 104 | self.classification = classification 105 | self.num_tokens = (img_dim // patch_dim) ** 2 106 | self.token_dim = in_channels * (patch_dim ** 2) 107 | 108 | self.projection = nn.Linear(self.token_dim, embedding_dim) 109 | self.embedding = nn.Parameter(torch.rand(self.num_tokens + 1, embedding_dim)) 110 | 111 | self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim)) 112 | 113 | self.dropout = nn.Dropout(0.1) 114 | 115 | self.transformer = TransformerEncoder(embedding_dim, head_num, mlp_dim, block_num) 116 | 117 | if self.classification: 118 | self.mlp_head = nn.Linear(embedding_dim, num_classes) 119 | 120 | def forward(self, x): 121 | 122 | img_patches = rearrange(x, 123 | 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)', 124 | patch_x=self.patch_dim, patch_y=self.patch_dim) 125 | 126 | batch_size, tokens, _ = img_patches.shape 127 | 128 | project = self.projection(img_patches) 129 | token = repeat(self.cls_token, 'b ... -> (b batch_size) ...', 130 | batch_size=batch_size) 131 | 132 | patches = torch.cat([token, project], dim=1) 133 | patches += self.embedding[:tokens + 1, :] 134 | 135 | x = self.dropout(patches) 136 | x = self.transformer(x) 137 | x = self.mlp_head(x[:, 0, :]) if self.classification else x[:, 1:, :] 138 | 139 | return x 140 | 141 | 142 | if __name__ == '__main__': 143 | vit = ViT(img_dim=128, 144 | in_channels=3, 145 | patch_dim=16, 146 | embedding_dim=512, 147 | block_num=6, 148 | head_num=4, 149 | mlp_dim=1024) 150 | print(sum(p.numel() for p in vit.parameters())) 151 | print(vit(torch.rand(1, 3, 128, 128)).shape) -------------------------------------------------------------------------------- /model/AGPCNet/agpc.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/5/18 17:25 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : agpc.py 5 | # @Software: PyCharm 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from model.AGPCNet.resnet import * 11 | from model.AGPCNet.context import CPM, AGCB_Element, AGCB_Patch 12 | from model.AGPCNet.fusion import * 13 | 14 | 15 | class _FCNHead(nn.Module): 16 | def __init__(self, in_channels, out_channels, drop=0.5): 17 | super(_FCNHead, self).__init__() 18 | inter_channels = in_channels // 4 19 | self.block = nn.Sequential( 20 | nn.Conv2d(in_channels, inter_channels, 3, 1, 1), 21 | nn.BatchNorm2d(inter_channels), 22 | nn.ReLU(True), 23 | nn.Dropout(drop), 24 | nn.Conv2d(inter_channels, out_channels, 1, 1, 0) 25 | ) 26 | 27 | def forward(self, x): 28 | return self.block(x) 29 | 30 | 31 | class AGPCNet(nn.Module): 32 | def __init__(self, backbone='resnet18', scales=(10, 6), reduce_ratios=(8, 8), gca_type='patch', gca_att='origin', 33 | drop=0.1, **kwargs): 34 | super(AGPCNet, self).__init__() 35 | assert backbone in ['resnet18', 'resnet34'] 36 | assert gca_type in ['patch', 'element'] 37 | assert gca_att in ['origin', 'post'] 38 | 39 | if backbone == 'resnet18': 40 | self.backbone = resnet18(pretrained=True) 41 | elif backbone == 'resnet34': 42 | self.backbone = resnet34(pretrained=True) 43 | else: 44 | raise NotImplementedError 45 | 46 | self.fuse23 = AsymFusionModule(512, 256, 256) 47 | self.fuse12 = AsymFusionModule(256, 128, 128) 48 | 49 | self.head = _FCNHead(128, 1, drop=drop) 50 | 51 | self.context = CPM(planes=512, scales=scales, reduce_ratios=reduce_ratios, block_type=gca_type, 52 | att_mode=gca_att) 53 | 54 | # 迭代循环初始化参数 55 | for m in self.modules(): 56 | # 也可以判断是否为conv2d,使用相应的初始化方式 57 | if isinstance(m, nn.Conv2d): 58 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 59 | elif isinstance(m, nn.BatchNorm2d): 60 | nn.init.constant_(m.weight, 1) 61 | nn.init.constant_(m.bias, 0) 62 | 63 | def forward(self, x): 64 | _, _, hei, wid = x.shape 65 | 66 | c1, c2, c3 = self.backbone(x) 67 | 68 | out = self.context(c3) 69 | 70 | out = F.interpolate(out, size=[hei // 4, wid // 4], mode='bilinear', align_corners=True) 71 | out = self.fuse23(out, c2) 72 | 73 | out = F.interpolate(out, size=[hei // 2, wid // 2], mode='bilinear', align_corners=True) 74 | out = self.fuse12(out, c1) 75 | 76 | pred = self.head(out) 77 | out = F.interpolate(pred, size=[hei, wid], mode='bilinear', align_corners=True) 78 | 79 | return out 80 | 81 | 82 | class AGPCNet_Pro(nn.Module): 83 | def __init__(self, backbone='resnet18', scales=(10, 6), reduce_ratios=(8, 8), gca_type='patch', gca_att='origin', 84 | drop=0.1, **kwargs): 85 | super(AGPCNet_Pro, self).__init__() 86 | assert backbone in ['resnet18', 'resnet34'] 87 | assert gca_type in ['patch', 'element'] 88 | assert gca_att in ['origin', 'post'] 89 | 90 | if backbone == 'resnet18': 91 | self.backbone = resnet18(pretrained=True) 92 | elif backbone == 'resnet34': 93 | self.backbone = resnet34(pretrained=True) 94 | else: 95 | raise NotImplementedError 96 | 97 | self.fuse23 = AsymFusionModule(512, 256, 256) 98 | self.fuse12 = AsymFusionModule(256, 128, 128) 99 | 100 | self.head = _FCNHead(128, 1, drop=drop) 101 | 102 | self.context = CPM(planes=512, scales=scales, reduce_ratios=reduce_ratios, block_type=gca_type, 103 | att_mode=gca_att) 104 | 105 | # 迭代循环初始化参数 106 | for m in self.modules(): 107 | # 也可以判断是否为conv2d,使用相应的初始化方式 108 | if isinstance(m, nn.Conv2d): 109 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 110 | elif isinstance(m, nn.BatchNorm2d): 111 | nn.init.constant_(m.weight, 1) 112 | nn.init.constant_(m.bias, 0) 113 | 114 | def forward(self, x): 115 | _, _, hei, wid = x.shape 116 | 117 | c1, c2, c3 = self.backbone(x) 118 | 119 | out = self.context(c3) 120 | 121 | out = F.interpolate(out, size=[hei // 4, wid // 4], mode='bilinear', align_corners=True) 122 | out = self.fuse23(out, c2) 123 | 124 | out = F.interpolate(out, size=[hei // 2, wid // 2], mode='bilinear', align_corners=True) 125 | out = self.fuse12(out, c1) 126 | 127 | pred = self.head(out) 128 | out = F.interpolate(pred, size=[hei, wid], mode='bilinear', align_corners=True) 129 | 130 | return out 131 | 132 | 133 | if __name__ == '__main__': 134 | model = AGPC(backbone='resnet18', scales=(10, 6, 5, 3), reduce_ratios=(16, 4), gca_type='patch', gca_att='post', 135 | drop=0.1) 136 | x = torch.rand(8, 3, 256, 256) 137 | out = model(x) 138 | print(out.size()) 139 | -------------------------------------------------------------------------------- /docs/add_optimizer.md: -------------------------------------------------------------------------------- 1 | ## Add Optimizer and Scheduler 2 | 3 | You need to follow the process below to add optimizer and scheduler. 4 | 5 | ### Add Optimizer 6 | 7 | 1. Optimizer does not need to be rewritten like loss function, you only need to change the value of type to \_\_all__ 8 | of [build/build_optimizer.py](https://github.com/PANPEIWEN/Infrared-Small-Target-Segmentation-Framework/blob/main/build/build_optimizer.py) 9 | , and then 10 | pass in the corresponding parameters in the setting of optimizer. The details will be described next. 11 | 12 | ### Add Scheduler 13 | 14 | In order to easily combine warmup and scheduler, all our schedulers do not use pytorch, but rewrite them by themselves, 15 | and do not inherit the scheduler in pytorch like some loss functions. 16 | 17 | You need to perform these operations 18 | in [utils/scheduler.py](https://github.com/PANPEIWEN/Infrared-Small-Target-Segmentation-Framework/blob/main/utils/scheduler.py) 19 | . 20 | 21 | 1. Add scheduler 22 | 23 | ```python 24 | class YourScheduler(object): 25 | """ 26 | Where optimizer, base_lr, warmup and warmup_epochs is a necessary parameter, and the default values of warmup 27 | and warmup_epochs are None and 0 respectively. 28 | num_epochs represents the total number of epochs for training, not every scheduler requires this parameter, 29 | but if it is required, the parameter name must be num_epochs. 30 | """ 31 | def __init__(self, optimizer, base_lr, num_epochs, args1, args2, ..., warmup=None, warmup_epochs=0, **kwargs): 32 | super(YourScheduler, self).__init__() 33 | # The next four lines should be written in this format 34 | self.optimizer = optimizer 35 | self.base_lr = base_lr 36 | self.warmup = warmup 37 | self.warmup_epoch = warmup_epochs if self.warmup else 0 38 | pass 39 | 40 | def step(self, epoch): 41 | # The learning rate policy needs to follow the format 42 | if self.warmup and epoch <= self.warmup_epoch: 43 | # warmup 44 | globals()[self.warmup](self.optimizer, self.args1, self.args2, ...) 45 | # For example: 46 | # >>> globals()[self.warmup](self.optimizer, epoch, self.base_lr, self.warmup_epoch) 47 | else: 48 | # Calculate the learning rate 49 | lr = ... 50 | # Update learning rate 51 | for param_group in self.optimizer.param_groups: 52 | param_group['lr'] = lr 53 | ``` 54 | 55 | 2. Add the scheduler class name to \_\_all__ in 56 | the [build/build_scheduler.py](https://github.com/PANPEIWEN/Infrared-Small-Target-Segmentation-Framework/blob/main/build/build_scheduler.py) 57 | . 58 | 59 | ### Add Warmup 60 | 61 | You need to perform these operations 62 | in [utils/scheduler.py](https://github.com/PANPEIWEN/Infrared-Small-Target-Segmentation-Framework/blob/main/utils/scheduler.py) 63 | . 64 | 65 | 1. Add warmup 66 | 67 | ```python 68 | def your_warmup(optimizer, args1, args2, ...): 69 | ... 70 | # Calculate the learning rate 71 | lr = ... 72 | # Update learning rate 73 | for param_group in optimizer.param_groups: 74 | param_group['lr'] = lr 75 | ``` 76 | 77 | 2. Add the scheduler function name to \_\_all__ in 78 | the [build/build_scheduler.py](https://github.com/PANPEIWEN/Infrared-Small-Target-Segmentation-Framework/blob/main/build/build_scheduler.py) 79 | . 80 | 81 | ### Modify Config File 82 | 83 | The settings of optimizer, scheduler and warmup are concentrated in 84 | the [configs/\_base_/schedules/schedule_500e.py](https://github.com/PANPEIWEN/Infrared-Small-Target-Segmentation-Framework/blob/main/configs/_base_/schedules/schedule_500e.py) 85 | file, which has other settings. Next, we 86 | will introduce the configuration file in detail. 87 | 88 | ```python 89 | """ 90 | Since no method has been found to rewrite the optimizer in pytorch, it is recommended to rewrite the optimizer 91 | dictionary in the final config file to cover it, which is only for illustration here. 92 | Please refer to docs/add_model.md for details. 93 | """ 94 | optimizer = dict( 95 | # The type must in __all__ of build/build_optimizer.py. 96 | type='SGD', 97 | # Set the parameters of the optimizer, since there is no **kwargs parameter, the parameters set here can only 98 | # be parameters common to all optimizers. 99 | # So it is recommended to rewrite the optimizer dictionary in the final configuration file to overwrite it. 100 | setting=dict(lr=0.01) 101 | ) 102 | 103 | # No practical use 104 | optimizer_config = dict() 105 | 106 | """ 107 | Choose your scheduler and warmup strategy, the policy and warmup must in __all__ of build/build_scheduler.py. 108 | The first letter is uppercase for policy, the first letter is lowercase for warmup. 109 | The parameters required by scheduler and warmup can be passed in directly by adding key-value pairs. 110 | """ 111 | lr_config = dict(policy='PolyLR', warmup='linear', args1=..., args2=..., ...) 112 | 113 | # Number of training epochs 114 | runner = dict(type='EpochBasedRunner', max_epochs=500) 115 | # If by_epoch=True, the checkpoint is saved every interval epoch. 116 | checkpoint_config = dict(by_epoch=True, interval=1) 117 | # It has no practical effect at present, and this function will be implemented in the future. 118 | # Validate every epochval epoch. 119 | evaluation = dict(epochval=1) 120 | 121 | ``` -------------------------------------------------------------------------------- /utils/logs.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/4/6 19:04 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : logs.py 5 | # @Software: PyCharm 6 | from datetime import datetime 7 | 8 | import os 9 | 10 | 11 | def save_config_log(cfg, save_dir, file_name): 12 | with open('work_dirs/%s/%s/train_log.txt' % (save_dir, file_name), 'a') as f: 13 | print(cfg.pretty_text) 14 | f.write(cfg.pretty_text) 15 | f.write('\n') 16 | return 17 | 18 | def save_test_config(cfg, save_dir): 19 | with open('%s/test_log.txt' % save_dir, 'a') as f: 20 | print(cfg.pretty_text) 21 | f.write('config_file = ' + cfg.filename) 22 | f.write('\n') 23 | f.write(cfg.pretty_text) 24 | f.write('\n') 25 | return 26 | 27 | 28 | def save_train_args_log(args, save_dir): 29 | dict_args = vars(args) 30 | args_key = list(dict_args.keys()) 31 | args_value = list(dict_args.values()) 32 | with open('work_dirs/%s/train_log.txt' % save_dir, 'a') as f: 33 | now = datetime.now() 34 | f.write("time:--") 35 | dt_string = now.strftime("%Y/%m/%d %H:%M:%S ") 36 | f.write(dt_string) 37 | f.write('\n') 38 | for i in range(len(args_key)): 39 | f.write(args_key[i]) 40 | f.write(':--') 41 | f.write(str(args_value[i])) 42 | f.write('\n') 43 | f.write('\n') 44 | return 45 | 46 | 47 | def save_model_struct(save_dir, file_name, model): 48 | with open('work_dirs/%s/%s/model.txt' % (save_dir, file_name), 'a') as f: 49 | f.write(str(model)) 50 | return 51 | 52 | def save_train_log(save_dir, file_name, epoch, epochs, iter, iters, loss, lr, time): 53 | with open('work_dirs/%s/%s/train_log.txt' % (save_dir, file_name), 'a') as f: 54 | now = datetime.now() 55 | dt_string = now.strftime("%Y/%m/%d %H:%M:%S ") 56 | f.write(dt_string) 57 | f.write('Epoch: [%d/%d] Iter[%d/%d] Loss: %.4f Lr: %.5f Time: %.5f' 58 | % (epoch, epochs, iter, iters, loss, lr, time)) 59 | f.write('\n') 60 | return 61 | 62 | 63 | def save_test_log(save_dir, file_name, epoch, epochs, loss, mIoU, nIoU, f1, best_miou, best_niou, best_f1): 64 | with open('work_dirs/%s/%s/train_log.txt' % (save_dir, file_name), 'a') as f: 65 | now = datetime.now() 66 | dt_string = now.strftime("%Y/%m/%d %H:%M:%S ") 67 | f.write(dt_string) 68 | f.write('Epoch: [%d/%d] Loss: %.4f mIoU: %.4f nIoU: %.4f F1-score: %.4f ' 69 | 'Best_mIoU: %.4f Best_nIoU: %.4f Best_F1-score: %.4f' % ( 70 | epoch, epochs, loss, mIoU, nIoU, f1, best_miou, best_niou, best_f1)) 71 | f.write('\n') 72 | return 73 | 74 | 75 | def save_result_for_test(save_dir, mIoU, nIoU, recall, precision, FA, PD, f1, tp, fp): 76 | with open('%s/test_log.txt' % save_dir, 'a') as f: 77 | now = datetime.now() 78 | dt_string = now.strftime("%Y/%m/%d %H:%M:%S") 79 | f.write(dt_string) 80 | f.write('\n') 81 | f.write('mIoU: %.4f nIoU: %.4f F1-score: %.4f' % (mIoU, nIoU, f1)) 82 | f.write('\n') 83 | f.write('Recall-----:') 84 | for i in range(len(recall)): 85 | f.write(' ') 86 | f.write(str(round(recall[i], 8))) 87 | f.write(' ') 88 | f.write('\n') 89 | f.write('Precision--:') 90 | for i in range(len(precision)): 91 | f.write(' ') 92 | f.write(str(round(precision[i], 8))) 93 | f.write(' ') 94 | f.write('\n') 95 | f.write('TP---------:') 96 | for i in range(len(tp)): 97 | f.write(' ') 98 | f.write(str(round(tp[i], 8))) 99 | f.write(' ') 100 | f.write('\n') 101 | f.write('FP---------:') 102 | for i in range(len(fp)): 103 | f.write(' ') 104 | f.write(str(round(fp[i], 8))) 105 | f.write(' ') 106 | f.write('\n') 107 | f.write('PD---------:') 108 | for i in range(len(PD)): 109 | f.write(' ') 110 | f.write(str(round(PD[i], 8))) 111 | f.write(' ') 112 | f.write('\n') 113 | f.write('FA---------:') 114 | for i in range(len(FA)): 115 | f.write(' ') 116 | f.write(str(round(FA[i], 8))) 117 | f.write(' ') 118 | f.write('\n') 119 | f.write( 120 | '---------------------------------------------------------------------------------------------------------' 121 | '---------------------------------------------------------------------------------------------------\n') 122 | f.write( 123 | '---------------------------------------------------------------------------------------------------------' 124 | '---------------------------------------------------------------------------------------------------\n') 125 | f.write( 126 | '---------------------------------------------------------------------------------------------------------' 127 | '---------------------------------------------------------------------------------------------------\n') 128 | return 129 | 130 | 131 | def make_dir(config): 132 | save_dir = config 133 | os.makedirs('work_dirs/%s' % save_dir, exist_ok=True) 134 | return save_dir 135 | 136 | 137 | def make_log_dir(config, log_file): 138 | os.makedirs('work_dirs/%s' % config, exist_ok=True) 139 | os.makedirs('work_dirs/%s/%s' % (config, log_file), exist_ok=True) 140 | 141 | 142 | def train_log_file(): 143 | now = datetime.now() 144 | dt_string = now.strftime("%Y%m%d_%H%M%S") 145 | file_name = dt_string 146 | return file_name 147 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/6/16 16:36 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : rebuild_train.py 5 | # @Software: PyCharm 6 | import argparse 7 | import os 8 | import time 9 | 10 | import torch.distributed 11 | import torch.nn 12 | from mmcv import Config, DictAction 13 | 14 | from utils.tools import * 15 | 16 | from build.build_model import build_model 17 | from build.build_criterion import build_criterion 18 | from build.build_optimizer import build_optimizer 19 | from build.build_dataset import build_dataset 20 | from build.build_scheduler import build_scheduler 21 | 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser(description='Train a segmentor') 25 | parser.add_argument('config', help='train config file path') 26 | parser.add_argument( 27 | '--load-from', help='the checkpoint file to load weights from') 28 | parser.add_argument( 29 | '--resume-from', help='the checkpoint file to resume from') 30 | group_gpus = parser.add_mutually_exclusive_group() 31 | group_gpus.add_argument( 32 | '--gpu-id', 33 | type=int, 34 | default=0, 35 | help='id of gpu to use ' 36 | '(only applicable to non-distributed training)') 37 | parser.add_argument( 38 | '--cfg-options', 39 | nargs='+', 40 | action=DictAction, 41 | help='override some settings in the used config, the key-value pair ' 42 | 'in xxx=yyy format will be merged into config file. If the value to ' 43 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 44 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 45 | 'Note that the quotation marks are necessary and that no white space ' 46 | 'is allowed.') 47 | parser.add_argument('--local_rank', type=int, default=-1) 48 | args = parser.parse_args() 49 | if 'LOCAL_RANK' not in os.environ: 50 | os.environ['LOCAL_RANK'] = str(args.local_rank) 51 | return args 52 | 53 | 54 | class Train(object): 55 | def __init__(self, args, cfg): 56 | super(Train, self).__init__() 57 | self.cfg = cfg 58 | self.cfg.gpus = torch.cuda.device_count() if args.local_rank != -1 else 1 59 | self.resume = args.resume_from 60 | self.deep_supervision = 'deep_supervision' in self.cfg.model['decode_head'] 61 | 62 | self.device = init_devices(args, self.cfg) 63 | 64 | data = build_dataset(args, self.cfg) 65 | self.data = init_data(args, data) 66 | 67 | model = build_model(self.cfg) 68 | self.model, checkpoint = init_model(args, self.cfg, model, self.device) 69 | self.criterion = build_criterion(self.cfg) 70 | optimizer = build_optimizer(self.model, self.cfg) 71 | if self.cfg.lr_config['policy']: 72 | self.scheduler = build_scheduler(optimizer, self.cfg) 73 | 74 | self.optimizer, self.metrics = init_metrics(args, optimizer, checkpoint if args.resume_from else None) 75 | self.save_dir, self.train_log_file, self.write = save_log(args, self.cfg, self.model) 76 | 77 | def training(self, epoch): 78 | self.model.train() 79 | losses = [] 80 | if args.local_rank != -1: 81 | self.data['train_sample'].set_epoch(epoch) 82 | if not self.resume and self.cfg.lr_config['policy']: 83 | self.scheduler.step(epoch - 1) 84 | 85 | for i, data in enumerate(self.data['train_data']): 86 | since = time.time() 87 | img, mask = data2device(args, data, self.device) 88 | preds = self.model(img) 89 | loss, _ = compute_loss(preds, mask, self.deep_supervision, self.cfg, self.criterion) 90 | self.optimizer.zero_grad() 91 | loss.backward() 92 | self.optimizer.step() 93 | losses.append(loss.item()) 94 | time_elapsed = time.time() - since 95 | show_log('train', args, self.cfg, epoch, losses, self.save_dir, self.train_log_file, i=i, data=self.data,time_elapsed=time_elapsed, optimizer=self.optimizer) 96 | save_model('train', args, self.cfg, epoch, self.model, losses, self.optimizer, self.metrics, self.save_dir, self.train_log_file) 97 | update_log('train', args, self.metrics, self.write, losses, epoch, optimizer=self.optimizer) 98 | 99 | def testing(self, epoch): 100 | self.model.eval() 101 | reset_metrics(self.metrics) 102 | eval_losses = [] 103 | with torch.no_grad(): 104 | for i, data in enumerate(self.data['test_data']): 105 | img, mask = data2device(args, data, self.device) 106 | preds = self.model(img) 107 | loss, preds = compute_loss(preds, mask, self.deep_supervision, self.cfg, self.criterion) 108 | eval_losses.append(loss.item()) 109 | IoU, nIoU, F1_score = update_metrics(preds, mask, self.metrics) 110 | show_log('test', args, self.cfg, epoch, eval_losses, self.save_dir, self.train_log_file, IoU=IoU, nIoU=nIoU,F1_score=F1_score, metrics=self.metrics) 111 | append_metrics(args, self.metrics, eval_losses, IoU, nIoU, F1_score) 112 | save_model('test', args, self.cfg, epoch, self.model, eval_losses, self.optimizer, self.metrics, 113 | self.save_dir, self.train_log_file, IoU=IoU, nIoU=nIoU) 114 | draw(args, self.metrics, self.save_dir, self.train_log_file) 115 | update_log('test', args, self.metrics, self.write, eval_losses, epoch, IoU=IoU, nIoU=nIoU, 116 | F1_score=F1_score) 117 | 118 | 119 | def main(args): 120 | cfg = Config.fromfile(args.config) 121 | trainer = Train(args, cfg) 122 | if args.local_rank != -1: 123 | torch.distributed.barrier() 124 | start = torch.load(args.resume_from)['epoch'] + 1 if args.resume_from else 1 125 | for i in range(start, cfg.runner['max_epochs'] + 1): 126 | trainer.training(i) 127 | trainer.testing(i) 128 | 129 | 130 | if __name__ == '__main__': 131 | args = parse_args() 132 | main(args) 133 | -------------------------------------------------------------------------------- /model/ACM/fusion.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/9/21 11:08 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : fusion.py 5 | # @Software: PyCharm 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class BiLocalChaFuseReduce(nn.Module): 11 | def __init__(self, in_high_channels, in_low_channels, out_channels=64, r=4): 12 | super(BiLocalChaFuseReduce, self).__init__() 13 | 14 | assert in_low_channels == out_channels 15 | self.high_channels = in_high_channels 16 | self.low_channels = in_low_channels 17 | self.out_channels = out_channels 18 | self.bottleneck_channels = int(out_channels // r) 19 | 20 | self.feature_high = nn.Sequential( 21 | nn.Conv2d(self.high_channels, self.out_channels, 1, 1, 0), 22 | nn.BatchNorm2d(self.out_channels), 23 | nn.ReLU(True), 24 | ) 25 | 26 | self.topdown = nn.Sequential( 27 | nn.Conv2d(self.out_channels, self.bottleneck_channels, 1, 1, 0), 28 | nn.BatchNorm2d(self.bottleneck_channels), 29 | nn.ReLU(True), 30 | 31 | nn.Conv2d(self.bottleneck_channels, self.out_channels, 1, 1, 0), 32 | nn.BatchNorm2d(self.out_channels), 33 | nn.Sigmoid() 34 | ) 35 | 36 | self.bottomup = nn.Sequential( 37 | nn.Conv2d(self.low_channels, self.bottleneck_channels, 1, 1, 0), 38 | nn.BatchNorm2d(self.bottleneck_channels), 39 | nn.ReLU(True), 40 | 41 | nn.Conv2d(self.bottleneck_channels, self.out_channels, 1, 1, 0), 42 | nn.BatchNorm2d(self.out_channels), 43 | nn.Sigmoid(), 44 | ) 45 | 46 | self.post = nn.Sequential( 47 | nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), 48 | nn.BatchNorm2d(self.out_channels), 49 | nn.ReLU(True), 50 | ) 51 | 52 | def forward(self, xh, xl): 53 | xh = self.feature_high(xh) 54 | topdown_wei = self.topdown(xh) 55 | bottomup_wei = self.bottomup(xl) 56 | 57 | out = 2 * xl * topdown_wei + 2 * xh * bottomup_wei 58 | out = self.post(out) 59 | return out 60 | 61 | 62 | class AsymBiChaFuseReduce(nn.Module): 63 | def __init__(self, in_high_channels, in_low_channels, out_channels=64, r=4): 64 | super(AsymBiChaFuseReduce, self).__init__() 65 | 66 | assert in_low_channels == out_channels 67 | self.high_channels = in_high_channels 68 | self.low_channels = in_low_channels 69 | self.out_channels = out_channels 70 | self.bottleneck_channels = int(out_channels // r) 71 | 72 | self.feature_high = nn.Sequential( 73 | nn.Conv2d(self.high_channels, self.out_channels, 1, 1, 0), 74 | nn.BatchNorm2d(out_channels), 75 | nn.ReLU(True), 76 | ) 77 | 78 | self.topdown = nn.Sequential( 79 | nn.AdaptiveAvgPool2d((1, 1)), 80 | nn.Conv2d(self.out_channels, self.bottleneck_channels, 1, 1, 0), 81 | nn.BatchNorm2d(self.bottleneck_channels), 82 | nn.ReLU(True), 83 | 84 | nn.Conv2d(self.bottleneck_channels, self.out_channels, 1, 1, 0), 85 | nn.BatchNorm2d(self.out_channels), 86 | nn.Sigmoid(), 87 | ) 88 | 89 | self.bottomup = nn.Sequential( 90 | nn.Conv2d(self.low_channels, self.bottleneck_channels, 1, 1, 0), 91 | nn.BatchNorm2d(self.bottleneck_channels), 92 | nn.ReLU(True), 93 | 94 | nn.Conv2d(self.bottleneck_channels, self.out_channels, 1, 1, 0), 95 | nn.BatchNorm2d(self.out_channels), 96 | nn.Sigmoid(), 97 | ) 98 | 99 | self.post = nn.Sequential( 100 | nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), 101 | nn.BatchNorm2d(self.out_channels), 102 | nn.ReLU(True), 103 | ) 104 | 105 | def forward(self, xh, xl): 106 | xh = self.feature_high(xh) 107 | 108 | topdown_wei = self.topdown(xh) 109 | bottomup_wei = self.bottomup(xl) 110 | xs = 2 * xl * topdown_wei + 2 * xh * bottomup_wei 111 | out = self.post(xs) 112 | return out 113 | 114 | 115 | class BiGlobalChaFuseReduce(nn.Module): 116 | def __init__(self, in_high_channels, in_low_channels, out_channels=64, r=4): 117 | super(BiGlobalChaFuseReduce, self).__init__() 118 | 119 | assert in_low_channels == out_channels 120 | self.high_channels = in_high_channels 121 | self.low_channels = in_low_channels 122 | self.out_channels = out_channels 123 | self.bottleneck_channels = int(out_channels // r) 124 | 125 | self.feature_high = nn.Sequential( 126 | nn.Conv2d(self.high_channels, self.out_channels, 1, 1, 0), 127 | nn.BatchNorm2d(out_channels), 128 | nn.ReLU(True), 129 | ) 130 | 131 | self.topdown = nn.Sequential( 132 | nn.AdaptiveAvgPool2d((1, 1)), 133 | nn.Conv2d(self.out_channels, self.bottleneck_channels, 1, 1, 0), 134 | nn.BatchNorm2d(self.bottleneck_channels), 135 | nn.ReLU(True), 136 | 137 | nn.Conv2d(self.bottleneck_channels, self.out_channels, 1, 1, 0), 138 | nn.BatchNorm2d(self.out_channels), 139 | nn.Sigmoid(), 140 | ) 141 | 142 | self.bottomup = nn.Sequential( 143 | nn.AdaptiveAvgPool2d((1, 1)), 144 | nn.Conv2d(self.low_channels, self.bottleneck_channels, 1, 1, 0), 145 | nn.BatchNorm2d(self.bottleneck_channels), 146 | nn.ReLU(True), 147 | 148 | nn.Conv2d(self.bottleneck_channels, self.out_channels, 1, 1, 0), 149 | nn.BatchNorm2d(self.out_channels), 150 | nn.Sigmoid(), 151 | ) 152 | 153 | self.post = nn.Sequential( 154 | nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), 155 | nn.BatchNorm2d(self.out_channels), 156 | nn.ReLU(True), 157 | ) 158 | 159 | def forward(self, xh, xl): 160 | xh = self.feature_high(xh) 161 | 162 | topdown_wei = self.topdown(xh) 163 | bottomup_wei = self.bottomup(xl) 164 | xs = 2 * xl * topdown_wei + 2 * xh * bottomup_wei 165 | out = self.post(xs) 166 | return out 167 | -------------------------------------------------------------------------------- /model/ABC/ABCNet.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/3/17 15:56 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : ABCNet.py 5 | # @Software: PyCharm 6 | import torch 7 | import torch.nn as nn 8 | from model.ABC.Module import conv_block, up_conv, _upsample_like, conv_relu_bn, dconv_block 9 | from einops import rearrange 10 | 11 | 12 | class Attention(nn.Module): 13 | def __init__(self, in_dim, in_feature, out_feature): 14 | super(Attention, self).__init__() 15 | self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=1, kernel_size=1) 16 | self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=1, kernel_size=1) 17 | self.query_line = nn.Linear(in_features=in_feature, out_features=out_feature) 18 | self.key_line = nn.Linear(in_features=in_feature, out_features=out_feature) 19 | self.s_conv = nn.Conv2d(in_channels=1, out_channels=in_dim, kernel_size=1) 20 | self.softmax = nn.Softmax(dim=-1) 21 | 22 | def forward(self, x): 23 | q = rearrange(self.query_line(rearrange(self.query_conv(x), 'b 1 h w -> b (h w)')), 'b h -> b h 1') 24 | k = rearrange(self.key_line(rearrange(self.key_conv(x), 'b 1 h w -> b (h w)')), 'b h -> b 1 h') 25 | att = rearrange(torch.matmul(q, k), 'b h w -> b 1 h w') 26 | att = self.softmax(self.s_conv(att)) 27 | return att 28 | 29 | 30 | class Conv(nn.Module): 31 | def __init__(self, in_dim): 32 | super(Conv, self).__init__() 33 | self.convs = nn.ModuleList([conv_relu_bn(in_dim, in_dim, 1) for _ in range(3)]) 34 | 35 | def forward(self, x): 36 | for conv in self.convs: 37 | x = conv(x) 38 | return x 39 | 40 | 41 | class DConv(nn.Module): 42 | def __init__(self, in_dim): 43 | super(DConv, self).__init__() 44 | dilation = [2, 4, 2] 45 | self.dconvs = nn.ModuleList([conv_relu_bn(in_dim, in_dim, dirate) for dirate in dilation]) 46 | 47 | def forward(self, x): 48 | for dconv in self.dconvs: 49 | x = dconv(x) 50 | return x 51 | 52 | 53 | class ConvAttention(nn.Module): 54 | def __init__(self, in_dim, in_feature, out_feature): 55 | super(ConvAttention, self).__init__() 56 | self.conv = Conv(in_dim) 57 | self.dconv = DConv(in_dim) 58 | self.att = Attention(in_dim, in_feature, out_feature) 59 | self.gamma = nn.Parameter(torch.zeros(1)) 60 | 61 | def forward(self, x): 62 | q = self.conv(x) 63 | k = self.dconv(x) 64 | v = q + k 65 | att = self.att(x) 66 | out = torch.matmul(att, v) 67 | return self.gamma * out + v + x 68 | 69 | 70 | class FeedForward(nn.Module): 71 | def __init__(self, in_dim, out_dim): 72 | super(FeedForward, self).__init__() 73 | self.conv = conv_relu_bn(in_dim, out_dim, 1) 74 | # self.x_conv = nn.Conv2d(in_dim, out_dim, kernel_size=1) 75 | self.x_conv = nn.Sequential( 76 | nn.Conv2d(in_dim, out_dim, kernel_size=1), 77 | nn.BatchNorm2d(out_dim), 78 | nn.ReLU(inplace=True) 79 | ) 80 | 81 | def forward(self, x): 82 | out = self.conv(x) 83 | x = self.x_conv(x) 84 | return x + out 85 | 86 | 87 | class ConvTransformer(nn.Module): 88 | def __init__(self, in_dim, out_dim, in_feature, out_feature): 89 | super(ConvTransformer, self).__init__() 90 | self.attention = ConvAttention(in_dim, in_feature, out_feature) 91 | self.feedforward = FeedForward(in_dim, out_dim) 92 | 93 | def forward(self, x): 94 | x = self.attention(x) 95 | out = self.feedforward(x) 96 | return out 97 | 98 | 99 | class ABCNet(nn.Module): 100 | def __init__(self, in_ch=3, out_ch=1, dim=64, ori_h=256, deep_supervision=True, **kwargs): 101 | super(ABCNet, self).__init__() 102 | self.deep_supervision = deep_supervision 103 | filters = [dim, dim * 2, dim * 4, dim * 8, dim * 16] 104 | features = [ori_h // 2, ori_h // 4, ori_h // 8, ori_h // 16] 105 | self.maxpools = nn.ModuleList([nn.MaxPool2d(kernel_size=2, stride=2) for _ in range(4)]) 106 | self.Conv1 = conv_block(in_ch=in_ch, out_ch=filters[0]) 107 | # self.Conv1 = ConvTransformer(in_ch, filters[0], pow(ori_h, 2), ori_h) 108 | self.Convtans2 = ConvTransformer(filters[0], filters[1], pow(features[0], 2), features[0]) 109 | self.Convtans3 = ConvTransformer(filters[1], filters[2], pow(features[1], 2), features[1]) 110 | self.Convtans4 = ConvTransformer(filters[2], filters[3], pow(features[2], 2), features[2]) 111 | self.Conv5 = dconv_block(in_ch=filters[3], out_ch=filters[4]) 112 | 113 | self.Up5 = up_conv(filters[4], filters[3]) 114 | self.Up_conv5 = dconv_block(filters[4], filters[3]) 115 | 116 | self.Up4 = up_conv(filters[3], filters[2]) 117 | self.Up_conv4 = conv_block(filters[3], filters[2]) 118 | 119 | self.Up3 = up_conv(filters[2], filters[1]) 120 | self.Up_conv3 = conv_block(filters[2], filters[1]) 121 | 122 | self.Up2 = up_conv(filters[1], filters[0]) 123 | self.Up_conv2 = conv_block(filters[1], filters[0]) 124 | 125 | self.Conv = nn.Conv2d(filters[0], out_ch, kernel_size=1, stride=1, padding=0) 126 | 127 | # -------------------------------------------------------------------------------------------------------------- 128 | self.conv5 = nn.Conv2d(filters[4], out_ch, kernel_size=3, stride=1, padding=1) 129 | self.conv4 = nn.Conv2d(filters[3], out_ch, kernel_size=3, stride=1, padding=1) 130 | self.conv3 = nn.Conv2d(filters[2], out_ch, kernel_size=3, stride=1, padding=1) 131 | self.conv2 = nn.Conv2d(filters[1], out_ch, kernel_size=3, stride=1, padding=1) 132 | self.conv1 = nn.Conv2d(filters[0], out_ch, kernel_size=3, stride=1, padding=1) 133 | # -------------------------------------------------------------------------------------------------------------- 134 | 135 | def forward(self, x): 136 | e1 = self.Conv1(x) 137 | 138 | e2 = self.maxpools[0](e1) 139 | e2 = self.Convtans2(e2) 140 | 141 | e3 = self.maxpools[1](e2) 142 | e3 = self.Convtans3(e3) 143 | 144 | e4 = self.maxpools[2](e3) 145 | e4 = self.Convtans4(e4) 146 | 147 | e5 = self.maxpools[3](e4) 148 | e5 = self.Conv5(e5) 149 | 150 | d5 = self.Up5(e5) 151 | d5 = torch.cat((e4, d5), dim=1) 152 | d5 = self.Up_conv5(d5) 153 | 154 | d4 = self.Up4(d5) 155 | d4 = torch.cat((e3, d4), dim=1) 156 | d4 = self.Up_conv4(d4) 157 | 158 | d3 = self.Up3(d4) 159 | d3 = torch.cat((e2, d3), dim=1) 160 | d3 = self.Up_conv3(d3) 161 | 162 | d2 = self.Up2(d3) 163 | d2 = torch.cat((e1, d2), dim=1) 164 | d2 = self.Up_conv2(d2) 165 | 166 | out = self.Conv(d2) 167 | 168 | d_s1 = self.conv1(d2) 169 | d_s2 = self.conv2(d3) 170 | d_s2 = _upsample_like(d_s2, d_s1) 171 | d_s3 = self.conv3(d4) 172 | d_s3 = _upsample_like(d_s3, d_s1) 173 | d_s4 = self.conv4(d5) 174 | d_s4 = _upsample_like(d_s4, d_s1) 175 | d_s5 = self.conv5(e5) 176 | d_s5 = _upsample_like(d_s5, d_s1) 177 | if self.deep_supervision: 178 | outs = [d_s1, d_s2, d_s3, d_s4, d_s5, out] 179 | else: 180 | outs = out 181 | # d1 = self.active(out) 182 | 183 | return outs 184 | 185 | 186 | if __name__ == '__main__': 187 | x = torch.randn(8, 3, 256, 256) 188 | model = ABCNet(ori_h=256) 189 | print(model(x)) 190 | 191 | -------------------------------------------------------------------------------- /model/URANet/uranet.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/6/10 16:13 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : uranet.py 5 | # @Software: PyCharm 6 | from audioop import bias 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class CDC_conv(nn.Module): 13 | def __init__(self, in_channels, out_channels, bias=True, kernel_size=3, padding=1, dilation=1, theta=0.0): 14 | super().__init__() 15 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, dilation=dilation, 16 | bias=bias) 17 | self.theta = theta 18 | 19 | def forward(self, x): 20 | norm_out = self.conv(x) 21 | [c_out, c_in, kernel_size, kernel_size] = self.conv.weight.shape 22 | kernel_diff = self.conv.weight.sum(2).sum(2) 23 | kernel_diff = kernel_diff[:, :, None, None] 24 | diff_out = F.conv2d(input=x, weight=kernel_diff, bias=self.conv.bias, stride=self.conv.stride, padding=0) 25 | out = norm_out - self.theta * diff_out 26 | return out 27 | 28 | 29 | class Layernorm(nn.Module): 30 | def __init__(self, in_c): 31 | super().__init__() 32 | self.layernorm = nn.LayerNorm(in_c) 33 | 34 | def forward(self, x): 35 | x = x.permute(0, 2, 3, 1) 36 | x = self.layernorm(x) 37 | x = x.permute(0, 3, 1, 2) 38 | return x 39 | 40 | 41 | class ResidualBlock(nn.Module): 42 | def __init__(self, in_c, out_c, theta=0.7, norm=nn.BatchNorm2d): 43 | super().__init__() 44 | self.conv_block = nn.Sequential( 45 | nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, theta=theta, bias=False if norm == nn.BatchNorm2d else True), 46 | norm(out_c), 47 | nn.ReLU(inplace=True), 48 | CDC_conv(out_c, out_c, kernel_size=3, padding=1, theta=theta, bias=False if norm == nn.BatchNorm2d else True), 49 | norm(out_c), 50 | ) 51 | self.residual_block = nn.Sequential( 52 | nn.Conv2d(in_c, out_c, kernel_size=1, bias=False), 53 | nn.BatchNorm2d(out_c) 54 | ) 55 | self.relu = nn.ReLU(inplace=True) 56 | 57 | def forward(self, x): 58 | conv_out = self.conv_block(x) 59 | residual_out = self.residual_block(x) 60 | out = self.relu(conv_out + residual_out) 61 | return out 62 | 63 | 64 | class UpsampleBlock(nn.Module): 65 | def __init__(self, in_c, out_c, bilinear=True): 66 | super().__init__() 67 | if bilinear: 68 | self.up_block = nn.Sequential( 69 | nn.Conv2d(in_c, out_c, kernel_size=1, stride=1, padding=0, bias=False), 70 | nn.BatchNorm2d(out_c), 71 | nn.ReLU(inplace=True) 72 | ) 73 | else: 74 | self.up_block = nn.Sequential( 75 | nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=3, stride=2, padding=1, 76 | bias=False), 77 | nn.BatchNorm2d(out_c), 78 | nn.ReLU(inplace=True) 79 | ) 80 | 81 | def forward(self, x, lateral): 82 | size = lateral.shape[-2:] 83 | x = self.up_block(x) 84 | x = F.interpolate(x, size=size, mode='bilinear', align_corners=False) 85 | out = x + lateral 86 | return out 87 | 88 | 89 | class Position_attention(nn.Module): 90 | def __init__(self, in_c, mid_c=None): 91 | super().__init__() 92 | mid_c = mid_c or in_c // 8 93 | self.q = nn.Conv2d(in_c, mid_c, kernel_size=1) 94 | self.k = nn.Conv2d(in_c, mid_c, kernel_size=1) 95 | self.v = nn.Conv2d(in_c, in_c, kernel_size=1) 96 | self.gamma = nn.Parameter(torch.zeros(1)) 97 | self.softmax = nn.Softmax(dim=-1) 98 | 99 | def forward(self, x): 100 | b, _, h, w = x.shape 101 | q = self.q(x).view(b, -1, h * w).permute(0, 2, 1) # bs, hw, c 102 | k = self.k(x).view(b, -1, h * w) # bs, c ,hw 103 | v = self.v(x).view(b, -1, h * w) # bs, c, hw 104 | att = self.softmax(q @ k) 105 | out = (v @ att.permute(0, 2, 1)).view(b, -1, h, w) 106 | out = self.gamma * out + x 107 | 108 | return out 109 | 110 | 111 | class Channel_attention(nn.Module): 112 | def __init__(self, in_c): 113 | super().__init__() 114 | self.in_c = in_c 115 | self.softmax = nn.Softmax(dim=-1) 116 | self.gamma = nn.Parameter(torch.zeros(1)) 117 | 118 | def forward(self, x): 119 | b, _, h, w = x.shape 120 | q = x.view(b, -1, h * w) # bs, c ,hw 121 | k = x.view(b, -1, h * w).permute(0, 2, 1) # bs, hw, c 122 | v = x.view(b, -1, h * w) # bs, c, hw 123 | att = self.softmax(q @ k) # b, c, c 124 | out = att @ v 125 | out = out.view(b, -1, h, w) 126 | out = self.gamma * out + x 127 | return out 128 | 129 | 130 | class Double_attention(nn.Module): 131 | def __init__(self, in_c, mid_c=None): 132 | super().__init__() 133 | self.pam = Position_attention(in_c, mid_c) 134 | self.cam = Channel_attention(in_c) 135 | self.relu = nn.ReLU() 136 | 137 | def forward(self, x): 138 | pam_out = self.pam(x) 139 | cam_out = self.cam(x) 140 | return pam_out + cam_out 141 | 142 | 143 | class URANet(nn.Module): 144 | def __init__(self, in_channel=3, base_dim=32, class_num=1, bilinear=True, use_da=True, theta=0.7, 145 | norm=nn.BatchNorm2d, **kwargs): 146 | super(URANet, self).__init__() 147 | self.norm = norm 148 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 149 | self.conv1 = nn.Sequential( 150 | # nn.Conv2d(in_c, base_dim, kernel_size=3, padding=1, bias=False), 151 | CDC_conv(in_channel, base_dim, bias=False, theta=theta), 152 | nn.BatchNorm2d(base_dim), 153 | nn.ReLU(inplace=True), 154 | ) 155 | self.conv2 = nn.Sequential( 156 | # nn.Conv2d(base_dim, base_dim, kernel_size=3, padding=1, bias=False), 157 | CDC_conv(base_dim, base_dim, bias=False, theta=theta), 158 | nn.BatchNorm2d(base_dim), 159 | nn.ReLU(inplace=True), 160 | ) 161 | self.layer1 = ResidualBlock(base_dim, base_dim * 2, theta=theta, norm=self.norm) 162 | self.layer2 = ResidualBlock(base_dim * 2, base_dim * 4, theta=theta, norm=self.norm) 163 | self.layer3 = ResidualBlock(base_dim * 4, base_dim * 8, theta=theta, norm=self.norm) 164 | self.layer4 = ResidualBlock(base_dim * 8, base_dim * 16, theta=theta, norm=self.norm) 165 | self.da = Double_attention(base_dim * 16, None) if use_da else nn.Identity() 166 | self.up3 = UpsampleBlock(base_dim * 16, base_dim * 8, bilinear=bilinear) 167 | self.up2 = UpsampleBlock(base_dim * 8, base_dim * 4, bilinear=bilinear) 168 | self.up1 = UpsampleBlock(base_dim * 4, base_dim * 2, bilinear=bilinear) 169 | self.up0 = UpsampleBlock(base_dim * 2, base_dim, bilinear=bilinear) 170 | self.last_conv = nn.Conv2d(base_dim, class_num, kernel_size=1, stride=1) 171 | 172 | def forward(self, x): 173 | out_0 = self.conv1(x) 174 | out_0 = self.conv2(out_0) 175 | out_1 = self.layer1(out_0) 176 | out_2 = self.layer2(self.maxpool(out_1)) 177 | out_3 = self.layer3(self.maxpool(out_2)) 178 | out_4 = self.layer4(self.maxpool(out_3)) 179 | out_da = self.da(out_4) 180 | up_3 = self.up3(out_da, out_3) 181 | up_2 = self.up2(up_3, out_2) 182 | up_1 = self.up1(up_2, out_1) 183 | up_0 = self.up0(up_1, out_0) 184 | out = self.last_conv(up_0) 185 | return out 186 | 187 | 188 | def main(): 189 | x = torch.rand(2, 3, 512, 512) 190 | net = URANet() 191 | out = net(x) 192 | print(out.shape) 193 | 194 | 195 | if __name__ == '__main__': 196 | main() 197 | -------------------------------------------------------------------------------- /model/MTUet/mtu_uet.py: -------------------------------------------------------------------------------- 1 | # @Time : 2024/2/6 13:34 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : mtu_uet.py 5 | # @Software: PyCharm 6 | import torch 7 | import torch.nn as nn 8 | 9 | from model.MTUUet.vit import ViT 10 | 11 | from einops import rearrange 12 | 13 | 14 | class _FCNHead(nn.Module): 15 | # pylint: disable=redefined-outer-name 16 | def __init__(self, in_channels, channels, momentum, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs): 17 | super(_FCNHead, self).__init__() 18 | inter_channels = in_channels // 4 19 | self.block = nn.Sequential( 20 | nn.Conv2d(in_channels=in_channels, out_channels=inter_channels, kernel_size=3, padding=1, bias=False), 21 | norm_layer(inter_channels, momentum=momentum), 22 | nn.ReLU(inplace=True), 23 | nn.Dropout(0.1), 24 | nn.Conv2d(in_channels=inter_channels, out_channels=channels, kernel_size=1) 25 | ) 26 | 27 | # pylint: disable=arguments-differ 28 | def forward(self, x): 29 | return self.block(x) 30 | 31 | 32 | class Res_block(nn.Module): 33 | def __init__(self, in_channels, out_channels, stride=1): 34 | super(Res_block, self).__init__() 35 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1) 36 | self.bn1 = nn.BatchNorm2d(out_channels) 37 | self.relu = nn.LeakyReLU(inplace=True) 38 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) 39 | self.bn2 = nn.BatchNorm2d(out_channels) 40 | # self.fca = FCA_Layer(out_channels) 41 | if stride != 1 or out_channels != in_channels: 42 | self.shortcut = nn.Sequential( 43 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride), 44 | nn.BatchNorm2d(out_channels)) 45 | else: 46 | self.shortcut = None 47 | 48 | def forward(self, x): 49 | residual = x 50 | if self.shortcut is not None: 51 | residual = self.shortcut(x) 52 | out = self.conv1(x) 53 | out = self.bn1(out) 54 | out = self.relu(out) 55 | out = self.conv2(out) 56 | out = self.bn2(out) 57 | 58 | out += residual 59 | out = self.relu(out) 60 | return out 61 | 62 | 63 | class MTUNet(nn.Module): 64 | def __init__(self, num_classes=1, input_channels=3, block=Res_block, num_blocks=(2, 2, 2, 2), 65 | nb_filter=(16, 32, 64, 128, 256), **kwargs): 66 | super(MTUNet, self).__init__() 67 | 68 | self.pool = nn.MaxPool2d(2, 2) 69 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 70 | self.up4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) 71 | self.up8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) 72 | self.up16 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True) 73 | 74 | self.conv0_0_0 = self._make_layer(block, input_channels, input_channels) 75 | self.conv0_0 = self._make_layer(block, input_channels, nb_filter[0]) 76 | # for name,value in self.conv0_0.named_parameters(): 77 | # value.requires_grad = False 78 | self.conv1_0 = self._make_layer(block, nb_filter[0], nb_filter[1], num_blocks[0]) 79 | # for name,value in self.conv1_0.named_parameters(): 80 | # value.requires_grad = False 81 | self.conv2_0 = self._make_layer(block, nb_filter[1], nb_filter[2], num_blocks[1]) 82 | # for name,value in self.conv2_0.named_parameters(): 83 | # value.requires_grad = False 84 | self.conv3_0 = self._make_layer(block, nb_filter[2], nb_filter[3], num_blocks[2]) 85 | # for name,value in self.conv3_0.named_parameters(): 86 | # value.requires_grad = False 87 | self.conv4_0 = self._make_layer(block, nb_filter[3], nb_filter[4], num_blocks[3]) 88 | # for name,value in self.conv4_0.named_parameters(): 89 | # value.requires_grad = False 90 | 91 | self.vit5 = ViT(img_dim=1024, in_channels=nb_filter[0], embedding_dim=nb_filter[2], head_num=1, mlp_dim=64 * 64, 92 | block_num=1, patch_dim=16, classification=False, num_classes=1) 93 | 94 | self.vit4 = ViT(img_dim=512, in_channels=nb_filter[1], embedding_dim=nb_filter[2], head_num=1, mlp_dim=64 * 64, 95 | block_num=1, patch_dim=8, classification=False, num_classes=1) 96 | self.vit3 = ViT(img_dim=256, in_channels=nb_filter[2], embedding_dim=nb_filter[2], head_num=1, mlp_dim=64 * 64, 97 | block_num=1, patch_dim=4, classification=False, num_classes=1) 98 | self.vit2 = ViT(img_dim=128, in_channels=nb_filter[3], embedding_dim=nb_filter[2], head_num=1, mlp_dim=64 * 64, 99 | block_num=1, patch_dim=2, classification=False, num_classes=1) 100 | self.vit1 = ViT(img_dim=64, in_channels=nb_filter[4], embedding_dim=nb_filter[4], head_num=1, mlp_dim=64 * 64, 101 | block_num=1, patch_dim=1, classification=False, num_classes=1) 102 | 103 | # self.conv4_1 = self._make_layer(block, nb_filter[4] + nb_filter[4], nb_filter[4]) 104 | 105 | self.conv3_1_1 = self._make_layer(block, 106 | nb_filter[2] + nb_filter[2] + nb_filter[2] + nb_filter[2] + nb_filter[4], 107 | nb_filter[4]) 108 | 109 | self.conv3_1 = self._make_layer(block, nb_filter[3] + nb_filter[4], nb_filter[3]) 110 | # for name,value in self.conv3_1.named_parameters(): 111 | # value.requires_grad = False 112 | self.conv2_2 = self._make_layer(block, nb_filter[2] + nb_filter[3], nb_filter[2]) 113 | # for name,value in self.conv2_2.named_parameters(): 114 | # value.requires_grad = False 115 | self.conv1_3 = self._make_layer(block, nb_filter[1] + nb_filter[2], nb_filter[1]) 116 | # for name,value in self.conv1_3.named_parameters(): 117 | # value.requires_grad = False 118 | self.conv0_4 = self._make_layer(block, nb_filter[0] + nb_filter[1], nb_filter[0]) 119 | # for name,value in self.conv0_4.named_parameters(): 120 | # value.requires_grad = False 121 | 122 | # self.final1 = nn.Conv2d(nb_filter[0]+3, num_classes, kernel_size=1) 123 | self.head = _FCNHead(nb_filter[0], channels=num_classes, momentum=0.9) 124 | 125 | # self.final2 = nn.BatchNorm2d(num_classes) 126 | self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1) 127 | # for name,value in self.final.named_parameters(): 128 | # value.requires_grad = False 129 | 130 | def _make_layer(self, block, input_channels, output_channels, num_blocks=1): 131 | layers = [] 132 | layers.append(block(input_channels, output_channels)) 133 | for i in range(num_blocks - 1): 134 | layers.append(block(output_channels, output_channels)) 135 | return nn.Sequential(*layers) 136 | 137 | def forward(self, input): 138 | # x0_0_0 = self.conv0_0_0(input) 139 | x0_0 = self.conv0_0(input) 140 | # (4,16,256,256) 141 | x1_0 = self.conv1_0(self.pool(x0_0)) 142 | # (4,32,128,128) 143 | x2_0 = self.conv2_0(self.pool(x1_0)) 144 | # (4,64,64,64) 145 | x3_0 = self.conv3_0(self.pool(x2_0)) # (4,128,32,32) 146 | 147 | out = self.conv4_0(self.pool(x3_0)) 148 | # (4,256,16,16) 149 | 150 | out = torch.cat([rearrange(self.vit2(x3_0), "b (x y) c -> b c x y", x=32, y=32), 151 | rearrange(self.vit3(x2_0), "b (x y) c -> b c x y", x=32, y=32), 152 | rearrange(self.vit4(x1_0), "b (x y) c -> b c x y", x=32, y=32), 153 | rearrange(self.vit5(x0_0), "b (x y) c -> b c x y", x=32, y=32), out], 1) 154 | 155 | out = self.conv3_1_1(out) 156 | 157 | out = self.conv3_1(torch.cat([x3_0, self.up(out)], 1)) 158 | 159 | out = self.conv2_2(torch.cat([x2_0, self.up(out)], 1)) 160 | 161 | out = self.conv1_3(torch.cat([x1_0, self.up(out)], 1)) 162 | 163 | out = self.conv0_4(torch.cat([x0_0, self.up(out)], 1)) 164 | 165 | out = self.final(out) 166 | 167 | return out 168 | 169 | 170 | if __name__ == '__main__': 171 | x = torch.rand(8, 3, 512, 512) 172 | model = res_UNet() 173 | out = model(x) 174 | print(out.size()) 175 | -------------------------------------------------------------------------------- /model/ACM/acm.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/9/21 11:09 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : acm.py 5 | # @Software: PyCharm 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from model.ACM.fusion import AsymBiChaFuseReduce, BiLocalChaFuseReduce, BiGlobalChaFuseReduce 11 | 12 | 13 | class ResidualBlock(nn.Module): 14 | def __init__(self, in_channels, out_channels, stride, downsample): 15 | super(ResidualBlock, self).__init__() 16 | self.body = nn.Sequential( 17 | nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False), 18 | nn.BatchNorm2d(out_channels), 19 | nn.ReLU(True), 20 | 21 | nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), 22 | nn.BatchNorm2d(out_channels), 23 | ) 24 | if downsample: 25 | self.downsample = nn.Sequential( 26 | nn.Conv2d(in_channels, out_channels, 1, stride, 0, bias=False), 27 | nn.BatchNorm2d(out_channels), 28 | ) 29 | else: 30 | self.downsample = nn.Sequential() 31 | 32 | def forward(self, x): 33 | residual = x 34 | x = self.body(x) 35 | 36 | if self.downsample: 37 | residual = self.downsample(residual) 38 | 39 | out = F.relu(x + residual, True) 40 | return out 41 | 42 | 43 | class _FCNHead(nn.Module): 44 | def __init__(self, in_channels, out_channels): 45 | super(_FCNHead, self).__init__() 46 | inter_channels = in_channels // 4 47 | self.block = nn.Sequential( 48 | nn.Conv2d(in_channels, inter_channels, 3, 1, 1, bias=False), 49 | nn.BatchNorm2d(inter_channels), 50 | nn.ReLU(True), 51 | nn.Dropout(0.1), 52 | nn.Conv2d(inter_channels, out_channels, 1, 1, 0) 53 | ) 54 | 55 | def forward(self, x): 56 | return self.block(x) 57 | 58 | 59 | class ASKCResNetFPN(nn.Module): 60 | def __init__(self, layer_blocks, channels, fuse_mode='AsymBi', **kwargs): 61 | super(ASKCResNetFPN, self).__init__() 62 | 63 | stem_width = channels[0] 64 | self.stem = nn.Sequential( 65 | nn.BatchNorm2d(3), 66 | nn.Conv2d(3, stem_width, 3, 2, 1, bias=False), 67 | nn.BatchNorm2d(stem_width), 68 | nn.ReLU(True), 69 | 70 | nn.Conv2d(stem_width, stem_width, 3, 1, 1, bias=False), 71 | nn.BatchNorm2d(stem_width), 72 | nn.ReLU(True), 73 | 74 | nn.Conv2d(stem_width, stem_width * 2, 3, 1, 1, bias=False), 75 | nn.BatchNorm2d(stem_width * 2), 76 | nn.ReLU(True), 77 | nn.MaxPool2d(3, 2, 1) 78 | ) 79 | 80 | self.layer1 = self._make_layer(block=ResidualBlock, block_num=layer_blocks[0], 81 | in_channels=channels[1], out_channels=channels[1], stride=1) 82 | self.layer2 = self._make_layer(block=ResidualBlock, block_num=layer_blocks[1], 83 | in_channels=channels[1], out_channels=channels[2], stride=2) 84 | self.layer3 = self._make_layer(block=ResidualBlock, block_num=layer_blocks[2], 85 | in_channels=channels[2], out_channels=channels[3], stride=2) 86 | 87 | self.fuse23 = self._fuse_layer(channels[3], channels[2], channels[2], fuse_mode) 88 | self.fuse12 = self._fuse_layer(channels[2], channels[1], channels[1], fuse_mode) 89 | 90 | self.head = _FCNHead(channels[1], 1) 91 | 92 | def forward(self, x): 93 | _, _, hei, wid = x.shape 94 | 95 | x = self.stem(x) 96 | c1 = self.layer1(x) 97 | c2 = self.layer2(c1) 98 | out = self.layer3(c2) 99 | 100 | out = F.interpolate(out, size=[hei // 8, wid // 8], mode='bilinear') 101 | out = self.fuse23(out, c2) 102 | 103 | out = F.interpolate(out, size=[hei // 4, wid // 4], mode='bilinear') 104 | out = self.fuse12(out, c1) 105 | 106 | pred = self.head(out) 107 | out = F.interpolate(pred, size=[hei, wid], mode='bilinear') 108 | 109 | return out 110 | 111 | def _make_layer(self, block, block_num, in_channels, out_channels, stride): 112 | downsample = (in_channels != out_channels) or (stride != 1) 113 | layer = [] 114 | layer.append(block(in_channels, out_channels, stride, downsample)) 115 | for _ in range(block_num - 1): 116 | layer.append(block(out_channels, out_channels, 1, False)) 117 | return nn.Sequential(*layer) 118 | 119 | def _fuse_layer(self, in_high_channels, in_low_channels, out_channels, fuse_mode='AsymBi'): 120 | assert fuse_mode in ['BiLocal', 'AsymBi', 'BiGlobal'] 121 | if fuse_mode == 'BiLocal': 122 | fuse_layer = BiLocalChaFuseReduce(in_high_channels, in_low_channels, out_channels) 123 | elif fuse_mode == 'AsymBi': 124 | fuse_layer = AsymBiChaFuseReduce(in_high_channels, in_low_channels, out_channels) 125 | elif fuse_mode == 'BiGlobal': 126 | fuse_layer = BiGlobalChaFuseReduce(in_high_channels, in_low_channels, out_channels) 127 | else: 128 | NameError 129 | return fuse_layer 130 | 131 | 132 | class ASKCResUNet(nn.Module): 133 | def __init__(self, layer_blocks, channels, fuse_mode='AsymBi', **kwargs): 134 | super(ASKCResUNet, self).__init__() 135 | 136 | stem_width = int(channels[0]) 137 | self.stem = nn.Sequential( 138 | nn.BatchNorm2d(3), 139 | nn.Conv2d(3, stem_width, 3, 2, 1, bias=False), 140 | nn.BatchNorm2d(stem_width), 141 | nn.ReLU(True), 142 | 143 | nn.Conv2d(stem_width, stem_width, 3, 1, 1, bias=False), 144 | nn.BatchNorm2d(stem_width), 145 | nn.ReLU(True), 146 | 147 | nn.Conv2d(stem_width, 2 * stem_width, 3, 1, 1, bias=False), 148 | nn.BatchNorm2d(2 * stem_width), 149 | nn.ReLU(True), 150 | 151 | nn.MaxPool2d(3, 2, 1), 152 | ) 153 | 154 | self.layer1 = self._make_layer(block=ResidualBlock, block_num=layer_blocks[0], 155 | in_channels=channels[1], out_channels=channels[1], stride=1) 156 | self.layer2 = self._make_layer(block=ResidualBlock, block_num=layer_blocks[1], 157 | in_channels=channels[1], out_channels=channels[2], stride=2) 158 | self.layer3 = self._make_layer(block=ResidualBlock, block_num=layer_blocks[2], 159 | in_channels=channels[2], out_channels=channels[3], stride=2) 160 | 161 | self.deconv2 = nn.ConvTranspose2d(channels[3], channels[2], 4, 2, 1) 162 | self.fuse2 = self._fuse_layer(channels[2], channels[2], channels[2], fuse_mode) 163 | self.uplayer2 = self._make_layer(block=ResidualBlock, block_num=layer_blocks[1], 164 | in_channels=channels[2], out_channels=channels[2], stride=1) 165 | 166 | self.deconv1 = nn.ConvTranspose2d(channels[2], channels[1], 4, 2, 1) 167 | self.fuse1 = self._fuse_layer(channels[1], channels[1], channels[1], fuse_mode) 168 | self.uplayer1 = self._make_layer(block=ResidualBlock, block_num=layer_blocks[0], 169 | in_channels=channels[1], out_channels=channels[1], stride=1) 170 | 171 | self.head = _FCNHead(channels[1], 1) 172 | 173 | def forward(self, x): 174 | _, _, hei, wid = x.shape 175 | 176 | x = self.stem(x) 177 | c1 = self.layer1(x) 178 | c2 = self.layer2(c1) 179 | c3 = self.layer3(c2) 180 | 181 | deconc2 = self.deconv2(c3) 182 | fusec2 = self.fuse2(deconc2, c2) 183 | upc2 = self.uplayer2(fusec2) 184 | 185 | deconc1 = self.deconv1(upc2) 186 | fusec1 = self.fuse1(deconc1, c1) 187 | upc1 = self.uplayer1(fusec1) 188 | 189 | pred = self.head(upc1) 190 | out = F.interpolate(pred, size=[hei, wid], mode='bilinear') 191 | return out 192 | 193 | def _make_layer(self, block, block_num, in_channels, out_channels, stride): 194 | layer = [] 195 | downsample = (in_channels != out_channels) or (stride != 1) 196 | layer.append(block(in_channels, out_channels, stride, downsample)) 197 | for _ in range(block_num - 1): 198 | layer.append(block(out_channels, out_channels, 1, False)) 199 | return nn.Sequential(*layer) 200 | 201 | def _fuse_layer(self, in_high_channels, in_low_channels, out_channels, fuse_mode='AsymBi'): 202 | assert fuse_mode in ['BiLocal', 'AsymBi', 'BiGlobal'] 203 | if fuse_mode == 'BiLocal': 204 | fuse_layer = BiLocalChaFuseReduce(in_high_channels, in_low_channels, out_channels) 205 | elif fuse_mode == 'AsymBi': 206 | fuse_layer = AsymBiChaFuseReduce(in_high_channels, in_low_channels, out_channels) 207 | elif fuse_mode == 'BiGlobal': 208 | fuse_layer = BiGlobalChaFuseReduce(in_high_channels, in_low_channels, out_channels) 209 | else: 210 | NameError 211 | return fuse_layer 212 | 213 | 214 | if __name__ == '__main__': 215 | x = torch.rand(8, 3, 512, 512) 216 | model = ASKCResNetFPN([4, 4, 4], [8, 16, 32, 64]) 217 | out = model(x) 218 | print(out.size()) 219 | -------------------------------------------------------------------------------- /model/DNANet/dna_net.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/4/6 16:17 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : dna_net.py 5 | # @Software: PyCharm 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class VGG_CBAM_Block(nn.Module): 11 | def __init__(self, in_channels, out_channels): 12 | super().__init__() 13 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1) 14 | self.bn1 = nn.BatchNorm2d(out_channels) 15 | self.relu = nn.ReLU(inplace=True) 16 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1) 17 | self.bn2 = nn.BatchNorm2d(out_channels) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.ca = ChannelAttention(out_channels) 20 | self.sa = SpatialAttention() 21 | 22 | def forward(self, x): 23 | out = self.conv1(x) 24 | out = self.bn1(out) 25 | out = self.relu(out) 26 | out = self.conv2(out) 27 | out = self.bn2(out) 28 | out = self.ca(out) * out 29 | out = self.sa(out) * out 30 | out = self.relu(out) 31 | return out 32 | 33 | 34 | class ChannelAttention(nn.Module): 35 | def __init__(self, in_planes, ratio=16): 36 | super(ChannelAttention, self).__init__() 37 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 38 | self.max_pool = nn.AdaptiveMaxPool2d(1) 39 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 40 | self.relu1 = nn.ReLU() 41 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 42 | self.sigmoid = nn.Sigmoid() 43 | 44 | def forward(self, x): 45 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 46 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 47 | out = avg_out + max_out 48 | return self.sigmoid(out) 49 | 50 | 51 | class SpatialAttention(nn.Module): 52 | def __init__(self, kernel_size=7): 53 | super(SpatialAttention, self).__init__() 54 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 55 | padding = 3 if kernel_size == 7 else 1 56 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 57 | self.sigmoid = nn.Sigmoid() 58 | 59 | def forward(self, x): 60 | avg_out = torch.mean(x, dim=1, keepdim=True) 61 | max_out, _ = torch.max(x, dim=1, keepdim=True) 62 | x = torch.cat([avg_out, max_out], dim=1) 63 | x = self.conv1(x) 64 | return self.sigmoid(x) 65 | 66 | 67 | class Res_CBAM_block(nn.Module): 68 | def __init__(self, in_channels, out_channels, stride=1): 69 | super(Res_CBAM_block, self).__init__() 70 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1) 71 | self.bn1 = nn.BatchNorm2d(out_channels) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) 74 | self.bn2 = nn.BatchNorm2d(out_channels) 75 | if stride != 1 or out_channels != in_channels: 76 | self.shortcut = nn.Sequential( 77 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride), 78 | nn.BatchNorm2d(out_channels)) 79 | else: 80 | self.shortcut = None 81 | 82 | self.ca = ChannelAttention(out_channels) 83 | self.sa = SpatialAttention() 84 | 85 | def forward(self, x): 86 | residual = x 87 | if self.shortcut is not None: 88 | residual = self.shortcut(x) 89 | out = self.conv1(x) 90 | out = self.bn1(out) 91 | out = self.relu(out) 92 | out = self.conv2(out) 93 | out = self.bn2(out) 94 | out = self.ca(out) * out 95 | out = self.sa(out) * out 96 | out += residual 97 | out = self.relu(out) 98 | return out 99 | 100 | 101 | class DNANet(nn.Module): 102 | def __init__(self, num_classes, input_channels, block_name, num_blocks, nb_filter, deep_supervision=False, **kwargs): 103 | super(DNANet, self).__init__() 104 | self.relu = nn.ReLU(inplace=True) 105 | self.deep_supervision = deep_supervision 106 | self.pool = nn.MaxPool2d(2, 2) 107 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 108 | self.down = nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=True) 109 | 110 | self.up_4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) 111 | self.up_8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) 112 | self.up_16 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True) 113 | 114 | block = Res_CBAM_block if block_name == 'resnet' else VGG_CBAM_Block 115 | self.conv0_0 = self._make_layer(block, input_channels, nb_filter[0]) 116 | self.conv1_0 = self._make_layer(block, nb_filter[0], nb_filter[1], num_blocks[0]) 117 | self.conv2_0 = self._make_layer(block, nb_filter[1], nb_filter[2], num_blocks[1]) 118 | self.conv3_0 = self._make_layer(block, nb_filter[2], nb_filter[3], num_blocks[2]) 119 | self.conv4_0 = self._make_layer(block, nb_filter[3], nb_filter[4], num_blocks[3]) 120 | 121 | self.conv0_1 = self._make_layer(block, nb_filter[0] + nb_filter[1], nb_filter[0]) 122 | self.conv1_1 = self._make_layer(block, nb_filter[1] + nb_filter[2] + nb_filter[0], nb_filter[1], num_blocks[0]) 123 | self.conv2_1 = self._make_layer(block, nb_filter[2] + nb_filter[3] + nb_filter[1], nb_filter[2], num_blocks[1]) 124 | self.conv3_1 = self._make_layer(block, nb_filter[3] + nb_filter[4] + nb_filter[2], nb_filter[3], num_blocks[2]) 125 | 126 | self.conv0_2 = self._make_layer(block, nb_filter[0] * 2 + nb_filter[1], nb_filter[0]) 127 | self.conv1_2 = self._make_layer(block, nb_filter[1] * 2 + nb_filter[2] + nb_filter[0], nb_filter[1], 128 | num_blocks[0]) 129 | self.conv2_2 = self._make_layer(block, nb_filter[2] * 2 + nb_filter[3] + nb_filter[1], nb_filter[2], 130 | num_blocks[1]) 131 | 132 | self.conv0_3 = self._make_layer(block, nb_filter[0] * 3 + nb_filter[1], nb_filter[0]) 133 | self.conv1_3 = self._make_layer(block, nb_filter[1] * 3 + nb_filter[2] + nb_filter[0], nb_filter[1], 134 | num_blocks[0]) 135 | 136 | self.conv0_4 = self._make_layer(block, nb_filter[0] * 4 + nb_filter[1], nb_filter[0]) 137 | 138 | self.conv0_4_final = self._make_layer(block, nb_filter[0] * 5, nb_filter[0]) 139 | 140 | self.conv0_4_1x1 = nn.Conv2d(nb_filter[4], nb_filter[0], kernel_size=1, stride=1) 141 | self.conv0_3_1x1 = nn.Conv2d(nb_filter[3], nb_filter[0], kernel_size=1, stride=1) 142 | self.conv0_2_1x1 = nn.Conv2d(nb_filter[2], nb_filter[0], kernel_size=1, stride=1) 143 | self.conv0_1_1x1 = nn.Conv2d(nb_filter[1], nb_filter[0], kernel_size=1, stride=1) 144 | 145 | if self.deep_supervision: 146 | self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1) 147 | self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1) 148 | self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1) 149 | self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1) 150 | else: 151 | self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1) 152 | 153 | def _make_layer(self, block, input_channels, output_channels, num_blocks=1): 154 | layers = [] 155 | layers.append(block(input_channels, output_channels)) 156 | for i in range(num_blocks - 1): 157 | layers.append(block(output_channels, output_channels)) 158 | return nn.Sequential(*layers) 159 | 160 | def forward(self, input): 161 | x0_0 = self.conv0_0(input) 162 | x1_0 = self.conv1_0(self.pool(x0_0)) 163 | x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1)) 164 | 165 | x2_0 = self.conv2_0(self.pool(x1_0)) 166 | x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0), self.down(x0_1)], 1)) 167 | x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1)) 168 | 169 | x3_0 = self.conv3_0(self.pool(x2_0)) 170 | x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0), self.down(x1_1)], 1)) 171 | x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1), self.down(x0_2)], 1)) 172 | x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1)) 173 | 174 | x4_0 = self.conv4_0(self.pool(x3_0)) 175 | x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0), self.down(x2_1)], 1)) 176 | x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1), self.down(x1_2)], 1)) 177 | x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2), self.down(x0_3)], 1)) 178 | x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1)) 179 | 180 | Final_x0_4 = self.conv0_4_final( 181 | torch.cat([self.up_16(self.conv0_4_1x1(x4_0)), self.up_8(self.conv0_3_1x1(x3_1)), 182 | self.up_4(self.conv0_2_1x1(x2_2)), self.up(self.conv0_1_1x1(x1_3)), x0_4], 1)) 183 | 184 | if self.deep_supervision: 185 | output1 = self.final1(x0_1) 186 | output2 = self.final2(x0_2) 187 | output3 = self.final3(x0_3) 188 | output4 = self.final4(Final_x0_4) 189 | return [output1, output2, output3, output4] 190 | else: 191 | output = self.final(Final_x0_4) 192 | return output 193 | -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/4/6 14:54 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : metric.py 5 | # @Software: PyCharm 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import numpy as np 10 | from skimage import measure 11 | 12 | 13 | class SigmoidMetric(): 14 | def __init__(self, score_thresh=0): 15 | self.score_thresh = score_thresh 16 | self.reset() 17 | 18 | def update(self, pred, labels): 19 | correct, labeled = self.batch_pix_accuracy(pred, labels) 20 | inter, union = self.batch_intersection_union(pred, labels) 21 | 22 | self.total_correct += correct 23 | self.total_label += labeled 24 | self.total_inter += inter 25 | self.total_union += union 26 | 27 | def get(self): 28 | """Gets the current evaluation result.""" 29 | pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label) 30 | IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union) 31 | mIoU = IoU.mean() 32 | return pixAcc, mIoU 33 | 34 | def reset(self): 35 | """Resets the internal evaluation result to initial state.""" 36 | self.total_inter = 0 37 | self.total_union = 0 38 | self.total_correct = 0 39 | self.total_label = 0 40 | 41 | def batch_pix_accuracy(self, output, target): 42 | assert output.shape == target.shape 43 | output = output.cpu().detach().numpy() 44 | target = target.cpu().detach().numpy() 45 | 46 | predict = (output > self.score_thresh).astype('int64') # P 47 | pixel_labeled = np.sum(target > 0) # T 48 | pixel_correct = np.sum((predict == target) * (target > 0)) # TP 49 | assert pixel_correct <= pixel_labeled 50 | return pixel_correct, pixel_labeled 51 | 52 | def batch_intersection_union(self, output, target): 53 | mini = 1 54 | maxi = 1 # nclass 55 | nbins = 1 # nclass 56 | predict = (output.cpu().detach().numpy() > self.score_thresh).astype('int64') # P 57 | target = target.cpu().numpy().astype('int64') # T 58 | intersection = predict * (predict == target) # TP 59 | 60 | 61 | # areas of intersection and union 62 | area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi)) 63 | area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi)) 64 | area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi)) 65 | area_union = area_pred + area_lab - area_inter 66 | assert (area_inter <= area_union).all() 67 | return area_inter, area_union 68 | 69 | 70 | class SamplewiseSigmoidMetric(): 71 | def __init__(self, nclass, score_thresh=0.5): 72 | self.nclass = nclass 73 | self.score_thresh = score_thresh 74 | self.reset() 75 | 76 | def update(self, preds, labels): 77 | """Updates the internal evaluation result.""" 78 | inter_arr, union_arr = self.batch_intersection_union(preds, labels, 79 | self.nclass, self.score_thresh) 80 | self.total_inter = np.append(self.total_inter, inter_arr) 81 | self.total_union = np.append(self.total_union, union_arr) 82 | 83 | def get(self): 84 | """Gets the current evaluation result.""" 85 | IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union) 86 | mIoU = IoU.mean() 87 | return IoU, mIoU 88 | 89 | def reset(self): 90 | """Resets the internal evaluation result to initial state.""" 91 | self.total_inter = np.array([]) 92 | self.total_union = np.array([]) 93 | self.total_correct = np.array([]) 94 | self.total_label = np.array([]) 95 | 96 | def batch_intersection_union(self, output, target, nclass, score_thresh): 97 | """mIoU""" 98 | # inputs are tensor 99 | # the category 0 is ignored class, typically for background / boundary 100 | mini = 1 101 | maxi = 1 # nclass 102 | nbins = 1 # nclass 103 | 104 | predict = (F.sigmoid(output).cpu().detach().numpy() > score_thresh).astype('int64') # P 105 | target = target.cpu().detach().numpy().astype('int64') # T 106 | intersection = predict * (predict == target) # TP 107 | 108 | num_sample = intersection.shape[0] 109 | area_inter_arr = np.zeros(num_sample) 110 | area_pred_arr = np.zeros(num_sample) 111 | area_lab_arr = np.zeros(num_sample) 112 | area_union_arr = np.zeros(num_sample) 113 | 114 | for b in range(num_sample): 115 | # areas of intersection and union 116 | area_inter, _ = np.histogram(intersection[b], bins=nbins, range=(mini, maxi)) 117 | area_inter_arr[b] = area_inter 118 | 119 | area_pred, _ = np.histogram(predict[b], bins=nbins, range=(mini, maxi)) 120 | area_pred_arr[b] = area_pred 121 | 122 | area_lab, _ = np.histogram(target[b], bins=nbins, range=(mini, maxi)) 123 | area_lab_arr[b] = area_lab 124 | 125 | area_union = area_pred + area_lab - area_inter 126 | area_union_arr[b] = area_union 127 | 128 | assert (area_inter <= area_union).all() 129 | 130 | return area_inter_arr, area_union_arr 131 | 132 | 133 | class ROCMetric(): 134 | """Computes pixAcc and mIoU metric scores 135 | """ 136 | 137 | def __init__(self, nclass, bins): # bin的意义实际上是确定ROC曲线上的threshold取多少个离散值 138 | super(ROCMetric, self).__init__() 139 | self.nclass = nclass 140 | self.bins = bins 141 | self.tp_arr = np.zeros(self.bins + 1) 142 | self.pos_arr = np.zeros(self.bins + 1) 143 | self.fp_arr = np.zeros(self.bins + 1) 144 | self.neg_arr = np.zeros(self.bins + 1) 145 | self.class_pos = np.zeros(self.bins + 1) 146 | # self.reset() 147 | 148 | def update(self, preds, labels): 149 | for iBin in range(self.bins + 1): 150 | score_thresh = (iBin + 0.0) / self.bins 151 | # print(iBin, "-th, score_thresh: ", score_thresh) 152 | i_tp, i_pos, i_fp, i_neg, i_class_pos = cal_tp_pos_fp_neg(preds, labels, self.nclass, score_thresh) 153 | self.tp_arr[iBin] += i_tp 154 | self.pos_arr[iBin] += i_pos 155 | self.fp_arr[iBin] += i_fp 156 | self.neg_arr[iBin] += i_neg 157 | self.class_pos[iBin] += i_class_pos 158 | 159 | def get(self): 160 | tp_rates = self.tp_arr / (self.pos_arr + 0.001) 161 | fp_rates = self.fp_arr / (self.neg_arr + 0.001) 162 | 163 | recall = self.tp_arr / (self.pos_arr + 0.001) 164 | precision = self.tp_arr / (self.class_pos + 0.001) 165 | f1_score = (2.0 * recall[5] * precision[5]) / (recall[5] + precision[5] + 0.00001) 166 | 167 | return tp_rates, fp_rates, recall, precision, f1_score 168 | 169 | def reset(self): 170 | self.tp_arr = np.zeros([11]) 171 | self.pos_arr = np.zeros([11]) 172 | self.fp_arr = np.zeros([11]) 173 | self.neg_arr = np.zeros([11]) 174 | self.class_pos = np.zeros([11]) 175 | 176 | 177 | class PD_FA(): 178 | def __init__(self, nclass, bins, cfg): 179 | super(PD_FA, self).__init__() 180 | self.nclass = nclass 181 | self.bins = bins 182 | self.image_area_total = [] 183 | self.image_area_match = [] 184 | self.FA = np.zeros(self.bins + 1) 185 | self.PD = np.zeros(self.bins + 1) 186 | self.target = np.zeros(self.bins + 1) 187 | self.cfg = cfg 188 | 189 | def update(self, preds, labels): 190 | 191 | for iBin in range(self.bins + 1): 192 | score_thresh = iBin * (255 / self.bins) 193 | batch = preds.size()[0] 194 | for b in range(batch): 195 | predits = np.array((preds[b, :, :, :] > score_thresh).cpu()).astype('int64') 196 | predits = np.reshape(predits, (self.cfg.data['crop_size'], self.cfg.data['crop_size'])) 197 | labelss = np.array((labels[b, :, :, :]).cpu()).astype('int64') # P 198 | labelss = np.reshape(labelss, (self.cfg.data['crop_size'], self.cfg.data['crop_size'])) 199 | 200 | image = measure.label(predits, connectivity=2) 201 | coord_image = measure.regionprops(image) 202 | label = measure.label(labelss, connectivity=2) 203 | coord_label = measure.regionprops(label) 204 | 205 | self.target[iBin] += len(coord_label) 206 | self.image_area_total = [] 207 | self.image_area_match = [] 208 | self.distance_match = [] 209 | self.dismatch = [] 210 | 211 | for K in range(len(coord_image)): 212 | area_image = np.array(coord_image[K].area) 213 | self.image_area_total.append(area_image) 214 | 215 | for i in range(len(coord_label)): 216 | centroid_label = np.array(list(coord_label[i].centroid)) 217 | for m in range(len(coord_image)): 218 | centroid_image = np.array(list(coord_image[m].centroid)) 219 | distance = np.linalg.norm(centroid_image - centroid_label) 220 | area_image = np.array(coord_image[m].area) 221 | if distance < 3: 222 | self.distance_match.append(distance) 223 | self.image_area_match.append(area_image) 224 | 225 | del coord_image[m] 226 | break 227 | 228 | self.dismatch = [x for x in self.image_area_total if x not in self.image_area_match] 229 | self.FA[iBin] += np.sum(self.dismatch) 230 | self.PD[iBin] += len(self.distance_match) 231 | 232 | def get(self, img_num): 233 | 234 | Final_FA = self.FA / ((self.cfg.data['crop_size'] * self.cfg.data['crop_size']) * img_num) 235 | Final_PD = self.PD / self.target 236 | 237 | return Final_FA, Final_PD 238 | 239 | def reset(self): 240 | self.FA = np.zeros([self.bins + 1]) 241 | self.PD = np.zeros([self.bins + 1]) 242 | 243 | 244 | def cal_tp_pos_fp_neg(output, target, nclass, score_thresh): 245 | predict = (torch.sigmoid(output) > score_thresh).float() 246 | if len(target.shape) == 3: 247 | target = np.expand_dims(target.float(), axis=1) 248 | elif len(target.shape) == 4: 249 | target = target.float() 250 | else: 251 | raise ValueError("Unknown target dimension") 252 | intersection = predict * ((predict == target).float()) 253 | tp = intersection.sum() 254 | fp = (predict * ((predict != target).float())).sum() 255 | tn = ((1 - predict) * ((predict == target).float())).sum() 256 | fn = (((predict != target).float()) * (1 - predict)).sum() 257 | pos = tp + fn 258 | neg = fp + tn 259 | class_pos = tp + fp 260 | return tp, pos, fp, neg, class_pos 261 | 262 | 263 | if __name__ == '__main__': 264 | pred = torch.rand(8, 1, 512, 512) 265 | target = torch.rand(8, 1, 512, 512) 266 | m1 = SigmoidMetric() 267 | m2 = SamplewiseSigmoidMetric(nclass=1, score_thresh=0.5) 268 | m1.update(pred, target) 269 | m2.update(pred, target) 270 | pixAcc, mIoU = m1.get() 271 | _, nIoU = m2.get() 272 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/9/14 22:11 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : tools.py 5 | # @Software: PyCharm 6 | 7 | import random 8 | import torch.distributed 9 | import torch.nn 10 | from utils.metric import * 11 | from torch.utils.tensorboard import SummaryWriter 12 | from utils.logs import * 13 | import shutil 14 | from utils.save_model import * 15 | from utils.drawing import * 16 | import logging 17 | 18 | 19 | def random_seed(n): 20 | random.seed(n) 21 | np.random.seed(n) 22 | torch.manual_seed(n) 23 | torch.cuda.manual_seed_all(n) 24 | 25 | 26 | def empty_function(): 27 | pass 28 | 29 | 30 | def model_wrapper(model_dict): 31 | new_dict = {} 32 | for k, v in model_dict.items(): 33 | new_dict['decode_head.' + k] = v 34 | return new_dict 35 | 36 | 37 | def init_metrics(args, optimizer, checkpoint=None): 38 | best_mIoU, best_nIoU, best_f1 = 0.0, 0.0, 0.0 39 | train_loss, test_loss, mIoU, nIoU, f1, num_epoch = [], [], [], [], [], [] 40 | if args.resume_from: 41 | best_mIoU = checkpoint['best_mIoU'] 42 | best_nIoU = checkpoint['best_nIoU'] 43 | best_f1 = checkpoint['best_f1'] 44 | train_loss = checkpoint['train_loss'] 45 | test_loss = checkpoint['test_loss'] 46 | mIoU = checkpoint['mIoU'] 47 | nIoU = checkpoint['nIoU'] 48 | f1 = checkpoint['f1'] 49 | num_epoch = checkpoint['num_epoch'] 50 | optimizer.load_state_dict(checkpoint['optimizer']) 51 | iou_metric = SigmoidMetric() 52 | nIoU_metric = SamplewiseSigmoidMetric(1, score_thresh=0.5) 53 | ROC = ROCMetric(1, 10) 54 | 55 | return optimizer, {'best_mIoU': best_mIoU, 'best_nIoU': best_nIoU, 'best_f1': best_f1, 'train_loss': train_loss, 56 | 'test_loss': test_loss, 'mIoU': mIoU, 'nIoU': nIoU, 'f1': f1, 'num_epoch': num_epoch, 57 | 'iou_metric': iou_metric, 'nIoU_metric': nIoU_metric, 'ROC': ROC} 58 | 59 | 60 | def init_data(args, data): 61 | train_sample = None 62 | if args.local_rank != -1: 63 | train_sample, train_data, test_data, train_data_len, test_data_len = data 64 | else: 65 | train_data, test_data, train_data_len, test_data_len = data 66 | return {'train_sample': train_sample, 'train_data': train_data, 'test_data': test_data, 67 | 'train_data_len': train_data_len, 'test_data_len': test_data_len} 68 | 69 | 70 | def init_model(args, cfg, model, device): 71 | checkpoint = None 72 | if args.load_from: 73 | cfg.load_from = args.load_from 74 | checkpoint = torch.load(args.load_from) 75 | model.load_state_dict(checkpoint) 76 | 77 | if args.resume_from: 78 | cfg.resume_from = args.resume_from 79 | checkpoint = torch.load(args.resume_from) 80 | model.load_state_dict(checkpoint['state_dict']) 81 | print("Model Initializing") 82 | 83 | if args.local_rank != -1: 84 | model.to(device) 85 | model = torch.nn.parallel.DistributedDataParallel( 86 | model, device_ids=[args.local_rank], output_device=args.local_rank, 87 | find_unused_parameters=cfg.find_unused_parameters) 88 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 89 | else: 90 | model = model.to(device) 91 | 92 | return model, checkpoint 93 | 94 | 95 | def init_devices(args, cfg): 96 | if args.local_rank != -1: 97 | device = torch.device('cuda', args.local_rank) 98 | torch.cuda.set_device(args.local_rank) 99 | torch.distributed.init_process_group(backend=cfg.dist_params['backend']) 100 | else: 101 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 102 | random_seed(cfg.random_seed) 103 | return device 104 | 105 | 106 | def save_log(args, cfg, model): 107 | save_dir, train_log_file_name, write = None, None, None 108 | if args.local_rank <= 0: 109 | save_dir = args.config.split('/')[-1][:-3] 110 | train_log_file_name = train_log_file() 111 | make_log_dir(save_dir, train_log_file_name) 112 | save_config_log(cfg, save_dir, train_log_file_name) 113 | save_model_struct(save_dir, train_log_file_name, model) 114 | if 'develop' in cfg: 115 | shutil.copy(cfg.develop['source_file_root'], 116 | os.path.join('work_dirs', save_dir, train_log_file_name, 'model.py')) 117 | write = SummaryWriter(log_dir='work_dirs/' + save_dir + '/' + train_log_file_name + '/tf_logs') 118 | return save_dir, train_log_file_name, write 119 | 120 | 121 | def data2device(args, data, device): 122 | img, mask = data 123 | if args.local_rank != -1: 124 | img = img.cuda() 125 | mask = mask.cuda() 126 | else: 127 | img = img.to(device) 128 | mask = mask.to(device) 129 | return img, mask 130 | 131 | 132 | def compute_loss(preds, mask, deep_supervision, cfg, criterion): 133 | if deep_supervision and cfg.model['decode_head']['deep_supervision']: 134 | loss = [] 135 | for pre in preds: 136 | loss.append(criterion(pre, mask)) 137 | loss = sum(loss) 138 | preds = preds[-1] 139 | else: 140 | loss = criterion(preds, mask) 141 | return loss, preds 142 | 143 | 144 | def show_log(mode, args, cfg, epoch, losses, save_dir, train_log_file, **kwargs): 145 | if mode not in ['train', 'test']: 146 | raise ValueError('The parameter "mode" input should be "train" or "test"') 147 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s: %(message)s', datefmt='%F %T') 148 | if args.local_rank <= 0: 149 | if mode == 'train': 150 | msg = 'Epoch %d/%d, Iter %d/%d, train loss %.4f, lr %.5f, time %.5f' % ( 151 | epoch, cfg.runner['max_epochs'], kwargs['i'] + 1, 152 | kwargs['data']['train_data_len'] / cfg.data['train_batch'] / cfg.gpus, 153 | np.mean(losses), kwargs['optimizer'].state_dict()['param_groups'][0]['lr'], kwargs['time_elapsed']) 154 | logging.info(msg) 155 | if (kwargs['i'] + 1) % cfg.log_config['interval'] == 0: 156 | save_train_log(save_dir, train_log_file, epoch, cfg.runner['max_epochs'], kwargs['i'] + 1, 157 | kwargs['data']['train_data_len'] / cfg.data['train_batch'] / cfg.gpus, 158 | np.mean(losses), kwargs['optimizer'].state_dict()['param_groups'][0]['lr'], 159 | kwargs['time_elapsed']) 160 | else: 161 | msg = 'Epoch %d/%d, test loss %.4f, mIoU %.4f, nIoU %.4f, F1-score %.4f, best_mIoU %.4f, ' \ 162 | 'best_nIoU %.4f, best_F1-score %.4f' % ( 163 | epoch, cfg.runner['max_epochs'], np.mean(losses), kwargs['IoU'], kwargs['nIoU'], 164 | kwargs['F1_score'], kwargs['metrics']['best_mIoU'], kwargs['metrics']['best_nIoU'], 165 | kwargs['metrics']['best_f1']) 166 | logging.info(msg) 167 | save_test_log(save_dir, train_log_file, epoch, cfg.runner['max_epochs'], 168 | np.mean(losses), kwargs['IoU'], kwargs['nIoU'], kwargs['F1_score'], 169 | kwargs['metrics']['best_mIoU'], kwargs['metrics']['best_nIoU'], kwargs['metrics']['best_f1']) 170 | 171 | 172 | def save_model(mode, args, cfg, epoch, model, losses, optimizer, metrics, save_dir, train_log_file, **kwargs): 173 | if mode not in ['train', 'test']: 174 | raise ValueError('The parameter "mode" input should be "train" or "test"') 175 | if args.local_rank <= 0: 176 | ckpt_info = { 177 | 'epoch': epoch, 178 | 'state_dict': model.module.state_dict() if args.local_rank != -1 else model.state_dict(), 179 | 'loss': np.mean(losses), 180 | 'optimizer': optimizer.state_dict(), 181 | 'train_loss': metrics['train_loss'], 182 | 'test_loss': metrics['test_loss'], 183 | 'num_epoch': metrics['num_epoch'], 184 | 'best_mIoU': metrics['best_mIoU'], 185 | 'best_nIoU': metrics['best_nIoU'], 186 | 'best_f1': metrics['best_f1'], 187 | 'mIoU': metrics['mIoU'], 188 | 'nIoU': metrics['nIoU'], 189 | 'f1': metrics['f1'] 190 | } 191 | if mode == 'train': 192 | save_ckpt(ckpt_info, save_path='work_dirs/' + save_dir + '/' + train_log_file, filename='last.pth.tar') 193 | if cfg.checkpoint_config['by_epoch'] and epoch % cfg.checkpoint_config['interval'] == 0: 194 | save_ckpt(ckpt_info, save_path='work_dirs/' + save_dir + '/' + train_log_file, 195 | filename='epoch_%d' % epoch + '.pth.tar') 196 | else: 197 | if kwargs['IoU'] > metrics['best_mIoU'] or kwargs['nIoU'] > metrics['best_nIoU']: 198 | save_ckpt(ckpt_info, save_path='work_dirs/' + save_dir + '/' + train_log_file, filename='best.pth.tar') 199 | if kwargs['IoU'] > metrics['best_mIoU']: 200 | save_ckpt(ckpt_info, save_path='work_dirs/' + save_dir + '/' + train_log_file, 201 | filename='best_mIoU.pth.tar') 202 | if kwargs['nIoU'] > metrics['best_nIoU']: 203 | save_ckpt(ckpt_info, save_path='work_dirs/' + save_dir + '/' + train_log_file, 204 | filename='best_nIoU.pth.tar') 205 | 206 | 207 | def update_log(mode, args, metrics, write, losses, epoch, **kwargs): 208 | if mode not in ['train', 'test']: 209 | raise ValueError('The parameter "mode" input should be "train" or "test"') 210 | if args.local_rank <= 0: 211 | if mode == 'train': 212 | metrics['train_loss'].append(np.mean(losses)) 213 | metrics['num_epoch'].append(epoch) 214 | write.add_scalar('train/train_loss', np.mean(losses), epoch) 215 | write.add_scalar('train/lr', kwargs['optimizer'].state_dict()['param_groups'][0]['lr'], epoch) 216 | else: 217 | metrics['best_mIoU'] = max(kwargs['IoU'], metrics['best_mIoU']) 218 | metrics['best_nIoU'] = max(kwargs['nIoU'], metrics['best_nIoU']) 219 | metrics['best_f1'] = max(kwargs['F1_score'], metrics['best_f1']) 220 | write.add_scalar('train/test_loss', np.mean(losses), epoch) 221 | write.add_scalar('test/mIoU', kwargs['IoU'], epoch) 222 | write.add_scalar('test/nIoU', kwargs['nIoU'], epoch) 223 | write.add_scalar('test/F1-score', kwargs['F1_score'], epoch) 224 | 225 | 226 | def reset_metrics(metrics): 227 | metrics['iou_metric'].reset() 228 | metrics['nIoU_metric'].reset() 229 | metrics['ROC'].reset() 230 | 231 | 232 | def update_metrics(preds, mask, metrics): 233 | metrics['iou_metric'].update(preds, mask) 234 | metrics['nIoU_metric'].update(preds, mask) 235 | metrics['ROC'].update(preds, mask) 236 | _, IoU = metrics['iou_metric'].get() 237 | _, nIoU = metrics['nIoU_metric'].get() 238 | _, _, _, _, F1_score = metrics['ROC'].get() 239 | return IoU, nIoU, F1_score 240 | 241 | 242 | def append_metrics(args, metrics, losses, IoU, nIoU, F1_score): 243 | if args.local_rank <= 0: 244 | metrics['test_loss'].append(np.mean(losses)) 245 | metrics['mIoU'].append(IoU) 246 | metrics['nIoU'].append(nIoU) 247 | metrics['f1'].append(F1_score) 248 | 249 | 250 | def draw(args, metrics, save_dir, train_log_file): 251 | if args.local_rank <= 0: 252 | drawing_loss(metrics['num_epoch'], metrics['train_loss'], metrics['test_loss'], save_dir, train_log_file) 253 | drawing_iou(metrics['num_epoch'], metrics['mIoU'], metrics['nIoU'], save_dir, train_log_file) 254 | drawing_f1(metrics['num_epoch'], metrics['f1'], save_dir, train_log_file) 255 | -------------------------------------------------------------------------------- /model/AGPCNet/resnet.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/5/18 17:25 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : resnet.py 5 | # @Software: PyCharm 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torchvision import models 10 | 11 | import math 12 | 13 | try: 14 | from torch.hub import load_state_dict_from_url 15 | except ImportError: 16 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 17 | 18 | __all__ = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] 19 | 20 | model_urls = { 21 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 22 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 23 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 24 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 25 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 26 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 27 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 28 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 29 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 30 | } 31 | 32 | 33 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 34 | """3x3 convolution with padding""" 35 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 36 | padding=dilation, groups=groups, bias=False, dilation=dilation) 37 | 38 | 39 | def conv1x1(in_planes, out_planes, stride=1): 40 | """1x1 convolution""" 41 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 42 | 43 | 44 | class BasicBlock(nn.Module): 45 | expansion = 1 46 | 47 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 48 | base_width=64, dilation=1, norm_layer=None): 49 | super(BasicBlock, self).__init__() 50 | if norm_layer is None: 51 | norm_layer = nn.BatchNorm2d 52 | if groups != 1 or base_width != 64: 53 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 54 | if dilation > 1: 55 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 56 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 57 | self.conv1 = conv3x3(inplanes, planes, stride) 58 | self.bn1 = norm_layer(planes) 59 | self.relu = nn.ReLU(inplace=True) 60 | self.conv2 = conv3x3(planes, planes) 61 | self.bn2 = norm_layer(planes) 62 | self.downsample = downsample 63 | self.stride = stride 64 | 65 | def forward(self, x): 66 | identity = x 67 | 68 | out = self.conv1(x) 69 | out = self.bn1(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv2(out) 73 | out = self.bn2(out) 74 | 75 | if self.downsample is not None: 76 | identity = self.downsample(x) 77 | 78 | out += identity 79 | out = self.relu(out) 80 | 81 | return out 82 | 83 | 84 | class Bottleneck(nn.Module): 85 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 86 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 87 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 88 | # This variant is also known as ResNet V1.5 and improves accuracy according to 89 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 90 | 91 | expansion = 4 92 | 93 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 94 | base_width=64, dilation=1, norm_layer=None): 95 | super(Bottleneck, self).__init__() 96 | if norm_layer is None: 97 | norm_layer = nn.BatchNorm2d 98 | width = int(planes * (base_width / 64.)) * groups 99 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 100 | self.conv1 = conv1x1(inplanes, width) 101 | self.bn1 = norm_layer(width) 102 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 103 | self.bn2 = norm_layer(width) 104 | self.conv3 = conv1x1(width, planes * self.expansion) 105 | self.bn3 = norm_layer(planes * self.expansion) 106 | self.relu = nn.ReLU(inplace=True) 107 | self.downsample = downsample 108 | self.stride = stride 109 | 110 | def forward(self, x): 111 | identity = x 112 | 113 | out = self.conv1(x) 114 | out = self.bn1(out) 115 | out = self.relu(out) 116 | 117 | out = self.conv2(out) 118 | out = self.bn2(out) 119 | out = self.relu(out) 120 | 121 | out = self.conv3(out) 122 | out = self.bn3(out) 123 | 124 | if self.downsample is not None: 125 | identity = self.downsample(x) 126 | 127 | out += identity 128 | out = self.relu(out) 129 | 130 | return out 131 | 132 | 133 | class ResNet(nn.Module): 134 | 135 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 136 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 137 | norm_layer=None): 138 | super(ResNet, self).__init__() 139 | if norm_layer is None: 140 | norm_layer = nn.BatchNorm2d 141 | self._norm_layer = norm_layer 142 | 143 | self.inplanes = 64 144 | self.dilation = 1 145 | if replace_stride_with_dilation is None: 146 | # each element in the tuple indicates if we should replace 147 | # the 2x2 stride with a dilated convolution instead 148 | replace_stride_with_dilation = [False, False, False] 149 | if len(replace_stride_with_dilation) != 3: 150 | raise ValueError("replace_stride_with_dilation should be None " 151 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 152 | self.groups = groups 153 | self.base_width = width_per_group 154 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=1, padding=3, 155 | bias=False) 156 | self.bn1 = norm_layer(self.inplanes) 157 | self.relu = nn.ReLU(inplace=True) 158 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 159 | self.layer1 = self._make_layer(block, 64, layers[0]) 160 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 161 | dilate=replace_stride_with_dilation[0]) 162 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 163 | dilate=replace_stride_with_dilation[1]) 164 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 165 | dilate=replace_stride_with_dilation[2]) 166 | # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 167 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 168 | 169 | for m in self.modules(): 170 | if isinstance(m, nn.Conv2d): 171 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 172 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 173 | nn.init.constant_(m.weight, 1) 174 | nn.init.constant_(m.bias, 0) 175 | 176 | # Zero-initialize the last BN in each residual branch, 177 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 178 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 179 | if zero_init_residual: 180 | for m in self.modules(): 181 | if isinstance(m, Bottleneck): 182 | nn.init.constant_(m.bn3.weight, 0) 183 | elif isinstance(m, BasicBlock): 184 | nn.init.constant_(m.bn2.weight, 0) 185 | 186 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 187 | norm_layer = self._norm_layer 188 | downsample = None 189 | previous_dilation = self.dilation 190 | if dilate: 191 | self.dilation *= stride 192 | stride = 1 193 | if stride != 1 or self.inplanes != planes * block.expansion: 194 | downsample = nn.Sequential( 195 | conv1x1(self.inplanes, planes * block.expansion, stride), 196 | norm_layer(planes * block.expansion), 197 | ) 198 | 199 | layers = [] 200 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 201 | self.base_width, previous_dilation, norm_layer)) 202 | self.inplanes = planes * block.expansion 203 | for _ in range(1, blocks): 204 | layers.append(block(self.inplanes, planes, groups=self.groups, 205 | base_width=self.base_width, dilation=self.dilation, 206 | norm_layer=norm_layer)) 207 | 208 | return nn.Sequential(*layers) 209 | 210 | def _forward_impl(self, x): 211 | # See note [TorchScript super()] 212 | x = self.conv1(x) 213 | x = self.bn1(x) 214 | x = self.relu(x) 215 | # x = self.maxpool(x) 216 | 217 | x = self.layer1(x) 218 | c1 = self.layer2(x) 219 | c2 = self.layer3(c1) 220 | c3 = self.layer4(c2) 221 | 222 | # x = self.avgpool(x) 223 | # x = torch.flatten(x, 1) 224 | # x = self.fc(x) 225 | 226 | return c1, c2, c3 227 | 228 | def forward(self, x): 229 | return self._forward_impl(x) 230 | 231 | 232 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 233 | model = ResNet(block, layers, **kwargs) 234 | if pretrained: 235 | state_dict = load_state_dict_from_url(model_urls[arch], 236 | progress=progress) 237 | model.load_state_dict(state_dict, strict=False) 238 | return model 239 | 240 | 241 | def resnet18(pretrained=False, progress=True, **kwargs): 242 | r"""ResNet-18 model from 243 | `"Deep Residual Learning for Image Recognition" `_ 244 | 245 | Args: 246 | pretrained (bool): If True, returns a model pre-trained on ImageNet 247 | progress (bool): If True, displays a progress bar of the download to stderr 248 | """ 249 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 250 | **kwargs) 251 | 252 | 253 | def resnet34(pretrained=False, progress=True, **kwargs): 254 | r"""ResNet-34 model from 255 | `"Deep Residual Learning for Image Recognition" `_ 256 | 257 | Args: 258 | pretrained (bool): If True, returns a model pre-trained on ImageNet 259 | progress (bool): If True, displays a progress bar of the download to stderr 260 | """ 261 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 262 | **kwargs) 263 | 264 | 265 | def resnet50(pretrained=False, progress=True, **kwargs): 266 | r"""ResNet-50 model from 267 | `"Deep Residual Learning for Image Recognition" `_ 268 | 269 | Args: 270 | pretrained (bool): If True, returns a model pre-trained on ImageNet 271 | progress (bool): If True, displays a progress bar of the download to stderr 272 | """ 273 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 274 | **kwargs) 275 | 276 | 277 | def resnet101(pretrained=False, progress=True, **kwargs): 278 | r"""ResNet-101 model from 279 | `"Deep Residual Learning for Image Recognition" `_ 280 | 281 | Args: 282 | pretrained (bool): If True, returns a model pre-trained on ImageNet 283 | progress (bool): If True, displays a progress bar of the download to stderr 284 | """ 285 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 286 | **kwargs) 287 | 288 | 289 | def resnet152(pretrained=False, progress=True, **kwargs): 290 | r"""ResNet-152 model from 291 | `"Deep Residual Learning for Image Recognition" `_ 292 | 293 | Args: 294 | pretrained (bool): If True, returns a model pre-trained on ImageNet 295 | progress (bool): If True, displays a progress bar of the download to stderr 296 | """ 297 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 298 | **kwargs) -------------------------------------------------------------------------------- /model/AGPCNet/context.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/5/18 17:24 2 | # @Author : PEIWEN PAN 3 | # @Email : 121106022690@njust.edu.cn 4 | # @File : context.py 5 | # @Software: PyCharm 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | __all__ = ['NonLocalBlock', 'GCA_Channel', 'GCA_Element', 'AGCB_Element', 'AGCB_Patch', 'CPM'] 12 | 13 | 14 | class NonLocalBlock(nn.Module): 15 | def __init__(self, planes, reduce_ratio=8): 16 | super(NonLocalBlock, self).__init__() 17 | 18 | inter_planes = planes // reduce_ratio 19 | self.query_conv = nn.Conv2d(planes, inter_planes, kernel_size=1) 20 | self.key_conv = nn.Conv2d(planes, inter_planes, kernel_size=1) 21 | self.value_conv = nn.Conv2d(planes, planes, kernel_size=1) 22 | self.gamma = nn.Parameter(torch.zeros(1)) 23 | 24 | self.softmax = nn.Softmax(dim=-1) 25 | 26 | def forward(self, x): 27 | m_batchsize, C, height, width = x.size() 28 | 29 | proj_query = self.query_conv(x) 30 | proj_key = self.key_conv(x) 31 | proj_value = self.value_conv(x) 32 | 33 | proj_query = proj_query.contiguous().view(m_batchsize, -1, width * height).permute(0, 2, 1) 34 | proj_key = proj_key.contiguous().view(m_batchsize, -1, width * height) 35 | energy = torch.bmm(proj_query, proj_key) 36 | attention = self.softmax(energy) 37 | proj_value = proj_value.contiguous().view(m_batchsize, -1, width * height) 38 | 39 | out = torch.bmm(proj_value, attention.permute(0, 2, 1)) 40 | out = out.view(m_batchsize, -1, height, width) 41 | 42 | out = self.gamma * out + x 43 | return out 44 | 45 | 46 | class GCA_Channel(nn.Module): 47 | def __init__(self, planes, scale, reduce_ratio_nl, att_mode='origin'): 48 | super(GCA_Channel, self).__init__() 49 | assert att_mode in ['origin', 'post'] 50 | 51 | self.att_mode = att_mode 52 | if att_mode == 'origin': 53 | self.pool = nn.AdaptiveMaxPool2d(scale) 54 | self.non_local_att = NonLocalBlock(planes, reduce_ratio=reduce_ratio_nl) 55 | self.sigmoid = nn.Sigmoid() 56 | elif att_mode == 'post': 57 | self.pool = nn.AdaptiveMaxPool2d(scale) 58 | self.non_local_att = NonLocalBlock(planes, reduce_ratio=1) 59 | self.conv_att = nn.Sequential( 60 | nn.Conv2d(planes, planes // 4, kernel_size=1), 61 | nn.BatchNorm2d(planes // 4), 62 | nn.ReLU(True), 63 | 64 | nn.Conv2d(planes // 4, planes, kernel_size=1), 65 | nn.BatchNorm2d(planes), 66 | nn.Sigmoid(), 67 | ) 68 | else: 69 | raise NotImplementedError 70 | 71 | def forward(self, x): 72 | if self.att_mode == 'origin': 73 | gca = self.pool(x) 74 | gca = self.non_local_att(gca) 75 | gca = self.sigmoid(gca) 76 | elif self.att_mode == 'post': 77 | gca = self.pool(x) 78 | gca = self.non_local_att(gca) 79 | gca = self.conv_att(gca) 80 | else: 81 | raise NotImplementedError 82 | return gca 83 | 84 | 85 | class GCA_Element(nn.Module): 86 | def __init__(self, planes, scale, reduce_ratio_nl, att_mode='origin'): 87 | super(GCA_Element, self).__init__() 88 | assert att_mode in ['origin', 'post'] 89 | 90 | self.att_mode = att_mode 91 | if att_mode == 'origin': 92 | self.pool = nn.AdaptiveMaxPool2d(scale) 93 | self.non_local_att = NonLocalBlock(planes, reduce_ratio=reduce_ratio_nl) 94 | self.sigmoid = nn.Sigmoid() 95 | elif att_mode == 'post': 96 | self.pool = nn.AdaptiveMaxPool2d(scale) 97 | self.non_local_att = NonLocalBlock(planes, reduce_ratio=1) 98 | self.conv_att = nn.Sequential( 99 | nn.Conv2d(planes, planes // 4, kernel_size=1), 100 | nn.BatchNorm2d(planes // 4), 101 | nn.ReLU(True), 102 | 103 | nn.Conv2d(planes // 4, planes, kernel_size=1), 104 | nn.BatchNorm2d(planes), 105 | ) 106 | self.sigmoid = nn.Sigmoid() 107 | else: 108 | raise NotImplementedError 109 | 110 | def forward(self, x): 111 | batch_size, C, height, width = x.size() 112 | 113 | if self.att_mode == 'origin': 114 | gca = self.pool(x) 115 | gca = self.non_local_att(gca) 116 | gca = F.interpolate(gca, [height, width], mode='bilinear', align_corners=True) 117 | gca = self.sigmoid(gca) 118 | elif self.att_mode == 'post': 119 | gca = self.pool(x) 120 | gca = self.non_local_att(gca) 121 | gca = self.conv_att(gca) 122 | gca = F.interpolate(gca, [height, width], mode='bilinear', align_corners=True) 123 | gca = self.sigmoid(gca) 124 | else: 125 | raise NotImplementedError 126 | return gca 127 | 128 | 129 | class AGCB_Patch(nn.Module): 130 | def __init__(self, planes, scale=2, reduce_ratio_nl=32, att_mode='origin'): 131 | super(AGCB_Patch, self).__init__() 132 | 133 | self.scale = scale 134 | self.non_local = NonLocalBlock(planes, reduce_ratio=reduce_ratio_nl) 135 | self.conv = nn.Sequential( 136 | nn.Conv2d(planes, planes, 3, 1, 1), 137 | nn.BatchNorm2d(planes), 138 | # nn.Dropout(0.1) 139 | ) 140 | self.relu = nn.ReLU(True) 141 | self.attention = GCA_Channel(planes, scale, reduce_ratio_nl, att_mode=att_mode) 142 | 143 | self.gamma = nn.Parameter(torch.zeros(1)) 144 | 145 | def forward(self, x): 146 | ## long context 147 | gca = self.attention(x) 148 | 149 | ## single scale non local 150 | batch_size, C, height, width = x.size() 151 | 152 | local_x, local_y, attention_ind = [], [], [] 153 | step_h, step_w = height // self.scale, width // self.scale 154 | for i in range(self.scale): 155 | for j in range(self.scale): 156 | start_x, start_y = i * step_h, j * step_w 157 | end_x, end_y = min(start_x + step_h, height), min(start_y + step_w, width) 158 | if i == (self.scale - 1): 159 | end_x = height 160 | if j == (self.scale - 1): 161 | end_y = width 162 | 163 | local_x += [start_x, end_x] 164 | local_y += [start_y, end_y] 165 | attention_ind += [i, j] 166 | 167 | index_cnt = 2 * self.scale * self.scale 168 | assert len(local_x) == index_cnt 169 | 170 | context_list = [] 171 | for i in range(0, index_cnt, 2): 172 | block = x[:, :, local_x[i]:local_x[i+1], local_y[i]:local_y[i+1]] 173 | attention = gca[:, :, attention_ind[i], attention_ind[i+1]].view(batch_size, C, 1, 1) 174 | context_list.append(self.non_local(block) * attention) 175 | 176 | tmp = [] 177 | for i in range(self.scale): 178 | row_tmp = [] 179 | for j in range(self.scale): 180 | row_tmp.append(context_list[j + i * self.scale]) 181 | tmp.append(torch.cat(row_tmp, 3)) 182 | context = torch.cat(tmp, 2) 183 | 184 | context = self.conv(context) 185 | context = self.gamma * context + x 186 | context = self.relu(context) 187 | return context 188 | 189 | 190 | class AGCB_Element(nn.Module): 191 | def __init__(self, planes, scale=2, reduce_ratio_nl=32, att_mode='origin'): 192 | super(AGCB_Element, self).__init__() 193 | 194 | self.scale = scale 195 | self.non_local = NonLocalBlock(planes, reduce_ratio=reduce_ratio_nl) 196 | self.conv = nn.Sequential( 197 | nn.Conv2d(planes, planes, 3, 1, 1), 198 | nn.BatchNorm2d(planes), 199 | # nn.Dropout(0.1) 200 | ) 201 | self.relu = nn.ReLU(True) 202 | self.attention = GCA_Element(planes, scale, reduce_ratio_nl, att_mode=att_mode) 203 | 204 | self.gamma = nn.Parameter(torch.zeros(1)) 205 | 206 | def forward(self, x): 207 | ## long context 208 | gca = self.attention(x) 209 | 210 | ## single scale non local 211 | batch_size, C, height, width = x.size() 212 | 213 | local_x, local_y, attention_ind = [], [], [] 214 | step_h, step_w = height // self.scale, width // self.scale 215 | for i in range(self.scale): 216 | for j in range(self.scale): 217 | start_x, start_y = i * step_h, j * step_w 218 | end_x, end_y = min(start_x + step_h, height), min(start_y + step_w, width) 219 | if i == (self.scale - 1): 220 | end_x = height 221 | if j == (self.scale - 1): 222 | end_y = width 223 | 224 | local_x += [start_x, end_x] 225 | local_y += [start_y, end_y] 226 | attention_ind += [i, j] 227 | 228 | index_cnt = 2 * self.scale * self.scale 229 | assert len(local_x) == index_cnt 230 | 231 | context_list = [] 232 | for i in range(0, index_cnt, 2): 233 | block = x[:, :, local_x[i]:local_x[i+1], local_y[i]:local_y[i+1]] 234 | # attention = gca[:, :, attention_ind[i], attention_ind[i+1]].view(batch_size, C, 1, 1) 235 | context_list.append(self.non_local(block)) 236 | 237 | tmp = [] 238 | for i in range(self.scale): 239 | row_tmp = [] 240 | for j in range(self.scale): 241 | row_tmp.append(context_list[j + i * self.scale]) 242 | tmp.append(torch.cat(row_tmp, 3)) 243 | context = torch.cat(tmp, 2) 244 | 245 | context = context * gca 246 | context = self.conv(context) 247 | context = self.gamma * context + x 248 | context = self.relu(context) 249 | return context 250 | 251 | 252 | class AGCB_NoGCA(nn.Module): 253 | def __init__(self, planes, scale=2, reduce_ratio_nl=32): 254 | super(AGCB_NoGCA, self).__init__() 255 | 256 | self.scale = scale 257 | self.non_local = NonLocalBlock(planes, reduce_ratio=reduce_ratio_nl) 258 | self.conv = nn.Sequential( 259 | nn.Conv2d(planes, planes, 3, 1, 1), 260 | nn.BatchNorm2d(planes), 261 | # nn.Dropout(0.1) 262 | ) 263 | self.relu = nn.ReLU(True) 264 | 265 | self.gamma = nn.Parameter(torch.zeros(1)) 266 | 267 | def forward(self, x): 268 | ## single scale non local 269 | batch_size, C, height, width = x.size() 270 | 271 | local_x, local_y, attention_ind = [], [], [] 272 | step_h, step_w = height // self.scale, width // self.scale 273 | for i in range(self.scale): 274 | for j in range(self.scale): 275 | start_x, start_y = i * step_h, j * step_w 276 | end_x, end_y = min(start_x + step_h, height), min(start_y + step_w, width) 277 | if i == (self.scale - 1): 278 | end_x = height 279 | if j == (self.scale - 1): 280 | end_y = width 281 | 282 | local_x += [start_x, end_x] 283 | local_y += [start_y, end_y] 284 | attention_ind += [i, j] 285 | 286 | index_cnt = 2 * self.scale * self.scale 287 | assert len(local_x) == index_cnt 288 | 289 | context_list = [] 290 | for i in range(0, index_cnt, 2): 291 | block = x[:, :, local_x[i]:local_x[i+1], local_y[i]:local_y[i+1]] 292 | context_list.append(self.non_local(block)) 293 | 294 | tmp = [] 295 | for i in range(self.scale): 296 | row_tmp = [] 297 | for j in range(self.scale): 298 | row_tmp.append(context_list[j + i * self.scale]) 299 | tmp.append(torch.cat(row_tmp, 3)) 300 | context = torch.cat(tmp, 2) 301 | 302 | context = self.conv(context) 303 | context = self.gamma * context + x 304 | context = self.relu(context) 305 | return context 306 | 307 | 308 | class CPM(nn.Module): 309 | def __init__(self, planes, block_type, scales=(3,5,6,10), reduce_ratios=(4,8), att_mode='origin'): 310 | super(CPM, self).__init__() 311 | assert block_type in ['patch', 'element'] 312 | assert att_mode in ['origin', 'post'] 313 | 314 | inter_planes = planes // reduce_ratios[0] 315 | self.conv1 = nn.Sequential( 316 | nn.Conv2d(planes, inter_planes, kernel_size=1), 317 | nn.BatchNorm2d(inter_planes), 318 | nn.ReLU(True), 319 | ) 320 | 321 | if block_type == 'patch': 322 | self.scale_list = nn.ModuleList( 323 | [AGCB_Patch(inter_planes, scale=scale, reduce_ratio_nl=reduce_ratios[1], att_mode=att_mode) 324 | for scale in scales]) 325 | elif block_type == 'element': 326 | self.scale_list = nn.ModuleList( 327 | [AGCB_Element(inter_planes, scale=scale, reduce_ratio_nl=reduce_ratios[1], att_mode=att_mode) 328 | for scale in scales]) 329 | else: 330 | raise NotImplementedError 331 | 332 | channels = inter_planes * (len(scales) + 1) 333 | self.conv2 = nn.Sequential( 334 | nn.Conv2d(channels, planes, 1), 335 | nn.BatchNorm2d(planes), 336 | nn.ReLU(True), 337 | ) 338 | 339 | def forward(self, x): 340 | reduced = self.conv1(x) 341 | 342 | blocks = [] 343 | for i in range(len(self.scale_list)): 344 | blocks.append(self.scale_list[i](reduced)) 345 | out = torch.cat(blocks, 1) 346 | out = torch.cat((reduced, out), 1) 347 | out = self.conv2(out) 348 | return out --------------------------------------------------------------------------------