├── HDMba
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ └── __init__.cpython-310.pyc
├── models
│ ├── __pycache__
│ │ ├── FFA.cpython-36.pyc
│ │ ├── FFA.cpython-38.pyc
│ │ ├── FFA.cpython-39.pyc
│ │ ├── FFA2.cpython-36.pyc
│ │ ├── FFA2.cpython-38.pyc
│ │ ├── FFA2.cpython-39.pyc
│ │ ├── SST.cpython-310.pyc
│ │ ├── SST.cpython-36.pyc
│ │ ├── AACNet.cpython-36.pyc
│ │ ├── CANet.cpython-310.pyc
│ │ ├── CANet.cpython-36.pyc
│ │ ├── FFANet.cpython-36.pyc
│ │ ├── HDMba.cpython-310.pyc
│ │ ├── LKDNet.cpython-36.pyc
│ │ ├── SCConv.cpython-36.pyc
│ │ ├── SGnet.cpython-310.pyc
│ │ ├── SGnet.cpython-36.pyc
│ │ ├── SGnet.cpython-38.pyc
│ │ ├── SGnet.cpython-39.pyc
│ │ ├── AACNet.cpython-310.pyc
│ │ ├── CSUTrans.cpython-310.pyc
│ │ ├── CSUTrans.cpython-36.pyc
│ │ ├── FFANet.cpython-310.pyc
│ │ ├── HFFormer.cpython-310.pyc
│ │ ├── HFFormer.cpython-36.pyc
│ │ ├── LKDNet.cpython-310.pyc
│ │ ├── MSST_MLK.cpython-310.pyc
│ │ ├── MSST_MLK.cpython-36.pyc
│ │ ├── PSMBNet.cpython-310.pyc
│ │ ├── PSMBNet.cpython-36.pyc
│ │ ├── RSDformer.cpython-36.pyc
│ │ ├── RSHazeNet.cpython-36.pyc
│ │ ├── RSdehaze.cpython-310.pyc
│ │ ├── RSdehaze.cpython-36.pyc
│ │ ├── RSdehaze.cpython-38.pyc
│ │ ├── RSdehaze.cpython-39.pyc
│ │ ├── Restormer.cpython-36.pyc
│ │ ├── SCConv.cpython-310.pyc
│ │ ├── SSMamba.cpython-310.pyc
│ │ ├── SSMamba.cpython-36.pyc
│ │ ├── SST_MLK.cpython-310.pyc
│ │ ├── SST_MLK.cpython-36.pyc
│ │ ├── SST_MSF.cpython-310.pyc
│ │ ├── SST_MSF.cpython-36.pyc
│ │ ├── UFormer.cpython-310.pyc
│ │ ├── UFormer.cpython-36.pyc
│ │ ├── __init__.cpython-310.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── RSDformer.cpython-310.pyc
│ │ ├── RSHazeNet.cpython-310.pyc
│ │ ├── Restormer.cpython-310.pyc
│ │ ├── SSMamba_ab.cpython-310.pyc
│ │ ├── TransWeater.cpython-36.pyc
│ │ ├── AIDTransformer.cpython-36.pyc
│ │ ├── Dehazeformer.cpython-310.pyc
│ │ ├── Dehazeformer.cpython-36.pyc
│ │ ├── GF5dehazeNet.cpython-36.pyc
│ │ ├── PerceptualLoss.cpython-36.pyc
│ │ ├── PerceptualLoss.cpython-38.pyc
│ │ ├── PerceptualLoss.cpython-39.pyc
│ │ ├── SST_MLK_GDFN.cpython-310.pyc
│ │ ├── SST_MLK_GDFN.cpython-36.pyc
│ │ ├── SST_MLK_MDTA.cpython-310.pyc
│ │ ├── SST_MLK_MDTA.cpython-36.pyc
│ │ ├── TransWeater.cpython-310.pyc
│ │ ├── TwobranchNet.cpython-310.pyc
│ │ ├── TwobranchNet.cpython-36.pyc
│ │ ├── TwobranchWoRB.cpython-36.pyc
│ │ ├── AIDTransformer.cpython-310.pyc
│ │ ├── PerceptualLoss.cpython-310.pyc
│ │ ├── TwobranchWoFFAB.cpython-36.pyc
│ │ ├── TwobranchWoLoss.cpython-36.pyc
│ │ ├── TwobranchWoSSAB.cpython-36.pyc
│ │ ├── Twobranch_ablation.cpython-36.pyc
│ │ └── Twobranchablation.cpython-36.pyc
│ ├── __init__.py
│ ├── PerceptualLoss.py
│ ├── FFANet.py
│ ├── AACNet.py
│ ├── HDMba.py
│ ├── Dehazeformer.py
│ └── AIDTransformer.py
├── option.py
├── data_utils.py
├── test.py
├── metrics.py
├── main.py
└── util.py
├── performance.PNG
└── README.md
/HDMba/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/performance.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/performance.PNG
--------------------------------------------------------------------------------
/HDMba/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/FFA.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/FFA.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/FFA.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/FFA.cpython-38.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/FFA.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/FFA.cpython-39.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/FFA2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/FFA2.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/FFA2.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/FFA2.cpython-38.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/FFA2.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/FFA2.cpython-39.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/SST.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SST.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/SST.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SST.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/AACNet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/AACNet.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/CANet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/CANet.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/CANet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/CANet.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/FFANet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/FFANet.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/HDMba.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/HDMba.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/LKDNet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/LKDNet.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/SCConv.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SCConv.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/SGnet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SGnet.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/SGnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SGnet.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/SGnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SGnet.cpython-38.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/SGnet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SGnet.cpython-39.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/AACNet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/AACNet.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/CSUTrans.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/CSUTrans.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/CSUTrans.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/CSUTrans.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/FFANet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/FFANet.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/HFFormer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/HFFormer.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/HFFormer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/HFFormer.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/LKDNet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/LKDNet.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/MSST_MLK.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/MSST_MLK.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/MSST_MLK.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/MSST_MLK.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/PSMBNet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/PSMBNet.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/PSMBNet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/PSMBNet.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/RSDformer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/RSDformer.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/RSHazeNet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/RSHazeNet.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/RSdehaze.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/RSdehaze.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/RSdehaze.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/RSdehaze.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/RSdehaze.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/RSdehaze.cpython-38.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/RSdehaze.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/RSdehaze.cpython-39.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/Restormer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/Restormer.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/SCConv.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SCConv.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/SSMamba.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SSMamba.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/SSMamba.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SSMamba.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/SST_MLK.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SST_MLK.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/SST_MLK.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SST_MLK.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/SST_MSF.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SST_MSF.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/SST_MSF.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SST_MSF.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/UFormer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/UFormer.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/UFormer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/UFormer.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/RSDformer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/RSDformer.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/RSHazeNet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/RSHazeNet.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/Restormer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/Restormer.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/SSMamba_ab.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SSMamba_ab.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/TransWeater.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/TransWeater.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/AIDTransformer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/AIDTransformer.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/Dehazeformer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/Dehazeformer.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/Dehazeformer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/Dehazeformer.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/GF5dehazeNet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/GF5dehazeNet.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/PerceptualLoss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/PerceptualLoss.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/PerceptualLoss.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/PerceptualLoss.cpython-38.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/PerceptualLoss.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/PerceptualLoss.cpython-39.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/SST_MLK_GDFN.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SST_MLK_GDFN.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/SST_MLK_GDFN.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SST_MLK_GDFN.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/SST_MLK_MDTA.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SST_MLK_MDTA.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/SST_MLK_MDTA.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/SST_MLK_MDTA.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/TransWeater.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/TransWeater.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/TwobranchNet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/TwobranchNet.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/TwobranchNet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/TwobranchNet.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/TwobranchWoRB.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/TwobranchWoRB.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/AIDTransformer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/AIDTransformer.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/PerceptualLoss.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/PerceptualLoss.cpython-310.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/TwobranchWoFFAB.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/TwobranchWoFFAB.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/TwobranchWoLoss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/TwobranchWoLoss.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/TwobranchWoSSAB.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/TwobranchWoSSAB.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/Twobranch_ablation.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/Twobranch_ablation.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__pycache__/Twobranchablation.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RsAI-lab/HDMba/HEAD/HDMba/models/__pycache__/Twobranchablation.cpython-36.pyc
--------------------------------------------------------------------------------
/HDMba/models/__init__.py:
--------------------------------------------------------------------------------
1 | import sys,os
2 | dir = os.path.abspath(os.path.dirname(__file__))
3 | sys.path.append(dir)
4 | # from FFA import FFA
5 | from PerceptualLoss import LossNetwork as PerLoss
6 | # import net
7 |
--------------------------------------------------------------------------------
/HDMba/models/PerceptualLoss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | # --- Perceptual loss network --- #
6 | class LossNetwork(torch.nn.Module):
7 | def __init__(self, vgg_model):
8 | super(LossNetwork, self).__init__()
9 | self.vgg_layers = vgg_model
10 | self.layer_name_mapping = {
11 | '3': "relu1_2",
12 | '8': "relu2_2",
13 | '15': "relu3_3"
14 | }
15 |
16 | def output_features(self, x):
17 | output = {}
18 | for name, module in self.vgg_layers._modules.items():
19 | x = module(x)
20 | if name in self.layer_name_mapping:
21 | output[self.layer_name_mapping[name]] = x
22 | return list(output.values())
23 |
24 | def forward(self, dehaze, gt):
25 | loss = []
26 | dehaze_features = self.output_features(dehaze)
27 | gt_features = self.output_features(gt)
28 | for dehaze_feature, gt_feature in zip(dehaze_features, gt_features):
29 | loss.append(F.mse_loss(dehaze_feature, gt_feature))
30 |
31 | return sum(loss)/len(loss)
--------------------------------------------------------------------------------
/HDMba/option.py:
--------------------------------------------------------------------------------
1 | import torch,os,sys,torchvision,argparse
2 | import torchvision.transforms as tfs
3 | import time,math
4 | import numpy as np
5 | from torch.backends import cudnn
6 | from torch import optim
7 | import torch,warnings
8 | from torch import nn
9 | import torchvision.utils as vutils
10 | warnings.filterwarnings('ignore')
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--steps', type=int, default=1000)
14 | parser.add_argument('--device', type=str, default='Automatic detection')
15 | parser.add_argument('--resume', type=bool, default=True)
16 | parser.add_argument('--eval_step', type=int, default=500)
17 | parser.add_argument('--lr', default=0.0001, type=float, help='learning rate')
18 | parser.add_argument('--model_dir', type=str, default='./trained_models/')
19 | parser.add_argument('--trainset', type=str, default='train')
20 | parser.add_argument('--testset', type=str, default='test')
21 |
22 | # model
23 | # parser.add_argument('--net', type=str, default='AACNet')
24 | # parser.add_argument('--net', type=str, default='AIDTransformer')
25 | # parser.add_argument('--net', type=str, default='DehazeFormer')
26 | parser.add_argument('--net', type=str, default='HDMba')
27 |
28 | parser.add_argument('--bs', type=int, default=2, help='batch size')
29 | parser.add_argument('--crop_size', type=int, default=64, help='Takes effect when using --crop')
30 | parser.add_argument('--no_lr_sche', action='store_true', help='no lr cos schedule')
31 | parser.add_argument('--perloss', help='spec_loss')
32 |
33 | opt = parser.parse_args()
34 | opt.device = 'cuda' if torch.cuda.is_available() else 'cpu'
35 | model_name = opt.trainset+'_'+opt.net.split('.')[0]
36 | opt.model_dir = opt.model_dir+model_name+'.pk'
37 | log_dir = 'logs/'+model_name
38 |
39 | print(opt)
40 | # print('model_dir:', opt.model_dir)
41 |
42 | if not os.path.exists('trained_models'):
43 | os.mkdir('trained_models')
44 | if not os.path.exists('numpy_files'):
45 | os.mkdir('numpy_files')
46 | if not os.path.exists('logs'):
47 | os.mkdir('logs')
48 | if not os.path.exists('samples'):
49 | os.mkdir('samples')
50 | if not os.path.exists(f"samples/{model_name}"):
51 | os.mkdir(f'samples/{model_name}')
52 | if not os.path.exists(log_dir):
53 | os.mkdir(log_dir)
54 |
--------------------------------------------------------------------------------
/HDMba/data_utils.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | import torchvision.transforms as tfs
3 | from torchvision.transforms import functional as FF
4 | import os,sys
5 | sys.path.append('.')
6 | sys.path.append('..')
7 | from torch.utils.data import DataLoader
8 | from metrics import *
9 | from option import opt
10 | from random import randrange
11 | from osgeo import gdal
12 | gdal.PushErrorHandler("CPLQuietErrorHandler")
13 |
14 | BS = opt.bs
15 | crop_size = opt.crop_size
16 |
17 |
18 | class RESIDE_Dataset(data.Dataset):
19 | # def __init__(self,path,train,size=crop_size,format='.png'):
20 | def __init__(self, path, train, size=crop_size, format='.tif'):
21 | super(RESIDE_Dataset,self).__init__()
22 | self.size = size
23 | self.train = train
24 | self.format = format
25 | # G5数据
26 | self.haze_imgs_dir = os.listdir(os.path.join(path, 'hazy'))
27 | self.haze_imgs = [os.path.join(path, 'hazy', img) for img in self.haze_imgs_dir]
28 | self.clear_dir = os.path.join(path, 'clear')
29 |
30 | def __getitem__(self, index):
31 | # G5数据
32 | haze = gdal.Open(self.haze_imgs[index], 0) # 读入的是结构
33 | haze_array = haze.ReadAsArray().astype(np.float32) # 读入数组 (305,512,512)
34 | band, width, height = haze_array.shape
35 | img_path = self.haze_imgs[index]
36 | id = os.path.basename(img_path).split('_')[0]
37 | clear_name = id+self.format
38 | clear = gdal.Open(os.path.join(self.clear_dir, clear_name))
39 | clear_array = clear.ReadAsArray().astype(np.float32)
40 | haze_array = haze_array - (np.min(haze_array)-np.min(clear_array))
41 | haze_array = haze_array / np.max(haze_array)
42 | clear_array = clear_array / np.max(clear_array)
43 |
44 | # 影像裁剪成块
45 | if isinstance(self.size, int):
46 | x, y = randrange(0, width - self.size + 1), randrange(0, height - self.size + 1)
47 | haze_array = haze_array[:, x:x+self.size, y:y+self.size]
48 | clear_array = clear_array[:, x:x + self.size, y:y + self.size]
49 |
50 | return haze_array, clear_array
51 |
52 | def __len__(self):
53 | return len(self.haze_imgs)
54 | # return len(self.haze_paths)
55 |
56 |
57 | import os
58 | pwd = os.getcwd()
59 |
60 | path = r'F:\GF-5 dehaze' # path to your 'data' folder
61 | train_loader = DataLoader(dataset=RESIDE_Dataset(path+r'\train', train=True, size=crop_size), batch_size=BS, shuffle=True)
62 | test_loader = DataLoader(dataset=RESIDE_Dataset(path+r'\test', train=False, size='whole img'), batch_size=1, shuffle=False)
63 |
64 | x, y = next(iter(train_loader))
65 |
66 |
67 | if __name__ == "__main__":
68 | pass
69 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
HDMba
3 | HDMba: Hyperspectral Remote Sensing Imagery Dehazing with State Space Model
4 |
5 | Hang Fu, [Genyun Sun](https://ocean.upc.edu.cn/2019/1107/c15434a224792/page.htm), Yinhe Li, [Jinchang Ren](https://scholar.google.com.hk/citations?user=Vsx9P-gAAAAJ&hl=zh-CN), [Aizhu Zhang](https://ocean.upc.edu.cn/2019/1108/c15434a224913/page.htm), Cheng Jing, [Pedram Ghamisi](https://www.ai4rs.com/)
6 |
7 | ArXiv Preprint ([arXiv:2406.05700](https://arxiv.org/abs/2406.05700))
8 |
9 |
10 |
11 | #
12 |
13 | ## Abstract
14 | Haze contamination in hyperspectral remote sensing images (HSI) can lead to spatial visibility degradation and spectral distortion. Haze in HSI exhibits spatial irregularity and inhomogeneous spectral distribution, with few dehazing networks available. Current CNN and Transformer-based dehazing meth- ods fail to balance global scene recovery, local detail retention, and computational efficiency. Inspired by the ability of Mamba to model long-range dependencies with linear complexity, we explore its potential for HSI dehazing and propose the first HSI Dehazing Mamba (HDMba) network. Specifically, we design a novel window selective scan module (WSSM) that captures local dependencies within windows and global correlations between windows by partitioning them. This approach improves the ability of conventional Mamba in local feature extraction. By modeling the local and global spectral-spatial information flow, we achieve a comprehensive analysis of hazy regions. The DehazeMamba layer (DML), constructed by WSSM, and residual DehazeMamba (RDM) blocks, composed of DMLs, are the core components of the HDMba framework. These components effec- tively characterize the complex distribution of haze in HSIs, aid- ing in scene reconstruction and dehazing. Experimental results on the Gaofen-5 HSI dataset demonstrate that HDMba outperforms other state-of-the-art methods in dehazing performance.
15 |
16 |
17 |
18 |

19 |
20 |
21 |
22 | ## Datasets
23 |
24 | HyperDehazing: https://github.com/RsAI-lab/HyperDehazing
25 |
26 | HDD: Available from ([Paper](https://ieeexplore.ieee.org/document/9511329))
27 |
28 |
29 | ## Other dehazing methods
30 |
31 | SG-Net:([Code](https://github.com/SZU-AdvTech-2022/158-A-Spectral-Grouping-based-Deep-Learning-Model-for-Haze-Removal-of-Hyperspectral-Images))
32 |
33 | AACNet: ([Code](http://www.jiasen.tech/papers/))
34 |
35 | DehazeFormer:([Code](https://github.com/IDKiro/DehazeFormer))
36 |
37 | AIDFormer:([Code](https://github.com/AshutoshKulkarni4998/AIDTransformer))
38 |
39 | RSDformer:([Code](https://github.com/MingTian99/RSDformer))
40 |
41 |
42 | ## Acknowledgement
43 | This project is based on FFANet ([code](https://github.com/zhilin007/FFA-Net)). Thanks for their wonderful works.
44 |
45 |
--------------------------------------------------------------------------------
/HDMba/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | from models import *
4 | import cv2
5 | from PIL import Image
6 | import torch
7 | import torch.nn as nn
8 | import matplotlib.pyplot as plt
9 | from torchvision.utils import make_grid
10 | from osgeo import gdal
11 | from util import fuse_images, tensor2im
12 | gdal.PushErrorHandler("CPLQuietErrorHandler")
13 |
14 |
15 | from models.AACNet import AACNet
16 | from models.AIDTransformer import AIDTransformer
17 | from models.Dehazeformer import DehazeFormer
18 | from models.HDMba import HDMba
19 |
20 |
21 | import time
22 | abs = os.getcwd()+'/'
23 |
24 |
25 | def TwoPercentLinear(image, max_out=255, min_out=0): # 2%的线性拉伸
26 | b, g, r = cv2.split(image) # 分开三个波段
27 |
28 | def gray_process(gray, maxout = max_out, minout = min_out):
29 | high_value = np.percentile(gray, 98) # 取得98%直方图处对应灰度
30 | low_value = np.percentile(gray, 2) # 同理
31 | truncated_gray = np.clip(gray, a_min=low_value, a_max=high_value)
32 | processed_gray = ((truncated_gray - low_value)/(high_value - low_value)) * (maxout - minout)#线性拉伸嘛
33 | return processed_gray
34 | r_p = gray_process(r)
35 | g_p = gray_process(g)
36 | b_p = gray_process(b)
37 | result = cv2.merge((b_p, g_p, r_p)) #合并处理后的三个波段
38 | return np.uint8(result)
39 |
40 |
41 | def get_write_picture_fina(img): # get_write_picture函数得到训练过程中的可视化结果
42 | img = tensor2im(img, np.float)
43 | img = img.astype(np.uint8)
44 | output = TwoPercentLinear(img[:, :, (58, 38, 20)])
45 | # output = TwoPercentLinear(img[:, :, (2, 1, 0)])
46 | return output
47 |
48 |
49 | def tensorShow(name, tensors, titles):
50 | fig = plt.figure(figsize=(8, 8))
51 | for tensor, tit, i in zip(tensors, titles, range(len(tensors))):
52 | img = make_grid(tensor)
53 | npimg = img.numpy()
54 | npimg = np.transpose(npimg, (1, 2, 0))
55 | npimg = npimg / np.max(npimg)
56 | npimg = np.clip(npimg, 0, 1)
57 | ax = fig.add_subplot(121+i)
58 | ax.imshow(npimg)
59 | plt.imsave(f"../pred_GF5/{name}_{tit}.png", npimg)
60 | ax.set_title(tit)
61 | plt.tight_layout()
62 | plt.show()
63 |
64 |
65 | parser = argparse.ArgumentParser()
66 | parser.add_argument('--test_imgs', type=str, default='F:/L8MSI/Paired/Test/pre', help='Test imgs folder')
67 | opt = parser.parse_args()
68 | img_dir = opt.test_imgs+'/'
69 |
70 | # 训练好的网络
71 | model_dir = abs+f'trained_models/train_HDMba.pk'
72 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
73 | ckp = torch.load(model_dir, map_location=device)
74 | net = HDMba()
75 |
76 | net = nn.DataParallel(net)
77 | net.load_state_dict(ckp['model'])
78 | net = net.module.to(torch.device('cpu'))
79 | net.eval()
80 |
81 |
82 | for im in os.listdir(img_dir):
83 | # print(im)
84 | start_time = time.time()
85 | print(f'\r {im}', end='', flush=True)
86 | haze = gdal.Open(img_dir+im).ReadAsArray().astype(np.float32)
87 | haze = haze / np.max(haze)
88 | # haze = haze[]
89 | haze = np.expand_dims(haze, 0)
90 | haze = torch.from_numpy(haze).type(torch.FloatTensor)
91 | with torch.no_grad():
92 | pred = net(haze)
93 | ts = torch.squeeze(pred.cpu())
94 | print(f'|time_used :{(time.time() - start_time):.4f}s ')
95 |
96 | write_image2 = get_write_picture_fina(pred)
97 | write_image_name = "C:/Users/Administrator/Desktop/canet/" + str(im) + "_new.png"
98 | Image.fromarray(np.uint8(write_image2)).save(write_image_name)
99 |
--------------------------------------------------------------------------------
/HDMba/models/FFANet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | # from torchsummary import summary
4 | # import torchsummary
5 |
6 |
7 | class PALayer(nn.Module):
8 | def __init__(self, channel):
9 | super(PALayer, self).__init__()
10 | self.pa = nn.Sequential(
11 | nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
12 | nn.ReLU(inplace=True),
13 | nn.Conv2d(channel // 8, 1, 1, padding=0, bias=True),
14 | nn.Sigmoid()
15 | )
16 | def forward(self, x):
17 | y = self.pa(x)
18 | return x * y
19 |
20 |
21 | class CALayer(nn.Module):
22 | def __init__(self, channel):
23 | super(CALayer, self).__init__()
24 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
25 | self.ca = nn.Sequential(
26 | nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
27 | nn.ReLU(inplace=True),
28 | nn.Conv2d(channel // 8, channel, 1, padding=0, bias=True),
29 | nn.Sigmoid()
30 | )
31 |
32 | def forward(self, x):
33 | y = self.avg_pool(x)
34 | y = self.ca(y)
35 | return x * y
36 |
37 |
38 | class Block(nn.Module):
39 | def __init__(self, dim):
40 | super(Block, self).__init__()
41 | self.conv1 = nn.Conv2d(dim, dim, 3, padding=1, bias=True)
42 | self.act1 = nn.ReLU(inplace=True)
43 | self.conv2 = nn.Conv2d(dim, dim, 3, padding=1, bias=True)
44 | self.calayer = CALayer(dim)
45 | self.palayer = PALayer(dim)
46 |
47 | def forward(self, x):
48 |
49 | res = self.act1(self.conv1(x))
50 | res = res+x
51 | res = self.conv2(res)
52 | res = self.calayer(res)
53 | res = self.palayer(res)
54 | res += x
55 | return res
56 |
57 |
58 | class GroupRLKAs(nn.Module):
59 | def __init__(self, dim):
60 | super(GroupRLKAs, self).__init__()
61 | self.g1 = Block(dim)
62 | self.g2 = Block(dim)
63 | self.g3 = Block(dim)
64 |
65 | def forward(self, x):
66 | y1 = self.g1(x)
67 | y2 = self.g1(y1)
68 | y3 = self.g1(y2)
69 | return y3, torch.cat([y1, y2, y3], dim=1)
70 |
71 |
72 | class FFB(nn.Module):
73 | def __init__(self, dim):
74 | super(FFB, self).__init__()
75 | self.conv0 = nn.Conv2d(dim, 32, 1, bias=False)
76 | # self.activation0 = nn.ReLU()
77 | self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=False)
78 |
79 | def forward(self, x):
80 | x = self.conv0(x)
81 | # x = self.activation0(x)
82 | x = self.conv1(x)
83 | return x
84 |
85 |
86 | class FFANet(nn.Module):
87 | def __init__(self):
88 | super(FFANet, self).__init__()
89 | # 初始特征提取层
90 | # self.conv0 = nn.Conv2d(305, 32, 3, padding=1, bias=False)
91 | self.conv0 = nn.Conv2d(7, 32, 3, padding=1, bias=False)
92 | #self.conv0 = nn.Conv2d(4, 32, 3, padding=1, bias=False)
93 | self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=False)
94 |
95 | # 中间的块
96 | self.g1 = GroupRLKAs(32)
97 | self.g2 = GroupRLKAs(32)
98 | self.g3 = GroupRLKAs(32)
99 |
100 | # 后面的块
101 | self.fusion = FFB(96*3)
102 | self.att1 = CALayer(32)
103 | self.att2 = PALayer(32)
104 | self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=False)
105 | # self.conv3 = nn.Conv2d(32, 305, 3, padding=1, bias=False)
106 | self.conv3 = nn.Conv2d(32, 7, 3, padding=1, bias=False)
107 | #self.conv3 = nn.Conv2d(32, 4, 3, padding=1, bias=False)
108 |
109 | def forward(self, x):
110 | out1 = self.conv0(x)
111 | out2 = self.conv1(out1)
112 |
113 | R1, F1 = self.g1(out2)
114 | R2, F2 = self.g2(R1)
115 | R3, F3 = self.g2(R2)
116 |
117 | Fea = torch.cat([F1, F2, F3], dim=1)
118 | Fea = self.fusion(Fea)
119 | Fea = self.att1(Fea)
120 | Fea = self.att2(Fea)
121 | Fea = self.conv2(Fea)
122 | Fea = Fea + out1
123 | Fea = self.conv3(Fea)
124 | # Fea = Fea + x
125 | return Fea
126 |
127 |
128 | if __name__ == "__main__":
129 | net = FFANet()
130 |
131 | # input_data = torch.rand(1, 305, 512, 512)
132 | # net = FFANet()
133 | # out = net(input_data)
134 | # print(out.shape)
135 |
136 | # summary(net, torch.rand(1, 305, 512, 512))
137 |
138 | device = torch.device('cpu')
139 | net.to(device)
140 |
141 | # torchsummary.summary(net.cuda(), (305, 512, 512))
--------------------------------------------------------------------------------
/HDMba/metrics.py:
--------------------------------------------------------------------------------
1 | from math import exp
2 | import math
3 | import numpy as np
4 | # import imgvision as iv # python图像光谱视觉分析库-imgvision
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from torch.autograd import Variable
9 |
10 | from math import exp
11 | import math
12 | import numpy as np
13 |
14 | import torch
15 | import torch.nn.functional as F
16 | from torch.autograd import Variable
17 | from torchvision.transforms import ToPILImage
18 | from numpy.linalg import norm
19 |
20 | def gaussian(window_size, sigma):
21 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
22 | return gauss / gauss.sum()
23 |
24 | def create_window(window_size, channel):
25 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
26 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
27 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
28 | return window
29 |
30 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
31 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel).cpu().numpy()
32 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel).cpu().numpy()
33 | # print(type(mu2))
34 | mu1_sq = np.power(mu1, 2) # mul的2次方
35 | mu2_sq = np.power(mu2, 2)
36 | mu1_mu2 = mu1 * mu2
37 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel).cpu().numpy() - mu1_sq
38 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel).cpu().numpy() - mu2_sq
39 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel).cpu().numpy() - mu1_mu2
40 | C1 = 0.01 ** 2
41 | C2 = 0.03 ** 2
42 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
43 |
44 | if size_average:
45 | return np.mean(ssim_map)
46 | else:
47 | return ssim_map.mean(1).mean(1).mean(1)
48 |
49 |
50 | def ssim(img1, img2, window_size=11, size_average=True):
51 | img1 = torch.clamp(img1, min=0, max=1) # 将输入input张量每个元素的夹紧到区间
52 | img2 = torch.clamp(img2, min=0, max=1)
53 | (_, channel, _, _) = img1.size()
54 | window = create_window(window_size, channel)
55 | if img1.is_cuda:
56 | window = window.cuda(img1.get_device())
57 | window = window.type_as(img1)
58 | return _ssim(img1, img2, window, window_size, channel, size_average)
59 |
60 | def psnr(pred, gt):
61 | pred = pred.clamp(0, 1).cpu().numpy()
62 | gt = gt.clamp(0, 1).cpu().numpy()
63 | imdff = pred - gt
64 | rmse = math.sqrt(np.mean(imdff ** 2))
65 | if rmse == 0:
66 | return 100
67 | return 20 * math.log10(1.0 / rmse)
68 |
69 |
70 | def UQI(O, F):
71 | meanO = torch.mean(O)
72 | meanF = torch.mean(F)
73 | (_, _, m, n) = np.shape(F)
74 | varO = torch.sqrt(torch.sum((O - meanO) ** 2) / (m * n - 1))
75 | varF = torch.sqrt(torch.sum((F - meanF) ** 2) / (m * n - 1))
76 |
77 | covOF = torch.sum((O - meanO) * (F - meanF)) / (m * n - 1)
78 | UQI = 4 * meanO * meanF * covOF / ((meanO ** 2 + meanF ** 2) * (varO ** 2 + varF ** 2))
79 | return UQI.data.cpu().numpy()
80 |
81 |
82 | def SAM(pred, target):
83 | pred = torch.clamp(pred, min=0, max=1)
84 | target = torch.clamp(target, min=0, max=1)
85 | pred1 = pred[0, :, :, :].cpu()
86 | target1 = target[0, :, :, :].cpu()
87 | pred1 = np.transpose(pred1, (2, 1, 0))
88 | target1 = np.transpose(target1, (2, 1, 0))
89 | sam_rad = np.zeros((pred1.shape[0], pred1.shape[1]))
90 | for x in range(pred1.shape[0]):
91 | for y in range(pred1.shape[1]):
92 | tmp_pred = pred1[x, y].ravel()
93 | tmp_true = target1[x, y].ravel()
94 | cos_value = (tmp_pred.mean() / (norm(tmp_pred) * tmp_true.mean() / norm(tmp_true)))
95 | # print(cos_value)
96 | if 1.0 < cos_value:
97 | cos_value = 1.0
98 | sam_rad[x, y] = cos_value
99 | SAM1 = np.arccos(sam_rad)
100 | # SAM1 = sam1.mean() * 180 / np.pi
101 | return SAM1
102 |
103 |
104 | def calc_sam(img_tgt, img_fus):
105 | img_tgt = np.squeeze(img_tgt)
106 | img_fus = np.squeeze(img_fus)
107 | img_tgt = img_tgt.reshape(img_tgt.shape[0], -1)
108 | img_fus = img_fus.reshape(img_fus.shape[0], -1)
109 | img_tgt = img_tgt / torch.max(img_tgt)
110 | img_fus = img_fus / torch.max(img_fus)
111 | A = torch.sqrt(torch.sum(img_tgt**2, 0))
112 | B = torch.sqrt(torch.sum(img_fus**2, 0))
113 | AB = torch.sum(img_tgt*img_fus, 0)
114 | sam = AB/(A*B)
115 | sam = torch.arccos_(sam)
116 | sam = torch.mean(sam)
117 | return sam
118 |
119 |
120 | if __name__ == "__main__":
121 | pass
122 |
123 | # np.random.seed(10)
124 | # pred = np.random.rand(1, 20, 100, 100)
125 | # targets = np.random.rand(1, 20, 100, 100)
126 | # pred = torch.rand(1, 20, 100, 100)
127 | # targets = torch.rand(1, 20, 100, 100)
128 | # sam1 = calc_sam(pred, targets)
129 | # sam1 = SAM(pred, targets)
130 | # print(sam1)
131 | #
132 | # pred1 = np.transpose(pred, (2, 1, 0))
133 | # targets1 = np.transpose(targets, (2, 1, 0))
134 | # Metric = iv.spectra_metric(pred1, targets1)
135 | # SAM1 = Metric.SAM()
136 | # print(SAM1)
137 | #
138 | # UQI1 = UQI(pred, targets)
139 | # print(UQI1)
140 | #
141 | # pred2 = torch.Tensor(pred)
142 | # targets2 = torch.Tensor(targets)
143 | # ssim1 = ssim(pred2, targets2).item()
144 | # print(ssim1)
145 | # psnr1 = psnr(pred2, targets2)
146 | # print(psnr1)
147 |
148 |
149 |
150 |
--------------------------------------------------------------------------------
/HDMba/main.py:
--------------------------------------------------------------------------------
1 | # load Network
2 | from models.AACNet import AACNet
3 | from models.AIDTransformer import AIDTransformer
4 | from models.Dehazeformer import DehazeFormer
5 | from models.HDMba import HDMba
6 |
7 | import time, math
8 | import numpy as np
9 | from torch.backends import cudnn
10 | from torch import optim
11 | import torch, warnings
12 | from torch import nn
13 | # from tensorboardX import SummaryWriter
14 | import torchvision.utils as vutils
15 | warnings.filterwarnings('ignore')
16 | from option import opt, model_name, log_dir
17 | from data_utils import *
18 | from matplotlib import pyplot as plt
19 | # torch.autograd.set_detect_anomaly(True)
20 |
21 | print('log_dir :', log_dir)
22 | print('model_name:', model_name)
23 | models_ = {
24 | 'AACNet': AACNet(),
25 | 'AIDTransformer': AIDTransformer(),
26 | 'DehazeFormer': DehazeFormer(),
27 | 'HDMba': HDMba(),
28 | }
29 | loaders_ = \
30 | {'train': train_loader, 'test': test_loader}
31 | start_time = time.time()
32 | T = opt.steps
33 |
34 |
35 | def lr_schedule_cosdecay(t, T, init_lr=opt.lr):
36 | lr = 0.5 * (1 + math.cos(t * math.pi / T)) * init_lr
37 | return lr
38 |
39 |
40 | def train(net, loader_train, loader_test, optim, criterion):
41 | losses = []
42 | start_step = 0
43 | max_ssim = 0
44 | max_psnr = 0
45 | max_uqi = 0
46 | min_sam = 1
47 | ssims = []
48 | psnrs = []
49 | uqis = []
50 | sams = []
51 | if opt.resume and os.path.exists(opt.model_dir):
52 | print(f'resume from {opt.model_dir}')
53 | ckp = torch.load(opt.model_dir)
54 | losses = ckp['losses']
55 | net.load_state_dict(ckp['model'])
56 | start_step = ckp['step']
57 | max_ssim = ckp['max_ssim']
58 | max_psnr = ckp['max_psnr']
59 | min_sam = ckp['min_sam']
60 | max_uqi = ckp['max_uqi']
61 | psnrs = ckp['psnrs']
62 | ssims = ckp['ssims']
63 | uqis = ckp['uqis']
64 | sams = ckp['sams']
65 | print(f'start_step:{start_step} start training ---')
66 | else:
67 | print('train from scratch *** ') # 没有训练好的网络,从头训练
68 |
69 | train_los = np.zeros(opt.steps)
70 | for step in range(start_step+1, opt.steps+1):
71 | net.train()
72 | lr = opt.lr
73 | if not opt.no_lr_sche:
74 | lr = lr_schedule_cosdecay(step, T)
75 | for param_group in optim.param_groups:
76 | param_group["lr"] = lr
77 | x, y = next(iter(loader_train))
78 | x = x.to(opt.device)
79 | y = y.to(opt.device)
80 | out = net(x)
81 | # loss function
82 | loss1 = criterion[0](out, y)
83 | loss2 = criterion[1](out, y)
84 | loss = loss1 + 0.01*loss2
85 | loss.backward()
86 |
87 | optim.step()
88 | optim.zero_grad()
89 | losses.append(loss.item())
90 | train_los[step-1] = loss.item()
91 | print(
92 | f'\rtrain loss : {loss.item():.5f} |step :{step}/{opt.steps} |lr:{lr :.7f} |time_used :{(time.time() - start_time) :.4f}s',
93 | end='', flush=True)
94 |
95 | if step % opt.eval_step == 0:
96 | with torch.no_grad():
97 | ssim_eval, psnr_eval, uqi_eval, sam_eval = test(net, loader_test)
98 | print(f'\nstep :{step} |ssim:{ssim_eval:.4f} |psnr:{psnr_eval:.4f} |uqi:{uqi_eval:.4f} |sam:{sam_eval:.4f}')
99 | ssims.append(ssim_eval)
100 | psnrs.append(psnr_eval)
101 | uqis.append(uqi_eval)
102 | sams.append(sam_eval)
103 | # if psnr_eval > max_psnr and ssim_eval > max_ssim and uqi_eval > max_uqi and min_sam > sam_eval:
104 | if psnr_eval > max_psnr and ssim_eval > max_ssim:
105 | max_ssim = max(max_ssim, ssim_eval)
106 | max_psnr = max(max_psnr, psnr_eval)
107 | max_uqi = max(max_uqi, uqi_eval)
108 | min_sam = min(min_sam, sam_eval)
109 | torch.save({
110 | 'step': step,
111 | 'max_psnr': max_psnr,
112 | 'max_ssim': max_ssim,
113 | 'max_uqi': max_uqi,
114 | 'min_sam': min_sam,
115 | 'ssims': ssims,
116 | 'psnrs': psnrs,
117 | 'uqis': uqis,
118 | 'sams': sams,
119 | 'losses': losses,
120 | 'model': net.state_dict()
121 | }, opt.model_dir)
122 | print(f'\n model saved at step :{step}| max_psnr:{max_psnr:.4f}|max_ssim:{max_ssim:.4f}|max_uqi:{max_uqi:.4f} |min_sam:{min_sam:.4f}')
123 |
124 | iters = range(len(train_los))
125 | plt.figure()
126 | plt.plot(iters, train_los, 'g', label='train loss')
127 | plt.show()
128 | np.save(f'./numpy_files/{model_name}_{opt.steps}_losses.npy', losses)
129 | np.save(f'./numpy_files/{model_name}_{opt.steps}_ssims.npy', ssims)
130 | np.save(f'./numpy_files/{model_name}_{opt.steps}_psnrs.npy', psnrs)
131 | np.save(f'./numpy_files/{model_name}_{opt.steps}_uqis.npy', uqis)
132 | np.save(f'./numpy_files/{model_name}_{opt.steps}_sams.npy', sams)
133 |
134 |
135 | def test(net, loader_test): # verification
136 | net.eval()
137 | torch.cuda.empty_cache()
138 | ssims = []
139 | psnrs = []
140 | uqis = []
141 | sams = []
142 | for i, (inputs, targets) in enumerate(loader_test):
143 | inputs = inputs.to(opt.device)
144 | targets = targets.to(opt.device)
145 | pred = net(inputs)
146 | ssim1 = ssim(pred, targets).item()
147 | psnr1 = psnr(pred, targets)
148 | uqi1 = UQI(pred, targets)
149 | sam1 = SAM(pred, targets)
150 | # sam1 = calc_sam(pred, targets)
151 | ssims.append(ssim1)
152 | psnrs.append(psnr1)
153 | uqis.append(uqi1)
154 | sams.append(sam1)
155 | return np.mean(ssims), np.mean(psnrs), np.mean(uqis), np.mean(sams)
156 |
157 |
158 | if __name__ == "__main__":
159 | loader_train = loaders_[opt.trainset]
160 | loader_test = loaders_[opt.testset]
161 | net = models_[opt.net]
162 | net = net.to(opt.device)
163 | if opt.device == 'cuda':
164 | net = torch.nn.DataParallel(net)
165 | cudnn.benchmark = True
166 | criterion = []
167 | criterion.append(nn.L1Loss().to(opt.device)) # L1损失被放入到criterion[0]
168 | criterion.append(nn.MSELoss().to(opt.device))
169 |
170 | optimizer = optim.Adam(params=filter(lambda x: x.requires_grad, net.parameters()), lr=opt.lr, betas=(0.9, 0.999), eps=1e-08)
171 | optimizer.zero_grad()
172 | if torch.cuda.device_count() > 1:
173 | model = torch.nn.DataParallel(net) # 前提是model已经在cuda上了
174 | train(net, loader_train, loader_test, optimizer, criterion)
175 |
176 |
177 |
--------------------------------------------------------------------------------
/HDMba/util.py:
--------------------------------------------------------------------------------
1 | """This module contains simple helper functions """
2 | from __future__ import print_function
3 | import torch
4 | import numpy as np
5 | from PIL import Image
6 | import scipy.ndimage as ndimage
7 | import os
8 | import cv2
9 | from random import randrange
10 | import torch.nn.functional as F
11 |
12 |
13 | def find_patches(img_c, patch_size):
14 | _, band, width, height = img_c.shape
15 | x1, y1 = randrange(0, width - patch_size + 1), randrange(0, height - patch_size + 1)
16 | img_patch = img_c[:, :, x1:x1 + patch_size, y1:y1 + patch_size]
17 | for id in range(5):
18 | x, y = randrange(0, width - patch_size + 1), randrange(0, height - patch_size + 1)
19 | patch = img_c[:, :, x:x + patch_size, y:y + patch_size]
20 | img_patch = torch.cat((img_patch, patch), dim=1)
21 | # img_patch.append(patch)
22 | return img_patch.cuda()
23 |
24 |
25 | def fuse_images(real_I, rec_J, refine_J):
26 | """
27 | real_I, rec_J, and refine_J: Images with shape hxwx3
28 | """
29 | # realness features
30 | mat_RGB2YMN = np.array([[0.299,0.587,0.114],
31 | [0.30,0.04,-0.35],
32 | [0.34,-0.6,0.17]])
33 |
34 | recH,recW,recChl = rec_J.shape
35 | rec_J_flat = rec_J.reshape([recH*recW,recChl])
36 | rec_J_flat_YMN = (mat_RGB2YMN.dot(rec_J_flat.T)).T
37 | rec_J_YMN = rec_J_flat_YMN.reshape(rec_J.shape)
38 |
39 | refine_J_flat = refine_J.reshape([recH*recW,recChl])
40 | refine_J_flat_YMN = (mat_RGB2YMN.dot(refine_J_flat.T)).T
41 | refine_J_YMN = refine_J_flat_YMN.reshape(refine_J.shape)
42 |
43 | real_I_flat = real_I.reshape([recH*recW,recChl])
44 | real_I_flat_YMN = (mat_RGB2YMN.dot(real_I_flat.T)).T
45 | real_I_YMN = real_I_flat_YMN.reshape(real_I.shape)
46 |
47 | # gradient features
48 | rec_Gx = cv2.Sobel(rec_J_YMN[:,:,0],cv2.CV_64F,1,0,ksize=3)
49 | rec_Gy = cv2.Sobel(rec_J_YMN[:,:,0],cv2.CV_64F,0,1,ksize=3)
50 | rec_GM = np.sqrt(rec_Gx**2 + rec_Gy**2)
51 |
52 | refine_Gx = cv2.Sobel(refine_J_YMN[:,:,0],cv2.CV_64F,1,0,ksize=3)
53 | refine_Gy = cv2.Sobel(refine_J_YMN[:,:,0],cv2.CV_64F,0,1,ksize=3)
54 | refine_GM = np.sqrt(refine_Gx**2 + refine_Gy**2)
55 |
56 | real_Gx = cv2.Sobel(real_I_YMN[:,:,0],cv2.CV_64F,1,0,ksize=3)
57 | real_Gy = cv2.Sobel(real_I_YMN[:,:,0],cv2.CV_64F,0,1,ksize=3)
58 | real_GM = np.sqrt(real_Gx**2 + real_Gy**2)
59 |
60 | # similarity
61 | rec_S_V = (2*real_GM*rec_GM+160)/(real_GM**2+rec_GM**2+160)
62 | rec_S_M = (2*rec_J_YMN[:,:,1]*real_I_YMN[:,:,1]+130)/(rec_J_YMN[:,:,1]**2+real_I_YMN[:,:,1]**2+130)
63 | rec_S_N = (2*rec_J_YMN[:,:,2]*real_I_YMN[:,:,2]+130)/(rec_J_YMN[:,:,2]**2+real_I_YMN[:,:,2]**2+130)
64 | rec_S_R = (rec_S_M*rec_S_N).reshape([recH,recW])
65 |
66 | refine_S_V = (2*real_GM*refine_GM+160)/(real_GM**2+refine_GM**2+160)
67 | refine_S_M = (2*refine_J_YMN[:,:,1]*real_I_YMN[:,:,1]+130)/(refine_J_YMN[:,:,1]**2+real_I_YMN[:,:,1]**2+130)
68 | refine_S_N = (2*refine_J_YMN[:,:,2]*real_I_YMN[:,:,2]+130)/(refine_J_YMN[:,:,2]**2+real_I_YMN[:,:,2]**2+130)
69 | refine_S_R = (refine_S_M*refine_S_N).reshape([recH,recW])
70 |
71 | rec_S = rec_S_R*np.power(rec_S_V, 0.4)
72 | refine_S = refine_S_R*np.power(refine_S_V, 0.4)
73 |
74 | fuseWeight = np.exp(rec_S)/(np.exp(rec_S)+np.exp(refine_S))
75 | return fuseWeight
76 |
77 |
78 | def get_tensor_dark_channel(img, neighborhood_size):
79 | shape = img.shape
80 | if len(shape) == 4:
81 | img_min = torch.min(img, dim=1)
82 | img_dark = F.max_pool2d(img_min, kernel_size=neighborhood_size, stride=1)
83 | else:
84 | raise NotImplementedError('get_tensor_dark_channel is only for 4-d tensor [N*C*H*W]')
85 |
86 | return img_dark
87 |
88 |
89 | def array2Tensor(in_array, gpu_id=-1):
90 | in_shape = in_array.shape
91 | if len(in_shape) == 2:
92 | in_array = in_array[:,:,np.newaxis]
93 |
94 | arr_tmp = in_array.transpose([2,0,1])
95 | arr_tmp = arr_tmp[np.newaxis,:]
96 |
97 | if gpu_id >= 0:
98 | return torch.tensor(arr_tmp.astype(np.float)).to(gpu_id)
99 | else:
100 | return torch.tensor(arr_tmp.astype(np.float))
101 |
102 |
103 | def tensor2im(input_image, imtype=np.uint8):
104 | """"Converts a Tensor array into a numpy image array.
105 |
106 | Parameters:
107 | input_image (tensor) -- the input image tensor array
108 | imtype (type) -- the desired type of the converted numpy array
109 | """
110 | if not isinstance(input_image, np.ndarray):
111 | if isinstance(input_image, torch.Tensor): # get the data from a variable
112 | image_tensor = input_image.data
113 | else:
114 | return input_image
115 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
116 | if image_numpy.shape[0] == 1: # grayscale to RGB
117 | image_numpy = np.tile(image_numpy, (3, 1, 1))
118 | # image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
119 | # 计算张量的最小值和最大值
120 | min_val, max_val = image_numpy.min(), image_numpy.max()
121 | # 执行归一化操作
122 | image_numpy = (image_numpy - min_val) / (max_val - min_val)
123 | image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0 # post-processing: tranpose and scaling
124 | else: # if it is a numpy array, do nothing
125 | image_numpy = input_image
126 | return image_numpy.astype(imtype)
127 |
128 |
129 | def rescale_tensor(input_tensor):
130 | """"Converts a Tensor array into the Tensor array whose data are identical to the image's.
131 | [height, width] not [width, height]
132 |
133 | Parameters:
134 | input_image (tensor) -- the input image tensor array
135 | imtype (type) -- the desired type of the converted numpy array
136 | """
137 |
138 | if isinstance(input_tensor, torch.Tensor):
139 | input_tmp = input_tensor.cpu().float()
140 | output_tmp = (input_tmp + 1) / 2.0 * 255.0
141 | output_tmp = output_tmp.to(torch.uint8)
142 | else:
143 | return input_tensor
144 |
145 | return output_tmp.to(torch.float32) / 255.0
146 |
147 |
148 |
149 |
150 |
--------------------------------------------------------------------------------
/HDMba/models/AACNet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.init as init
3 | import torch
4 | import torch.nn.functional as F
5 | from torchsummary import summary
6 |
7 |
8 | class FAConv(nn.Module):
9 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, deploy=False, reduce_gamma=False, gamma_init=None ):
10 | super(FAConv, self).__init__()
11 | self.deploy = deploy
12 | if deploy:
13 | self.fused_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size,kernel_size), stride=stride,
14 | padding=padding, bias=False, padding_mode= padding_mode)
15 | # self.fused_point_conv = nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=1,stride=1,padding=0)
16 | else:
17 | self.square_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, kernel_size),stride=stride,padding=padding, dilation=dilation,bias=True)
18 | self.square_point_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1,stride=stride,bias=True)
19 | if padding - kernel_size // 2 >= 0:
20 | # Common use case. E.g., k=3, p=1 or k=5, p=2
21 | self.crop = 0
22 | # Compared to the KxK layer, the padding of the 1xK layer and Kx1 layer should be adjust to align the sliding windows (Fig 2 in the paper)
23 | hor_padding = [padding - kernel_size // 2, padding]
24 | ver_padding = [padding, padding - kernel_size // 2]
25 | else:
26 | # A negative "padding" (padding - kernel_size//2 < 0, which is not a common use case) is cropping.
27 | # Since nn.Conv2d does not support negative padding, we implement it manually
28 | self.crop = kernel_size // 2 - padding
29 | hor_padding = [0, padding]
30 | ver_padding = [padding, 0]
31 |
32 | self.ver_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, 1),
33 | stride=stride,padding=ver_padding, dilation=dilation, bias=True)
34 | self.hor_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, kernel_size),
35 | stride=stride,padding=hor_padding, dilation=dilation, bias=True)
36 | if reduce_gamma:
37 | self.init_gamma(1.0 / 3)
38 | if gamma_init is not None:
39 | assert not reduce_gamma
40 | self.init_gamma(gamma_init)
41 |
42 | def forward(self, input):
43 | if self.deploy:
44 | return self.fused_conv(input)
45 | else:
46 | square_outputs_dw=self.square_conv(input)
47 | square_outputs_pw = self.square_point_conv(input)
48 | square_outputs = square_outputs_dw+square_outputs_pw
49 | if self.crop > 0:
50 | ver_input = input[:, :, :, self.crop:-self.crop]
51 | hor_input = input[:, :, self.crop:-self.crop, :]
52 | else:
53 | ver_input = input
54 | hor_input = input
55 | vertical_outputs_dw = self.ver_conv(ver_input)
56 | vertical_outputs=vertical_outputs_dw
57 | horizontal_outputs_dw = self.hor_conv(hor_input)
58 | horizontal_outputs=horizontal_outputs_dw
59 | result = square_outputs + vertical_outputs + horizontal_outputs
60 | return result
61 |
62 |
63 | class GA(nn.Module):
64 | def __init__(self, dim=150):
65 | super(GA, self).__init__()
66 | self.ga = nn.Sequential(*[
67 | nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1,
68 | padding=0, groups=1, bias=True,padding_mode="zeros"),
69 | nn.ReLU(),
70 | nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1,
71 | padding=0, groups=1, bias=True,padding_mode="zeros"),
72 | nn.Sigmoid()
73 | ])
74 |
75 | def forward(self, x):
76 | out=self.ga(x)
77 | return out
78 |
79 |
80 | class AAConv(nn.Module):
81 | def __init__(self, dim,kernel_size,padding,stride,deploy):
82 | super(AAConv, self).__init__()
83 | self.faconv=FAConv(dim, dim, kernel_size=kernel_size, padding=padding, stride=stride, deploy=deploy)
84 | self.ga=GA(dim=dim)
85 |
86 | def forward(self, input):
87 | x=self.faconv(input)
88 | attn=self.ga(input)
89 | out=attn*x
90 | return out
91 |
92 |
93 | class PCSA(nn.Module):
94 | def __init__(self, dim=150, out_dim=150):
95 | super(PCSA, self).__init__()
96 | self.in_dim=dim
97 | self.out_dim=out_dim
98 | self.avg_pooling=torch.nn.AdaptiveAvgPool2d((1,1))
99 | self.liner1=nn.Linear(dim,dim)
100 | self.liner2=nn.Linear(dim,dim)
101 | self.Sigmoid=nn.Sigmoid()
102 | self.conv2d = nn.Conv2d(dim,dim,1,1)
103 | self.avg_pooling_1d = torch.nn.AdaptiveAvgPool1d(1)
104 | self.conv1d=nn.Conv1d(1,1,1,1)
105 |
106 | def forward(self, x):
107 | # x=self.conv2d(x)
108 |
109 | x_pool=self.avg_pooling(x).squeeze(dim=3).permute(0,2,1)
110 | # print(x_pool.shape)
111 | q=self.liner1(x_pool)
112 | k=self.liner2(x_pool)
113 | attn=k.permute(0,2,1)@q
114 | attn = self.avg_pooling_1d(attn)
115 | attn = self.conv1d(attn.transpose(-1,-2)).transpose(-1,-2).unsqueeze(dim=3)
116 | attn=self.Sigmoid(attn)
117 | out =x*attn
118 | return out
119 |
120 |
121 | # Pixel Attention 像素注意力
122 | class AACNet(nn.Module):
123 | # def __init__(self, in_dim=4, out_dim=4, dim=64, kernel_size=3, padding=1, num_blocks=5, stride=1, deploy=False):
124 | # def __init__(self, in_dim=7, out_dim=7, dim=64, kernel_size=3, padding=1, num_blocks=5, stride=1, deploy=False):
125 | def __init__(self, in_dim=305, out_dim=305, dim=64, kernel_size=3, padding=1,num_blocks=5, stride=1,deploy=False):
126 | super(AACNet, self).__init__()
127 | self.blocks = nn.ModuleList([])
128 | self.blocks1 = nn.ModuleList([])
129 | self.blocks2 = nn.ModuleList([])
130 | self.dim = dim
131 | self.t_Conv1 = nn.Conv2d(self.dim, self.dim, 3, 1, 1)
132 | self.t_Conv2 = nn.Conv2d(self.dim, self.dim, 3, 1, 1)
133 | self.t_Conv3 = nn.Conv2d(self.dim, self.dim, 3, 1, 1)
134 | self.num_block = num_blocks
135 | self.Convd_in = nn.Conv2d(in_dim, self.dim, kernel_size=1, padding=0, stride=1)
136 | self.Convd = nn.Conv2d(self.dim, out_dim, kernel_size=1, padding=0, stride=1)
137 | self.Convd_out = nn.Conv2d(out_dim, out_dim, 3, 1, 1)
138 | self.cattn = PCSA(self.dim, self.dim)
139 | # self.acdw=AAConv( self.dim, self.dim, kernel_size=kernel_size, padding=padding, stride=stride, deploy=False)
140 | # self.gps=1
141 |
142 | def weigth_init(m):
143 | if isinstance(m, nn.Conv2d):
144 | init.xavier_uniform_(m.weight.data)
145 | elif isinstance(m, nn.Linear):
146 | m.weight.data.normal_(0, 0.01)
147 | m.bias.data.zero_()
148 | for _ in range(num_blocks):
149 | self.blocks.append(nn.ModuleList([
150 | AAConv(self.dim, kernel_size=kernel_size, padding=padding, stride=stride, deploy=deploy),
151 | nn.PReLU(),
152 | AAConv( self.dim, kernel_size=kernel_size, padding=padding, stride=stride, deploy=deploy),
153 | PCSA(self.dim,self.dim),
154 | ]))
155 | for _ in range(num_blocks):
156 | self.blocks1.append(nn.ModuleList([
157 | AAConv(self.dim, kernel_size=kernel_size, padding=padding, stride=stride, deploy=deploy),
158 | nn.PReLU(),
159 | AAConv(self.dim, kernel_size=kernel_size, padding=padding, stride=stride, deploy=deploy),
160 | PCSA(self.dim, self.dim),
161 | ]))
162 | for _ in range(num_blocks):
163 | self.blocks2.append(nn.ModuleList([
164 | AAConv(self.dim, kernel_size=kernel_size, padding=padding, stride=stride, deploy=deploy),
165 | nn.PReLU(),
166 | AAConv(self.dim, kernel_size=kernel_size, padding=padding, stride=stride, deploy=deploy),
167 | PCSA(self.dim, self.dim),
168 | ]))
169 | self.cattn.apply(weigth_init)
170 | self.blocks.apply(weigth_init)
171 | self.blocks1.apply(weigth_init)
172 | self.blocks2.apply(weigth_init)
173 | # self..apply(weigth_init)
174 |
175 | def forward(self, x):
176 | x_original_features = x
177 | x = self.Convd_in(x)
178 | x_shallow_features = x
179 | for (aaconv, act, aaconv1, pcsa) in self.blocks:
180 | res = x
181 | x = aaconv(x)
182 | x = act(x)
183 | x = aaconv1(x)
184 | x = pcsa(x)
185 | x = x+res
186 | x = self.t_Conv1(x)
187 | for (aaconv, act, aaconv1, pcsa) in self.blocks1:
188 | res1 = x
189 | x = aaconv(x)
190 | x = act(x)
191 | x = aaconv1(x)
192 | x = pcsa(x)
193 | x = x + res1
194 | x = self.t_Conv2(x)
195 | for (aaconv, act, aaconv1, pcsa) in self.blocks2:
196 | res2 = x
197 | x = aaconv(x)
198 | x = act(x)
199 | x = aaconv1(x)
200 | x = pcsa(x)
201 | x = x + res2
202 | x = self.t_Conv3(x)
203 | x = x+x_shallow_features
204 | x = self.cattn(x)
205 | x = self.Convd(x)
206 | x = self.Convd_out(x)
207 | out = x + x_original_features
208 | # out=x
209 | return out
210 |
211 |
212 | if __name__ == '__main__':
213 | net = AACNet(in_dim=305, out_dim=305, kernel_size=3, padding=1, stride=1, num_blocks=5).cuda()
214 | # net = AACNet(in_dim=7, out_dim=7, kernel_size=3, padding=1, stride=1, num_blocks=5).cuda()
215 | # net = AACNet(in_dim=4, out_dim=4, kernel_size=3, padding=1, stride=1, num_blocks=5).cuda()
216 |
217 | device = torch.device('cpu')
218 | net.to(device)
219 | # summary(net.cuda(), (305, 64, 64))
220 |
221 |
222 | #=======================================================================================================================
223 | # if __name__ == '__main__':
224 | # x = torch.randn(4, 305, 64, 64).cuda()
225 | # test_kernel_padding = [(3,1), (5,2), (7,3), (9,4),(11,5) ]
226 | # # mcplb = MCPLB(in_dim=150, out_dim=150, kernel_size=3, padding=1, stride=1, num_blocks=5).cuda()
227 | # mcplb = AACNet(in_dim=305, out_dim=305, kernel_size=3, padding=1, stride=1, num_blocks=5).cuda()
228 | # out = mcplb(x)
229 | # # summary(mcplb.cuda(), (150, 64, 64))
230 |
--------------------------------------------------------------------------------
/HDMba/models/HDMba.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | from torch.utils.data import DataLoader, Dataset
5 | import torch.utils.checkpoint as checkpoint
6 | import torch.nn.functional as F
7 | import einops
8 | from einops import rearrange
9 | # import tqdm
10 | # 系统相关的库
11 | import math
12 | import os
13 | import urllib.request
14 | from zipfile import ZipFile
15 | # from transformers import *
16 | from timm.models.layers import DropPath, to_2tuple
17 |
18 | torch.autograd.set_detect_anomaly(True)
19 |
20 | # 配置标识和超参数
21 | USE_MAMBA = 1
22 | DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM = 0
23 | # 设定所用设备
24 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25 |
26 | # 人为定义的超参数
27 | batch_size = 4 # 批次大小
28 | last_batch_size = 81 # 最后一个批次大小
29 | current_batch_size = batch_size
30 | different_batch_size = False
31 | h_new = None
32 | temp_buffer = None
33 |
34 |
35 | # 定义S6模块
36 | class S6(nn.Module):
37 | def __init__(self, seq_len, d_model, state_size, device):
38 | super(S6, self).__init__()
39 | # 一系列线性变换
40 | self.fc1 = nn.Linear(d_model, d_model, device=device)
41 | self.fc2 = nn.Linear(d_model, state_size, device=device)
42 | self.fc3 = nn.Linear(d_model, state_size, device=device)
43 | # 设定一些超参数
44 | self.seq_len = seq_len
45 | self.d_model = d_model
46 | self.state_size = state_size
47 | self.A = nn.Parameter(F.normalize(torch.ones(d_model, state_size, device=device), p=2, dim=-1))
48 | # 参数初始化
49 | nn.init.xavier_uniform_(self.A)
50 |
51 | self.B = torch.zeros(batch_size, self.seq_len, self.state_size, device=device)
52 | self.C = torch.zeros(batch_size, self.seq_len, self.state_size, device=device)
53 |
54 | self.delta = torch.zeros(batch_size, self.seq_len, self.d_model, device=device)
55 | self.dA = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)
56 | self.dB = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)
57 |
58 | # 定义内部参数h和y
59 | self.h = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)
60 | self.y = torch.zeros(batch_size, self.seq_len, self.d_model, device=device)
61 |
62 | # 离散化函数
63 | def discretization(self):
64 | # 离散化函数定义介绍在Mamba论文中的28页
65 | self.dB = torch.einsum("bld,bln->bldn", self.delta, self.B)
66 | # dA = torch.matrix_exp(A * delta) # matrix_exp() only supports square matrix
67 | self.dA = torch.exp(torch.einsum("bld,dn->bldn", self.delta, self.A))
68 | return self.dA, self.dB
69 |
70 | # 前行传播
71 | def forward(self, x):
72 | # 参考Mamba论文中算法2
73 | self.B = self.fc2(x)
74 | self.C = self.fc3(x)
75 | self.delta = F.softplus(self.fc1(x))
76 | # 离散化
77 | self.discretization()
78 |
79 | if DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM:
80 | # 如果不使用'h_new',将触发本地允许错误
81 | global current_batch_size
82 | current_batch_size = x.shape[0]
83 |
84 | if self.h.shape[0] != current_batch_size:
85 | different_batch_size = True
86 | # 缩放h的维度匹配当前的批次
87 | h_new = torch.einsum('bldn,bldn->bldn', self.dA, self.h[:current_batch_size, ...]) + rearrange(x, "b l d -> b l d 1") * self.dB
88 | else:
89 | different_batch_size = False
90 | h_new = torch.einsum('bldn,bldn->bldn', self.dA, self.h) + rearrange(x, "b l d -> b l d 1") * self.dB
91 |
92 | # 改变y的维度
93 | self.y = torch.einsum('bln,bldn->bld', self.C, h_new)
94 |
95 | # 基于h_new更新h的信息
96 | global temp_buffer
97 | temp_buffer = h_new.detach().clone() if not self.h.requires_grad else h_new.clone()
98 |
99 | return self.y
100 | else:
101 | # 将会触发错误
102 | # 设置h的维度
103 | h = torch.zeros(x.size(0), self.seq_len, self.d_model, self.state_size, device=x.device)
104 | y = torch.zeros_like(x)
105 |
106 | h = torch.einsum('bldn,bldn->bldn', self.dA, h) + rearrange(x, "b l d -> b l d 1") * self.dB
107 |
108 | #设置y的维度
109 | y = torch.einsum('bln,bldn->bld', self.C, h)
110 | return y
111 |
112 |
113 | class RMSNorm(nn.Module):
114 | def __init__(self, d_model: int, eps: float = 1e-5, device: str = 'cuda'):
115 | super().__init__()
116 | self.eps = eps
117 | self.weight = nn.Parameter(torch.ones(d_model, device=device))
118 |
119 | def forward(self, x):
120 | output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
121 | return output
122 |
123 |
124 | # 定义MambaBlock模块
125 | class MambaBlock(nn.Module):
126 | def __init__(self, seq_len, d_model, state_size, device):
127 | super(MambaBlock, self).__init__()
128 | self.inp_proj = nn.Linear(d_model, 2 * d_model, device=device)
129 | self.out_proj = nn.Linear(2 * d_model, d_model, device=device)
130 | # 残差连接
131 | self.D = nn.Linear(d_model, 2 * d_model, device=device)
132 | # 设置偏差属性
133 | self.out_proj.bias._no_weight_decay = True
134 | # 初始化偏差
135 | nn.init.constant_(self.out_proj.bias, 1.0)
136 | # 初始化S6模块
137 | self.S6 = S6(seq_len, 2 * d_model, state_size, device)
138 | # 添加1D卷积
139 | self.conv = nn.Conv1d(seq_len, seq_len, kernel_size=3, padding=1, groups=seq_len, device=device)
140 | # 添加线性层
141 | self.conv_linear = nn.Linear(2 * d_model, 2 * d_model, device=device)
142 | # 正则化
143 | self.norm = RMSNorm(d_model, device=device)
144 |
145 | def forward(self, x):
146 | # 参考Mamba论文中的图3
147 | x = self.norm(x)
148 | x_proj = self.inp_proj(x)
149 | # 1D卷积操作
150 | x_conv = self.conv(x_proj)
151 | x_conv_act = F.silu(x_conv) # Swish激活
152 | # 线性操作
153 | x_conv_out = self.conv_linear(x_conv_act)
154 | # S6模块操作
155 | x_ssm = self.S6(x_conv_out)
156 | x_act = F.silu(x_ssm) # Swish激活
157 | # 残差连接
158 | x_residual = F.silu(self.D(x))
159 | x_combined = x_act * x_residual
160 | x_out = self.out_proj(x_combined)
161 | return x_out
162 |
163 |
164 | # 输入:序列长度 模型维度 状态大小
165 | class Mamba(nn.Module):
166 | def __init__(self, seq_len, d_model, state_size, device):
167 | super(Mamba, self).__init__()
168 | self.seq_len = seq_len
169 | self.d_model = d_model
170 | self.state_size = state_size
171 | self.mamba_block1 = MambaBlock(self.seq_len, self.d_model, self.state_size, device)
172 | self.mamba_block2 = MambaBlock(self.seq_len, self.d_model, self.state_size, device)
173 | self.mamba_block3 = MambaBlock(self.seq_len, self.d_model, self.state_size, device)
174 |
175 | def forward(self, x):
176 | x = self.mamba_block1(x)
177 | x = self.mamba_block2(x)
178 | x = self.mamba_block3(x)
179 | return x
180 |
181 |
182 | def window_partition(x, window_size):
183 | """
184 | Args:
185 | x: (B, H, W, C)
186 | window_size (int): window size
187 | """
188 | B, H, W, C = x.shape
189 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
190 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
191 | return windows
192 |
193 |
194 | def window_reverse(windows, window_size, H, W):
195 | """
196 | Args:
197 | windows: (num_windows*B, window_size, window_size, C)
198 | window_size (int): Window size
199 | H (int): Height of image
200 | W (int): Width of image
201 | """
202 | B = int(windows.shape[0] / (H * W / window_size / window_size))
203 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
204 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
205 | return x
206 |
207 |
208 | class Mlp(nn.Module):
209 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
210 | super().__init__()
211 | out_features = out_features or in_features
212 | hidden_features = hidden_features or in_features
213 | self.fc1 = nn.Linear(in_features, hidden_features)
214 | self.act = act_layer()
215 | self.fc2 = nn.Linear(hidden_features, out_features)
216 | self.drop = nn.Dropout(drop)
217 |
218 | def forward(self, x):
219 | x = self.fc1(x)
220 | x = self.act(x)
221 | x = self.drop(x)
222 | x = self.fc2(x)
223 | x = self.drop(x)
224 | return x
225 |
226 |
227 | class SSMamba(nn.Module):
228 | def __init__(self, dim, window_size, shift_size=0,
229 | mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
230 | super().__init__()
231 | self.dim = dim
232 | self.window_size = window_size
233 | self.shift_size = shift_size
234 | self.mlp_ratio = mlp_ratio
235 |
236 | self.norm1 = norm_layer(dim)
237 | self.mamba = Mamba(seq_len=self.window_size ** 2, d_model=dim, state_size=dim, device=device)
238 |
239 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
240 | self.norm2 = norm_layer(dim)
241 | mlp_hidden_dim = int(dim * mlp_ratio)
242 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
243 |
244 | def forward(self, x):
245 | B, C, H, W = x.shape
246 | shortcut = x
247 | x = x.permute(0, 2, 3, 1).contiguous().view(B, H * W, C) # Change shape to (B, L, C)
248 | x = self.norm1(x)
249 |
250 | x = x.view(B, H, W, C)
251 | pad_r = (self.window_size - W % self.window_size) % self.window_size
252 | pad_b = (self.window_size - H % self.window_size) % self.window_size
253 | x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
254 | _, Hp, Wp, _ = x.shape
255 |
256 | if self.shift_size > 0:
257 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
258 | else:
259 | shifted_x = x
260 |
261 | x_windows = window_partition(shifted_x, self.window_size)
262 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
263 |
264 | # Apply Mamba
265 | attn_windows = self.mamba(x_windows)
266 |
267 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
268 | shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)
269 |
270 | if self.shift_size > 0:
271 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
272 | else:
273 | x = shifted_x
274 |
275 | if pad_r > 0 or pad_b > 0:
276 | x = x[:, :H, :W, :].contiguous()
277 |
278 | x = x.view(B, H * W, C)
279 | x = self.norm2(x)
280 | x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous() # Change shape back to (B, C, H, W)
281 |
282 | x = shortcut + self.drop_path(x)
283 | x = x + self.drop_path(self.mlp(x.permute(0, 2, 3, 1).contiguous().view(B, H * W, C)).view(B, H, W, C).permute(0, 3, 1, 2).contiguous())
284 |
285 | return x
286 |
287 |
288 | class SSMaBlock(nn.Module):
289 | def __init__(self,
290 | dim=32,
291 | window_size=7,
292 | depth=4,
293 | mlp_ratio=2,
294 | drop_path=0.0):
295 | super(SSMaBlock, self).__init__()
296 | self.ssmablock = nn.Sequential(*[
297 | SSMamba(dim=dim, window_size=window_size,
298 | shift_size=0 if (i % 2 == 0) else window_size // 2,
299 | mlp_ratio=mlp_ratio,
300 | drop_path=drop_path)
301 | for i in range(depth)
302 | ])
303 | self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
304 |
305 | def forward(self, x):
306 | out = self.ssmablock(x)
307 | out = self.conv(out) + x
308 | return out
309 |
310 |
311 | class HDMba(nn.Module):
312 | def __init__(self,
313 | inp_channels=305,
314 | dim=32,
315 | window_size=7,
316 | depths=[4, 4, 4, 4],
317 | mlp_ratio=2,
318 | bias=False,
319 | drop_path=0.0
320 | ):
321 | super(HDMba, self).__init__()
322 |
323 | self.conv_first = nn.Conv2d(inp_channels, dim, 3, 1, 1) # shallow feature extraction
324 | self.num_layers = depths
325 | self.layers = nn.ModuleList()
326 |
327 | for i_layer in range(len(self.num_layers)):
328 | layer = SSMaBlock(dim=dim,
329 | window_size=window_size,
330 | depth=depths[i_layer],
331 | mlp_ratio=mlp_ratio,
332 | drop_path=drop_path)
333 | self.layers.append(layer)
334 |
335 | self.output = nn.Conv2d(int(dim), dim, kernel_size=3, stride=1, padding=1, bias=bias)
336 | self.conv_delasta = nn.Conv2d(dim, inp_channels, 3, 1, 1) # reconstruction from features
337 |
338 | def forward(self, inp_img):
339 | f1 = self.conv_first(inp_img)
340 | x = f1
341 | for layer in self.layers:
342 | x = layer(x)
343 |
344 | x = self.output(x + f1)
345 | x = self.conv_delasta(x) + inp_img
346 | return x
347 |
348 | from torchsummary import summary
349 |
350 | if __name__ == '__main__':
351 | net = HDMba().to(device)
352 | # x_train = torch.randn(2, 305, 64, 64).to(device)
353 | # out_train = net(x_train)
354 | # print("Training output shape:", out_train.shape) # Expected output shape should be (2, 32, 64, 64)
355 | # summary(net.cuda(), (305, 64, 64))
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 | ############################################################################################
364 | # class SSMamba(nn.Module):
365 | # def __init__(self, dim, input_resolution, num_heads, window_size, shift_size=0,
366 | # mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
367 | # super().__init__()
368 | # self.dim = dim
369 | # self.input_resolution = input_resolution
370 | # self.num_heads = num_heads
371 | # self.window_size = window_size
372 | # self.shift_size = shift_size
373 | # self.mlp_ratio = mlp_ratio
374 | #
375 | # if min(self.input_resolution) <= self.window_size:
376 | # # if window size is larger than input resolution, we don't partition windows
377 | # self.shift_size = 0
378 | # self.window_size = min(self.input_resolution)
379 | #
380 | # self.norm1 = norm_layer(dim)
381 | # self.mamba = Mamba(seq_len=self.window_size ** 2, d_model=dim, state_size=dim, device=device)
382 | #
383 | # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
384 | # self.norm2 = norm_layer(dim)
385 | # mlp_hidden_dim = int(dim * mlp_ratio)
386 | # self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
387 | #
388 | # def forward(self, x):
389 | # H, W = self.input_resolution
390 | # B, L, C = x.shape
391 | # assert L == H * W, "input feature has wrong size"
392 | #
393 | # shortcut = x
394 | # x = self.norm1(x)
395 | # x = x.view(B, H, W, C)
396 | # # Pad feature maps to multiples of window size
397 | # pad_r = (self.window_size - W % self.window_size) % self.window_size
398 | # pad_b = (self.window_size - H % self.window_size) % self.window_size
399 | # x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
400 | # _, Hp, Wp, _ = x.shape
401 | #
402 | # if self.shift_size > 0:
403 | # shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
404 | # else:
405 | # shifted_x = x
406 | #
407 | # x_windows = window_partition(shifted_x, self.window_size)
408 | # x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
409 | #
410 | # # Apply Mamba
411 | # attn_windows = self.mamba(x_windows)
412 | #
413 | # attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
414 | # shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)
415 | #
416 | # if self.shift_size > 0:
417 | # x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
418 | # else:
419 | # x = shifted_x
420 | #
421 | # if pad_r > 0 or pad_b > 0:
422 | # x = x[:, :H, :W, :].contiguous()
423 | #
424 | # x = x.view(B, H * W, C)
425 | #
426 | # # FFN
427 | # x = shortcut + self.drop_path(x)
428 | # x = x + self.drop_path(self.mlp(self.norm2(x)))
429 | #
430 | # return x
431 | #
432 | #
433 | # # if __name__ == '__main__':
434 | # # net = SSMamba(dim=30, input_resolution=[32, 32], num_heads=3).to(device)
435 | # # x = torch.randn(4, 32*32, 30).to(device)
436 | # # out = net(x)
437 | # # print(out.shape)
438 | #
439 | #
440 | # class SSMaBlock(nn.Module):
441 | # def __init__(self,
442 | # dim=32,
443 | # num_head=3,
444 | # input_resolution=[32,32],
445 | # window_size=7,
446 | # depth=3,
447 | # mlp_ratio=2,
448 | # drop_path=0.0):
449 | # super(SSMaBlock, self).__init__()
450 | # self.ssmablock = nn.Sequential(*[SSMamba(dim=dim, input_resolution=input_resolution, num_heads=num_head, window_size=window_size,
451 | # shift_size=0 if (i % 2 == 0) else window_size // 2,
452 | # mlp_ratio=mlp_ratio,
453 | # drop_path = drop_path,
454 | # # drop_path=drop_path[i],
455 | # )
456 | # for i in range(depth)])
457 | # self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
458 | #
459 | # def forward(self, x):
460 | # out = self.ssmablock(x)
461 | # # out = self.conv(out) + x
462 | # return out
463 | #
464 | #
465 | # if __name__ == '__main__':
466 | # net = SSMaBlock(dim=30, input_resolution=[32, 32]).to(device)
467 | # x = torch.randn(4, 32*32, 30).to(device)
468 | # out = net(x)
469 | # print(out.shape)
--------------------------------------------------------------------------------
/HDMba/models/Dehazeformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 | import numpy as np
6 | from torch.nn.init import _calculate_fan_in_and_fan_out
7 | from timm.models.layers import to_2tuple, trunc_normal_
8 | from torchsummary import summary
9 |
10 |
11 | class RLN(nn.Module):
12 | r"""Revised LayerNorm"""
13 | def __init__(self, dim, eps=1e-5, detach_grad=False):
14 | super(RLN, self).__init__()
15 | self.eps = eps
16 | self.detach_grad = detach_grad
17 |
18 | self.weight = nn.Parameter(torch.ones((1, dim, 1, 1)))
19 | self.bias = nn.Parameter(torch.zeros((1, dim, 1, 1)))
20 |
21 | self.meta1 = nn.Conv2d(1, dim, 1)
22 | self.meta2 = nn.Conv2d(1, dim, 1)
23 |
24 | trunc_normal_(self.meta1.weight, std=.02)
25 | nn.init.constant_(self.meta1.bias, 1)
26 |
27 | trunc_normal_(self.meta2.weight, std=.02)
28 | nn.init.constant_(self.meta2.bias, 0)
29 |
30 | def forward(self, input):
31 | mean = torch.mean(input, dim=(1, 2, 3), keepdim=True)
32 | std = torch.sqrt((input - mean).pow(2).mean(dim=(1, 2, 3), keepdim=True) + self.eps)
33 |
34 | normalized_input = (input - mean) / std
35 |
36 | if self.detach_grad:
37 | rescale, rebias = self.meta1(std.detach()), self.meta2(mean.detach())
38 | else:
39 | rescale, rebias = self.meta1(std), self.meta2(mean)
40 |
41 | out = normalized_input * self.weight + self.bias
42 | return out, rescale, rebias
43 |
44 |
45 | class Mlp(nn.Module):
46 | def __init__(self, network_depth, in_features, hidden_features=None, out_features=None):
47 | super().__init__()
48 | out_features = out_features or in_features
49 | hidden_features = hidden_features or in_features
50 |
51 | self.network_depth = network_depth
52 |
53 | self.mlp = nn.Sequential(
54 | nn.Conv2d(in_features, hidden_features, 1),
55 | nn.ReLU(True),
56 | nn.Conv2d(hidden_features, out_features, 1)
57 | )
58 |
59 | self.apply(self._init_weights)
60 |
61 | def _init_weights(self, m):
62 | if isinstance(m, nn.Conv2d):
63 | gain = (8 * self.network_depth) ** (-1/4)
64 | fan_in, fan_out = _calculate_fan_in_and_fan_out(m.weight)
65 | std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
66 | trunc_normal_(m.weight, std=std)
67 | if m.bias is not None:
68 | nn.init.constant_(m.bias, 0)
69 |
70 | def forward(self, x):
71 | return self.mlp(x)
72 |
73 |
74 | def window_partition(x, window_size):
75 | B, H, W, C = x.shape
76 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
77 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size**2, C)
78 | return windows
79 |
80 |
81 | def window_reverse(windows, window_size, H, W):
82 | B = int(windows.shape[0] / (H * W / window_size / window_size))
83 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
84 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
85 | return x
86 |
87 |
88 | def get_relative_positions(window_size):
89 | coords_h = torch.arange(window_size)
90 | coords_w = torch.arange(window_size)
91 |
92 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
93 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
94 | relative_positions = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
95 |
96 | relative_positions = relative_positions.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
97 | relative_positions_log = torch.sign(relative_positions) * torch.log(1. + relative_positions.abs())
98 |
99 | return relative_positions_log
100 |
101 |
102 | class WindowAttention(nn.Module):
103 | def __init__(self, dim, window_size, num_heads):
104 |
105 | super().__init__()
106 | self.dim = dim
107 | self.window_size = window_size # Wh, Ww
108 | self.num_heads = num_heads
109 | head_dim = dim // num_heads
110 | self.scale = head_dim ** -0.5
111 |
112 | relative_positions = get_relative_positions(self.window_size)
113 | self.register_buffer("relative_positions", relative_positions)
114 | self.meta = nn.Sequential(
115 | nn.Linear(2, 256, bias=True),
116 | nn.ReLU(True),
117 | nn.Linear(256, num_heads, bias=True)
118 | )
119 |
120 | self.softmax = nn.Softmax(dim=-1)
121 |
122 | def forward(self, qkv):
123 | B_, N, _ = qkv.shape
124 |
125 | qkv = qkv.reshape(B_, N, 3, self.num_heads, self.dim // self.num_heads).permute(2, 0, 3, 1, 4)
126 |
127 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
128 |
129 | q = q * self.scale
130 | attn = (q @ k.transpose(-2, -1))
131 |
132 | relative_position_bias = self.meta(self.relative_positions)
133 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
134 | attn = attn + relative_position_bias.unsqueeze(0)
135 |
136 | attn = self.softmax(attn)
137 |
138 | x = (attn @ v).transpose(1, 2).reshape(B_, N, self.dim)
139 | return x
140 |
141 |
142 | class Attention(nn.Module):
143 | def __init__(self, network_depth, dim, num_heads, window_size, shift_size, use_attn=False, conv_type=None):
144 | super().__init__()
145 | self.dim = dim
146 | self.head_dim = int(dim // num_heads)
147 | self.num_heads = num_heads
148 |
149 | self.window_size = window_size
150 | self.shift_size = shift_size
151 |
152 | self.network_depth = network_depth
153 | self.use_attn = use_attn
154 | self.conv_type = conv_type
155 |
156 | if self.conv_type == 'Conv':
157 | self.conv = nn.Sequential(
158 | nn.Conv2d(dim, dim, kernel_size=3, padding=1, padding_mode='reflect'),
159 | nn.ReLU(True),
160 | nn.Conv2d(dim, dim, kernel_size=3, padding=1, padding_mode='reflect')
161 | )
162 |
163 | if self.conv_type == 'DWConv':
164 | self.conv = nn.Conv2d(dim, dim, kernel_size=5, padding=2, groups=dim, padding_mode='reflect')
165 |
166 | if self.conv_type == 'DWConv' or self.use_attn:
167 | self.V = nn.Conv2d(dim, dim, 1)
168 | self.proj = nn.Conv2d(dim, dim, 1)
169 |
170 | if self.use_attn:
171 | self.QK = nn.Conv2d(dim, dim * 2, 1)
172 | self.attn = WindowAttention(dim, window_size, num_heads)
173 |
174 | self.apply(self._init_weights)
175 |
176 | def _init_weights(self, m):
177 | if isinstance(m, nn.Conv2d):
178 | w_shape = m.weight.shape
179 |
180 | if w_shape[0] == self.dim * 2: # QK
181 | fan_in, fan_out = _calculate_fan_in_and_fan_out(m.weight)
182 | std = math.sqrt(2.0 / float(fan_in + fan_out))
183 | trunc_normal_(m.weight, std=std)
184 | else:
185 | gain = (8 * self.network_depth) ** (-1/4)
186 | fan_in, fan_out = _calculate_fan_in_and_fan_out(m.weight)
187 | std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
188 | trunc_normal_(m.weight, std=std)
189 |
190 | if m.bias is not None:
191 | nn.init.constant_(m.bias, 0)
192 |
193 | def check_size(self, x, shift=False):
194 | _, _, h, w = x.size()
195 | mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
196 | mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
197 |
198 | if shift:
199 | x = F.pad(x, (self.shift_size, (self.window_size-self.shift_size+mod_pad_w) % self.window_size,
200 | self.shift_size, (self.window_size-self.shift_size+mod_pad_h) % self.window_size), mode='reflect')
201 | else:
202 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
203 | return x
204 |
205 | def forward(self, X):
206 | B, C, H, W = X.shape
207 |
208 | if self.conv_type == 'DWConv' or self.use_attn:
209 | V = self.V(X)
210 |
211 | if self.use_attn:
212 | QK = self.QK(X)
213 | QKV = torch.cat([QK, V], dim=1)
214 |
215 | # shift
216 | shifted_QKV = self.check_size(QKV, self.shift_size > 0)
217 | Ht, Wt = shifted_QKV.shape[2:]
218 |
219 | # partition windows
220 | shifted_QKV = shifted_QKV.permute(0, 2, 3, 1)
221 | qkv = window_partition(shifted_QKV, self.window_size) # nW*B, window_size**2, C
222 |
223 | attn_windows = self.attn(qkv)
224 |
225 | # merge windows
226 | shifted_out = window_reverse(attn_windows, self.window_size, Ht, Wt) # B H' W' C
227 |
228 | # reverse cyclic shift
229 | out = shifted_out[:, self.shift_size:(self.shift_size+H), self.shift_size:(self.shift_size+W), :]
230 | attn_out = out.permute(0, 3, 1, 2)
231 |
232 | if self.conv_type in ['Conv', 'DWConv']:
233 | conv_out = self.conv(V)
234 | out = self.proj(conv_out + attn_out)
235 | else:
236 | out = self.proj(attn_out)
237 |
238 | else:
239 | if self.conv_type == 'Conv':
240 | out = self.conv(X) # no attention and use conv, no projection
241 | elif self.conv_type == 'DWConv':
242 | out = self.proj(self.conv(V))
243 |
244 | return out
245 |
246 |
247 | class TransformerBlock(nn.Module):
248 | def __init__(self, network_depth, dim, num_heads, mlp_ratio=4.,
249 | norm_layer=nn.LayerNorm, mlp_norm=False,
250 | window_size=8, shift_size=0, use_attn=True, conv_type=None):
251 | super().__init__()
252 | self.use_attn = use_attn
253 | self.mlp_norm = mlp_norm
254 |
255 | self.norm1 = norm_layer(dim) if use_attn else nn.Identity()
256 | self.attn = Attention(network_depth, dim, num_heads=num_heads, window_size=window_size,
257 | shift_size=shift_size, use_attn=use_attn, conv_type=conv_type)
258 |
259 | self.norm2 = norm_layer(dim) if use_attn and mlp_norm else nn.Identity()
260 | self.mlp = Mlp(network_depth, dim, hidden_features=int(dim * mlp_ratio))
261 |
262 | def forward(self, x):
263 | identity = x
264 | if self.use_attn: x, rescale, rebias = self.norm1(x)
265 | x = self.attn(x)
266 | if self.use_attn: x = x * rescale + rebias
267 | x = identity + x
268 |
269 | identity = x
270 | if self.use_attn and self.mlp_norm: x, rescale, rebias = self.norm2(x)
271 | x = self.mlp(x)
272 | if self.use_attn and self.mlp_norm: x = x * rescale + rebias
273 | x = identity + x
274 | return x
275 |
276 |
277 | class BasicLayer(nn.Module):
278 | def __init__(self, network_depth, dim, depth, num_heads, mlp_ratio=4.,
279 | norm_layer=nn.LayerNorm, window_size=8,
280 | attn_ratio=0., attn_loc='last', conv_type=None):
281 |
282 | super().__init__()
283 | self.dim = dim
284 | self.depth = depth
285 |
286 | attn_depth = attn_ratio * depth
287 |
288 | if attn_loc == 'last':
289 | use_attns = [i >= depth-attn_depth for i in range(depth)]
290 | elif attn_loc == 'first':
291 | use_attns = [i < attn_depth for i in range(depth)]
292 | elif attn_loc == 'middle':
293 | use_attns = [i >= (depth-attn_depth)//2 and i < (depth+attn_depth)//2 for i in range(depth)]
294 |
295 | # build blocks
296 | self.blocks = nn.ModuleList([
297 | TransformerBlock(network_depth=network_depth,
298 | dim=dim,
299 | num_heads=num_heads,
300 | mlp_ratio=mlp_ratio,
301 | norm_layer=norm_layer,
302 | window_size=window_size,
303 | shift_size=0 if (i % 2 == 0) else window_size // 2,
304 | use_attn=use_attns[i], conv_type=conv_type)
305 | for i in range(depth)])
306 |
307 | def forward(self, x):
308 | for blk in self.blocks:
309 | x = blk(x)
310 | return x
311 |
312 |
313 | class PatchEmbed(nn.Module):
314 | def __init__(self, patch_size=4, in_chans=3, embed_dim=96, kernel_size=None):
315 | super().__init__()
316 | self.in_chans = in_chans
317 | self.embed_dim = embed_dim
318 |
319 | if kernel_size is None:
320 | kernel_size = patch_size
321 |
322 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=patch_size,
323 | padding=(kernel_size-patch_size+1)//2, padding_mode='reflect')
324 |
325 | def forward(self, x):
326 | x = self.proj(x)
327 | return x
328 |
329 |
330 | class PatchUnEmbed(nn.Module):
331 | def __init__(self, patch_size=4, out_chans=3, embed_dim=96, kernel_size=None):
332 | super().__init__()
333 | self.out_chans = out_chans
334 | self.embed_dim = embed_dim
335 |
336 | if kernel_size is None:
337 | kernel_size = 1
338 |
339 | self.proj = nn.Sequential(
340 | nn.Conv2d(embed_dim, out_chans*patch_size**2, kernel_size=kernel_size,
341 | padding=kernel_size//2, padding_mode='reflect'),
342 | nn.PixelShuffle(patch_size)
343 | )
344 |
345 | def forward(self, x):
346 | x = self.proj(x)
347 | return x
348 |
349 |
350 | class SKFusion(nn.Module):
351 | def __init__(self, dim, height=2, reduction=8):
352 | super(SKFusion, self).__init__()
353 |
354 | self.height = height
355 | d = max(int(dim/reduction), 4)
356 |
357 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
358 | self.mlp = nn.Sequential(
359 | nn.Conv2d(dim, d, 1, bias=False),
360 | nn.ReLU(),
361 | nn.Conv2d(d, dim*height, 1, bias=False)
362 | )
363 |
364 | self.softmax = nn.Softmax(dim=1)
365 |
366 | def forward(self, in_feats):
367 | B, C, H, W = in_feats[0].shape
368 |
369 | in_feats = torch.cat(in_feats, dim=1)
370 | in_feats = in_feats.view(B, self.height, C, H, W)
371 |
372 | feats_sum = torch.sum(in_feats, dim=1)
373 | attn = self.mlp(self.avg_pool(feats_sum))
374 | attn = self.softmax(attn.view(B, self.height, C, 1, 1))
375 |
376 | out = torch.sum(in_feats*attn, dim=1)
377 | return out
378 |
379 |
380 | class DehazeFormer(nn.Module):
381 | def __init__(self, in_chans=305, out_chans=306, window_size=8,
382 | # def __init__(self, in_chans=7, out_chans=8, window_size=8,
383 | # def __init__(self, in_chans=4, out_chans=5, window_size=8,
384 | embed_dims=[24, 48, 96, 48, 24],
385 | mlp_ratios=[2., 4., 4., 2., 2.],
386 | depths=[8, 8, 8, 4, 4],
387 | num_heads=[2, 4, 6, 1, 1],
388 | attn_ratio=[1/4, 1/2, 3/4, 0, 0],
389 | conv_type=['DWConv', 'DWConv', 'DWConv', 'DWConv', 'DWConv'],
390 | norm_layer=[RLN, RLN, RLN, RLN, RLN]):
391 | super(DehazeFormer, self).__init__()
392 |
393 | # setting
394 | self.patch_size = 4
395 | self.window_size = window_size
396 | self.mlp_ratios = mlp_ratios
397 |
398 | # split image into non-overlapping patches
399 | self.patch_embed = PatchEmbed(
400 | patch_size=1, in_chans=in_chans, embed_dim=embed_dims[0], kernel_size=3)
401 |
402 | # backbone
403 | self.layer1 = BasicLayer(network_depth=sum(depths), dim=embed_dims[0], depth=depths[0],
404 | num_heads=num_heads[0], mlp_ratio=mlp_ratios[0],
405 | norm_layer=norm_layer[0], window_size=window_size,
406 | attn_ratio=attn_ratio[0], attn_loc='last', conv_type=conv_type[0])
407 |
408 | self.patch_merge1 = PatchEmbed(
409 | patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])
410 |
411 | self.skip1 = nn.Conv2d(embed_dims[0], embed_dims[0], 1)
412 |
413 | self.layer2 = BasicLayer(network_depth=sum(depths), dim=embed_dims[1], depth=depths[1],
414 | num_heads=num_heads[1], mlp_ratio=mlp_ratios[1],
415 | norm_layer=norm_layer[1], window_size=window_size,
416 | attn_ratio=attn_ratio[1], attn_loc='last', conv_type=conv_type[1])
417 |
418 | self.patch_merge2 = PatchEmbed(
419 | patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])
420 |
421 | self.skip2 = nn.Conv2d(embed_dims[1], embed_dims[1], 1)
422 |
423 | self.layer3 = BasicLayer(network_depth=sum(depths), dim=embed_dims[2], depth=depths[2],
424 | num_heads=num_heads[2], mlp_ratio=mlp_ratios[2],
425 | norm_layer=norm_layer[2], window_size=window_size,
426 | attn_ratio=attn_ratio[2], attn_loc='last', conv_type=conv_type[2])
427 |
428 | self.patch_split1 = PatchUnEmbed(
429 | patch_size=2, out_chans=embed_dims[3], embed_dim=embed_dims[2])
430 |
431 | assert embed_dims[1] == embed_dims[3]
432 | self.fusion1 = SKFusion(embed_dims[3])
433 |
434 | self.layer4 = BasicLayer(network_depth=sum(depths), dim=embed_dims[3], depth=depths[3],
435 | num_heads=num_heads[3], mlp_ratio=mlp_ratios[3],
436 | norm_layer=norm_layer[3], window_size=window_size,
437 | attn_ratio=attn_ratio[3], attn_loc='last', conv_type=conv_type[3])
438 |
439 | self.patch_split2 = PatchUnEmbed(
440 | patch_size=2, out_chans=embed_dims[4], embed_dim=embed_dims[3])
441 |
442 | assert embed_dims[0] == embed_dims[4]
443 | self.fusion2 = SKFusion(embed_dims[4])
444 |
445 | self.layer5 = BasicLayer(network_depth=sum(depths), dim=embed_dims[4], depth=depths[4],
446 | num_heads=num_heads[4], mlp_ratio=mlp_ratios[4],
447 | norm_layer=norm_layer[4], window_size=window_size,
448 | attn_ratio=attn_ratio[4], attn_loc='last', conv_type=conv_type[4])
449 |
450 | # merge non-overlapping patches into image
451 | self.patch_unembed = PatchUnEmbed(
452 | patch_size=1, out_chans=out_chans, embed_dim=embed_dims[4], kernel_size=3)
453 |
454 | def check_image_size(self, x):
455 | # NOTE: for I2I test
456 | _, _, h, w = x.size()
457 | mod_pad_h = (self.patch_size - h % self.patch_size) % self.patch_size
458 | mod_pad_w = (self.patch_size - w % self.patch_size) % self.patch_size
459 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
460 | return x
461 |
462 | def forward_features(self, x):
463 | x = self.patch_embed(x)
464 | x = self.layer1(x)
465 | skip1 = x
466 |
467 | x = self.patch_merge1(x)
468 | x = self.layer2(x)
469 | skip2 = x
470 |
471 | x = self.patch_merge2(x)
472 | x = self.layer3(x)
473 | x = self.patch_split1(x)
474 |
475 | x = self.fusion1([x, self.skip2(skip2)]) + x
476 | x = self.layer4(x)
477 | x = self.patch_split2(x)
478 |
479 | x = self.fusion2([x, self.skip1(skip1)]) + x
480 | x = self.layer5(x)
481 | x = self.patch_unembed(x)
482 | return x
483 |
484 | def forward(self, x):
485 | H, W = x.shape[2:]
486 | x = self.check_image_size(x)
487 |
488 | feat = self.forward_features(x)
489 | K, B = torch.split(feat, (1, 305), dim=1)
490 | # K, B = torch.split(feat, (1, 7), dim=1)
491 | # K, B = torch.split(feat, (1, 4), dim=1)
492 |
493 | x = K * x - B + x
494 | x = x[:, :, :H, :W]
495 | return x
496 |
497 |
498 | if __name__ == '__main__':
499 | net = DehazeFormer()
500 | device = torch.device('cpu')
501 | net.to(device)
502 | # summary(net.cuda(), (305, 64, 64))
503 |
504 |
505 | # x = torch.randn(4, 305, 64, 64).cuda()
506 | # mcplb = DehazeFormer().cuda()
507 | # out = mcplb(x)
508 | # # summary(mcplb.cuda(), (305, 64, 64))
509 |
510 |
511 |
512 |
513 | # def dehazeformer_t():
514 | # return DehazeFormer(
515 | # embed_dims=[24, 48, 96, 48, 24],
516 | # mlp_ratios=[2., 4., 4., 2., 2.],
517 | # depths=[4, 4, 4, 2, 2],
518 | # num_heads=[2, 4, 6, 1, 1],
519 | # attn_ratio=[0, 1/2, 1, 0, 0],
520 | # conv_type=['DWConv', 'DWConv', 'DWConv', 'DWConv', 'DWConv'])
521 | #
522 | #
523 | # def dehazeformer_s(): 使用的这个
524 | # return DehazeFormer(
525 | # embed_dims=[24, 48, 96, 48, 24],
526 | # mlp_ratios=[2., 4., 4., 2., 2.],
527 | # depths=[8, 8, 8, 4, 4],
528 | # num_heads=[2, 4, 6, 1, 1],
529 | # attn_ratio=[1/4, 1/2, 3/4, 0, 0],
530 | # conv_type=['DWConv', 'DWConv', 'DWConv', 'DWConv', 'DWConv'])
531 | #
532 | #
533 | # def dehazeformer_b():
534 | # return DehazeFormer(
535 | # embed_dims=[24, 48, 96, 48, 24],
536 | # mlp_ratios=[2., 4., 4., 2., 2.],
537 | # depths=[16, 16, 16, 8, 8],
538 | # num_heads=[2, 4, 6, 1, 1],
539 | # attn_ratio=[1/4, 1/2, 3/4, 0, 0],
540 | # conv_type=['DWConv', 'DWConv', 'DWConv', 'DWConv', 'DWConv'])
541 | #
542 | #
543 | # def dehazeformer_d():
544 | # return DehazeFormer(
545 | # embed_dims=[24, 48, 96, 48, 24],
546 | # mlp_ratios=[2., 4., 4., 2., 2.],
547 | # depths=[32, 32, 32, 16, 16],
548 | # num_heads=[2, 4, 6, 1, 1],
549 | # attn_ratio=[1/4, 1/2, 3/4, 0, 0],
550 | # conv_type=['DWConv', 'DWConv', 'DWConv', 'DWConv', 'DWConv'])
551 | #
552 | #
553 | # def dehazeformer_w():
554 | # return DehazeFormer(
555 | # embed_dims=[48, 96, 192, 96, 48],
556 | # mlp_ratios=[2., 4., 4., 2., 2.],
557 | # depths=[16, 16, 16, 8, 8],
558 | # num_heads=[2, 4, 6, 1, 1],
559 | # attn_ratio=[1/4, 1/2, 3/4, 0, 0],
560 | # conv_type=['DWConv', 'DWConv', 'DWConv', 'DWConv', 'DWConv'])
561 | #
562 | #
563 | # def dehazeformer_m():
564 | # return DehazeFormer(
565 | # embed_dims=[24, 48, 96, 48, 24],
566 | # mlp_ratios=[2., 4., 4., 2., 2.],
567 | # depths=[12, 12, 12, 6, 6],
568 | # num_heads=[2, 4, 6, 1, 1],
569 | # attn_ratio=[1/4, 1/2, 3/4, 0, 0],
570 | # conv_type=['Conv', 'Conv', 'Conv', 'Conv', 'Conv'])
571 | #
572 | #
573 | # def dehazeformer_l():
574 | # return DehazeFormer(
575 | # embed_dims=[48, 96, 192, 96, 48],
576 | # mlp_ratios=[2., 4., 4., 2., 2.],
577 | # depths=[16, 16, 16, 12, 12],
578 | # num_heads=[2, 4, 6, 1, 1],
579 | # attn_ratio=[1/4, 1/2, 3/4, 0, 0],
580 | # conv_type=['Conv', 'Conv', 'Conv', 'Conv', 'Conv'])
--------------------------------------------------------------------------------
/HDMba/models/AIDTransformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.utils.checkpoint as checkpoint
4 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
5 | import torch.nn.functional as F
6 | from einops import rearrange, repeat
7 | from einops.layers.torch import Rearrange
8 | import math
9 | import numpy as np
10 | import time
11 | from torch import einsum
12 | import einops
13 | from torchsummary import summary
14 | from torchvision.ops import DeformConv2d
15 |
16 |
17 | class ConvBlock(nn.Module):
18 | def __init__(self, in_channel, out_channel, strides=1):
19 | super(ConvBlock, self).__init__()
20 | self.strides = strides
21 | self.in_channel=in_channel
22 | self.out_channel=out_channel
23 | self.block = nn.Sequential(
24 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=strides, padding=1),
25 | nn.LeakyReLU(inplace=True),
26 | nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=strides, padding=1),
27 | nn.LeakyReLU(inplace=True),
28 | )
29 | self.conv11 = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=strides, padding=0)
30 |
31 | def forward(self, x):
32 | out1 = self.block(x)
33 | out2 = self.conv11(x)
34 | out = out1 + out2
35 | return out
36 |
37 | def flops(self, H, W):
38 | flops = H*W*self.in_channel*self.out_channel*(3*3+1)+H*W*self.out_channel*self.out_channel*3*3
39 | return flops
40 |
41 |
42 | class PosCNN(nn.Module):
43 | def __init__(self, in_chans, embed_dim=768, s=1):
44 | super(PosCNN, self).__init__()
45 | self.proj = nn.Sequential(nn.Conv2d(in_chans, embed_dim, 3, s, 1, bias=True, groups=embed_dim))
46 | self.s = s
47 |
48 | def forward(self, x, H=None, W=None):
49 | B, N, C = x.shape
50 | H = H or int(math.sqrt(N))
51 | W = W or int(math.sqrt(N))
52 | feat_token = x
53 | cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W)
54 | if self.s == 1:
55 | x = self.proj(cnn_feat) + cnn_feat
56 | else:
57 | x = self.proj(cnn_feat)
58 | x = x.flatten(2).transpose(1, 2)
59 | return x
60 |
61 | def no_weight_decay(self):
62 | return ['proj.%d.weight' % i for i in range(4)]
63 |
64 |
65 | class SELayer(nn.Module):
66 | def __init__(self, channel, reduction=16):
67 | super(SELayer, self).__init__()
68 | self.avg_pool = nn.AdaptiveAvgPool1d(1)
69 | self.fc = nn.Sequential(
70 | nn.Linear(channel, channel // reduction, bias=False),
71 | nn.ReLU(inplace=True),
72 | nn.Linear(channel // reduction, channel, bias=False),
73 | nn.Sigmoid()
74 | )
75 |
76 | def forward(self, x): # x: [B, N, C]
77 | x = torch.transpose(x, 1, 2) # [B, C, N]
78 | b, c, _ = x.size()
79 | y = self.avg_pool(x).view(b, c)
80 | y = self.fc(y).view(b, c, 1)
81 | x = x * y.expand_as(x)
82 | x = torch.transpose(x, 1, 2) # [B, N, C]
83 | return x
84 |
85 |
86 | class SpatialAttention(nn.Module):
87 | def __init__(self, kernel_size=7):
88 | super(SpatialAttention, self).__init__()
89 |
90 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
91 | self.sigmoid = nn.Sigmoid()
92 |
93 | def forward(self, x):
94 | avg_out = torch.mean(x, dim=1, keepdim=True)
95 | max_out, _ = torch.max(x, dim=1, keepdim=True)
96 | x1 = torch.cat([avg_out, max_out], dim=1)
97 | x1 = self.conv1(x1)
98 | return self.sigmoid(x1)*x
99 |
100 |
101 | class SepConv2d(torch.nn.Module):
102 | def __init__(self,
103 | in_channels,
104 | out_channels,
105 | kernel_size,
106 | stride=1,
107 | padding=0,
108 | dilation=1,act_layer=nn.ReLU):
109 | super(SepConv2d, self).__init__()
110 |
111 | self.depthwise = torch.nn.Conv2d(in_channels,
112 | in_channels,
113 | kernel_size=kernel_size,
114 | stride=stride,
115 | padding=padding,
116 | dilation=dilation,
117 | groups=in_channels)
118 | self.pointwise = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1)
119 | self.act_layer = act_layer() if act_layer is not None else nn.Identity()
120 | self.in_channels = in_channels
121 | self.out_channels = out_channels
122 | self.kernel_size = kernel_size
123 | self.stride = stride
124 | self.offset_conv1 = nn.Conv2d(in_channels, 216, 3, stride=1, padding=1, bias= False)
125 | self.deform1 = DeformConv2d(in_channels, out_channels, 3, padding=1, groups=8)
126 |
127 | self.SA = SpatialAttention()
128 |
129 | def offset_gen(self, x):
130 | sa = self.SA(x)
131 | o1, o2, mask = torch.chunk(sa, 3, dim=1)
132 | offset = torch.cat((o1, o2), dim=1)
133 | mask = torch.sigmoid(mask)
134 | return offset,mask
135 |
136 | def forward(self, x):
137 | offset1,mask = self.offset_gen(self.offset_conv1(x))
138 | feat1 = self.deform1(x, offset1, mask)
139 | x = self.act_layer(feat1)
140 | x = self.pointwise(x)
141 | return x
142 |
143 |
144 | ######## Embedding for q,k,v ########
145 | class ConvProjection(nn.Module):
146 | def __init__(self, dim, heads = 8, dim_head = 64, kernel_size=3, q_stride=1, k_stride=1, v_stride=1, dropout = 0.,
147 | last_stage=False,bias=True):
148 |
149 | super().__init__()
150 |
151 | inner_dim = dim_head * heads
152 | self.heads = heads
153 | pad = (kernel_size - q_stride)//2
154 | self.to_q = SepConv2d(dim, inner_dim, kernel_size, q_stride, pad, bias)
155 | self.to_k = SepConv2d(dim, inner_dim, kernel_size, k_stride, pad, bias)
156 | self.to_v = SepConv2d(dim, inner_dim, kernel_size, v_stride, pad, bias)
157 |
158 | def forward(self, x, attn_kv=None):
159 | b, n, c, h = * x.shape, self.heads
160 | l = int(math.sqrt(n))
161 | w = int(math.sqrt(n))
162 |
163 | attn_kv = x if attn_kv is None else attn_kv
164 | x = rearrange(x, 'b (l w) c -> b c l w', l=l, w=w)
165 | attn_kv = rearrange(attn_kv, 'b (l w) c -> b c l w', l=l, w=w)
166 | # print(attn_kv)
167 | q = self.to_q(x)
168 | q = rearrange(q, 'b (h d) l w -> b h (l w) d', h=h)
169 |
170 | k = self.to_k(attn_kv)
171 | v = self.to_v(attn_kv)
172 | k = rearrange(k, 'b (h d) l w -> b h (l w) d', h=h)
173 | v = rearrange(v, 'b (h d) l w -> b h (l w) d', h=h)
174 | return q,k,v
175 |
176 | def flops(self, H, W):
177 | flops = 0
178 | flops += self.to_q.flops(H, W)
179 | flops += self.to_k.flops(H, W)
180 | flops += self.to_v.flops(H, W)
181 | return flops
182 |
183 |
184 | class LinearProjection(nn.Module):
185 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., bias=True):
186 | super().__init__()
187 | inner_dim = dim_head * heads
188 | self.heads = heads
189 | self.to_q = nn.Linear(dim, inner_dim, bias = bias)
190 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = bias)
191 | self.dim = dim
192 | self.inner_dim = inner_dim
193 |
194 | def forward(self, x, attn_kv=None):
195 | B_, N, C = x.shape
196 | attn_kv = x if attn_kv is None else attn_kv
197 | q = self.to_q(x).reshape(B_, N, 1, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
198 | kv = self.to_kv(attn_kv).reshape(B_, N, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
199 | q = q[0]
200 | k, v = kv[0], kv[1]
201 | return q,k,v
202 |
203 | def flops(self, H, W):
204 | flops = H*W*self.dim*self.inner_dim*3
205 | return flops
206 |
207 |
208 | class LinearProjection_Concat_kv(nn.Module):
209 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., bias=True):
210 | super().__init__()
211 | inner_dim = dim_head * heads
212 | self.heads = heads
213 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = bias)
214 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = bias)
215 | self.dim = dim
216 | self.inner_dim = inner_dim
217 |
218 | def forward(self, x, attn_kv=None):
219 | B_, N, C = x.shape
220 | attn_kv = x if attn_kv is None else attn_kv
221 | qkv_dec = self.to_qkv(x).reshape(B_, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
222 | kv_enc = self.to_kv(attn_kv).reshape(B_, N, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
223 | q, k_d, v_d = qkv_dec[0], qkv_dec[1], qkv_dec[2] # make torchscript happy (cannot use tensor as tuple)
224 | k_e, v_e = kv_enc[0], kv_enc[1]
225 | k = torch.cat((k_d,k_e),dim=2)
226 | v = torch.cat((v_d,v_e),dim=2)
227 | return q,k,v
228 |
229 | def flops(self, H, W):
230 | flops = H*W*self.dim*self.inner_dim*5
231 | return flops
232 |
233 |
234 | #########################################
235 | class WindowAttention(nn.Module):
236 | def __init__(self, dim, win_size,num_heads, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,se_layer=False):
237 |
238 | super().__init__()
239 | self.dim = dim
240 | self.win_size = win_size # Wh, Ww
241 | self.num_heads = num_heads
242 | head_dim = dim // num_heads
243 | self.scale = qk_scale or head_dim ** -0.5
244 |
245 | # define a parameter table of relative position bias
246 | self.relative_position_bias_table = nn.Parameter(
247 | torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
248 |
249 | # get pair-wise relative position index for each token inside the window
250 | coords_h = torch.arange(self.win_size[0]) # [0,...,Wh-1]
251 | coords_w = torch.arange(self.win_size[1]) # [0,...,Ww-1]
252 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
253 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
254 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
255 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
256 | relative_coords[:, :, 0] += self.win_size[0] - 1 # shift to start from 0
257 | relative_coords[:, :, 1] += self.win_size[1] - 1
258 | relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1
259 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
260 | self.register_buffer("relative_position_index", relative_position_index)
261 |
262 | # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
263 | if token_projection =='conv':
264 | self.qkv = ConvProjection(dim,num_heads,dim//num_heads,bias=qkv_bias)
265 | elif token_projection =='linear_concat':
266 | self.qkv = LinearProjection_Concat_kv(dim,num_heads,dim//num_heads,bias=qkv_bias)
267 | else:
268 | self.qkv = LinearProjection(dim,num_heads,dim//num_heads,bias=qkv_bias)
269 |
270 | self.token_projection = token_projection
271 | self.attn_drop = nn.Dropout(attn_drop)
272 | self.proj = nn.Linear(dim, dim)
273 | self.se_layer = SELayer(dim) if se_layer else nn.Identity()
274 | self.proj_drop = nn.Dropout(proj_drop)
275 |
276 | trunc_normal_(self.relative_position_bias_table, std=.02)
277 | self.softmax = nn.Softmax(dim=-1)
278 |
279 | def forward(self, x, attn_kv=None, mask=None):
280 | B_, N, C = x.shape
281 | q, k, v = self.qkv(x,attn_kv)
282 | q = q * self.scale
283 | attn = (q @ k.transpose(-2, -1))
284 |
285 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
286 | self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1) # Wh*Ww,Wh*Ww,nH
287 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
288 | ratio = attn.size(-1)//relative_position_bias.size(-1)
289 | relative_position_bias = repeat(relative_position_bias, 'nH l c -> nH l (c d)', d = ratio)
290 |
291 | attn = attn + relative_position_bias.unsqueeze(0)
292 |
293 | if mask is not None:
294 | nW = mask.shape[0]
295 | mask = repeat(mask, 'nW m n -> nW m (n d)',d = ratio)
296 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N*ratio) + mask.unsqueeze(1).unsqueeze(0)
297 | attn = attn.view(-1, self.num_heads, N, N*ratio)
298 | attn = self.softmax(attn)
299 | else:
300 | attn = self.softmax(attn)
301 |
302 | attn = self.attn_drop(attn)
303 |
304 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
305 | x = self.proj(x)
306 | x = self.se_layer(x)
307 | x = self.proj_drop(x)
308 | return x
309 |
310 | def extra_repr(self) -> str:
311 | return f'dim={self.dim}, win_size={self.win_size}, num_heads={self.num_heads}'
312 |
313 |
314 | class Mlp(nn.Module):
315 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
316 | super().__init__()
317 | out_features = out_features or in_features
318 | hidden_features = hidden_features or in_features
319 | self.fc1 = nn.Linear(in_features, hidden_features)
320 | self.act = act_layer()
321 | self.fc2 = nn.Linear(hidden_features, out_features)
322 | self.drop = nn.Dropout(drop)
323 | self.in_features = in_features
324 | self.hidden_features = hidden_features
325 | self.out_features = out_features
326 |
327 | def forward(self, x):
328 | x = self.fc1(x)
329 | x = self.act(x)
330 | x = self.drop(x)
331 | x = self.fc2(x)
332 | x = self.drop(x)
333 | return x
334 |
335 |
336 | class LeFF(nn.Module):
337 | def __init__(self, dim=32, hidden_dim=128, act_layer=nn.GELU,drop = 0.):
338 | super().__init__()
339 | self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim),
340 | act_layer())
341 | self.dwconv = nn.Sequential(nn.Conv2d(hidden_dim,hidden_dim,groups=hidden_dim,kernel_size=3,stride=1,padding=1),
342 | act_layer())
343 | self.linear2 = nn.Sequential(nn.Linear(hidden_dim, dim))
344 | self.dim = dim
345 | self.hidden_dim = hidden_dim
346 |
347 | def forward(self, x):
348 | # bs x hw x c
349 | bs, hw, c = x.size()
350 | hh = int(math.sqrt(hw))
351 |
352 | x = self.linear1(x)
353 |
354 | # spatial restore
355 | x = rearrange(x, ' b (h w) (c) -> b c h w ', h = hh, w = hh)
356 | # bs,hidden_dim,32x32
357 |
358 | x = self.dwconv(x)
359 |
360 | # flaten
361 | x = rearrange(x, ' b c h w -> b (h w) c', h = hh, w = hh)
362 |
363 | x = self.linear2(x)
364 |
365 | return x
366 |
367 |
368 | def window_partition(x, win_size, dilation_rate=1):
369 | B, H, W, C = x.shape
370 | if dilation_rate != 1:
371 | x = x.permute(0,3,1,2) # B, C, H, W
372 | assert type(dilation_rate) is int, 'dilation_rate should be a int'
373 | x = F.unfold(x, kernel_size=win_size,dilation=dilation_rate,padding=4*(dilation_rate-1),stride=win_size) # B, C*Wh*Ww, H/Wh*W/Ww
374 | windows = x.permute(0,2,1).contiguous().view(-1, C, win_size, win_size) # B' ,C ,Wh ,Ww
375 | windows = windows.permute(0,2,3,1).contiguous() # B' ,Wh ,Ww ,C
376 | else:
377 | x = x.view(B, H // win_size, win_size, W // win_size, win_size, C)
378 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C) # B' ,Wh ,Ww ,C
379 | return windows
380 |
381 |
382 | def window_reverse(windows, win_size, H, W, dilation_rate=1):
383 | # B' ,Wh ,Ww ,C
384 | B = int(windows.shape[0] / (H * W / win_size / win_size))
385 | x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1)
386 | if dilation_rate !=1:
387 | x = windows.permute(0,5,3,4,1,2).contiguous() # B, C*Wh*Ww, H/Wh*W/Ww
388 | x = F.fold(x, (H, W), kernel_size=win_size, dilation=dilation_rate, padding=4*(dilation_rate-1),stride=win_size)
389 | else:
390 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
391 | return x
392 |
393 |
394 | class Downsample(nn.Module):
395 | def __init__(self, in_channel, out_channel):
396 | super(Downsample, self).__init__()
397 | self.conv = nn.Sequential(
398 | nn.Conv2d(in_channel, out_channel, kernel_size=4, stride=2, padding=1),
399 | )
400 | self.in_channel = in_channel
401 | self.out_channel = out_channel
402 |
403 | def forward(self, x):
404 | B, L, C = x.shape
405 | # import pdb;pdb.set_trace()
406 | H = int(math.sqrt(L))
407 | W = int(math.sqrt(L))
408 | x = x.transpose(1, 2).contiguous().view(B, C, H, W)
409 | out = self.conv(x).flatten(2).transpose(1,2).contiguous() # B H*W C
410 | return out
411 |
412 |
413 | # Upsample Block
414 | class Upsample(nn.Module):
415 | def __init__(self, in_channel, out_channel):
416 | super(Upsample, self).__init__()
417 | self.deconv = nn.Sequential(
418 | nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2),
419 | )
420 | self.in_channel = in_channel
421 | self.out_channel = out_channel
422 |
423 | def forward(self, x):
424 | B, L, C = x.shape
425 | H = int(math.sqrt(L))
426 | W = int(math.sqrt(L))
427 | x = x.transpose(1, 2).contiguous().view(B, C, H, W)
428 | out = self.deconv(x).flatten(2).transpose(1,2).contiguous() # B H*W C
429 | return out
430 |
431 |
432 | class Upsample_advanced(nn.Module):
433 | def __init__(self, in_channel, out_channel,factor):
434 | super(Upsample_advanced, self).__init__()
435 | self.deconv = nn.Sequential(
436 | nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=factor),
437 | )
438 | self.in_channel = in_channel
439 | self.out_channel = out_channel
440 |
441 | def forward(self, x):
442 | B, L, C = x.shape
443 | H = int(math.sqrt(L))
444 | W = int(math.sqrt(L))
445 | x = x.transpose(1, 2).contiguous().view(B, C, H, W)
446 | out = self.deconv(x).flatten(2).transpose(1,2).contiguous() # B H*W C
447 | return out
448 |
449 |
450 | # Input Projection
451 | class InputProj(nn.Module):
452 | def __init__(self, in_channel=305, out_channel=64, kernel_size=3, stride=1, norm_layer=None,act_layer=nn.LeakyReLU):
453 | super().__init__()
454 | self.proj = nn.Sequential(
455 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=kernel_size//2),
456 | act_layer(inplace=True)
457 | )
458 | if norm_layer is not None:
459 | self.norm = norm_layer(out_channel)
460 | else:
461 | self.norm = None
462 | self.in_channel = in_channel
463 | self.out_channel = out_channel
464 |
465 | def forward(self, x):
466 | B, C, H, W = x.shape
467 | x = self.proj(x).flatten(2).transpose(1, 2).contiguous() # B H*W C
468 | if self.norm is not None:
469 | x = self.norm(x)
470 | return x
471 |
472 |
473 | # Output Projection
474 | class OutputProj(nn.Module):
475 | def __init__(self, in_channel=64, out_channel=305, kernel_size=3, stride=1, norm_layer=None,act_layer=None):
476 | super().__init__()
477 | self.proj = nn.Sequential(
478 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=kernel_size//2),
479 | )
480 | if act_layer is not None:
481 | self.proj.add_module(act_layer(inplace=True))
482 | if norm_layer is not None:
483 | self.norm = norm_layer(out_channel)
484 | else:
485 | self.norm = None
486 | self.in_channel = in_channel
487 | self.out_channel = out_channel
488 |
489 | def forward(self, x):
490 | B, L, C = x.shape
491 | H = int(math.sqrt(L))
492 | W = int(math.sqrt(L))
493 | x = x.transpose(1, 2).view(B, C, H, W)
494 | x = self.proj(x)
495 | if self.norm is not None:
496 | x = self.norm(x)
497 | return x
498 |
499 |
500 | class Deformable_Attentive_Transformer(nn.Module):
501 | def __init__(self, dim, input_resolution, num_heads, win_size=8, shift_size=0,
502 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
503 | act_layer=nn.GELU, norm_layer=nn.LayerNorm,token_projection='conv',token_mlp='leff',se_layer=False):
504 | super().__init__()
505 | self.dim = dim
506 | self.input_resolution = input_resolution
507 | self.num_heads = num_heads
508 | self.win_size = win_size
509 | self.shift_size = shift_size
510 | self.mlp_ratio = mlp_ratio
511 | self.token_mlp = token_mlp
512 | if min(self.input_resolution) <= self.win_size:
513 | self.shift_size = 0
514 | self.win_size = min(self.input_resolution)
515 | assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-win_size"
516 |
517 | self.norm1 = norm_layer(dim)
518 | self.attn = WindowAttention(
519 | dim, win_size=to_2tuple(self.win_size), num_heads=num_heads,
520 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
521 | token_projection=token_projection,se_layer=se_layer)
522 |
523 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
524 | self.norm2 = norm_layer(dim)
525 | mlp_hidden_dim = int(dim * mlp_ratio)
526 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,act_layer=act_layer, drop=drop) if token_mlp=='ffn' else LeFF(dim,mlp_hidden_dim,act_layer=act_layer, drop=drop)
527 |
528 |
529 | def extra_repr(self) -> str:
530 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
531 | f"win_size={self.win_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
532 |
533 | def forward(self, x, mask=None):
534 | B, L, C = x.shape
535 | H = int(math.sqrt(L))
536 | W = int(math.sqrt(L))
537 |
538 | ## input mask
539 | if mask != None:
540 | input_mask = F.interpolate(mask, size=(H,W)).permute(0,2,3,1)
541 | input_mask_windows = window_partition(input_mask, self.win_size) # nW, win_size, win_size, 1
542 | attn_mask = input_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size
543 | attn_mask = attn_mask.unsqueeze(2)*attn_mask.unsqueeze(1) # nW, win_size*win_size, win_size*win_size
544 | attn_mask = attn_mask.masked_fill(attn_mask!=0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
545 | else:
546 | attn_mask = None
547 |
548 | ## shift mask
549 | if self.shift_size > 0:
550 | # calculate attention mask for SW-MSA
551 | shift_mask = torch.zeros((1, H, W, 1)).type_as(x)
552 | h_slices = (slice(0, -self.win_size),
553 | slice(-self.win_size, -self.shift_size),
554 | slice(-self.shift_size, None))
555 | w_slices = (slice(0, -self.win_size),
556 | slice(-self.win_size, -self.shift_size),
557 | slice(-self.shift_size, None))
558 | cnt = 0
559 | for h in h_slices:
560 | for w in w_slices:
561 | shift_mask[:, h, w, :] = cnt
562 | cnt += 1
563 | shift_mask_windows = window_partition(shift_mask, self.win_size) # nW, win_size, win_size, 1
564 | shift_mask_windows = shift_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size
565 | shift_attn_mask = shift_mask_windows.unsqueeze(1) - shift_mask_windows.unsqueeze(2) # nW, win_size*win_size, win_size*win_size
566 | shift_attn_mask = shift_attn_mask.masked_fill(shift_attn_mask != 0, float(-100.0)).masked_fill(shift_attn_mask == 0, float(0.0))
567 | attn_mask = attn_mask + shift_attn_mask if attn_mask is not None else shift_attn_mask
568 |
569 | shortcut = x
570 | x = self.norm1(x)
571 | x = x.view(B, H, W, C)
572 |
573 | # cyclic shift
574 | if self.shift_size > 0:
575 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
576 | else:
577 | shifted_x = x
578 |
579 | # partition windows
580 | x_windows = window_partition(shifted_x, self.win_size) # nW*B, win_size, win_size, C N*C->C
581 | x_windows = x_windows.view(-1, self.win_size * self.win_size, C) # nW*B, win_size*win_size, C
582 |
583 | # W-MSA/SW-MSA
584 | attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, win_size*win_size, C
585 |
586 | # merge windows
587 | attn_windows = attn_windows.view(-1, self.win_size, self.win_size, C)
588 | shifted_x = window_reverse(attn_windows, self.win_size, H, W) # B H' W' C
589 |
590 | # reverse cyclic shift
591 | if self.shift_size > 0:
592 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
593 | else:
594 | x = shifted_x
595 | x = x.view(B, H * W, C)
596 |
597 | # FFN
598 | x = shortcut + self.drop_path(x)
599 | x = x + self.drop_path(self.mlp(self.norm2(x)))
600 | del attn_mask
601 | return x
602 |
603 |
604 | class TRANSFORMER_BLOCK(nn.Module):
605 | def __init__(self, dim, output_dim, input_resolution, depth, num_heads, win_size,
606 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
607 | drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False,
608 | token_projection='linear',token_mlp='ffn',se_layer=False):
609 |
610 | super().__init__()
611 | self.dim = dim
612 | self.input_resolution = input_resolution
613 | self.depth = depth
614 | self.use_checkpoint = use_checkpoint
615 | # build blocks
616 | self.blocks = nn.ModuleList([
617 | Deformable_Attentive_Transformer(dim=dim, input_resolution=input_resolution,
618 | num_heads=num_heads, win_size=win_size,
619 | shift_size=0 if (i % 2 == 0) else win_size // 2,
620 | mlp_ratio=mlp_ratio,
621 | qkv_bias=qkv_bias, qk_scale=qk_scale,
622 | drop=drop, attn_drop=attn_drop,
623 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
624 | norm_layer=norm_layer,token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer)
625 | for i in range(depth)])
626 |
627 | def extra_repr(self) -> str:
628 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
629 |
630 | def forward(self, x, mask=None):
631 | for blk in self.blocks:
632 | if self.use_checkpoint:
633 | x = checkpoint.checkpoint(blk, x)
634 | else:
635 | x = blk(x,mask)
636 | return x
637 |
638 | def flops(self):
639 | flops = 0
640 | for blk in self.blocks:
641 | flops += blk.flops()
642 | return flops
643 |
644 |
645 | class AIDTransformer(nn.Module):
646 | # def __init__(self, img_size=64, in_chans=4,
647 | # def __init__(self, img_size=64, in_chans=7,
648 | def __init__(self, img_size=64, in_chans=305,
649 | embed_dim=32, depths=[2, 2, 2, 2, 2, 2, 2, 2, 2], num_heads=[1, 2, 4, 8, 16, 16, 8, 4, 2],
650 | win_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
651 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
652 | norm_layer=nn.LayerNorm, patch_norm=True,
653 | use_checkpoint=False, token_projection='conv', token_mlp='ffn', se_layer=False,
654 | dowsample=Downsample, upsample=Upsample, **kwargs):
655 | super().__init__()
656 |
657 | self.num_enc_layers = len(depths)//2
658 | self.num_dec_layers = len(depths)//2
659 | self.embed_dim = embed_dim
660 | self.patch_norm = patch_norm
661 | self.mlp_ratio = mlp_ratio
662 | self.token_projection = token_projection
663 | self.mlp = token_mlp
664 | self.win_size =win_size
665 | self.reso = img_size
666 | self.pos_drop = nn.Dropout(p=drop_rate)
667 |
668 | # stochastic depth
669 | enc_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths[:self.num_enc_layers]))]
670 | conv_dpr = [drop_path_rate]*depths[4]
671 | dec_dpr = enc_dpr[::-1]
672 |
673 | # build layers
674 |
675 | # Input/Output
676 | self.input_proj = InputProj(in_channel=in_chans, out_channel=embed_dim, kernel_size=3, stride=1, act_layer=nn.LeakyReLU)
677 | self.output_proj = OutputProj(in_channel=3*embed_dim, out_channel=in_chans, kernel_size=3, stride=1)
678 | self.output_proj1 = OutputProj(in_channel=embed_dim, out_channel=in_chans, kernel_size=3, stride=1)
679 | # Encoder
680 | self.encoderlayer_0 = TRANSFORMER_BLOCK(dim=embed_dim,
681 | output_dim=embed_dim,
682 | input_resolution=(img_size,
683 | img_size),
684 | depth=depths[0],
685 | num_heads=num_heads[0],
686 | win_size=win_size,
687 | mlp_ratio=self.mlp_ratio,
688 | qkv_bias=qkv_bias, qk_scale=qk_scale,
689 | drop=drop_rate, attn_drop=attn_drop_rate,
690 | drop_path=enc_dpr[sum(depths[:0]):sum(depths[:1])],
691 | norm_layer=norm_layer,
692 | use_checkpoint=use_checkpoint,
693 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer)
694 | self.dowsample_0 = dowsample(embed_dim, embed_dim*2)
695 | self.hs_downsample = dowsample(embed_dim, embed_dim)
696 | self.hs_upsample = upsample(3*embed_dim,embed_dim)
697 | self.qs_upsample = upsample(5*embed_dim,embed_dim)
698 | self.qs_upsample1 = upsample(embed_dim,embed_dim)
699 |
700 | self.upsample_poolup1 = upsample(3*embed_dim,embed_dim)
701 | self.upsample_poolup2 = upsample(5*embed_dim,embed_dim*3)
702 |
703 | self.encoderlayer_1 = TRANSFORMER_BLOCK(dim=embed_dim*3,
704 | output_dim=embed_dim*2,
705 | input_resolution=(img_size // 2,
706 | img_size // 2),
707 | depth=depths[1],
708 | num_heads=num_heads[1],
709 | win_size=win_size,
710 | mlp_ratio=self.mlp_ratio,
711 | qkv_bias=qkv_bias, qk_scale=qk_scale,
712 | drop=drop_rate, attn_drop=attn_drop_rate,
713 | drop_path=enc_dpr[sum(depths[:1]):sum(depths[:2])],
714 | norm_layer=norm_layer,
715 | use_checkpoint=use_checkpoint,
716 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer)
717 | self.dowsample_1 = dowsample(embed_dim*3, embed_dim*4)
718 | self.encoderlayer_2 = TRANSFORMER_BLOCK(dim=embed_dim*5,
719 | output_dim=embed_dim*4,
720 | input_resolution=(img_size // (2 ** 2),
721 | img_size // (2 ** 2)),
722 | depth=depths[2],
723 | num_heads=num_heads[2],
724 | win_size=win_size,
725 | mlp_ratio=self.mlp_ratio,
726 | qkv_bias=qkv_bias, qk_scale=qk_scale,
727 | drop=drop_rate, attn_drop=attn_drop_rate,
728 | drop_path=enc_dpr[sum(depths[:2]):sum(depths[:3])],
729 | norm_layer=norm_layer,
730 | use_checkpoint=use_checkpoint,
731 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer)
732 | self.dowsample_2 = dowsample(embed_dim*5, embed_dim*8)
733 | self.encoderlayer_3 = TRANSFORMER_BLOCK(dim=embed_dim*8,
734 | output_dim=embed_dim*8,
735 | input_resolution=(img_size // (2 ** 3),
736 | img_size // (2 ** 3)),
737 | depth=depths[3],
738 | num_heads=num_heads[3],
739 | win_size=win_size,
740 | mlp_ratio=self.mlp_ratio,
741 | qkv_bias=qkv_bias, qk_scale=qk_scale,
742 | drop=drop_rate, attn_drop=attn_drop_rate,
743 | drop_path=enc_dpr[sum(depths[:3]):sum(depths[:4])],
744 | norm_layer=norm_layer,
745 | use_checkpoint=use_checkpoint,
746 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer)
747 | self.dowsample_3 = dowsample(embed_dim*8, embed_dim*16)
748 |
749 | # Bottleneck
750 | self.conv = TRANSFORMER_BLOCK(dim=embed_dim*9,
751 | output_dim=embed_dim*8,
752 | input_resolution=(img_size // (2 ** 4),
753 | img_size // (2 ** 4)),
754 | depth=depths[4],
755 | num_heads=num_heads[4],
756 | win_size=win_size,
757 | mlp_ratio=self.mlp_ratio,
758 | qkv_bias=qkv_bias, qk_scale=qk_scale,
759 | drop=drop_rate, attn_drop=attn_drop_rate,
760 | drop_path=conv_dpr,
761 | norm_layer=norm_layer,
762 | use_checkpoint=use_checkpoint,
763 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer)
764 |
765 | # Decoder
766 | self.upsample_0 = upsample(embed_dim*9, embed_dim*8)
767 | self.decoderlayer_0 = TRANSFORMER_BLOCK(dim=embed_dim*14,
768 | output_dim=embed_dim*16,
769 | input_resolution=(img_size // (2 ** 3),
770 | img_size // (2 ** 3)),
771 | depth=depths[5],
772 | num_heads=num_heads[5],
773 | win_size=win_size,
774 | mlp_ratio=self.mlp_ratio,
775 | qkv_bias=qkv_bias, qk_scale=qk_scale,
776 | drop=drop_rate, attn_drop=attn_drop_rate,
777 | drop_path=dec_dpr[:depths[5]],
778 | norm_layer=norm_layer,
779 | use_checkpoint=use_checkpoint,
780 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer)
781 | self.upsample_1 = upsample(embed_dim*14, embed_dim*4)
782 | self.decoderlayer_1 = TRANSFORMER_BLOCK(dim=embed_dim*8,
783 | output_dim=embed_dim*8,
784 | input_resolution=(img_size // (2 ** 2),
785 | img_size // (2 ** 2)),
786 | depth=depths[6],
787 | num_heads=num_heads[6],
788 | win_size=win_size,
789 | mlp_ratio=self.mlp_ratio,
790 | qkv_bias=qkv_bias, qk_scale=qk_scale,
791 | drop=drop_rate, attn_drop=attn_drop_rate,
792 | drop_path=dec_dpr[sum(depths[5:6]):sum(depths[5:7])],
793 | norm_layer=norm_layer,
794 | use_checkpoint=use_checkpoint,
795 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer)
796 | self.upsample_2 = upsample(embed_dim*8, embed_dim*2)
797 | self.decoderlayer_2 = TRANSFORMER_BLOCK(dim=embed_dim*3,
798 | output_dim=embed_dim*4,
799 | input_resolution=(img_size // 2,
800 | img_size // 2),
801 | depth=depths[7],
802 | num_heads=num_heads[7],
803 | win_size=win_size,
804 | mlp_ratio=self.mlp_ratio,
805 | qkv_bias=qkv_bias, qk_scale=qk_scale,
806 | drop=drop_rate, attn_drop=attn_drop_rate,
807 | drop_path=dec_dpr[sum(depths[5:7]):sum(depths[5:8])],
808 | norm_layer=norm_layer,
809 | use_checkpoint=use_checkpoint,
810 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer)
811 | self.upsample_3 = upsample(embed_dim*4, embed_dim)
812 | self.decoderlayer_3 = TRANSFORMER_BLOCK(dim=embed_dim*2,
813 | output_dim=embed_dim*2,
814 | input_resolution=(img_size,
815 | img_size),
816 | depth=depths[8],
817 | num_heads=num_heads[8],
818 | win_size=win_size,
819 | mlp_ratio=self.mlp_ratio,
820 | qkv_bias=qkv_bias, qk_scale=qk_scale,
821 | drop=drop_rate, attn_drop=attn_drop_rate,
822 | drop_path=dec_dpr[sum(depths[5:8]):sum(depths[5:9])],
823 | norm_layer=norm_layer,
824 | use_checkpoint=use_checkpoint,
825 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer)
826 | # 1/2 scale stream
827 | self.hs1 = TRANSFORMER_BLOCK(dim=embed_dim,
828 | output_dim=embed_dim,
829 | input_resolution=(img_size // 2,
830 | img_size // 2),
831 | depth=depths[1],
832 | num_heads=num_heads[1],
833 | win_size=win_size,
834 | mlp_ratio=self.mlp_ratio,
835 | qkv_bias=qkv_bias, qk_scale=qk_scale,
836 | drop=drop_rate, attn_drop=attn_drop_rate,
837 | drop_path=enc_dpr[sum(depths[:1]):sum(depths[:2])],
838 | norm_layer=norm_layer,
839 | use_checkpoint=use_checkpoint,
840 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer)
841 | self.hs2 = TRANSFORMER_BLOCK(dim=3*embed_dim,
842 | output_dim=embed_dim,
843 | input_resolution=(img_size // 2,
844 | img_size // 2),
845 | depth=depths[1],
846 | num_heads=num_heads[1],
847 | win_size=win_size,
848 | mlp_ratio=self.mlp_ratio,
849 | qkv_bias=qkv_bias, qk_scale=qk_scale,
850 | drop=drop_rate, attn_drop=attn_drop_rate,
851 | drop_path=enc_dpr[sum(depths[:1]):sum(depths[:2])],
852 | norm_layer=norm_layer,
853 | use_checkpoint=use_checkpoint,
854 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer)
855 | self.hs3 = TRANSFORMER_BLOCK(dim=3*embed_dim,
856 | output_dim=embed_dim,
857 | input_resolution=(img_size // 2,
858 | img_size // 2),
859 | depth=depths[1],
860 | num_heads=num_heads[1],
861 | win_size=win_size,
862 | mlp_ratio=self.mlp_ratio,
863 | qkv_bias=qkv_bias, qk_scale=qk_scale,
864 | drop=drop_rate, attn_drop=attn_drop_rate,
865 | drop_path=enc_dpr[sum(depths[:1]):sum(depths[:2])],
866 | norm_layer=norm_layer,
867 | use_checkpoint=use_checkpoint,
868 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer)
869 | self.hs4 = TRANSFORMER_BLOCK(dim=3*embed_dim,
870 | output_dim=embed_dim,
871 | input_resolution=(img_size // 2,
872 | img_size // 2),
873 | depth=depths[1],
874 | num_heads=num_heads[1],
875 | win_size=win_size,
876 | mlp_ratio=self.mlp_ratio,
877 | qkv_bias=qkv_bias, qk_scale=qk_scale,
878 | drop=drop_rate, attn_drop=attn_drop_rate,
879 | drop_path=enc_dpr[sum(depths[:1]):sum(depths[:2])],
880 | norm_layer=norm_layer,
881 | use_checkpoint=use_checkpoint,
882 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer)
883 |
884 | self.qs1 = TRANSFORMER_BLOCK(dim=embed_dim,
885 | output_dim=embed_dim,
886 | input_resolution=(img_size // (2 ** 2),
887 | img_size // (2 ** 2)),
888 | depth=depths[2],
889 | num_heads=num_heads[2],
890 | win_size=win_size,
891 | mlp_ratio=self.mlp_ratio,
892 | qkv_bias=qkv_bias, qk_scale=qk_scale,
893 | drop=drop_rate, attn_drop=attn_drop_rate,
894 | drop_path=enc_dpr[sum(depths[:2]):sum(depths[:3])],
895 | norm_layer=norm_layer,
896 | use_checkpoint=use_checkpoint,
897 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer)
898 | self.qs2 = TRANSFORMER_BLOCK(dim=5*embed_dim,
899 | output_dim=embed_dim,
900 | input_resolution=(img_size // (2 ** 2),
901 | img_size // (2 ** 2)),
902 | depth=depths[2],
903 | num_heads=num_heads[2],
904 | win_size=win_size,
905 | mlp_ratio=self.mlp_ratio,
906 | qkv_bias=qkv_bias, qk_scale=qk_scale,
907 | drop=drop_rate, attn_drop=attn_drop_rate,
908 | drop_path=enc_dpr[sum(depths[:2]):sum(depths[:3])],
909 | norm_layer=norm_layer,
910 | use_checkpoint=use_checkpoint,
911 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer)
912 | self.qs3 = TRANSFORMER_BLOCK(dim=5*embed_dim,
913 | output_dim=embed_dim,
914 | input_resolution=(img_size // (2 ** 2),
915 | img_size // (2 ** 2)),
916 | depth=depths[2],
917 | num_heads=num_heads[2],
918 | win_size=win_size,
919 | mlp_ratio=self.mlp_ratio,
920 | qkv_bias=qkv_bias, qk_scale=qk_scale,
921 | drop=drop_rate, attn_drop=attn_drop_rate,
922 | drop_path=enc_dpr[sum(depths[:2]):sum(depths[:3])],
923 | norm_layer=norm_layer,
924 | use_checkpoint=use_checkpoint,
925 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer)
926 | self.qs4 = TRANSFORMER_BLOCK(dim=5*embed_dim,
927 | output_dim=embed_dim,
928 | input_resolution=(img_size // (2 ** 2),
929 | img_size // (2 ** 2)),
930 | depth=depths[2],
931 | num_heads=num_heads[2],
932 | win_size=win_size,
933 | mlp_ratio=self.mlp_ratio,
934 | qkv_bias=qkv_bias, qk_scale=qk_scale,
935 | drop=drop_rate, attn_drop=attn_drop_rate,
936 | drop_path=enc_dpr[sum(depths[:2]):sum(depths[:3])],
937 | norm_layer=norm_layer,
938 | use_checkpoint=use_checkpoint,
939 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer)
940 |
941 | # self.edge_booster =
942 | self.apply(self._init_weights)
943 |
944 | def _init_weights(self, m):
945 | if isinstance(m, nn.Linear):
946 | trunc_normal_(m.weight, std=.02)
947 | if isinstance(m, nn.Linear) and m.bias is not None:
948 | nn.init.constant_(m.bias, 0)
949 | elif isinstance(m, nn.LayerNorm):
950 | nn.init.constant_(m.bias, 0)
951 | nn.init.constant_(m.weight, 1.0)
952 |
953 | @torch.jit.ignore
954 | def no_weight_decay(self):
955 | return {'absolute_pos_embed'}
956 |
957 | @torch.jit.ignore
958 | def no_weight_decay_keywords(self):
959 | return {'relative_position_bias_table'}
960 |
961 | def extra_repr(self) -> str:
962 | return f"embed_dim={self.embed_dim}, token_projection={self.token_projection}, token_mlp={self.mlp},win_size={self.win_size}"
963 |
964 | def forward(self, x, mask=None):
965 | y = self.input_proj(x)
966 | y = self.pos_drop(y)
967 | conv01 = self.encoderlayer_0(y,mask=mask)
968 | conv11 = self.encoderlayer_0(conv01,mask=mask)
969 | conv21 = self.encoderlayer_0(conv11,mask=mask)
970 | conv31 = self.encoderlayer_0(conv21,mask=mask)
971 | conv41 = self.encoderlayer_0(conv31,mask=mask)
972 | conv51= self.encoderlayer_0(conv41,mask=mask)
973 | conv0 = self.encoderlayer_0(y,mask=mask)
974 | pool0 = self.dowsample_0(conv0)
975 | pool0 = torch.cat([pool0, self.hs_downsample(conv01)],-1)
976 | poolup0 = self.upsample_poolup1(pool0)
977 | difference1 = conv0 - poolup0
978 | conv1 = self.encoderlayer_1(pool0,mask=mask)
979 | pool1 = self.dowsample_1(conv1)
980 | pool1 = torch.cat([pool1, self.hs_downsample(self.hs_downsample(conv11))],-1)
981 | poolup1 = self.upsample_poolup2(pool1)
982 | difference2 = conv1 - poolup1
983 |
984 | conv2 = self.encoderlayer_2(pool1,mask=mask)
985 | pool2 = self.dowsample_2(conv2)
986 | pool2 = torch.cat([pool2, self.hs_downsample(self.hs_downsample(self.hs_downsample(conv21)))],-1)
987 | conv3 = self.conv(pool2, mask=mask)
988 | up0 = self.upsample_0(conv3)
989 | deconv0 = torch.cat([up0,conv2,self.hs_downsample(self.hs_downsample(conv31))],-1)
990 | deconv0 = self.decoderlayer_0(deconv0,mask=mask)
991 | up1 = self.upsample_1(deconv0)
992 | deconv1 = torch.cat([up1,difference2,self.hs_downsample(conv41)],-1)
993 | deconv1 = self.decoderlayer_1(deconv1,mask=mask)
994 | up2 = self.upsample_2(deconv1)
995 | deconv2 = torch.cat([up2,difference1],-1)
996 | deconv2 = self.decoderlayer_2(deconv2,mask=mask)
997 | z = self.output_proj1(conv51)
998 | y = self.output_proj(deconv2)
999 | return x + z + y
1000 |
1001 |
1002 | if __name__ == '__main__':
1003 | net = AIDTransformer()
1004 | device = torch.device('cpu')
1005 | net.to(device)
1006 | # summary(net.cuda(), (305, 64, 64))
1007 |
1008 | # x = torch.randn(4, 305, 64, 64).cuda()
1009 | # mcplb = AIDTransformer().cuda()
1010 | # out = mcplb(x)
1011 | # summary(mcplb.cuda(), (305, 64, 64))
--------------------------------------------------------------------------------