├── HDMba ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── __init__.cpython-310.pyc ├── models │ ├── __pycache__ │ │ ├── FFA.cpython-36.pyc │ │ ├── FFA.cpython-38.pyc │ │ ├── FFA.cpython-39.pyc │ │ ├── FFA2.cpython-36.pyc │ │ ├── FFA2.cpython-38.pyc │ │ ├── FFA2.cpython-39.pyc │ │ ├── SST.cpython-310.pyc │ │ ├── SST.cpython-36.pyc │ │ ├── AACNet.cpython-36.pyc │ │ ├── CANet.cpython-310.pyc │ │ ├── CANet.cpython-36.pyc │ │ ├── FFANet.cpython-36.pyc │ │ ├── HDMba.cpython-310.pyc │ │ ├── LKDNet.cpython-36.pyc │ │ ├── SCConv.cpython-36.pyc │ │ ├── SGnet.cpython-310.pyc │ │ ├── SGnet.cpython-36.pyc │ │ ├── SGnet.cpython-38.pyc │ │ ├── SGnet.cpython-39.pyc │ │ ├── AACNet.cpython-310.pyc │ │ ├── CSUTrans.cpython-310.pyc │ │ ├── CSUTrans.cpython-36.pyc │ │ ├── FFANet.cpython-310.pyc │ │ ├── HFFormer.cpython-310.pyc │ │ ├── HFFormer.cpython-36.pyc │ │ ├── LKDNet.cpython-310.pyc │ │ ├── MSST_MLK.cpython-310.pyc │ │ ├── MSST_MLK.cpython-36.pyc │ │ ├── PSMBNet.cpython-310.pyc │ │ ├── PSMBNet.cpython-36.pyc │ │ ├── RSDformer.cpython-36.pyc │ │ ├── RSHazeNet.cpython-36.pyc │ │ ├── RSdehaze.cpython-310.pyc │ │ ├── RSdehaze.cpython-36.pyc │ │ ├── RSdehaze.cpython-38.pyc │ │ ├── RSdehaze.cpython-39.pyc │ │ ├── Restormer.cpython-36.pyc │ │ ├── SCConv.cpython-310.pyc │ │ ├── SSMamba.cpython-310.pyc │ │ ├── SSMamba.cpython-36.pyc │ │ ├── SST_MLK.cpython-310.pyc │ │ ├── SST_MLK.cpython-36.pyc │ │ ├── SST_MSF.cpython-310.pyc │ │ ├── SST_MSF.cpython-36.pyc │ │ ├── UFormer.cpython-310.pyc │ │ ├── UFormer.cpython-36.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── RSDformer.cpython-310.pyc │ │ ├── RSHazeNet.cpython-310.pyc │ │ ├── Restormer.cpython-310.pyc │ │ ├── SSMamba_ab.cpython-310.pyc │ │ ├── TransWeater.cpython-36.pyc │ │ ├── AIDTransformer.cpython-36.pyc │ │ ├── Dehazeformer.cpython-310.pyc │ │ ├── Dehazeformer.cpython-36.pyc │ │ ├── GF5dehazeNet.cpython-36.pyc │ │ ├── PerceptualLoss.cpython-36.pyc │ │ ├── PerceptualLoss.cpython-38.pyc │ │ ├── PerceptualLoss.cpython-39.pyc │ │ ├── SST_MLK_GDFN.cpython-310.pyc │ │ ├── SST_MLK_GDFN.cpython-36.pyc │ │ ├── SST_MLK_MDTA.cpython-310.pyc │ │ ├── SST_MLK_MDTA.cpython-36.pyc │ │ ├── TransWeater.cpython-310.pyc │ │ ├── TwobranchNet.cpython-310.pyc │ │ ├── TwobranchNet.cpython-36.pyc │ │ ├── TwobranchWoRB.cpython-36.pyc │ │ ├── AIDTransformer.cpython-310.pyc │ │ ├── PerceptualLoss.cpython-310.pyc │ │ ├── TwobranchWoFFAB.cpython-36.pyc │ │ ├── TwobranchWoLoss.cpython-36.pyc │ │ ├── TwobranchWoSSAB.cpython-36.pyc │ │ ├── Twobranch_ablation.cpython-36.pyc │ │ └── Twobranchablation.cpython-36.pyc │ ├── __init__.py │ ├── PerceptualLoss.py │ ├── FFANet.py │ ├── AACNet.py │ ├── HDMba.py │ ├── Dehazeformer.py │ └── AIDTransformer.py ├── option.py ├── data_utils.py ├── test.py ├── metrics.py ├── main.py └── util.py ├── performance.PNG └── README.md /HDMba/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /performance.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/performance.PNG -------------------------------------------------------------------------------- /HDMba/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/FFA.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/FFA.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/FFA.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/FFA.cpython-38.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/FFA.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/FFA.cpython-39.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/FFA2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/FFA2.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/FFA2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/FFA2.cpython-38.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/FFA2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/FFA2.cpython-39.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/SST.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SST.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/SST.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SST.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/AACNet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/AACNet.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/CANet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/CANet.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/CANet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/CANet.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/FFANet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/FFANet.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/HDMba.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/HDMba.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/LKDNet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/LKDNet.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/SCConv.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SCConv.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/SGnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SGnet.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/SGnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SGnet.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/SGnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SGnet.cpython-38.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/SGnet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SGnet.cpython-39.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/AACNet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/AACNet.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/CSUTrans.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/CSUTrans.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/CSUTrans.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/CSUTrans.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/FFANet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/FFANet.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/HFFormer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/HFFormer.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/HFFormer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/HFFormer.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/LKDNet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/LKDNet.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/MSST_MLK.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/MSST_MLK.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/MSST_MLK.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/MSST_MLK.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/PSMBNet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/PSMBNet.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/PSMBNet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/PSMBNet.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/RSDformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/RSDformer.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/RSHazeNet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/RSHazeNet.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/RSdehaze.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/RSdehaze.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/RSdehaze.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/RSdehaze.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/RSdehaze.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/RSdehaze.cpython-38.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/RSdehaze.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/RSdehaze.cpython-39.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/Restormer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/Restormer.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/SCConv.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SCConv.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/SSMamba.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SSMamba.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/SSMamba.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SSMamba.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/SST_MLK.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SST_MLK.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/SST_MLK.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SST_MLK.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/SST_MSF.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SST_MSF.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/SST_MSF.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SST_MSF.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/UFormer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/UFormer.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/UFormer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/UFormer.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/RSDformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/RSDformer.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/RSHazeNet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/RSHazeNet.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/Restormer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/Restormer.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/SSMamba_ab.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SSMamba_ab.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/TransWeater.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/TransWeater.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/AIDTransformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/AIDTransformer.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/Dehazeformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/Dehazeformer.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/Dehazeformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/Dehazeformer.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/GF5dehazeNet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/GF5dehazeNet.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/PerceptualLoss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/PerceptualLoss.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/PerceptualLoss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/PerceptualLoss.cpython-38.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/PerceptualLoss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/PerceptualLoss.cpython-39.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/SST_MLK_GDFN.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SST_MLK_GDFN.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/SST_MLK_GDFN.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SST_MLK_GDFN.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/SST_MLK_MDTA.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SST_MLK_MDTA.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/SST_MLK_MDTA.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SST_MLK_MDTA.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/TransWeater.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/TransWeater.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/TwobranchNet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/TwobranchNet.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/TwobranchNet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/TwobranchNet.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/TwobranchWoRB.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/TwobranchWoRB.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/AIDTransformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/AIDTransformer.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/PerceptualLoss.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/PerceptualLoss.cpython-310.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/TwobranchWoFFAB.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/TwobranchWoFFAB.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/TwobranchWoLoss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/TwobranchWoLoss.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/TwobranchWoSSAB.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/TwobranchWoSSAB.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/Twobranch_ablation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/Twobranch_ablation.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__pycache__/Twobranchablation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/Twobranchablation.cpython-36.pyc -------------------------------------------------------------------------------- /HDMba/models/__init__.py: -------------------------------------------------------------------------------- 1 | import sys,os 2 | dir = os.path.abspath(os.path.dirname(__file__)) 3 | sys.path.append(dir) 4 | # from FFA import FFA 5 | from PerceptualLoss import LossNetwork as PerLoss 6 | # import net 7 | -------------------------------------------------------------------------------- /HDMba/models/PerceptualLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | # --- Perceptual loss network --- # 6 | class LossNetwork(torch.nn.Module): 7 | def __init__(self, vgg_model): 8 | super(LossNetwork, self).__init__() 9 | self.vgg_layers = vgg_model 10 | self.layer_name_mapping = { 11 | '3': "relu1_2", 12 | '8': "relu2_2", 13 | '15': "relu3_3" 14 | } 15 | 16 | def output_features(self, x): 17 | output = {} 18 | for name, module in self.vgg_layers._modules.items(): 19 | x = module(x) 20 | if name in self.layer_name_mapping: 21 | output[self.layer_name_mapping[name]] = x 22 | return list(output.values()) 23 | 24 | def forward(self, dehaze, gt): 25 | loss = [] 26 | dehaze_features = self.output_features(dehaze) 27 | gt_features = self.output_features(gt) 28 | for dehaze_feature, gt_feature in zip(dehaze_features, gt_features): 29 | loss.append(F.mse_loss(dehaze_feature, gt_feature)) 30 | 31 | return sum(loss)/len(loss) -------------------------------------------------------------------------------- /HDMba/option.py: -------------------------------------------------------------------------------- 1 | import torch,os,sys,torchvision,argparse 2 | import torchvision.transforms as tfs 3 | import time,math 4 | import numpy as np 5 | from torch.backends import cudnn 6 | from torch import optim 7 | import torch,warnings 8 | from torch import nn 9 | import torchvision.utils as vutils 10 | warnings.filterwarnings('ignore') 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--steps', type=int, default=1000) 14 | parser.add_argument('--device', type=str, default='Automatic detection') 15 | parser.add_argument('--resume', type=bool, default=True) 16 | parser.add_argument('--eval_step', type=int, default=500) 17 | parser.add_argument('--lr', default=0.0001, type=float, help='learning rate') 18 | parser.add_argument('--model_dir', type=str, default='./trained_models/') 19 | parser.add_argument('--trainset', type=str, default='train') 20 | parser.add_argument('--testset', type=str, default='test') 21 | 22 | # model 23 | # parser.add_argument('--net', type=str, default='AACNet') 24 | # parser.add_argument('--net', type=str, default='AIDTransformer') 25 | # parser.add_argument('--net', type=str, default='DehazeFormer') 26 | parser.add_argument('--net', type=str, default='HDMba') 27 | 28 | parser.add_argument('--bs', type=int, default=2, help='batch size') 29 | parser.add_argument('--crop_size', type=int, default=64, help='Takes effect when using --crop') 30 | parser.add_argument('--no_lr_sche', action='store_true', help='no lr cos schedule') 31 | parser.add_argument('--perloss', help='spec_loss') 32 | 33 | opt = parser.parse_args() 34 | opt.device = 'cuda' if torch.cuda.is_available() else 'cpu' 35 | model_name = opt.trainset+'_'+opt.net.split('.')[0] 36 | opt.model_dir = opt.model_dir+model_name+'.pk' 37 | log_dir = 'logs/'+model_name 38 | 39 | print(opt) 40 | # print('model_dir:', opt.model_dir) 41 | 42 | if not os.path.exists('trained_models'): 43 | os.mkdir('trained_models') 44 | if not os.path.exists('numpy_files'): 45 | os.mkdir('numpy_files') 46 | if not os.path.exists('logs'): 47 | os.mkdir('logs') 48 | if not os.path.exists('samples'): 49 | os.mkdir('samples') 50 | if not os.path.exists(f"samples/{model_name}"): 51 | os.mkdir(f'samples/{model_name}') 52 | if not os.path.exists(log_dir): 53 | os.mkdir(log_dir) 54 | -------------------------------------------------------------------------------- /HDMba/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torchvision.transforms as tfs 3 | from torchvision.transforms import functional as FF 4 | import os,sys 5 | sys.path.append('.') 6 | sys.path.append('..') 7 | from torch.utils.data import DataLoader 8 | from metrics import * 9 | from option import opt 10 | from random import randrange 11 | from osgeo import gdal 12 | gdal.PushErrorHandler("CPLQuietErrorHandler") 13 | 14 | BS = opt.bs 15 | crop_size = opt.crop_size 16 | 17 | 18 | class RESIDE_Dataset(data.Dataset): 19 | # def __init__(self,path,train,size=crop_size,format='.png'): 20 | def __init__(self, path, train, size=crop_size, format='.tif'): 21 | super(RESIDE_Dataset,self).__init__() 22 | self.size = size 23 | self.train = train 24 | self.format = format 25 | # G5数据 26 | self.haze_imgs_dir = os.listdir(os.path.join(path, 'hazy')) 27 | self.haze_imgs = [os.path.join(path, 'hazy', img) for img in self.haze_imgs_dir] 28 | self.clear_dir = os.path.join(path, 'clear') 29 | 30 | def __getitem__(self, index): 31 | # G5数据 32 | haze = gdal.Open(self.haze_imgs[index], 0) # 读入的是结构 33 | haze_array = haze.ReadAsArray().astype(np.float32) # 读入数组 (305,512,512) 34 | band, width, height = haze_array.shape 35 | img_path = self.haze_imgs[index] 36 | id = os.path.basename(img_path).split('_')[0] 37 | clear_name = id+self.format 38 | clear = gdal.Open(os.path.join(self.clear_dir, clear_name)) 39 | clear_array = clear.ReadAsArray().astype(np.float32) 40 | haze_array = haze_array - (np.min(haze_array)-np.min(clear_array)) 41 | haze_array = haze_array / np.max(haze_array) 42 | clear_array = clear_array / np.max(clear_array) 43 | 44 | # 影像裁剪成块 45 | if isinstance(self.size, int): 46 | x, y = randrange(0, width - self.size + 1), randrange(0, height - self.size + 1) 47 | haze_array = haze_array[:, x:x+self.size, y:y+self.size] 48 | clear_array = clear_array[:, x:x + self.size, y:y + self.size] 49 | 50 | return haze_array, clear_array 51 | 52 | def __len__(self): 53 | return len(self.haze_imgs) 54 | # return len(self.haze_paths) 55 | 56 | 57 | import os 58 | pwd = os.getcwd() 59 | 60 | path = r'F:\GF-5 dehaze' # path to your 'data' folder 61 | train_loader = DataLoader(dataset=RESIDE_Dataset(path+r'\train', train=True, size=crop_size), batch_size=BS, shuffle=True) 62 | test_loader = DataLoader(dataset=RESIDE_Dataset(path+r'\test', train=False, size='whole img'), batch_size=1, shuffle=False) 63 | 64 | x, y = next(iter(train_loader)) 65 | 66 | 67 | if __name__ == "__main__": 68 | pass 69 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

HDMba

3 |

HDMba: Hyperspectral Remote Sensing Imagery Dehazing with State Space Model

