├── LICENSE
├── README.md
├── common.py
├── data
├── HStest.py
├── HStrain.py
└── __init__.py
├── demo.sh
├── loss.py
├── main_CST.py
├── metrics.py
├── network
├── CST.py
└── csa.py
├── test_demo.sh
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Tomchenshi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CST
2 | The pytorch code of hyperspectral image super-resolution method CST
3 |
4 | ## Requirements
5 | * Python 3.6.13
6 | * Pytorch 1.8.
7 |
8 | ## Preparation
9 | To get the training set, validation set and testing set, refer to SSPSR to download the mcodes for cropping the hyperspectral image.
10 |
11 | ## Training
12 | To train CST, run the following command.
13 | ```
14 | sh demo.sh
15 | ```
16 | ## Testing
17 | run the the following command.
18 | ```
19 | sh test_demo.sh
20 | ```
21 | ## References
22 | * [SSPSR](https://github.com/junjun-jiang/SSPSR)
23 |
--------------------------------------------------------------------------------
/common.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 |
8 |
9 | def default_conv(in_channels, out_channels, kernel_size, stride=1, bias=True, dilation=1, groups=1):
10 | if dilation==1:
11 | return nn.Conv2d(
12 | in_channels, out_channels, kernel_size,
13 | padding=(kernel_size//2), bias=bias, groups=groups)
14 | elif dilation==2:
15 | return nn.Conv2d(
16 | in_channels, out_channels, kernel_size,
17 | padding=2, bias=bias, dilation=dilation, groups=groups)
18 |
19 | else:
20 | padding = int((kernel_size - 1) / 2) * dilation
21 | return nn.Conv2d(
22 | in_channels, out_channels, kernel_size,
23 | stride, padding=padding, bias=bias, dilation=dilation, groups=groups)
24 |
25 |
26 | class CALayer(nn.Module):
27 | def __init__(self, channel, reduction=16):
28 | super(CALayer, self).__init__()
29 | # global average pooling: feature --> point
30 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
31 | # feature channel downscale and upscale --> channel weight
32 | self.conv_du = nn.Sequential(
33 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
34 | nn.ReLU(inplace=True),
35 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
36 | nn.Sigmoid()
37 | )
38 |
39 | def forward(self, x):
40 | y = self.avg_pool(x)
41 | y = self.conv_du(y)
42 | return x * y
43 |
44 | def mean_channels(F):
45 | assert(F.dim() == 4)
46 | spatial_sum = F.sum(3, keepdim=True).sum(2, keepdim=True)
47 | return spatial_sum / (F.size(2) * F.size(3))
48 |
49 | def stdv_channels(F):
50 | assert(F.dim() == 4)
51 | F_mean = mean_channels(F)
52 | F_variance = (F - F_mean).pow(2).sum(3, keepdim=True).sum(2, keepdim=True) / (F.size(2) * F.size(3))
53 | return F_variance.pow(0.5)
54 |
55 | class Upsampler(nn.Sequential):
56 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
57 | m = []
58 | if (scale & (scale - 1)) == 0: # Is scale = 2^n?
59 | for _ in range(int(math.log(scale, 2))):
60 | m.append(conv(n_feats, 4 * n_feats, 3, bias))
61 | m.append(nn.PixelShuffle(2))
62 | if bn:
63 | m.append(nn.BatchNorm2d(n_feats))
64 | if act == 'relu':
65 | m.append(nn.ReLU(True))
66 | elif act == 'prelu':
67 | m.append(nn.PReLU(n_feats))
68 |
69 | elif scale == 3:
70 | m.append(conv(n_feats, 9 * n_feats, 3, bias))
71 | m.append(nn.PixelShuffle(3))
72 | if bn:
73 | m.append(nn.BatchNorm2d(n_feats))
74 | if act == 'relu':
75 | m.append(nn.ReLU(True))
76 | elif act == 'prelu':
77 | m.append(nn.PReLU(n_feats))
78 | else:
79 | raise NotImplementedError
80 |
81 | super(Upsampler, self).__init__(*m)
82 |
83 |
84 | class ResAttentionBlock(nn.Module):
85 | def __init__(self, conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
86 | super(ResAttentionBlock, self).__init__()
87 | m = []
88 | for i in range(2):
89 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
90 | if bn:
91 | m.append(nn.BatchNorm2d(n_feats))
92 | if i == 0:
93 | m.append(act)
94 |
95 | m.append(CALayer(n_feats, 16))
96 |
97 | self.body = nn.Sequential(*m)
98 | self.res_scale = res_scale
99 |
100 | def forward(self, x):
101 | res = self.body(x).mul(self.res_scale)
102 | res += x
103 | return res
--------------------------------------------------------------------------------
/data/HStest.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch.utils.data as data
3 | import scipy.io as sio
4 | import torch
5 |
6 |
7 | class HSTestData(data.Dataset):
8 | def __init__(self, image_dir, use_3D=False):
9 | test_data = sio.loadmat(image_dir)
10 | self.use_3Dconv = use_3D
11 | self.ms = np.array(test_data['ms'][...], dtype=np.float32)
12 | self.lms = np.array(test_data['ms_bicubic'][...], dtype=np.float32)
13 | self.gt = np.array(test_data['gt'][...], dtype=np.float32)
14 |
15 | def __getitem__(self, index):
16 | gt = self.gt[index, :, :, :]
17 | ms = self.ms[index, :, :, :]
18 | lms = self.lms[index, :, :, :]
19 | if self.use_3Dconv:
20 | ms, lms, gt = ms[np.newaxis, :, :, :], lms[np.newaxis, :, :, :], gt[np.newaxis, :, :, :]
21 | ms = torch.from_numpy(ms.copy()).permute(0, 3, 1, 2)
22 | lms = torch.from_numpy(lms.copy()).permute(0, 3, 1, 2)
23 | gt = torch.from_numpy(gt.copy()).permute(0, 3, 1, 2)
24 | else:
25 | ms = torch.from_numpy(ms.copy()).permute(2, 0, 1)
26 | lms = torch.from_numpy(lms.copy()).permute(2, 0, 1)
27 | gt = torch.from_numpy(gt.copy()).permute(2, 0, 1)
28 | #ms = torch.from_numpy(ms.transpose((2, 0, 1)))
29 | #lms = torch.from_numpy(lms.transpose((2, 0, 1)))
30 | #gt = torch.from_numpy(gt.transpose((2, 0, 1)))
31 | return ms, lms, gt
32 |
33 | def __len__(self):
34 | return self.gt.shape[0]
35 |
--------------------------------------------------------------------------------
/data/HStrain.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch.utils.data as data
3 | import scipy.io as sio
4 | import torch
5 | import os
6 | import utils
7 |
8 |
9 | def is_mat_file(filename):
10 | return any(filename.endswith(extension) for extension in [".mat"])
11 |
12 |
13 | class HSTrainingData(data.Dataset):
14 | def __init__(self, image_dir, augment=None, use_3D=False):
15 | self.image_files = [os.path.join(image_dir, x) for x in os.listdir(image_dir) if is_mat_file(x)]
16 | # self.image_files = []
17 | # for i in self.image_folders:
18 | # images = os.listdir(i)
19 | # for j in images:
20 | # if is_mat_file(j):
21 | # full_path = os.path.join(i, j)
22 | # self.image_files.append(full_path)
23 | self.augment = augment
24 | self.use_3Dconv = use_3D
25 | if self.augment:
26 | self.factor = 8
27 | else:
28 | self.factor = 1
29 |
30 | def __getitem__(self, index):
31 | file_index = index
32 | aug_num = 0
33 | if self.augment:
34 | file_index = index // self.factor #
35 | aug_num = int(index % self.factor) # 0-7
36 | load_dir = self.image_files[file_index]
37 | data = sio.loadmat(load_dir)
38 | ms = np.array(data['ms'][...], dtype=np.float32)
39 | lms = np.array(data['ms_bicubic'][...], dtype=np.float32)
40 | gt = np.array(data['gt'][...], dtype=np.float32)
41 | ms, lms, gt = utils.data_augmentation(ms, mode=aug_num), utils.data_augmentation(lms, mode=aug_num), \
42 | utils.data_augmentation(gt, mode=aug_num)
43 | if self.use_3Dconv:
44 | ms, lms, gt = ms[np.newaxis, :, :, :], lms[np.newaxis, :, :, :], gt[np.newaxis, :, :, :]
45 | ms = torch.from_numpy(ms.copy()).permute(0, 3, 1, 2)
46 | lms = torch.from_numpy(lms.copy()).permute(0, 3, 1, 2)
47 | gt = torch.from_numpy(gt.copy()).permute(0, 3, 1, 2)
48 | else:
49 | ms = torch.from_numpy(ms.copy()).permute(2, 0, 1)
50 | lms = torch.from_numpy(lms.copy()).permute(2, 0, 1)
51 | gt = torch.from_numpy(gt.copy()).permute(2, 0, 1)
52 | return ms, lms, gt
53 |
54 | def __len__(self):
55 | return len(self.image_files)*self.factor
56 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .HStrain import HSTrainingData
2 | from .HStest import HSTestData
3 |
--------------------------------------------------------------------------------
/demo.sh:
--------------------------------------------------------------------------------
1 | python main_CST.py train --dataset "Chikusei" --n_scale 4 --gpus "0,1"
2 |
3 | python main_CST.py train --dataset "Houston" --n_scale 4 --gpus "0,1"
4 |
5 | python main_CST.py train --dataset "Pavia" --n_scale 4 --gpus "0,1"
6 |
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class HLoss(torch.nn.Module):
6 | def __init__(self, la1, la2, sam=True, gra=True):
7 | super(HLoss, self).__init__()
8 | self.lamd1 = la1
9 | self.lamd2 = la2
10 | self.sam = sam
11 | self.gra = gra
12 |
13 | self.fidelity = torch.nn.L1Loss()
14 | self.gra = torch.nn.L1Loss()
15 |
16 | def forward(self, y, gt):
17 | loss1 = self.fidelity(y, gt)
18 | loss2 = self.lamd1 * cal_sam(y, gt)
19 | loss3 = self.lamd2 * self.gra(cal_gradient(y), cal_gradient(gt))
20 | loss = loss1 + loss2 + loss3
21 | return loss
22 |
23 |
24 | class HyLoss(torch.nn.Module):
25 | def __init__(self, la1=0.1):
26 | super(HyLoss, self).__init__()
27 | self.lamd1 = la1
28 | self.fidelity = torch.nn.L1Loss()
29 |
30 | def forward(self, y, gt):
31 | loss1 = self.fidelity(y, gt)
32 | loss2 = self.lamd1 * cal_sam(y, gt)
33 | loss = loss1 + loss2
34 | return loss
35 |
36 | class HybridLoss(torch.nn.Module):
37 | def __init__(self, lamd=1e-1, spatial_tv=False, spectral_tv=False):
38 | super(HybridLoss, self).__init__()
39 | self.lamd = lamd
40 | self.use_spatial_TV = spatial_tv
41 | self.use_spectral_TV = spectral_tv
42 | self.fidelity = torch.nn.L1Loss()
43 | self.spatial = TVLoss(weight=1e-3)
44 | self.spectral = TVLossSpectral(weight=1e-3)
45 |
46 | def forward(self, y, gt):
47 | loss = self.fidelity(y, gt)
48 | spatial_TV = 0.0
49 | spectral_TV = 0.0
50 | if self.use_spatial_TV:
51 | spatial_TV = self.spatial(y)
52 | if self.use_spectral_TV:
53 | spectral_TV = self.spectral(y)
54 | total_loss = loss + spatial_TV + spectral_TV
55 | return total_loss
56 |
57 |
58 | # from https://github.com/jxgu1016/Total_Variation_Loss.pytorch with slight modifications
59 | class TVLoss(torch.nn.Module):
60 | def __init__(self, weight=1.0):
61 | super(TVLoss, self).__init__()
62 | self.TVLoss_weight = weight
63 |
64 | def forward(self, x):
65 | batch_size = x.size()[0]
66 | h_x = x.size()[2]
67 | w_x = x.size()[3]
68 | count_h = self._tensor_size(x[:, :, 1:, :])
69 | count_w = self._tensor_size(x[:, :, :, 1:])
70 | # h_tv = torch.abs(x[:, :, 1:, :] - x[:, :, :h_x - 1, :]).sum()
71 | # w_tv = torch.abs(x[:, :, :, 1:] - x[:, :, :, :w_x - 1]).sum()
72 | h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
73 | w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
74 | return self.TVLoss_weight * (h_tv / count_h + w_tv / count_w) / batch_size
75 |
76 | def _tensor_size(self, t):
77 | return t.size()[1] * t.size()[2] * t.size()[3]
78 |
79 |
80 | class TVLossSpectral(torch.nn.Module):
81 | def __init__(self, weight=1.0):
82 | super(TVLossSpectral, self).__init__()
83 | self.TVLoss_weight = weight
84 |
85 | def forward(self, x):
86 | batch_size = x.size()[0]
87 | c_x = x.size()[1]
88 | count_c = self._tensor_size(x[:, 1:, :, :])
89 | # c_tv = torch.abs((x[:, 1:, :, :] - x[:, :c_x - 1, :, :])).sum()
90 | c_tv = torch.pow((x[:, 1:, :, :] - x[:, :c_x - 1, :, :]), 2).sum()
91 | return self.TVLoss_weight * 2 * (c_tv / count_c) / batch_size
92 |
93 | def _tensor_size(self, t):
94 | return t.size()[1] * t.size()[2] * t.size()[3]
95 |
96 | def cal_sam(Itrue, Ifake):
97 | esp = 1e-6
98 | InnerPro = torch.sum(Itrue*Ifake,1,keepdim=True)
99 | len1 = torch.norm(Itrue, p=2,dim=1,keepdim=True)
100 | len2 = torch.norm(Ifake, p=2,dim=1,keepdim=True)
101 | divisor = len1*len2
102 | mask = torch.eq(divisor,0)
103 | divisor = divisor + (mask.float())*esp
104 | cosA = torch.sum(InnerPro/divisor,1).clamp(-1+esp, 1-esp)
105 | sam = torch.acos(cosA)
106 | sam = torch.mean(sam) / np.pi
107 | return sam
108 |
109 |
110 | def cal_gradient_c(x):
111 | c_x = x.size(1)
112 | g = x[:, 1:, 1:, 1:] - x[:, :c_x - 1, 1:, 1:]
113 | return g
114 |
115 |
116 | def cal_gradient_x(x):
117 | c_x = x.size(2)
118 | g = x[:, 1:, 1:, 1:] - x[:, 1:, :c_x - 1, 1:]
119 | return g
120 |
121 |
122 | def cal_gradient_y(x):
123 | c_x = x.size(3)
124 | g = x[:, 1:, 1:, 1:] - x[:, 1:, 1:, :c_x - 1]
125 | return g
126 |
127 |
128 | def cal_gradient(inp):
129 | x = cal_gradient_x(inp)
130 | y = cal_gradient_y(inp)
131 | c = cal_gradient_c(inp)
132 | g = torch.sqrt(torch.pow(x, 2) + torch.pow(y, 2) + torch.pow(c, 2) + 1e-6)
133 | return g
--------------------------------------------------------------------------------
/main_CST.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import random
5 | import time
6 | import torch
7 | import cv2
8 | import math
9 | import numpy as np
10 | import torch.backends.cudnn as cudnn
11 | from torch.optim import Adam
12 | from torch.utils.data import DataLoader
13 | from tensorboardX import SummaryWriter
14 | from torchnet import meter
15 | import json
16 | from tqdm import tqdm
17 | from data import HSTrainingData
18 | from data import HSTestData
19 | from network.CST import *
20 | from common import *
21 | from metrics import compare_mpsnr
22 | # loss
23 | from loss import HLoss
24 | # from loss import HyLapLoss
25 | from metrics import quality_assessment
26 |
27 | # global settings
28 | resume = False
29 | log_interval = 50
30 | model_name = ''
31 | test_data_dir = ''
32 |
33 |
34 | def main():
35 | # parsers
36 | main_parser = argparse.ArgumentParser(description="parser for SR network")
37 | subparsers = main_parser.add_subparsers(title="subcommands", dest="subcommand")
38 | train_parser = subparsers.add_parser("train", help="parser for training arguments")
39 | train_parser.add_argument("--cuda", type=int, required=False,default=1,
40 | help="set it to 1 for running on GPU, 0 for CPU")
41 | train_parser.add_argument("--batch_size", type=int, default=32, help="batch size, default set to 64")
42 | train_parser.add_argument("--epochs", type=int, default=300, help="epochs, default set to 20")
43 | train_parser.add_argument("--n_feats", type=int, default=180, help="n_feats, default set to 256")
44 | train_parser.add_argument("--n_scale", type=int, default=4, help="n_scale, default set to 2")
45 | train_parser.add_argument("--dataset_name", type=str, default="Chikusei", help="dataset_name, default set to dataset_name")
46 | train_parser.add_argument("--model_title", type=str, default="CST", help="model_title, default set to model_title")
47 | train_parser.add_argument("--seed", type=int, default=3000, help="start seed for model")
48 | train_parser.add_argument('--la1', type=float, default=0.3, help="")
49 | train_parser.add_argument('--la2', type=float, default=0.1, help="")
50 | train_parser.add_argument("--learning_rate", type=float, default=1e-4,
51 | help="learning rate, default set to 1e-4")
52 | train_parser.add_argument("--weight_decay", type=float, default=0, help="weight decay, default set to 0")
53 | train_parser.add_argument("--gpus", type=str, default="1", help="gpu ids (default: 7)")
54 |
55 | test_parser = subparsers.add_parser("test", help="parser for testing arguments")
56 | test_parser.add_argument("--cuda", type=int, required=False,default=1,
57 | help="set it to 1 for running on GPU, 0 for CPU")
58 | test_parser.add_argument("--gpus", type=str, default="0,1", help="gpu ids (default: 7)")
59 | test_parser.add_argument("--dataset_name", type=str, default="Chikusei",help="dataset_name, default set to dataset_name")
60 | test_parser.add_argument("--model_title", type=str, default="CST",help="model_title, default set to model_title")
61 | test_parser.add_argument("--n_feats", type=int, default=180, help="n_feats, default set to 256")
62 | test_parser.add_argument("--n_scale", type=int, default=4, help="n_scale, default set to 2")
63 |
64 |
65 | args = main_parser.parse_args()
66 | print('===>GPU:',args.gpus)
67 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
68 | if args.subcommand is None:
69 | print("ERROR: specify either train or test")
70 | sys.exit(1)
71 | if args.cuda and not torch.cuda.is_available():
72 | print("ERROR: cuda is not available, try running on CPU")
73 | sys.exit(1)
74 | if args.subcommand == "train":
75 | train(args)
76 | else:
77 | test(args)
78 | pass
79 |
80 |
81 | def train(args):
82 | traintime = str(time.ctime())
83 | device = torch.device("cuda" if args.cuda else "cpu")
84 | # args.seed = random.randint(1, 10000)
85 | print("Start seed: ", args.seed)
86 | torch.manual_seed(args.seed)
87 | if args.cuda:
88 | torch.cuda.manual_seed(args.seed)
89 | cudnn.benchmark = True
90 |
91 | print('===> Loading datasets')
92 | train_path = './datasets/'+args.dataset_name+'_x'+str(args.n_scale)+'/trains/'
93 | result_path = './results/' + args.dataset_name + '_x' + str(args.n_scale)+'/'
94 | test_data_dir = './datasets/'+args.dataset_name+'_x'+str(args.n_scale)+'/'+args.dataset_name+'_test.mat'
95 |
96 | train_set = HSTrainingData(image_dir=train_path, augment=True)
97 |
98 | train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=8, shuffle=True)
99 | test_set = HSTestData(test_data_dir)
100 | test_loader = DataLoader(test_set, batch_size=1, shuffle=False)
101 |
102 | if args.dataset_name=='Cave':
103 | colors = 31
104 | elif args.dataset_name=='Pavia':
105 | colors = 102
106 | elif args.dataset_name=='Houston':
107 | colors = 48
108 | else:
109 | colors = 128
110 |
111 | print('===> Building model:{}'.format(args.model_title))
112 | net = CST(inp_channels=colors, dim=args.n_feats, depths=[4,4,4,4], num_heads=[6,6,6,6],mlp_ratio=2, scale=args.n_scale)
113 | # print(net)
114 | model_title = args.dataset_name + "_" + args.model_title+'_x'+ str(args.n_scale)
115 |
116 | args.model_title = model_title
117 |
118 | if torch.cuda.device_count() > 1:
119 | print("===> Let's use", torch.cuda.device_count(), "GPUs.")
120 | net = torch.nn.DataParallel(net)
121 | start_epoch = 0
122 |
123 | if resume:
124 | model_name = './checkpoints/' + model_title + "_ckpt_epoch_" + str(300) + ".pth"
125 | if os.path.isfile(model_name):
126 | print("=> loading checkpoint '{}'".format(model_name))
127 | checkpoint = torch.load(model_name)
128 | start_epoch = checkpoint["epoch"]
129 | net.load_state_dict(checkpoint["model"].state_dict())
130 | else:
131 | print("=> no checkpoint found at '{}'".format(model_name))
132 | net.to(device).train()
133 | print_network(net)
134 | # loss functions to choose
135 | h_loss = HLoss(args.la1,args.la2)
136 |
137 |
138 | print("===> Setting optimizer and logger")
139 | # add adam optimizer
140 | optimizer = Adam(net.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
141 | epoch_meter = meter.AverageValueMeter()
142 | writer = SummaryWriter('runs/'+model_title+'_'+traintime)
143 |
144 |
145 | best_psnr = 0
146 | best_epoch = 0
147 |
148 | print('===> Start training')
149 | for e in range(start_epoch, args.epochs):
150 | psnr = []
151 | adjust_learning_rate(args.learning_rate, optimizer, e+1)
152 | epoch_meter.reset()
153 | net.train()
154 | print("Start epoch {}, learning rate = {}".format(e + 1, optimizer.param_groups[0]["lr"]))
155 | for iteration, (x, lms, gt) in enumerate(tqdm(train_loader, leave=False)):
156 | x, lms, gt = x.to(device), lms.to(device), gt.to(device)
157 | psnr = []
158 | optimizer.zero_grad()
159 | y = net(x, lms)
160 | loss = h_loss(y, gt)
161 | epoch_meter.add(loss.item())
162 | loss.backward()
163 | optimizer.step()
164 | # tensorboard visualization
165 | if (iteration + log_interval) % log_interval == 0:
166 | print("===> {} \tEpoch[{}]({}/{}): Loss: {:.6f}".format(time.ctime(), e + 1, iteration + 1, len(train_loader)-1, loss.item()))
167 | n_iter = e * len(train_loader) + iteration + 1
168 | writer.add_scalar('scalar/train_loss', loss, n_iter)
169 |
170 | print("Running testset")
171 | net.eval()
172 | with torch.no_grad():
173 | output = []
174 | test_number = 0
175 | for i, (ms, lms, gt) in enumerate(test_loader):
176 | ms, lms, gt = ms.to(device), lms.to(device), gt.to(device)
177 | y = net(ms, lms)
178 | y, gt = y.squeeze().cpu().numpy().transpose(1, 2, 0), gt.squeeze().cpu().numpy().transpose(1, 2, 0)
179 | y = y[:gt.shape[0], :gt.shape[1], :]
180 | psnr_value = compare_mpsnr(gt, y, data_range=1.)
181 | psnr.append(psnr_value)
182 | output.append(y)
183 | test_number += 1
184 |
185 | avg_psnr = sum(psnr) / test_number
186 | if avg_psnr >best_psnr:
187 | best_psnr = avg_psnr
188 | best_epoch = e+1
189 | save_checkpoint(args, net, e + 1, traintime)
190 | writer.add_scalar('scalar/test_psnr', avg_psnr, e + 1)
191 |
192 | print("===> {}\tEpoch {} Training Complete: Avg. Loss: {:.6f} PSNR:{:.3f} best_psnr:{:.3f} best_epoch:{}".format(
193 | time.ctime(), e+1, epoch_meter.value()[0], avg_psnr, best_psnr, best_epoch))
194 | # run validation set every epoch
195 | # eval_loss = validate(args, eval_loader, net, L1_loss)
196 | # tensorboard visualization
197 | writer.add_scalar('scalar/avg_epoch_loss', epoch_meter.value()[0], e + 1)
198 | # writer.add_scalar('scalar/avg_validation_loss', eval_loss, e + 1)
199 | # save model weights at checkpoints every 10 epochs
200 | if (e + 1) % 5 == 0:
201 | save_checkpoint(args, net, e+1, traintime)
202 |
203 | ## Save the testing results
204 |
205 | print('===> Start testing')
206 | model_name = './checkpoints/' + traintime +'/' + "_" + args.model_title + "_ckpt_epoch_" + str(best_epoch) + ".pth"
207 | with torch.no_grad():
208 | test_number = 0
209 | epoch_meter = meter.AverageValueMeter()
210 | epoch_meter.reset()
211 | # loading model
212 | net = CST(inp_channels=colors, dim=args.n_feats, depths=[4, 4, 4, 4],
213 | num_heads=[6, 6, 6, 6], mlp_ratio=2, scale=args.n_scale)
214 | net.to(device).eval()
215 | state_dict = torch.load(model_name)
216 | net.load_state_dict(state_dict['model'])
217 |
218 | output = []
219 | for i, (ms, lms, gt) in enumerate(test_loader):
220 | # compute output
221 | ms, lms, gt = ms.to(device), lms.to(device), gt.\
222 | to(device)
223 | # y = model(ms)
224 | y = net(ms, lms)
225 | y, gt = y.squeeze().cpu().numpy().transpose(1, 2, 0), gt.squeeze().cpu().numpy().transpose(1, 2, 0)
226 | y = y[:gt.shape[0],:gt.shape[1],:]
227 | if i==0:
228 | indices = quality_assessment(gt, y, data_range=1., ratio=4)
229 | else:
230 | indices = sum_dict(indices, quality_assessment(gt, y, data_range=1., ratio=4))
231 | output.append(y)
232 | test_number += 1
233 | for index in indices:
234 | indices[index] = indices[index] / test_number
235 |
236 |
237 | save_dir = result_path + model_title + '.npy'
238 | np.save(save_dir, output)
239 | print("Test finished, test results saved to .npy file at ", save_dir)
240 | print(indices)
241 | QIstr = model_title+'_'+str(time.ctime()) + ".txt"
242 | json.dump(indices, open(QIstr, 'w'))
243 |
244 |
245 | def sum_dict(a, b):
246 | temp = dict()
247 | for key in a.keys()| b.keys():
248 | temp[key] = sum([d.get(key, 0) for d in (a, b)])
249 | return temp
250 |
251 |
252 | def adjust_learning_rate(start_lr, optimizer, epoch):
253 | """Sets the learning rate to the initial LR decayed by 2 every 150 epochs"""
254 | lr = start_lr * (0.5 ** (epoch // 150))
255 | for param_group in optimizer.param_groups:
256 | param_group['lr'] = lr
257 |
258 |
259 | def validate(args, loader, model, criterion):
260 | device = torch.device("cuda" if args.cuda else "cpu")
261 | # switch to evaluate mode
262 | model.eval()
263 | epoch_meter = meter.AverageValueMeter()
264 | epoch_meter.reset()
265 | with torch.no_grad():
266 | for i, (ms, lms, gt) in enumerate(loader):
267 | ms, lms, gt = ms.to(device), lms.to(device), gt.to(device)
268 | y = model(ms, lms)
269 | loss = criterion(y, gt)
270 | epoch_meter.add(loss.item())
271 |
272 | # back to training mode
273 | model.train()
274 | return epoch_meter.value()[0]
275 |
276 |
277 | def test(args):
278 | if args.dataset_name=='Cave':
279 | colors = 31
280 | elif args.dataset_name=='Pavia':
281 | colors = 102
282 | elif args.dataset_name=='Houston':
283 | colors = 48
284 | else:
285 | colors = 128
286 | test_data_dir = './datasets/' + args.dataset_name + '_x' + str(args.n_scale) + '/' + args.dataset_name + '_test.mat'
287 | result_path = './results/' + args.dataset_name + '_x' + str(args.n_scale) + '/'
288 | model_title = args.model_title+'_x' + str(args.n_scale)
289 | #model_name = './checkpoints/' +'/'+args.dataset_name +'_'+ model_title + "_ckpt_epoch_" + str() + ".pth"
290 | model_name = './model/' +args.dataset_name + '_'+ model_title + ".pth"
291 | device = torch.device("cuda" if args.cuda else "cpu")
292 | print('===> Loading testset')
293 |
294 | test_set = HSTestData(test_data_dir)
295 | test_loader = DataLoader(test_set, batch_size=1, shuffle=False)
296 | print('===> Start testing')
297 |
298 | with torch.no_grad():
299 | test_number = 0
300 | epoch_meter = meter.AverageValueMeter()
301 | epoch_meter.reset()
302 | # loading model
303 | net = CST(inp_channels=colors, dim=args.n_feats, depths=[4, 4, 4, 4],
304 | num_heads=[6, 6, 6, 6], mlp_ratio=2, scale=args.n_scale)
305 | net.to(device).eval()
306 | state_dict = torch.load(model_name)
307 | net.load_state_dict(state_dict['model'])
308 |
309 | output = []
310 | for i, (ms, lms, gt) in enumerate(test_loader):
311 | # compute output
312 | ms, lms, gt = ms.to(device), lms.to(device), gt.\
313 | to(device)
314 | # y = model(ms)
315 | y = net(ms, lms)
316 | y, gt = y.squeeze().cpu().numpy().transpose(1, 2, 0), gt.squeeze().cpu().numpy().transpose(1, 2, 0)
317 | y = y[:gt.shape[0],:gt.shape[1],:]
318 | if i==0:
319 | indices = quality_assessment(gt, y, data_range=1., ratio=4)
320 | else:
321 | indices = sum_dict(indices, quality_assessment(gt, y, data_range=1., ratio=4))
322 | output.append(y)
323 | test_number += 1
324 | for index in indices:
325 | indices[index] = indices[index] / test_number
326 |
327 | #save_dir = "./test.npy"
328 | save_dir = result_path + model_title + '.npy'
329 | np.save(save_dir, output)
330 | print("Test finished, test results saved to .npy file at ", save_dir)
331 | print(indices)
332 | QIstr = model_title+'_'+str(time.ctime()) + ".txt"
333 | json.dump(indices, open(QIstr, 'w'))
334 |
335 | def save_checkpoint(args, model, epoch, traintime):
336 | device = torch.device("cuda" if args.cuda else "cpu")
337 | model.eval().cpu()
338 | checkpoint_model_dir = './checkpoints/'+traintime+'/'
339 | if not os.path.exists(checkpoint_model_dir):
340 | os.makedirs(checkpoint_model_dir)
341 | ckpt_model_filename = args.model_title + "_ckpt_epoch_" + str(epoch) + ".pth"
342 | ckpt_model_path = os.path.join(checkpoint_model_dir, ckpt_model_filename)
343 |
344 | if torch.cuda.device_count() > 1:
345 | state = {"epoch": epoch, "model": model.module.state_dict()}
346 | else:
347 | state = {"epoch": epoch, "model": model.state_dict()}
348 | torch.save(state, ckpt_model_path)
349 | model.to(device).train()
350 | print("Checkpoint saved to {}".format(ckpt_model_path))
351 |
352 |
353 | def print_network(net):
354 | num_params = 0
355 | for param in net.parameters():
356 | num_params += param.numel()
357 | print('Total number of parameters: %d' % num_params)
358 |
359 |
360 | if __name__ == "__main__":
361 | main()
362 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @Author : zhwzhong
4 | @License : (C) Copyright 2013-2018, hit
5 | @Contact : zhwzhong.hit@gmail.com
6 | @Software: PyCharm
7 | @File : metrics.py
8 | @Time : 2019/12/4 17:35
9 | @Desc :
10 | """
11 | import numpy as np
12 | from scipy.signal import convolve2d
13 | from skimage.measure import compare_psnr, compare_ssim
14 |
15 | def compare_ergas(x_true, x_pred, ratio):
16 | """
17 | Calculate ERGAS, ERGAS offers a global indication of the quality of fused image.The ideal value is 0.
18 | :param x_true:
19 | :param x_pred:
20 | :param ratio: 上采样系数
21 | :return:
22 | """
23 | x_true, x_pred = img_2d_mat(x_true=x_true, x_pred=x_pred)
24 | sum_ergas = 0
25 | for i in range(x_true.shape[0]):
26 | vec_x = x_true[i]
27 | vec_y = x_pred[i]
28 | err = vec_x - vec_y
29 | r_mse = np.mean(np.power(err, 2))
30 | tmp = r_mse / (np.mean(vec_x)**2)
31 | sum_ergas += tmp
32 | return (100 / ratio) * np.sqrt(sum_ergas / x_true.shape[0])
33 |
34 |
35 | def compare_sam(x_true, x_pred):
36 | """
37 | :param x_true: 高光谱图像:格式:(H, W, C)
38 | :param x_pred: 高光谱图像:格式:(H, W, C)
39 | :return: 计算原始高光谱数据与重构高光谱数据的光谱角相似度
40 | """
41 | num = 0
42 | sum_sam = 0
43 | x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32)
44 | for x in range(x_true.shape[0]):
45 | for y in range(x_true.shape[1]):
46 | tmp_pred = x_pred[x, y].ravel()
47 | tmp_true = x_true[x, y].ravel()
48 | if np.linalg.norm(tmp_true) != 0 and np.linalg.norm(tmp_pred) != 0:
49 | sum_sam += np.arccos(
50 | np.minimum(1, np.inner(tmp_pred, tmp_true) / (np.linalg.norm(tmp_true) * np.linalg.norm(tmp_pred))))
51 |
52 | num += 1
53 | sam_deg = (sum_sam / num) * 180 / np.pi
54 | return sam_deg
55 |
56 |
57 | def compare_corr(x_true, x_pred):
58 | """
59 | Calculate the cross correlation between x_pred and x_true.
60 | 求对应波段的相关系数,然后取均值
61 | CC is a spatial measure.
62 | """
63 | x_true, x_pred = img_2d_mat(x_true=x_true, x_pred=x_pred)
64 | x_true = x_true - np.mean(x_true, axis=1).reshape(-1, 1)
65 | x_pred = x_pred - np.mean(x_pred, axis=1).reshape(-1, 1)
66 | numerator = np.sum(x_true * x_pred, axis=1).reshape(-1, 1)
67 | denominator = np.sqrt(np.sum(x_true * x_true, axis=1) * np.sum(x_pred * x_pred, axis=1)).reshape(-1, 1)
68 | return (numerator / denominator).mean()
69 |
70 |
71 | def img_2d_mat(x_true, x_pred):
72 | """
73 | # 将三维的多光谱图像转为2位矩阵
74 | :param x_true: (H, W, C)
75 | :param x_pred: (H, W, C)
76 | :return: a matrix which shape is (C, H * W)
77 | """
78 | h, w, c = x_true.shape
79 | x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32)
80 | x_mat = np.zeros((c, h * w), dtype=np.float32)
81 | y_mat = np.zeros((c, h * w), dtype=np.float32)
82 | for i in range(c):
83 | x_mat[i] = x_true[:, :, i].reshape((1, -1))
84 | y_mat[i] = x_pred[:, :, i].reshape((1, -1))
85 | return x_mat, y_mat
86 |
87 |
88 | def compare_rmse(x_true, x_pred):
89 | """
90 | Calculate Root mean squared error
91 | :param x_true:
92 | :param x_pred:
93 | :return:
94 | """
95 | x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32)
96 | return np.linalg.norm(x_true - x_pred) / (np.sqrt(x_true.shape[0] * x_true.shape[1] * x_true.shape[2]))
97 |
98 |
99 | def compare_mpsnr(x_true, x_pred, data_range):
100 | """
101 | :param x_true: Input image must have three dimension (H, W, C)
102 | :param x_pred:
103 | :return:
104 | """
105 | x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32)
106 | channels = x_true.shape[2]
107 | total_psnr = [compare_psnr(im_true=x_true[:, :, k], im_test=x_pred[:, :, k], data_range=data_range)
108 | for k in range(channels)]
109 |
110 | return np.mean(total_psnr)
111 |
112 | def compare_mpsnr_test(x_true, x_pred, data_range):
113 | """
114 | :param x_true: Input image must have three dimension (H, W, C)
115 | :param x_pred:
116 | :return:
117 | """
118 | x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32)
119 | print(np.argwhere(np.isnan(x_true)))
120 | print(np.argwhere(np.isnan(x_pred)))
121 | channels = x_true.shape[2]
122 | total_psnr = [compare_psnr(im_true=x_true[:, :, k], im_test=x_pred[:, :, k], data_range=data_range)
123 | for k in range(channels)]
124 |
125 | return np.mean(total_psnr)
126 |
127 |
128 | def compare_mssim(x_true, x_pred, data_range, multidimension):
129 | """
130 |
131 | :param x_true:
132 | :param x_pred:
133 | :param data_range:
134 | :param multidimension:
135 | :return:
136 | """
137 | mssim = [compare_ssim(X=x_true[:, :, i], Y=x_pred[:, :, i], data_range=data_range, multidimension=multidimension)
138 | for i in range(x_true.shape[2])]
139 |
140 | return np.mean(mssim)
141 |
142 |
143 | def compare_sid(x_true, x_pred):
144 | """
145 | SID is an information theoretic measure for spectral similarity and discriminability.
146 | :param x_true:
147 | :param x_pred:
148 | :return:
149 | """
150 | x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32)
151 | N = x_true.shape[2]
152 | err = np.zeros(N)
153 | for i in range(N):
154 | err[i] = abs(np.sum(x_pred[:, :, i] * np.log10((x_pred[:, :, i] + 1e-3) / (x_true[:, :, i] + 1e-3))) +
155 | np.sum(x_true[:, :, i] * np.log10((x_true[:, :, i] + 1e-3) / (x_pred[:, :, i] + 1e-3))))
156 | return np.mean(err / (x_true.shape[1] * x_true.shape[0]))
157 |
158 |
159 | def compare_appsa(x_true, x_pred):
160 | """
161 |
162 | :param x_true:
163 | :param x_pred:
164 | :return:
165 | """
166 | x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32)
167 | nom = np.sum(x_true * x_pred, axis=2)
168 | denom = np.linalg.norm(x_true, axis=2) * np.linalg.norm(x_pred, axis=2)
169 |
170 | cos = np.where((nom / (denom + 1e-3)) > 1, 1, (nom / (denom + 1e-3)))
171 | appsa = np.arccos(cos)
172 | return np.sum(appsa) / (x_true.shape[1] * x_true.shape[0])
173 |
174 |
175 | def compare_mare(x_true, x_pred):
176 | """
177 |
178 | :param x_true:
179 | :param x_pred:
180 | :return:
181 | """
182 | x_true, x_pred = x_true.astype(np.float32), x_pred.astype(np.float32)
183 | diff = x_true - x_pred
184 | abs_diff = np.abs(diff)
185 | relative_abs_diff = np.divide(abs_diff, x_true + 1) # added epsilon to avoid division by zero.
186 | return np.mean(relative_abs_diff)
187 |
188 |
189 | def img_qi(img1, img2, block_size=8):
190 | N = block_size ** 2
191 | sum2_filter = np.ones((block_size, block_size))
192 |
193 | img1_sq = img1 * img1
194 | img2_sq = img2 * img2
195 | img12 = img1 * img2
196 |
197 | img1_sum = convolve2d(img1, np.rot90(sum2_filter), mode='valid')
198 | img2_sum = convolve2d(img2, np.rot90(sum2_filter), mode='valid')
199 | img1_sq_sum = convolve2d(img1_sq, np.rot90(sum2_filter), mode='valid')
200 | img2_sq_sum = convolve2d(img2_sq, np.rot90(sum2_filter), mode='valid')
201 | img12_sum = convolve2d(img12, np.rot90(sum2_filter), mode='valid')
202 |
203 | img12_sum_mul = img1_sum * img2_sum
204 | img12_sq_sum_mul = img1_sum * img1_sum + img2_sum * img2_sum
205 | numerator = 4 * (N * img12_sum - img12_sum_mul) * img12_sum_mul
206 | denominator1 = N * (img1_sq_sum + img2_sq_sum) - img12_sq_sum_mul
207 | denominator = denominator1 * img12_sq_sum_mul
208 | quality_map = np.ones(denominator.shape)
209 | index = (denominator1 == 0) & (img12_sq_sum_mul != 0)
210 | quality_map[index] = 2 * img12_sum_mul[index] / img12_sq_sum_mul[index]
211 | index = (denominator != 0)
212 | quality_map[index] = numerator[index] / denominator[index]
213 | return quality_map.mean()
214 |
215 |
216 | def compare_qave(x_true, x_pred, block_size=8):
217 | n_bands = x_true.shape[2]
218 | q_orig = np.zeros(n_bands)
219 | for idim in range(n_bands):
220 | q_orig[idim] = img_qi(x_true[:, :, idim], x_pred[:, :, idim], block_size)
221 | return q_orig.mean()
222 |
223 |
224 | def quality_assessment(x_true, x_pred, data_range, ratio, multi_dimension=False, block_size=8):
225 | """
226 |
227 | :param multi_dimension:
228 | :param ratio:
229 | :param data_range:
230 | :param x_true:
231 | :param x_pred:
232 | :param block_size
233 | :return:
234 | """
235 | result = {'MPSNR': compare_mpsnr(x_true=x_true, x_pred=x_pred, data_range=data_range),
236 | 'MSSIM': compare_mssim(x_true=x_true, x_pred=x_pred, data_range=data_range,
237 | multidimension=multi_dimension),
238 | 'ERGAS': compare_ergas(x_true=x_true, x_pred=x_pred, ratio=ratio),
239 | 'SAM': compare_sam(x_true=x_true, x_pred=x_pred),
240 | # 'SID': compare_sid(x_true=x_true, x_pred=x_pred),
241 | 'CrossCorrelation': compare_corr(x_true=x_true, x_pred=x_pred),
242 | 'RMSE': compare_rmse(x_true=x_true, x_pred=x_pred),
243 | # 'APPSA': compare_appsa(x_true=x_true, x_pred=x_pred),
244 | # 'MARE': compare_mare(x_true=x_true, x_pred=x_pred),
245 | # "QAVE": compare_qave(x_true=x_true, x_pred=x_pred, block_size=block_size)
246 | }
247 | return result
248 |
249 | # from scipy import io as sio
250 | # im_out = np.array(sio.loadmat('/home/zhwzhong/PycharmProject/HyperSR/SOAT/HyperSR/SRindices/Chikuse_EDSRViDeCNN_Blocks=9_Feats=256_Loss_H_Real_1_1_X2X2_N5new_BS32_Epo60_epoch_60_Fri_Sep_20_21:38:44_2019.mat')['output'])
251 | # im_gt = np.array(sio.loadmat('/home/zhwzhong/PycharmProject/HyperSR/SOAT/HyperSR/SRindices/Chikusei_test.mat')['gt'])
252 | #
253 | # sum_rmse, sum_sam, sum_psnr, sum_ssim, sum_ergas = [], [], [], [], []
254 | # for i in range(im_gt.shape[0]):
255 | # print(im_out[i].shape)
256 | # score = quality_assessment(x_pred=im_out[i], x_true=im_gt[i], data_range=1, ratio=4, multi_dimension=False, block_size=8)
257 | # sum_rmse.append(score['RMSE'])
258 | # sum_psnr.append(score['MPSNR'])
259 | # sum_ssim.append(score['MSSIM'])
260 | # sum_sam.append(score['SAM'])
261 | # sum_ergas.append(score['ERGAS'])
262 | #
263 | # print(np.mean(sum_rmse), np.mean(sum_psnr), np.mean(sum_ssim), np.mean(sum_sam))
264 |
--------------------------------------------------------------------------------
/network/CST.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn as nn
4 | from common import *
5 | from einops import rearrange
6 | from network.csa import CSA
7 | from timm.models.layers import DropPath, trunc_normal_
8 | import scipy.io as sio
9 |
10 | class CST(nn.Module):
11 | """SST
12 | Spatial-Spectral Transformer for Hyperspectral Image Denoising
13 | Args:
14 | inp_channels (int, optional): Input channels of HSI. Defaults to 31.
15 | dim (int, optional): Embedding dimension. Defaults to 90.
16 | window_size (int, optional): Window size of non-local spatial attention. Defaults to 8.
17 | depths (list, optional): Number of Transformer block at different layers of network. Defaults to [ 6,6,6,6,6,6].
18 | num_heads (list, optional): Number of attention heads in different layers. Defaults to [ 6,6,6,6,6,6].
19 | mlp_ratio (int, optional): Ratio of mlp dim. Defaults to 2.
20 | qkv_bias (bool, optional): Learnable bias to query, key, value. Defaults to True.
21 | qk_scale (_type_, optional): The qk scale in non-local spatial attention. Defaults to None. If it is set to None, the embedding dimension is used to calculate the qk scale.
22 | bias (bool, optional): Defaults to False.
23 | drop_path_rate (float, optional): Stochastic depth rate of drop rate. Defaults to 0.1.
24 | """
25 |
26 | def __init__(self,
27 | inp_channels=31,
28 | dim=90,
29 | depths=[6, 6, 6, 6, 6, 6],
30 | num_heads=[6, 6, 6, 6, 6, 6],
31 | mlp_ratio=2,
32 | qkv_bias=True, qk_scale=None,
33 | bias=False,
34 | drop_path_rate=0.1,
35 | scale=4
36 | ):
37 | super(CST, self).__init__()
38 |
39 | self.conv_first = nn.Conv2d(inp_channels, dim, 3, 1, 1) # shallow featrure extraction
40 | self.num_layers = depths
41 | self.layers = nn.ModuleList()
42 | print("network depth:", len(self.num_layers))
43 |
44 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
45 | for i_layer in range(len(self.num_layers)):
46 | layer = Cstage(dim=dim,
47 | depth=depths[i_layer],
48 | num_head=num_heads[i_layer],
49 | mlp_ratio=mlp_ratio,
50 | qkv_bias=qkv_bias, qk_scale=qk_scale,
51 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
52 | bias=bias)
53 | self.layers.append(layer)
54 |
55 | # self.conv_delasta = nn.Conv2d(dim, inp_channels, 3, 1, 1) # reconstruction from features
56 | self.skip_conv = default_conv(inp_channels, dim, 3)
57 | self.upsample = Upsampler(default_conv, scale, dim)
58 | self.tail = default_conv(dim, inp_channels, 3)
59 | self.conv = default_conv(dim,dim,3)
60 |
61 | def forward(self, inp_img, lms):
62 | f1 = self.conv_first(inp_img)
63 | # ff = f1.detach().cpu().numpy()
64 | # outputfile = "f1.mat"
65 | # sio.savemat(outputfile, {'features':ff})
66 | # print("save successfully")
67 |
68 | x = f1
69 | for i in range(len(self.num_layers)):
70 | x = self.layers[i](x)
71 | x = self.conv(x + f1)
72 | # x = self.conv_delasta(x) + inp_img
73 | x = self.upsample(x)
74 | x = x + self.skip_conv(lms)
75 | x = self.tail(x)
76 | return x
77 |
78 |
79 | class Cstage(nn.Module):
80 | def __init__(self,
81 | dim=90,
82 | split_size=(2,16),
83 | depth=6,
84 | num_head=6,
85 | mlp_ratio=2,
86 | qkv_bias=True, qk_scale=None,
87 | drop_path=0.1,
88 | bias=False):
89 | super(Cstage, self).__init__()
90 | self.layers1 = nn.ModuleList()
91 | self.layers2 = ResAttentionBlock(default_conv, dim, 1, res_scale=0.1)
92 | self.depth = depth
93 | for i_layer in range(depth):
94 | self.layers1.append(CSMA(dim=dim,
95 | input_resolution=(32, 32),
96 | num_heads=num_head,
97 | drop_path=drop_path[i_layer],
98 | split_size=split_size,
99 | shift_size=[0,0] if (i_layer % 2 == 0) else [split_size[0]//2, split_size[1]//2],
100 | mlp_ratio=mlp_ratio,
101 | attn_drop=0,
102 | qkv_bias=qkv_bias, qk_scale=qk_scale, bias=bias))
103 | self.conv = nn.Conv2d(dim, dim, 1)
104 |
105 | def forward(self, x):
106 | x1 = x
107 | for i in range(self.depth):
108 | x1 = self.layers1[i](x1)
109 | x2 = self.layers2(x)
110 | out = self.conv(x1) + x2
111 | out = x + out
112 | return out
113 |
114 |
115 | class CSE(nn.Module):
116 | """global spectral attention (CSE)
117 | Args:
118 | dim (int): Number of input channels.
119 | num_heads (int): Number of attention heads
120 | bias (bool): If True, add a learnable bias to projection
121 | """
122 |
123 | def __init__(self, dim, num_heads, bias, k=0.5, sr_ratio=2):
124 | super(CSE, self).__init__()
125 | self.num_heads = num_heads
126 | self.k = int(k * dim)
127 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
128 | # self.qkv = nn.Conv2d(dim, k*3, kernel_size=1, bias=bias)
129 | self.sr_ratio = sr_ratio
130 | self.v = nn.Conv2d(dim, self.k, kernel_size=1, bias=bias)
131 | self.qk = BSConvU(dim, 2 * self.k, kernel_size=sr_ratio, stride=sr_ratio, padding=0)
132 | self.project_out = nn.Conv2d(self.k, dim, kernel_size=1, bias=bias)
133 | self.norm = nn.LayerNorm(dim)
134 |
135 | def forward(self, x):
136 | b, c, h, w = x.shape
137 | qk = self.qk(x)
138 | q, k = qk.chunk(2, dim=1) # b self.k h/s w/s
139 | v = self.v(x) # b k h w
140 | q = q.reshape(b, self.num_heads, self.k // self.num_heads, -1)
141 | k = k.reshape(b, self.num_heads, self.k // self.num_heads, -1)
142 | v = v.reshape(b, self.num_heads, self.k // self.num_heads, -1) # b k h w
143 |
144 | q = torch.nn.functional.normalize(q, dim=-1)
145 | k = torch.nn.functional.normalize(k, dim=-1)
146 | attn = (q @ k.transpose(-2, -1)) * self.temperature
147 | attn = attn.softmax(dim=-1)
148 |
149 | out = (attn @ v)
150 |
151 | out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
152 | out = self.project_out(out)
153 | return out
154 |
155 | def flops(self, patchresolution):
156 | flops = 0
157 | H, W, C = patchresolution
158 | flops += H * C * W * C
159 | flops += C * C * H * W
160 | return flops
161 |
162 |
163 | class FeedForward(nn.Module):
164 | def __init__(self, dim, ffn_expansion_factor=2.66, bias=False):
165 | super(FeedForward, self).__init__()
166 |
167 | hidden_features = int(dim*ffn_expansion_factor)
168 |
169 | self.bsconv = BSConvU(dim, hidden_features*2, kernel_size=3, stride=1, padding=1, bias=bias)
170 |
171 | self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
172 |
173 | def forward(self, x):
174 | x1, x2 = self.bsconv(x).chunk(2, dim=1)
175 | x = F.gelu(x1) * x2
176 | x = self.project_out(x)
177 | return x
178 |
179 | class BSConvU(torch.nn.Module):
180 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,
181 | dilation=1, bias=True, padding_mode="zeros", with_ln=False, bn_kwargs=None):
182 | super().__init__()
183 | self.with_ln = with_ln
184 | # check arguments
185 | if bn_kwargs is None:
186 | bn_kwargs = {}
187 |
188 | # pointwise
189 | self.pw = torch.nn.Conv2d(
190 | in_channels=in_channels,
191 | out_channels=out_channels,
192 | kernel_size=1,
193 | stride=1,
194 | padding=0,
195 | dilation=1,
196 | groups=1,
197 | bias=False,
198 | )
199 |
200 | # depthwise
201 | self.dw = torch.nn.Conv2d(
202 | in_channels=out_channels,
203 | out_channels=out_channels,
204 | kernel_size=kernel_size,
205 | stride=stride,
206 | padding=padding,
207 | dilation=dilation,
208 | groups=out_channels,
209 | bias=bias,
210 | padding_mode=padding_mode,
211 | )
212 |
213 | def forward(self, fea):
214 | fea = self.pw(fea)
215 | fea = self.dw(fea)
216 | return fea
217 |
218 | class CSMA(nn.Module):
219 | def __init__(self, dim, input_resolution=[32,32], num_heads=6, drop_path=0.0, split_size=[7, 7], shift_size=[0,0],
220 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., act_layer=nn.GELU, bias=False):
221 | super(CSMA, self).__init__()
222 | self.dim = dim
223 | self.input_resolution = input_resolution
224 | self.num_heads = num_heads
225 | self.mlp_ratio = mlp_ratio
226 |
227 | self.norm1 = nn.LayerNorm(dim)
228 | self.norm2 = nn.LayerNorm(dim)
229 |
230 |
231 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
232 | self.ffn = FeedForward(dim)
233 |
234 | self.attns = CSA(
235 | dim,
236 | input_resolution=input_resolution,
237 | num_heads=num_heads,
238 | split_size=split_size,
239 | shift_size=shift_size,
240 | qkv_bias=qkv_bias,
241 | attn_drop=attn_drop,
242 | proj_drop=drop)
243 | self.spectral_attn = CSE(dim, num_heads, bias)
244 |
245 | def forward(self, x):
246 | B, C, H, W = x.shape
247 | x = x.flatten(2).transpose(1, 2)
248 | shortcut = x
249 | x = self.norm1(x)
250 | x = self.attns(x, (H,W))
251 |
252 | x = x.view(B, H * W, C)
253 | x = x.transpose(1, 2).view(B, C, H, W)
254 | x = self.spectral_attn(x) # global spectral attention
255 |
256 | x = x.flatten(2).transpose(1, 2)
257 | # FFN
258 | x = shortcut + self.drop_path(x)
259 | x = x + self.drop_path(self.ffn(self.norm2(x).transpose(1, 2).view(B, C, H, W)).flatten(2).transpose(1, 2))
260 |
261 | x = x.transpose(1, 2).view(B, C, H, W)
262 | return x
263 |
264 |
--------------------------------------------------------------------------------
/network/csa.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.utils.checkpoint as checkpoint
4 |
5 | from timm.models.layers import DropPath, trunc_normal_
6 | from einops.layers.torch import Rearrange
7 | from einops import rearrange
8 |
9 | import math
10 | import numpy as np
11 |
12 |
13 | class CSA(nn.Module):
14 | """ Regular Cross Aggregation Transformer Block.
15 | Args:
16 | dim (int): Number of input channels.
17 | reso (int): Input resolution.
18 | num_heads (int): Number of attention heads.
19 | split_size (tuple(int)): Height and Width of the regular rectangle window (regular-Rwin).
20 | shift_size (tuple(int)): Shift size for regular-Rwin.
21 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
22 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
23 | qk_scale (float | None): Override default qk scale of head_dim ** -0.5 if set.
24 | drop (float): Dropout rate. Default: 0.0
25 | attn_drop (float): Attention dropout rate. Default: 0.0
26 | drop_path (float): Stochastic depth rate. Default: 0.0
27 | act_layer (nn.Module): Activation layer. Default: nn.GELU
28 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm
29 | """
30 |
31 | def __init__(self, dim,
32 | input_resolution,
33 | num_heads,
34 | split_size=[2,4],
35 | shift_size=[1,2],
36 | qkv_bias=True,
37 | qk_scale=None,
38 | attn_drop=0.,
39 | proj_drop=0.,):
40 | super().__init__()
41 | self.dim = dim
42 | self.num_heads = num_heads
43 | self.input_resolution = input_resolution
44 | self.num_heads = num_heads
45 | self.split_size = split_size
46 | self.shift_size = shift_size
47 |
48 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
49 |
50 | assert 0 <= self.shift_size[0] < self.split_size[0], "shift_size must in 0-split_size0"
51 | assert 0 <= self.shift_size[1] < self.split_size[1], "shift_size must in 0-split_size1"
52 |
53 | self.proj = nn.Linear(dim, dim)
54 | self.attn_drop = nn.Dropout(attn_drop)
55 |
56 | self.attns = nn.ModuleList([
57 | Attention_regular(
58 | dim, resolution=self.input_resolution, idx=i,
59 | split_size=split_size, num_heads=num_heads // 2, dim_out=dim // 2,
60 | qk_scale=qk_scale, attn_drop=attn_drop, position_bias=True)
61 | for i in range(2)])
62 |
63 | self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) # DW Conv
64 |
65 | if self.shift_size[0] > 0 or self.shift_size[1] > 0:
66 | attn_mask = self.calculate_mask(self.input_resolution[0], self.input_resolution[1])
67 | self.register_buffer("attn_mask_0", attn_mask[0])
68 | self.register_buffer("attn_mask_1", attn_mask[1])
69 | else:
70 | attn_mask = None
71 |
72 | self.register_buffer("attn_mask_0", None)
73 | self.register_buffer("attn_mask_1", None)
74 |
75 | def calculate_mask(self, H, W):
76 | # The implementation builds on Swin Transformer code https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
77 | # calculate attention mask for Rwin
78 | img_mask_0 = torch.zeros([1, H, W, 1]) # 1 H W 1 idx=0
79 | img_mask_1 = torch.zeros([1, H, W, 1]) # 1 H W 1 idx=1
80 | h_slices_0 = (slice(0, -self.split_size[0]),
81 | slice(-self.split_size[0], -self.shift_size[0]),
82 | slice(-self.shift_size[0], None))
83 | w_slices_0 = (slice(0, -self.split_size[1]),
84 | slice(-self.split_size[1], -self.shift_size[1]),
85 | slice(-self.shift_size[1], None))
86 |
87 | h_slices_1 = (slice(0, -self.split_size[1]),
88 | slice(-self.split_size[1], -self.shift_size[1]),
89 | slice(-self.shift_size[1], None))
90 | w_slices_1 = (slice(0, -self.split_size[0]),
91 | slice(-self.split_size[0], -self.shift_size[0]),
92 | slice(-self.shift_size[0], None))
93 | cnt = 0
94 | for h in h_slices_0:
95 | for w in w_slices_0:
96 | img_mask_0[:, h, w, :] = cnt
97 | cnt += 1
98 | cnt = 0
99 | for h in h_slices_1:
100 | for w in w_slices_1:
101 | img_mask_1[:, h, w, :] = cnt
102 | cnt += 1
103 |
104 | # calculate mask for H-Shift
105 | img_mask_0 = img_mask_0.view(1, H // self.split_size[0], self.split_size[0], W // self.split_size[1],
106 | self.split_size[1], 1)
107 | img_mask_0 = img_mask_0.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.split_size[0], self.split_size[1],
108 | 1) # nW, sw[0], sw[1], 1
109 | mask_windows_0 = img_mask_0.view(-1, self.split_size[0] * self.split_size[1])
110 | attn_mask_0 = mask_windows_0.unsqueeze(1) - mask_windows_0.unsqueeze(2)
111 | attn_mask_0 = attn_mask_0.masked_fill(attn_mask_0 != 0, float(-100.0)).masked_fill(attn_mask_0 == 0, float(0.0))
112 |
113 | # calculate mask for V-Shift
114 | img_mask_1 = img_mask_1.view(1, H // self.split_size[1], self.split_size[1], W // self.split_size[0],
115 | self.split_size[0], 1)
116 | img_mask_1 = img_mask_1.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.split_size[0], self.split_size[1],
117 | 1) # nW, sw[1], sw[0], 1
118 | mask_windows_1 = img_mask_1.view(-1, self.split_size[1] * self.split_size[0])
119 | attn_mask_1 = mask_windows_1.unsqueeze(1) - mask_windows_1.unsqueeze(2)
120 | attn_mask_1 = attn_mask_1.masked_fill(attn_mask_1 != 0, float(-100.0)).masked_fill(attn_mask_1 == 0, float(0.0))
121 |
122 | return attn_mask_0, attn_mask_1
123 |
124 | def forward(self, x, x_size):
125 | """
126 | Input: x: (B, H*W, C), x_size: (H, W)
127 | Output: x: (B, H*W, C)
128 | """
129 |
130 | H, W = x_size
131 | B, L, C = x.shape
132 | assert L == H * W, "flatten img_tokens has wrong size"
133 |
134 | qkv = self.qkv(x).reshape(B, -1, 3, C).permute(2, 0, 1, 3) # 3, B, HW, C
135 | # v without partition
136 | v = qkv[2].transpose(-2, -1).contiguous().view(B, C, H, W)
137 |
138 | if self.shift_size[0] > 0 or self.shift_size[1] > 0:
139 | qkv = qkv.view(3, B, H, W, C)
140 | # H-Shift
141 | qkv_0 = torch.roll(qkv[:, :, :, :, :C // 2], shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(2, 3))
142 | qkv_0 = qkv_0.view(3, B, L, C // 2)
143 | # V-Shift
144 | qkv_1 = torch.roll(qkv[:, :, :, :, C // 2:], shifts=(-self.shift_size[1], -self.shift_size[0]), dims=(2, 3))
145 | qkv_1 = qkv_1.view(3, B, L, C // 2)
146 |
147 | if self.input_resolution[0] != H or self.input_resolution[1] != W:
148 | mask_tmp = self.calculate_mask(H, W)
149 | # H-Rwin
150 | x1_shift = self.attns[0](qkv_0, H, W, mask=mask_tmp[0].to(x.device))
151 | # V-Rwin
152 | x2_shift = self.attns[1](qkv_1, H, W, mask=mask_tmp[1].to(x.device))
153 |
154 | else:
155 | # H-Rwin
156 | x1_shift = self.attns[0](qkv_0, H, W, mask=self.attn_mask_0)
157 | # V-Rwin
158 | x2_shift = self.attns[1](qkv_1, H, W, mask=self.attn_mask_1)
159 |
160 | x1 = torch.roll(x1_shift, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2))
161 | x2 = torch.roll(x2_shift, shifts=(self.shift_size[1], self.shift_size[0]), dims=(1, 2))
162 | x1 = x1.view(B, L, C // 2).contiguous()
163 | x2 = x2.view(B, L, C // 2).contiguous()
164 | # Concat
165 | attened_x = torch.cat([x1, x2], dim=2)
166 | else:
167 | # V-Rwin
168 | x1 = self.attns[0](qkv[:, :, :, :C // 2], H, W).view(B, L, C // 2).contiguous()
169 | # H-Rwin
170 | x2 = self.attns[1](qkv[:, :, :, C // 2:], H, W).view(B, L, C // 2).contiguous()
171 | # Concat
172 | attened_x = torch.cat([x1, x2], dim=2)
173 |
174 | # Locality Complementary Module
175 | lcm = self.get_v(v)
176 | lcm = lcm.permute(0, 2, 3, 1).contiguous().view(B, L, C)
177 |
178 | attened_x = attened_x + lcm
179 |
180 | attened_x = self.proj(attened_x)
181 | x = x + attened_x
182 |
183 | return x
184 |
185 |
186 | class Attention_regular(nn.Module):
187 | """ Regular Rectangle-Window (regular-Rwin) self-attention with dynamic relative position bias.
188 | It supports both of shifted and non-shifted window.
189 | Args:
190 | dim (int): Number of input channels.
191 | resolution (int): Input resolution.
192 | idx (int): The identix of V-Rwin and H-Rwin, 0 is H-Rwin, 1 is Vs-Rwin. (different order from Attention_axial)
193 | split_size (tuple(int)): Height and Width of the regular rectangle window (regular-Rwin).
194 | dim_out (int | None): The dimension of the attention output. Default: None
195 | num_heads (int): Number of attention heads. Default: 6
196 | attn_drop (float): Dropout ratio of attention weight. Default: 0.0
197 | proj_drop (float): Dropout ratio of output. Default: 0.0
198 | qk_scale (float | None): Override default qk scale of head_dim ** -0.5 if set
199 | position_bias (bool): The dynamic relative position bias. Default: True
200 | """
201 | def __init__(self,
202 | dim,
203 | resolution,
204 | idx,
205 | split_size=[2,4],
206 | dim_out=None,
207 | num_heads=6,
208 | attn_drop=0.,
209 | qk_scale=None,
210 | position_bias=True):
211 | super().__init__()
212 | self.dim = dim
213 | self.dim_out = dim_out or dim
214 | self.resolution = resolution
215 | self.split_size = split_size
216 | self.num_heads = num_heads
217 | self.idx = idx
218 | self.position_bias = position_bias
219 |
220 | head_dim = dim // num_heads
221 | self.scale = qk_scale or head_dim ** -0.5
222 | if idx == -1:
223 | H_sp, W_sp = self.resolution, self.resolution
224 | elif idx == 0:
225 | H_sp, W_sp = self.split_size[0], self.split_size[1]
226 | elif idx == 1:
227 | W_sp, H_sp = self.split_size[0], self.split_size[1]
228 | else:
229 | print ("ERROR MODE", idx)
230 | exit(0)
231 | self.H_sp = H_sp
232 | self.W_sp = W_sp
233 |
234 | if self.position_bias:
235 | self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False)
236 | # generate mother-set
237 | position_bias_h = torch.arange(1 - self.H_sp, self.H_sp)
238 | position_bias_w = torch.arange(1 - self.W_sp, self.W_sp)
239 | biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w]))
240 | biases = biases.flatten(1).transpose(0, 1).contiguous().float()
241 | self.register_buffer('rpe_biases', biases)
242 |
243 | # get pair-wise relative position index for each token inside the window
244 | coords_h = torch.arange(self.H_sp)
245 | coords_w = torch.arange(self.W_sp)
246 | coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
247 | coords_flatten = torch.flatten(coords, 1)
248 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
249 | relative_coords = relative_coords.permute(1, 2, 0).contiguous()
250 | relative_coords[:, :, 0] += self.H_sp - 1
251 | relative_coords[:, :, 1] += self.W_sp - 1
252 | relative_coords[:, :, 0] *= 2 * self.W_sp - 1
253 | relative_position_index = relative_coords.sum(-1)
254 | self.register_buffer('relative_position_index', relative_position_index)
255 |
256 | self.attn_drop = nn.Dropout(attn_drop)
257 | self.pool = nn.AdaptiveAvgPool2d((self.H_sp, self.W_sp))
258 | self.pool2 = nn.AdaptiveMaxPool2d((self.H_sp, self.W_sp))
259 |
260 | def im2win(self, x, H, W):
261 | B, N, C = x.shape
262 | x = x.transpose(-2,-1).contiguous().view(B, C, H, W)
263 | x = img2windows(x, self.H_sp, self.W_sp)
264 | x = x.reshape(-1, self.H_sp* self.W_sp, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous()
265 | return x # -1, heads, H_s* W_s, C // self.num_heads
266 |
267 | def forward(self, qkv, H, W, mask=None):
268 | """
269 | Input: qkv: (B, 3*L, C), H, W, mask: (B, N, N), N is the window size
270 | Output: x (B, H, W, C)
271 | """
272 | q,k,v = qkv[0], qkv[1], qkv[2] #B L C
273 |
274 | B, L, C = q.shape
275 | assert L == H * W, "flatten img_tokens has wrong size"
276 |
277 | # partition the q,k,v, image to window
278 |
279 | q1 = q.transpose(-2, -1).view(B, C, H, W) # B, C, H_s, W_s
280 | q1 = self.pool(q1[:, :C//2, :, :])
281 | q2 = q.transpose(-2, -1).view(B, C, H, W)
282 | q2 = self.pool2(q2[:, C//2:, :, :])
283 | q = torch.cat([q1,q2],dim=1)
284 | q = q.reshape(B, self.num_heads, C//self.num_heads, self.H_sp, self.W_sp).flatten(3).transpose(-2, -1)
285 | q = q.repeat(H*W //(self.H_sp*self.W_sp),1,1,1) # -1, heads, H_s* W_s, C // self.num_heads
286 | k = self.im2win(k, H, W) # -1, heads, H_s* W_s, C // self.num_heads
287 | v = self.im2win(v, H, W)
288 |
289 | q = q * self.scale
290 | attn = (q @ k.transpose(-2, -1)) # B head N C @ B head C N --> B head N N
291 |
292 | # calculate drpe
293 | if self.position_bias:
294 | pos = self.pos(self.rpe_biases)
295 | # select position bias
296 | relative_position_bias = pos[self.relative_position_index.view(-1)].view(
297 | self.H_sp * self.W_sp, self.H_sp * self.W_sp, -1)
298 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
299 | attn = attn + relative_position_bias.unsqueeze(0)
300 |
301 | N = attn.shape[3]
302 |
303 | # use mask for shift window
304 | if mask is not None:
305 | nW = mask.shape[0]
306 | attn = attn.view(B, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
307 | attn = attn.view(-1, self.num_heads, N, N)
308 | attn = nn.functional.softmax(attn, dim=-1, dtype=attn.dtype)
309 | attn = self.attn_drop(attn)
310 |
311 | x = (attn @ v)
312 | x = x.transpose(1, 2).reshape(-1, self.H_sp* self.W_sp, C) # B head N N @ B head N C
313 |
314 | # merge the window, window to image
315 | x = windows2img(x, self.H_sp, self.W_sp, H, W) # B H' W' C
316 |
317 | return x
318 |
319 |
320 | class DynamicPosBias(nn.Module):
321 | # The implementation builds on Crossformer code https://github.com/cheerss/CrossFormer/blob/main/models/crossformer.py
322 | """ Dynamic Relative Position Bias.
323 | Args:
324 | dim (int): Number of input channels.
325 | num_heads (int): Number of attention heads.
326 | residual (bool): If True, use residual strage to connect conv.
327 | """
328 | def __init__(self, dim, num_heads, residual):
329 | super().__init__()
330 | self.residual = residual
331 | self.num_heads = num_heads
332 | self.pos_dim = dim // 4
333 | self.pos_proj = nn.Linear(2, self.pos_dim)
334 | self.pos1 = nn.Sequential(
335 | nn.LayerNorm(self.pos_dim),
336 | nn.ReLU(inplace=True),
337 | nn.Linear(self.pos_dim, self.pos_dim),
338 | )
339 | self.pos2 = nn.Sequential(
340 | nn.LayerNorm(self.pos_dim),
341 | nn.ReLU(inplace=True),
342 | nn.Linear(self.pos_dim, self.pos_dim)
343 | )
344 | self.pos3 = nn.Sequential(
345 | nn.LayerNorm(self.pos_dim),
346 | nn.ReLU(inplace=True),
347 | nn.Linear(self.pos_dim, self.num_heads)
348 | )
349 |
350 | def forward(self, biases):
351 | if self.residual:
352 | pos = self.pos_proj(biases) # 2Gh-1 * 2Gw-1, heads
353 | pos = pos + self.pos1(pos)
354 | pos = pos + self.pos2(pos)
355 | pos = self.pos3(pos)
356 | else:
357 | pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases))))
358 | return pos
359 |
360 |
361 | def img2windows(img, H_sp, W_sp):
362 | """
363 | Input: Image (B, C, H, W)
364 | Output: Window Partition (B', N, C)
365 | """
366 | B, C, H, W = img.shape
367 | img_reshape = img.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp)
368 | img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp* W_sp, C)
369 | return img_perm
370 |
371 |
372 | def windows2img(img_splits_hw, H_sp, W_sp, H, W):
373 | """
374 | Input: Window Partition (B', N, C)
375 | Output: Image (B, H, W, C)
376 | """
377 | B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp))
378 |
379 | img = img_splits_hw.view(B, H // H_sp, W // W_sp, H_sp, W_sp, -1)
380 | img = img.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
381 | return img
--------------------------------------------------------------------------------
/test_demo.sh:
--------------------------------------------------------------------------------
1 | python main_CST.py test --model_title "CST" --dataset "Chikusei" --n_scale 4 --gpus "0"
2 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import scipy.io as sio
2 | import numpy as np
3 | import torch
4 | import cv2
5 | from torch.utils.data import DataLoader
6 |
7 | def data_augmentation(label, mode=0):
8 | if mode == 0:
9 | # original
10 | return label
11 | elif mode == 1:
12 | # flip up and down
13 | return np.flipud(label)
14 | elif mode == 2:
15 | # rotate counterwise 90 degree
16 | return np.rot90(label)
17 | elif mode == 3:
18 | # rotate 90 degree and flip up and down
19 | return np.flipud(np.rot90(label))
20 | elif mode == 4:
21 | # rotate 180 degree
22 | return np.rot90(label, k=2)
23 | elif mode == 5:
24 | # rotate 180 degree and flip
25 | return np.flipud(np.rot90(label, k=2))
26 | elif mode == 6:
27 | # rotate 270 degree
28 | return np.rot90(label, k=3)
29 | elif mode == 7:
30 | # rotate 270 degree and flip
31 | return np.flipud(np.rot90(label, k=3))
32 |
33 |
34 | # rescale every channel to between 0 and 1
35 | def channel_scale(img):
36 | eps = 1e-5
37 | max_list = np.max((np.max(img, axis=0)), axis=0)
38 | min_list = np.min((np.min(img, axis=0)), axis=0)
39 | output = (img - min_list) / (max_list - min_list + eps)
40 | return output
41 |
42 |
43 | # up sample before feeding into network
44 | def upsample(img, ratio):
45 | [h, w, _] = img.shape
46 | return cv2.resize(img, (ratio*h, ratio*w), interpolation=cv2.INTER_CUBIC)
47 |
48 |
49 | def bicubic_downsample(img, ratio):
50 | [h, w, _] = img.shape
51 | new_h, new_w = int(ratio * h), int(ratio * w)
52 | return cv2.resize(img, (new_h, new_w), interpolation=cv2.INTER_CUBIC)
53 |
54 |
55 | def wald_downsample(data, ratio):
56 | [h, w, c] = data.shape
57 | out = []
58 | for i in range(c):
59 | dst = cv2.GaussianBlur(data[:, :, i], (7, 7), 0)
60 | dst = dst[0:h:ratio, 0:w:ratio, np.newaxis]
61 | out.append(dst)
62 | out = np.concatenate(out, axis=2)
63 | return out
64 |
65 |
66 | def save_result(result_dir, out):
67 | out = out.numpy().transpose((0, 2, 3, 1))
68 | sio.savemat(result_dir, {'output': out})
69 |
70 |
71 | def sam_loss(y, ref):
72 | (b, ch, h, w) = y.size()
73 | tmp1 = y.view(b, ch, h * w).transpose(1, 2)
74 | tmp2 = ref.view(b, ch, h * w)
75 | sam = torch.bmm(tmp1, tmp2)
76 | idx = torch.arange(0, h * w, out=torch.LongTensor())
77 | sam = sam[:, idx, idx].view(b, h, w)
78 | norm1 = torch.norm(y, 2, 1)
79 | norm2 = torch.norm(ref, 2, 1)
80 | sam = torch.div(sam, (norm1 * norm2))
81 | sam = torch.sum(sam) / (b * h * w)
82 | return sam
83 |
84 |
85 | def extract_RGB(y):
86 | # take 4-2-1 band (R-G-B) for WV-3
87 | R = torch.unsqueeze(torch.mean(y[:, 4:8, :, :], 1), 1)
88 | G = torch.unsqueeze(torch.mean(y[:, 2:4, :, :], 1), 1)
89 | B = torch.unsqueeze(torch.mean(y[:, 0:2, :, :], 1), 1)
90 | y_RGB = torch.cat((R, G, B), 1)
91 | return y_RGB
92 |
93 |
94 | def extract_edge(data):
95 | N = data.shape[0]
96 | out = np.zeros_like(data)
97 | for i in range(N):
98 | if len(data.shape) == 3:
99 | out[i, :, :] = data[i, :, :] - cv2.boxFilter(data[i, :, :], -1, (5, 5))
100 | else:
101 | out[i, :, :, :] = data[i, :, :, :] - cv2.boxFilter(data[i, :, :, :], -1, (5, 5))
102 | return out
103 |
104 |
105 | def normalize_batch(batch):
106 | # normalize using imagenet mean and std
107 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(-1, 1, 1).cuda()
108 | std = torch.Tensor([0.229, 0.224, 0.225]).view(-1, 1, 1).cuda()
109 | return (batch - mean) / std
110 |
111 |
112 | def add_channel(rgb):
113 | # initialize other channels using the average of RGB from VGG
114 | R = torch.unsqueeze(y[:, 0, :, :], 1)
115 | G = torch.unsqueeze(y[:, 1, :, :], 1)
116 | B = torch.unsqueeze(y[:, 2, :, :], 1)
117 | all_channel = torch.cat((B, B, G, G, R, R, R, R), 1)
118 | return all_channel
119 |
120 |
121 | # from LapSRN
122 | class L1_Charbonnier_loss(torch.nn.Module):
123 | """L1 Charbonnierloss."""
124 | def __init__(self):
125 | super(L1_Charbonnier_loss, self).__init__()
126 | self.eps = 1e-6
127 |
128 | def forward(self, X, Y):
129 | diff = torch.add(X, -Y)
130 | error = torch.sqrt(diff * diff + self.eps)
131 | loss = torch.sum(error)
132 | return loss
133 |
134 |
135 |
--------------------------------------------------------------------------------