4 | 5 | Hang Fu, [Genyun Sun](https://ocean.upc.edu.cn/2019/1107/c15434a224792/page.htm), Yinhe Li, [Jinchang Ren](https://scholar.google.com.hk/citations?user=Vsx9P-gAAAAJ&hl=zh-CN), [Aizhu Zhang](https://ocean.upc.edu.cn/2019/1108/c15434a224913/page.htm), Cheng Jing, [Pedram Ghamisi](https://www.ai4rs.com/) 6 | 7 | ArXiv Preprint ([arXiv:2406.05700](https://arxiv.org/abs/2406.05700)) 8 |
9 | 10 | 11 | # 12 | 13 | ## Abstract 14 | Haze contamination in hyperspectral remote sensing images (HSI) can lead to spatial visibility degradation and spectral distortion. Haze in HSI exhibits spatial irregularity and inhomogeneous spectral distribution, with few dehazing networks available. Current CNN and Transformer-based dehazing meth- ods fail to balance global scene recovery, local detail retention, and computational efficiency. Inspired by the ability of Mamba to model long-range dependencies with linear complexity, we explore its potential for HSI dehazing and propose the first HSI Dehazing Mamba (HDMba) network. Specifically, we design a novel window selective scan module (WSSM) that captures local dependencies within windows and global correlations between windows by partitioning them. This approach improves the ability of conventional Mamba in local feature extraction. By modeling the local and global spectral-spatial information flow, we achieve a comprehensive analysis of hazy regions. The DehazeMamba layer (DML), constructed by WSSM, and residual DehazeMamba (RDM) blocks, composed of DMLs, are the core components of the HDMba framework. These components effec- tively characterize the complex distribution of haze in HSIs, aid- ing in scene reconstruction and dehazing. Experimental results on the Gaofen-5 HSI dataset demonstrate that HDMba outperforms other state-of-the-art methods in dehazing performance. 15 | 16 | 17 |
18 | 19 |
20 | 21 | 22 | ## Datasets 23 | 24 | HyperDehazing: https://github.com/RsAI-lab/HyperDehazing 25 | 26 | HDD: Available from ([Paper](https://ieeexplore.ieee.org/document/9511329)) 27 | 28 | 29 | ## Other dehazing methods 30 | 31 | SG-Net:([Code](https://github.com/SZU-AdvTech-2022/158-A-Spectral-Grouping-based-Deep-Learning-Model-for-Haze-Removal-of-Hyperspectral-Images)) 32 | 33 | AACNet: ([Code](http://www.jiasen.tech/papers/)) 34 | 35 | DehazeFormer:([Code](https://github.com/IDKiro/DehazeFormer)) 36 | 37 | AIDFormer:([Code](https://github.com/AshutoshKulkarni4998/AIDTransformer)) 38 | 39 | RSDformer:([Code](https://github.com/MingTian99/RSDformer)) 40 | 41 | 42 | ## Acknowledgement 43 | This project is based on FFANet ([code](https://github.com/zhilin007/FFA-Net)). Thanks for their wonderful works. 44 | 45 | -------------------------------------------------------------------------------- /HDMba/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from models import * 4 | import cv2 5 | from PIL import Image 6 | import torch 7 | import torch.nn as nn 8 | import matplotlib.pyplot as plt 9 | from torchvision.utils import make_grid 10 | from osgeo import gdal 11 | from util import fuse_images, tensor2im 12 | gdal.PushErrorHandler("CPLQuietErrorHandler") 13 | 14 | 15 | from models.AACNet import AACNet 16 | from models.AIDTransformer import AIDTransformer 17 | from models.Dehazeformer import DehazeFormer 18 | from models.HDMba import HDMba 19 | 20 | 21 | import time 22 | abs = os.getcwd()+'/' 23 | 24 | 25 | def TwoPercentLinear(image, max_out=255, min_out=0): # 2%的线性拉伸 26 | b, g, r = cv2.split(image) # 分开三个波段 27 | 28 | def gray_process(gray, maxout = max_out, minout = min_out): 29 | high_value = np.percentile(gray, 98) # 取得98%直方图处对应灰度 30 | low_value = np.percentile(gray, 2) # 同理 31 | truncated_gray = np.clip(gray, a_min=low_value, a_max=high_value) 32 | processed_gray = ((truncated_gray - low_value)/(high_value - low_value)) * (maxout - minout)#线性拉伸嘛 33 | return processed_gray 34 | r_p = gray_process(r) 35 | g_p = gray_process(g) 36 | b_p = gray_process(b) 37 | result = cv2.merge((b_p, g_p, r_p)) #合并处理后的三个波段 38 | return np.uint8(result) 39 | 40 | 41 | def get_write_picture_fina(img): # get_write_picture函数得到训练过程中的可视化结果 42 | img = tensor2im(img, np.float) 43 | img = img.astype(np.uint8) 44 | output = TwoPercentLinear(img[:, :, (58, 38, 20)]) 45 | # output = TwoPercentLinear(img[:, :, (2, 1, 0)]) 46 | return output 47 | 48 | 49 | def tensorShow(name, tensors, titles): 50 | fig = plt.figure(figsize=(8, 8)) 51 | for tensor, tit, i in zip(tensors, titles, range(len(tensors))): 52 | img = make_grid(tensor) 53 | npimg = img.numpy() 54 | npimg = np.transpose(npimg, (1, 2, 0)) 55 | npimg = npimg / np.max(npimg) 56 | npimg = np.clip(npimg, 0, 1) 57 | ax = fig.add_subplot(121+i) 58 | ax.imshow(npimg) 59 | plt.imsave(f"../pred_GF5/{name}_{tit}.png", npimg) 60 | ax.set_title(tit) 61 | plt.tight_layout() 62 | plt.show() 63 | 64 | 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('--test_imgs', type=str, default='F:/L8MSI/Paired/Test/pre', help='Test imgs folder') 67 | opt = parser.parse_args() 68 | img_dir = opt.test_imgs+'/' 69 | 70 | # 训练好的网络 71 | model_dir = abs+f'trained_models/train_HDMba.pk' 72 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 73 | ckp = torch.load(model_dir, map_location=device) 74 | net = HDMba() 75 | 76 | net = nn.DataParallel(net) 77 | net.load_state_dict(ckp['model']) 78 | net = net.module.to(torch.device('cpu')) 79 | net.eval() 80 | 81 | 82 | for im in os.listdir(img_dir): 83 | # print(im) 84 | start_time = time.time() 85 | print(f'\r {im}', end='', flush=True) 86 | haze = gdal.Open(img_dir+im).ReadAsArray().astype(np.float32) 87 | haze = haze / np.max(haze) 88 | # haze = haze[] 89 | haze = np.expand_dims(haze, 0) 90 | haze = torch.from_numpy(haze).type(torch.FloatTensor) 91 | with torch.no_grad(): 92 | pred = net(haze) 93 | ts = torch.squeeze(pred.cpu()) 94 | print(f'|time_used :{(time.time() - start_time):.4f}s ') 95 | 96 | write_image2 = get_write_picture_fina(pred) 97 | write_image_name = "C:/Users/Administrator/Desktop/canet/" + str(im) + "_new.png" 98 | Image.fromarray(np.uint8(write_image2)).save(write_image_name) 99 | -------------------------------------------------------------------------------- /HDMba/models/FFANet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | # from torchsummary import summary 4 | # import torchsummary 5 | 6 | 7 | class PALayer(nn.Module): 8 | def __init__(self, channel): 9 | super(PALayer, self).__init__() 10 | self.pa = nn.Sequential( 11 | nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True), 12 | nn.ReLU(inplace=True), 13 | nn.Conv2d(channel // 8, 1, 1, padding=0, bias=True), 14 | nn.Sigmoid() 15 | ) 16 | def forward(self, x): 17 | y = self.pa(x) 18 | return x * y 19 | 20 | 21 | class CALayer(nn.Module): 22 | def __init__(self, channel): 23 | super(CALayer, self).__init__() 24 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 25 | self.ca = nn.Sequential( 26 | nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True), 27 | nn.ReLU(inplace=True), 28 | nn.Conv2d(channel // 8, channel, 1, padding=0, bias=True), 29 | nn.Sigmoid() 30 | ) 31 | 32 | def forward(self, x): 33 | y = self.avg_pool(x) 34 | y = self.ca(y) 35 | return x * y 36 | 37 | 38 | class Block(nn.Module): 39 | def __init__(self, dim): 40 | super(Block, self).__init__() 41 | self.conv1 = nn.Conv2d(dim, dim, 3, padding=1, bias=True) 42 | self.act1 = nn.ReLU(inplace=True) 43 | self.conv2 = nn.Conv2d(dim, dim, 3, padding=1, bias=True) 44 | self.calayer = CALayer(dim) 45 | self.palayer = PALayer(dim) 46 | 47 | def forward(self, x): 48 | 49 | res = self.act1(self.conv1(x)) 50 | res = res+x 51 | res = self.conv2(res) 52 | res = self.calayer(res) 53 | res = self.palayer(res) 54 | res += x 55 | return res 56 | 57 | 58 | class GroupRLKAs(nn.Module): 59 | def __init__(self, dim): 60 | super(GroupRLKAs, self).__init__() 61 | self.g1 = Block(dim) 62 | self.g2 = Block(dim) 63 | self.g3 = Block(dim) 64 | 65 | def forward(self, x): 66 | y1 = self.g1(x) 67 | y2 = self.g1(y1) 68 | y3 = self.g1(y2) 69 | return y3, torch.cat([y1, y2, y3], dim=1) 70 | 71 | 72 | class FFB(nn.Module): 73 | def __init__(self, dim): 74 | super(FFB, self).__init__() 75 | self.conv0 = nn.Conv2d(dim, 32, 1, bias=False) 76 | # self.activation0 = nn.ReLU() 77 | self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=False) 78 | 79 | def forward(self, x): 80 | x = self.conv0(x) 81 | # x = self.activation0(x) 82 | x = self.conv1(x) 83 | return x 84 | 85 | 86 | class FFANet(nn.Module): 87 | def __init__(self): 88 | super(FFANet, self).__init__() 89 | # 初始特征提取层 90 | # self.conv0 = nn.Conv2d(305, 32, 3, padding=1, bias=False) 91 | self.conv0 = nn.Conv2d(7, 32, 3, padding=1, bias=False) 92 | #self.conv0 = nn.Conv2d(4, 32, 3, padding=1, bias=False) 93 | self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=False) 94 | 95 | # 中间的块 96 | self.g1 = GroupRLKAs(32) 97 | self.g2 = GroupRLKAs(32) 98 | self.g3 = GroupRLKAs(32) 99 | 100 | # 后面的块 101 | self.fusion = FFB(96*3) 102 | self.att1 = CALayer(32) 103 | self.att2 = PALayer(32) 104 | self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=False) 105 | # self.conv3 = nn.Conv2d(32, 305, 3, padding=1, bias=False) 106 | self.conv3 = nn.Conv2d(32, 7, 3, padding=1, bias=False) 107 | #self.conv3 = nn.Conv2d(32, 4, 3, padding=1, bias=False) 108 | 109 | def forward(self, x): 110 | out1 = self.conv0(x) 111 | out2 = self.conv1(out1) 112 | 113 | R1, F1 = self.g1(out2) 114 | R2, F2 = self.g2(R1) 115 | R3, F3 = self.g2(R2) 116 | 117 | Fea = torch.cat([F1, F2, F3], dim=1) 118 | Fea = self.fusion(Fea) 119 | Fea = self.att1(Fea) 120 | Fea = self.att2(Fea) 121 | Fea = self.conv2(Fea) 122 | Fea = Fea + out1 123 | Fea = self.conv3(Fea) 124 | # Fea = Fea + x 125 | return Fea 126 | 127 | 128 | if __name__ == "__main__": 129 | net = FFANet() 130 | 131 | # input_data = torch.rand(1, 305, 512, 512) 132 | # net = FFANet() 133 | # out = net(input_data) 134 | # print(out.shape) 135 | 136 | # summary(net, torch.rand(1, 305, 512, 512)) 137 | 138 | device = torch.device('cpu') 139 | net.to(device) 140 | 141 | # torchsummary.summary(net.cuda(), (305, 512, 512)) -------------------------------------------------------------------------------- /HDMba/metrics.py: -------------------------------------------------------------------------------- 1 | from math import exp 2 | import math 3 | import numpy as np 4 | # import imgvision as iv # python图像光谱视觉分析库-imgvision 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | 10 | from math import exp 11 | import math 12 | import numpy as np 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | from torch.autograd import Variable 17 | from torchvision.transforms import ToPILImage 18 | from numpy.linalg import norm 19 | 20 | def gaussian(window_size, sigma): 21 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 22 | return gauss / gauss.sum() 23 | 24 | def create_window(window_size, channel): 25 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 26 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 27 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 28 | return window 29 | 30 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 31 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel).cpu().numpy() 32 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel).cpu().numpy() 33 | # print(type(mu2)) 34 | mu1_sq = np.power(mu1, 2) # mul的2次方 35 | mu2_sq = np.power(mu2, 2) 36 | mu1_mu2 = mu1 * mu2 37 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel).cpu().numpy() - mu1_sq 38 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel).cpu().numpy() - mu2_sq 39 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel).cpu().numpy() - mu1_mu2 40 | C1 = 0.01 ** 2 41 | C2 = 0.03 ** 2 42 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 43 | 44 | if size_average: 45 | return np.mean(ssim_map) 46 | else: 47 | return ssim_map.mean(1).mean(1).mean(1) 48 | 49 | 50 | def ssim(img1, img2, window_size=11, size_average=True): 51 | img1 = torch.clamp(img1, min=0, max=1) # 将输入input张量每个元素的夹紧到区间 52 | img2 = torch.clamp(img2, min=0, max=1) 53 | (_, channel, _, _) = img1.size() 54 | window = create_window(window_size, channel) 55 | if img1.is_cuda: 56 | window = window.cuda(img1.get_device()) 57 | window = window.type_as(img1) 58 | return _ssim(img1, img2, window, window_size, channel, size_average) 59 | 60 | def psnr(pred, gt): 61 | pred = pred.clamp(0, 1).cpu().numpy() 62 | gt = gt.clamp(0, 1).cpu().numpy() 63 | imdff = pred - gt 64 | rmse = math.sqrt(np.mean(imdff ** 2)) 65 | if rmse == 0: 66 | return 100 67 | return 20 * math.log10(1.0 / rmse) 68 | 69 | 70 | def UQI(O, F): 71 | meanO = torch.mean(O) 72 | meanF = torch.mean(F) 73 | (_, _, m, n) = np.shape(F) 74 | varO = torch.sqrt(torch.sum((O - meanO) ** 2) / (m * n - 1)) 75 | varF = torch.sqrt(torch.sum((F - meanF) ** 2) / (m * n - 1)) 76 | 77 | covOF = torch.sum((O - meanO) * (F - meanF)) / (m * n - 1) 78 | UQI = 4 * meanO * meanF * covOF / ((meanO ** 2 + meanF ** 2) * (varO ** 2 + varF ** 2)) 79 | return UQI.data.cpu().numpy() 80 | 81 | 82 | def SAM(pred, target): 83 | pred = torch.clamp(pred, min=0, max=1) 84 | target = torch.clamp(target, min=0, max=1) 85 | pred1 = pred[0, :, :, :].cpu() 86 | target1 = target[0, :, :, :].cpu() 87 | pred1 = np.transpose(pred1, (2, 1, 0)) 88 | target1 = np.transpose(target1, (2, 1, 0)) 89 | sam_rad = np.zeros((pred1.shape[0], pred1.shape[1])) 90 | for x in range(pred1.shape[0]): 91 | for y in range(pred1.shape[1]): 92 | tmp_pred = pred1[x, y].ravel() 93 | tmp_true = target1[x, y].ravel() 94 | cos_value = (tmp_pred.mean() / (norm(tmp_pred) * tmp_true.mean() / norm(tmp_true))) 95 | # print(cos_value) 96 | if 1.0 < cos_value: 97 | cos_value = 1.0 98 | sam_rad[x, y] = cos_value 99 | SAM1 = np.arccos(sam_rad) 100 | # SAM1 = sam1.mean() * 180 / np.pi 101 | return SAM1 102 | 103 | 104 | def calc_sam(img_tgt, img_fus): 105 | img_tgt = np.squeeze(img_tgt) 106 | img_fus = np.squeeze(img_fus) 107 | img_tgt = img_tgt.reshape(img_tgt.shape[0], -1) 108 | img_fus = img_fus.reshape(img_fus.shape[0], -1) 109 | img_tgt = img_tgt / torch.max(img_tgt) 110 | img_fus = img_fus / torch.max(img_fus) 111 | A = torch.sqrt(torch.sum(img_tgt**2, 0)) 112 | B = torch.sqrt(torch.sum(img_fus**2, 0)) 113 | AB = torch.sum(img_tgt*img_fus, 0) 114 | sam = AB/(A*B) 115 | sam = torch.arccos_(sam) 116 | sam = torch.mean(sam) 117 | return sam 118 | 119 | 120 | if __name__ == "__main__": 121 | pass 122 | 123 | # np.random.seed(10) 124 | # pred = np.random.rand(1, 20, 100, 100) 125 | # targets = np.random.rand(1, 20, 100, 100) 126 | # pred = torch.rand(1, 20, 100, 100) 127 | # targets = torch.rand(1, 20, 100, 100) 128 | # sam1 = calc_sam(pred, targets) 129 | # sam1 = SAM(pred, targets) 130 | # print(sam1) 131 | # 132 | # pred1 = np.transpose(pred, (2, 1, 0)) 133 | # targets1 = np.transpose(targets, (2, 1, 0)) 134 | # Metric = iv.spectra_metric(pred1, targets1) 135 | # SAM1 = Metric.SAM() 136 | # print(SAM1) 137 | # 138 | # UQI1 = UQI(pred, targets) 139 | # print(UQI1) 140 | # 141 | # pred2 = torch.Tensor(pred) 142 | # targets2 = torch.Tensor(targets) 143 | # ssim1 = ssim(pred2, targets2).item() 144 | # print(ssim1) 145 | # psnr1 = psnr(pred2, targets2) 146 | # print(psnr1) 147 | 148 | 149 | 150 | -------------------------------------------------------------------------------- /HDMba/main.py: -------------------------------------------------------------------------------- 1 | # load Network 2 | from models.AACNet import AACNet 3 | from models.AIDTransformer import AIDTransformer 4 | from models.Dehazeformer import DehazeFormer 5 | from models.HDMba import HDMba 6 | 7 | import time, math 8 | import numpy as np 9 | from torch.backends import cudnn 10 | from torch import optim 11 | import torch, warnings 12 | from torch import nn 13 | # from tensorboardX import SummaryWriter 14 | import torchvision.utils as vutils 15 | warnings.filterwarnings('ignore') 16 | from option import opt, model_name, log_dir 17 | from data_utils import * 18 | from matplotlib import pyplot as plt 19 | # torch.autograd.set_detect_anomaly(True) 20 | 21 | print('log_dir :', log_dir) 22 | print('model_name:', model_name) 23 | models_ = { 24 | 'AACNet': AACNet(), 25 | 'AIDTransformer': AIDTransformer(), 26 | 'DehazeFormer': DehazeFormer(), 27 | 'HDMba': HDMba(), 28 | } 29 | loaders_ = \ 30 | {'train': train_loader, 'test': test_loader} 31 | start_time = time.time() 32 | T = opt.steps 33 | 34 | 35 | def lr_schedule_cosdecay(t, T, init_lr=opt.lr): 36 | lr = 0.5 * (1 + math.cos(t * math.pi / T)) * init_lr 37 | return lr 38 | 39 | 40 | def train(net, loader_train, loader_test, optim, criterion): 41 | losses = [] 42 | start_step = 0 43 | max_ssim = 0 44 | max_psnr = 0 45 | max_uqi = 0 46 | min_sam = 1 47 | ssims = [] 48 | psnrs = [] 49 | uqis = [] 50 | sams = [] 51 | if opt.resume and os.path.exists(opt.model_dir): 52 | print(f'resume from {opt.model_dir}') 53 | ckp = torch.load(opt.model_dir) 54 | losses = ckp['losses'] 55 | net.load_state_dict(ckp['model']) 56 | start_step = ckp['step'] 57 | max_ssim = ckp['max_ssim'] 58 | max_psnr = ckp['max_psnr'] 59 | min_sam = ckp['min_sam'] 60 | max_uqi = ckp['max_uqi'] 61 | psnrs = ckp['psnrs'] 62 | ssims = ckp['ssims'] 63 | uqis = ckp['uqis'] 64 | sams = ckp['sams'] 65 | print(f'start_step:{start_step} start training ---') 66 | else: 67 | print('train from scratch *** ') # 没有训练好的网络,从头训练 68 | 69 | train_los = np.zeros(opt.steps) 70 | for step in range(start_step+1, opt.steps+1): 71 | net.train() 72 | lr = opt.lr 73 | if not opt.no_lr_sche: 74 | lr = lr_schedule_cosdecay(step, T) 75 | for param_group in optim.param_groups: 76 | param_group["lr"] = lr 77 | x, y = next(iter(loader_train)) 78 | x = x.to(opt.device) 79 | y = y.to(opt.device) 80 | out = net(x) 81 | # loss function 82 | loss1 = criterion[0](out, y) 83 | loss2 = criterion[1](out, y) 84 | loss = loss1 + 0.01*loss2 85 | loss.backward() 86 | 87 | optim.step() 88 | optim.zero_grad() 89 | losses.append(loss.item()) 90 | train_los[step-1] = loss.item() 91 | print( 92 | f'\rtrain loss : {loss.item():.5f} |step :{step}/{opt.steps} |lr:{lr :.7f} |time_used :{(time.time() - start_time) :.4f}s', 93 | end='', flush=True) 94 | 95 | if step % opt.eval_step == 0: 96 | with torch.no_grad(): 97 | ssim_eval, psnr_eval, uqi_eval, sam_eval = test(net, loader_test) 98 | print(f'\nstep :{step} |ssim:{ssim_eval:.4f} |psnr:{psnr_eval:.4f} |uqi:{uqi_eval:.4f} |sam:{sam_eval:.4f}') 99 | ssims.append(ssim_eval) 100 | psnrs.append(psnr_eval) 101 | uqis.append(uqi_eval) 102 | sams.append(sam_eval) 103 | # if psnr_eval > max_psnr and ssim_eval > max_ssim and uqi_eval > max_uqi and min_sam > sam_eval: 104 | if psnr_eval > max_psnr and ssim_eval > max_ssim: 105 | max_ssim = max(max_ssim, ssim_eval) 106 | max_psnr = max(max_psnr, psnr_eval) 107 | max_uqi = max(max_uqi, uqi_eval) 108 | min_sam = min(min_sam, sam_eval) 109 | torch.save({ 110 | 'step': step, 111 | 'max_psnr': max_psnr, 112 | 'max_ssim': max_ssim, 113 | 'max_uqi': max_uqi, 114 | 'min_sam': min_sam, 115 | 'ssims': ssims, 116 | 'psnrs': psnrs, 117 | 'uqis': uqis, 118 | 'sams': sams, 119 | 'losses': losses, 120 | 'model': net.state_dict() 121 | }, opt.model_dir) 122 | print(f'\n model saved at step :{step}| max_psnr:{max_psnr:.4f}|max_ssim:{max_ssim:.4f}|max_uqi:{max_uqi:.4f} |min_sam:{min_sam:.4f}') 123 | 124 | iters = range(len(train_los)) 125 | plt.figure() 126 | plt.plot(iters, train_los, 'g', label='train loss') 127 | plt.show() 128 | np.save(f'./numpy_files/{model_name}_{opt.steps}_losses.npy', losses) 129 | np.save(f'./numpy_files/{model_name}_{opt.steps}_ssims.npy', ssims) 130 | np.save(f'./numpy_files/{model_name}_{opt.steps}_psnrs.npy', psnrs) 131 | np.save(f'./numpy_files/{model_name}_{opt.steps}_uqis.npy', uqis) 132 | np.save(f'./numpy_files/{model_name}_{opt.steps}_sams.npy', sams) 133 | 134 | 135 | def test(net, loader_test): # verification 136 | net.eval() 137 | torch.cuda.empty_cache() 138 | ssims = [] 139 | psnrs = [] 140 | uqis = [] 141 | sams = [] 142 | for i, (inputs, targets) in enumerate(loader_test): 143 | inputs = inputs.to(opt.device) 144 | targets = targets.to(opt.device) 145 | pred = net(inputs) 146 | ssim1 = ssim(pred, targets).item() 147 | psnr1 = psnr(pred, targets) 148 | uqi1 = UQI(pred, targets) 149 | sam1 = SAM(pred, targets) 150 | # sam1 = calc_sam(pred, targets) 151 | ssims.append(ssim1) 152 | psnrs.append(psnr1) 153 | uqis.append(uqi1) 154 | sams.append(sam1) 155 | return np.mean(ssims), np.mean(psnrs), np.mean(uqis), np.mean(sams) 156 | 157 | 158 | if __name__ == "__main__": 159 | loader_train = loaders_[opt.trainset] 160 | loader_test = loaders_[opt.testset] 161 | net = models_[opt.net] 162 | net = net.to(opt.device) 163 | if opt.device == 'cuda': 164 | net = torch.nn.DataParallel(net) 165 | cudnn.benchmark = True 166 | criterion = [] 167 | criterion.append(nn.L1Loss().to(opt.device)) # L1损失被放入到criterion[0] 168 | criterion.append(nn.MSELoss().to(opt.device)) 169 | 170 | optimizer = optim.Adam(params=filter(lambda x: x.requires_grad, net.parameters()), lr=opt.lr, betas=(0.9, 0.999), eps=1e-08) 171 | optimizer.zero_grad() 172 | if torch.cuda.device_count() > 1: 173 | model = torch.nn.DataParallel(net) # 前提是model已经在cuda上了 174 | train(net, loader_train, loader_test, optimizer, criterion) 175 | 176 | 177 | -------------------------------------------------------------------------------- /HDMba/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import scipy.ndimage as ndimage 7 | import os 8 | import cv2 9 | from random import randrange 10 | import torch.nn.functional as F 11 | 12 | 13 | def find_patches(img_c, patch_size): 14 | _, band, width, height = img_c.shape 15 | x1, y1 = randrange(0, width - patch_size + 1), randrange(0, height - patch_size + 1) 16 | img_patch = img_c[:, :, x1:x1 + patch_size, y1:y1 + patch_size] 17 | for id in range(5): 18 | x, y = randrange(0, width - patch_size + 1), randrange(0, height - patch_size + 1) 19 | patch = img_c[:, :, x:x + patch_size, y:y + patch_size] 20 | img_patch = torch.cat((img_patch, patch), dim=1) 21 | # img_patch.append(patch) 22 | return img_patch.cuda() 23 | 24 | 25 | def fuse_images(real_I, rec_J, refine_J): 26 | """ 27 | real_I, rec_J, and refine_J: Images with shape hxwx3 28 | """ 29 | # realness features 30 | mat_RGB2YMN = np.array([[0.299,0.587,0.114], 31 | [0.30,0.04,-0.35], 32 | [0.34,-0.6,0.17]]) 33 | 34 | recH,recW,recChl = rec_J.shape 35 | rec_J_flat = rec_J.reshape([recH*recW,recChl]) 36 | rec_J_flat_YMN = (mat_RGB2YMN.dot(rec_J_flat.T)).T 37 | rec_J_YMN = rec_J_flat_YMN.reshape(rec_J.shape) 38 | 39 | refine_J_flat = refine_J.reshape([recH*recW,recChl]) 40 | refine_J_flat_YMN = (mat_RGB2YMN.dot(refine_J_flat.T)).T 41 | refine_J_YMN = refine_J_flat_YMN.reshape(refine_J.shape) 42 | 43 | real_I_flat = real_I.reshape([recH*recW,recChl]) 44 | real_I_flat_YMN = (mat_RGB2YMN.dot(real_I_flat.T)).T 45 | real_I_YMN = real_I_flat_YMN.reshape(real_I.shape) 46 | 47 | # gradient features 48 | rec_Gx = cv2.Sobel(rec_J_YMN[:,:,0],cv2.CV_64F,1,0,ksize=3) 49 | rec_Gy = cv2.Sobel(rec_J_YMN[:,:,0],cv2.CV_64F,0,1,ksize=3) 50 | rec_GM = np.sqrt(rec_Gx**2 + rec_Gy**2) 51 | 52 | refine_Gx = cv2.Sobel(refine_J_YMN[:,:,0],cv2.CV_64F,1,0,ksize=3) 53 | refine_Gy = cv2.Sobel(refine_J_YMN[:,:,0],cv2.CV_64F,0,1,ksize=3) 54 | refine_GM = np.sqrt(refine_Gx**2 + refine_Gy**2) 55 | 56 | real_Gx = cv2.Sobel(real_I_YMN[:,:,0],cv2.CV_64F,1,0,ksize=3) 57 | real_Gy = cv2.Sobel(real_I_YMN[:,:,0],cv2.CV_64F,0,1,ksize=3) 58 | real_GM = np.sqrt(real_Gx**2 + real_Gy**2) 59 | 60 | # similarity 61 | rec_S_V = (2*real_GM*rec_GM+160)/(real_GM**2+rec_GM**2+160) 62 | rec_S_M = (2*rec_J_YMN[:,:,1]*real_I_YMN[:,:,1]+130)/(rec_J_YMN[:,:,1]**2+real_I_YMN[:,:,1]**2+130) 63 | rec_S_N = (2*rec_J_YMN[:,:,2]*real_I_YMN[:,:,2]+130)/(rec_J_YMN[:,:,2]**2+real_I_YMN[:,:,2]**2+130) 64 | rec_S_R = (rec_S_M*rec_S_N).reshape([recH,recW]) 65 | 66 | refine_S_V = (2*real_GM*refine_GM+160)/(real_GM**2+refine_GM**2+160) 67 | refine_S_M = (2*refine_J_YMN[:,:,1]*real_I_YMN[:,:,1]+130)/(refine_J_YMN[:,:,1]**2+real_I_YMN[:,:,1]**2+130) 68 | refine_S_N = (2*refine_J_YMN[:,:,2]*real_I_YMN[:,:,2]+130)/(refine_J_YMN[:,:,2]**2+real_I_YMN[:,:,2]**2+130) 69 | refine_S_R = (refine_S_M*refine_S_N).reshape([recH,recW]) 70 | 71 | rec_S = rec_S_R*np.power(rec_S_V, 0.4) 72 | refine_S = refine_S_R*np.power(refine_S_V, 0.4) 73 | 74 | fuseWeight = np.exp(rec_S)/(np.exp(rec_S)+np.exp(refine_S)) 75 | return fuseWeight 76 | 77 | 78 | def get_tensor_dark_channel(img, neighborhood_size): 79 | shape = img.shape 80 | if len(shape) == 4: 81 | img_min = torch.min(img, dim=1) 82 | img_dark = F.max_pool2d(img_min, kernel_size=neighborhood_size, stride=1) 83 | else: 84 | raise NotImplementedError('get_tensor_dark_channel is only for 4-d tensor [N*C*H*W]') 85 | 86 | return img_dark 87 | 88 | 89 | def array2Tensor(in_array, gpu_id=-1): 90 | in_shape = in_array.shape 91 | if len(in_shape) == 2: 92 | in_array = in_array[:,:,np.newaxis] 93 | 94 | arr_tmp = in_array.transpose([2,0,1]) 95 | arr_tmp = arr_tmp[np.newaxis,:] 96 | 97 | if gpu_id >= 0: 98 | return torch.tensor(arr_tmp.astype(np.float)).to(gpu_id) 99 | else: 100 | return torch.tensor(arr_tmp.astype(np.float)) 101 | 102 | 103 | def tensor2im(input_image, imtype=np.uint8): 104 | """"Converts a Tensor array into a numpy image array. 105 | 106 | Parameters: 107 | input_image (tensor) -- the input image tensor array 108 | imtype (type) -- the desired type of the converted numpy array 109 | """ 110 | if not isinstance(input_image, np.ndarray): 111 | if isinstance(input_image, torch.Tensor): # get the data from a variable 112 | image_tensor = input_image.data 113 | else: 114 | return input_image 115 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array 116 | if image_numpy.shape[0] == 1: # grayscale to RGB 117 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 118 | # image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling 119 | # 计算张量的最小值和最大值 120 | min_val, max_val = image_numpy.min(), image_numpy.max() 121 | # 执行归一化操作 122 | image_numpy = (image_numpy - min_val) / (max_val - min_val) 123 | image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0 # post-processing: tranpose and scaling 124 | else: # if it is a numpy array, do nothing 125 | image_numpy = input_image 126 | return image_numpy.astype(imtype) 127 | 128 | 129 | def rescale_tensor(input_tensor): 130 | """"Converts a Tensor array into the Tensor array whose data are identical to the image's. 131 | [height, width] not [width, height] 132 | 133 | Parameters: 134 | input_image (tensor) -- the input image tensor array 135 | imtype (type) -- the desired type of the converted numpy array 136 | """ 137 | 138 | if isinstance(input_tensor, torch.Tensor): 139 | input_tmp = input_tensor.cpu().float() 140 | output_tmp = (input_tmp + 1) / 2.0 * 255.0 141 | output_tmp = output_tmp.to(torch.uint8) 142 | else: 143 | return input_tensor 144 | 145 | return output_tmp.to(torch.float32) / 255.0 146 | 147 | 148 | 149 | 150 | -------------------------------------------------------------------------------- /HDMba/models/AACNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.init as init 3 | import torch 4 | import torch.nn.functional as F 5 | from torchsummary import summary 6 | 7 | 8 | class FAConv(nn.Module): 9 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, deploy=False, reduce_gamma=False, gamma_init=None ): 10 | super(FAConv, self).__init__() 11 | self.deploy = deploy 12 | if deploy: 13 | self.fused_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size,kernel_size), stride=stride, 14 | padding=padding, bias=False, padding_mode= padding_mode) 15 | # self.fused_point_conv = nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=1,stride=1,padding=0) 16 | else: 17 | self.square_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, kernel_size),stride=stride,padding=padding, dilation=dilation,bias=True) 18 | self.square_point_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1,stride=stride,bias=True) 19 | if padding - kernel_size // 2 >= 0: 20 | # Common use case. E.g., k=3, p=1 or k=5, p=2 21 | self.crop = 0 22 | # Compared to the KxK layer, the padding of the 1xK layer and Kx1 layer should be adjust to align the sliding windows (Fig 2 in the paper) 23 | hor_padding = [padding - kernel_size // 2, padding] 24 | ver_padding = [padding, padding - kernel_size // 2] 25 | else: 26 | # A negative "padding" (padding - kernel_size//2 < 0, which is not a common use case) is cropping. 27 | # Since nn.Conv2d does not support negative padding, we implement it manually 28 | self.crop = kernel_size // 2 - padding 29 | hor_padding = [0, padding] 30 | ver_padding = [padding, 0] 31 | 32 | self.ver_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, 1), 33 | stride=stride,padding=ver_padding, dilation=dilation, bias=True) 34 | self.hor_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, kernel_size), 35 | stride=stride,padding=hor_padding, dilation=dilation, bias=True) 36 | if reduce_gamma: 37 | self.init_gamma(1.0 / 3) 38 | if gamma_init is not None: 39 | assert not reduce_gamma 40 | self.init_gamma(gamma_init) 41 | 42 | def forward(self, input): 43 | if self.deploy: 44 | return self.fused_conv(input) 45 | else: 46 | square_outputs_dw=self.square_conv(input) 47 | square_outputs_pw = self.square_point_conv(input) 48 | square_outputs = square_outputs_dw+square_outputs_pw 49 | if self.crop > 0: 50 | ver_input = input[:, :, :, self.crop:-self.crop] 51 | hor_input = input[:, :, self.crop:-self.crop, :] 52 | else: 53 | ver_input = input 54 | hor_input = input 55 | vertical_outputs_dw = self.ver_conv(ver_input) 56 | vertical_outputs=vertical_outputs_dw 57 | horizontal_outputs_dw = self.hor_conv(hor_input) 58 | horizontal_outputs=horizontal_outputs_dw 59 | result = square_outputs + vertical_outputs + horizontal_outputs 60 | return result 61 | 62 | 63 | class GA(nn.Module): 64 | def __init__(self, dim=150): 65 | super(GA, self).__init__() 66 | self.ga = nn.Sequential(*[ 67 | nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, 68 | padding=0, groups=1, bias=True,padding_mode="zeros"), 69 | nn.ReLU(), 70 | nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, 71 | padding=0, groups=1, bias=True,padding_mode="zeros"), 72 | nn.Sigmoid() 73 | ]) 74 | 75 | def forward(self, x): 76 | out=self.ga(x) 77 | return out 78 | 79 | 80 | class AAConv(nn.Module): 81 | def __init__(self, dim,kernel_size,padding,stride,deploy): 82 | super(AAConv, self).__init__() 83 | self.faconv=FAConv(dim, dim, kernel_size=kernel_size, padding=padding, stride=stride, deploy=deploy) 84 | self.ga=GA(dim=dim) 85 | 86 | def forward(self, input): 87 | x=self.faconv(input) 88 | attn=self.ga(input) 89 | out=attn*x 90 | return out 91 | 92 | 93 | class PCSA(nn.Module): 94 | def __init__(self, dim=150, out_dim=150): 95 | super(PCSA, self).__init__() 96 | self.in_dim=dim 97 | self.out_dim=out_dim 98 | self.avg_pooling=torch.nn.AdaptiveAvgPool2d((1,1)) 99 | self.liner1=nn.Linear(dim,dim) 100 | self.liner2=nn.Linear(dim,dim) 101 | self.Sigmoid=nn.Sigmoid() 102 | self.conv2d = nn.Conv2d(dim,dim,1,1) 103 | self.avg_pooling_1d = torch.nn.AdaptiveAvgPool1d(1) 104 | self.conv1d=nn.Conv1d(1,1,1,1) 105 | 106 | def forward(self, x): 107 | # x=self.conv2d(x) 108 | 109 | x_pool=self.avg_pooling(x).squeeze(dim=3).permute(0,2,1) 110 | # print(x_pool.shape) 111 | q=self.liner1(x_pool) 112 | k=self.liner2(x_pool) 113 | attn=k.permute(0,2,1)@q 114 | attn = self.avg_pooling_1d(attn) 115 | attn = self.conv1d(attn.transpose(-1,-2)).transpose(-1,-2).unsqueeze(dim=3) 116 | attn=self.Sigmoid(attn) 117 | out =x*attn 118 | return out 119 | 120 | 121 | # Pixel Attention 像素注意力 122 | class AACNet(nn.Module): 123 | # def __init__(self, in_dim=4, out_dim=4, dim=64, kernel_size=3, padding=1, num_blocks=5, stride=1, deploy=False): 124 | # def __init__(self, in_dim=7, out_dim=7, dim=64, kernel_size=3, padding=1, num_blocks=5, stride=1, deploy=False): 125 | def __init__(self, in_dim=305, out_dim=305, dim=64, kernel_size=3, padding=1,num_blocks=5, stride=1,deploy=False): 126 | super(AACNet, self).__init__() 127 | self.blocks = nn.ModuleList([]) 128 | self.blocks1 = nn.ModuleList([]) 129 | self.blocks2 = nn.ModuleList([]) 130 | self.dim = dim 131 | self.t_Conv1 = nn.Conv2d(self.dim, self.dim, 3, 1, 1) 132 | self.t_Conv2 = nn.Conv2d(self.dim, self.dim, 3, 1, 1) 133 | self.t_Conv3 = nn.Conv2d(self.dim, self.dim, 3, 1, 1) 134 | self.num_block = num_blocks 135 | self.Convd_in = nn.Conv2d(in_dim, self.dim, kernel_size=1, padding=0, stride=1) 136 | self.Convd = nn.Conv2d(self.dim, out_dim, kernel_size=1, padding=0, stride=1) 137 | self.Convd_out = nn.Conv2d(out_dim, out_dim, 3, 1, 1) 138 | self.cattn = PCSA(self.dim, self.dim) 139 | # self.acdw=AAConv( self.dim, self.dim, kernel_size=kernel_size, padding=padding, stride=stride, deploy=False) 140 | # self.gps=1 141 | 142 | def weigth_init(m): 143 | if isinstance(m, nn.Conv2d): 144 | init.xavier_uniform_(m.weight.data) 145 | elif isinstance(m, nn.Linear): 146 | m.weight.data.normal_(0, 0.01) 147 | m.bias.data.zero_() 148 | for _ in range(num_blocks): 149 | self.blocks.append(nn.ModuleList([ 150 | AAConv(self.dim, kernel_size=kernel_size, padding=padding, stride=stride, deploy=deploy), 151 | nn.PReLU(), 152 | AAConv( self.dim, kernel_size=kernel_size, padding=padding, stride=stride, deploy=deploy), 153 | PCSA(self.dim,self.dim), 154 | ])) 155 | for _ in range(num_blocks): 156 | self.blocks1.append(nn.ModuleList([ 157 | AAConv(self.dim, kernel_size=kernel_size, padding=padding, stride=stride, deploy=deploy), 158 | nn.PReLU(), 159 | AAConv(self.dim, kernel_size=kernel_size, padding=padding, stride=stride, deploy=deploy), 160 | PCSA(self.dim, self.dim), 161 | ])) 162 | for _ in range(num_blocks): 163 | self.blocks2.append(nn.ModuleList([ 164 | AAConv(self.dim, kernel_size=kernel_size, padding=padding, stride=stride, deploy=deploy), 165 | nn.PReLU(), 166 | AAConv(self.dim, kernel_size=kernel_size, padding=padding, stride=stride, deploy=deploy), 167 | PCSA(self.dim, self.dim), 168 | ])) 169 | self.cattn.apply(weigth_init) 170 | self.blocks.apply(weigth_init) 171 | self.blocks1.apply(weigth_init) 172 | self.blocks2.apply(weigth_init) 173 | # self..apply(weigth_init) 174 | 175 | def forward(self, x): 176 | x_original_features = x 177 | x = self.Convd_in(x) 178 | x_shallow_features = x 179 | for (aaconv, act, aaconv1, pcsa) in self.blocks: 180 | res = x 181 | x = aaconv(x) 182 | x = act(x) 183 | x = aaconv1(x) 184 | x = pcsa(x) 185 | x = x+res 186 | x = self.t_Conv1(x) 187 | for (aaconv, act, aaconv1, pcsa) in self.blocks1: 188 | res1 = x 189 | x = aaconv(x) 190 | x = act(x) 191 | x = aaconv1(x) 192 | x = pcsa(x) 193 | x = x + res1 194 | x = self.t_Conv2(x) 195 | for (aaconv, act, aaconv1, pcsa) in self.blocks2: 196 | res2 = x 197 | x = aaconv(x) 198 | x = act(x) 199 | x = aaconv1(x) 200 | x = pcsa(x) 201 | x = x + res2 202 | x = self.t_Conv3(x) 203 | x = x+x_shallow_features 204 | x = self.cattn(x) 205 | x = self.Convd(x) 206 | x = self.Convd_out(x) 207 | out = x + x_original_features 208 | # out=x 209 | return out 210 | 211 | 212 | if __name__ == '__main__': 213 | net = AACNet(in_dim=305, out_dim=305, kernel_size=3, padding=1, stride=1, num_blocks=5).cuda() 214 | # net = AACNet(in_dim=7, out_dim=7, kernel_size=3, padding=1, stride=1, num_blocks=5).cuda() 215 | # net = AACNet(in_dim=4, out_dim=4, kernel_size=3, padding=1, stride=1, num_blocks=5).cuda() 216 | 217 | device = torch.device('cpu') 218 | net.to(device) 219 | # summary(net.cuda(), (305, 64, 64)) 220 | 221 | 222 | #======================================================================================================================= 223 | # if __name__ == '__main__': 224 | # x = torch.randn(4, 305, 64, 64).cuda() 225 | # test_kernel_padding = [(3,1), (5,2), (7,3), (9,4),(11,5) ] 226 | # # mcplb = MCPLB(in_dim=150, out_dim=150, kernel_size=3, padding=1, stride=1, num_blocks=5).cuda() 227 | # mcplb = AACNet(in_dim=305, out_dim=305, kernel_size=3, padding=1, stride=1, num_blocks=5).cuda() 228 | # out = mcplb(x) 229 | # # summary(mcplb.cuda(), (150, 64, 64)) 230 | -------------------------------------------------------------------------------- /HDMba/models/HDMba.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.utils.data import DataLoader, Dataset 5 | import torch.utils.checkpoint as checkpoint 6 | import torch.nn.functional as F 7 | import einops 8 | from einops import rearrange 9 | # import tqdm 10 | # 系统相关的库 11 | import math 12 | import os 13 | import urllib.request 14 | from zipfile import ZipFile 15 | # from transformers import * 16 | from timm.models.layers import DropPath, to_2tuple 17 | 18 | torch.autograd.set_detect_anomaly(True) 19 | 20 | # 配置标识和超参数 21 | USE_MAMBA = 1 22 | DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM = 0 23 | # 设定所用设备 24 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 25 | 26 | # 人为定义的超参数 27 | batch_size = 4 # 批次大小 28 | last_batch_size = 81 # 最后一个批次大小 29 | current_batch_size = batch_size 30 | different_batch_size = False 31 | h_new = None 32 | temp_buffer = None 33 | 34 | 35 | # 定义S6模块 36 | class S6(nn.Module): 37 | def __init__(self, seq_len, d_model, state_size, device): 38 | super(S6, self).__init__() 39 | # 一系列线性变换 40 | self.fc1 = nn.Linear(d_model, d_model, device=device) 41 | self.fc2 = nn.Linear(d_model, state_size, device=device) 42 | self.fc3 = nn.Linear(d_model, state_size, device=device) 43 | # 设定一些超参数 44 | self.seq_len = seq_len 45 | self.d_model = d_model 46 | self.state_size = state_size 47 | self.A = nn.Parameter(F.normalize(torch.ones(d_model, state_size, device=device), p=2, dim=-1)) 48 | # 参数初始化 49 | nn.init.xavier_uniform_(self.A) 50 | 51 | self.B = torch.zeros(batch_size, self.seq_len, self.state_size, device=device) 52 | self.C = torch.zeros(batch_size, self.seq_len, self.state_size, device=device) 53 | 54 | self.delta = torch.zeros(batch_size, self.seq_len, self.d_model, device=device) 55 | self.dA = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device) 56 | self.dB = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device) 57 | 58 | # 定义内部参数h和y 59 | self.h = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device) 60 | self.y = torch.zeros(batch_size, self.seq_len, self.d_model, device=device) 61 | 62 | # 离散化函数 63 | def discretization(self): 64 | # 离散化函数定义介绍在Mamba论文中的28页 65 | self.dB = torch.einsum("bld,bln->bldn", self.delta, self.B) 66 | # dA = torch.matrix_exp(A * delta) # matrix_exp() only supports square matrix 67 | self.dA = torch.exp(torch.einsum("bld,dn->bldn", self.delta, self.A)) 68 | return self.dA, self.dB 69 | 70 | # 前行传播 71 | def forward(self, x): 72 | # 参考Mamba论文中算法2 73 | self.B = self.fc2(x) 74 | self.C = self.fc3(x) 75 | self.delta = F.softplus(self.fc1(x)) 76 | # 离散化 77 | self.discretization() 78 | 79 | if DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM: 80 | # 如果不使用'h_new',将触发本地允许错误 81 | global current_batch_size 82 | current_batch_size = x.shape[0] 83 | 84 | if self.h.shape[0] != current_batch_size: 85 | different_batch_size = True 86 | # 缩放h的维度匹配当前的批次 87 | h_new = torch.einsum('bldn,bldn->bldn', self.dA, self.h[:current_batch_size, ...]) + rearrange(x, "b l d -> b l d 1") * self.dB 88 | else: 89 | different_batch_size = False 90 | h_new = torch.einsum('bldn,bldn->bldn', self.dA, self.h) + rearrange(x, "b l d -> b l d 1") * self.dB 91 | 92 | # 改变y的维度 93 | self.y = torch.einsum('bln,bldn->bld', self.C, h_new) 94 | 95 | # 基于h_new更新h的信息 96 | global temp_buffer 97 | temp_buffer = h_new.detach().clone() if not self.h.requires_grad else h_new.clone() 98 | 99 | return self.y 100 | else: 101 | # 将会触发错误 102 | # 设置h的维度 103 | h = torch.zeros(x.size(0), self.seq_len, self.d_model, self.state_size, device=x.device) 104 | y = torch.zeros_like(x) 105 | 106 | h = torch.einsum('bldn,bldn->bldn', self.dA, h) + rearrange(x, "b l d -> b l d 1") * self.dB 107 | 108 | #设置y的维度 109 | y = torch.einsum('bln,bldn->bld', self.C, h) 110 | return y 111 | 112 | 113 | class RMSNorm(nn.Module): 114 | def __init__(self, d_model: int, eps: float = 1e-5, device: str = 'cuda'): 115 | super().__init__() 116 | self.eps = eps 117 | self.weight = nn.Parameter(torch.ones(d_model, device=device)) 118 | 119 | def forward(self, x): 120 | output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight 121 | return output 122 | 123 | 124 | # 定义MambaBlock模块 125 | class MambaBlock(nn.Module): 126 | def __init__(self, seq_len, d_model, state_size, device): 127 | super(MambaBlock, self).__init__() 128 | self.inp_proj = nn.Linear(d_model, 2 * d_model, device=device) 129 | self.out_proj = nn.Linear(2 * d_model, d_model, device=device) 130 | # 残差连接 131 | self.D = nn.Linear(d_model, 2 * d_model, device=device) 132 | # 设置偏差属性 133 | self.out_proj.bias._no_weight_decay = True 134 | # 初始化偏差 135 | nn.init.constant_(self.out_proj.bias, 1.0) 136 | # 初始化S6模块 137 | self.S6 = S6(seq_len, 2 * d_model, state_size, device) 138 | # 添加1D卷积 139 | self.conv = nn.Conv1d(seq_len, seq_len, kernel_size=3, padding=1, groups=seq_len, device=device) 140 | # 添加线性层 141 | self.conv_linear = nn.Linear(2 * d_model, 2 * d_model, device=device) 142 | # 正则化 143 | self.norm = RMSNorm(d_model, device=device) 144 | 145 | def forward(self, x): 146 | # 参考Mamba论文中的图3 147 | x = self.norm(x) 148 | x_proj = self.inp_proj(x) 149 | # 1D卷积操作 150 | x_conv = self.conv(x_proj) 151 | x_conv_act = F.silu(x_conv) # Swish激活 152 | # 线性操作 153 | x_conv_out = self.conv_linear(x_conv_act) 154 | # S6模块操作 155 | x_ssm = self.S6(x_conv_out) 156 | x_act = F.silu(x_ssm) # Swish激活 157 | # 残差连接 158 | x_residual = F.silu(self.D(x)) 159 | x_combined = x_act * x_residual 160 | x_out = self.out_proj(x_combined) 161 | return x_out 162 | 163 | 164 | # 输入:序列长度 模型维度 状态大小 165 | class Mamba(nn.Module): 166 | def __init__(self, seq_len, d_model, state_size, device): 167 | super(Mamba, self).__init__() 168 | self.seq_len = seq_len 169 | self.d_model = d_model 170 | self.state_size = state_size 171 | self.mamba_block1 = MambaBlock(self.seq_len, self.d_model, self.state_size, device) 172 | self.mamba_block2 = MambaBlock(self.seq_len, self.d_model, self.state_size, device) 173 | self.mamba_block3 = MambaBlock(self.seq_len, self.d_model, self.state_size, device) 174 | 175 | def forward(self, x): 176 | x = self.mamba_block1(x) 177 | x = self.mamba_block2(x) 178 | x = self.mamba_block3(x) 179 | return x 180 | 181 | 182 | def window_partition(x, window_size): 183 | """ 184 | Args: 185 | x: (B, H, W, C) 186 | window_size (int): window size 187 | """ 188 | B, H, W, C = x.shape 189 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 190 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 191 | return windows 192 | 193 | 194 | def window_reverse(windows, window_size, H, W): 195 | """ 196 | Args: 197 | windows: (num_windows*B, window_size, window_size, C) 198 | window_size (int): Window size 199 | H (int): Height of image 200 | W (int): Width of image 201 | """ 202 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 203 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 204 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 205 | return x 206 | 207 | 208 | class Mlp(nn.Module): 209 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 210 | super().__init__() 211 | out_features = out_features or in_features 212 | hidden_features = hidden_features or in_features 213 | self.fc1 = nn.Linear(in_features, hidden_features) 214 | self.act = act_layer() 215 | self.fc2 = nn.Linear(hidden_features, out_features) 216 | self.drop = nn.Dropout(drop) 217 | 218 | def forward(self, x): 219 | x = self.fc1(x) 220 | x = self.act(x) 221 | x = self.drop(x) 222 | x = self.fc2(x) 223 | x = self.drop(x) 224 | return x 225 | 226 | 227 | class SSMamba(nn.Module): 228 | def __init__(self, dim, window_size, shift_size=0, 229 | mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 230 | super().__init__() 231 | self.dim = dim 232 | self.window_size = window_size 233 | self.shift_size = shift_size 234 | self.mlp_ratio = mlp_ratio 235 | 236 | self.norm1 = norm_layer(dim) 237 | self.mamba = Mamba(seq_len=self.window_size ** 2, d_model=dim, state_size=dim, device=device) 238 | 239 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 240 | self.norm2 = norm_layer(dim) 241 | mlp_hidden_dim = int(dim * mlp_ratio) 242 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 243 | 244 | def forward(self, x): 245 | B, C, H, W = x.shape 246 | shortcut = x 247 | x = x.permute(0, 2, 3, 1).contiguous().view(B, H * W, C) # Change shape to (B, L, C) 248 | x = self.norm1(x) 249 | 250 | x = x.view(B, H, W, C) 251 | pad_r = (self.window_size - W % self.window_size) % self.window_size 252 | pad_b = (self.window_size - H % self.window_size) % self.window_size 253 | x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) 254 | _, Hp, Wp, _ = x.shape 255 | 256 | if self.shift_size > 0: 257 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 258 | else: 259 | shifted_x = x 260 | 261 | x_windows = window_partition(shifted_x, self.window_size) 262 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) 263 | 264 | # Apply Mamba 265 | attn_windows = self.mamba(x_windows) 266 | 267 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 268 | shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) 269 | 270 | if self.shift_size > 0: 271 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 272 | else: 273 | x = shifted_x 274 | 275 | if pad_r > 0 or pad_b > 0: 276 | x = x[:, :H, :W, :].contiguous() 277 | 278 | x = x.view(B, H * W, C) 279 | x = self.norm2(x) 280 | x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous() # Change shape back to (B, C, H, W) 281 | 282 | x = shortcut + self.drop_path(x) 283 | x = x + self.drop_path(self.mlp(x.permute(0, 2, 3, 1).contiguous().view(B, H * W, C)).view(B, H, W, C).permute(0, 3, 1, 2).contiguous()) 284 | 285 | return x 286 | 287 | 288 | class SSMaBlock(nn.Module): 289 | def __init__(self, 290 | dim=32, 291 | window_size=7, 292 | depth=4, 293 | mlp_ratio=2, 294 | drop_path=0.0): 295 | super(SSMaBlock, self).__init__() 296 | self.ssmablock = nn.Sequential(*[ 297 | SSMamba(dim=dim, window_size=window_size, 298 | shift_size=0 if (i % 2 == 0) else window_size // 2, 299 | mlp_ratio=mlp_ratio, 300 | drop_path=drop_path) 301 | for i in range(depth) 302 | ]) 303 | self.conv = nn.Conv2d(dim, dim, 3, 1, 1) 304 | 305 | def forward(self, x): 306 | out = self.ssmablock(x) 307 | out = self.conv(out) + x 308 | return out 309 | 310 | 311 | class HDMba(nn.Module): 312 | def __init__(self, 313 | inp_channels=305, 314 | dim=32, 315 | window_size=7, 316 | depths=[4, 4, 4, 4], 317 | mlp_ratio=2, 318 | bias=False, 319 | drop_path=0.0 320 | ): 321 | super(HDMba, self).__init__() 322 | 323 | self.conv_first = nn.Conv2d(inp_channels, dim, 3, 1, 1) # shallow feature extraction 324 | self.num_layers = depths 325 | self.layers = nn.ModuleList() 326 | 327 | for i_layer in range(len(self.num_layers)): 328 | layer = SSMaBlock(dim=dim, 329 | window_size=window_size, 330 | depth=depths[i_layer], 331 | mlp_ratio=mlp_ratio, 332 | drop_path=drop_path) 333 | self.layers.append(layer) 334 | 335 | self.output = nn.Conv2d(int(dim), dim, kernel_size=3, stride=1, padding=1, bias=bias) 336 | self.conv_delasta = nn.Conv2d(dim, inp_channels, 3, 1, 1) # reconstruction from features 337 | 338 | def forward(self, inp_img): 339 | f1 = self.conv_first(inp_img) 340 | x = f1 341 | for layer in self.layers: 342 | x = layer(x) 343 | 344 | x = self.output(x + f1) 345 | x = self.conv_delasta(x) + inp_img 346 | return x 347 | 348 | from torchsummary import summary 349 | 350 | if __name__ == '__main__': 351 | net = HDMba().to(device) 352 | # x_train = torch.randn(2, 305, 64, 64).to(device) 353 | # out_train = net(x_train) 354 | # print("Training output shape:", out_train.shape) # Expected output shape should be (2, 32, 64, 64) 355 | # summary(net.cuda(), (305, 64, 64)) 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | ############################################################################################ 364 | # class SSMamba(nn.Module): 365 | # def __init__(self, dim, input_resolution, num_heads, window_size, shift_size=0, 366 | # mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 367 | # super().__init__() 368 | # self.dim = dim 369 | # self.input_resolution = input_resolution 370 | # self.num_heads = num_heads 371 | # self.window_size = window_size 372 | # self.shift_size = shift_size 373 | # self.mlp_ratio = mlp_ratio 374 | # 375 | # if min(self.input_resolution) <= self.window_size: 376 | # # if window size is larger than input resolution, we don't partition windows 377 | # self.shift_size = 0 378 | # self.window_size = min(self.input_resolution) 379 | # 380 | # self.norm1 = norm_layer(dim) 381 | # self.mamba = Mamba(seq_len=self.window_size ** 2, d_model=dim, state_size=dim, device=device) 382 | # 383 | # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 384 | # self.norm2 = norm_layer(dim) 385 | # mlp_hidden_dim = int(dim * mlp_ratio) 386 | # self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 387 | # 388 | # def forward(self, x): 389 | # H, W = self.input_resolution 390 | # B, L, C = x.shape 391 | # assert L == H * W, "input feature has wrong size" 392 | # 393 | # shortcut = x 394 | # x = self.norm1(x) 395 | # x = x.view(B, H, W, C) 396 | # # Pad feature maps to multiples of window size 397 | # pad_r = (self.window_size - W % self.window_size) % self.window_size 398 | # pad_b = (self.window_size - H % self.window_size) % self.window_size 399 | # x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) 400 | # _, Hp, Wp, _ = x.shape 401 | # 402 | # if self.shift_size > 0: 403 | # shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 404 | # else: 405 | # shifted_x = x 406 | # 407 | # x_windows = window_partition(shifted_x, self.window_size) 408 | # x_windows = x_windows.view(-1, self.window_size * self.window_size, C) 409 | # 410 | # # Apply Mamba 411 | # attn_windows = self.mamba(x_windows) 412 | # 413 | # attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 414 | # shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) 415 | # 416 | # if self.shift_size > 0: 417 | # x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 418 | # else: 419 | # x = shifted_x 420 | # 421 | # if pad_r > 0 or pad_b > 0: 422 | # x = x[:, :H, :W, :].contiguous() 423 | # 424 | # x = x.view(B, H * W, C) 425 | # 426 | # # FFN 427 | # x = shortcut + self.drop_path(x) 428 | # x = x + self.drop_path(self.mlp(self.norm2(x))) 429 | # 430 | # return x 431 | # 432 | # 433 | # # if __name__ == '__main__': 434 | # # net = SSMamba(dim=30, input_resolution=[32, 32], num_heads=3).to(device) 435 | # # x = torch.randn(4, 32*32, 30).to(device) 436 | # # out = net(x) 437 | # # print(out.shape) 438 | # 439 | # 440 | # class SSMaBlock(nn.Module): 441 | # def __init__(self, 442 | # dim=32, 443 | # num_head=3, 444 | # input_resolution=[32,32], 445 | # window_size=7, 446 | # depth=3, 447 | # mlp_ratio=2, 448 | # drop_path=0.0): 449 | # super(SSMaBlock, self).__init__() 450 | # self.ssmablock = nn.Sequential(*[SSMamba(dim=dim, input_resolution=input_resolution, num_heads=num_head, window_size=window_size, 451 | # shift_size=0 if (i % 2 == 0) else window_size // 2, 452 | # mlp_ratio=mlp_ratio, 453 | # drop_path = drop_path, 454 | # # drop_path=drop_path[i], 455 | # ) 456 | # for i in range(depth)]) 457 | # self.conv = nn.Conv2d(dim, dim, 3, 1, 1) 458 | # 459 | # def forward(self, x): 460 | # out = self.ssmablock(x) 461 | # # out = self.conv(out) + x 462 | # return out 463 | # 464 | # 465 | # if __name__ == '__main__': 466 | # net = SSMaBlock(dim=30, input_resolution=[32, 32]).to(device) 467 | # x = torch.randn(4, 32*32, 30).to(device) 468 | # out = net(x) 469 | # print(out.shape) -------------------------------------------------------------------------------- /HDMba/models/Dehazeformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import numpy as np 6 | from torch.nn.init import _calculate_fan_in_and_fan_out 7 | from timm.models.layers import to_2tuple, trunc_normal_ 8 | from torchsummary import summary 9 | 10 | 11 | class RLN(nn.Module): 12 | r"""Revised LayerNorm""" 13 | def __init__(self, dim, eps=1e-5, detach_grad=False): 14 | super(RLN, self).__init__() 15 | self.eps = eps 16 | self.detach_grad = detach_grad 17 | 18 | self.weight = nn.Parameter(torch.ones((1, dim, 1, 1))) 19 | self.bias = nn.Parameter(torch.zeros((1, dim, 1, 1))) 20 | 21 | self.meta1 = nn.Conv2d(1, dim, 1) 22 | self.meta2 = nn.Conv2d(1, dim, 1) 23 | 24 | trunc_normal_(self.meta1.weight, std=.02) 25 | nn.init.constant_(self.meta1.bias, 1) 26 | 27 | trunc_normal_(self.meta2.weight, std=.02) 28 | nn.init.constant_(self.meta2.bias, 0) 29 | 30 | def forward(self, input): 31 | mean = torch.mean(input, dim=(1, 2, 3), keepdim=True) 32 | std = torch.sqrt((input - mean).pow(2).mean(dim=(1, 2, 3), keepdim=True) + self.eps) 33 | 34 | normalized_input = (input - mean) / std 35 | 36 | if self.detach_grad: 37 | rescale, rebias = self.meta1(std.detach()), self.meta2(mean.detach()) 38 | else: 39 | rescale, rebias = self.meta1(std), self.meta2(mean) 40 | 41 | out = normalized_input * self.weight + self.bias 42 | return out, rescale, rebias 43 | 44 | 45 | class Mlp(nn.Module): 46 | def __init__(self, network_depth, in_features, hidden_features=None, out_features=None): 47 | super().__init__() 48 | out_features = out_features or in_features 49 | hidden_features = hidden_features or in_features 50 | 51 | self.network_depth = network_depth 52 | 53 | self.mlp = nn.Sequential( 54 | nn.Conv2d(in_features, hidden_features, 1), 55 | nn.ReLU(True), 56 | nn.Conv2d(hidden_features, out_features, 1) 57 | ) 58 | 59 | self.apply(self._init_weights) 60 | 61 | def _init_weights(self, m): 62 | if isinstance(m, nn.Conv2d): 63 | gain = (8 * self.network_depth) ** (-1/4) 64 | fan_in, fan_out = _calculate_fan_in_and_fan_out(m.weight) 65 | std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) 66 | trunc_normal_(m.weight, std=std) 67 | if m.bias is not None: 68 | nn.init.constant_(m.bias, 0) 69 | 70 | def forward(self, x): 71 | return self.mlp(x) 72 | 73 | 74 | def window_partition(x, window_size): 75 | B, H, W, C = x.shape 76 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 77 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size**2, C) 78 | return windows 79 | 80 | 81 | def window_reverse(windows, window_size, H, W): 82 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 83 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 84 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 85 | return x 86 | 87 | 88 | def get_relative_positions(window_size): 89 | coords_h = torch.arange(window_size) 90 | coords_w = torch.arange(window_size) 91 | 92 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 93 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 94 | relative_positions = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 95 | 96 | relative_positions = relative_positions.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 97 | relative_positions_log = torch.sign(relative_positions) * torch.log(1. + relative_positions.abs()) 98 | 99 | return relative_positions_log 100 | 101 | 102 | class WindowAttention(nn.Module): 103 | def __init__(self, dim, window_size, num_heads): 104 | 105 | super().__init__() 106 | self.dim = dim 107 | self.window_size = window_size # Wh, Ww 108 | self.num_heads = num_heads 109 | head_dim = dim // num_heads 110 | self.scale = head_dim ** -0.5 111 | 112 | relative_positions = get_relative_positions(self.window_size) 113 | self.register_buffer("relative_positions", relative_positions) 114 | self.meta = nn.Sequential( 115 | nn.Linear(2, 256, bias=True), 116 | nn.ReLU(True), 117 | nn.Linear(256, num_heads, bias=True) 118 | ) 119 | 120 | self.softmax = nn.Softmax(dim=-1) 121 | 122 | def forward(self, qkv): 123 | B_, N, _ = qkv.shape 124 | 125 | qkv = qkv.reshape(B_, N, 3, self.num_heads, self.dim // self.num_heads).permute(2, 0, 3, 1, 4) 126 | 127 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 128 | 129 | q = q * self.scale 130 | attn = (q @ k.transpose(-2, -1)) 131 | 132 | relative_position_bias = self.meta(self.relative_positions) 133 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 134 | attn = attn + relative_position_bias.unsqueeze(0) 135 | 136 | attn = self.softmax(attn) 137 | 138 | x = (attn @ v).transpose(1, 2).reshape(B_, N, self.dim) 139 | return x 140 | 141 | 142 | class Attention(nn.Module): 143 | def __init__(self, network_depth, dim, num_heads, window_size, shift_size, use_attn=False, conv_type=None): 144 | super().__init__() 145 | self.dim = dim 146 | self.head_dim = int(dim // num_heads) 147 | self.num_heads = num_heads 148 | 149 | self.window_size = window_size 150 | self.shift_size = shift_size 151 | 152 | self.network_depth = network_depth 153 | self.use_attn = use_attn 154 | self.conv_type = conv_type 155 | 156 | if self.conv_type == 'Conv': 157 | self.conv = nn.Sequential( 158 | nn.Conv2d(dim, dim, kernel_size=3, padding=1, padding_mode='reflect'), 159 | nn.ReLU(True), 160 | nn.Conv2d(dim, dim, kernel_size=3, padding=1, padding_mode='reflect') 161 | ) 162 | 163 | if self.conv_type == 'DWConv': 164 | self.conv = nn.Conv2d(dim, dim, kernel_size=5, padding=2, groups=dim, padding_mode='reflect') 165 | 166 | if self.conv_type == 'DWConv' or self.use_attn: 167 | self.V = nn.Conv2d(dim, dim, 1) 168 | self.proj = nn.Conv2d(dim, dim, 1) 169 | 170 | if self.use_attn: 171 | self.QK = nn.Conv2d(dim, dim * 2, 1) 172 | self.attn = WindowAttention(dim, window_size, num_heads) 173 | 174 | self.apply(self._init_weights) 175 | 176 | def _init_weights(self, m): 177 | if isinstance(m, nn.Conv2d): 178 | w_shape = m.weight.shape 179 | 180 | if w_shape[0] == self.dim * 2: # QK 181 | fan_in, fan_out = _calculate_fan_in_and_fan_out(m.weight) 182 | std = math.sqrt(2.0 / float(fan_in + fan_out)) 183 | trunc_normal_(m.weight, std=std) 184 | else: 185 | gain = (8 * self.network_depth) ** (-1/4) 186 | fan_in, fan_out = _calculate_fan_in_and_fan_out(m.weight) 187 | std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) 188 | trunc_normal_(m.weight, std=std) 189 | 190 | if m.bias is not None: 191 | nn.init.constant_(m.bias, 0) 192 | 193 | def check_size(self, x, shift=False): 194 | _, _, h, w = x.size() 195 | mod_pad_h = (self.window_size - h % self.window_size) % self.window_size 196 | mod_pad_w = (self.window_size - w % self.window_size) % self.window_size 197 | 198 | if shift: 199 | x = F.pad(x, (self.shift_size, (self.window_size-self.shift_size+mod_pad_w) % self.window_size, 200 | self.shift_size, (self.window_size-self.shift_size+mod_pad_h) % self.window_size), mode='reflect') 201 | else: 202 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') 203 | return x 204 | 205 | def forward(self, X): 206 | B, C, H, W = X.shape 207 | 208 | if self.conv_type == 'DWConv' or self.use_attn: 209 | V = self.V(X) 210 | 211 | if self.use_attn: 212 | QK = self.QK(X) 213 | QKV = torch.cat([QK, V], dim=1) 214 | 215 | # shift 216 | shifted_QKV = self.check_size(QKV, self.shift_size > 0) 217 | Ht, Wt = shifted_QKV.shape[2:] 218 | 219 | # partition windows 220 | shifted_QKV = shifted_QKV.permute(0, 2, 3, 1) 221 | qkv = window_partition(shifted_QKV, self.window_size) # nW*B, window_size**2, C 222 | 223 | attn_windows = self.attn(qkv) 224 | 225 | # merge windows 226 | shifted_out = window_reverse(attn_windows, self.window_size, Ht, Wt) # B H' W' C 227 | 228 | # reverse cyclic shift 229 | out = shifted_out[:, self.shift_size:(self.shift_size+H), self.shift_size:(self.shift_size+W), :] 230 | attn_out = out.permute(0, 3, 1, 2) 231 | 232 | if self.conv_type in ['Conv', 'DWConv']: 233 | conv_out = self.conv(V) 234 | out = self.proj(conv_out + attn_out) 235 | else: 236 | out = self.proj(attn_out) 237 | 238 | else: 239 | if self.conv_type == 'Conv': 240 | out = self.conv(X) # no attention and use conv, no projection 241 | elif self.conv_type == 'DWConv': 242 | out = self.proj(self.conv(V)) 243 | 244 | return out 245 | 246 | 247 | class TransformerBlock(nn.Module): 248 | def __init__(self, network_depth, dim, num_heads, mlp_ratio=4., 249 | norm_layer=nn.LayerNorm, mlp_norm=False, 250 | window_size=8, shift_size=0, use_attn=True, conv_type=None): 251 | super().__init__() 252 | self.use_attn = use_attn 253 | self.mlp_norm = mlp_norm 254 | 255 | self.norm1 = norm_layer(dim) if use_attn else nn.Identity() 256 | self.attn = Attention(network_depth, dim, num_heads=num_heads, window_size=window_size, 257 | shift_size=shift_size, use_attn=use_attn, conv_type=conv_type) 258 | 259 | self.norm2 = norm_layer(dim) if use_attn and mlp_norm else nn.Identity() 260 | self.mlp = Mlp(network_depth, dim, hidden_features=int(dim * mlp_ratio)) 261 | 262 | def forward(self, x): 263 | identity = x 264 | if self.use_attn: x, rescale, rebias = self.norm1(x) 265 | x = self.attn(x) 266 | if self.use_attn: x = x * rescale + rebias 267 | x = identity + x 268 | 269 | identity = x 270 | if self.use_attn and self.mlp_norm: x, rescale, rebias = self.norm2(x) 271 | x = self.mlp(x) 272 | if self.use_attn and self.mlp_norm: x = x * rescale + rebias 273 | x = identity + x 274 | return x 275 | 276 | 277 | class BasicLayer(nn.Module): 278 | def __init__(self, network_depth, dim, depth, num_heads, mlp_ratio=4., 279 | norm_layer=nn.LayerNorm, window_size=8, 280 | attn_ratio=0., attn_loc='last', conv_type=None): 281 | 282 | super().__init__() 283 | self.dim = dim 284 | self.depth = depth 285 | 286 | attn_depth = attn_ratio * depth 287 | 288 | if attn_loc == 'last': 289 | use_attns = [i >= depth-attn_depth for i in range(depth)] 290 | elif attn_loc == 'first': 291 | use_attns = [i < attn_depth for i in range(depth)] 292 | elif attn_loc == 'middle': 293 | use_attns = [i >= (depth-attn_depth)//2 and i < (depth+attn_depth)//2 for i in range(depth)] 294 | 295 | # build blocks 296 | self.blocks = nn.ModuleList([ 297 | TransformerBlock(network_depth=network_depth, 298 | dim=dim, 299 | num_heads=num_heads, 300 | mlp_ratio=mlp_ratio, 301 | norm_layer=norm_layer, 302 | window_size=window_size, 303 | shift_size=0 if (i % 2 == 0) else window_size // 2, 304 | use_attn=use_attns[i], conv_type=conv_type) 305 | for i in range(depth)]) 306 | 307 | def forward(self, x): 308 | for blk in self.blocks: 309 | x = blk(x) 310 | return x 311 | 312 | 313 | class PatchEmbed(nn.Module): 314 | def __init__(self, patch_size=4, in_chans=3, embed_dim=96, kernel_size=None): 315 | super().__init__() 316 | self.in_chans = in_chans 317 | self.embed_dim = embed_dim 318 | 319 | if kernel_size is None: 320 | kernel_size = patch_size 321 | 322 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=patch_size, 323 | padding=(kernel_size-patch_size+1)//2, padding_mode='reflect') 324 | 325 | def forward(self, x): 326 | x = self.proj(x) 327 | return x 328 | 329 | 330 | class PatchUnEmbed(nn.Module): 331 | def __init__(self, patch_size=4, out_chans=3, embed_dim=96, kernel_size=None): 332 | super().__init__() 333 | self.out_chans = out_chans 334 | self.embed_dim = embed_dim 335 | 336 | if kernel_size is None: 337 | kernel_size = 1 338 | 339 | self.proj = nn.Sequential( 340 | nn.Conv2d(embed_dim, out_chans*patch_size**2, kernel_size=kernel_size, 341 | padding=kernel_size//2, padding_mode='reflect'), 342 | nn.PixelShuffle(patch_size) 343 | ) 344 | 345 | def forward(self, x): 346 | x = self.proj(x) 347 | return x 348 | 349 | 350 | class SKFusion(nn.Module): 351 | def __init__(self, dim, height=2, reduction=8): 352 | super(SKFusion, self).__init__() 353 | 354 | self.height = height 355 | d = max(int(dim/reduction), 4) 356 | 357 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 358 | self.mlp = nn.Sequential( 359 | nn.Conv2d(dim, d, 1, bias=False), 360 | nn.ReLU(), 361 | nn.Conv2d(d, dim*height, 1, bias=False) 362 | ) 363 | 364 | self.softmax = nn.Softmax(dim=1) 365 | 366 | def forward(self, in_feats): 367 | B, C, H, W = in_feats[0].shape 368 | 369 | in_feats = torch.cat(in_feats, dim=1) 370 | in_feats = in_feats.view(B, self.height, C, H, W) 371 | 372 | feats_sum = torch.sum(in_feats, dim=1) 373 | attn = self.mlp(self.avg_pool(feats_sum)) 374 | attn = self.softmax(attn.view(B, self.height, C, 1, 1)) 375 | 376 | out = torch.sum(in_feats*attn, dim=1) 377 | return out 378 | 379 | 380 | class DehazeFormer(nn.Module): 381 | def __init__(self, in_chans=305, out_chans=306, window_size=8, 382 | # def __init__(self, in_chans=7, out_chans=8, window_size=8, 383 | # def __init__(self, in_chans=4, out_chans=5, window_size=8, 384 | embed_dims=[24, 48, 96, 48, 24], 385 | mlp_ratios=[2., 4., 4., 2., 2.], 386 | depths=[8, 8, 8, 4, 4], 387 | num_heads=[2, 4, 6, 1, 1], 388 | attn_ratio=[1/4, 1/2, 3/4, 0, 0], 389 | conv_type=['DWConv', 'DWConv', 'DWConv', 'DWConv', 'DWConv'], 390 | norm_layer=[RLN, RLN, RLN, RLN, RLN]): 391 | super(DehazeFormer, self).__init__() 392 | 393 | # setting 394 | self.patch_size = 4 395 | self.window_size = window_size 396 | self.mlp_ratios = mlp_ratios 397 | 398 | # split image into non-overlapping patches 399 | self.patch_embed = PatchEmbed( 400 | patch_size=1, in_chans=in_chans, embed_dim=embed_dims[0], kernel_size=3) 401 | 402 | # backbone 403 | self.layer1 = BasicLayer(network_depth=sum(depths), dim=embed_dims[0], depth=depths[0], 404 | num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], 405 | norm_layer=norm_layer[0], window_size=window_size, 406 | attn_ratio=attn_ratio[0], attn_loc='last', conv_type=conv_type[0]) 407 | 408 | self.patch_merge1 = PatchEmbed( 409 | patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1]) 410 | 411 | self.skip1 = nn.Conv2d(embed_dims[0], embed_dims[0], 1) 412 | 413 | self.layer2 = BasicLayer(network_depth=sum(depths), dim=embed_dims[1], depth=depths[1], 414 | num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], 415 | norm_layer=norm_layer[1], window_size=window_size, 416 | attn_ratio=attn_ratio[1], attn_loc='last', conv_type=conv_type[1]) 417 | 418 | self.patch_merge2 = PatchEmbed( 419 | patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2]) 420 | 421 | self.skip2 = nn.Conv2d(embed_dims[1], embed_dims[1], 1) 422 | 423 | self.layer3 = BasicLayer(network_depth=sum(depths), dim=embed_dims[2], depth=depths[2], 424 | num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], 425 | norm_layer=norm_layer[2], window_size=window_size, 426 | attn_ratio=attn_ratio[2], attn_loc='last', conv_type=conv_type[2]) 427 | 428 | self.patch_split1 = PatchUnEmbed( 429 | patch_size=2, out_chans=embed_dims[3], embed_dim=embed_dims[2]) 430 | 431 | assert embed_dims[1] == embed_dims[3] 432 | self.fusion1 = SKFusion(embed_dims[3]) 433 | 434 | self.layer4 = BasicLayer(network_depth=sum(depths), dim=embed_dims[3], depth=depths[3], 435 | num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], 436 | norm_layer=norm_layer[3], window_size=window_size, 437 | attn_ratio=attn_ratio[3], attn_loc='last', conv_type=conv_type[3]) 438 | 439 | self.patch_split2 = PatchUnEmbed( 440 | patch_size=2, out_chans=embed_dims[4], embed_dim=embed_dims[3]) 441 | 442 | assert embed_dims[0] == embed_dims[4] 443 | self.fusion2 = SKFusion(embed_dims[4]) 444 | 445 | self.layer5 = BasicLayer(network_depth=sum(depths), dim=embed_dims[4], depth=depths[4], 446 | num_heads=num_heads[4], mlp_ratio=mlp_ratios[4], 447 | norm_layer=norm_layer[4], window_size=window_size, 448 | attn_ratio=attn_ratio[4], attn_loc='last', conv_type=conv_type[4]) 449 | 450 | # merge non-overlapping patches into image 451 | self.patch_unembed = PatchUnEmbed( 452 | patch_size=1, out_chans=out_chans, embed_dim=embed_dims[4], kernel_size=3) 453 | 454 | def check_image_size(self, x): 455 | # NOTE: for I2I test 456 | _, _, h, w = x.size() 457 | mod_pad_h = (self.patch_size - h % self.patch_size) % self.patch_size 458 | mod_pad_w = (self.patch_size - w % self.patch_size) % self.patch_size 459 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') 460 | return x 461 | 462 | def forward_features(self, x): 463 | x = self.patch_embed(x) 464 | x = self.layer1(x) 465 | skip1 = x 466 | 467 | x = self.patch_merge1(x) 468 | x = self.layer2(x) 469 | skip2 = x 470 | 471 | x = self.patch_merge2(x) 472 | x = self.layer3(x) 473 | x = self.patch_split1(x) 474 | 475 | x = self.fusion1([x, self.skip2(skip2)]) + x 476 | x = self.layer4(x) 477 | x = self.patch_split2(x) 478 | 479 | x = self.fusion2([x, self.skip1(skip1)]) + x 480 | x = self.layer5(x) 481 | x = self.patch_unembed(x) 482 | return x 483 | 484 | def forward(self, x): 485 | H, W = x.shape[2:] 486 | x = self.check_image_size(x) 487 | 488 | feat = self.forward_features(x) 489 | K, B = torch.split(feat, (1, 305), dim=1) 490 | # K, B = torch.split(feat, (1, 7), dim=1) 491 | # K, B = torch.split(feat, (1, 4), dim=1) 492 | 493 | x = K * x - B + x 494 | x = x[:, :, :H, :W] 495 | return x 496 | 497 | 498 | if __name__ == '__main__': 499 | net = DehazeFormer() 500 | device = torch.device('cpu') 501 | net.to(device) 502 | # summary(net.cuda(), (305, 64, 64)) 503 | 504 | 505 | # x = torch.randn(4, 305, 64, 64).cuda() 506 | # mcplb = DehazeFormer().cuda() 507 | # out = mcplb(x) 508 | # # summary(mcplb.cuda(), (305, 64, 64)) 509 | 510 | 511 | 512 | 513 | # def dehazeformer_t(): 514 | # return DehazeFormer( 515 | # embed_dims=[24, 48, 96, 48, 24], 516 | # mlp_ratios=[2., 4., 4., 2., 2.], 517 | # depths=[4, 4, 4, 2, 2], 518 | # num_heads=[2, 4, 6, 1, 1], 519 | # attn_ratio=[0, 1/2, 1, 0, 0], 520 | # conv_type=['DWConv', 'DWConv', 'DWConv', 'DWConv', 'DWConv']) 521 | # 522 | # 523 | # def dehazeformer_s(): 使用的这个 524 | # return DehazeFormer( 525 | # embed_dims=[24, 48, 96, 48, 24], 526 | # mlp_ratios=[2., 4., 4., 2., 2.], 527 | # depths=[8, 8, 8, 4, 4], 528 | # num_heads=[2, 4, 6, 1, 1], 529 | # attn_ratio=[1/4, 1/2, 3/4, 0, 0], 530 | # conv_type=['DWConv', 'DWConv', 'DWConv', 'DWConv', 'DWConv']) 531 | # 532 | # 533 | # def dehazeformer_b(): 534 | # return DehazeFormer( 535 | # embed_dims=[24, 48, 96, 48, 24], 536 | # mlp_ratios=[2., 4., 4., 2., 2.], 537 | # depths=[16, 16, 16, 8, 8], 538 | # num_heads=[2, 4, 6, 1, 1], 539 | # attn_ratio=[1/4, 1/2, 3/4, 0, 0], 540 | # conv_type=['DWConv', 'DWConv', 'DWConv', 'DWConv', 'DWConv']) 541 | # 542 | # 543 | # def dehazeformer_d(): 544 | # return DehazeFormer( 545 | # embed_dims=[24, 48, 96, 48, 24], 546 | # mlp_ratios=[2., 4., 4., 2., 2.], 547 | # depths=[32, 32, 32, 16, 16], 548 | # num_heads=[2, 4, 6, 1, 1], 549 | # attn_ratio=[1/4, 1/2, 3/4, 0, 0], 550 | # conv_type=['DWConv', 'DWConv', 'DWConv', 'DWConv', 'DWConv']) 551 | # 552 | # 553 | # def dehazeformer_w(): 554 | # return DehazeFormer( 555 | # embed_dims=[48, 96, 192, 96, 48], 556 | # mlp_ratios=[2., 4., 4., 2., 2.], 557 | # depths=[16, 16, 16, 8, 8], 558 | # num_heads=[2, 4, 6, 1, 1], 559 | # attn_ratio=[1/4, 1/2, 3/4, 0, 0], 560 | # conv_type=['DWConv', 'DWConv', 'DWConv', 'DWConv', 'DWConv']) 561 | # 562 | # 563 | # def dehazeformer_m(): 564 | # return DehazeFormer( 565 | # embed_dims=[24, 48, 96, 48, 24], 566 | # mlp_ratios=[2., 4., 4., 2., 2.], 567 | # depths=[12, 12, 12, 6, 6], 568 | # num_heads=[2, 4, 6, 1, 1], 569 | # attn_ratio=[1/4, 1/2, 3/4, 0, 0], 570 | # conv_type=['Conv', 'Conv', 'Conv', 'Conv', 'Conv']) 571 | # 572 | # 573 | # def dehazeformer_l(): 574 | # return DehazeFormer( 575 | # embed_dims=[48, 96, 192, 96, 48], 576 | # mlp_ratios=[2., 4., 4., 2., 2.], 577 | # depths=[16, 16, 16, 12, 12], 578 | # num_heads=[2, 4, 6, 1, 1], 579 | # attn_ratio=[1/4, 1/2, 3/4, 0, 0], 580 | # conv_type=['Conv', 'Conv', 'Conv', 'Conv', 'Conv']) -------------------------------------------------------------------------------- /HDMba/models/AIDTransformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.checkpoint as checkpoint 4 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 5 | import torch.nn.functional as F 6 | from einops import rearrange, repeat 7 | from einops.layers.torch import Rearrange 8 | import math 9 | import numpy as np 10 | import time 11 | from torch import einsum 12 | import einops 13 | from torchsummary import summary 14 | from torchvision.ops import DeformConv2d 15 | 16 | 17 | class ConvBlock(nn.Module): 18 | def __init__(self, in_channel, out_channel, strides=1): 19 | super(ConvBlock, self).__init__() 20 | self.strides = strides 21 | self.in_channel=in_channel 22 | self.out_channel=out_channel 23 | self.block = nn.Sequential( 24 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=strides, padding=1), 25 | nn.LeakyReLU(inplace=True), 26 | nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=strides, padding=1), 27 | nn.LeakyReLU(inplace=True), 28 | ) 29 | self.conv11 = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=strides, padding=0) 30 | 31 | def forward(self, x): 32 | out1 = self.block(x) 33 | out2 = self.conv11(x) 34 | out = out1 + out2 35 | return out 36 | 37 | def flops(self, H, W): 38 | flops = H*W*self.in_channel*self.out_channel*(3*3+1)+H*W*self.out_channel*self.out_channel*3*3 39 | return flops 40 | 41 | 42 | class PosCNN(nn.Module): 43 | def __init__(self, in_chans, embed_dim=768, s=1): 44 | super(PosCNN, self).__init__() 45 | self.proj = nn.Sequential(nn.Conv2d(in_chans, embed_dim, 3, s, 1, bias=True, groups=embed_dim)) 46 | self.s = s 47 | 48 | def forward(self, x, H=None, W=None): 49 | B, N, C = x.shape 50 | H = H or int(math.sqrt(N)) 51 | W = W or int(math.sqrt(N)) 52 | feat_token = x 53 | cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W) 54 | if self.s == 1: 55 | x = self.proj(cnn_feat) + cnn_feat 56 | else: 57 | x = self.proj(cnn_feat) 58 | x = x.flatten(2).transpose(1, 2) 59 | return x 60 | 61 | def no_weight_decay(self): 62 | return ['proj.%d.weight' % i for i in range(4)] 63 | 64 | 65 | class SELayer(nn.Module): 66 | def __init__(self, channel, reduction=16): 67 | super(SELayer, self).__init__() 68 | self.avg_pool = nn.AdaptiveAvgPool1d(1) 69 | self.fc = nn.Sequential( 70 | nn.Linear(channel, channel // reduction, bias=False), 71 | nn.ReLU(inplace=True), 72 | nn.Linear(channel // reduction, channel, bias=False), 73 | nn.Sigmoid() 74 | ) 75 | 76 | def forward(self, x): # x: [B, N, C] 77 | x = torch.transpose(x, 1, 2) # [B, C, N] 78 | b, c, _ = x.size() 79 | y = self.avg_pool(x).view(b, c) 80 | y = self.fc(y).view(b, c, 1) 81 | x = x * y.expand_as(x) 82 | x = torch.transpose(x, 1, 2) # [B, N, C] 83 | return x 84 | 85 | 86 | class SpatialAttention(nn.Module): 87 | def __init__(self, kernel_size=7): 88 | super(SpatialAttention, self).__init__() 89 | 90 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False) 91 | self.sigmoid = nn.Sigmoid() 92 | 93 | def forward(self, x): 94 | avg_out = torch.mean(x, dim=1, keepdim=True) 95 | max_out, _ = torch.max(x, dim=1, keepdim=True) 96 | x1 = torch.cat([avg_out, max_out], dim=1) 97 | x1 = self.conv1(x1) 98 | return self.sigmoid(x1)*x 99 | 100 | 101 | class SepConv2d(torch.nn.Module): 102 | def __init__(self, 103 | in_channels, 104 | out_channels, 105 | kernel_size, 106 | stride=1, 107 | padding=0, 108 | dilation=1,act_layer=nn.ReLU): 109 | super(SepConv2d, self).__init__() 110 | 111 | self.depthwise = torch.nn.Conv2d(in_channels, 112 | in_channels, 113 | kernel_size=kernel_size, 114 | stride=stride, 115 | padding=padding, 116 | dilation=dilation, 117 | groups=in_channels) 118 | self.pointwise = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1) 119 | self.act_layer = act_layer() if act_layer is not None else nn.Identity() 120 | self.in_channels = in_channels 121 | self.out_channels = out_channels 122 | self.kernel_size = kernel_size 123 | self.stride = stride 124 | self.offset_conv1 = nn.Conv2d(in_channels, 216, 3, stride=1, padding=1, bias= False) 125 | self.deform1 = DeformConv2d(in_channels, out_channels, 3, padding=1, groups=8) 126 | 127 | self.SA = SpatialAttention() 128 | 129 | def offset_gen(self, x): 130 | sa = self.SA(x) 131 | o1, o2, mask = torch.chunk(sa, 3, dim=1) 132 | offset = torch.cat((o1, o2), dim=1) 133 | mask = torch.sigmoid(mask) 134 | return offset,mask 135 | 136 | def forward(self, x): 137 | offset1,mask = self.offset_gen(self.offset_conv1(x)) 138 | feat1 = self.deform1(x, offset1, mask) 139 | x = self.act_layer(feat1) 140 | x = self.pointwise(x) 141 | return x 142 | 143 | 144 | ######## Embedding for q,k,v ######## 145 | class ConvProjection(nn.Module): 146 | def __init__(self, dim, heads = 8, dim_head = 64, kernel_size=3, q_stride=1, k_stride=1, v_stride=1, dropout = 0., 147 | last_stage=False,bias=True): 148 | 149 | super().__init__() 150 | 151 | inner_dim = dim_head * heads 152 | self.heads = heads 153 | pad = (kernel_size - q_stride)//2 154 | self.to_q = SepConv2d(dim, inner_dim, kernel_size, q_stride, pad, bias) 155 | self.to_k = SepConv2d(dim, inner_dim, kernel_size, k_stride, pad, bias) 156 | self.to_v = SepConv2d(dim, inner_dim, kernel_size, v_stride, pad, bias) 157 | 158 | def forward(self, x, attn_kv=None): 159 | b, n, c, h = * x.shape, self.heads 160 | l = int(math.sqrt(n)) 161 | w = int(math.sqrt(n)) 162 | 163 | attn_kv = x if attn_kv is None else attn_kv 164 | x = rearrange(x, 'b (l w) c -> b c l w', l=l, w=w) 165 | attn_kv = rearrange(attn_kv, 'b (l w) c -> b c l w', l=l, w=w) 166 | # print(attn_kv) 167 | q = self.to_q(x) 168 | q = rearrange(q, 'b (h d) l w -> b h (l w) d', h=h) 169 | 170 | k = self.to_k(attn_kv) 171 | v = self.to_v(attn_kv) 172 | k = rearrange(k, 'b (h d) l w -> b h (l w) d', h=h) 173 | v = rearrange(v, 'b (h d) l w -> b h (l w) d', h=h) 174 | return q,k,v 175 | 176 | def flops(self, H, W): 177 | flops = 0 178 | flops += self.to_q.flops(H, W) 179 | flops += self.to_k.flops(H, W) 180 | flops += self.to_v.flops(H, W) 181 | return flops 182 | 183 | 184 | class LinearProjection(nn.Module): 185 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., bias=True): 186 | super().__init__() 187 | inner_dim = dim_head * heads 188 | self.heads = heads 189 | self.to_q = nn.Linear(dim, inner_dim, bias = bias) 190 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = bias) 191 | self.dim = dim 192 | self.inner_dim = inner_dim 193 | 194 | def forward(self, x, attn_kv=None): 195 | B_, N, C = x.shape 196 | attn_kv = x if attn_kv is None else attn_kv 197 | q = self.to_q(x).reshape(B_, N, 1, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) 198 | kv = self.to_kv(attn_kv).reshape(B_, N, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) 199 | q = q[0] 200 | k, v = kv[0], kv[1] 201 | return q,k,v 202 | 203 | def flops(self, H, W): 204 | flops = H*W*self.dim*self.inner_dim*3 205 | return flops 206 | 207 | 208 | class LinearProjection_Concat_kv(nn.Module): 209 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., bias=True): 210 | super().__init__() 211 | inner_dim = dim_head * heads 212 | self.heads = heads 213 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = bias) 214 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = bias) 215 | self.dim = dim 216 | self.inner_dim = inner_dim 217 | 218 | def forward(self, x, attn_kv=None): 219 | B_, N, C = x.shape 220 | attn_kv = x if attn_kv is None else attn_kv 221 | qkv_dec = self.to_qkv(x).reshape(B_, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) 222 | kv_enc = self.to_kv(attn_kv).reshape(B_, N, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) 223 | q, k_d, v_d = qkv_dec[0], qkv_dec[1], qkv_dec[2] # make torchscript happy (cannot use tensor as tuple) 224 | k_e, v_e = kv_enc[0], kv_enc[1] 225 | k = torch.cat((k_d,k_e),dim=2) 226 | v = torch.cat((v_d,v_e),dim=2) 227 | return q,k,v 228 | 229 | def flops(self, H, W): 230 | flops = H*W*self.dim*self.inner_dim*5 231 | return flops 232 | 233 | 234 | ######################################### 235 | class WindowAttention(nn.Module): 236 | def __init__(self, dim, win_size,num_heads, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,se_layer=False): 237 | 238 | super().__init__() 239 | self.dim = dim 240 | self.win_size = win_size # Wh, Ww 241 | self.num_heads = num_heads 242 | head_dim = dim // num_heads 243 | self.scale = qk_scale or head_dim ** -0.5 244 | 245 | # define a parameter table of relative position bias 246 | self.relative_position_bias_table = nn.Parameter( 247 | torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 248 | 249 | # get pair-wise relative position index for each token inside the window 250 | coords_h = torch.arange(self.win_size[0]) # [0,...,Wh-1] 251 | coords_w = torch.arange(self.win_size[1]) # [0,...,Ww-1] 252 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 253 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 254 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 255 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 256 | relative_coords[:, :, 0] += self.win_size[0] - 1 # shift to start from 0 257 | relative_coords[:, :, 1] += self.win_size[1] - 1 258 | relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1 259 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 260 | self.register_buffer("relative_position_index", relative_position_index) 261 | 262 | # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 263 | if token_projection =='conv': 264 | self.qkv = ConvProjection(dim,num_heads,dim//num_heads,bias=qkv_bias) 265 | elif token_projection =='linear_concat': 266 | self.qkv = LinearProjection_Concat_kv(dim,num_heads,dim//num_heads,bias=qkv_bias) 267 | else: 268 | self.qkv = LinearProjection(dim,num_heads,dim//num_heads,bias=qkv_bias) 269 | 270 | self.token_projection = token_projection 271 | self.attn_drop = nn.Dropout(attn_drop) 272 | self.proj = nn.Linear(dim, dim) 273 | self.se_layer = SELayer(dim) if se_layer else nn.Identity() 274 | self.proj_drop = nn.Dropout(proj_drop) 275 | 276 | trunc_normal_(self.relative_position_bias_table, std=.02) 277 | self.softmax = nn.Softmax(dim=-1) 278 | 279 | def forward(self, x, attn_kv=None, mask=None): 280 | B_, N, C = x.shape 281 | q, k, v = self.qkv(x,attn_kv) 282 | q = q * self.scale 283 | attn = (q @ k.transpose(-2, -1)) 284 | 285 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 286 | self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1) # Wh*Ww,Wh*Ww,nH 287 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 288 | ratio = attn.size(-1)//relative_position_bias.size(-1) 289 | relative_position_bias = repeat(relative_position_bias, 'nH l c -> nH l (c d)', d = ratio) 290 | 291 | attn = attn + relative_position_bias.unsqueeze(0) 292 | 293 | if mask is not None: 294 | nW = mask.shape[0] 295 | mask = repeat(mask, 'nW m n -> nW m (n d)',d = ratio) 296 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N*ratio) + mask.unsqueeze(1).unsqueeze(0) 297 | attn = attn.view(-1, self.num_heads, N, N*ratio) 298 | attn = self.softmax(attn) 299 | else: 300 | attn = self.softmax(attn) 301 | 302 | attn = self.attn_drop(attn) 303 | 304 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 305 | x = self.proj(x) 306 | x = self.se_layer(x) 307 | x = self.proj_drop(x) 308 | return x 309 | 310 | def extra_repr(self) -> str: 311 | return f'dim={self.dim}, win_size={self.win_size}, num_heads={self.num_heads}' 312 | 313 | 314 | class Mlp(nn.Module): 315 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 316 | super().__init__() 317 | out_features = out_features or in_features 318 | hidden_features = hidden_features or in_features 319 | self.fc1 = nn.Linear(in_features, hidden_features) 320 | self.act = act_layer() 321 | self.fc2 = nn.Linear(hidden_features, out_features) 322 | self.drop = nn.Dropout(drop) 323 | self.in_features = in_features 324 | self.hidden_features = hidden_features 325 | self.out_features = out_features 326 | 327 | def forward(self, x): 328 | x = self.fc1(x) 329 | x = self.act(x) 330 | x = self.drop(x) 331 | x = self.fc2(x) 332 | x = self.drop(x) 333 | return x 334 | 335 | 336 | class LeFF(nn.Module): 337 | def __init__(self, dim=32, hidden_dim=128, act_layer=nn.GELU,drop = 0.): 338 | super().__init__() 339 | self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim), 340 | act_layer()) 341 | self.dwconv = nn.Sequential(nn.Conv2d(hidden_dim,hidden_dim,groups=hidden_dim,kernel_size=3,stride=1,padding=1), 342 | act_layer()) 343 | self.linear2 = nn.Sequential(nn.Linear(hidden_dim, dim)) 344 | self.dim = dim 345 | self.hidden_dim = hidden_dim 346 | 347 | def forward(self, x): 348 | # bs x hw x c 349 | bs, hw, c = x.size() 350 | hh = int(math.sqrt(hw)) 351 | 352 | x = self.linear1(x) 353 | 354 | # spatial restore 355 | x = rearrange(x, ' b (h w) (c) -> b c h w ', h = hh, w = hh) 356 | # bs,hidden_dim,32x32 357 | 358 | x = self.dwconv(x) 359 | 360 | # flaten 361 | x = rearrange(x, ' b c h w -> b (h w) c', h = hh, w = hh) 362 | 363 | x = self.linear2(x) 364 | 365 | return x 366 | 367 | 368 | def window_partition(x, win_size, dilation_rate=1): 369 | B, H, W, C = x.shape 370 | if dilation_rate != 1: 371 | x = x.permute(0,3,1,2) # B, C, H, W 372 | assert type(dilation_rate) is int, 'dilation_rate should be a int' 373 | x = F.unfold(x, kernel_size=win_size,dilation=dilation_rate,padding=4*(dilation_rate-1),stride=win_size) # B, C*Wh*Ww, H/Wh*W/Ww 374 | windows = x.permute(0,2,1).contiguous().view(-1, C, win_size, win_size) # B' ,C ,Wh ,Ww 375 | windows = windows.permute(0,2,3,1).contiguous() # B' ,Wh ,Ww ,C 376 | else: 377 | x = x.view(B, H // win_size, win_size, W // win_size, win_size, C) 378 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C) # B' ,Wh ,Ww ,C 379 | return windows 380 | 381 | 382 | def window_reverse(windows, win_size, H, W, dilation_rate=1): 383 | # B' ,Wh ,Ww ,C 384 | B = int(windows.shape[0] / (H * W / win_size / win_size)) 385 | x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1) 386 | if dilation_rate !=1: 387 | x = windows.permute(0,5,3,4,1,2).contiguous() # B, C*Wh*Ww, H/Wh*W/Ww 388 | x = F.fold(x, (H, W), kernel_size=win_size, dilation=dilation_rate, padding=4*(dilation_rate-1),stride=win_size) 389 | else: 390 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 391 | return x 392 | 393 | 394 | class Downsample(nn.Module): 395 | def __init__(self, in_channel, out_channel): 396 | super(Downsample, self).__init__() 397 | self.conv = nn.Sequential( 398 | nn.Conv2d(in_channel, out_channel, kernel_size=4, stride=2, padding=1), 399 | ) 400 | self.in_channel = in_channel 401 | self.out_channel = out_channel 402 | 403 | def forward(self, x): 404 | B, L, C = x.shape 405 | # import pdb;pdb.set_trace() 406 | H = int(math.sqrt(L)) 407 | W = int(math.sqrt(L)) 408 | x = x.transpose(1, 2).contiguous().view(B, C, H, W) 409 | out = self.conv(x).flatten(2).transpose(1,2).contiguous() # B H*W C 410 | return out 411 | 412 | 413 | # Upsample Block 414 | class Upsample(nn.Module): 415 | def __init__(self, in_channel, out_channel): 416 | super(Upsample, self).__init__() 417 | self.deconv = nn.Sequential( 418 | nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2), 419 | ) 420 | self.in_channel = in_channel 421 | self.out_channel = out_channel 422 | 423 | def forward(self, x): 424 | B, L, C = x.shape 425 | H = int(math.sqrt(L)) 426 | W = int(math.sqrt(L)) 427 | x = x.transpose(1, 2).contiguous().view(B, C, H, W) 428 | out = self.deconv(x).flatten(2).transpose(1,2).contiguous() # B H*W C 429 | return out 430 | 431 | 432 | class Upsample_advanced(nn.Module): 433 | def __init__(self, in_channel, out_channel,factor): 434 | super(Upsample_advanced, self).__init__() 435 | self.deconv = nn.Sequential( 436 | nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=factor), 437 | ) 438 | self.in_channel = in_channel 439 | self.out_channel = out_channel 440 | 441 | def forward(self, x): 442 | B, L, C = x.shape 443 | H = int(math.sqrt(L)) 444 | W = int(math.sqrt(L)) 445 | x = x.transpose(1, 2).contiguous().view(B, C, H, W) 446 | out = self.deconv(x).flatten(2).transpose(1,2).contiguous() # B H*W C 447 | return out 448 | 449 | 450 | # Input Projection 451 | class InputProj(nn.Module): 452 | def __init__(self, in_channel=305, out_channel=64, kernel_size=3, stride=1, norm_layer=None,act_layer=nn.LeakyReLU): 453 | super().__init__() 454 | self.proj = nn.Sequential( 455 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=kernel_size//2), 456 | act_layer(inplace=True) 457 | ) 458 | if norm_layer is not None: 459 | self.norm = norm_layer(out_channel) 460 | else: 461 | self.norm = None 462 | self.in_channel = in_channel 463 | self.out_channel = out_channel 464 | 465 | def forward(self, x): 466 | B, C, H, W = x.shape 467 | x = self.proj(x).flatten(2).transpose(1, 2).contiguous() # B H*W C 468 | if self.norm is not None: 469 | x = self.norm(x) 470 | return x 471 | 472 | 473 | # Output Projection 474 | class OutputProj(nn.Module): 475 | def __init__(self, in_channel=64, out_channel=305, kernel_size=3, stride=1, norm_layer=None,act_layer=None): 476 | super().__init__() 477 | self.proj = nn.Sequential( 478 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=kernel_size//2), 479 | ) 480 | if act_layer is not None: 481 | self.proj.add_module(act_layer(inplace=True)) 482 | if norm_layer is not None: 483 | self.norm = norm_layer(out_channel) 484 | else: 485 | self.norm = None 486 | self.in_channel = in_channel 487 | self.out_channel = out_channel 488 | 489 | def forward(self, x): 490 | B, L, C = x.shape 491 | H = int(math.sqrt(L)) 492 | W = int(math.sqrt(L)) 493 | x = x.transpose(1, 2).view(B, C, H, W) 494 | x = self.proj(x) 495 | if self.norm is not None: 496 | x = self.norm(x) 497 | return x 498 | 499 | 500 | class Deformable_Attentive_Transformer(nn.Module): 501 | def __init__(self, dim, input_resolution, num_heads, win_size=8, shift_size=0, 502 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 503 | act_layer=nn.GELU, norm_layer=nn.LayerNorm,token_projection='conv',token_mlp='leff',se_layer=False): 504 | super().__init__() 505 | self.dim = dim 506 | self.input_resolution = input_resolution 507 | self.num_heads = num_heads 508 | self.win_size = win_size 509 | self.shift_size = shift_size 510 | self.mlp_ratio = mlp_ratio 511 | self.token_mlp = token_mlp 512 | if min(self.input_resolution) <= self.win_size: 513 | self.shift_size = 0 514 | self.win_size = min(self.input_resolution) 515 | assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-win_size" 516 | 517 | self.norm1 = norm_layer(dim) 518 | self.attn = WindowAttention( 519 | dim, win_size=to_2tuple(self.win_size), num_heads=num_heads, 520 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, 521 | token_projection=token_projection,se_layer=se_layer) 522 | 523 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 524 | self.norm2 = norm_layer(dim) 525 | mlp_hidden_dim = int(dim * mlp_ratio) 526 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,act_layer=act_layer, drop=drop) if token_mlp=='ffn' else LeFF(dim,mlp_hidden_dim,act_layer=act_layer, drop=drop) 527 | 528 | 529 | def extra_repr(self) -> str: 530 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 531 | f"win_size={self.win_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 532 | 533 | def forward(self, x, mask=None): 534 | B, L, C = x.shape 535 | H = int(math.sqrt(L)) 536 | W = int(math.sqrt(L)) 537 | 538 | ## input mask 539 | if mask != None: 540 | input_mask = F.interpolate(mask, size=(H,W)).permute(0,2,3,1) 541 | input_mask_windows = window_partition(input_mask, self.win_size) # nW, win_size, win_size, 1 542 | attn_mask = input_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size 543 | attn_mask = attn_mask.unsqueeze(2)*attn_mask.unsqueeze(1) # nW, win_size*win_size, win_size*win_size 544 | attn_mask = attn_mask.masked_fill(attn_mask!=0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 545 | else: 546 | attn_mask = None 547 | 548 | ## shift mask 549 | if self.shift_size > 0: 550 | # calculate attention mask for SW-MSA 551 | shift_mask = torch.zeros((1, H, W, 1)).type_as(x) 552 | h_slices = (slice(0, -self.win_size), 553 | slice(-self.win_size, -self.shift_size), 554 | slice(-self.shift_size, None)) 555 | w_slices = (slice(0, -self.win_size), 556 | slice(-self.win_size, -self.shift_size), 557 | slice(-self.shift_size, None)) 558 | cnt = 0 559 | for h in h_slices: 560 | for w in w_slices: 561 | shift_mask[:, h, w, :] = cnt 562 | cnt += 1 563 | shift_mask_windows = window_partition(shift_mask, self.win_size) # nW, win_size, win_size, 1 564 | shift_mask_windows = shift_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size 565 | shift_attn_mask = shift_mask_windows.unsqueeze(1) - shift_mask_windows.unsqueeze(2) # nW, win_size*win_size, win_size*win_size 566 | shift_attn_mask = shift_attn_mask.masked_fill(shift_attn_mask != 0, float(-100.0)).masked_fill(shift_attn_mask == 0, float(0.0)) 567 | attn_mask = attn_mask + shift_attn_mask if attn_mask is not None else shift_attn_mask 568 | 569 | shortcut = x 570 | x = self.norm1(x) 571 | x = x.view(B, H, W, C) 572 | 573 | # cyclic shift 574 | if self.shift_size > 0: 575 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 576 | else: 577 | shifted_x = x 578 | 579 | # partition windows 580 | x_windows = window_partition(shifted_x, self.win_size) # nW*B, win_size, win_size, C N*C->C 581 | x_windows = x_windows.view(-1, self.win_size * self.win_size, C) # nW*B, win_size*win_size, C 582 | 583 | # W-MSA/SW-MSA 584 | attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, win_size*win_size, C 585 | 586 | # merge windows 587 | attn_windows = attn_windows.view(-1, self.win_size, self.win_size, C) 588 | shifted_x = window_reverse(attn_windows, self.win_size, H, W) # B H' W' C 589 | 590 | # reverse cyclic shift 591 | if self.shift_size > 0: 592 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 593 | else: 594 | x = shifted_x 595 | x = x.view(B, H * W, C) 596 | 597 | # FFN 598 | x = shortcut + self.drop_path(x) 599 | x = x + self.drop_path(self.mlp(self.norm2(x))) 600 | del attn_mask 601 | return x 602 | 603 | 604 | class TRANSFORMER_BLOCK(nn.Module): 605 | def __init__(self, dim, output_dim, input_resolution, depth, num_heads, win_size, 606 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 607 | drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False, 608 | token_projection='linear',token_mlp='ffn',se_layer=False): 609 | 610 | super().__init__() 611 | self.dim = dim 612 | self.input_resolution = input_resolution 613 | self.depth = depth 614 | self.use_checkpoint = use_checkpoint 615 | # build blocks 616 | self.blocks = nn.ModuleList([ 617 | Deformable_Attentive_Transformer(dim=dim, input_resolution=input_resolution, 618 | num_heads=num_heads, win_size=win_size, 619 | shift_size=0 if (i % 2 == 0) else win_size // 2, 620 | mlp_ratio=mlp_ratio, 621 | qkv_bias=qkv_bias, qk_scale=qk_scale, 622 | drop=drop, attn_drop=attn_drop, 623 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 624 | norm_layer=norm_layer,token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer) 625 | for i in range(depth)]) 626 | 627 | def extra_repr(self) -> str: 628 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 629 | 630 | def forward(self, x, mask=None): 631 | for blk in self.blocks: 632 | if self.use_checkpoint: 633 | x = checkpoint.checkpoint(blk, x) 634 | else: 635 | x = blk(x,mask) 636 | return x 637 | 638 | def flops(self): 639 | flops = 0 640 | for blk in self.blocks: 641 | flops += blk.flops() 642 | return flops 643 | 644 | 645 | class AIDTransformer(nn.Module): 646 | # def __init__(self, img_size=64, in_chans=4, 647 | # def __init__(self, img_size=64, in_chans=7, 648 | def __init__(self, img_size=64, in_chans=305, 649 | embed_dim=32, depths=[2, 2, 2, 2, 2, 2, 2, 2, 2], num_heads=[1, 2, 4, 8, 16, 16, 8, 4, 2], 650 | win_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None, 651 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 652 | norm_layer=nn.LayerNorm, patch_norm=True, 653 | use_checkpoint=False, token_projection='conv', token_mlp='ffn', se_layer=False, 654 | dowsample=Downsample, upsample=Upsample, **kwargs): 655 | super().__init__() 656 | 657 | self.num_enc_layers = len(depths)//2 658 | self.num_dec_layers = len(depths)//2 659 | self.embed_dim = embed_dim 660 | self.patch_norm = patch_norm 661 | self.mlp_ratio = mlp_ratio 662 | self.token_projection = token_projection 663 | self.mlp = token_mlp 664 | self.win_size =win_size 665 | self.reso = img_size 666 | self.pos_drop = nn.Dropout(p=drop_rate) 667 | 668 | # stochastic depth 669 | enc_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths[:self.num_enc_layers]))] 670 | conv_dpr = [drop_path_rate]*depths[4] 671 | dec_dpr = enc_dpr[::-1] 672 | 673 | # build layers 674 | 675 | # Input/Output 676 | self.input_proj = InputProj(in_channel=in_chans, out_channel=embed_dim, kernel_size=3, stride=1, act_layer=nn.LeakyReLU) 677 | self.output_proj = OutputProj(in_channel=3*embed_dim, out_channel=in_chans, kernel_size=3, stride=1) 678 | self.output_proj1 = OutputProj(in_channel=embed_dim, out_channel=in_chans, kernel_size=3, stride=1) 679 | # Encoder 680 | self.encoderlayer_0 = TRANSFORMER_BLOCK(dim=embed_dim, 681 | output_dim=embed_dim, 682 | input_resolution=(img_size, 683 | img_size), 684 | depth=depths[0], 685 | num_heads=num_heads[0], 686 | win_size=win_size, 687 | mlp_ratio=self.mlp_ratio, 688 | qkv_bias=qkv_bias, qk_scale=qk_scale, 689 | drop=drop_rate, attn_drop=attn_drop_rate, 690 | drop_path=enc_dpr[sum(depths[:0]):sum(depths[:1])], 691 | norm_layer=norm_layer, 692 | use_checkpoint=use_checkpoint, 693 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer) 694 | self.dowsample_0 = dowsample(embed_dim, embed_dim*2) 695 | self.hs_downsample = dowsample(embed_dim, embed_dim) 696 | self.hs_upsample = upsample(3*embed_dim,embed_dim) 697 | self.qs_upsample = upsample(5*embed_dim,embed_dim) 698 | self.qs_upsample1 = upsample(embed_dim,embed_dim) 699 | 700 | self.upsample_poolup1 = upsample(3*embed_dim,embed_dim) 701 | self.upsample_poolup2 = upsample(5*embed_dim,embed_dim*3) 702 | 703 | self.encoderlayer_1 = TRANSFORMER_BLOCK(dim=embed_dim*3, 704 | output_dim=embed_dim*2, 705 | input_resolution=(img_size // 2, 706 | img_size // 2), 707 | depth=depths[1], 708 | num_heads=num_heads[1], 709 | win_size=win_size, 710 | mlp_ratio=self.mlp_ratio, 711 | qkv_bias=qkv_bias, qk_scale=qk_scale, 712 | drop=drop_rate, attn_drop=attn_drop_rate, 713 | drop_path=enc_dpr[sum(depths[:1]):sum(depths[:2])], 714 | norm_layer=norm_layer, 715 | use_checkpoint=use_checkpoint, 716 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer) 717 | self.dowsample_1 = dowsample(embed_dim*3, embed_dim*4) 718 | self.encoderlayer_2 = TRANSFORMER_BLOCK(dim=embed_dim*5, 719 | output_dim=embed_dim*4, 720 | input_resolution=(img_size // (2 ** 2), 721 | img_size // (2 ** 2)), 722 | depth=depths[2], 723 | num_heads=num_heads[2], 724 | win_size=win_size, 725 | mlp_ratio=self.mlp_ratio, 726 | qkv_bias=qkv_bias, qk_scale=qk_scale, 727 | drop=drop_rate, attn_drop=attn_drop_rate, 728 | drop_path=enc_dpr[sum(depths[:2]):sum(depths[:3])], 729 | norm_layer=norm_layer, 730 | use_checkpoint=use_checkpoint, 731 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer) 732 | self.dowsample_2 = dowsample(embed_dim*5, embed_dim*8) 733 | self.encoderlayer_3 = TRANSFORMER_BLOCK(dim=embed_dim*8, 734 | output_dim=embed_dim*8, 735 | input_resolution=(img_size // (2 ** 3), 736 | img_size // (2 ** 3)), 737 | depth=depths[3], 738 | num_heads=num_heads[3], 739 | win_size=win_size, 740 | mlp_ratio=self.mlp_ratio, 741 | qkv_bias=qkv_bias, qk_scale=qk_scale, 742 | drop=drop_rate, attn_drop=attn_drop_rate, 743 | drop_path=enc_dpr[sum(depths[:3]):sum(depths[:4])], 744 | norm_layer=norm_layer, 745 | use_checkpoint=use_checkpoint, 746 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer) 747 | self.dowsample_3 = dowsample(embed_dim*8, embed_dim*16) 748 | 749 | # Bottleneck 750 | self.conv = TRANSFORMER_BLOCK(dim=embed_dim*9, 751 | output_dim=embed_dim*8, 752 | input_resolution=(img_size // (2 ** 4), 753 | img_size // (2 ** 4)), 754 | depth=depths[4], 755 | num_heads=num_heads[4], 756 | win_size=win_size, 757 | mlp_ratio=self.mlp_ratio, 758 | qkv_bias=qkv_bias, qk_scale=qk_scale, 759 | drop=drop_rate, attn_drop=attn_drop_rate, 760 | drop_path=conv_dpr, 761 | norm_layer=norm_layer, 762 | use_checkpoint=use_checkpoint, 763 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer) 764 | 765 | # Decoder 766 | self.upsample_0 = upsample(embed_dim*9, embed_dim*8) 767 | self.decoderlayer_0 = TRANSFORMER_BLOCK(dim=embed_dim*14, 768 | output_dim=embed_dim*16, 769 | input_resolution=(img_size // (2 ** 3), 770 | img_size // (2 ** 3)), 771 | depth=depths[5], 772 | num_heads=num_heads[5], 773 | win_size=win_size, 774 | mlp_ratio=self.mlp_ratio, 775 | qkv_bias=qkv_bias, qk_scale=qk_scale, 776 | drop=drop_rate, attn_drop=attn_drop_rate, 777 | drop_path=dec_dpr[:depths[5]], 778 | norm_layer=norm_layer, 779 | use_checkpoint=use_checkpoint, 780 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer) 781 | self.upsample_1 = upsample(embed_dim*14, embed_dim*4) 782 | self.decoderlayer_1 = TRANSFORMER_BLOCK(dim=embed_dim*8, 783 | output_dim=embed_dim*8, 784 | input_resolution=(img_size // (2 ** 2), 785 | img_size // (2 ** 2)), 786 | depth=depths[6], 787 | num_heads=num_heads[6], 788 | win_size=win_size, 789 | mlp_ratio=self.mlp_ratio, 790 | qkv_bias=qkv_bias, qk_scale=qk_scale, 791 | drop=drop_rate, attn_drop=attn_drop_rate, 792 | drop_path=dec_dpr[sum(depths[5:6]):sum(depths[5:7])], 793 | norm_layer=norm_layer, 794 | use_checkpoint=use_checkpoint, 795 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer) 796 | self.upsample_2 = upsample(embed_dim*8, embed_dim*2) 797 | self.decoderlayer_2 = TRANSFORMER_BLOCK(dim=embed_dim*3, 798 | output_dim=embed_dim*4, 799 | input_resolution=(img_size // 2, 800 | img_size // 2), 801 | depth=depths[7], 802 | num_heads=num_heads[7], 803 | win_size=win_size, 804 | mlp_ratio=self.mlp_ratio, 805 | qkv_bias=qkv_bias, qk_scale=qk_scale, 806 | drop=drop_rate, attn_drop=attn_drop_rate, 807 | drop_path=dec_dpr[sum(depths[5:7]):sum(depths[5:8])], 808 | norm_layer=norm_layer, 809 | use_checkpoint=use_checkpoint, 810 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer) 811 | self.upsample_3 = upsample(embed_dim*4, embed_dim) 812 | self.decoderlayer_3 = TRANSFORMER_BLOCK(dim=embed_dim*2, 813 | output_dim=embed_dim*2, 814 | input_resolution=(img_size, 815 | img_size), 816 | depth=depths[8], 817 | num_heads=num_heads[8], 818 | win_size=win_size, 819 | mlp_ratio=self.mlp_ratio, 820 | qkv_bias=qkv_bias, qk_scale=qk_scale, 821 | drop=drop_rate, attn_drop=attn_drop_rate, 822 | drop_path=dec_dpr[sum(depths[5:8]):sum(depths[5:9])], 823 | norm_layer=norm_layer, 824 | use_checkpoint=use_checkpoint, 825 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer) 826 | # 1/2 scale stream 827 | self.hs1 = TRANSFORMER_BLOCK(dim=embed_dim, 828 | output_dim=embed_dim, 829 | input_resolution=(img_size // 2, 830 | img_size // 2), 831 | depth=depths[1], 832 | num_heads=num_heads[1], 833 | win_size=win_size, 834 | mlp_ratio=self.mlp_ratio, 835 | qkv_bias=qkv_bias, qk_scale=qk_scale, 836 | drop=drop_rate, attn_drop=attn_drop_rate, 837 | drop_path=enc_dpr[sum(depths[:1]):sum(depths[:2])], 838 | norm_layer=norm_layer, 839 | use_checkpoint=use_checkpoint, 840 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer) 841 | self.hs2 = TRANSFORMER_BLOCK(dim=3*embed_dim, 842 | output_dim=embed_dim, 843 | input_resolution=(img_size // 2, 844 | img_size // 2), 845 | depth=depths[1], 846 | num_heads=num_heads[1], 847 | win_size=win_size, 848 | mlp_ratio=self.mlp_ratio, 849 | qkv_bias=qkv_bias, qk_scale=qk_scale, 850 | drop=drop_rate, attn_drop=attn_drop_rate, 851 | drop_path=enc_dpr[sum(depths[:1]):sum(depths[:2])], 852 | norm_layer=norm_layer, 853 | use_checkpoint=use_checkpoint, 854 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer) 855 | self.hs3 = TRANSFORMER_BLOCK(dim=3*embed_dim, 856 | output_dim=embed_dim, 857 | input_resolution=(img_size // 2, 858 | img_size // 2), 859 | depth=depths[1], 860 | num_heads=num_heads[1], 861 | win_size=win_size, 862 | mlp_ratio=self.mlp_ratio, 863 | qkv_bias=qkv_bias, qk_scale=qk_scale, 864 | drop=drop_rate, attn_drop=attn_drop_rate, 865 | drop_path=enc_dpr[sum(depths[:1]):sum(depths[:2])], 866 | norm_layer=norm_layer, 867 | use_checkpoint=use_checkpoint, 868 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer) 869 | self.hs4 = TRANSFORMER_BLOCK(dim=3*embed_dim, 870 | output_dim=embed_dim, 871 | input_resolution=(img_size // 2, 872 | img_size // 2), 873 | depth=depths[1], 874 | num_heads=num_heads[1], 875 | win_size=win_size, 876 | mlp_ratio=self.mlp_ratio, 877 | qkv_bias=qkv_bias, qk_scale=qk_scale, 878 | drop=drop_rate, attn_drop=attn_drop_rate, 879 | drop_path=enc_dpr[sum(depths[:1]):sum(depths[:2])], 880 | norm_layer=norm_layer, 881 | use_checkpoint=use_checkpoint, 882 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer) 883 | 884 | self.qs1 = TRANSFORMER_BLOCK(dim=embed_dim, 885 | output_dim=embed_dim, 886 | input_resolution=(img_size // (2 ** 2), 887 | img_size // (2 ** 2)), 888 | depth=depths[2], 889 | num_heads=num_heads[2], 890 | win_size=win_size, 891 | mlp_ratio=self.mlp_ratio, 892 | qkv_bias=qkv_bias, qk_scale=qk_scale, 893 | drop=drop_rate, attn_drop=attn_drop_rate, 894 | drop_path=enc_dpr[sum(depths[:2]):sum(depths[:3])], 895 | norm_layer=norm_layer, 896 | use_checkpoint=use_checkpoint, 897 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer) 898 | self.qs2 = TRANSFORMER_BLOCK(dim=5*embed_dim, 899 | output_dim=embed_dim, 900 | input_resolution=(img_size // (2 ** 2), 901 | img_size // (2 ** 2)), 902 | depth=depths[2], 903 | num_heads=num_heads[2], 904 | win_size=win_size, 905 | mlp_ratio=self.mlp_ratio, 906 | qkv_bias=qkv_bias, qk_scale=qk_scale, 907 | drop=drop_rate, attn_drop=attn_drop_rate, 908 | drop_path=enc_dpr[sum(depths[:2]):sum(depths[:3])], 909 | norm_layer=norm_layer, 910 | use_checkpoint=use_checkpoint, 911 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer) 912 | self.qs3 = TRANSFORMER_BLOCK(dim=5*embed_dim, 913 | output_dim=embed_dim, 914 | input_resolution=(img_size // (2 ** 2), 915 | img_size // (2 ** 2)), 916 | depth=depths[2], 917 | num_heads=num_heads[2], 918 | win_size=win_size, 919 | mlp_ratio=self.mlp_ratio, 920 | qkv_bias=qkv_bias, qk_scale=qk_scale, 921 | drop=drop_rate, attn_drop=attn_drop_rate, 922 | drop_path=enc_dpr[sum(depths[:2]):sum(depths[:3])], 923 | norm_layer=norm_layer, 924 | use_checkpoint=use_checkpoint, 925 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer) 926 | self.qs4 = TRANSFORMER_BLOCK(dim=5*embed_dim, 927 | output_dim=embed_dim, 928 | input_resolution=(img_size // (2 ** 2), 929 | img_size // (2 ** 2)), 930 | depth=depths[2], 931 | num_heads=num_heads[2], 932 | win_size=win_size, 933 | mlp_ratio=self.mlp_ratio, 934 | qkv_bias=qkv_bias, qk_scale=qk_scale, 935 | drop=drop_rate, attn_drop=attn_drop_rate, 936 | drop_path=enc_dpr[sum(depths[:2]):sum(depths[:3])], 937 | norm_layer=norm_layer, 938 | use_checkpoint=use_checkpoint, 939 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer) 940 | 941 | # self.edge_booster = 942 | self.apply(self._init_weights) 943 | 944 | def _init_weights(self, m): 945 | if isinstance(m, nn.Linear): 946 | trunc_normal_(m.weight, std=.02) 947 | if isinstance(m, nn.Linear) and m.bias is not None: 948 | nn.init.constant_(m.bias, 0) 949 | elif isinstance(m, nn.LayerNorm): 950 | nn.init.constant_(m.bias, 0) 951 | nn.init.constant_(m.weight, 1.0) 952 | 953 | @torch.jit.ignore 954 | def no_weight_decay(self): 955 | return {'absolute_pos_embed'} 956 | 957 | @torch.jit.ignore 958 | def no_weight_decay_keywords(self): 959 | return {'relative_position_bias_table'} 960 | 961 | def extra_repr(self) -> str: 962 | return f"embed_dim={self.embed_dim}, token_projection={self.token_projection}, token_mlp={self.mlp},win_size={self.win_size}" 963 | 964 | def forward(self, x, mask=None): 965 | y = self.input_proj(x) 966 | y = self.pos_drop(y) 967 | conv01 = self.encoderlayer_0(y,mask=mask) 968 | conv11 = self.encoderlayer_0(conv01,mask=mask) 969 | conv21 = self.encoderlayer_0(conv11,mask=mask) 970 | conv31 = self.encoderlayer_0(conv21,mask=mask) 971 | conv41 = self.encoderlayer_0(conv31,mask=mask) 972 | conv51= self.encoderlayer_0(conv41,mask=mask) 973 | conv0 = self.encoderlayer_0(y,mask=mask) 974 | pool0 = self.dowsample_0(conv0) 975 | pool0 = torch.cat([pool0, self.hs_downsample(conv01)],-1) 976 | poolup0 = self.upsample_poolup1(pool0) 977 | difference1 = conv0 - poolup0 978 | conv1 = self.encoderlayer_1(pool0,mask=mask) 979 | pool1 = self.dowsample_1(conv1) 980 | pool1 = torch.cat([pool1, self.hs_downsample(self.hs_downsample(conv11))],-1) 981 | poolup1 = self.upsample_poolup2(pool1) 982 | difference2 = conv1 - poolup1 983 | 984 | conv2 = self.encoderlayer_2(pool1,mask=mask) 985 | pool2 = self.dowsample_2(conv2) 986 | pool2 = torch.cat([pool2, self.hs_downsample(self.hs_downsample(self.hs_downsample(conv21)))],-1) 987 | conv3 = self.conv(pool2, mask=mask) 988 | up0 = self.upsample_0(conv3) 989 | deconv0 = torch.cat([up0,conv2,self.hs_downsample(self.hs_downsample(conv31))],-1) 990 | deconv0 = self.decoderlayer_0(deconv0,mask=mask) 991 | up1 = self.upsample_1(deconv0) 992 | deconv1 = torch.cat([up1,difference2,self.hs_downsample(conv41)],-1) 993 | deconv1 = self.decoderlayer_1(deconv1,mask=mask) 994 | up2 = self.upsample_2(deconv1) 995 | deconv2 = torch.cat([up2,difference1],-1) 996 | deconv2 = self.decoderlayer_2(deconv2,mask=mask) 997 | z = self.output_proj1(conv51) 998 | y = self.output_proj(deconv2) 999 | return x + z + y 1000 | 1001 | 1002 | if __name__ == '__main__': 1003 | net = AIDTransformer() 1004 | device = torch.device('cpu') 1005 | net.to(device) 1006 | # summary(net.cuda(), (305, 64, 64)) 1007 | 1008 | # x = torch.randn(4, 305, 64, 64).cuda() 1009 | # mcplb = AIDTransformer().cuda() 1010 | # out = mcplb(x) 1011 | # summary(mcplb.cuda(), (305, 64, 64)) --------------------------------------------------------------------------------