├── benchmark
├── README.md
└── datasets
│ ├── ICVL
│ └── todo.txt
│ ├── RealHSI
│ ├── real_test.txt
│ ├── real_train.txt
│ ├── build_test.py
│ └── real_build_lmdb.py
│ ├── Harvard
│ ├── harvard_test.txt
│ ├── harvard_train.txt
│ ├── harvard_meta.txt
│ ├── harvard_build_test.py
│ └── harvard_build_lmdb.py
│ └── CAVE
│ ├── cave_test.txt
│ ├── cave_train.txt
│ ├── cave_meta.txt
│ ├── cave_preprocess.py
│ ├── cave_build_test.py
│ └── cave_build_lmdb.py
├── hsir
├── model
│ ├── man
│ │ ├── __init__.py
│ │ ├── v2.py
│ │ └── v1.py
│ ├── t3sc
│ │ ├── utils
│ │ │ ├── __init__.py
│ │ │ └── patches_handler.py
│ │ ├── metrics
│ │ │ ├── __init__.py
│ │ │ └── utils.py
│ │ ├── layers
│ │ │ ├── __init__.py
│ │ │ ├── soft_thresholding.py
│ │ │ ├── encoding_layer.py
│ │ │ └── lowrank_sc_layer.py
│ │ ├── __init__.py
│ │ ├── configs
│ │ │ └── t3sc.yaml
│ │ ├── multilayer.py
│ │ └── base.py
│ ├── __init__.py
│ ├── trq3d
│ │ ├── __init__.py
│ │ ├── conv.py
│ │ └── combinations.py
│ ├── qrnn3d
│ │ ├── __init__.py
│ │ ├── conv.py
│ │ ├── qrnn.py
│ │ └── net.py
│ ├── denet.py
│ ├── memnet.py
│ ├── hsidcnn.py
│ ├── unet3d.py
│ ├── grunet.py
│ └── hsdt
│ │ ├── sepconv.py
│ │ ├── __init__.py
│ │ ├── arch.py
│ │ └── attention.py
├── data
│ ├── transform
│ │ ├── __init__.py
│ │ ├── general.py
│ │ ├── sr.py
│ │ └── noise.py
│ ├── __init__.py
│ ├── utils.py
│ ├── dataloader.py
│ └── dataset.py
├── __init__.py
├── scheduler.py
└── schedule.py
├── docs
├── requirements.txt
├── source
│ ├── getstart.md
│ ├── dataset.md
│ ├── index.md
│ └── conf.py
├── Makefile
└── make.bat
├── hsiboard
├── app.py
├── main.py
├── table.py
├── viewer.py
├── box.py
├── util.py
└── cmp.py
├── setup.py
├── .readthedocs.yaml
├── LICENSE
├── hsirun
├── benchmark.py
├── train_ssr.py
├── train.py
└── test.py
└── .gitignore
/benchmark/README.md:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/hsir/model/man/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/benchmark/datasets/ICVL/todo.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/hsir/data/transform/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/hsir/model/t3sc/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .patches_handler import PatchesHandler
2 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | myst-parser
2 | sphinx-rtd-theme
3 | furo
4 | sphinx-copybutton
5 | sphinx-inline-tabs
--------------------------------------------------------------------------------
/hsir/model/t3sc/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .metrics import mergas, mfsim, mpsnr, msam, mse, mssim, psnr
2 |
--------------------------------------------------------------------------------
/hsir/__init__.py:
--------------------------------------------------------------------------------
1 | from . import scheduler
2 | from . import schedule
3 | from . import trainer
4 | from . import model
5 | from . import data
--------------------------------------------------------------------------------
/hsir/data/__init__.py:
--------------------------------------------------------------------------------
1 | from . import dataset
2 | from . import transform
3 | from .dataset import HSITestDataset, HSITrainDataset, HSITransformTestDataset
4 | from . import ssr
--------------------------------------------------------------------------------
/hsir/model/t3sc/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .lowrank_sc_layer import LowRankSCLayer
2 | from .encoding_layer import EncodingLayer
3 | from .soft_thresholding import SoftThresholding
4 |
--------------------------------------------------------------------------------
/benchmark/datasets/RealHSI/real_test.txt:
--------------------------------------------------------------------------------
1 | 41
2 | 42
3 | 43
4 | 44
5 | 45
6 | 46
7 | 47
8 | 48
9 | 49
10 | 50
11 | 51
12 | 52
13 | 53
14 | 54
15 | 55
16 | 56
17 | 57
18 | 58
19 | 59
--------------------------------------------------------------------------------
/hsir/model/__init__.py:
--------------------------------------------------------------------------------
1 | from . import qrnn3d
2 | from . import t3sc
3 | from . import grunet
4 | from . import restormer
5 | from . import uformer
6 | from . import unet3d
7 | from . import hsidcnn
--------------------------------------------------------------------------------
/benchmark/datasets/Harvard/harvard_test.txt:
--------------------------------------------------------------------------------
1 | imgc7
2 | imgd3
3 | imgd9
4 | imge3
5 | imge7
6 | imgf4
7 | imgf8
8 | imgh3
9 | imga1
10 | imga7
11 | imgb3
12 | imgb7
13 | imgc2
14 | imgc8
15 | imgd4
16 | imge0
17 | imge4
18 | imgf1
19 | imgf5
20 | imgh0
--------------------------------------------------------------------------------
/benchmark/datasets/CAVE/cave_test.txt:
--------------------------------------------------------------------------------
1 | thread_spools_ms
2 | jelly_beans_ms
3 | beads_ms
4 | fake_and_real_peppers_ms
5 | face_ms
6 | pompoms_ms
7 | flowers_ms
8 | photo_and_face_ms
9 | fake_and_real_lemons_ms
10 | stuffed_toys_ms
11 | glass_tiles_ms
12 | oil_painting_ms
--------------------------------------------------------------------------------
/benchmark/datasets/RealHSI/real_train.txt:
--------------------------------------------------------------------------------
1 | 1
2 | 2
3 | 3
4 | 4
5 | 5
6 | 6
7 | 7
8 | 8
9 | 9
10 | 10
11 | 11
12 | 12
13 | 13
14 | 14
15 | 15
16 | 16
17 | 17
18 | 18
19 | 19
20 | 20
21 | 21
22 | 22
23 | 23
24 | 24
25 | 25
26 | 26
27 | 27
28 | 28
29 | 29
30 | 30
31 | 31
32 | 32
33 | 33
34 | 34
35 | 35
36 | 36
37 | 37
38 | 38
39 | 39
40 | 40
--------------------------------------------------------------------------------
/benchmark/datasets/Harvard/harvard_train.txt:
--------------------------------------------------------------------------------
1 | imga2
2 | imgb0
3 | imgb4
4 | imgb8
5 | imgc4
6 | imgc9
7 | imgd7
8 | imge1
9 | imge5
10 | imgf2
11 | imgf6
12 | imgh1
13 | img1
14 | imga5
15 | imgb1
16 | imgb5
17 | imgb9
18 | imgc5
19 | imgd2
20 | imgd8
21 | imge2
22 | imge6
23 | imgf3
24 | imgf7
25 | imgh2
26 | img2
27 | imga6
28 | imgb2
29 | imgb6
30 | imgc1
--------------------------------------------------------------------------------
/hsiboard/app.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import subprocess
4 |
5 |
6 | if __name__ == '__main__':
7 | parser = argparse.ArgumentParser('HSIR Board')
8 | parser.add_argument('--logdir', default='results')
9 | args = parser.parse_args()
10 |
11 | path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'main.py')
12 | subprocess.call(['streamlit', 'run', path, '--', '--logdir', args.logdir])
13 |
--------------------------------------------------------------------------------
/benchmark/datasets/CAVE/cave_train.txt:
--------------------------------------------------------------------------------
1 | fake_and_real_sushi_ms
2 | watercolors_ms
3 | chart_and_stuffed_toy_ms
4 | superballs_ms
5 | real_and_fake_peppers_ms
6 | fake_and_real_tomatoes_ms
7 | balloons_ms
8 | fake_and_real_beers_ms
9 | real_and_fake_apples_ms
10 | cloth_ms
11 | feathers_ms
12 | hairs_ms
13 | paints_ms
14 | sponges_ms
15 | clay_ms
16 | fake_and_real_food_ms
17 | fake_and_real_strawberries_ms
18 | egyptian_statue_ms
19 | fake_and_real_lemon_slices_ms
20 | cd_ms
--------------------------------------------------------------------------------
/hsir/data/utils.py:
--------------------------------------------------------------------------------
1 | import torchlight as tl
2 | import torch
3 | import numpy as np
4 |
5 |
6 | def visualize_gray(hsi, band=20):
7 | return hsi[:, band:band + 1, :, :]
8 |
9 |
10 | def visualize_color(hsi):
11 | srf = tl.transforms.HSI2RGB().srf
12 | srf = torch.from_numpy(srf).float().to(hsi.device)
13 | hsi = hsi.permute(0, 2, 3, 1) @ srf.T
14 | return hsi.permute(0, 3, 1, 2)
15 |
16 |
17 | def worker_init_fn(worker_id):
18 | np.random.seed(np.random.get_state()[1][0] + worker_id)
19 |
--------------------------------------------------------------------------------
/hsir/model/t3sc/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .multilayer import MultilayerModel
3 |
4 |
5 | def build_t3sc(num_band):
6 | from omegaconf import OmegaConf
7 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
8 | cfg_path = os.path.join(CURRENT_DIR, 'configs', 't3sc.yaml')
9 | cfg = OmegaConf.load(cfg_path)
10 | cfg.params.channels = num_band
11 | net = MultilayerModel(**cfg.params)
12 | net.use_2dconv = True
13 | net.bandwise = False
14 | return net
15 |
16 |
17 | def t3sc():
18 | return build_t3sc(31)
19 |
--------------------------------------------------------------------------------
/hsir/model/trq3d/__init__.py:
--------------------------------------------------------------------------------
1 | # the original version, with FCU before and after
2 | from functools import partial
3 | from .trq3d import *
4 |
5 |
6 | TRQ3D = partial(
7 | TRQ3D,
8 | Encoder=TRQ3DEncoder,
9 | Decoder=TRQ3DDecoder
10 | )
11 |
12 |
13 | def trq3d():
14 | net = TRQ3D(in_channels=1, in_channels_tr=31, channels=16, channels_tr=16, med_channels=31, num_half_layer=4, sample_idx=[1, 3], has_ad=True,
15 | input_resolution=(512, 512))
16 | net.use_2dconv = False
17 | net.bandwise = False
18 | return net
19 |
--------------------------------------------------------------------------------
/docs/source/getstart.md:
--------------------------------------------------------------------------------
1 | # Getting Started
2 |
3 | ```{toctree}
4 | :maxdepth: 2
5 | ```
6 |
7 |
8 | [](https://pypi.org/project/hsir/)
9 |
10 | Out-of-box Hyperspectral Image Restoration Toolbox
11 |
12 |
13 | ## Install
14 |
15 | ```shell
16 | pip install hsir
17 | ```
18 |
19 | ## Usage
20 |
21 | Here are some runable examples, please refer to the code for more options.
22 |
23 | ```shell
24 | python hsirun/train.py -a qrnn3d.qrnn3d
25 | python hsirun/test.py -a qrnn3d.qrnn3d -r qrnn3d.pth -t icvl_512_50
26 | ```
--------------------------------------------------------------------------------
/benchmark/datasets/Harvard/harvard_meta.txt:
--------------------------------------------------------------------------------
1 | imga2
2 | imgb0
3 | imgb4
4 | imgb8
5 | imgc4
6 | imgc9
7 | imgd7
8 | imge1
9 | imge5
10 | imgf2
11 | imgf6
12 | imgh1
13 | img1
14 | imga5
15 | imgb1
16 | imgb5
17 | imgb9
18 | imgc5
19 | imgd2
20 | imgd8
21 | imge2
22 | imge6
23 | imgf3
24 | imgf7
25 | imgh2
26 | img2
27 | imga6
28 | imgb2
29 | imgb6
30 | imgc1
31 | imgc7
32 | imgd3
33 | imgd9
34 | imge3
35 | imge7
36 | imgf4
37 | imgf8
38 | imgh3
39 | imga1
40 | imga7
41 | imgb3
42 | imgb7
43 | imgc2
44 | imgc8
45 | imgd4
46 | imge0
47 | imge4
48 | imgf1
49 | imgf5
50 | imgh0
--------------------------------------------------------------------------------
/benchmark/datasets/CAVE/cave_meta.txt:
--------------------------------------------------------------------------------
1 | fake_and_real_sushi_ms
2 | watercolors_ms
3 | chart_and_stuffed_toy_ms
4 | superballs_ms
5 | real_and_fake_peppers_ms
6 | fake_and_real_tomatoes_ms
7 | balloons_ms
8 | fake_and_real_beers_ms
9 | real_and_fake_apples_ms
10 | cloth_ms
11 | feathers_ms
12 | hairs_ms
13 | paints_ms
14 | sponges_ms
15 | clay_ms
16 | fake_and_real_food_ms
17 | fake_and_real_strawberries_ms
18 | egyptian_statue_ms
19 | fake_and_real_lemon_slices_ms
20 | cd_ms
21 | thread_spools_ms
22 | jelly_beans_ms
23 | beads_ms
24 | fake_and_real_peppers_ms
25 | face_ms
26 | pompoms_ms
27 | flowers_ms
28 | photo_and_face_ms
29 | fake_and_real_lemons_ms
30 | stuffed_toys_ms
31 | glass_tiles_ms
32 | oil_painting_ms
--------------------------------------------------------------------------------
/hsir/scheduler.py:
--------------------------------------------------------------------------------
1 |
2 | def adjust_learning_rate(optimizer, lr):
3 | for param_group in optimizer.param_groups:
4 | param_group['lr'] = lr
5 |
6 | class MultiStepSetLR:
7 | def __init__(self, optimizer, schedule, epoch=0) -> None:
8 | self.optimizer = optimizer
9 | self.schedule = schedule
10 | self.epoch = epoch
11 |
12 | def step(self):
13 | self.epoch += 1
14 | if self.epoch in self.schedule.keys():
15 | adjust_learning_rate(self.optimizer, self.schedule[self.epoch])
16 |
17 | def state_dict(self):
18 | return {'epoch': self.epoch}
19 |
20 | def load_state_dict(self, state_dict):
21 | self.epoch = state_dict['epoch']
22 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/benchmark/datasets/CAVE/cave_preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | import scipy.io as sio
5 |
6 | root = '/home/wzliu/projects/data/cave'
7 | names = os.listdir(root)
8 | names = [n for n in names if n.endswith('ms')]
9 |
10 | save_dir = '/home/wzliu/projects/data/cave_mat'
11 | os.makedirs(save_dir, exist_ok=True)
12 |
13 | for n in names:
14 | data = []
15 | for b in range(31):
16 | path = os.path.join(root, n, n, '{}_{:0>2d}.png'.format(n, b + 1))
17 | img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
18 | data.append(np.expand_dims(img, 2))
19 | data = np.concatenate(data, axis=2)
20 |
21 | print(n)
22 | sio.savemat(os.path.join(save_dir, n + '.mat'), {'gt': data})
23 |
--------------------------------------------------------------------------------
/hsiboard/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import streamlit as st
3 | import os
4 | import sys
5 |
6 | sys.path.append(os.path.dirname(os.path.abspath(__file__)))
7 |
8 | import viewer
9 | import table
10 | import cmp
11 |
12 | def main(logdir):
13 | page_names_to_funcs = {
14 | "Viewer": viewer.main,
15 | "Comparator": cmp.main,
16 | "Table": table.main,
17 | }
18 | selected_page = st.sidebar.selectbox("Select a page", page_names_to_funcs.keys())
19 | page_names_to_funcs[selected_page](logdir)
20 |
21 |
22 | if __name__ == '__main__':
23 | parser = argparse.ArgumentParser('HSIR Board')
24 | parser.add_argument('--logdir', default='results')
25 | args = parser.parse_args()
26 |
27 | main(args.logdir)
28 |
--------------------------------------------------------------------------------
/hsir/model/t3sc/metrics/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | EPS = 1e-12
6 |
7 |
8 | def abs(x):
9 | return torch.sqrt(x[:, :, :, :, 0] ** 2 + x[:, :, :, :, 1] ** 2 + EPS)
10 |
11 |
12 | def real(x):
13 | return x[:, :, :, :, 0]
14 |
15 |
16 | def imag(x):
17 | return x[:, :, :, :, 1]
18 |
19 |
20 | def downsample(img1, img2, maxSize=256):
21 | _, channels, H, W = img1.shape
22 | f = int(max(1, np.round(min(H, W) / maxSize)))
23 | if f > 1:
24 | aveKernel = (torch.ones(channels, 1, f, f) / f ** 2).to(img1.device)
25 | img1 = F.conv2d(img1, aveKernel, stride=f, padding=0, groups=channels)
26 | img2 = F.conv2d(img2, aveKernel, stride=f, padding=0, groups=channels)
27 | return img1, img2
28 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 |
4 | from pathlib import Path
5 | this_directory = Path(__file__).parent
6 | long_description = (this_directory / "README.md").read_text()
7 |
8 | packages = find_packages()
9 | packages.append('hsirun')
10 | packages.append('hsiboard')
11 |
12 | setup(
13 | name='hsir',
14 | description="Hyperspectral Image Restoration Toolbox",
15 | long_description=long_description,
16 | long_description_content_type='text/markdown',
17 | url='https://github.com/Zeqiang-Lai/HSIR',
18 | packages=packages,
19 | package_dir={'hsir': 'hsir', 'hsirrun':'hsirrun', 'hsiboard': 'hsiboard'},
20 | version='0.1.1',
21 | include_package_data=True,
22 | install_requires=['tqdm', 'qqdm', 'timm', 'streamlit', 'torchlights'],
23 | )
24 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yaml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | # Set the version of Python and other tools you might need
9 | build:
10 | os: ubuntu-20.04
11 | tools:
12 | python: "3.9"
13 | # You can also specify other tool versions:
14 | # nodejs: "16"
15 | # rust: "1.55"
16 | # golang: "1.17"
17 |
18 | # Build documentation in the docs/ directory with Sphinx
19 | sphinx:
20 | configuration: docs/source/conf.py
21 |
22 | # If using Sphinx, optionally build your docs in additional formats such as PDF
23 | # formats:
24 | # - pdf
25 |
26 | # Optionally declare the Python requirements required to build your docs
27 | python:
28 | install:
29 | - requirements: docs/requirements.txt
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.https://www.sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/hsir/model/qrnn3d/__init__.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from .net import QRNN3DDecoder, QRNN3DEncoder, QRNNREDC3D
3 | from .qrnn import QRNNConv3D, QRNNDeConv3D, QRNNUpsampleConv3d, BiQRNNConv3D, BiQRNNDeConv3D
4 |
5 |
6 | QRNN3DEncoder = partial(
7 | QRNN3DEncoder,
8 | QRNNConv3D=QRNNConv3D)
9 |
10 | QRNN3DDecoder = partial(
11 | QRNN3DDecoder,
12 | QRNNDeConv3D=QRNNDeConv3D,
13 | QRNNUpsampleConv3d=QRNNUpsampleConv3d)
14 |
15 | QRNNREDC3D = partial(
16 | QRNNREDC3D,
17 | BiQRNNConv3D=BiQRNNConv3D,
18 | BiQRNNDeConv3D=BiQRNNDeConv3D,
19 | QRNN3DEncoder=QRNN3DEncoder,
20 | QRNN3DDecoder=QRNN3DDecoder
21 | )
22 |
23 |
24 | def qrnn3d():
25 | net = QRNNREDC3D(1, 16, 5, [1, 3], has_ad=True)
26 | net.use_2dconv = False
27 | net.bandwise = False
28 | return net
29 |
30 |
31 | def qrnn3d_nobn():
32 | net = QRNNREDC3D(1, 16, 5, [1, 3], has_ad=True, bn=False)
33 | net.use_2dconv = False
34 | net.bandwise = False
35 | return net
36 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Zeqiang Lai
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/benchmark/datasets/RealHSI/build_test.py:
--------------------------------------------------------------------------------
1 | from skimage import io
2 | import numpy as np
3 | import scipy.io as sio
4 | import os
5 |
6 | def load_tif_img(filepath):
7 | img = io.imread(filepath)
8 | img = img.astype(np.float32)
9 | return img
10 |
11 |
12 | root = '/media/exthdd/datasets/hsi/real_hsi/real_dataset'
13 |
14 | names = os.listdir(os.path.join(root, 'gt'))
15 | names = [n for n in names if n.endswith('tif')]
16 |
17 | save_dir = '/media/exthdd/datasets/hsi/real_hsi/real_dataset/test34_50'
18 | os.makedirs(save_dir, exist_ok=True)
19 |
20 | names = os.listdir('data/real34')
21 |
22 | for n in names:
23 | n = n[:-4]
24 | print(n)
25 | gt = load_tif_img(os.path.join(root, 'gt', n+'.tif')).transpose((1, 2, 0))
26 | input = load_tif_img(os.path.join(root, 'input50', n+'.tif')).transpose((1, 2, 0))
27 | gt = gt / gt.max()
28 | input = input / input.max()
29 | from skimage import exposure
30 | input = exposure.match_histograms(input, gt, multichannel=False)
31 | input = input[:688,:512,:]
32 | gt = gt[:688,:512,:]
33 | sio.savemat(os.path.join(save_dir, n + '.mat'), {'gt': gt, 'input': input})
--------------------------------------------------------------------------------
/hsir/model/t3sc/configs/t3sc.yaml:
--------------------------------------------------------------------------------
1 | class_name: "MultilayerModel"
2 | trainable: true
3 | beta: 0
4 | ckpt: null
5 | params:
6 | ssl: 0
7 | n_ssl: 4
8 | channels: 31
9 | ckpt: null
10 | layers:
11 | l0:
12 | name: "LowRankSCLayer"
13 | params:
14 | patch_side: 1
15 | K: 12
16 | rank: 1
17 | code_size: 64
18 | stride: 1
19 | input_centering: 1
20 | patch_centering: 0
21 | tied: "D"
22 | init_method: "kaiming_uniform"
23 | lbda_init: 0.001
24 | lbda_mode: "MC"
25 | beta: 0
26 | ssl: 00
27 | l1:
28 | name: "LowRankSCLayer"
29 | params:
30 | patch_side: 5
31 | K: 5
32 | rank: 3
33 | code_size: 1024
34 | stride: 1
35 | input_centering: 0
36 | patch_centering: 1
37 | tied: "D"
38 | init_method: "kaiming_uniform"
39 | lbda_init: 0.001
40 | lbda_mode: "MC"
41 | beta: 0
42 | ssl: 0
43 |
44 | backtracking:
45 | monitor: "val_mpsnr"
46 | mode: "max"
47 | dirpath: "backtracking"
48 | period: 5
49 | div_thresh: 4
50 | dummy: False
51 | lr_decay: 0.8
52 | id: T3SC
53 |
--------------------------------------------------------------------------------
/docs/source/dataset.md:
--------------------------------------------------------------------------------
1 | # Datasets
2 |
3 | 🌟 存放位置: `/share/dataset/hsi`
4 |
5 | ```tree
6 | ├── cave
7 | ├── harvard
8 | ├── icvl
9 | │ ├── test
10 | │ │ ├── icvl_512_30
11 | │ │ ├── icvl_512_50
12 | │ │ ├── icvl_512_70
13 | │ │ ├── icvl_512_blind
14 | │ │ ├── icvl_512_deadline
15 | │ │ ├── icvl_512_impulse
16 | │ │ ├── icvl_512_mixture
17 | │ │ ├── icvl_512_noniid
18 | │ │ └── icvl_512_stripe
19 | │ └── train
20 | │ └── ICVL64_31_100.db
21 | ├── raw
22 | └── remote
23 | ├── Indian_pines.mat
24 | ├── Pavia.mat
25 | ├── PaviaU.mat
26 | ├── Salinas.mat
27 | └── Urban_R162.mat
28 | ```
29 | ## ICVL
30 |
31 | 自然场景高光谱图像数据集
32 |
33 | - 来源:[Official Website](http://icvl.cs.bgu.ac.il/hyperspectral/)
34 | - 服务器存放路径: `/data/dataset/hsi/icvl`
35 |
36 | Meta Data
37 |
38 | - 201 张图
39 |
40 | 示例图:
41 |
42 | 
43 |
44 | 文件结构:
45 | ```
46 |
47 | ```
48 |
49 | Cite
50 | ```bibtex
51 | @inproceedings{arad_and_ben_shahar_2016_ECCV,
52 | title={Sparse Recovery of Hyperspectral Signal from Natural RGB Images},
53 | author={Arad, Boaz and Ben-Shahar, Ohad},
54 | booktitle={European Conference on Computer Vision},
55 | pages={19--34},
56 | year={2016},
57 | organization={Springer}
58 | }
59 | ```
60 |
61 | ## CAVE
62 |
63 |
64 |
65 | ## Harvard
66 |
67 | ## Remote Sensed
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/hsir/model/t3sc/layers/soft_thresholding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | MODES = ["SG", "SC", "MG", "MC"]
7 |
8 |
9 | class SoftThresholding(nn.Module):
10 | def __init__(self, mode, lbda_init, code_size=None, K=None):
11 | super().__init__()
12 | assert mode in MODES, f"Mode {mode!r} not recognized"
13 | self.mode = mode
14 |
15 | if self.mode[1] == "C":
16 | # 1 lambda per channel
17 | lbda_shape = (1, code_size, 1, 1)
18 | else:
19 | # 1 lambda for all channels
20 | lbda_shape = (1, 1, 1, 1)
21 |
22 | if self.mode[0] == "M":
23 | # 1 set of lambdas per unfolding
24 | self.lbda = nn.ParameterList(
25 | [
26 | nn.Parameter(lbda_init * torch.ones(*lbda_shape))
27 | for _ in range(K)
28 | ]
29 | )
30 | else:
31 | # 1 set of lambdas for all unfoldings
32 | self.lbda = nn.Parameter(lbda_init * torch.ones(*lbda_shape))
33 |
34 | def forward(self, x, k=None):
35 | if self.mode[0] == "M":
36 | return self._forward(x, self.lbda[k])
37 | else:
38 | return self._forward(x, self.lbda)
39 |
40 | def _forward(self, x, lbda):
41 | return F.relu(x - lbda) - F.relu(-x - lbda)
42 |
--------------------------------------------------------------------------------
/benchmark/datasets/CAVE/cave_build_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from scipy.io import loadmat, savemat
4 |
5 |
6 | def minmax_normalize(array):
7 | amin = np.min(array)
8 | amax = np.max(array)
9 | return (array - amin) / (amax - amin)
10 |
11 |
12 | def main(get_noise, sigma):
13 | root = '/home/wzliu/projects/data/cave_mat/'
14 | save_root = os.path.join('/home/wzliu/projects/data/cave_test/', 'cave_512_{}'.format(sigma))
15 | os.makedirs(save_root, exist_ok=True)
16 |
17 | with open('cave_test.txt') as f:
18 | fns = [l.strip() for l in f.readlines()]
19 |
20 | np.random.seed(2022)
21 | for fn in fns:
22 | print(fn)
23 | path = os.path.join(root, fn)
24 | data = loadmat(path)
25 | gt = data['gt']
26 | gt = minmax_normalize(gt)
27 | noise = get_noise(gt.shape)
28 | input = gt + noise
29 | data = {'input': input, 'gt': gt, 'sigma': sigma}
30 | savemat(os.path.join(save_root, fn+'.mat'), data)
31 |
32 |
33 | def fixed(sigma):
34 | def get_noise(shape):
35 | return np.random.randn(*shape) * sigma / 255.
36 | return get_noise
37 |
38 | def blind(min, max):
39 | def get_noise(shape):
40 | return np.random.randn(*shape) * (min+np.random.rand(1)*(max-min)) / 255.
41 | return get_noise
42 |
43 | main(fixed(30),30)
44 | main(fixed(50),50)
45 | main(fixed(70),70)
46 | main(blind(10,70),'blind')
47 |
48 |
--------------------------------------------------------------------------------
/hsir/model/t3sc/layers/encoding_layer.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch.nn as nn
4 |
5 | logger = logging.getLogger(__name__)
6 | logger.setLevel(logging.DEBUG)
7 |
8 |
9 | class EncodingLayer(nn.Module):
10 | def __init__(
11 | self,
12 | in_channels=None,
13 | code_size=None,
14 | input_centering=False,
15 | **kwargs,
16 | ):
17 | super().__init__()
18 | self.in_channels = in_channels
19 | self.code_size = code_size
20 | self.input_centering = input_centering
21 |
22 | def forward(self, x, mode=None, **kwargs):
23 | assert mode in ["encode", "decode", None], f"Mode {mode!r} unknown"
24 |
25 | if mode in ["encode", None]:
26 | x = self.encode(x, **kwargs)
27 | if mode in ["decode", None]:
28 | x = self.decode(x, **kwargs)
29 | return x
30 |
31 | def encode(self, x, **kwargs):
32 | if self.input_centering:
33 | self.input_means = x.mean(dim=[2, 3], keepdim=True)
34 | x -= self.input_means
35 |
36 | x = self._encode(x, **kwargs)
37 |
38 | return x
39 |
40 | def decode(self, x, **kwargs):
41 | x = self._decode(x, **kwargs)
42 |
43 | if self.input_centering:
44 | x += self.input_means
45 |
46 | return x
47 |
48 | def _encode(self, x, **kwargs):
49 | raise NotImplementedError
50 |
51 | def _decode(self, x, **kwargs):
52 | raise NotImplementedError
53 |
--------------------------------------------------------------------------------
/hsiboard/table.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from util import *
3 |
4 |
5 | def main(logdir):
6 |
7 | with st.sidebar:
8 | st.title('HSIR Board')
9 | st.subheader('Table')
10 | st.caption(f'Directory: {logdir}')
11 | methods = listdir(logdir, exclude=['gt'])
12 | selected_methods = st.sidebar.multiselect(
13 | "Select Method",
14 | methods,
15 | default=methods,
16 | )
17 |
18 | stat = load_stat(logdir)
19 |
20 | st.subheader('Not loaded')
21 | st.text(set(selected_methods)-set(stat.keys()))
22 |
23 | if len(selected_methods) == 1:
24 | selected_method = selected_methods[0]
25 | print(stat[selected_method][0])
26 | st.header(selected_method)
27 | st.dataframe(stat[selected_method])
28 | else:
29 | table = {}
30 | for m in stat.keys():
31 | for d in stat[m]:
32 | row = {'Method': m}
33 | for k, v in d.items():
34 | if k != 'Name':
35 | row[k] = v
36 | if d['Name'] not in table:
37 | table[d['Name']] = []
38 | table[d['Name']].append(row)
39 | for k, v in table.items():
40 | st.subheader(k)
41 | st.dataframe(v)
42 |
43 |
44 | if __name__ == '__main__':
45 | parser = argparse.ArgumentParser('HSIR Board')
46 | parser.add_argument('--logdir', default='results')
47 | args = parser.parse_args()
48 |
49 | main(args.logdir)
50 |
--------------------------------------------------------------------------------
/docs/source/index.md:
--------------------------------------------------------------------------------
1 | ---
2 | hide-toc: true
3 | ---
4 |
5 | # Welcome to HSIR
6 |
7 |
8 | ```{toctree}
9 | :maxdepth: 3
10 | :hidden: true
11 |
12 | getstart
13 | dataset
14 | benchmark
15 | ```
16 |
17 |
18 | ```{toctree}
19 | :caption: Useful Links
20 | :hidden:
21 | PyPI page
22 | GitHub Repository
23 | ```
24 |
25 |
26 | [](https://pypi.org/project/hsir/)
27 |
28 | Out-of-box Hyperspectral Image Restoration Toolbox
29 |
30 |
31 | ## Install
32 |
33 | ```shell
34 | pip install hsir
35 | ```
36 |
37 | ## Usage
38 |
39 | Here are some runable examples, please refer to the code for more options.
40 |
41 | ```shell
42 | python hsirun/train.py -a qrnn3d.qrnn3d
43 | python hsirun/test.py -a qrnn3d.qrnn3d -r qrnn3d.pth -t icvl_512_50
44 | ```
45 |
46 |
47 | ## Acknowledgement
48 |
49 | - [QRNN3D](https://github.com/Vandermode/QRNN3D)
50 | - [DPHSIR](https://github.com/Zeqiang-Lai/DPHSIR)
51 | - [MST](https://github.com/caiyuanhao1998/MST)
52 | - [TLC](https://github.com/megvii-research/TLC)
53 |
54 | ## Citation
55 |
56 | If you find this repo helpful, please considering citing us.
57 |
58 | ```bibtex
59 | @misc{hsir,
60 | author={Zeqiang Lai, Miaoyu Li, Ying Fu},
61 | title={HSIR: Out-of-box Hyperspectral Image Restoration Toolbox},
62 | year={2022},
63 | url={https://github.com/bit-isp/HSIR},
64 | }
65 | ```
66 |
67 | 
--------------------------------------------------------------------------------
/benchmark/datasets/Harvard/harvard_build_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from scipy.io import loadmat, savemat
4 |
5 |
6 | def minmax_normalize(array):
7 | amin = np.min(array)
8 | amax = np.max(array)
9 | return (array - amin) / (amax - amin)
10 |
11 | def crop_center(img, cropx, cropy):
12 | y, x, _ = img.shape
13 | startx = x // 2 - (cropx // 2)
14 | starty = y // 2 - (cropy // 2)
15 | return img[starty:starty + cropy, startx:startx + cropx, :]
16 |
17 | def main(get_noise, sigma):
18 | root = '/home/wzliu/projects/data/harvard/CZ_hsdb/'
19 | save_root = os.path.join('/home/wzliu/projects/data/harvard_test/', 'harvard_512_{}'.format(sigma))
20 | os.makedirs(save_root, exist_ok=True)
21 |
22 | with open('harvard_test.txt') as f:
23 | fns = [l.strip() for l in f.readlines()]
24 |
25 | np.random.seed(2022)
26 | for fn in fns:
27 | print(fn)
28 | path = os.path.join(root, fn)
29 | data = loadmat(path)
30 | gt = data['ref']
31 | gt = minmax_normalize(gt)
32 | gt = crop_center(gt, 512, 512)
33 | print(gt.shape)
34 | noise = get_noise(gt.shape)
35 | input = gt + noise
36 | data = {'input': input, 'gt': gt, 'sigma': sigma}
37 | savemat(os.path.join(save_root, fn+'.mat'), data)
38 |
39 |
40 | def fixed(sigma):
41 | def get_noise(shape):
42 | return np.random.randn(*shape) * sigma / 255.
43 | return get_noise
44 |
45 | def blind(min, max):
46 | def get_noise(shape):
47 | return np.random.randn(*shape) * (min+np.random.rand(1)*(max-min)) / 255.
48 | return get_noise
49 |
50 | main(fixed(30),30)
51 | main(fixed(50),50)
52 | main(fixed(70),70)
53 | main(blind(10,70),'blind')
54 |
55 |
--------------------------------------------------------------------------------
/hsir/data/transform/general.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import threading
4 |
5 |
6 | class HSI2Tensor(object):
7 | """
8 | Transform a numpy array with shape (C, H, W)
9 | into torch 4D Tensor (1, C, H, W) or (C, H, W)
10 | """
11 |
12 | def __init__(self, use_chw=True):
13 | """ use_chw: True for (C, D, H, W) and False for (C, H, W) """
14 | self.use_chw = use_chw
15 |
16 | def __call__(self, hsi):
17 | img = torch.from_numpy(hsi)
18 | if not self.use_chw:
19 | img = img.unsqueeze(0)
20 | return img.float()
21 |
22 |
23 | class Identity(object):
24 | """
25 | Identity transformw
26 | """
27 |
28 | def __call__(self, data):
29 | return data
30 |
31 |
32 | class RandomCrop(object):
33 | """
34 | Random crop transform
35 | """
36 |
37 | def __init__(self, croph, cropw):
38 | self.croph = croph
39 | self.cropw = cropw
40 |
41 | def __call__(self, img):
42 | _, h, w = img.shape
43 | croph, cropw = self.croph, self.cropw
44 | h1 = random.randint(0, h - croph)
45 | w1 = random.randint(0, w - cropw)
46 | return img[:, h1:h1 + croph, w1:w1 + cropw]
47 |
48 |
49 | class LockedIterator(object):
50 | def __init__(self, it):
51 | self.lock = threading.Lock()
52 | self.it = it.__iter__()
53 |
54 | def __iter__(self): return self
55 |
56 | def __next__(self):
57 | self.lock.acquire()
58 | try:
59 | return next(self.it)
60 | finally:
61 | self.lock.release()
62 |
63 |
64 | class SequentialSelect(object):
65 | def __pos(self, n):
66 | i = 0
67 | while True:
68 | # print(i)
69 | yield i
70 | i = (i + 1) % n
71 |
72 | def __init__(self, transforms):
73 | self.transforms = transforms
74 | self.pos = LockedIterator(self.__pos(len(transforms)))
75 |
76 | def __call__(self, img):
77 | out = self.transforms[next(self.pos)](img)
78 | return out
79 |
--------------------------------------------------------------------------------
/hsirun/benchmark.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | from datetime import datetime
5 | from ast import literal_eval
6 | import platform
7 |
8 | import torch
9 | import torch.cuda
10 | import torch.backends.cudnn
11 |
12 | from torchlight.utils import instantiate
13 |
14 | torch.set_grad_enabled(False)
15 | torch.backends.cudnn.benchmark = True
16 |
17 |
18 | def main(arch, input_shape, total_steps=10, save_path='benchmark.json'):
19 | net = instantiate(arch)
20 |
21 | net = net.cuda().eval()
22 | input = torch.randn(*input_shape).cuda()
23 |
24 | total_params = sum([param.nelement() for param in net.parameters()]) / 1e6
25 | print("Number of parameter: %.2fM" % (total_params))
26 |
27 | # warm up for benchmark
28 | for _ in range(10):
29 | net(input)
30 |
31 | start = torch.cuda.Event(enable_timing=True)
32 | end = torch.cuda.Event(enable_timing=True)
33 |
34 | steps = int(total_steps)
35 | start.record()
36 |
37 | for _ in range(steps):
38 | net(input)
39 | end.record()
40 |
41 | torch.cuda.synchronize()
42 |
43 | avg_time = start.elapsed_time(end) / steps
44 | print('Time: {} ms'.format(avg_time))
45 |
46 | database = {}
47 | if os.path.exists(save_path):
48 | database = json.load(open(save_path, 'r'))
49 | entry = database.get(arch, [])
50 | dtstr = datetime.now().strftime("%Y-%m-%d-%H:%M:%S")
51 | entry.append({
52 | 'runtime': avg_time, 'params': total_params, 'date': dtstr,
53 | 'os': platform.platform(),
54 | 'processor': platform.processor(),
55 | })
56 | database[arch] = entry
57 | with open(save_path, 'w') as f:
58 | json.dump(database, f, indent=4)
59 |
60 |
61 | if __name__ == '__main__':
62 | parser = argparse.ArgumentParser('Benchmark model runtime and params')
63 | parser.add_argument('-a', '--arch', type=str, required=True)
64 | parser.add_argument('-t', '--steps', type=int, default=10)
65 | parser.add_argument('-s', '--save-path', type=str, default='benchmark.json')
66 | parser.add_argument('-i', '--input-shape', type=str, default='[1,1,31,512,512]')
67 | args = parser.parse_args()
68 | input_shape = literal_eval(args.input_shape)
69 | main(args.arch, input_shape, args.steps, args.save_path)
70 |
--------------------------------------------------------------------------------
/hsiboard/viewer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from os.path import join
3 |
4 | from util import *
5 | from box import *
6 |
7 |
8 | def main(logdir):
9 | set_page_container_style()
10 | with st.sidebar:
11 | st.title('HSIR Board')
12 | st.subheader('Viewer')
13 | st.caption(f'Directory: {logdir}')
14 | methods = listdir(logdir, exclude=['gt'])
15 | selected_method = st.sidebar.selectbox(
16 | "Select Method",
17 | methods
18 | )
19 | datasets = listdir(join(logdir, selected_method))
20 | selected_dataset = st.sidebar.selectbox(
21 | "Select Dataset",
22 | datasets
23 | )
24 | selected_vis_type = st.sidebar.selectbox(
25 | "Select Image Type",
26 | ['color', 'gray']
27 | )
28 |
29 | st.header('Box')
30 | enable_enlarge = st.checkbox('Enlarge')
31 | crow = st.slider('row coordinate', min_value=0.0, max_value=1.0, value=0.2)
32 | ccol = st.slider('col coordinate', min_value=0.0, max_value=1.0, value=0.2)
33 | selected_box_pos = st.sidebar.selectbox(
34 | "Select Box Position",
35 | ['Bottom Right', 'Bottom Left', 'Top Right', 'Top Left'],
36 | )
37 |
38 | st.header('Layout')
39 | ncol = st.slider('number of columns', min_value=1, max_value=20, value=6)
40 |
41 | imgs = load_imgs(join(logdir, selected_method, selected_dataset, selected_vis_type))
42 |
43 | nrow = len(imgs) // ncol + 1
44 | grids = make_grid(nrow, ncol)
45 | details = load_per_image_stat(join(logdir, selected_method, selected_dataset))
46 | with st.container():
47 | idx = 0
48 | for name, img in imgs.items():
49 | name = os.path.splitext(name)[-2]
50 | if enable_enlarge:
51 | h, w = img.shape[:2]
52 | img = addbox(img.copy(), (int(h * crow), int(w * ccol)),
53 | bbpos=mapbbpox[selected_box_pos])
54 | ct = grids[idx // ncol][idx % ncol]
55 | ct.image(img, caption='%s [%.4f]' % (name, details[name]['MPSNR']))
56 | idx += 1
57 |
58 |
59 |
60 | if __name__ == '__main__':
61 | parser = argparse.ArgumentParser('HSIR Board')
62 | parser.add_argument('--logdir', default='results')
63 | args = parser.parse_args()
64 |
65 | main(args.logdir)
66 |
--------------------------------------------------------------------------------
/hsir/model/denet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.init as init
3 |
4 |
5 | """
6 | each tuple unit: [out_channels, dilation]
7 | """
8 | cfg = ([64, 1], [64, 1], [64, 1],
9 | [128, 1], [128, 1], [128, 1],
10 | [256, 2], [256, 2], [256, 2],
11 | [128, 1], [128, 1], [128, 1],
12 | [64, 1], [64, 1], [64, 1], [64, 1])
13 |
14 |
15 | class DeNet(nn.Module):
16 | def __init__(self, in_channels=10, kernel_size=3, init_weights=True):
17 | super(DeNet, self).__init__()
18 | layers = []
19 | out_channels = 64
20 | # add CR
21 | layers.append(nn.Conv2d(
22 | in_channels=in_channels, out_channels=out_channels,
23 | kernel_size=kernel_size, padding=1, bias=True))
24 | layers.append(nn.ReLU(inplace=True))
25 | # add CBR1 to CBR16
26 | in_channels = out_channels
27 | for out_channels, dilation in cfg:
28 | if dilation == 1:
29 | padding = 1
30 | elif dilation == 2:
31 | padding = 2
32 | layers.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
33 | kernel_size=kernel_size, padding=padding,
34 | dilation=dilation, bias=False)
35 | )
36 | layers.append(nn.BatchNorm2d(num_features=out_channels))
37 | layers.append(nn.ReLU(inplace=True))
38 | in_channels = out_channels
39 |
40 | # add C
41 | layers.append(nn.Conv2d(in_channels=64, out_channels=10,
42 | kernel_size=kernel_size, padding=1, bias=False)
43 | )
44 | self.denet = nn.Sequential(*layers)
45 | # if init_weights:
46 | # self._initialize_weights()
47 |
48 | def forward(self, x):
49 | y = x
50 | out = self.denet(x)
51 | return y - out
52 |
53 | def _initialize_weights(self):
54 | print("===> Start initializing weights")
55 | for m in self.modules():
56 | if isinstance(m, nn.Conv2d):
57 | init.orthogonal_(m.weight)
58 | if m.bias is not None:
59 | init.constant_(m.bias, 0)
60 | elif isinstance(m, nn.BatchNorm2d):
61 | init.constant_(m.weight, 1)
62 | init.constant_(m.bias, 0)
63 |
64 | def hsidenet():
65 | net = DeNet(in_channels=31)
66 | net.use_2dconv = True
67 | net.bandwise = False
68 | return net
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | checkpoints
2 | results
3 | model_zoo
4 | ./data
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
--------------------------------------------------------------------------------
/hsir/model/t3sc/multilayer.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | #from .base import BaseModel
7 | from . import layers
8 |
9 | logger = logging.getLogger(__name__)
10 | logger.setLevel(logging.DEBUG)
11 |
12 |
13 | class MultilayerModel(nn.Module):
14 | def __init__(
15 | self,
16 | channels,
17 | layers,
18 | ssl=0,
19 | n_ssl=0,
20 | ckpt=None,
21 | ):
22 | super().__init__()
23 | self.channels = channels
24 | self.layers_params = layers
25 | self.ssl = ssl
26 | self.n_ssl = n_ssl
27 | logger.debug(f"ssl : {self.ssl}, n_ssl : {self.n_ssl}")
28 |
29 | self.init_layers()
30 | self.normalized_dict = False
31 |
32 | logger.info(f"Using SSL : {self.ssl}")
33 | self.ckpt = ckpt
34 | if self.ckpt is not None:
35 | logger.info(f"Loading ckpt {self.ckpt!r}")
36 | d = torch.load(self.ckpt)
37 | self.load_state_dict(d["state_dict"])
38 |
39 | def init_layers(self):
40 | list_layers = []
41 | in_channels = self.channels
42 |
43 | for i in range(len(self.layers_params)):
44 | logger.debug(f"Initializing layer {i}")
45 | name = self.layers_params[f"l{i}"]["name"]
46 | params = self.layers_params[f"l{i}"]["params"]
47 | layer_cls = layers.__dict__[name]
48 | layer = layer_cls(
49 | in_channels=in_channels,
50 | **params,
51 | )
52 | in_channels = layer.code_size
53 |
54 | list_layers.append(layer)
55 | self.layers = nn.ModuleList(list_layers)
56 |
57 | def forward(
58 | self, x, mode=None, img_id=None, sigmas=None, ssl_idx=None, **kwargs
59 | ):
60 | assert mode in ["encode", "decode", None], f"Mode {mode!r} unknown"
61 | x = x.float().clone()
62 |
63 | if mode in ["encode", None]:
64 | x = self.encode(x, img_id, sigmas=sigmas, ssl_idx=ssl_idx)
65 | if mode in ["decode", None]:
66 | x = self.decode(x, img_id)
67 | return x
68 |
69 | def encode(self, x, img_id, sigmas, ssl_idx):
70 |
71 | for layer in self.layers:
72 | x = layer(
73 | x,
74 | mode="encode",
75 | img_id=img_id,
76 | sigmas=sigmas,
77 | ssl_idx=ssl_idx,
78 | )
79 | return x
80 |
81 | def decode(self, x, img_id):
82 | for layer in self.layers[::-1]:
83 | x = layer(x, mode="decode", img_id=img_id)
84 |
85 | return x
86 |
--------------------------------------------------------------------------------
/hsiboard/box.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 | mapbbpox = {
5 | 'Bottom Right': 'br',
6 | 'Bottom Left': 'bl',
7 | 'Top Right': 'ur',
8 | 'Top Left': 'ul',
9 | }
10 |
11 | def get_border_box_pt(h, w, size, thickness, pos):
12 | if pos == 'br':
13 | return (w-size-thickness, h-size-thickness)
14 | if pos == 'ur':
15 | return (w-size-thickness, 0)
16 | if pos == 'ul':
17 | return (0, 0)
18 | if pos == 'bl':
19 | return (0, h-size-thickness)
20 |
21 | def addbox(img, pt, size=100,
22 | bbsize=200, bbpos='br',
23 | color=(255, 0, 0),
24 | thickness=1,
25 | bbthickness=2):
26 | H, W = img.shape[0], img.shape[1]
27 |
28 | ptb = get_border_box_pt(H, W, bbsize, bbthickness, pos=bbpos)
29 | crop_img = img[pt[1]:pt[1]+size, pt[0]:pt[0]+size]
30 | crop_img = cv2.resize(crop_img, (bbsize, bbsize))
31 | img[ptb[1]:ptb[1]+bbsize, ptb[0]:ptb[0]+bbsize] = crop_img
32 |
33 | # big box
34 | pt1 = ptb
35 | pt2 = (pt1[0]+bbsize, pt1[1]+bbsize)
36 | cv2.rectangle(img, pt1, pt2, color, bbthickness)
37 |
38 | # small box
39 | pt1 = pt
40 | pt2 = (pt1[0]+size, pt1[1]+size)
41 | cv2.rectangle(img, pt1, pt2, color, thickness)
42 |
43 | return img
44 |
45 |
46 | def convert_color(arr, cmap='viridis', vmin=0, vmax=0.1):
47 | import matplotlib.cm as cm
48 | sm = cm.ScalarMappable(cmap=cmap)
49 | sm.set_clim(vmin, vmax)
50 | rgba = sm.to_rgba(arr, alpha=1)
51 | return np.array(rgba[:, :, :3])
52 |
53 |
54 | def addbox_with_diff(input, gt, pt, vmax=0.1, size=100,
55 | color=(255, 0, 0),
56 | thickness=1,
57 | bbthickness=2,
58 | sep=4):
59 | if len(input.shape) == 2:
60 | input = np.expand_dims(input, -1)
61 | if len(gt.shape) == 2:
62 | gt = np.expand_dims(gt, -1)
63 |
64 | H, W = input.shape[:2]
65 | C = 3
66 |
67 | out = np.zeros((H, W + H // 2 + sep//2, C))
68 | out[:, :W] = input
69 |
70 | # small box.
71 | pt1 = pt
72 | pt2 = (pt1[0] + size, pt1[1] + size)
73 | cv2.rectangle(out, pt1, pt2, color, thickness)
74 |
75 | # crop.
76 | bbsize = H // 2 - sep//2
77 | crop_img = input[pt[1]:pt[1] + size, pt[0]:pt[0] + size, :]
78 | crop_img = cv2.resize(crop_img, (bbsize, bbsize))
79 |
80 | # diff
81 | diff = convert_color(np.abs(input-gt)[:,:,0], vmin=0, vmax=vmax)
82 |
83 | crop_img_diff = diff[pt[1]:pt[1] + size, pt[0]:pt[0] + size, :]
84 | crop_img_diff = cv2.resize(crop_img_diff, (bbsize, bbsize))
85 |
86 | if len(crop_img.shape) == 2: crop_img = np.expand_dims(crop_img, axis=-1)
87 | if len(crop_img_diff.shape) == 2: crop_img_diff = np.expand_dims(crop_img_diff, axis=-1)
88 |
89 | out[:bbsize, W+sep:W+sep + bbsize, :] = crop_img
90 | out[bbsize+sep:, W+sep:W+sep + bbsize, :] = crop_img_diff
91 |
92 | return out
93 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | # import os
14 | # import sys
15 | # sys.path.insert(0, os.path.abspath('.'))
16 |
17 |
18 | # -- Project information -----------------------------------------------------
19 |
20 | project = 'HSIR'
21 | copyright = '2022, BIT-ISP Lab'
22 | author = 'Zeqiang-Lai, Miaoyu Li'
23 |
24 | # The full version, including alpha/beta/rc tags
25 | release = '0.0.1'
26 |
27 |
28 | # -- General configuration ---------------------------------------------------
29 |
30 | # Add any Sphinx extension module names here, as strings. They can be
31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
32 | # ones.
33 | extensions = [
34 | 'myst_parser',
35 | 'sphinx_copybutton',
36 | # "sphinx_inline_tabs",
37 | 'sphinx.ext.autodoc',
38 | 'sphinx.ext.autosummary',
39 | 'sphinx.ext.doctest',
40 | 'sphinx.ext.todo',
41 | 'sphinx.ext.coverage',
42 | 'sphinx.ext.mathjax',
43 | 'sphinx.ext.viewcode',
44 | 'sphinx.ext.napoleon',
45 | ]
46 |
47 | # Add any paths that contain templates here, relative to this directory.
48 | templates_path = ['_templates']
49 |
50 | # List of patterns, relative to source directory, that match files and
51 | # directories to ignore when looking for source files.
52 | # This pattern also affects html_static_path and html_extra_path.
53 | exclude_patterns = []
54 |
55 | # The name of the Pygments (syntax highlighting) style to use.
56 | pygments_style = 'sphinx'
57 |
58 | # -- Options for HTML output -------------------------------------------------
59 |
60 | # The theme to use for HTML and HTML Help pages. See the documentation for
61 | # a list of builtin themes.
62 | #
63 | # html_theme = 'alabaster'
64 | # html_theme = 'sphinx_rtd_theme'
65 | html_theme = 'furo'
66 |
67 | html_title = project
68 |
69 | # Add any paths that contain custom static files (such as style sheets) here,
70 | # relative to this directory. They are copied after the builtin static files,
71 | # so a file named "default.css" will overwrite the builtin "default.css".
72 | html_static_path = ['_static']
73 |
74 |
75 | #
76 | # -- Options for TODOs -------------------------------------------------------
77 | #
78 | todo_include_todos = True
79 |
80 | #
81 | # -- Options for Markdown files ----------------------------------------------
82 | #
83 | myst_admonition_enable = True
84 | myst_deflist_enable = True
85 | myst_heading_anchors = 3
86 |
--------------------------------------------------------------------------------
/hsirun/train_ssr.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | from torch.utils.data import DataLoader
5 | from torch.optim.lr_scheduler import CosineAnnealingLR
6 | from torchlight.utils import instantiate
7 | from torchlight.nn.utils import get_learning_rate
8 |
9 | from hsir.data.ssr.dataset import TrainDataset, ValidDataset
10 | from hsir.trainer import Trainer
11 |
12 |
13 | def train_cfg():
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--arch', '-a', required=True)
16 | parser.add_argument('--name', '-n', type=str, default=None,
17 | help='name of the experiment, if not specified, arch will be used.')
18 | parser.add_argument('--lr', type=float, default=4e-4)
19 | parser.add_argument('--bs', type=int, default=20)
20 | parser.add_argument('--epochs', type=int, default=5)
21 | parser.add_argument('--schedule', type=str, default='hsir.schedule.denoise_default')
22 | parser.add_argument('--resume', '-r', action='store_true')
23 | parser.add_argument('--resume-path', '-rp', type=str, default=None)
24 | parser.add_argument('--data-root', type=str, default='data/rgb2hsi')
25 | parser.add_argument('--data-size', type=int, default=None)
26 | parser.add_argument('--save-root', type=str, default='checkpoints/ssr')
27 | parser.add_argument('--gpu-ids', type=str, default='0', help='gpu ids')
28 | cfg = parser.parse_args()
29 | cfg.gpu_ids = [int(id) for id in cfg.gpu_ids.split(',')]
30 | cfg.name = cfg.arch if cfg.name is None else cfg.name
31 | return cfg
32 |
33 |
34 | def main():
35 | cfg = train_cfg()
36 |
37 | net = instantiate(cfg.arch)
38 | trainer = Trainer(
39 | net,
40 | lr=cfg.lr,
41 | save_dir=os.path.join(cfg.save_root, cfg.name),
42 | gpu_ids=cfg.gpu_ids,
43 | )
44 | trainer.logger.print(cfg)
45 | if cfg.resume: trainer.load(cfg.resume_path)
46 |
47 | dataset = TrainDataset(cfg.data_root, size=cfg.data_size, stride=64)
48 | train_loader = DataLoader(dataset, batch_size=cfg.bs, shuffle=True, num_workers=8, pin_memory=True)
49 | dataset = ValidDataset(cfg.data_root)
50 | val_loader = DataLoader(dataset, batch_size=1)
51 |
52 | """Main loop"""
53 |
54 | # lr_scheduler = CosineAnnealingLR(trainer.optimizer, cfg.max_epochs, eta_min=1e-6)
55 | epoch_per_save = 10
56 | best_psnr = 0
57 |
58 | while trainer.epoch < cfg.epochs:
59 | trainer.logger.print('Epoch [{}] Use lr={}'.format(trainer.epoch, get_learning_rate(trainer.optimizer)))
60 |
61 | trainer.train(train_loader)
62 |
63 | # save ckpt
64 | trainer.save_checkpoint('model_latest.pth')
65 | metrics = trainer.validate(val_loader, 'NITRE')
66 | if metrics['psnr'] > best_psnr:
67 | best_psnr = metrics['psnr']
68 | trainer.save_checkpoint('model_best.pth')
69 | if trainer.epoch % epoch_per_save == 0:
70 | trainer.save_checkpoint()
71 |
72 |
73 |
74 | if __name__ == '__main__':
75 | main()
76 |
--------------------------------------------------------------------------------
/hsir/model/t3sc/utils/patches_handler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import logging
6 |
7 | logger = logging.getLogger(__name__)
8 | logger.setLevel(logging.DEBUG)
9 |
10 |
11 | class PatchesHandler(nn.Module):
12 | def __init__(self, size, channels, stride, padding="constant"):
13 | super().__init__()
14 | if isinstance(size, int):
15 | self.size = np.array([size, size])
16 | else:
17 | self.size = np.array(size)
18 | self.channels = channels
19 | self.n_elements = self.channels * self.size[0] * self.size[1]
20 | self.stride = stride
21 |
22 | self.padding = padding
23 | self.fold = None
24 | self.normalizer = None
25 | self.img_size = None
26 |
27 | def forward(self, x, mode="extract"):
28 | if mode == "extract":
29 | x = self.pad(x)
30 | x = self.extract(x)
31 | return x
32 | elif mode == "aggregate":
33 | x = self.aggregate(x)
34 | x = self.unpad(x)
35 | return x
36 | else:
37 | raise ValueError(f"Mode {mode!r} not recognized")
38 |
39 | def set_img_size(self, img_size):
40 | if np.any(self.img_size != np.array(img_size)):
41 | self.img_size = np.array(img_size)
42 |
43 | self.n_patches = 1 + np.ceil(
44 | np.maximum(self.img_size - self.size, 0) / self.stride
45 | ).astype(int)
46 | pads = (
47 | self.size + (self.n_patches - 1) * self.stride - self.img_size
48 | )
49 | self.padded_size = tuple(self.img_size + pads)
50 | _pads = []
51 | for i in reversed(range(2)):
52 | _pads += [0, pads[i]]
53 | self.pads = _pads
54 |
55 | def pad(self, x):
56 | self.set_img_size(x.shape[2:])
57 | x = F.pad(x, pad=self.pads, mode=self.padding)
58 | return x
59 |
60 | def unpad(self, x):
61 | x = x[:, :, : self.img_size[0], : self.img_size[1]]
62 | return x
63 |
64 | def extract(self, x):
65 |
66 | x = x.unfold(dimension=2, size=self.size[0], step=self.stride)
67 | x = x.unfold(dimension=3, size=self.size[1], step=self.stride)
68 | x = x.permute(0, 1, 4, 5, 2, 3)
69 |
70 | x = x.contiguous()
71 |
72 | return x
73 |
74 | def aggregate(self, x):
75 | x = x.view(-1, self.n_elements, self.n_patches[0] * self.n_patches[1])
76 |
77 | if self.fold is None or self.padded_size != self.fold.output_size:
78 | self.init_fold()
79 |
80 | out = self.fold(x)
81 | if self.normalizer is None or self.normalizer.shape != x.shape:
82 | ones_input = torch.ones_like(x)
83 | self.normalizer = self.fold(ones_input)
84 | return (out / self.normalizer).squeeze(1)
85 |
86 | def init_fold(self):
87 | logger.debug(f"Initializing fold, padded shape: {self.padded_size}")
88 | self.fold = nn.Fold(
89 | output_size=self.padded_size,
90 | kernel_size=tuple(self.size),
91 | stride=self.stride,
92 | )
93 |
--------------------------------------------------------------------------------
/hsir/model/memnet.py:
--------------------------------------------------------------------------------
1 | """MemNet"""
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from sync_batchnorm import SynchronizedBatchNorm2d
6 |
7 | # BatchNorm2d = SynchronizedBatchNorm2d
8 | BatchNorm2d = nn.BatchNorm2d
9 |
10 |
11 | class MemNet(nn.Module):
12 | def __init__(self, in_channels, channels, num_memblock, num_resblock):
13 | super(MemNet, self).__init__()
14 | self.feature_extractor = BNReLUConv(in_channels, channels)
15 | self.reconstructor = BNReLUConv(channels, in_channels)
16 | self.dense_memory = nn.ModuleList(
17 | [MemoryBlock(channels, num_resblock, i+1) for i in range(num_memblock)]
18 | )
19 | self.freeze_bn = True
20 | self.freeze_bn_affine = True
21 |
22 | def forward(self, x):
23 | residual = x
24 | out = self.feature_extractor(x)
25 | ys = [out]
26 | for memory_block in self.dense_memory:
27 | out = memory_block(out, ys)
28 | out = self.reconstructor(out)
29 |
30 | out = out + residual
31 |
32 | return out
33 |
34 |
35 | class MemoryBlock(nn.Module):
36 | """Note: num_memblock denotes the number of MemoryBlock currently"""
37 | def __init__(self, channels, num_resblock, num_memblock):
38 | super(MemoryBlock, self).__init__()
39 | self.recursive_unit = nn.ModuleList(
40 | [ResidualBlock(channels) for i in range(num_resblock)]
41 | )
42 | self.gate_unit = BNReLUConv((num_resblock+num_memblock) * channels, channels, 1, 1, 0)
43 |
44 | def forward(self, x, ys):
45 | """ys is a list which contains long-term memory coming from previous memory block
46 | xs denotes the short-term memory coming from recursive unit
47 | """
48 | xs = []
49 | residual = x
50 | for layer in self.recursive_unit:
51 | x = layer(x)
52 | xs.append(x)
53 |
54 | gate_out = self.gate_unit(torch.cat(xs+ys, 1))
55 | ys.append(gate_out)
56 | return gate_out
57 |
58 |
59 | class ResidualBlock(torch.nn.Module):
60 | """ResidualBlock
61 | introduced in: https://arxiv.org/abs/1512.03385
62 | x - Relu - Conv - Relu - Conv - x
63 | """
64 |
65 | def __init__(self, channels, k=3, s=1, p=1):
66 | super(ResidualBlock, self).__init__()
67 | self.relu_conv1 = BNReLUConv(channels, channels, k, s, p)
68 | self.relu_conv2 = BNReLUConv(channels, channels, k, s, p)
69 |
70 | def forward(self, x):
71 | residual = x
72 | out = self.relu_conv1(x)
73 | out = self.relu_conv2(out)
74 | out = out + residual
75 | return out
76 |
77 |
78 | class BNReLUConv(nn.Sequential):
79 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=True):
80 | super(BNReLUConv, self).__init__()
81 | self.add_module('bn', BatchNorm2d(in_channels))
82 | self.add_module('relu', nn.ReLU(inplace=inplace))
83 | self.add_module('conv', nn.Conv2d(in_channels, channels, k, s, p, bias=False))
84 |
85 |
86 | def memnet():
87 | net = MemNet(31, 64, 6, 6)
88 | net.use_2dconv = True
89 | net.bandwise = False
90 | return net
--------------------------------------------------------------------------------
/hsiboard/util.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from collections import OrderedDict
4 | from os.path import exists, join
5 |
6 | import imageio
7 | import numpy as np
8 | import streamlit as st
9 | import cv2
10 |
11 | BACKGROUND_COLOR = 'white'
12 | COLOR = 'black'
13 |
14 |
15 | def set_page_container_style(
16 | max_width: int = 1100, max_width_100_percent: bool = True,
17 | padding_top: int = 2, padding_right: int = 2, padding_left: int = 2, padding_bottom: int = 10,
18 | color: str = COLOR, background_color: str = BACKGROUND_COLOR,
19 | ):
20 | if max_width_100_percent:
21 | max_width_str = f'max-width: 100%;'
22 | else:
23 | max_width_str = f'max-width: {max_width}px;'
24 | st.markdown(
25 | f'''
26 |
42 | ''',
43 | unsafe_allow_html=True,
44 | )
45 |
46 |
47 | def listdir(path, exclude=[]):
48 | outputs = []
49 | for folder in os.listdir(path):
50 | if os.path.isdir(join(path, folder)) and folder not in exclude:
51 | outputs.append(folder)
52 | return outputs
53 |
54 |
55 | def load_method_stat(logdir):
56 | stat = []
57 | for folder in listdir(logdir):
58 | path = join(logdir, folder, 'log.json')
59 | if exists(path):
60 | data = json.load(open(path, 'r'))
61 | s = OrderedDict()
62 | s['Name'] = folder
63 | s.update(data['avg'])
64 | stat.append(s)
65 | return stat
66 |
67 |
68 | def load_stat(logdir):
69 | total_stat = {}
70 | for folder in listdir(logdir, exclude=['gt']):
71 | try:
72 | stat = load_method_stat(join(logdir, folder))
73 | total_stat[folder] = stat
74 | except:
75 | print('error loading', folder, 'ignored')
76 | return total_stat
77 |
78 | def load_per_image_stat(logdir):
79 | path = join(logdir, 'log.json')
80 | data = json.load(open(path, 'r'))
81 | return data['detail']
82 |
83 | def load_imgs(path):
84 | out = {}
85 | for name in os.listdir(path):
86 | if name.endswith('png'):
87 | img = imageio.imread(join(path, name))
88 | img = np.array(img).clip(0, 255)
89 | out[name] = img
90 | return out
91 |
92 |
93 | def make_grid(rows, cols):
94 | grid = [0] * rows
95 | for i in range(rows):
96 | with st.container():
97 | grid[i] = st.columns(cols)
98 | return grid
99 |
100 | def encode_image(img):
101 | img = np.uint8(img.clip(0,1)*255)
102 | if len(img.shape) == 3: img = img[:,:,::-1]
103 | _, encoded_image = cv2.imencode('.png', img)
104 | data = encoded_image.tobytes()
105 | return data
--------------------------------------------------------------------------------
/hsir/model/qrnn3d/conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from sync_batchnorm import SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
4 |
5 | BatchNorm3d = SynchronizedBatchNorm3d
6 |
7 |
8 | class BNReLUConv3d(nn.Sequential):
9 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False):
10 | super(BNReLUConv3d, self).__init__()
11 | self.add_module('bn', BatchNorm3d(in_channels))
12 | self.add_module('relu', nn.ReLU(inplace=inplace))
13 | self.add_module('conv', nn.Conv3d(in_channels, channels, k, s, p, bias=False))
14 |
15 |
16 | class BNReLUDeConv3d(nn.Sequential):
17 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False):
18 | super(BNReLUDeConv3d, self).__init__()
19 | self.add_module('bn', BatchNorm3d(in_channels))
20 | self.add_module('relu', nn.ReLU(inplace=inplace))
21 | self.add_module('deconv', nn.ConvTranspose3d(in_channels, channels, k, s, p, bias=False))
22 |
23 |
24 | class BNReLUUpsampleConv3d(nn.Sequential):
25 | def __init__(self, in_channels, channels, k=3, s=1, p=1, upsample=(1,2,2), inplace=False):
26 | super(BNReLUUpsampleConv3d, self).__init__()
27 | self.add_module('bn', BatchNorm3d(in_channels))
28 | self.add_module('relu', nn.ReLU(inplace=inplace))
29 | self.add_module('upsample_conv', UpsampleConv3d(in_channels, channels, k, s, p, bias=False, upsample=upsample))
30 |
31 |
32 | class UpsampleConv3d(torch.nn.Module):
33 | """UpsampleConvLayer
34 | Upsamples the input and then does a convolution. This method gives better results
35 | compared to ConvTranspose2d.
36 | ref: http://distill.pub/2016/deconv-checkerboard/
37 | """
38 |
39 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=True, upsample=None):
40 | super(UpsampleConv3d, self).__init__()
41 | self.upsample = upsample
42 | if upsample:
43 | self.upsample_layer = torch.nn.Upsample(scale_factor=upsample, mode='trilinear', align_corners=True)
44 |
45 | self.conv3d = torch.nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
46 |
47 | def forward(self, x):
48 | x_in = x
49 | if self.upsample:
50 | x_in = self.upsample_layer(x_in)
51 | out = self.conv3d(x_in)
52 | return out
53 |
54 |
55 | class BasicConv3d(nn.Sequential):
56 | def __init__(self, in_channels, channels, k=3, s=1, p=1, bias=False, bn=True):
57 | super(BasicConv3d, self).__init__()
58 | if bn:
59 | self.add_module('bn', BatchNorm3d(in_channels))
60 | self.add_module('conv', nn.Conv3d(in_channels, channels, k, s, p, bias=bias))
61 |
62 |
63 | class BasicDeConv3d(nn.Sequential):
64 | def __init__(self, in_channels, channels, k=3, s=1, p=1, bias=False, bn=True):
65 | super(BasicDeConv3d, self).__init__()
66 | if bn:
67 | self.add_module('bn', BatchNorm3d(in_channels))
68 | self.add_module('deconv', nn.ConvTranspose3d(in_channels, channels, k, s, p, bias=bias))
69 |
70 |
71 | class BasicUpsampleConv3d(nn.Sequential):
72 | def __init__(self, in_channels, channels, k=3, s=1, p=1, upsample=(1,2,2), bn=True):
73 | super(BasicUpsampleConv3d, self).__init__()
74 | if bn:
75 | self.add_module('bn', BatchNorm3d(in_channels))
76 | self.add_module('upsample_conv', UpsampleConv3d(in_channels, channels, k, s, p, bias=False, upsample=upsample))
77 |
--------------------------------------------------------------------------------
/hsir/data/dataloader.py:
--------------------------------------------------------------------------------
1 | import torchvision.transforms as T
2 | import torch.utils.data as D
3 | import hsir.data.transform.noise as N
4 | import hsir.data.transform.general as G
5 | from hsir.data import HSITestDataset, HSITrainDataset
6 | from hsir.data.utils import worker_init_fn
7 |
8 |
9 | def gaussian_loader_train_s1(root, use_chw=False):
10 | common_transform = G.Identity()
11 | input_transform = T.Compose([
12 | N.AddNoise(50),
13 | G.HSI2Tensor(use_chw=use_chw)
14 | ])
15 | target_transform = G.HSI2Tensor(use_chw=use_chw)
16 | dataset = HSITrainDataset(root, input_transform, target_transform, common_transform)
17 | loader = D.DataLoader(dataset, batch_size=16, shuffle=True, num_workers=8,
18 | pin_memory=True, worker_init_fn=worker_init_fn)
19 | return loader
20 |
21 |
22 | def gaussian_loader_train_s2(root, use_chw=False):
23 | common_transform = G.RandomCrop(32, 32)
24 | input_transform = T.Compose([
25 | N.AddNoiseBlind([10, 30, 50, 70]),
26 | G.HSI2Tensor(use_chw=use_chw)
27 | ])
28 | target_transform = G.HSI2Tensor(use_chw=use_chw)
29 | dataset = HSITrainDataset(root, input_transform, target_transform, common_transform)
30 | loader = D.DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8,
31 | pin_memory=True, worker_init_fn=worker_init_fn)
32 | return loader
33 |
34 |
35 | def gaussian_loader_train_s2_16(root, use_chw=False):
36 | common_transform = G.Identity()
37 | input_transform = T.Compose([
38 | N.AddNoiseBlind([10, 30, 50, 70]),
39 | G.HSI2Tensor(use_chw=use_chw)
40 | ])
41 | target_transform = G.HSI2Tensor(use_chw=use_chw)
42 | dataset = HSITrainDataset(root, input_transform, target_transform, common_transform)
43 | loader = D.DataLoader(dataset, batch_size=16, shuffle=True, num_workers=8,
44 | pin_memory=True, worker_init_fn=worker_init_fn)
45 | return loader
46 |
47 |
48 | def gaussian_loader_val(root, use_chw=False):
49 | dataset = HSITestDataset(root, size=5, use_chw=use_chw)
50 | loader = D.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
51 | return loader
52 |
53 |
54 | def gaussian_loader_test(root, use_chw=False):
55 | dataset = HSITestDataset(root, return_name=True, use_chw=use_chw)
56 | loader = D.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
57 | return loader
58 |
59 |
60 | def complex_loader_train(root, use_chw=False):
61 | sigmas = [10, 30, 50, 70]
62 |
63 | common_transform = G.Identity()
64 | input_transform = T.Compose([
65 | N.AddNoiseNoniid(sigmas),
66 | G.SequentialSelect(
67 | transforms=[
68 | lambda x: x,
69 | N.AddNoiseImpulse(),
70 | N.AddNoiseStripe(),
71 | N.AddNoiseDeadline()
72 | ]
73 | ),
74 | G.HSI2Tensor(use_chw=use_chw)
75 | ])
76 | target_transform = G.HSI2Tensor(use_chw=use_chw)
77 | dataset = HSITrainDataset(root, input_transform, target_transform, common_transform)
78 | loader = D.DataLoader(dataset, batch_size=16, shuffle=True, num_workers=8,
79 | pin_memory=True, worker_init_fn=worker_init_fn)
80 | return loader
81 |
82 |
83 | def complex_loader_val(root, use_chw=False):
84 | dataset = HSITestDataset(root, use_chw=use_chw, size=5)
85 | loader = D.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
86 | return loader
87 |
--------------------------------------------------------------------------------
/hsir/data/transform/sr.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import scipy.ndimage
4 |
5 | # C,H,W format
6 |
7 |
8 | """For Super-Resolution"""
9 |
10 |
11 | class SRDegrade(object):
12 | def __init__(self, scale_factor=2):
13 | self.scale_factor = scale_factor
14 |
15 | def __call__(self, img):
16 | from scipy.ndimage import zoom
17 | img = zoom(img, zoom=(1, 1. / self.scale_factor, 1. / self.scale_factor))
18 | img = zoom(img, zoom=(1, self.scale_factor, self.scale_factor))
19 | return img
20 |
21 |
22 | class GaussianBlurScipy(object):
23 | def __init__(self, ksize=8, sigma=3):
24 | self.sigma = sigma
25 | self.truncate = (((ksize - 1) / 2) - 0.5) / sigma
26 |
27 | def __call__(self, img):
28 | from scipy.ndimage.filters import gaussian_filter
29 | img = gaussian_filter(img, sigma=self.sigma, truncate=self.truncate)
30 | return img
31 |
32 |
33 | # ---------------------- Blur ----------------------- #
34 |
35 | def fspecial_gaussian(hsize, sigma):
36 | hsize = [hsize, hsize]
37 | siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
38 | std = sigma
39 | [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1),
40 | np.arange(-siz[0], siz[0] + 1))
41 | arg = -(x * x + y * y) / (2 * std * std)
42 | h = np.exp(arg)
43 | h[h < scipy.finfo(float).eps * h.max()] = 0
44 | sumh = h.sum()
45 | if sumh != 0:
46 | h = h / sumh
47 | return h
48 |
49 |
50 | class AbstractBlur:
51 | def __call__(self, img):
52 | img_L = scipy.ndimage.filters.convolve(
53 | img, np.expand_dims(self.kernel, axis=0), mode='wrap')
54 | return img_L
55 |
56 |
57 | class GaussianBlur(AbstractBlur):
58 | def __init__(self, ksize=8, sigma=3):
59 | self.kernel = fspecial_gaussian(ksize, sigma)
60 |
61 |
62 | class UniformBlur(AbstractBlur):
63 | def __init__(self, ksize):
64 | self.kernel = np.ones((ksize, ksize)) / (ksize * ksize)
65 |
66 |
67 | ## -------------------- Resize -------------------- ##
68 |
69 | class KFoldDownsample:
70 | ''' k-fold downsampler:
71 | Keeping the upper-left pixel for each distinct sfxsf patch and discarding the others
72 | '''
73 |
74 | def __init__(self, sf):
75 | self.sf = sf
76 |
77 | def __call__(self, img):
78 | st = 0
79 | return img[:, st::self.sf, st::self.sf]
80 |
81 |
82 | class Resize:
83 | def __init__(self, sf, mode='cubic'):
84 | self.sf = sf
85 | self.mode = self.mode_map[mode]
86 |
87 | def __call__(self, img):
88 | img = img.transpose(1, 2, 0)
89 | img = cv2.resize(img, (int(img.shape[1] * self.sf), int(img.shape[0] * self.sf)), interpolation=self.mode)
90 | img = img.transpose(2, 0, 1)
91 | return img
92 |
93 | mode_map = {
94 | 'nearest': cv2.INTER_NEAREST,
95 | 'linear': cv2.INTER_LINEAR,
96 | 'cubic': cv2.INTER_CUBIC,
97 | 'area': cv2.INTER_AREA
98 | }
99 |
100 |
101 | class BicubicDownsample(Resize):
102 | def __init__(self, sf):
103 | super().__init__(1 / sf, 'cubic')
104 |
105 | def __call__(self, img):
106 | return super().__call__(img)
107 |
108 |
109 | class BicubicUpsample(Resize):
110 | def __init__(self, sf):
111 | super().__init__(sf, 'cubic')
112 |
113 | def __call__(self, img):
114 | return super().__call__(img)
115 |
--------------------------------------------------------------------------------
/hsir/schedule.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 |
4 | TrainSchedule = namedtuple('TrainSchedule', ['base_lr', 'stage1_epoch', 'lr_schedule', 'max_epochs'])
5 |
6 |
7 | denoise_default = TrainSchedule(
8 | max_epochs=80,
9 | stage1_epoch=30,
10 | base_lr=1e-3,
11 | lr_schedule={
12 | 0: 1e-3,
13 | 20: 1e-4,
14 | 30: 1e-3,
15 | 45: 1e-4,
16 | 55: 5e-5,
17 | 60: 1e-5,
18 | 65: 5e-6,
19 | 75: 1e-6,
20 | },
21 | )
22 |
23 | denoise_1x = TrainSchedule(
24 | max_epochs=30,
25 | stage1_epoch=15,
26 | base_lr=1e-3,
27 | lr_schedule={
28 | 0: 1e-3,
29 | 5: 1e-4,
30 | 10: 1e-3,
31 | 15: 1e-4,
32 | 20: 5e-5,
33 | 25: 1e-5,
34 | 30: 5e-6,
35 | },
36 | )
37 |
38 | denoise_unet = TrainSchedule(
39 | max_epochs=80,
40 | stage1_epoch=30,
41 | base_lr=2e-4,
42 | lr_schedule={
43 | 0: 2e-4,
44 | 20: 1e-4,
45 | 30: 2e-4,
46 | 40: 1e-4,
47 | 50: 5e-5,
48 | 60: 1e-5,
49 | 70: 5e-6,
50 | },
51 | )
52 |
53 | denoise_grunet = TrainSchedule(
54 | max_epochs=80,
55 | stage1_epoch=30,
56 | base_lr=1e-4,
57 | lr_schedule={
58 | 45: 5e-5,
59 | 60: 1e-5,
60 | 70: 5e-6,
61 | },
62 | )
63 |
64 | denoise_restormer = TrainSchedule(
65 | max_epochs=80,
66 | stage1_epoch=30,
67 | base_lr=1e-4,
68 | lr_schedule={
69 | 0: 1e-4,
70 | 45: 5e-5,
71 | 55: 1e-5,
72 | 65: 5e-6,
73 | 75: 1e-6,
74 | },
75 | )
76 |
77 | denoise_hsid_cnn = denoise_restormer
78 |
79 | denoise_swinir = TrainSchedule(
80 | max_epochs=80,
81 | stage1_epoch=30,
82 | base_lr=2e-4,
83 | lr_schedule={
84 | 0: 2e-4,
85 | 17: 1e-4,
86 | 22: 5e-5,
87 | 30: 1e-4,
88 | 50: 1e-5,
89 | 60: 5e-6,
90 | 70: 1e-6,
91 | },
92 | )
93 |
94 | denoise_uformer = TrainSchedule(
95 | max_epochs=80,
96 | stage1_epoch=30,
97 | base_lr=2e-4,
98 | lr_schedule={
99 | 0: 2e-4,
100 | 17: 1e-4,
101 | 40: 5e-5,
102 | 50: 1e-5,
103 | 60: 5e-6,
104 | 70: 1e-6,
105 | },
106 | )
107 |
108 |
109 | denoise_complex_uformer = TrainSchedule(
110 | max_epochs=110,
111 | stage1_epoch=30,
112 | base_lr=1e-4,
113 | lr_schedule={
114 | 80: 1e-4,
115 | 85: 5e-5,
116 | 90: 1e-5,
117 | 95: 5e-6,
118 | 100: 1e-6
119 | },
120 | )
121 |
122 |
123 | denoise_complex_swinir = TrainSchedule(
124 | max_epochs=110,
125 | stage1_epoch=30,
126 | base_lr=1e-4,
127 | lr_schedule={
128 | 80: 1e-4,
129 | 85: 5e-5,
130 | 90: 1e-5,
131 | 95: 5e-6,
132 | 100: 1e-6
133 | },
134 | )
135 |
136 |
137 | denoise_complex_default = TrainSchedule(
138 | max_epochs=110,
139 | stage1_epoch=30,
140 | base_lr=1e-3,
141 | lr_schedule={
142 | 80: 1e-3,
143 | 90: 5e-4,
144 | 95: 1e-4,
145 | 100: 5e-5,
146 | 105: 1e-5,
147 | },
148 | )
149 |
150 | denoise_complex_restormer = TrainSchedule(
151 | max_epochs=110,
152 | stage1_epoch=30,
153 | base_lr=1e-4,
154 | lr_schedule={
155 | 80: 1e-4,
156 | 82: 5e-5,
157 | 90: 1e-5,
158 | 95: 5e-6,
159 | 100: 1e-6
160 | },
161 | )
162 |
163 | denoise_complex_hsid_cnn = TrainSchedule(
164 | max_epochs=110,
165 | stage1_epoch=30,
166 | base_lr=1e-4,
167 | lr_schedule={
168 | 80: 1e-4,
169 | 95: 5e-5,
170 | 100: 1e-5,
171 | 105: 1e-6,
172 | },
173 | )
174 |
--------------------------------------------------------------------------------
/hsirun/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | from os.path import join
4 |
5 | from torchlight.utils import instantiate, locate
6 | from torchlight.nn.utils import adjust_learning_rate, get_learning_rate
7 |
8 | import hsir.data.dataloader as loaders
9 | from hsir.trainer import Trainer
10 | from hsir.scheduler import MultiStepSetLR
11 |
12 |
13 | def train_cfg():
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--arch', '-a', required=True)
16 | parser.add_argument('--noise', default='gaussian', choices=['gaussian', 'complex'])
17 | parser.add_argument('--name', '-n', type=str, default=None,
18 | help='name of the experiment, if not specified, arch will be used.')
19 | parser.add_argument('--lr', type=float, default=None)
20 | parser.add_argument('-s', '--schedule', type=str, default='hsir.schedule.denoise_default')
21 | parser.add_argument('--resume', '-r', type=str, default=None)
22 | parser.add_argument('--bandwise', action='store_true')
23 | parser.add_argument('--use-conv2d', action='store_true')
24 | parser.add_argument('--train-root', type=str, default='data/ICVL64_31_100.db')
25 | parser.add_argument('--test-root', type=str, default='data')
26 | parser.add_argument('--save-root', type=str, default='checkpoints')
27 | parser.add_argument('--save-freq', type=int, default=10)
28 | parser.add_argument('--gpu-ids', type=str, default='0', help='gpu ids')
29 | cfg = parser.parse_args()
30 | cfg.gpu_ids = [int(id) for id in cfg.gpu_ids.split(',')]
31 | cfg.name = cfg.arch if cfg.name is None else cfg.name
32 | return cfg
33 |
34 |
35 | def main():
36 | cfg = train_cfg()
37 | net = instantiate(cfg.arch)
38 | schedule = locate(cfg.schedule)
39 | trainer = Trainer(
40 | net,
41 | lr=schedule.base_lr,
42 | save_dir=join(cfg.save_root, cfg.name),
43 | gpu_ids=cfg.gpu_ids,
44 | bandwise=cfg.bandwise,
45 | )
46 | trainer.logger.print(cfg)
47 | if cfg.resume: trainer.load(cfg.resume)
48 |
49 | # preare dataset
50 | if cfg.noise == 'gaussian':
51 | train_loader1 = loaders.gaussian_loader_train_s1(cfg.train_root, cfg.use_conv2d)
52 | train_loader2 = loaders.gaussian_loader_train_s2_16(cfg.train_root, cfg.use_conv2d)
53 | val_name = 'icvl_512_50'
54 | val_loader = loaders.gaussian_loader_val(join(cfg.test_root, val_name), cfg.use_conv2d)
55 | else:
56 | train_loader = loaders.complex_loader_train(cfg.train_root, cfg.use_conv2d)
57 | val_name = 'icvl_512_mixture'
58 | val_loader = loaders.complex_loader_val(join(cfg.test_root, val_name), cfg.use_conv2d)
59 |
60 | """Main loop"""
61 | if cfg.lr: adjust_learning_rate(trainer.optimizer, cfg.lr) # override lr
62 | lr_scheduler = MultiStepSetLR(trainer.optimizer, schedule.lr_schedule, epoch=trainer.epoch)
63 | epoch_per_save = cfg.save_freq
64 | best_psnr = 0
65 | while trainer.epoch < schedule.max_epochs:
66 | np.random.seed() # reset seed per epoch, otherwise the noise will be added with a specific pattern
67 | trainer.logger.print('Epoch [{}] Use lr={}'.format(trainer.epoch, get_learning_rate(trainer.optimizer)))
68 |
69 | # train
70 | if cfg.noise == 'gaussian':
71 | if trainer.epoch == 30: best_psnr = 0
72 | if trainer.epoch < 30: trainer.train(train_loader1)
73 | else: trainer.train(train_loader2, warm_up=trainer.epoch == 30)
74 | else:
75 | trainer.train(train_loader, warm_up=trainer.epoch == 80)
76 |
77 | # save ckpt
78 | metrics = trainer.validate(val_loader, val_name)
79 | if metrics['psnr'] > best_psnr:
80 | best_psnr = metrics['psnr']
81 | trainer.save_checkpoint('model_best.pth')
82 | if trainer.epoch % epoch_per_save == 0:
83 | trainer.save_checkpoint()
84 | trainer.save_checkpoint('model_latest.pth')
85 |
86 | lr_scheduler.step()
87 |
88 |
89 | if __name__ == '__main__':
90 | main()
91 |
--------------------------------------------------------------------------------
/hsir/model/hsidcnn.py:
--------------------------------------------------------------------------------
1 | """
2 | Ported from official caffe implementation
3 | https://github.com/qzhang95/HSID-CNN
4 | """
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | def hsid_cnn():
11 | return HSIDCNN()
12 |
13 |
14 | class HSIDCNN(nn.Module):
15 | def __init__(self, num_adj_bands=12):
16 | super().__init__()
17 | self.num_adj_bands = num_adj_bands
18 |
19 |
20 | self.conv_k3 = nn.Conv2d(num_adj_bands * 2, 20, 3, 1, 1)
21 | self.conv_k5 = nn.Conv2d(num_adj_bands * 2, 20, 5, 1, 2)
22 | self.conv_k7 = nn.Conv2d(num_adj_bands * 2, 20, 7, 1, 3)
23 |
24 | self.conv_k3_2 = nn.Conv2d(1, 20, 3, 1, 1)
25 | self.conv_k5_2 = nn.Conv2d(1, 20, 5, 1, 2)
26 | self.conv_k7_2 = nn.Conv2d(1, 20, 7, 1, 3)
27 |
28 | self.conv1 = nn.Conv2d(120, 60, 3, 1, 1)
29 | self.conv2 = nn.Conv2d(60, 60, 3, 1, 1)
30 | self.conv3 = nn.Conv2d(60, 60, 3, 1, 1)
31 | self.conv4 = nn.Conv2d(60, 60, 3, 1, 1)
32 | self.conv5 = nn.Conv2d(60, 60, 3, 1, 1)
33 | self.conv6 = nn.Conv2d(60, 60, 3, 1, 1)
34 | self.conv7 = nn.Conv2d(60, 60, 3, 1, 1)
35 | self.conv8 = nn.Conv2d(60, 60, 3, 1, 1)
36 | self.conv9 = nn.Conv2d(60, 60, 3, 1, 1)
37 |
38 | self.tail = nn.Sequential(
39 | nn.Conv2d(60 * 4, 15, 3, 1, 1),
40 | nn.ReLU(),
41 | nn.Conv2d(15, 1, 3, 1, 1),
42 | )
43 |
44 | self._reset_parameters()
45 |
46 | def _reset_parameters(self):
47 | for m in self.modules():
48 | if isinstance(m, nn.Conv2d):
49 | torch.nn.init.kaiming_normal_(m.weight, mode='fan_out')
50 | torch.nn.init.constant_(m.bias, 0)
51 |
52 | def forward(self, x):
53 | num_bands = x.shape[1]
54 | outputs = []
55 |
56 | inputs_x, inputs_adj_x = [], []
57 | for i in range(self.num_adj_bands):
58 | inputs_x.append(x[:, i:i + 1, :, :])
59 | inputs_adj_x.append(x[:, :self.num_adj_bands * 2, :, :])
60 |
61 | for i in range(self.num_adj_bands, num_bands - self.num_adj_bands):
62 | adj = torch.cat([x[:, i - self.num_adj_bands:i, :, :],
63 | x[:, i + 1:i + 1 + self.num_adj_bands, :, :]], dim=1)
64 | inputs_x.append(x[:, i:i + 1, :, :])
65 | inputs_adj_x.append(adj)
66 |
67 | for i in range(num_bands - self.num_adj_bands, num_bands):
68 | inputs_x.append(x[:, i:i + 1, :, :])
69 | inputs_adj_x.append(x[:, -self.num_adj_bands * 2:, :, :])
70 |
71 | for i in range(num_bands):
72 | output = self._forward(inputs_x[i], inputs_adj_x[i])
73 | outputs.append(output)
74 | return torch.cat(outputs, dim=1)
75 |
76 |
77 | def _forward(self, x, adj_x):
78 | feat3 = self.conv_k3(adj_x)
79 | feat5 = self.conv_k5(adj_x)
80 | feat7 = self.conv_k7(adj_x)
81 | feat_3_5_7 = torch.cat([feat3, feat5, feat7], dim=1).relu()
82 |
83 | feat3_2 = self.conv_k3_2(x)
84 | feat5_2 = self.conv_k5_2(x)
85 | feat7_2 = self.conv_k7_2(x)
86 | feat_3_5_7_2 = torch.cat([feat3_2, feat5_2, feat7_2], dim=1).relu()
87 |
88 | feat_all = torch.cat([feat_3_5_7, feat_3_5_7_2], dim=1)
89 |
90 | tmp = self.conv1(feat_all).relu()
91 | tmp = self.conv2(tmp).relu()
92 | feat_conv3 = self.conv3(tmp).relu()
93 |
94 | tmp = self.conv4(feat_conv3).relu()
95 | feat_conv5 = self.conv5(tmp).relu()
96 |
97 | tmp = self.conv6(feat_conv5).relu()
98 | feat_conv7 = self.conv7(tmp).relu()
99 |
100 | tmp = self.conv8(feat_conv7).relu()
101 | feat_conv9 = self.conv9(tmp).relu()
102 |
103 | feat_all = torch.cat([feat_conv3, feat_conv5, feat_conv7, feat_conv9], dim=1)
104 | out = self.tail(feat_all)
105 |
106 | return out
107 |
108 |
109 | if __name__ == '__main__':
110 | net = HSIDCNN().cuda()
111 | x = torch.randn(4,31,64,64).cuda()
112 | out = net(x)
113 | print(out.shape)
114 |
--------------------------------------------------------------------------------
/hsir/model/unet3d.py:
--------------------------------------------------------------------------------
1 | """
2 | Model from
3 | Hyperspectral Image Denoising With Realistic Data ICCV 2021
4 | """
5 |
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torch.utils.data
9 | import torch
10 | import math
11 |
12 | class conv_block(nn.Module):
13 | """
14 | Convolution Block
15 | """
16 | def __init__(self, in_ch, out_ch):
17 | super(conv_block, self).__init__()
18 |
19 | self.conv = nn.Sequential(
20 | # nn.InstanceNorm3d(in_ch),
21 | nn.Conv3d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
22 | nn.LeakyReLU(negative_slope=0.01, inplace=True),
23 | # nn.InstanceNorm3d(out_ch),
24 | nn.Conv3d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
25 | nn.LeakyReLU(negative_slope=0.01, inplace=True))
26 |
27 | self.conv_residual = nn.Conv3d(in_ch, out_ch, kernel_size=1, stride=1, bias=True)
28 |
29 | def forward(self, x):
30 |
31 | x = self.conv(x) + self.conv_residual(x)
32 | return x
33 |
34 | class U_Net_3D(nn.Module):
35 | def __init__(self, in_ch=1, out_ch=1, dim=32):
36 | super(U_Net_3D, self).__init__()
37 |
38 | n1 = dim
39 | filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
40 |
41 | self.Down1 = nn.Conv3d(filters[0], filters[0], kernel_size=[1,4,4], stride=[1,2,2], padding=[0,1,1], bias=True)
42 | self.Down2 = nn.Conv3d(filters[1], filters[1], kernel_size=[1,4,4], stride=[1,2,2], padding=[0,1,1], bias=True)
43 | self.Down3 = nn.Conv3d(filters[2], filters[2], kernel_size=[1,4,4], stride=[1,2,2], padding=[0,1,1], bias=True)
44 | self.Down4 = nn.Conv3d(filters[3], filters[3], kernel_size=[1,4,4], stride=[1,2,2], padding=[0,1,1], bias=True)
45 |
46 | self.Conv1 = conv_block(in_ch, filters[0])
47 | self.Conv2 = conv_block(filters[0], filters[1])
48 | self.Conv3 = conv_block(filters[1], filters[2])
49 | self.Conv4 = conv_block(filters[2], filters[3])
50 | self.Conv5 = conv_block(filters[3], filters[4])
51 |
52 | self.Up5 = nn.ConvTranspose3d(filters[4], filters[3], kernel_size=[1,2,2], stride=[1,2,2], padding=0, bias=True)
53 | self.Up_conv5 = conv_block(filters[4], filters[3])
54 |
55 | self.Up4 = nn.ConvTranspose3d(filters[3], filters[2], kernel_size=[1,2,2], stride=[1,2,2], padding=0, bias=True)
56 | self.Up_conv4 = conv_block(filters[3], filters[2])
57 |
58 | self.Up3 = nn.ConvTranspose3d(filters[2], filters[1], kernel_size=[1,2,2], stride=[1,2,2], padding=0, bias=True)
59 | self.Up_conv3 = conv_block(filters[2], filters[1])
60 |
61 | self.Up2 = nn.ConvTranspose3d(filters[1], filters[0], kernel_size=[1,2,2], stride=[1,2,2], padding=0, bias=True)
62 | self.Up_conv2 = conv_block(filters[1], filters[0])
63 |
64 | self.Conv = nn.Conv3d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)
65 |
66 | def forward(self, x):
67 |
68 | e1 = self.Conv1(x)
69 |
70 | e2 = self.Down1(e1)
71 | e2 = self.Conv2(e2)
72 |
73 | e3 = self.Down2(e2)
74 | e3 = self.Conv3(e3)
75 |
76 | e4 = self.Down3(e3)
77 | e4 = self.Conv4(e4)
78 |
79 | e5 = self.Down4(e4)
80 | e5 = self.Conv5(e5)
81 |
82 | d5 = self.Up5(e5)
83 | d5 = torch.cat((e4, d5), dim=1)
84 | d5 = self.Up_conv5(d5)
85 |
86 | d4 = self.Up4(d5)
87 | d4 = torch.cat((e3, d4), dim=1)
88 | d4 = self.Up_conv4(d4)
89 |
90 | d3 = self.Up3(d4)
91 | d3 = torch.cat((e2, d3), dim=1)
92 | d3 = self.Up_conv3(d3)
93 |
94 | d2 = self.Up2(d3)
95 | d2 = torch.cat((e1, d2), dim=1)
96 | d2 = self.Up_conv2(d2)
97 |
98 | out = self.Conv(d2)
99 |
100 | return out+x
101 |
102 | def unet3d():
103 | net = U_Net_3D()
104 | net.use_2dconv = False
105 | net.bandwise = False
106 | return net
107 |
108 | def unet3d_m():
109 | net = U_Net_3D(dim=16)
110 | net.use_2dconv = False
111 | net.bandwise = False
112 | return net
--------------------------------------------------------------------------------
/hsir/model/grunet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.utils.data
3 | import torch
4 |
5 | from .qrnn3d.qrnn import QRNNConv3D, QRNNUpsampleConv3d, BiQRNNConv3D, BiQRNNDeConv3D, QRNNDeConv3D
6 |
7 |
8 | class conv_block(nn.Module):
9 | def __init__(self, in_ch, out_ch, bn=False):
10 | super(conv_block, self).__init__()
11 |
12 | self.conv1 = QRNNConv3D(in_ch, out_ch, bn=bn)
13 | self.conv2 = QRNNConv3D(out_ch, out_ch, bn=bn)
14 | self.conv_residual = QRNNConv3D(in_ch, out_ch, k=1, s=1, p=0, bn=bn)
15 |
16 | def forward(self, x, reverse=False):
17 | residual = self.conv2(self.conv1(x, reverse=reverse), reverse=reverse)
18 | x = residual + self.conv_residual(x, reverse=reverse)
19 | return x
20 |
21 |
22 | class deconv_block(nn.Module):
23 | def __init__(self, in_ch, out_ch, bn=False):
24 | super(deconv_block, self).__init__()
25 |
26 | self.conv1 = QRNNDeConv3D(in_ch, out_ch, bn=bn)
27 | self.conv2 = QRNNDeConv3D(out_ch, out_ch, bn=bn)
28 | self.conv_residual = QRNNDeConv3D(in_ch, out_ch, k=1, s=1, p=0, bn=bn)
29 |
30 | def forward(self, x, reverse=False):
31 | residual = self.conv2(self.conv1(x, reverse=reverse), reverse=reverse)
32 | x = residual + self.conv_residual(x, reverse=reverse)
33 | return x
34 |
35 |
36 | class GRUnet(nn.Module):
37 | def __init__(self, in_ch=1, out_ch=1, use_noise_map=False, bn=False):
38 | super(GRUnet, self).__init__()
39 | self.use_2dconv = False
40 | self.bandwise = False
41 | self.use_noise_map = use_noise_map
42 |
43 | n1 = 16
44 | filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
45 |
46 | self.Down1 = QRNNConv3D(filters[0], filters[0], k=3, s=(1, 2, 2), p=1, bn=bn)
47 | self.Down2 = QRNNConv3D(filters[1], filters[1], k=3, s=(1, 2, 2), p=1, bn=bn)
48 | self.Down3 = QRNNConv3D(filters[2], filters[2], k=3, s=(1, 2, 2), p=1, bn=bn)
49 | self.Down4 = QRNNConv3D(filters[3], filters[3], k=3, s=(1, 2, 2), p=1, bn=bn)
50 |
51 | self.Conv1 = BiQRNNConv3D(in_ch, filters[0], bn=bn)
52 | self.Conv2 = conv_block(filters[0], filters[1], bn=bn)
53 | self.Conv3 = conv_block(filters[1], filters[2], bn=bn)
54 | self.Conv4 = conv_block(filters[2], filters[3], bn=bn)
55 | self.Conv5 = conv_block(filters[3], filters[4], bn=bn)
56 |
57 | self.Up5 = QRNNUpsampleConv3d(filters[4], filters[3], bn=bn)
58 | self.Up_conv5 = deconv_block(filters[4], filters[3], bn=bn)
59 |
60 | self.Up4 = QRNNUpsampleConv3d(filters[3], filters[2], bn=bn)
61 | self.Up_conv4 = deconv_block(filters[3], filters[2], bn=bn)
62 |
63 | self.Up3 = QRNNUpsampleConv3d(filters[2], filters[1], bn=bn)
64 | self.Up_conv3 = deconv_block(filters[2], filters[1], bn=bn)
65 |
66 | self.Up2 = QRNNUpsampleConv3d(filters[1], filters[0], bn=bn)
67 | self.Up_conv2 = deconv_block(filters[1], filters[0], bn=bn)
68 |
69 | self.Conv = BiQRNNDeConv3D(filters[0], 1, bias=True, bn=bn)
70 |
71 | def forward(self, x):
72 | # x: [B, C, B, W, H]
73 | e1 = self.Conv1(x)
74 |
75 | e2 = self.Down1(e1, reverse=True)
76 | e2 = self.Conv2(e2, reverse=False)
77 |
78 | e3 = self.Down2(e2, reverse=True)
79 | e3 = self.Conv3(e3, reverse=False)
80 |
81 | e4 = self.Down3(e3, reverse=True)
82 | e4 = self.Conv4(e4, reverse=False)
83 |
84 | e5 = self.Down4(e4, reverse=True)
85 | e5 = self.Conv5(e5, reverse=False)
86 |
87 | d5 = self.Up5(e5, reverse=True)
88 | d5 = torch.cat((e4, d5), dim=1)
89 | d5 = self.Up_conv5(d5, reverse=False)
90 |
91 | d4 = self.Up4(d5, reverse=True)
92 | d4 = torch.cat((e3, d4), dim=1)
93 | d4 = self.Up_conv4(d4, reverse=False)
94 |
95 | d3 = self.Up3(d4, reverse=True)
96 | d3 = torch.cat((e2, d3), dim=1)
97 | d3 = self.Up_conv3(d3, reverse=False)
98 |
99 | d2 = self.Up2(d3, reverse=True)
100 | d2 = torch.cat((e1, d2), dim=1)
101 | d2 = self.Up_conv2(d2, reverse=False)
102 |
103 | out = self.Conv(d2)
104 |
105 | if self.use_noise_map:
106 | return out + x[:, 0, :, :, :].unsqueeze(1)
107 | else:
108 | return out + x
109 |
110 |
111 | def grunet():
112 | return GRUnet()
113 |
114 |
115 | def grunet_noise_map():
116 | return GRUnet(in_ch=2, use_noise_map=True)
117 |
--------------------------------------------------------------------------------
/hsiboard/cmp.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import imageio
3 | from os.path import join
4 | import numpy as np
5 | import json
6 | import zipfile
7 | import io
8 |
9 | from util import *
10 | from box import *
11 |
12 |
13 | def main(logdir):
14 | set_page_container_style()
15 | with st.sidebar:
16 | st.title('HSIR Board')
17 | st.subheader('Comparator')
18 | st.caption(f'Directory: {logdir}')
19 | methods = listdir(logdir)
20 | selected_methods = st.sidebar.multiselect(
21 | "Select Method",
22 | methods,
23 | default=methods
24 | )
25 | datasets_list = [listdir(join(logdir, m)) for m in selected_methods]
26 | datasets = set(datasets_list[0]).intersection(*map(set, datasets_list))
27 | selected_dataset = st.sidebar.selectbox(
28 | "Select Dataset",
29 | datasets
30 | )
31 | selected_vis_type = st.sidebar.selectbox(
32 | "Select Image Type",
33 | ['color', 'gray']
34 | )
35 |
36 | img_names = os.listdir(join(logdir, selected_methods[0], selected_dataset, selected_vis_type))
37 | selected_img = st.sidebar.selectbox(
38 | "Select Image",
39 | img_names
40 | )
41 |
42 | st.header('Box')
43 |
44 | enable_enlarge = st.checkbox('Enlarge')
45 | enable_diff = st.checkbox('Difference Map')
46 | enable_sidebyside = st.checkbox('Side by Side')
47 |
48 | selected_box_pos = st.sidebar.selectbox(
49 | "Select Box Position",
50 | ['Bottom Right', 'Bottom Left', 'Top Right', 'Top Left'],
51 | )
52 |
53 | crow = st.slider('row coordinate', min_value=0.0, max_value=1.0, value=0.2)
54 | ccol = st.slider('col coordinate', min_value=0.0, max_value=1.0, value=0.2)
55 | vmax = st.slider('vmax', min_value=0.0, max_value=1.0, value=0.1)
56 |
57 | st.header('Layout')
58 | ncol = st.slider('number of columns', min_value=1, max_value=20, value=4)
59 |
60 | imgs = {}
61 | stats = {}
62 | for m in selected_methods:
63 | img = imageio.imread(join(logdir, m, selected_dataset, selected_vis_type, selected_img))
64 | imgs[m] = np.array(img, dtype=np.float32) / 255
65 | stat = json.load(open(join(logdir, m, selected_dataset, 'log.json')))
66 | stats[m] = stat['detail'][os.path.splitext(selected_img)[-2]]
67 |
68 | gt = imgs['gt']
69 | nrow = len(imgs) // ncol + 1
70 | grids = make_grid(nrow, ncol)
71 |
72 | download_data = {'img': {}, 'meta': {}}
73 | with st.container():
74 | idx = 0
75 | for name, im in imgs.items():
76 | img = im.copy()
77 | name = os.path.splitext(name)[-2]
78 | if enable_diff:
79 | diff = np.abs(img - gt)
80 | if len(diff.shape) == 3: diff = diff.mean(-1)
81 | img = convert_color(diff, vmax=vmax)
82 |
83 | if enable_sidebyside:
84 | h, w = img.shape[:2]
85 | img = addbox_with_diff(img, gt, (int(h * crow), int(w * ccol)), vmax=vmax)
86 |
87 | elif enable_enlarge:
88 | h, w = img.shape[:2]
89 | img = addbox(img, (int(h * crow), int(w * ccol)), bbpos=mapbbpox[selected_box_pos])
90 |
91 | ct = grids[idx // ncol][idx % ncol]
92 | ct.image(img, caption='%s [%.4f]' % (name, stats[name]['MPSNR']), clamp=[0, 1])
93 | idx += 1
94 |
95 | download_data['img'][name] = img
96 | download_data['meta'][name] = stats[name]
97 |
98 | with st.sidebar:
99 | st.header('Download')
100 | filename = os.path.splitext(selected_img)[-2]
101 | st.download_button('Donwload ' + filename,
102 | to_zip(download_data),
103 | file_name=f'{filename}_{selected_dataset}.zip',
104 | mime='application/zip')
105 |
106 |
107 | def to_zip(data):
108 | zip_buffer = io.BytesIO()
109 | with zipfile.ZipFile(zip_buffer, "a",
110 | zipfile.ZIP_DEFLATED, False) as zip_file:
111 | for file_name, img in data['img'].items():
112 | zip_file.writestr(file_name + '.png', encode_image(img))
113 |
114 | zip_file.writestr('meta.json', json.dumps(data['meta'], indent=4))
115 | return zip_buffer
116 |
117 |
118 | if __name__ == '__main__':
119 | parser = argparse.ArgumentParser('HSIR Board')
120 | parser.add_argument('--logdir', default='results')
121 | args = parser.parse_args()
122 |
123 | main(args.logdir)
124 |
--------------------------------------------------------------------------------
/hsir/model/trq3d/conv.py:
--------------------------------------------------------------------------------
1 | from .combinations import *
2 |
3 |
4 | class QRNN3DLayer(nn.Module):
5 | def __init__(self, in_channels, hidden_channels, conv_layer, act='tanh'):
6 | super(QRNN3DLayer, self).__init__()
7 | self.in_channels = in_channels
8 | self.hidden_channels = hidden_channels
9 | # quasi_conv_layer
10 | self.conv = conv_layer
11 | self.act = act
12 |
13 | def _conv_step(self, inputs):
14 | gates = self.conv(inputs)
15 | Z, F = gates.split(split_size=self.hidden_channels, dim=1)
16 | if self.act == 'tanh':
17 | return Z.tanh(), F.sigmoid()
18 | elif self.act == 'relu':
19 | return Z.relu(), F.sigmoid()
20 | elif self.act == 'none':
21 | return Z, F.sigmoid
22 | else:
23 | raise NotImplementedError
24 |
25 | def _rnn_step(self, z, f, h):
26 | # uses 'f pooling' at each time step
27 | h_ = (1 - f) * z if h is None else f * h + (1 - f) * z
28 | return h_
29 |
30 | def forward(self, inputs, reverse=False):
31 | h = None
32 | Z, F = self._conv_step(inputs)
33 | h_time = []
34 |
35 | if not reverse:
36 | for time, (z, f) in enumerate(zip(Z.split(1, 2), F.split(1, 2))): # split along timestep
37 | h = self._rnn_step(z, f, h)
38 | h_time.append(h)
39 | else:
40 | for time, (z, f) in enumerate((zip(
41 | reversed(Z.split(1, 2)), reversed(F.split(1, 2))
42 | ))): # split along timestep
43 | h = self._rnn_step(z, f, h)
44 | h_time.insert(0, h)
45 |
46 | # return concatenated hidden states
47 | return torch.cat(h_time, dim=2)
48 |
49 | def extra_repr(self):
50 | return 'act={}'.format(self.act)
51 |
52 |
53 | class BiQRNN3DLayer(QRNN3DLayer):
54 | def _conv_step(self, inputs):
55 | gates = self.conv(inputs)
56 | Z, F1, F2 = gates.split(split_size=self.hidden_channels, dim=1)
57 | if self.act == 'tanh':
58 | return Z.tanh(), F1.sigmoid(), F2.sigmoid()
59 | elif self.act == 'relu':
60 | return Z.relu(), F1.sigmoid(), F2.sigmoid()
61 | elif self.act == 'none':
62 | return Z, F1.sigmoid(), F2.sigmoid()
63 | else:
64 | raise NotImplementedError
65 |
66 | def forward(self, inputs, fname=None):
67 | h = None
68 | Z, F1, F2 = self._conv_step(inputs)
69 | hsl = []
70 | hsr = []
71 | zs = Z.split(1, 2)
72 |
73 | for time, (z, f) in enumerate(zip(zs, F1.split(1, 2))): # split along timestep
74 | h = self._rnn_step(z, f, h)
75 | hsl.append(h)
76 |
77 | h = None
78 | for time, (z, f) in enumerate((zip(
79 | reversed(zs), reversed(F2.split(1, 2))
80 | ))): # split along timestep
81 | h = self._rnn_step(z, f, h)
82 | hsr.insert(0, h)
83 |
84 | # return concatenated hidden states
85 | hsl = torch.cat(hsl, dim=2)
86 | hsr = torch.cat(hsr, dim=2)
87 |
88 | if fname is not None:
89 | stats_dict = {'z': Z, 'fl': F1, 'fr': F2, 'hsl': hsl, 'hsr': hsr}
90 | torch.save(stats_dict, fname)
91 | return hsl + hsr
92 |
93 |
94 | class BiQRNNConv3D(BiQRNN3DLayer):
95 | def __init__(self, in_channels, hidden_channels, k=3, s=1, p=1, bn=True, act='tanh'):
96 | super(BiQRNNConv3D, self).__init__(
97 | in_channels, hidden_channels, BasicConv3d(in_channels, hidden_channels * 3, k, s, p, bn=bn), act=act)
98 |
99 |
100 | class BiQRNNDeConv3D(BiQRNN3DLayer):
101 | def __init__(self, in_channels, hidden_channels, k=3, s=1, p=1, bias=False, bn=True, act='tanh'):
102 | super(BiQRNNDeConv3D, self).__init__(
103 | in_channels, hidden_channels, BasicDeConv3d(in_channels, hidden_channels * 3, k, s, p, bias=bias, bn=bn),
104 | act=act)
105 |
106 |
107 | class QRNNConv3D(QRNN3DLayer):
108 | def __init__(self, in_channels, hidden_channels, k=3, s=1, p=1, bn=True, act='tanh'):
109 | super(QRNNConv3D, self).__init__(
110 | in_channels, hidden_channels, BasicConv3d(in_channels, hidden_channels * 2, k, s, p, bn=bn), act=act)
111 |
112 |
113 | class QRNNDeConv3D(QRNN3DLayer):
114 | def __init__(self, in_channels, hidden_channels, k=3, s=1, p=1, bn=True, act='tanh'):
115 | super(QRNNDeConv3D, self).__init__(
116 | in_channels, hidden_channels, BasicDeConv3d(in_channels, hidden_channels * 2, k, s, p, bn=bn), act=act)
117 |
118 |
119 | class QRNNUpsampleConv3D(QRNN3DLayer):
120 | def __init__(self, in_channels, hidden_channels, k=3, s=1, p=1, upsample=(1, 2, 2), bn=True, act='tanh'):
121 | super(QRNNUpsampleConv3D, self).__init__(
122 | in_channels, hidden_channels,
123 | BasicUpsampleConv3d(in_channels, hidden_channels * 2, k, s, p, upsample, bn=bn), act=act)
124 |
125 |
126 |
127 |
--------------------------------------------------------------------------------
/hsir/model/qrnn3d/qrnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .conv import *
5 |
6 |
7 | """F pooling"""
8 |
9 |
10 | class QRNN3DLayer(nn.Module):
11 | def __init__(self, in_channels, hidden_channels, conv_layer, act='tanh'):
12 | super(QRNN3DLayer, self).__init__()
13 | self.in_channels = in_channels
14 | self.hidden_channels = hidden_channels
15 | # quasi_conv_layer
16 | self.conv = conv_layer
17 | self.act = act
18 |
19 | def _conv_step(self, inputs):
20 | gates = self.conv(inputs)
21 | Z, F = gates.split(split_size=self.hidden_channels, dim=1)
22 | if self.act == 'tanh':
23 | return Z.tanh(), F.sigmoid()
24 | elif self.act == 'relu':
25 | return Z.relu(), F.sigmoid()
26 | elif self.act == 'none':
27 | return Z, F.sigmoid
28 | else:
29 | raise NotImplementedError
30 |
31 | def _rnn_step(self, z, f, h):
32 | # uses 'f pooling' at each time step
33 | h_ = (1 - f) * z if h is None else f * h + (1 - f) * z
34 | return h_
35 |
36 | def forward(self, inputs, reverse=False):
37 | h = None
38 | Z, F = self._conv_step(inputs)
39 | h_time = []
40 |
41 | if not reverse:
42 | for time, (z, f) in enumerate(zip(Z.split(1, 2), F.split(1, 2))): # split along timestep
43 | h = self._rnn_step(z, f, h)
44 | h_time.append(h)
45 | else:
46 | for time, (z, f) in enumerate((zip(
47 | reversed(Z.split(1, 2)), reversed(F.split(1, 2))
48 | ))): # split along timestep
49 | h = self._rnn_step(z, f, h)
50 | h_time.insert(0, h)
51 |
52 | # return concatenated hidden states
53 | return torch.cat(h_time, dim=2)
54 |
55 | def extra_repr(self):
56 | return 'act={}'.format(self.act)
57 |
58 |
59 | class BiQRNN3DLayer(QRNN3DLayer):
60 | def _conv_step(self, inputs):
61 | gates = self.conv(inputs)
62 | Z, F1, F2 = gates.split(split_size=self.hidden_channels, dim=1)
63 | if self.act == 'tanh':
64 | return Z.tanh(), F1.sigmoid(), F2.sigmoid()
65 | elif self.act == 'relu':
66 | return Z.relu(), F1.sigmoid(), F2.sigmoid()
67 | elif self.act == 'none':
68 | return Z, F1.sigmoid(), F2.sigmoid()
69 | else:
70 | raise NotImplementedError
71 |
72 | def forward(self, inputs, fname=None):
73 | h = None
74 | Z, F1, F2 = self._conv_step(inputs)
75 | hsl = []
76 | hsr = []
77 | zs = Z.split(1, 2)
78 |
79 | for time, (z, f) in enumerate(zip(zs, F1.split(1, 2))): # split along timestep
80 | h = self._rnn_step(z, f, h)
81 | hsl.append(h)
82 |
83 | h = None
84 | for time, (z, f) in enumerate((zip(
85 | reversed(zs), reversed(F2.split(1, 2))
86 | ))): # split along timestep
87 | h = self._rnn_step(z, f, h)
88 | hsr.insert(0, h)
89 |
90 | # return concatenated hidden states
91 | hsl = torch.cat(hsl, dim=2)
92 | hsr = torch.cat(hsr, dim=2)
93 |
94 | if fname is not None:
95 | stats_dict = {'z': Z, 'fl': F1, 'fr': F2, 'hsl': hsl, 'hsr': hsr}
96 | torch.save(stats_dict, fname)
97 | return hsl + hsr
98 |
99 |
100 | class BiQRNNConv3D(BiQRNN3DLayer):
101 | def __init__(self, in_channels, hidden_channels, k=3, s=1, p=1, bn=True, act='tanh'):
102 | super(BiQRNNConv3D, self).__init__(
103 | in_channels, hidden_channels, BasicConv3d(in_channels, hidden_channels*3, k, s, p, bn=bn), act=act)
104 |
105 |
106 | class BiQRNNDeConv3D(BiQRNN3DLayer):
107 | def __init__(self, in_channels, hidden_channels, k=3, s=1, p=1, bias=False, bn=True, act='tanh'):
108 | super(BiQRNNDeConv3D, self).__init__(
109 | in_channels, hidden_channels, BasicDeConv3d(in_channels, hidden_channels*3, k, s, p, bias=bias, bn=bn), act=act)
110 |
111 |
112 | class QRNNConv3D(QRNN3DLayer):
113 | def __init__(self, in_channels, hidden_channels, k=3, s=1, p=1, bn=True, act='tanh'):
114 | super(QRNNConv3D, self).__init__(
115 | in_channels, hidden_channels, BasicConv3d(in_channels, hidden_channels*2, k, s, p, bn=bn), act=act)
116 |
117 |
118 | class QRNNDeConv3D(QRNN3DLayer):
119 | def __init__(self, in_channels, hidden_channels, k=3, s=1, p=1, bn=True, act='tanh'):
120 | super(QRNNDeConv3D, self).__init__(
121 | in_channels, hidden_channels, BasicDeConv3d(in_channels, hidden_channels*2, k, s, p, bn=bn), act=act)
122 |
123 |
124 | class QRNNUpsampleConv3d(QRNN3DLayer):
125 | def __init__(self, in_channels, hidden_channels, k=3, s=1, p=1, upsample=(1, 2, 2), bn=True, act='tanh'):
126 | super(QRNNUpsampleConv3d, self).__init__(
127 | in_channels, hidden_channels, BasicUpsampleConv3d(in_channels, hidden_channels*2, k, s, p, upsample, bn=bn), act=act)
128 |
--------------------------------------------------------------------------------
/hsir/model/hsdt/sepconv.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import torch.nn as nn
3 | import torch
4 |
5 |
6 | def repeat(x):
7 | if isinstance(x, (tuple, list)):
8 | return x
9 | return [x] * 3
10 |
11 |
12 | class SepConv_PD(nn.Module):
13 | def __init__(self, BASECONV, in_ch, out_ch, k, s=1, p=1, bias=False):
14 | super().__init__()
15 | k, s, p = repeat(k), repeat(s), repeat(p)
16 |
17 | padding_mode = 'zeros'
18 | self.pw_conv = BASECONV(in_ch, out_ch, (k[0], 1, 1), (s[0], 1, 1), (p[0], 0, 0), bias=bias, padding_mode=padding_mode)
19 | self.dw_conv = BASECONV(out_ch, out_ch, (1, k[1], k[2]), (1, s[1], s[2]), (0, p[1], p[2]), bias=bias, padding_mode=padding_mode)
20 |
21 | def forward(self, x):
22 | x = self.pw_conv(x)
23 | x = self.dw_conv(x)
24 | return x
25 |
26 | @staticmethod
27 | def of(base_conv):
28 | return partial(SepConv_PD, base_conv)
29 |
30 |
31 | class SepConv_DP(nn.Module):
32 | def __init__(self, BASECONV, in_ch, out_ch, k, s=1, p=1, bias=False):
33 | super().__init__()
34 | k, s, p = repeat(k), repeat(s), repeat(p)
35 |
36 | padding_mode = 'zeros'
37 | self.dw_conv = BASECONV(in_ch, out_ch, (1, k[1], k[2]), (1, s[1], s[2]), (0, p[1], p[2]), bias=bias, padding_mode=padding_mode)
38 | self.pw_conv = BASECONV(out_ch, out_ch, (k[0], 1, 1), (s[0], 1, 1), (p[0], 0, 0), bias=bias, padding_mode=padding_mode)
39 |
40 | def forward(self, x):
41 | x = self.dw_conv(x)
42 | x = self.pw_conv(x)
43 | return x
44 |
45 | @staticmethod
46 | def of(base_conv):
47 | return partial(SepConv_DP, base_conv)
48 |
49 |
50 | class SepConv_DP_CA(nn.Module):
51 | def __init__(self, BASECONV, in_ch, out_ch, k, s=1, p=1, bias=False):
52 | super().__init__()
53 | k, s, p = repeat(k), repeat(s), repeat(p)
54 |
55 | padding_mode = 'zeros'
56 | self.dw_conv = BASECONV(in_ch, out_ch, (1, k[1], k[2]), (1, s[1], s[2]), (0, p[1], p[2]), bias=bias, padding_mode=padding_mode)
57 | self.pw_conv = BASECONV(out_ch, out_ch * 2, (k[0], 1, 1), (s[0], 1, 1), (p[0], 0, 0), bias=bias, padding_mode=padding_mode)
58 |
59 | def forward(self, x):
60 | x = self.dw_conv(x)
61 | x = self.pw_conv(x)
62 | x, w = torch.chunk(x, 2, dim=1)
63 | x = x * w.sigmoid()
64 | return x
65 |
66 | @staticmethod
67 | def of(base_conv):
68 | return partial(SepConv_DP_CA, base_conv)
69 |
70 |
71 | class S3Conv(nn.Module):
72 | # deep wise then point wise
73 | def __init__(self, BASECONV, in_ch, out_ch, k, s=1, p=1, bias=False):
74 | super().__init__()
75 | k, s, p = repeat(k), repeat(s), repeat(p)
76 |
77 | padding_mode = 'zeros'
78 | self.dw_conv = nn.Sequential(
79 | BASECONV(in_ch, out_ch, (1, k[1], k[2]), (1, s[1], s[2]), (0, p[1], p[2]), bias=bias, padding_mode=padding_mode),
80 | nn.LeakyReLU(),
81 | BASECONV(out_ch, out_ch, (1, k[1], k[2]), 1, (0, p[1], p[2]), bias=bias, padding_mode=padding_mode),
82 | nn.LeakyReLU(),
83 | )
84 | self.pw_conv = BASECONV(in_ch, out_ch, (k[0], 1, 1), (s[0], 1, 1), (p[0], 0, 0), bias=bias, padding_mode=padding_mode)
85 |
86 | def forward(self, x):
87 | x1 = self.dw_conv(x)
88 | x2 = self.pw_conv(x)
89 | return x1 + x2
90 |
91 | @staticmethod
92 | def of(base_conv):
93 | return partial(S3Conv, base_conv)
94 |
95 |
96 | class S3Conv_Seq(nn.Module):
97 | def __init__(self, BASECONV, in_ch, out_ch, k, s=1, p=1, bias=False):
98 | super().__init__()
99 | k, s, p = repeat(k), repeat(s), repeat(p)
100 |
101 | padding_mode = 'zeros'
102 | self.dw_conv = nn.Sequential(
103 | BASECONV(in_ch, out_ch, (1, k[1], k[2]), (1, s[1], s[2]), (0, p[1], p[2]), bias=bias, padding_mode=padding_mode),
104 | nn.LeakyReLU(),
105 | BASECONV(out_ch, out_ch, (1, k[1], k[2]), 1, (0, p[1], p[2]), bias=bias, padding_mode=padding_mode),
106 | nn.LeakyReLU(),
107 | )
108 | self.pw_conv = BASECONV(out_ch, out_ch, (k[0], 1, 1), (s[0], 1, 1), (p[0], 0, 0), bias=bias, padding_mode=padding_mode)
109 |
110 | def forward(self, x):
111 | x = self.dw_conv(x)
112 | x = self.pw_conv(x)
113 | return x
114 |
115 | @staticmethod
116 | def of(base_conv):
117 | return partial(S3Conv_Seq, base_conv)
118 |
119 |
120 | class S3Conv1(nn.Module):
121 | def __init__(self, BASECONV, in_ch, out_ch, k, s=1, p=1, bias=False):
122 | super().__init__()
123 | k, s, p = repeat(k), repeat(s), repeat(p)
124 |
125 | padding_mode = 'zeros'
126 | self.dw_conv = nn.Sequential(
127 | BASECONV(in_ch, out_ch, (1, k[1], k[2]), (1, s[1], s[2]), (0, p[1], p[2]), bias=bias, padding_mode=padding_mode),
128 | nn.LeakyReLU(),
129 | )
130 | self.pw_conv = BASECONV(in_ch, out_ch, (k[0], 1, 1), (s[0], 1, 1), (p[0], 0, 0), bias=bias, padding_mode=padding_mode)
131 |
132 | def forward(self, x):
133 | x1 = self.dw_conv(x)
134 | x2 = self.pw_conv(x)
135 | return x1 + x2
136 |
137 | @staticmethod
138 | def of(base_conv):
139 | return partial(S3Conv1, base_conv)
140 |
--------------------------------------------------------------------------------
/hsir/model/trq3d/combinations.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional
4 | from sync_batchnorm import SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
5 |
6 | BatchNorm2d = SynchronizedBatchNorm2d
7 | BatchNorm3d = SynchronizedBatchNorm3d
8 |
9 |
10 | class BNReLUConv3d(nn.Sequential):
11 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False):
12 | super(BNReLUConv3d, self).__init__()
13 | self.add_module('bn', BatchNorm3d(in_channels))
14 | self.add_module('relu', nn.ReLU(inplace=inplace))
15 | self.add_module('conv', nn.Conv3d(in_channels, channels, k, s, p, bias=False))
16 |
17 |
18 | class BNReLUDeConv3d(nn.Sequential):
19 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False):
20 | super(BNReLUDeConv3d, self).__init__()
21 | self.add_module('bn', BatchNorm3d(in_channels))
22 | self.add_module('relu', nn.ReLU(inplace=inplace))
23 | self.add_module('deconv', nn.ConvTranspose3d(in_channels, channels, k, s, p, bias=False))
24 |
25 |
26 | class BNReLUUpsampleConv3d(nn.Sequential):
27 | def __init__(self, in_channels, channels, k=3, s=1, p=1, upsample=(1, 2, 2), inplace=False):
28 | super(BNReLUUpsampleConv3d, self).__init__()
29 | self.add_module('bn', BatchNorm3d(in_channels))
30 | self.add_module('relu', nn.ReLU(inplace=inplace))
31 | self.add_module('upsample_conv', UpsampleConv3d(in_channels, channels, k, s, p, bias=False, upsample=upsample))
32 |
33 | class UpsampleConv2d(torch.nn.Module):
34 | """UpsampleConvLayer
35 | Upsamples the input and then does a convolution. This method gives better results
36 | compared to ConvTranspose2d.
37 | ref: http://distill.pub/2016/deconv-checkerboard/
38 | """
39 |
40 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=True, upsample=None):
41 | super(UpsampleConv2d, self).__init__()
42 | self.upsample = upsample
43 | if upsample:
44 | self.upsample_layer = torch.nn.Upsample(scale_factor=upsample, mode='bilinear', align_corners=True)
45 |
46 | self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
47 |
48 | def forward(self, x):
49 | x_in = x
50 | if self.upsample:
51 | x_in = self.upsample_layer(x_in)
52 | out = self.conv2d(x_in)
53 | return out
54 |
55 | class UpsampleConv3d(torch.nn.Module):
56 | """UpsampleConvLayer
57 | Upsamples the input and then does a convolution. This method gives better results
58 | compared to ConvTranspose2d.
59 | ref: http://distill.pub/2016/deconv-checkerboard/
60 | """
61 |
62 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=True, upsample=None):
63 | super(UpsampleConv3d, self).__init__()
64 | self.upsample = upsample
65 | if upsample:
66 | self.upsample_layer = torch.nn.Upsample(scale_factor=upsample, mode='trilinear', align_corners=True)
67 |
68 | self.conv3d = torch.nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
69 |
70 | def forward(self, x):
71 | x_in = x
72 | if self.upsample:
73 | x_in = self.upsample_layer(x_in)
74 | out = self.conv3d(x_in)
75 | return out
76 |
77 | class BasicConv2d(nn.Sequential):
78 | def __init__(self, in_channels, channels, k=3, s=1, p=1, bias=False, bn=True):
79 | super(BasicConv2d, self).__init__()
80 | if bn:
81 | self.add_module('bn', BatchNorm2d(in_channels))
82 | self.add_module('conv', nn.Conv2d(in_channels, channels, k, s, p, bias=bias))
83 |
84 | class BasicConv3d(nn.Sequential):
85 | def __init__(self, in_channels, channels, k=3, s=1, p=1, bias=False, bn=True):
86 | super(BasicConv3d, self).__init__()
87 | if bn:
88 | self.add_module('bn', BatchNorm3d(in_channels))
89 | self.add_module('conv', nn.Conv3d(in_channels, channels, k, s, p, bias=bias))
90 |
91 | class BasicDeConv2d(nn.Sequential):
92 | def __init__(self, in_channels, channels, k=3, s=1, p=1, bias=False, bn=True):
93 | super(BasicDeConv2d, self).__init__()
94 | if bn:
95 | self.add_module('bn', BatchNorm2d(in_channels))
96 | self.add_module('deconv', nn.ConvTranspose2d(in_channels, channels, k, s, p, bias=bias))
97 |
98 | class BasicDeConv3d(nn.Sequential):
99 | def __init__(self, in_channels, channels, k=3, s=1, p=1, bias=False, bn=True):
100 | super(BasicDeConv3d, self).__init__()
101 | if bn:
102 | self.add_module('bn', BatchNorm3d(in_channels))
103 | self.add_module('deconv', nn.ConvTranspose3d(in_channels, channels, k, s, p, bias=bias))
104 |
105 | class BasicUpsampleConv2d(nn.Sequential):
106 | def __init__(self, in_channels, channels, k=3, s=1, p=1, upsample=(2, 2), bn=True):
107 | super(BasicUpsampleConv2d, self).__init__()
108 | if bn:
109 | self.add_module('bn', BatchNorm2d(in_channels))
110 | self.add_module('upsample_conv', UpsampleConv2d(in_channels, channels, k, s, p, bias=False, upsample=upsample))
111 |
112 | class BasicUpsampleConv3d(nn.Sequential):
113 | def __init__(self, in_channels, channels, k=3, s=1, p=1, upsample=(1, 2, 2), bn=True):
114 | super(BasicUpsampleConv3d, self).__init__()
115 | if bn:
116 | self.add_module('bn', BatchNorm3d(in_channels))
117 | self.add_module('upsample_conv', UpsampleConv3d(in_channels, channels, k, s, p, bias=False, upsample=upsample))
118 |
119 |
120 |
--------------------------------------------------------------------------------
/hsir/data/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 | import scipy.io
5 | import torch
6 | import lmdb
7 | import numpy as np
8 |
9 | from torch.utils.data import Dataset
10 | from torchvision.transforms import Compose
11 | from functools import partial
12 |
13 | __all__ = [
14 | 'HSITestDataset',
15 | 'HSITrainDataset',
16 | 'HSITransformTestDataset'
17 | ]
18 |
19 |
20 | class HSITestDataset(Dataset):
21 | def __init__(self, root, size=None, use_chw=False, return_name=False):
22 | super().__init__()
23 | self.dataset = MatDataFromFolder(root, size=size)
24 | self.transform = Compose([
25 | LoadMatHSI(input_key='input', gt_key='gt',
26 | transform=None if use_chw else partial(np.expand_dims, axis=0)),
27 | ])
28 | self.return_name = return_name
29 |
30 | def __getitem__(self, index):
31 | mat, filename = self.dataset[index]
32 | inputs, targets = self.transform(mat)
33 | output = {'input': inputs, 'target': targets}
34 | if self.return_name:
35 | output['filename'] = filename
36 | return output
37 |
38 | def __len__(self):
39 | return len(self.dataset)
40 |
41 |
42 | class HSITransformTestDataset(Dataset):
43 | def __init__(self, root, transform, size=None):
44 | super().__init__()
45 | self.dataset = MatDataFromFolder(root, size=size)
46 | self.transform = transform
47 |
48 | def __getitem__(self, index):
49 | mat, filename = self.dataset[index]
50 | gt = mat['gt'].transpose(2, 0, 1).astype('float32')
51 | input = self.transform(gt)
52 | return (input, gt), filename
53 |
54 | def __len__(self):
55 | return len(self.dataset)
56 |
57 |
58 | class HSITrainDataset(Dataset):
59 | def __init__(self,
60 | root,
61 | input_transform,
62 | target_transform,
63 | common_transform,
64 | repeat=1,
65 | ):
66 | super().__init__()
67 | self.dataset = LMDBDataset(root, repeat=repeat)
68 | self.common_transform = common_transform
69 | self.input_transform = input_transform
70 | self.target_transform = target_transform
71 |
72 | def __getitem__(self, index):
73 | img = self.dataset[index]
74 | img = self.common_transform(img)
75 | target = img.copy()
76 | if self.input_transform is not None:
77 | img = self.input_transform(img)
78 | if self.target_transform is not None:
79 | target = self.target_transform(target)
80 | return {'input': img, 'target': target}
81 |
82 | def __len__(self):
83 | return len(self.dataset)
84 |
85 |
86 | class LoadMatHSI(object):
87 | def __init__(self, input_key, gt_key, transform=None):
88 | self.gt_key = gt_key
89 | self.input_key = input_key
90 | self.transform = transform
91 |
92 | def __call__(self, mat):
93 | if self.transform:
94 | input = self.transform(mat[self.input_key].transpose((2, 0, 1)))
95 | gt = self.transform(mat[self.gt_key].transpose((2, 0, 1)))
96 | else:
97 | input = mat[self.input_key].transpose((2, 0, 1))
98 | gt = mat[self.gt_key].transpose((2, 0, 1))
99 |
100 | input = torch.from_numpy(input).float()
101 | gt = torch.from_numpy(gt).float()
102 |
103 | return input, gt
104 |
105 |
106 | class MatDataFromFolder(Dataset):
107 | """Wrap mat data from folder"""
108 |
109 | def __init__(self, data_dir, load=scipy.io.loadmat, suffix='mat', fns=None, size=None):
110 | super(MatDataFromFolder, self).__init__()
111 | if fns is not None:
112 | self.filenames = [
113 | os.path.join(data_dir, fn) for fn in fns
114 | ]
115 | else:
116 | self.filenames = [
117 | os.path.join(data_dir, fn)
118 | for fn in os.listdir(data_dir)
119 | if fn.endswith(suffix)
120 | ]
121 |
122 | self.load = load
123 |
124 | if size and size <= len(self.filenames):
125 | self.filenames = self.filenames[:size]
126 |
127 | def __getitem__(self, index):
128 | filename = self.filenames[index]
129 | mat = self.load(filename)
130 | return mat, filename
131 |
132 | def __len__(self):
133 | return len(self.filenames)
134 |
135 |
136 | class LMDBDataset(Dataset):
137 | def __init__(self, db_path, repeat=1, backend='pickle'):
138 | self.db_path = db_path
139 | self.env = lmdb.open(db_path, max_readers=1, readonly=True, lock=False,
140 | readahead=False, meminit=False)
141 | with self.env.begin(write=False) as txn:
142 | self.length = txn.stat()['entries']
143 | self.repeat = repeat
144 | self.backend = backend
145 |
146 | def __getitem__(self, index):
147 | index = index % self.length
148 | env = self.env
149 | with env.begin(write=False) as txn:
150 | raw_datum = txn.get('{:08}'.format(index).encode('ascii'))
151 |
152 | if self.backend == 'caffe':
153 | import caffe
154 | datum = caffe.proto.caffe_pb2.Datum()
155 | datum.ParseFromString(raw_datum)
156 |
157 | flat_x = np.fromstring(datum.data, dtype=np.float32)
158 | x = flat_x.reshape(datum.channels, datum.height, datum.width)
159 | else:
160 | x = pickle.loads(raw_datum)
161 |
162 | return x
163 |
164 | def __len__(self):
165 | return self.length * self.repeat
166 |
167 | def __repr__(self):
168 | return self.__class__.__name__ + ' (' + self.db_path + ')'
169 |
--------------------------------------------------------------------------------
/hsir/model/hsdt/__init__.py:
--------------------------------------------------------------------------------
1 | from .arch import HSDT
2 | from .attention import GSSA, SMFFN, TransformerBlock
3 | from .sepconv import S3Conv
4 |
5 | def hsdt():
6 | net = HSDT(1, 16, 5, [1, 3])
7 | net.use_2dconv = False
8 | net.bandwise = False
9 | return net
10 |
11 |
12 | def hsdt_4():
13 | net = HSDT(1, 4, 5, [1, 3])
14 | net.use_2dconv = False
15 | net.bandwise = False
16 | return net
17 |
18 |
19 | def hsdt_8():
20 | net = HSDT(1, 8, 5, [1, 3])
21 | net.use_2dconv = False
22 | net.bandwise = False
23 | return net
24 |
25 |
26 | def hsdt_24():
27 | net = HSDT(1, 24, 5, [1, 3])
28 | net.use_2dconv = False
29 | net.bandwise = False
30 | return net
31 |
32 |
33 | def hsdt_32():
34 | net = HSDT(1, 32, 5, [1, 3])
35 | net.use_2dconv = False
36 | net.bandwise = False
37 | return net
38 |
39 |
40 | def hsdt_deep():
41 | net = HSDT(1, 16, 7, [1, 3, 5])
42 | net.use_2dconv = False
43 | net.bandwise = False
44 | return net
45 |
46 |
47 | """ Extension
48 | """
49 |
50 |
51 | def hsdt_pnp():
52 | net = HSDT(2, 16, 5, [1, 3])
53 | net.use_2dconv = False
54 | net.bandwise = False
55 | return net
56 |
57 |
58 | def hsdt_ssr():
59 | from .arch import HSDTSSR
60 | net = HSDTSSR(1, 16, 5, [1, 3])
61 | net.use_2dconv = False
62 | net.bandwise = False
63 | return net
64 |
65 |
66 | """ ablations
67 | """
68 |
69 |
70 | def hsdt_pixelwise():
71 | from . import arch
72 | from .attention import PixelwiseTransformerBlock
73 | arch.TransformerBlock = PixelwiseTransformerBlock
74 |
75 | net = HSDT(1, 16, 5, [1, 3])
76 | net.use_2dconv = False
77 | net.bandwise = False
78 | return net
79 |
80 | # ablation of ffn
81 |
82 |
83 | def hsdt_ffn():
84 | from . import arch
85 | from .attention import FFNTransformerBlock
86 | arch.TransformerBlock = FFNTransformerBlock
87 |
88 | net = HSDT(1, 16, 5, [1, 3])
89 | net.use_2dconv = False
90 | net.bandwise = False
91 | return net
92 |
93 |
94 | def hsdt_ffn_flex():
95 | from . import arch
96 | from .attention import FFNTransformerBlock
97 | from functools import partial
98 | arch.TransformerBlock = partial(FFNTransformerBlock, flex=True)
99 |
100 | net = HSDT(1, 16, 5, [1, 3])
101 | net.use_2dconv = False
102 | net.bandwise = False
103 | return net
104 |
105 |
106 | def hsdt_gdfn():
107 | from . import arch
108 | from .attention import GDFNTransformerBlock
109 | arch.TransformerBlock = GDFNTransformerBlock
110 |
111 | net = HSDT(1, 16, 5, [1, 3])
112 | net.use_2dconv = False
113 | net.bandwise = False
114 | return net
115 |
116 |
117 | def hsdt_smffn1():
118 | from . import arch
119 | from .attention import GFNTransformerBlock
120 | arch.TransformerBlock = GFNTransformerBlock
121 |
122 | net = HSDT(1, 16, 5, [1, 3])
123 | net.use_2dconv = False
124 | net.bandwise = False
125 | return net
126 |
127 | # ablation of ssa
128 |
129 |
130 | def hsdt_ssa():
131 | from . import arch
132 | from .attention import SSATransformerBlock
133 | arch.TransformerBlock = SSATransformerBlock
134 |
135 | net = HSDT(1, 16, 5, [1, 3])
136 | net.use_2dconv = False
137 | net.bandwise = False
138 | return net
139 |
140 | # ablation of s3conv
141 |
142 |
143 | def hsdt_conv3d():
144 | from . import arch
145 | import torch.nn as nn
146 | arch.Conv3d = nn.Conv3d
147 |
148 | net = HSDT(1, 16, 5, [1, 3])
149 | net.use_2dconv = False
150 | net.bandwise = False
151 | return net
152 |
153 |
154 | def hsdt_s3conv_sep():
155 | from . import arch
156 | import torch.nn as nn
157 | from .sepconv import SepConv_DP
158 | arch.Conv3d = SepConv_DP.of(nn.Conv3d)
159 | net = HSDT(1, 16, 5, [1, 3])
160 | net.use_2dconv = False
161 | net.bandwise = False
162 | return net
163 |
164 |
165 | def hsdt_s3conv_seq():
166 | from . import arch
167 | import torch.nn as nn
168 | from .sepconv import S3Conv_Seq
169 | arch.Conv3d = S3Conv_Seq.of(nn.Conv3d)
170 | net = HSDT(1, 16, 5, [1, 3])
171 | net.use_2dconv = False
172 | net.bandwise = False
173 | return net
174 |
175 |
176 | def hsdt_s3conv1():
177 | from . import arch
178 | import torch.nn as nn
179 | from .sepconv import S3Conv1
180 | arch.Conv3d = S3Conv1.of(nn.Conv3d)
181 | net = HSDT(1, 16, 5, [1, 3])
182 | net.use_2dconv = False
183 | net.bandwise = False
184 | return net
185 |
186 |
187 | """ Break down
188 | """
189 |
190 |
191 | def baseline_s3conv():
192 | from . import arch
193 | from .attention import DummyTransformerBlock
194 | arch.TransformerBlock = DummyTransformerBlock
195 | arch.UseBN = False
196 |
197 | net = HSDT(1, 16, 5, [1, 3])
198 | net.use_2dconv = False
199 | net.bandwise = False
200 | return net
201 |
202 |
203 | def baseline_conv3d():
204 | from . import arch
205 | import torch.nn as nn
206 | arch.Conv3d = nn.Conv3d
207 | from .attention import DummyTransformerBlock
208 | arch.TransformerBlock = DummyTransformerBlock
209 | arch.UseBN = False
210 |
211 | net = HSDT(1, 16, 5, [1, 3])
212 | net.use_2dconv = False
213 | net.bandwise = False
214 | return net
215 |
216 |
217 | def baseline_gssa():
218 | from . import arch
219 | import torch.nn as nn
220 | arch.Conv3d = nn.Conv3d
221 | from .attention import FFNTransformerBlock
222 | arch.TransformerBlock = FFNTransformerBlock
223 | arch.UseBN = True
224 |
225 | net = HSDT(1, 16, 5, [1, 3])
226 | net.use_2dconv = False
227 | net.bandwise = False
228 | return net
229 |
230 |
231 | def baseline_ssa():
232 | from . import arch
233 | import torch.nn as nn
234 | arch.Conv3d = nn.Conv3d
235 | from .attention import SSAFFNTransformerBlock
236 | arch.TransformerBlock = SSAFFNTransformerBlock
237 | arch.UseBN = True
238 |
239 | net = HSDT(1, 16, 5, [1, 3])
240 | net.use_2dconv = False
241 | net.bandwise = False
242 | return net
243 |
--------------------------------------------------------------------------------
/hsir/model/qrnn3d/net.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class QRNNREDC3D(nn.Module):
5 | def __init__(self, in_channels, channels, num_half_layer, sample_idx,
6 | BiQRNNConv3D=None, BiQRNNDeConv3D=None,
7 | QRNN3DEncoder=None, QRNN3DDecoder=None, is_2d=False, has_ad=True, bn=True, act='tanh', plain=False):
8 | super(QRNNREDC3D, self).__init__()
9 | assert sample_idx is None or isinstance(sample_idx, list)
10 |
11 | self.enable_ad = has_ad
12 | if sample_idx is None: sample_idx = []
13 | if is_2d:
14 | self.feature_extractor = BiQRNNConv3D(in_channels, channels, k=(1,3,3), s=1, p=(0,1,1), bn=bn, act=act)
15 | else:
16 | self.feature_extractor = BiQRNNConv3D(in_channels, channels, bn=bn, act=act)
17 |
18 | self.encoder = QRNN3DEncoder(channels, num_half_layer, sample_idx, is_2d=is_2d, has_ad=has_ad, bn=bn, act=act, plain=plain)
19 | self.decoder = QRNN3DDecoder(channels*(2**len(sample_idx)), num_half_layer, sample_idx, is_2d=is_2d, has_ad=has_ad, bn=bn, act=act, plain=plain)
20 |
21 | if act == 'relu':
22 | act = 'none'
23 |
24 | if is_2d:
25 | self.reconstructor = BiQRNNDeConv3D(channels, in_channels, bias=True, k=(1,3,3), s=1, p=(0,1,1), bn=bn, act=act)
26 | else:
27 | self.reconstructor = BiQRNNDeConv3D(channels, in_channels, bias=True, bn=bn, act=act)
28 |
29 | def forward(self, x):
30 | xs = [x]
31 | out = self.feature_extractor(xs[0])
32 | xs.append(out)
33 | if self.enable_ad:
34 | out, reverse = self.encoder(out, xs, reverse=False)
35 | out = self.decoder(out, xs, reverse=(reverse))
36 | else:
37 | out = self.encoder(out, xs)
38 | out = self.decoder(out, xs)
39 | out = out + xs.pop()
40 | out = self.reconstructor(out)
41 | out = out + xs.pop()
42 | return out
43 |
44 |
45 | class QRNN3DEncoder(nn.Module):
46 | def __init__(self, channels, num_half_layer, sample_idx, QRNNConv3D=None,
47 | is_2d=False, has_ad=True, bn=True, act='tanh', plain=False):
48 | super(QRNN3DEncoder, self).__init__()
49 | # Encoder
50 | self.layers = nn.ModuleList()
51 | self.enable_ad = has_ad
52 | for i in range(num_half_layer):
53 | if i not in sample_idx:
54 | if is_2d:
55 | encoder_layer = QRNNConv3D(channels, channels, k=(1,3,3), s=1, p=(0,1,1), bn=bn, act=act)
56 | else:
57 | encoder_layer = QRNNConv3D(channels, channels, bn=bn, act=act)
58 | else:
59 | if is_2d:
60 | encoder_layer = QRNNConv3D(channels, 2*channels, k=(1,3,3), s=(1,2,2), p=(0,1,1), bn=bn, act=act)
61 | else:
62 | if not plain:
63 | encoder_layer = QRNNConv3D(channels, 2*channels, k=3, s=(1,2,2), p=1, bn=bn, act=act)
64 | else:
65 | encoder_layer = QRNNConv3D(channels, 2*channels, k=3, s=(1,1,1), p=1, bn=bn, act=act)
66 |
67 | channels *= 2
68 | self.layers.append(encoder_layer)
69 |
70 | def forward(self, x, xs, reverse=False):
71 | if not self.enable_ad:
72 | num_half_layer = len(self.layers)
73 | for i in range(num_half_layer-1):
74 | x = self.layers[i](x)
75 | xs.append(x)
76 | x = self.layers[-1](x)
77 |
78 | return x
79 | else:
80 | num_half_layer = len(self.layers)
81 | for i in range(num_half_layer-1):
82 | x = self.layers[i](x, reverse=reverse)
83 | reverse = not reverse
84 | xs.append(x)
85 | x = self.layers[-1](x, reverse=reverse)
86 | reverse = not reverse
87 |
88 | return x, reverse
89 |
90 |
91 | class QRNN3DDecoder(nn.Module):
92 | def __init__(self, channels, num_half_layer, sample_idx, QRNNDeConv3D=None, QRNNUpsampleConv3d=None,
93 | is_2d=False, has_ad=True, bn=True, act='tanh', plain=False):
94 | super(QRNN3DDecoder, self).__init__()
95 | # Decoder
96 | self.layers = nn.ModuleList()
97 | self.enable_ad = has_ad
98 | for i in reversed(range(num_half_layer)):
99 | if i not in sample_idx:
100 | if is_2d:
101 | decoder_layer = QRNNDeConv3D(channels, channels, k=(1,3,3), s=1, p=(0,1,1), bn=bn, act=act)
102 | else:
103 | decoder_layer = QRNNDeConv3D(channels, channels, bn=bn, act=act)
104 | else:
105 | if is_2d:
106 | decoder_layer = QRNNUpsampleConv3d(channels, channels//2, k=(1,3,3), s=1, p=(0,1,1), bn=bn, act=act)
107 | else:
108 | if not plain:
109 | decoder_layer = QRNNUpsampleConv3d(channels, channels//2, bn=bn, act=act)
110 | else:
111 | decoder_layer = QRNNDeConv3D(channels, channels//2, bn=bn, act=act)
112 |
113 | channels //= 2
114 | self.layers.append(decoder_layer)
115 |
116 |
117 | def forward(self, x, xs, reverse=False):
118 | if not self.enable_ad:
119 | num_half_layer = len(self.layers)
120 | x = self.layers[0](x)
121 | for i in range(1, num_half_layer):
122 | x = x + xs.pop()
123 | x = self.layers[i](x)
124 | return x
125 | else:
126 | num_half_layer = len(self.layers)
127 | x = self.layers[0](x, reverse=reverse)
128 | reverse = not reverse
129 | for i in range(1, num_half_layer):
130 | x = x + xs.pop()
131 | x = self.layers[i](x, reverse=reverse)
132 | reverse = not reverse
133 | return x
--------------------------------------------------------------------------------
/benchmark/datasets/CAVE/cave_build_lmdb.py:
--------------------------------------------------------------------------------
1 | """Create lmdb dataset"""
2 | import os
3 | import numpy as np
4 | import lmdb
5 | import caffe
6 | from itertools import product
7 | from scipy.ndimage import zoom
8 | import random
9 | from scipy.io import loadmat
10 |
11 | def Data2Volume(data, ksizes, strides):
12 | """
13 | Construct Volumes from Original High Dimensional (D) Data
14 | """
15 | dshape = data.shape
16 | def PatNum(l, k, s): return (np.floor((l - k) / s) + 1)
17 |
18 | TotalPatNum = 1
19 | for i in range(len(ksizes)):
20 | TotalPatNum = TotalPatNum * PatNum(dshape[i], ksizes[i], strides[i])
21 |
22 | V = np.zeros([int(TotalPatNum)] + ksizes) # create D+1 dimension volume
23 |
24 | args = [range(kz) for kz in ksizes]
25 | for s in product(*args):
26 | s1 = (slice(None),) + s
27 | s2 = tuple([slice(key, -ksizes[i] + key + 1 or None, strides[i]) for i, key in enumerate(s)])
28 | V[s1] = np.reshape(data[s2], (-1,))
29 |
30 | return V
31 |
32 |
33 | def minmax_normalize(array):
34 | amin = np.min(array)
35 | amax = np.max(array)
36 | return (array - amin) / (amax - amin)
37 |
38 |
39 | def crop_center(img, cropx, cropy):
40 | _, y, x = img.shape
41 | startx = x // 2 - (cropx // 2)
42 | starty = y // 2 - (cropy // 2)
43 | return img[:, starty:starty + cropy, startx:startx + cropx]
44 |
45 |
46 | def data_augmentation(image, mode=None):
47 | """
48 | Args:
49 | image: np.ndarray, shape: C X H X W
50 | """
51 | axes = (-2, -1)
52 | def flipud(x): return x[:, ::-1, :]
53 |
54 | if mode is None:
55 | mode = random.randint(0, 7)
56 | if mode == 0:
57 | # original
58 | image = image
59 | elif mode == 1:
60 | # flip up and down
61 | image = flipud(image)
62 | elif mode == 2:
63 | # rotate counterwise 90 degree
64 | image = np.rot90(image, axes=axes)
65 | elif mode == 3:
66 | # rotate 90 degree and flip up and down
67 | image = np.rot90(image, axes=axes)
68 | image = flipud(image)
69 | elif mode == 4:
70 | # rotate 180 degree
71 | image = np.rot90(image, k=2, axes=axes)
72 | elif mode == 5:
73 | # rotate 180 degree and flip
74 | image = np.rot90(image, k=2, axes=axes)
75 | image = flipud(image)
76 | elif mode == 6:
77 | # rotate 270 degree
78 | image = np.rot90(image, k=3, axes=axes)
79 | elif mode == 7:
80 | # rotate 270 degree and flip
81 | image = np.rot90(image, k=3, axes=axes)
82 | image = flipud(image)
83 |
84 | # we apply spectrum reversal for training 3D CNN, e.g. QRNN3D.
85 | # disable it when training 2D CNN, e.g. MemNet
86 | if random.random() < 0.5:
87 | image = image[::-1, :, :]
88 |
89 | return np.ascontiguousarray(image)
90 |
91 |
92 | def create_lmdb_train(
93 | datadir, fns, name, matkey,
94 | crop_sizes, scales, ksizes, strides,
95 | load=loadmat, augment=True,
96 | seed=2022):
97 | """
98 | Create Augmented Dataset
99 | """
100 | def preprocess(data):
101 | new_data = []
102 | data = np.float32(data)
103 | # data = minmax_normalize(data)
104 | # data = np.rot90(data, k=2, axes=(1,2)) # ICVL
105 | data = minmax_normalize(data.transpose((2, 0, 1))) # for Remote Sensing
106 | # Visualize3D(data)
107 | if crop_sizes is not None:
108 | data = crop_center(data, crop_sizes[0], crop_sizes[1])
109 |
110 | for i in range(len(scales)):
111 | if scales[i] != 1:
112 | temp = zoom(data, zoom=(1, scales[i], scales[i]))
113 | else:
114 | temp = data
115 | temp = Data2Volume(temp, ksizes=ksizes, strides=list(strides[i]))
116 | new_data.append(temp)
117 | new_data = np.concatenate(new_data, axis=0)
118 | if augment:
119 | for i in range(new_data.shape[0]):
120 | new_data[i, ...] = data_augmentation(new_data[i, ...])
121 |
122 | return new_data.astype(np.float32)
123 |
124 | np.random.seed(seed)
125 | scales = list(scales)
126 | ksizes = list(ksizes)
127 | assert len(scales) == len(strides)
128 | # calculate the shape of dataset
129 | data = load(datadir + fns[0])[matkey]
130 | data = preprocess(data)
131 | N = data.shape[0]
132 |
133 | print(data.shape)
134 | map_size = data.nbytes * len(fns) * 1.2
135 | print('map size (GB):', map_size / 1024 / 1024 / 1024)
136 |
137 | # import ipdb; ipdb.set_trace()
138 | if os.path.exists(name + '.db'):
139 | raise Exception('database already exist!')
140 | env = lmdb.open(name + '.db', map_size=map_size, writemap=True)
141 | with env.begin(write=True) as txn:
142 | # txn is a Transaction object
143 | k = 0
144 | for i, fn in enumerate(fns):
145 | try:
146 | X = load(datadir + fn)[matkey]
147 | except:
148 | print('loading', datadir + fn, 'fail')
149 | continue
150 | X = preprocess(X)
151 | N = X.shape[0]
152 | for j in range(N):
153 | datum = caffe.proto.caffe_pb2.Datum()
154 | datum.channels = X.shape[1]
155 | datum.height = X.shape[2]
156 | datum.width = X.shape[3]
157 | datum.data = X[j].tobytes()
158 | str_id = '{:08}'.format(k)
159 | k += 1
160 | txn.put(str_id.encode('ascii'), datum.SerializeToString())
161 | print('load mat (%d/%d): %s' % (i, len(fns), fn))
162 |
163 | print('done')
164 |
165 | def create_cave():
166 | print('create cave...')
167 | datadir = '/home/wzliu/projects/data/cave_mat/'
168 | with open('cave_train.txt') as f:
169 | fns = [l.strip() for l in f.readlines()]
170 |
171 | create_lmdb_train(
172 | datadir, fns, 'CAVE64_31', 'gt',
173 | crop_sizes=(512,512),
174 | scales=(1, 0.5, 0.25),
175 | ksizes=(31, 64, 64),
176 | strides=[(31, 64, 64), (31, 32, 32), (31, 32, 32)],
177 | load=loadmat, augment=True,
178 | )
179 |
180 |
181 | if __name__ == '__main__':
182 | create_cave()
183 |
--------------------------------------------------------------------------------
/benchmark/datasets/Harvard/harvard_build_lmdb.py:
--------------------------------------------------------------------------------
1 | """Create lmdb dataset"""
2 | import os
3 | import numpy as np
4 | import lmdb
5 | import caffe
6 | from itertools import product
7 | from scipy.ndimage import zoom
8 | import random
9 | from scipy.io import loadmat
10 |
11 | def Data2Volume(data, ksizes, strides):
12 | """
13 | Construct Volumes from Original High Dimensional (D) Data
14 | """
15 | dshape = data.shape
16 | def PatNum(l, k, s): return (np.floor((l - k) / s) + 1)
17 |
18 | TotalPatNum = 1
19 | for i in range(len(ksizes)):
20 | TotalPatNum = TotalPatNum * PatNum(dshape[i], ksizes[i], strides[i])
21 |
22 | V = np.zeros([int(TotalPatNum)] + ksizes) # create D+1 dimension volume
23 |
24 | args = [range(kz) for kz in ksizes]
25 | for s in product(*args):
26 | s1 = (slice(None),) + s
27 | s2 = tuple([slice(key, -ksizes[i] + key + 1 or None, strides[i]) for i, key in enumerate(s)])
28 | V[s1] = np.reshape(data[s2], (-1,))
29 |
30 | return V
31 |
32 |
33 | def minmax_normalize(array):
34 | amin = np.min(array)
35 | amax = np.max(array)
36 | return (array - amin) / (amax - amin)
37 |
38 |
39 | def crop_center(img, cropx, cropy):
40 | _, y, x = img.shape
41 | startx = x // 2 - (cropx // 2)
42 | starty = y // 2 - (cropy // 2)
43 | return img[:, starty:starty + cropy, startx:startx + cropx]
44 |
45 |
46 | def data_augmentation(image, mode=None):
47 | """
48 | Args:
49 | image: np.ndarray, shape: C X H X W
50 | """
51 | axes = (-2, -1)
52 | def flipud(x): return x[:, ::-1, :]
53 |
54 | if mode is None:
55 | mode = random.randint(0, 7)
56 | if mode == 0:
57 | # original
58 | image = image
59 | elif mode == 1:
60 | # flip up and down
61 | image = flipud(image)
62 | elif mode == 2:
63 | # rotate counterwise 90 degree
64 | image = np.rot90(image, axes=axes)
65 | elif mode == 3:
66 | # rotate 90 degree and flip up and down
67 | image = np.rot90(image, axes=axes)
68 | image = flipud(image)
69 | elif mode == 4:
70 | # rotate 180 degree
71 | image = np.rot90(image, k=2, axes=axes)
72 | elif mode == 5:
73 | # rotate 180 degree and flip
74 | image = np.rot90(image, k=2, axes=axes)
75 | image = flipud(image)
76 | elif mode == 6:
77 | # rotate 270 degree
78 | image = np.rot90(image, k=3, axes=axes)
79 | elif mode == 7:
80 | # rotate 270 degree and flip
81 | image = np.rot90(image, k=3, axes=axes)
82 | image = flipud(image)
83 |
84 | # we apply spectrum reversal for training 3D CNN, e.g. QRNN3D.
85 | # disable it when training 2D CNN, e.g. MemNet
86 | if random.random() < 0.5:
87 | image = image[::-1, :, :]
88 |
89 | return np.ascontiguousarray(image)
90 |
91 |
92 | def create_lmdb_train(
93 | datadir, fns, name, matkey,
94 | crop_sizes, scales, ksizes, strides,
95 | load=loadmat, augment=True,
96 | seed=2022):
97 | """
98 | Create Augmented Dataset
99 | """
100 | def preprocess(data):
101 | new_data = []
102 | data = np.float32(data)
103 | # data = minmax_normalize(data)
104 | # data = np.rot90(data, k=2, axes=(1,2)) # ICVL
105 | data = minmax_normalize(data.transpose((2, 0, 1))) # for Remote Sensing
106 | # Visualize3D(data)
107 | if crop_sizes is not None:
108 | data = crop_center(data, crop_sizes[0], crop_sizes[1])
109 |
110 | for i in range(len(scales)):
111 | if scales[i] != 1:
112 | temp = zoom(data, zoom=(1, scales[i], scales[i]))
113 | else:
114 | temp = data
115 | temp = Data2Volume(temp, ksizes=ksizes, strides=list(strides[i]))
116 | new_data.append(temp)
117 | new_data = np.concatenate(new_data, axis=0)
118 | if augment:
119 | for i in range(new_data.shape[0]):
120 | new_data[i, ...] = data_augmentation(new_data[i, ...])
121 |
122 | return new_data.astype(np.float32)
123 |
124 | np.random.seed(seed)
125 | scales = list(scales)
126 | ksizes = list(ksizes)
127 | assert len(scales) == len(strides)
128 | # calculate the shape of dataset
129 | data = load(datadir + fns[0])[matkey]
130 | data = preprocess(data)
131 | N = data.shape[0]
132 |
133 | print(data.shape)
134 | map_size = data.nbytes * len(fns) * 1.2
135 | print('map size (GB):', map_size / 1024 / 1024 / 1024)
136 |
137 | # import ipdb; ipdb.set_trace()
138 | if os.path.exists(name + '.db'):
139 | raise Exception('database already exist!')
140 | env = lmdb.open(name + '.db', map_size=map_size, writemap=True)
141 | with env.begin(write=True) as txn:
142 | # txn is a Transaction object
143 | k = 0
144 | for i, fn in enumerate(fns):
145 | try:
146 | X = load(datadir + fn)[matkey]
147 | except:
148 | print('loading', datadir + fn, 'fail')
149 | continue
150 | X = preprocess(X)
151 | N = X.shape[0]
152 | for j in range(N):
153 | datum = caffe.proto.caffe_pb2.Datum()
154 | datum.channels = X.shape[1]
155 | datum.height = X.shape[2]
156 | datum.width = X.shape[3]
157 | datum.data = X[j].tobytes()
158 | str_id = '{:08}'.format(k)
159 | k += 1
160 | txn.put(str_id.encode('ascii'), datum.SerializeToString())
161 | print('load mat (%d/%d): %s' % (i, len(fns), fn))
162 |
163 | print('done')
164 |
165 | def create_cave():
166 | print('create harvard...')
167 | datadir = '/home/wzliu/projects/data/harvard/CZ_hsdb/'
168 | with open('harvard_train.txt') as f:
169 | fns = [l.strip() for l in f.readlines()]
170 |
171 | create_lmdb_train(
172 | datadir, fns, 'Harvard64_31', 'ref',
173 | crop_sizes=(1024,1024),
174 | scales=(1, 0.5, 0.25),
175 | ksizes=(31, 64, 64),
176 | strides=[(31, 64, 64), (31, 32, 32), (31, 32, 32)],
177 | load=loadmat, augment=True,
178 | )
179 |
180 |
181 | if __name__ == '__main__':
182 | create_cave()
183 |
--------------------------------------------------------------------------------
/benchmark/datasets/RealHSI/real_build_lmdb.py:
--------------------------------------------------------------------------------
1 | """Create lmdb dataset"""
2 | import os
3 | import numpy as np
4 | import lmdb
5 | from itertools import product
6 | from scipy.ndimage import zoom
7 | import random
8 | from scipy.io import loadmat
9 | import pickle
10 |
11 | def Data2Volume(data, ksizes, strides):
12 | """
13 | Construct Volumes from Original High Dimensional (D) Data
14 | """
15 | dshape = data.shape
16 | def PatNum(l, k, s): return (np.floor((l - k) / s) + 1)
17 |
18 | TotalPatNum = 1
19 | for i in range(len(ksizes)):
20 | TotalPatNum = TotalPatNum * PatNum(dshape[i], ksizes[i], strides[i])
21 |
22 | V = np.zeros([int(TotalPatNum)] + ksizes) # create D+1 dimension volume
23 |
24 | args = [range(kz) for kz in ksizes]
25 | for s in product(*args):
26 | s1 = (slice(None),) + s
27 | s2 = tuple([slice(key, -ksizes[i] + key + 1 or None, strides[i]) for i, key in enumerate(s)])
28 | V[s1] = np.reshape(data[s2], (-1,))
29 |
30 | return V
31 |
32 |
33 | def minmax_normalize(array):
34 | amin = np.min(array)
35 | amax = np.max(array)
36 | return (array - amin) / (amax - amin)
37 |
38 |
39 | def crop_center(img, cropx, cropy):
40 | _, y, x = img.shape
41 | startx = x // 2 - (cropx // 2)
42 | starty = y // 2 - (cropy // 2)
43 | return img[:, starty:starty + cropy, startx:startx + cropx]
44 |
45 |
46 | def data_augmentation(image, mode=None):
47 | """
48 | Args:
49 | image: np.ndarray, shape: C X H X W
50 | """
51 | axes = (-2, -1)
52 | def flipud(x): return x[:, ::-1, :]
53 |
54 | if mode is None:
55 | mode = random.randint(0, 7)
56 | if mode == 0:
57 | # original
58 | image = image
59 | elif mode == 1:
60 | # flip up and down
61 | image = flipud(image)
62 | elif mode == 2:
63 | # rotate counterwise 90 degree
64 | image = np.rot90(image, axes=axes)
65 | elif mode == 3:
66 | # rotate 90 degree and flip up and down
67 | image = np.rot90(image, axes=axes)
68 | image = flipud(image)
69 | elif mode == 4:
70 | # rotate 180 degree
71 | image = np.rot90(image, k=2, axes=axes)
72 | elif mode == 5:
73 | # rotate 180 degree and flip
74 | image = np.rot90(image, k=2, axes=axes)
75 | image = flipud(image)
76 | elif mode == 6:
77 | # rotate 270 degree
78 | image = np.rot90(image, k=3, axes=axes)
79 | elif mode == 7:
80 | # rotate 270 degree and flip
81 | image = np.rot90(image, k=3, axes=axes)
82 | image = flipud(image)
83 |
84 | # we apply spectrum reversal for training 3D CNN, e.g. QRNN3D.
85 | # disable it when training 2D CNN, e.g. MemNet
86 | if random.random() < 0.5:
87 | image = image[::-1, :, :]
88 |
89 | return np.ascontiguousarray(image)
90 |
91 |
92 | def create_lmdb_train(
93 | datadir, fns, name, matkey,
94 | crop_sizes, scales, ksizes, strides,
95 | load=loadmat, augment=True,
96 | seed=2022):
97 | """
98 | Create Augmented Dataset
99 | """
100 | def preprocess(data):
101 | new_data = []
102 | data = np.float32(data)
103 | # data = minmax_normalize(data.transpose((2, 0, 1)))
104 | data = data.transpose(2, 0, 1) / 4096
105 | if crop_sizes is not None:
106 | data = crop_center(data, crop_sizes[0], crop_sizes[1])
107 |
108 | for i in range(len(scales)):
109 | if scales[i] != 1:
110 | temp = zoom(data, zoom=(1, scales[i], scales[i]))
111 | else:
112 | temp = data
113 | temp = Data2Volume(temp, ksizes=ksizes, strides=list(strides[i]))
114 | new_data.append(temp)
115 | new_data = np.concatenate(new_data, axis=0)
116 | return new_data.astype(np.float32)
117 |
118 | np.random.seed(seed)
119 | scales = list(scales)
120 | ksizes = list(ksizes)
121 | assert len(scales) == len(strides)
122 | # calculate the shape of dataset
123 | data = load(os.path.join(datadir, fns[0]))
124 | gt = data['gt']
125 | input = data['input']
126 | data = preprocess(gt)
127 |
128 | N = data.shape[0]
129 |
130 | print(data.shape)
131 | map_size = data.nbytes * len(fns) * 2.4
132 | print('map size (GB):', map_size / 1024 / 1024 / 1024)
133 |
134 | if os.path.exists(name + '.db'):
135 | raise Exception('database already exist!')
136 | env = lmdb.open(name + '.db', map_size=map_size, writemap=True)
137 | with env.begin(write=True) as txn:
138 | # txn is a Transaction object
139 | k = 0
140 | for i, fn in enumerate(fns):
141 | try:
142 | data = load(os.path.join(datadir, fn))
143 | gt = data['gt']
144 | input = data['input']
145 | except:
146 | print('loading', datadir + fn, 'fail')
147 | continue
148 | gt = preprocess(gt)
149 | input = preprocess(input)
150 | if augment:
151 | for t in range(gt.shape[0]):
152 | mode = random.randint(0, 7)
153 | gt[t, ...] = data_augmentation(gt[t, ...], mode)
154 | input[t, ...] = data_augmentation(input[t, ...], mode)
155 |
156 | N = gt.shape[0]
157 | for j in range(N):
158 | str_id = '{:08}'.format(k)
159 | k += 1
160 | data = {'gt': gt[j], 'input': input[j]}
161 | data = pickle.dumps(data)
162 | txn.put(str_id.encode('ascii'), data)
163 | print('load mat (%d/%d): %s' % (i, len(fns), fn))
164 |
165 | print('done')
166 |
167 | def create_real():
168 | print('create real...')
169 | datadir = '/media/exthdd/datasets/hsi/real_hsi/real_dataset/mat/'
170 | with open('real_train.txt') as f:
171 | fns = [l.strip() for l in f.readlines()]
172 |
173 | create_lmdb_train(
174 | datadir, fns, 'Real64_31_2', 'gt',
175 | crop_sizes=None,
176 | scales=(1, 0.5, 0.25),
177 | ksizes=(31, 64, 64),
178 | strides=[(31, 64, 64), (31, 32, 32), (31, 32, 32)],
179 | load=loadmat, augment=True,
180 | )
181 |
182 |
183 | if __name__ == '__main__':
184 | create_real()
185 |
--------------------------------------------------------------------------------
/hsirun/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from collections import OrderedDict
3 | import os
4 | from os.path import join, exists
5 |
6 | import torch
7 | import torch.utils.data
8 | import imageio
9 | import numpy as np
10 | from tqdm import tqdm
11 | from tabulate import tabulate
12 |
13 | import hsir.model
14 | import hsir.data.utils
15 | from hsir.data import HSITestDataset
16 |
17 | import torchlight as tl
18 |
19 | tl.metrics.set_data_format('chw')
20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21 |
22 |
23 | def bchw2hwc(x):
24 | return np.uint8(x.cpu().squeeze(0).permute(1, 2, 0).numpy() * 255)
25 |
26 |
27 | def eval(net, loader, name, logdir, clamp, bandwise):
28 | os.makedirs(join(logdir, 'color'), exist_ok=True)
29 | os.makedirs(join(logdir, 'gray'), exist_ok=True)
30 |
31 | tracker = tl.trainer.util.MetricTracker()
32 | detail_stat = {}
33 |
34 | with torch.no_grad():
35 | pbar = tqdm(total=len(loader), dynamic_ncols=True)
36 | pbar.set_description(name)
37 | for data in loader:
38 | filename = data['filename'][0]
39 | basename = tl.utils.filename(filename)
40 | inputs, targets = data['input'].to(device), data['target'].to(device)
41 |
42 | if clamp:
43 | inputs = torch.clamp(inputs, 0., 1.)
44 | tl.utils.timer.tic()
45 | if isinstance(net, str):
46 | outputs = targets if net == 'gt' else inputs
47 | else:
48 | outputs = net(inputs)
49 |
50 | torch.cuda.synchronize()
51 | run_time = float(tl.utils.timer.toc()) / 1000
52 |
53 | outputs = outputs.clamp(0,1)
54 | inputs = inputs.squeeze(1)
55 | outputs = outputs.squeeze(1)
56 | targets = targets.squeeze(1)
57 |
58 | imageio.imwrite(
59 | join(logdir, 'color', basename + '.png'),
60 | bchw2hwc(hsir.data.utils.visualize_color(outputs))
61 | )
62 | imageio.imwrite(
63 | join(logdir, 'gray', basename + '.png'),
64 | bchw2hwc(hsir.data.utils.visualize_gray(outputs))
65 | )
66 |
67 | psnr = tl.metrics.mpsnr(targets, outputs).item()
68 | ssim = tl.metrics.mssim(targets, outputs).item()
69 | sam = tl.metrics.sam(targets, outputs).item()
70 |
71 | tracker.update('MPSNR', psnr)
72 | tracker.update('MSSIM', ssim)
73 | tracker.update('SAM', sam)
74 | tracker.update('Time', run_time)
75 |
76 | pbar.set_postfix({k: f'{v:.4f}' for k, v in tracker.result().items()})
77 | pbar.update()
78 |
79 | detail_stat[basename] = {'MPSNR': psnr, 'MSSIM': ssim, 'SAM': sam, 'Time': run_time}
80 |
81 | pbar.close()
82 |
83 | avg_speed = tl.utils.format_time(tracker['Time'])
84 | print(f'Average speed {avg_speed}')
85 | print(f'Average results {tracker.summary()}')
86 |
87 | # log structural results
88 | avg_stat = {k: v for k, v in tracker.result().items()}
89 | tl.utils.io.jsonwrite(join(logdir, 'log.json'), {'avg': avg_stat, 'detail': detail_stat})
90 |
91 |
92 | def pretty_summary(logdir):
93 | stat = []
94 | print('')
95 | for folder in os.listdir(logdir):
96 | if os.path.isdir(join(logdir, folder)):
97 | path = join(logdir, folder, 'log.json')
98 | if exists(path):
99 | data = tl.utils.io.jsonload(path)
100 | s = OrderedDict()
101 | s['Name'] = folder
102 | s.update(data['avg'])
103 | stat.append(s)
104 | print(tabulate(stat, headers='keys', tablefmt='github'))
105 | print('')
106 |
107 |
108 | def main(args, logdir):
109 | if args.arch == 'gt' or args.arch == 'input':
110 | net = args.arch
111 | else:
112 | net = tl.utils.instantiate(args.arch)
113 | net = net.to(device)
114 | if args.resume:
115 | ckpt = tl.utils.dict_get(torch.load(args.resume), args.key_path)
116 | if ckpt is None: print(f'key_path {args.key_path} might be wrong')
117 | net.load_state_dict(ckpt)
118 | net.eval()
119 |
120 | for testset in args.testset:
121 | print('Evaluating {}'.format(args.arch))
122 | print('On {}'.format(testset))
123 | print('With {}'.format(args.resume))
124 | testdir = join(args.basedir, testset)
125 | dataset = HSITestDataset(testdir, use_chw=args.use_conv2d, return_name=True)
126 | loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
127 | eval(net, loader, testset, join(logdir, testset), args.clamp, args.bandwise)
128 |
129 | pretty_summary(logdir)
130 |
131 |
132 | if __name__ == '__main__':
133 | parser = argparse.ArgumentParser(description='HSIR Test script')
134 | parser.add_argument('-a', '--arch', required=True, help='architecture name')
135 | parser.add_argument('-n', '--name', default=None, help='save name')
136 | parser.add_argument('-r', '--resume', default=None, help='checkpoint')
137 | parser.add_argument('-t', '--testset', nargs='+', default=['icvl_512_50'], help='testset')
138 | parser.add_argument('-d', '--basedir', default='data', help='basedir')
139 | parser.add_argument('--logdir', default='results', help='logdir')
140 | parser.add_argument('--save_img', action='store_true', help='whether to save image')
141 | parser.add_argument('--clamp', action='store_true', help='whether clamp input into [0, 1]')
142 | parser.add_argument('-kp', '--key_path', default='net', help='key path to access network state_dict in ckpt')
143 | parser.add_argument('--bandwise', action='store_true')
144 | parser.add_argument('--use-conv2d', action='store_true')
145 | parser.add_argument('--force-run', action='store_true')
146 | args = parser.parse_args()
147 |
148 | save_name = args.arch if args.name is None else args.name
149 | logdir = join(args.logdir, save_name)
150 | if exists(logdir) and not args.force_run:
151 | print(f'It seems that you have evaluated {args.arch} before.')
152 | pretty_summary(logdir)
153 | action = input('Are you sure you want to continue? (y) continue (n) exit\n')
154 | if action != 'y': exit()
155 |
156 | os.makedirs(logdir, exist_ok=True)
157 | with open(join(logdir, 'meta.txt'), 'w') as f:
158 | f.write(tl.utils.get_datetime() + '\n')
159 | f.write(tl.utils.get_cmd() + '\n')
160 | f.write(str(args) + '\n')
161 |
162 | main(args, logdir)
163 |
--------------------------------------------------------------------------------
/hsir/model/t3sc/layers/lowrank_sc_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | import math
5 | import logging
6 |
7 | from .encoding_layer import EncodingLayer
8 | from .soft_thresholding import SoftThresholding
9 |
10 | logger = logging.getLogger(__name__)
11 | logger.setLevel(logging.DEBUG)
12 |
13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14 |
15 |
16 | class LowRankSCLayer(EncodingLayer):
17 | def __init__(
18 | self,
19 | patch_side,
20 | stride,
21 | K,
22 | rank,
23 | patch_centering,
24 | lbda_init,
25 | lbda_mode,
26 | beta=0,
27 | ssl=0,
28 | **kwargs,
29 | ):
30 | super().__init__(**kwargs)
31 | assert self.in_channels is not None
32 | assert self.code_size is not None
33 | self.patch_side = patch_side
34 | self.stride = stride
35 | self.K = K
36 | self.rank = rank
37 | self.patch_centering = patch_centering
38 | self.lbda_init = lbda_init
39 | self.lbda_mode = lbda_mode
40 | self.patch_size = self.in_channels * self.patch_side ** 2
41 | self.spat_dim = self.patch_side ** 2
42 | self.spec_dim = self.in_channels
43 | self.beta = beta
44 | self.ssl = ssl
45 |
46 | # first is spectral, second is spatial
47 | self.init_weights(
48 | [
49 | (self.code_size, self.spec_dim, self.rank),
50 | (self.code_size, self.rank, self.spat_dim),
51 | ]
52 | )
53 |
54 | self.thresholds = SoftThresholding(
55 | mode=self.lbda_mode,
56 | lbda_init=self.lbda_init,
57 | code_size=self.code_size,
58 | K=self.K,
59 | )
60 | if self.patch_centering and self.patch_side == 1:
61 | raise ValueError(
62 | "Patch centering and 1x1 kernel will result in null patches"
63 | )
64 |
65 | if self.patch_centering:
66 | ones = torch.ones(
67 | self.in_channels, 1, self.patch_side, self.patch_side
68 | )
69 | self.ker_mean = (ones / self.patch_side ** 2).to(device)
70 | self.ker_divider = torch.ones(
71 | 1, 1, self.patch_side, self.patch_side
72 | ).to(device)
73 | self.divider = None
74 |
75 | if self.beta:
76 | self.beta_estimator = nn.Sequential(
77 | # layer1
78 | nn.Conv2d(
79 | in_channels=1, out_channels=64, kernel_size=5, stride=2
80 | ),
81 | nn.ReLU(),
82 | nn.MaxPool2d(kernel_size=2),
83 | # layer2
84 | nn.Conv2d(
85 | in_channels=64, out_channels=128, kernel_size=3, stride=2
86 | ),
87 | nn.ReLU(),
88 | nn.MaxPool2d(kernel_size=2),
89 | # layer3
90 | nn.Conv2d(
91 | in_channels=128, out_channels=1, kernel_size=3, stride=1
92 | ),
93 | nn.Sigmoid(),
94 | )
95 |
96 | def init_weights(self, shape):
97 | for w in ["C", "D", "W"]:
98 | setattr(self, w, self.init_param(shape))
99 |
100 | def init_param(self, shape):
101 | def init_tensor(shape):
102 | tensor = torch.empty(*shape)
103 | torch.nn.init.kaiming_uniform_(tensor, a=math.sqrt(5))
104 | return tensor
105 |
106 | if isinstance(shape, list):
107 | return torch.nn.ParameterList([self.init_param(s) for s in shape])
108 | return torch.nn.Parameter(init_tensor(shape))
109 |
110 | def _encode(self, x, sigmas=None, ssl_idx=None, **kwargs):
111 | self.shape_in = x.shape
112 | bs, c, h, w = self.shape_in
113 |
114 | if self.beta:
115 | block = min(56, h)
116 | c_w = (w - block) // 2
117 | c_h = (h - block) // 2
118 | to_estimate = x[:, :, c_h : c_h + block, c_w : c_w + block].view(
119 | bs * c, 1, block, block
120 | )
121 | beta = 1 - self.beta_estimator(to_estimate)
122 | # (bs * c, 1)
123 | beta = beta.view(bs, c, 1, 1)
124 | else:
125 | beta = torch.ones((bs, c, 1, 1), device=x.device)
126 |
127 | if self.ssl:
128 | # discard error on bands we want to predict
129 | with torch.no_grad():
130 | mask = torch.ones_like(beta)
131 | mask[:, ssl_idx.long()] = 0.0
132 |
133 | beta = beta * mask
134 |
135 | if self.beta or self.ssl:
136 | # applying beta before or after centering is equivalent
137 | x = x * beta
138 |
139 | CT = (self.C[0] @ self.C[1]).view(
140 | self.code_size,
141 | self.in_channels,
142 | self.patch_side,
143 | self.patch_side,
144 | )
145 |
146 | if self.patch_centering:
147 | A = F.conv2d(x, CT - CT.mean(dim=[2, 3], keepdim=True))
148 | self.means = F.conv2d(x, self.ker_mean, groups=self.in_channels)
149 | else:
150 | A = F.conv2d(x, CT)
151 |
152 | alpha = self.thresholds(A, 0)
153 |
154 | D = (self.D[0] @ self.D[1]).view(
155 | self.code_size,
156 | self.in_channels,
157 | self.patch_side,
158 | self.patch_side,
159 | )
160 |
161 | for k in range(1, self.K):
162 | D_alpha = F.conv_transpose2d(alpha, D)
163 | D_alpha = D_alpha * beta
164 | alpha = self.thresholds(A + alpha - F.conv2d(D_alpha, CT), k)
165 |
166 | return alpha
167 |
168 | def _decode(self, alpha, **kwargs):
169 | W = ((self.W[0]) @ self.W[1]).view(
170 | self.code_size,
171 | self.in_channels,
172 | self.patch_side,
173 | self.patch_side,
174 | )
175 |
176 | x = F.conv_transpose2d(alpha, W)
177 |
178 | if self.patch_centering:
179 | x += F.conv_transpose2d(
180 | self.means,
181 | self.ker_mean * self.patch_side ** 2,
182 | groups=self.in_channels,
183 | )
184 | if self.divider is None or self.divider.shape[-2:] != (x.shape[-2:]):
185 | ones = torch.ones(
186 | 1, 1, alpha.shape[2], alpha.shape[3], device=alpha.device
187 | ).to(alpha.device)
188 | self.divider = F.conv_transpose2d(ones, self.ker_divider)
189 |
190 | x = x / self.divider
191 |
192 | return x
193 |
--------------------------------------------------------------------------------
/hsir/data/transform/noise.py:
--------------------------------------------------------------------------------
1 | import threading
2 |
3 | import numpy as np
4 |
5 |
6 | # helper
7 |
8 | class LockedIterator(object):
9 | def __init__(self, it):
10 | self.lock = threading.Lock()
11 | self.it = it.__iter__()
12 |
13 | def __iter__(self): return self
14 |
15 | def __next__(self):
16 | self.lock.acquire()
17 | try:
18 | return next(self.it)
19 | finally:
20 | self.lock.release()
21 |
22 |
23 | # noise
24 |
25 | class AddNoise(object):
26 | """add gaussian noise to the given numpy array (B,H,W)"""
27 |
28 | def __init__(self, sigma):
29 | self.sigma_ratio = sigma / 255.
30 |
31 | def __call__(self, img):
32 | noise = np.random.randn(*img.shape) * self.sigma_ratio
33 | # print(img.sum(), noise.sum())
34 | return img + noise
35 |
36 |
37 | class AddNoiseBlind(object):
38 | """add blind gaussian noise to the given numpy array (B,H,W)"""
39 |
40 | def __pos(self, n):
41 | i = 0
42 | while True:
43 | yield i
44 | i = (i + 1) % n
45 |
46 | def __init__(self, sigmas):
47 | self.sigmas = np.array(sigmas) / 255.
48 | self.pos = LockedIterator(self.__pos(len(sigmas)))
49 |
50 | def __call__(self, img):
51 | noise = np.random.randn(*img.shape) * self.sigmas[next(self.pos)]
52 | return img + noise
53 |
54 |
55 | class AddNoiseBlindv2(object):
56 | """add blind gaussian noise to the given numpy array (B,H,W)"""
57 |
58 | def __init__(self, min_sigma, max_sigma):
59 | self.min_sigma = min_sigma
60 | self.max_sigma = max_sigma
61 |
62 | def __call__(self, img):
63 | noise = np.random.randn(*img.shape) * np.random.uniform(self.min_sigma, self.max_sigma) / 255
64 | return img + noise
65 |
66 |
67 | class AddNoiseNoniid(object):
68 | """add non-iid gaussian noise to the given numpy array (B,H,W)"""
69 |
70 | def __init__(self, sigmas):
71 | self.sigmas = np.array(sigmas) / 255.
72 |
73 | def __call__(self, img):
74 | bwsigmas = np.reshape(self.sigmas[np.random.randint(0, len(self.sigmas), img.shape[0])], (-1, 1, 1))
75 | noise = np.random.randn(*img.shape) * bwsigmas
76 | return img + noise
77 |
78 |
79 | class AddNoiseMixed(object):
80 | """add mixed noise to the given numpy array (B,H,W)
81 | Args:
82 | noise_bank: list of noise maker (e.g. AddNoiseImpulse)
83 | num_bands: list of number of band which is corrupted by each item in noise_bank"""
84 |
85 | def __init__(self, noise_bank, num_bands):
86 | assert len(noise_bank) == len(num_bands)
87 | self.noise_bank = noise_bank
88 | self.num_bands = num_bands
89 |
90 | def __call__(self, img):
91 | B, H, W = img.shape
92 | all_bands = np.random.permutation(range(B))
93 | pos = 0
94 | for noise_maker, num_band in zip(self.noise_bank, self.num_bands):
95 | if 0 < num_band <= 1:
96 | num_band = int(np.floor(num_band * B))
97 | bands = all_bands[pos:pos+num_band]
98 | pos += num_band
99 | img = noise_maker(img, bands)
100 | return img
101 |
102 |
103 | class _AddNoiseImpulse(object):
104 | """add impulse noise to the given numpy array (B,H,W)"""
105 |
106 | def __init__(self, amounts, s_vs_p=0.5):
107 | self.amounts = np.array(amounts)
108 | self.s_vs_p = s_vs_p
109 |
110 | def __call__(self, img, bands):
111 | # bands = np.random.permutation(range(img.shape[0]))[:self.num_band]
112 | bwamounts = self.amounts[np.random.randint(0, len(self.amounts), len(bands))]
113 | for i, amount in zip(bands, bwamounts):
114 | self.add_noise(img[i, ...], amount=amount, salt_vs_pepper=self.s_vs_p)
115 | return img
116 |
117 | def add_noise(self, image, amount, salt_vs_pepper):
118 | # out = image.copy()
119 | out = image
120 | p = amount
121 | q = salt_vs_pepper
122 | flipped = np.random.choice([True, False], size=image.shape,
123 | p=[p, 1 - p])
124 | salted = np.random.choice([True, False], size=image.shape,
125 | p=[q, 1 - q])
126 | peppered = ~salted
127 | out[flipped & salted] = 1
128 | out[flipped & peppered] = 0
129 | return out
130 |
131 |
132 | class _AddNoiseStripe(object):
133 | """add stripe noise to the given numpy array (B,H,W)"""
134 |
135 | def __init__(self, min_amount, max_amount):
136 | assert max_amount > min_amount
137 | self.min_amount = min_amount
138 | self.max_amount = max_amount
139 |
140 | def __call__(self, img, bands):
141 | B, H, W = img.shape
142 | # bands = np.random.permutation(range(img.shape[0]))[:len(bands)]
143 | num_stripe = np.random.randint(np.floor(self.min_amount*W), np.floor(self.max_amount*W), len(bands))
144 | for i, n in zip(bands, num_stripe):
145 | loc = np.random.permutation(range(W))
146 | loc = loc[:n]
147 | stripe = np.random.uniform(0, 1, size=(len(loc),))*0.5-0.25
148 | img[i, :, loc] -= np.reshape(stripe, (-1, 1))
149 | return img
150 |
151 |
152 | class _AddNoiseDeadline(object):
153 | """add deadline noise to the given numpy array (B,H,W)"""
154 |
155 | def __init__(self, min_amount, max_amount):
156 | assert max_amount > min_amount
157 | self.min_amount = min_amount
158 | self.max_amount = max_amount
159 |
160 | def __call__(self, img, bands):
161 | B, H, W = img.shape
162 | # bands = np.random.permutation(range(img.shape[0]))[:len(bands)]
163 | num_deadline = np.random.randint(np.ceil(self.min_amount*W), np.ceil(self.max_amount*W), len(bands))
164 | for i, n in zip(bands, num_deadline):
165 | loc = np.random.permutation(range(W))
166 | loc = loc[:n]
167 | img[i, :, loc] = 0
168 | return img
169 |
170 |
171 | class AddNoiseImpulse(AddNoiseMixed):
172 | def __init__(self):
173 | self.noise_bank = [_AddNoiseImpulse([0.1, 0.3, 0.5, 0.7])]
174 | self.num_bands = [1/3]
175 |
176 |
177 | class AddNoiseStripe(AddNoiseMixed):
178 | def __init__(self):
179 | self.noise_bank = [_AddNoiseStripe(0.05, 0.15)]
180 | self.num_bands = [1/3]
181 |
182 |
183 | class AddNoiseDeadline(AddNoiseMixed):
184 | def __init__(self):
185 | self.noise_bank = [_AddNoiseDeadline(0.05, 0.15)]
186 | self.num_bands = [1/3]
187 |
188 |
189 | class AddNoiseComplex(AddNoiseMixed):
190 | def __init__(self):
191 | self.noise_bank = [
192 | _AddNoiseStripe(0.05, 0.15),
193 | _AddNoiseDeadline(0.05, 0.15),
194 | _AddNoiseImpulse([0.1, 0.3, 0.5, 0.7])
195 | ]
196 | self.num_bands = [1/3, 1/3, 1/3]
197 |
--------------------------------------------------------------------------------
/hsir/model/hsdt/arch.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from .attention import TransformerBlock
8 | from .sepconv import SepConv_DP, SepConv_DP_CA, S3Conv
9 |
10 | BatchNorm3d = nn.BatchNorm3d
11 | Conv3d = S3Conv.of(nn.Conv3d)
12 | TransformerBlock = TransformerBlock
13 | IsConvImpl = False
14 | UseBN = True
15 |
16 |
17 | def PlainConv(in_ch, out_ch):
18 | return nn.Sequential(OrderedDict([
19 | ('conv', Conv3d(in_ch, out_ch, 3, 1, 1, bias=False)),
20 | ('bn', BatchNorm3d(out_ch) if UseBN else nn.Identity()),
21 | ('attn', TransformerBlock(out_ch, bias=True))
22 | ]))
23 |
24 |
25 | def DownConv(in_ch, out_ch):
26 | return nn.Sequential(OrderedDict([
27 | ('conv', nn. Conv3d(in_ch, out_ch, 3, (1, 2, 2), 1, bias=False)),
28 | ('bn', BatchNorm3d(out_ch)if UseBN else nn.Identity()),
29 | ('attn', TransformerBlock(out_ch, bias=True))
30 | ]))
31 |
32 |
33 | def UpConv(in_ch, out_ch):
34 | return nn.Sequential(OrderedDict([
35 | ('up', nn.Upsample(scale_factor=(1, 2, 2), mode='trilinear', align_corners=True)),
36 | ('conv', nn.Conv3d(in_ch, out_ch, 3, 1, 1, bias=False)),
37 | ('bn', BatchNorm3d(out_ch) if UseBN else nn.Identity()),
38 | ('attn', TransformerBlock(out_ch, bias=True))
39 | ]))
40 |
41 |
42 | class Encoder(nn.Module):
43 | def __init__(self, channels, num_half_layer, sample_idx):
44 | super(Encoder, self).__init__()
45 | self.layers = nn.ModuleList()
46 | for i in range(num_half_layer):
47 | if i not in sample_idx:
48 | encoder_layer = PlainConv(channels, channels)
49 | else:
50 | encoder_layer = DownConv(channels, 2 * channels)
51 | channels *= 2
52 | self.layers.append(encoder_layer)
53 |
54 | def forward(self, x, xs):
55 | num_half_layer = len(self.layers)
56 | for i in range(num_half_layer - 1):
57 | x = self.layers[i](x)
58 | xs.append(x)
59 | x = self.layers[-1](x)
60 | return x
61 |
62 |
63 | class Decoder(nn.Module):
64 | count = 1
65 | def __init__(self, channels, num_half_layer, sample_idx, Fusion=None):
66 | super(Decoder, self).__init__()
67 | # Decoder
68 | self.layers = nn.ModuleList()
69 | self.enable_fusion = Fusion is not None
70 |
71 | if self.enable_fusion:
72 | self.fusions = nn.ModuleList()
73 | ch = channels
74 | for i in reversed(range(num_half_layer)):
75 | fusion_layer = Fusion(ch)
76 | if i in sample_idx:
77 | ch //= 2
78 | self.fusions.append(fusion_layer)
79 |
80 | for i in reversed(range(num_half_layer)):
81 | if i not in sample_idx:
82 | decoder_layer = PlainConv(channels, channels)
83 | else:
84 | decoder_layer = UpConv(channels, channels // 2)
85 | channels //= 2
86 | self.layers.append(decoder_layer)
87 |
88 | def forward(self, x, xs):
89 | num_half_layer = len(self.layers)
90 | x = self.layers[0](x)
91 | for i in range(1, num_half_layer):
92 | if self.enable_fusion:
93 | x = self.fusions[i](x, xs.pop())
94 | else:
95 | x = x + xs.pop()
96 | x = self.layers[i](x)
97 | return x
98 |
99 |
100 | class HSDT(nn.Module):
101 | def __init__(self, in_channels, channels, num_half_layer, sample_idx, Fusion=None):
102 | super(HSDT, self).__init__()
103 | self.head = PlainConv(in_channels, channels)
104 | self.encoder = Encoder(channels, num_half_layer, sample_idx)
105 | self.decoder = Decoder(channels * (2**len(sample_idx)), num_half_layer, sample_idx, Fusion=Fusion)
106 | self.tail = nn.Conv3d(channels, 1, 3, 1, 1, bias=True)
107 |
108 | def forward(self, x):
109 | xs = [x]
110 | out = self.head(xs[0])
111 | xs.append(out)
112 | out = self.encoder(out, xs)
113 | out = self.decoder(out, xs)
114 | out = out + xs.pop()
115 | out = self.tail(out)
116 | out = out + xs.pop()[:, 0:1, :, :, :]
117 | return out
118 |
119 | def load_state_dict(self, state_dict, strict: bool = True):
120 | if IsConvImpl:
121 | new_state_dict = {}
122 | for k, v in state_dict.items():
123 | if ('attn.attn' in k) and 'weight' in k and 'attn_proj' not in k:
124 | new_state_dict[k] = v.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
125 | else:
126 | new_state_dict[k] = v
127 | state_dict = new_state_dict
128 | return super().load_state_dict(state_dict, strict)
129 |
130 |
131 | class HSDTSSR(nn.Module):
132 | def __init__(self, in_channels, channels, num_half_layer, sample_idx, Fusion=None):
133 | super(HSDTSSR, self).__init__()
134 | self.proj = nn.Conv2d(3, 31, 1, bias=False)
135 | self.head = PlainConv(in_channels, channels)
136 | self.encoder = Encoder(channels, num_half_layer, sample_idx)
137 | self.decoder = Decoder(channels * (2**len(sample_idx)), num_half_layer, sample_idx, Fusion=Fusion)
138 | self.tail = nn.Conv3d(channels, 1, 3, 1, 1, bias=True)
139 |
140 | def forward(self, x):
141 | if self.training:
142 | return self.forward_train(x)
143 | return self.forward_test(x)
144 |
145 | def forward_train(self, x):
146 | x = F.leaky_relu(self.proj(x)).unsqueeze(1)
147 | xs = [x]
148 | out = self.head(xs[0])
149 | xs.append(out)
150 | out = self.encoder(out, xs)
151 | out = self.decoder(out, xs)
152 | out = out + xs.pop()
153 | out = self.tail(out)
154 | out = out + xs.pop()[:, 0:1, :, :, :]
155 | out = out.squeeze(1)
156 | return out
157 |
158 | def forward_test(self, x):
159 | pad_x, H, W = pad_mod(x, 8)
160 | output = self.forward_train(pad_x)[..., :H, :W]
161 | return output
162 |
163 | def load_state_dict(self, state_dict, strict: bool = True):
164 | if IsConvImpl:
165 | new_state_dict = {}
166 | for k, v in state_dict.items():
167 | if ('attn.attn' in k) and 'weight' in k and 'attn_proj' not in k:
168 | new_state_dict[k] = v.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
169 | else:
170 | new_state_dict[k] = v
171 | state_dict = new_state_dict
172 | return super().load_state_dict(state_dict, strict)
173 |
174 |
175 | def pad_mod(x, mod):
176 | h, w = x.shape[-2:]
177 | h_out = (h // mod + 1) * mod
178 | w_out = (w // mod + 1) * mod
179 | out = torch.zeros(*x.shape[:-2], h_out, w_out).type_as(x)
180 | out[..., :h, :w] = x
181 | return out.to(x.device), h, w
182 |
--------------------------------------------------------------------------------
/hsir/model/t3sc/base.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import time
3 |
4 | import numpy as np
5 | import pytorch_lightning as pl
6 | import torch
7 |
8 | from .metrics import mpsnr, mse, psnr
9 | from .utils import PatchesHandler
10 |
11 | logger = logging.getLogger(__name__)
12 | logger.setLevel(logging.DEBUG)
13 |
14 |
15 | class BaseModel(pl.LightningModule):
16 | def __init__(
17 | self, optimizer=None, lr_scheduler=None, block_inference=None
18 | ):
19 | super().__init__()
20 | self.optimizer = optimizer
21 | self.lr_scheduler = lr_scheduler
22 | self.block_inference = block_inference
23 | self.ssl = 0
24 | self.n_ssl = 0
25 |
26 | self.automatic_optimization = False
27 |
28 | def training_step(self, batch, batch_idx):
29 | opt = self.optimizers()
30 | opt.zero_grad()
31 |
32 | y = batch.pop("y")
33 | if self.ssl:
34 | x = batch["x"]
35 | bs, c, h, w = x.shape
36 | ssl_idx = torch.randperm(c)[: self.n_ssl].to(self.device)
37 | batch["ssl_idx"] = ssl_idx
38 | out = self.forward(**batch)
39 | band_out = out[:, ssl_idx]
40 | band_target = x[:, ssl_idx]
41 | self.log("train_mse", mse(band_out, band_target))
42 | self.log("train_psnr", psnr(band_out, band_target))
43 | self.log("train_mpsnr", mpsnr(band_out.detach(), band_target))
44 | band_y = y[:, ssl_idx]
45 | self.log("train_mse_y", mse(band_out, band_y))
46 | self.log("train_psnr_y", psnr(band_out, band_y))
47 | self.log("train_mpsnr_y", mpsnr(band_out.detach(), band_y))
48 |
49 | loss = mse(band_out, band_target)
50 |
51 | else:
52 | out = self.forward(**batch)
53 | self.log("train_mse", mse(out, y))
54 | self.log("train_psnr", psnr(out, y))
55 | self.log("train_mpsnr", mpsnr(out.detach(), y))
56 |
57 | loss = mse(out, y)
58 |
59 | self.manual_backward(loss)
60 | opt.step()
61 |
62 | sch = self.lr_schedulers()
63 | if self.trainer.is_last_batch:
64 | epoch = self.current_epoch
65 | lr = sch.get_last_lr()
66 | logger.info(f"Epoch {epoch} : lr={lr} \t loss={loss:.6f}")
67 | sch.step()
68 |
69 | def validation_step(self, batch, batch_idx):
70 | y = batch.pop("y")
71 | start = time.time()
72 | if self.ssl:
73 | bs, c, h, w = batch["x"].shape
74 | out = torch.zeros_like(batch["x"])
75 | N = int(np.ceil(c / self.n_ssl))
76 | for i in range(N):
77 | ssl_idx = self.get_ssl_idx(i, c).long()
78 | batch["ssl_idx"] = ssl_idx
79 | if self.block_inference and self.block_inference.use_bi:
80 | _out = self.forward_blocks(**batch)
81 | else:
82 | _out = self.forward(**batch)
83 | out[:, ssl_idx] = _out[:, ssl_idx]
84 | else:
85 | if self.block_inference and self.block_inference.use_bi:
86 | out = self.forward_blocks(**batch)
87 | else:
88 | out = self.forward(**batch)
89 | logger.debug(f"Val denoised shape: {out.shape}")
90 | out = out.clamp(0, 1)
91 | elapsed = time.time() - start
92 | _mse = mse(out, y)
93 | _mpsnr = mpsnr(out, y)
94 | logger.debug(f"Val mse : {_mse}, mpsnr: {_mpsnr}")
95 | self.log("val_mse", mse(out, y))
96 | self.log("val_psnr", psnr(out, y))
97 | self.log("val_mpsnr", mpsnr(out, y))
98 | self.log("val_batch_time", elapsed)
99 | self.log("val_psnr_noise", psnr(batch["x"], y))
100 | self.log("val_mpsnr_noise", mpsnr(batch["x"], y))
101 |
102 | def get_ssl_idx(self, i, c):
103 | N = np.ceil(c / self.n_ssl)
104 | L = int(np.ceil((c - i) / N))
105 | return i + N * torch.arange(L)
106 |
107 | def test_step(self, batch, batch_idx):
108 | y = batch.pop("y")
109 | if self.ssl:
110 | bs, c, h, w = batch["x"].shape
111 | out = torch.zeros_like(batch["x"])
112 | N = int(np.ceil(c / self.n_ssl))
113 | for i in range(N):
114 | ssl_idx = self.get_ssl_idx(i, c).long()
115 | batch["ssl_idx"] = ssl_idx
116 | if self.block_inference and self.block_inference.use_bi:
117 | _out = self.forward_blocks(**batch)
118 | else:
119 | _out = self.forward(**batch)
120 | out[:, ssl_idx] = _out[:, ssl_idx]
121 | else:
122 | if self.block_inference and self.block_inference.use_bi:
123 | out = self.forward_blocks(**batch)
124 | else:
125 | out = self.forward(**batch)
126 | logger.debug(f"Test denoised shape: {out.shape}")
127 | out = out.clamp(0, 1)
128 | self.log("test_mse", mse(out, y))
129 | self.log("test_psnr", psnr(out, y))
130 | self.log("test_mpsnr", mpsnr(out, y))
131 | self.log("test_psnr_noise", psnr(batch["x"], y))
132 | self.log("test_mpsnr_noise", mpsnr(batch["x"], y))
133 |
134 | def forward_blocks(self, x, **kwargs):
135 | logger.debug(f"Starting block inference")
136 | block_size = min(
137 | max(x.shape[-1], x.shape[-2]), self.block_inference.block_size
138 | )
139 | patches_handler = PatchesHandler(
140 | size=(block_size,) * 2,
141 | channels=x.shape[1],
142 | stride=block_size - self.block_inference.overlap,
143 | padding=self.block_inference.padding,
144 | )
145 |
146 | logger.debug(f"Forward patches handler")
147 | blocks_in = patches_handler(x, mode="extract").clone()
148 | blocks_grid = tuple(blocks_in.shape[-2:])
149 | logger.debug(f"blocks grid : {blocks_in.shape}")
150 |
151 | blocks_out = torch.zeros_like(blocks_in)
152 |
153 | logger.debug(f"Processing blocks {blocks_grid}")
154 | for i in range(blocks_grid[0]):
155 | for j in range(blocks_grid[1]):
156 | blocks_ij = self.forward(blocks_in[:, :, :, :, i, j], **kwargs)
157 | blocks_out[:, :, :, :, i, j] = blocks_ij
158 | x = patches_handler(blocks_out, mode="aggregate")
159 | logger.debug(f"Blocks aggregated to shape : {tuple(x.shape)}")
160 | return x
161 |
162 | def configure_optimizers(self):
163 | logger.debug("Configuring optimizer")
164 | optim_class = torch.optim.__dict__[self.optimizer.class_name]
165 | optimizer = optim_class(self.parameters(), **self.optimizer.params)
166 |
167 | if self.lr_scheduler is not None:
168 | scheduler_class = torch.optim.lr_scheduler.__dict__[
169 | self.lr_scheduler.class_name
170 | ]
171 | scheduler = scheduler_class(optimizer, **self.lr_scheduler.params)
172 | return [optimizer], [scheduler]
173 |
174 | def count_params(self):
175 | desc = "Model parameters:\n"
176 | counter = 0
177 | for name, param in self.named_parameters():
178 | if param.requires_grad:
179 | count = param.numel()
180 | desc += f"\t{name} : {count}\n"
181 | counter += count
182 | desc += f"Total number of learnable parameters : {counter}\n"
183 | logger.info(desc)
184 | return counter, desc
185 |
--------------------------------------------------------------------------------
/hsir/model/man/v2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | """ Utility block """
7 |
8 |
9 | class UpsampleConv3d(torch.nn.Module):
10 | """ UpsampleConvLayer
11 | """
12 |
13 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=True, upsample=None, group=1):
14 | super(UpsampleConv3d, self).__init__()
15 | self.upsample_layer = nn.Upsample(scale_factor=upsample, mode='trilinear', align_corners=True)
16 | self.conv3d = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, bias=bias, groups=group)
17 |
18 | def forward(self, x):
19 | x = self.upsample_layer(x)
20 | x = self.conv3d(x)
21 | return x
22 |
23 |
24 | class MLP(nn.Module):
25 | """
26 | Multilayer Perceptron (MLP)
27 | """
28 |
29 | def __init__(self, channel, bias=True):
30 | super().__init__()
31 | self.w_1 = nn.Conv3d(channel, channel, bias=bias, kernel_size=1)
32 | self.w_2 = nn.Conv3d(channel, channel, bias=bias, kernel_size=1)
33 |
34 | def forward(self, x):
35 | return self.w_2(F.tanh(self.w_1(x)))
36 |
37 |
38 | """ The proposed blocks
39 | """
40 |
41 |
42 | class PSCA(nn.Module):
43 | """ Progressive Spectral Channel Attention (PSCA)
44 | """
45 |
46 | def __init__(self, d_model, d_ff):
47 | super().__init__()
48 | self.w_1 = nn.Conv3d(d_model, d_ff, 1, bias=False)
49 | self.w_2 = nn.Conv3d(d_ff, d_model, 1, bias=False)
50 | self.w_3 = nn.Conv3d(d_model, d_model, 1, bias=False)
51 |
52 | nn.init.zeros_(self.w_3.weight)
53 |
54 | def forward(self, x):
55 | x = self.w_3(x) * x + x
56 | x = self.w_1(x)
57 | x = F.gelu(x)
58 | x = self.w_2(x)
59 | return x
60 |
61 |
62 | class ASC(nn.Module):
63 | """ Attentive Skip Connection
64 | """
65 |
66 | def __init__(self, channel):
67 | super().__init__()
68 | self.weight = nn.Sequential(
69 | nn.Conv3d(channel * 2, channel, 1),
70 | nn.LeakyReLU(),
71 | nn.Conv3d(channel, channel, 3, 1, 1),
72 | nn.Sigmoid()
73 | )
74 |
75 | def forward(self, x, y):
76 | w = self.weight(torch.cat([x, y], dim=1))
77 | out = (1 - w) * x + w * y
78 | return out
79 |
80 |
81 | class MHRSA(nn.Module):
82 | """ Multi-Head Recurrent Spectral Attention
83 | """
84 |
85 | def __init__(self, channels, multi_head=True, ffn=True):
86 | super().__init__()
87 | self.channels = channels
88 | self.multi_head = multi_head
89 | self.ffn = ffn
90 |
91 | if ffn:
92 | self.ffn1 = MLP(channels)
93 | self.ffn2 = MLP(channels)
94 |
95 | def _conv_step(self, inputs):
96 | if self.ffn:
97 | Z = self.ffn1(inputs).tanh()
98 | F = self.ffn2(inputs).sigmoid()
99 | else:
100 | Z, F = inputs.split(split_size=self.channels, dim=1)
101 | Z, F = Z.tanh(), F.sigmoid()
102 | return Z, F
103 |
104 | def _rnn_step(self, z, f, h):
105 | h_ = (1 - f) * z if h is None else f * h + (1 - f) * z
106 | return h_
107 |
108 | def forward(self, inputs, reverse=False):
109 | Z, F = self._conv_step(inputs)
110 |
111 | if self.multi_head:
112 | Z1, Z2 = Z.split(self.channels // 2, 1)
113 | Z2 = torch.flip(Z2, [2])
114 | Z = torch.cat([Z1, Z2], dim=1)
115 |
116 | F1, F2 = F.split(self.channels // 2, 1)
117 | F2 = torch.flip(F2, [2])
118 | F = torch.cat([F1, F2], dim=1)
119 |
120 | h = None
121 | h_time = []
122 |
123 | if not reverse:
124 | for _, (z, f) in enumerate(zip(Z.split(1, 2), F.split(1, 2))):
125 | h = self._rnn_step(z, f, h)
126 | h_time.append(h)
127 | else:
128 | for _, (z, f) in enumerate((zip(
129 | reversed(Z.split(1, 2)), reversed(F.split(1, 2))
130 | ))): # split along timestep
131 | h = self._rnn_step(z, f, h)
132 | h_time.insert(0, h)
133 |
134 | y = torch.cat(h_time, dim=2)
135 |
136 | if self.multi_head:
137 | y1, y2 = y.split(self.channels // 2, 1)
138 | y2 = torch.flip(y2, [2])
139 | y = torch.cat([y1, y2], dim=1)
140 |
141 | return y
142 |
143 |
144 | class MAB(nn.Module):
145 | def __init__(self, conv_layer, channels, multi_head=True, ffn=True):
146 | super().__init__()
147 | self.conv = conv_layer
148 | self.inter_sa = MHRSA(channels, multi_head=multi_head, ffn=ffn)
149 | self.intra_sa = PSCA(channels, channels * 2)
150 |
151 | def forward(self, x, reverse=False):
152 | x = self.conv(x)
153 | x = self.inter_sa(x, reverse=reverse)
154 | x = self.intra_sa(x)
155 | return x
156 |
157 |
158 | """ Encoder-Decoder
159 | """
160 |
161 |
162 | def PlainMAB(in_ch, out_ch, bias=False):
163 | return MAB(nn.Conv3d(in_ch, out_ch, 3, 1, 1, bias=bias), out_ch)
164 |
165 |
166 | def DownMAB(in_ch, out_ch, bias=False):
167 | return MAB(nn.Conv3d(in_ch, out_ch, 3, (1, 2, 2), 1, bias=bias), out_ch)
168 |
169 |
170 | def UpMAB(in_ch, out_ch, bias=False):
171 | return MAB(UpsampleConv3d(in_ch, out_ch, 3, 1, 1, bias=bias, upsample=(1, 2, 2)), out_ch)
172 |
173 |
174 | class Encoder(nn.Module):
175 | def __init__(self, channels, num_half_layer, sample_idx):
176 | super(Encoder, self).__init__()
177 | self.layers = nn.ModuleList()
178 | for i in range(num_half_layer):
179 | if i not in sample_idx:
180 | encoder_layer = PlainMAB(channels, channels)
181 | else:
182 | encoder_layer = DownMAB(channels, 2 * channels)
183 | channels *= 2
184 | self.layers.append(encoder_layer)
185 |
186 | def forward(self, x, xs, reverse=False):
187 | num_half_layer = len(self.layers)
188 | for i in range(num_half_layer - 1):
189 | x = self.layers[i](x, reverse)
190 | reverse = not reverse
191 | xs.append(x)
192 | x = self.layers[-1](x, reverse)
193 | reverse = not reverse
194 | return x
195 |
196 |
197 | class Decoder(nn.Module):
198 | def __init__(self, channels, num_half_layer, sample_idx, Fusion=None):
199 | super(Decoder, self).__init__()
200 | self.layers = nn.ModuleList()
201 | self.enable_fusion = Fusion is not None
202 |
203 | if self.enable_fusion:
204 | self.fusions = nn.ModuleList()
205 | ch = channels
206 | for i in reversed(range(num_half_layer)):
207 | fusion_layer = Fusion(ch)
208 | if i in sample_idx:
209 | ch //= 2
210 | self.fusions.append(fusion_layer)
211 |
212 | for i in reversed(range(num_half_layer)):
213 | if i not in sample_idx:
214 | decoder_layer = PlainMAB(channels, channels)
215 | else:
216 | decoder_layer = UpMAB(channels, channels // 2)
217 | channels //= 2
218 | self.layers.append(decoder_layer)
219 |
220 | def forward(self, x, xs, reverse=False):
221 | num_half_layer = len(self.layers)
222 | x = self.layers[0](x, reverse)
223 | reverse = not reverse
224 | for i in range(1, num_half_layer):
225 | if self.enable_fusion:
226 | x = self.fusions[i](x, xs.pop())
227 | else:
228 | x = x + xs.pop()
229 | x = self.layers[i](x, reverse)
230 | reverse = not reverse
231 | return x
232 |
233 |
234 | class MAN(nn.Module):
235 | def __init__(self, in_channels, channels, num_half_layer, sample_idx, Fusion=None):
236 | super().__init__()
237 | self.head = PlainMAB(in_channels, channels)
238 | self.encoder = Encoder(channels, num_half_layer, sample_idx)
239 | self.decoder = Decoder(channels * (2**len(sample_idx)), num_half_layer, sample_idx, Fusion=Fusion)
240 | self.tail = nn.Conv3d(channels, in_channels, 3, 1, 1, bias=True)
241 |
242 | def forward(self, x):
243 | xs = [x]
244 | out = self.head(xs[0])
245 | xs.append(out)
246 | reverse = True
247 | out = self.encoder(out, xs, reverse)
248 | out = self.decoder(out, xs, reverse)
249 | out = out + xs.pop()
250 | out = self.tail(out)
251 | out = out + xs.pop()
252 | return out
253 |
254 |
255 | class MAN_T(MAN):
256 | def __init__(self, in_channels, channels, num_half_layer, sample_idx, Fusion=None):
257 | super().__init__(in_channels, channels, num_half_layer, sample_idx, Fusion)
258 | self.tail = PlainMAB(channels, in_channels, bias=True)
259 |
260 |
261 | """ Models
262 | """
263 |
264 |
265 | def man_s():
266 | net = MAN(1, 12, 5, [1, 3], Fusion=ASC)
267 | net.use_2dconv = False
268 | net.bandwise = False
269 | return net
270 |
271 |
272 | def man_m():
273 | net = MAN(1, 16, 5, [1, 3], Fusion=ASC)
274 | net.use_2dconv = False
275 | net.bandwise = False
276 | return net
277 |
278 |
279 | def man_l():
280 | net = MAN(1, 20, 5, [1, 3], Fusion=ASC)
281 | net.use_2dconv = False
282 | net.bandwise = False
283 | return net
284 |
285 |
--------------------------------------------------------------------------------
/hsir/model/man/v1.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | """ Skip Connection """
8 |
9 |
10 | class AdditiveConnection(nn.Module):
11 | def __init__(self, channel):
12 | super().__init__()
13 |
14 | def forward(self, x, y):
15 | return x + y
16 |
17 |
18 | class ConcatSkipConnection(nn.Module):
19 | def __init__(self, channel):
20 | super().__init__()
21 | self.conv = nn.Sequential(
22 | nn.Conv3d(channel * 2, channel, 3, 1, 1),
23 | )
24 |
25 | def forward(self, x, y):
26 | return self.conv(torch.cat([x, y], dim=1))
27 |
28 |
29 | class AdaptiveSkipConnection(nn.Module):
30 | def __init__(self, channel):
31 | super().__init__()
32 | self.conv_weight = nn.Sequential(
33 | nn.Conv3d(channel * 2, channel, 1),
34 | nn.Tanh(),
35 | nn.Conv3d(channel, channel, 3, 1, 1),
36 | nn.Sigmoid()
37 | )
38 |
39 | def forward(self, x, y):
40 | w = self.conv_weight(torch.cat([x, y], dim=1))
41 | return (1 - w) * x + w * y
42 |
43 |
44 | """ Channel Attention """
45 |
46 |
47 | class GlobalAvgPool3d(nn.Module):
48 | def __init__(self):
49 | super().__init__()
50 |
51 | def forward(self, x):
52 | sum = torch.sum(x, dim=(2, 3, 4), keepdim=True)
53 | return sum / (x.shape[2] * x.shape[3] * x.shape[4])
54 |
55 |
56 | class SimplifiedChannelAttention(nn.Module):
57 | def __init__(self, ch):
58 | super().__init__()
59 | self.sca = nn.Sequential(
60 | # nn.AdaptiveAvgPool3d(1),
61 | GlobalAvgPool3d(),
62 | nn.Conv3d(ch, ch, kernel_size=1),
63 | nn.Sigmoid()
64 | )
65 |
66 | def forward(self, x):
67 | return self.sca(x) * x
68 |
69 |
70 | class ChannelAttention(nn.Module):
71 | def __init__(self, ch):
72 | super().__init__()
73 | self.ca = nn.Sequential(
74 | GlobalAvgPool3d(),
75 | nn.Conv3d(ch, ch, kernel_size=1),
76 | nn.ReLU(),
77 | nn.Conv3d(ch, ch, kernel_size=1),
78 | nn.Sigmoid()
79 | )
80 |
81 | def forward(self, x):
82 | return self.ca(x) * x
83 |
84 |
85 | """ Mixed Attention Block """
86 |
87 |
88 | class PositionwiseFeedForward(nn.Module):
89 | "Implements FFN equation."
90 |
91 | def __init__(self, channel):
92 | super(PositionwiseFeedForward, self).__init__()
93 | self.w_1 = nn.Conv3d(channel, channel, kernel_size=1)
94 | self.w_2 = nn.Conv3d(channel, channel, kernel_size=1)
95 |
96 | def forward(self, x):
97 | return self.w_2(torch.tanh(self.w_1(x)))
98 |
99 |
100 | class MAB(nn.Module):
101 | def __init__(self, channels, enable_ca=True, reverse=False):
102 | super(MAB, self).__init__()
103 | self.channels = channels
104 | self.enable_ca = enable_ca
105 | self.reverse = reverse
106 | self.sca = SimplifiedChannelAttention(channels)
107 | self.ffn_f = PositionwiseFeedForward(channels)
108 | self.ffn_w = PositionwiseFeedForward(channels)
109 |
110 | def _rnn_step(self, z, f, h):
111 | h_ = (1 - f) * z if h is None else f * h + (1 - f) * z
112 | return h_
113 |
114 | def forward(self, inputs):
115 | h = None
116 | Z = self.ffn_f(inputs).tanh()
117 | F = self.ffn_w(inputs).sigmoid()
118 | h_time = []
119 |
120 | if not self.reverse:
121 | for time, (z, f) in enumerate(zip(Z.split(1, 2), F.split(1, 2))):
122 | h = self._rnn_step(z, f, h)
123 | h_time.append(h)
124 | else:
125 | for time, (z, f) in enumerate((zip(
126 | reversed(Z.split(1, 2)), reversed(F.split(1, 2))
127 | ))):
128 | h = self._rnn_step(z, f, h)
129 | h_time.insert(0, h)
130 |
131 | out = torch.cat(h_time, dim=2)
132 | if self.enable_ca:
133 | out = self.sca(out)
134 | return out
135 |
136 |
137 | class BiMAB(MAB):
138 | def __init__(self, channels, enable_ca=True):
139 | super().__init__(channels, enable_ca, None)
140 | self.ffn_w2 = PositionwiseFeedForward(channels)
141 |
142 | def forward(self, inputs):
143 | Z = self.ffn_f(inputs).tanh()
144 | F1 = self.ffn_w(inputs).sigmoid()
145 | F2 = self.ffn_w2(inputs).sigmoid()
146 |
147 | h = None
148 | hsl = []
149 | hsr = []
150 | zs = Z.split(1, 2)
151 |
152 | for time, (z, f) in enumerate(zip(zs, F1.split(1, 2))):
153 | h = self._rnn_step(z, f, h)
154 | hsl.append(h)
155 |
156 | h = None
157 | for time, (z, f) in enumerate((zip(reversed(zs), reversed(F2.split(1, 2))))):
158 | h = self._rnn_step(z, f, h)
159 | hsr.insert(0, h)
160 |
161 | hsl = torch.cat(hsl, dim=2)
162 | hsr = torch.cat(hsr, dim=2)
163 |
164 | out = hsl + hsr
165 | if self.enable_ca:
166 | out = self.sca(out)
167 | return out
168 |
169 |
170 | """ Mixed Attention Network """
171 |
172 |
173 | class Encoder(nn.Module):
174 | def __init__(self, channels, num_half_layer, sample_idx, Attn=None):
175 | super(Encoder, self).__init__()
176 | self.layers = nn.ModuleList()
177 | for i in range(num_half_layer):
178 | if i not in sample_idx:
179 | encoder_layer = nn.Sequential(OrderedDict([
180 | ('conv', nn.Conv3d(channels, channels, 3, 1, 1, bias=False)),
181 | ('attn', Attn(channels, reverse=i % 2 == 1)) if Attn else None
182 | ]))
183 | else:
184 | encoder_layer = nn.Sequential(OrderedDict([
185 | ('conv', nn.Conv3d(channels, channels * 2, 3, (1, 2, 2), 1, bias=False)),
186 | ('attn', Attn(channels * 2, reverse=i % 2 == 1)) if Attn else None
187 | ]))
188 | channels *= 2
189 | self.layers.append(encoder_layer)
190 |
191 | def forward(self, x, xs):
192 | num_half_layer = len(self.layers)
193 | for i in range(num_half_layer - 1):
194 | x = self.layers[i](x)
195 | xs.append(x)
196 | x = self.layers[-1](x)
197 | return x
198 |
199 |
200 | class Decoder(nn.Module):
201 | def __init__(self, channels, num_half_layer, sample_idx, Fusion=None, Attn=None):
202 | super(Decoder, self).__init__()
203 |
204 | self.layers = nn.ModuleList()
205 | self.enable_fusion = Fusion is not None
206 |
207 | if self.enable_fusion:
208 | self.fusions = nn.ModuleList()
209 | ch = channels
210 | for i in reversed(range(num_half_layer)):
211 | fusion_layer = Fusion(ch)
212 | if i in sample_idx:
213 | ch //= 2
214 | self.fusions.append(fusion_layer)
215 |
216 | for i in reversed(range(num_half_layer)):
217 | if i not in sample_idx:
218 | decoder_layer = nn.Sequential(OrderedDict([
219 | ('conv', nn.ConvTranspose3d(channels, channels, 3, 1, 1, bias=False)),
220 | ('attn', Attn(channels, reverse=i % 2 == 0)) if Attn else None
221 | ]))
222 | else:
223 | decoder_layer = nn.Sequential(OrderedDict([
224 | ('up', nn.Upsample(scale_factor=(1, 2, 2), mode='trilinear', align_corners=True)),
225 | ('conv', nn.Conv3d(channels, channels // 2, 3, 1, 1, bias=False)),
226 | ('attn', Attn(channels // 2, reverse=i % 2 == 0)) if Attn else None
227 | ]))
228 | channels //= 2
229 | self.layers.append(decoder_layer)
230 |
231 | def forward(self, x, xs):
232 | num_half_layer = len(self.layers)
233 | x = self.layers[0](x)
234 | for i in range(1, num_half_layer):
235 | if self.enable_fusion:
236 | x = self.fusions[i](x, xs.pop())
237 | else:
238 | x = x + xs.pop()
239 | x = self.layers[i](x)
240 | return x
241 |
242 |
243 | class MAN(nn.Module):
244 | def __init__(
245 | self,
246 | in_channels=1,
247 | channels=16,
248 | num_half_layer=5,
249 | sample_idx=[1, 3],
250 | Attn=MAB,
251 | BiAttn=BiMAB,
252 | Fusion=AdaptiveSkipConnection,
253 | ):
254 | super(MAN, self).__init__()
255 |
256 | self.head = nn.Sequential(
257 | nn.Conv3d(in_channels, channels, 3, 1, 1, bias=False),
258 | BiAttn(channels) if BiAttn else None
259 | )
260 |
261 | self.encoder = Encoder(channels, num_half_layer, sample_idx, Attn)
262 | self.decoder = Decoder(channels * (2**len(sample_idx)), num_half_layer, sample_idx,
263 | Fusion=Fusion, Attn=Attn)
264 |
265 | self.tail = nn.Sequential(
266 | nn.ConvTranspose3d(channels, in_channels, 3, 1, 1, bias=True),
267 | BiAttn(in_channels) if BiAttn else None
268 | )
269 |
270 | def forward(self, x):
271 | xs = [x]
272 | out = self.head(xs[0])
273 | xs.append(out)
274 | out = self.encoder(out, xs)
275 | out = self.decoder(out, xs)
276 | out = out + xs.pop()
277 | out = self.tail(out)
278 | out = out + xs.pop()
279 | return out
280 |
281 |
282 | """ Model Variants """
283 |
284 |
285 | def man():
286 | net = MAN(1, 16, 5, [1, 3])
287 | net.use_2dconv = False
288 | net.bandwise = False
289 | return net
290 |
291 |
292 | def man_m():
293 | net = MAN(1, 12, 5, [1, 3])
294 | net.use_2dconv = False
295 | net.bandwise = False
296 | return net
297 |
298 |
299 | def man_s():
300 | net = MAN(1, 8, 5, [1, 3])
301 | net.use_2dconv = False
302 | net.bandwise = False
303 | return net
304 |
305 |
306 | def man_b():
307 | net = MAN(1, 24, 5, [1, 3])
308 | net.use_2dconv = False
309 | net.bandwise = False
310 | return net
311 |
312 |
313 | def man_deep():
314 | net = MAN(1, 16, 7, [1, 3, 5])
315 | net.use_2dconv = False
316 | net.bandwise = False
317 | return net
318 |
319 |
320 | """ Baseline """
321 |
322 |
323 | def baseline():
324 | net = MAN(1, 16, 5, [1, 3], Attn=None, BiAttn=None, Fusion=None)
325 | net.use_2dconv = False
326 | net.bandwise = False
327 | return net
328 |
--------------------------------------------------------------------------------
/hsir/model/hsdt/attention.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from einops import rearrange
8 |
9 |
10 | class SSA(nn.Module):
11 | """ Spectral Self Attention (SSA) in "Hybrid Spectral Denoising Transformer with Learnable Query"
12 | GSSA without gudiance.
13 | """
14 | def __init__(self, channel, num_bands, flex=False):
15 | super().__init__()
16 | self.channel = channel
17 | self.num_bands = num_bands
18 |
19 | self.value_proj = nn.Linear(channel, channel, bias=False)
20 | self.fc = nn.Linear(channel, channel, bias=False)
21 |
22 | def forward(self, x):
23 | B, C, D, H, W = x.shape
24 |
25 | residual = x
26 |
27 | tmp = x.reshape(B, C, D, H * W).mean(-1).permute(0, 2, 1)
28 | attn = tmp @ tmp.transpose(1, 2)
29 |
30 | attn = attn.reshape(B, self.num_bands, self.num_bands)
31 | attn = F.softmax(attn, dim=-1) # B, band, band
32 | attn = attn.unsqueeze(1).unsqueeze(1)
33 |
34 | v = self.value_proj(rearrange(x, 'b c d h w -> b h w d c'))
35 |
36 | q = torch.matmul(attn, v)
37 |
38 | q = self.fc(q)
39 | q = rearrange(q, 'b h w d c -> b c d h w')
40 |
41 | q += residual
42 |
43 | return q, attn
44 |
45 |
46 | class GSSA(nn.Module):
47 | """ Guided Spectral Self Attention (GSSA) in "Hybrid Spectral Denoising Transformer with Learnable Query"
48 | """
49 |
50 | def __init__(self, channel, num_bands, flex=False):
51 | super().__init__()
52 | self.channel = channel
53 | self.num_bands = num_bands
54 | self.flex = flex
55 |
56 | # learnable query
57 | self.attn_proj = nn.Linear(channel, num_bands)
58 | self.value_proj = nn.Linear(channel, channel, bias=False)
59 | self.fc = nn.Linear(channel, channel, bias=False)
60 |
61 | def forward(self, x):
62 | B, C, D, H, W = x.shape
63 |
64 | residual = x
65 |
66 | tmp = x.reshape(B, C, D, H * W).mean(-1).permute(0, 2, 1)
67 |
68 | if self.training:
69 | if random.random() > 0.5:
70 | attn = tmp @ tmp.transpose(1, 2)
71 | else:
72 | attn = self.attn_proj(tmp)
73 | else:
74 | if self.flex:
75 | attn = tmp @ tmp.transpose(1, 2)
76 | else:
77 | attn = self.attn_proj(tmp)
78 |
79 | attn = attn.reshape(B, self.num_bands, self.num_bands)
80 | attn = F.softmax(attn, dim=-1) # B, band, band
81 | attn = attn.unsqueeze(1).unsqueeze(1)
82 |
83 | v = self.value_proj(rearrange(x, 'b c d h w -> b h w d c'))
84 |
85 | q = torch.matmul(attn, v)
86 |
87 | q = self.fc(q)
88 | q = rearrange(q, 'b h w d c -> b c d h w')
89 |
90 | q += residual
91 |
92 | return q, attn
93 |
94 |
95 | class PixelwiseGSSA(GSSA):
96 | """ Pixelwise GSSA
97 | """
98 | def __init__(self, channel, num_bands):
99 | super().__init__(channel, num_bands)
100 |
101 | def forward(self, x):
102 | B, C, D, H, W = x.shape
103 | x = rearrange(x, 'b c d h w -> b (h w) d c')
104 |
105 | residual = x
106 |
107 | if self.training:
108 | if random.random() > 0.5:
109 | attn = x @ x.transpose(-1, -2)
110 | else:
111 | attn = self.attn_proj(x)
112 | else:
113 | attn = x @ x.transpose(-1, -2)
114 |
115 | attn = F.softmax(attn, dim=-1)
116 |
117 | v = self.value_proj(x)
118 | q = torch.matmul(attn, v)
119 |
120 | q = self.fc(q)
121 | q += residual
122 |
123 | q = rearrange(q, 'b (h w) d c -> b c d h w', d=D, h=H, w=W)
124 |
125 | return q, attn
126 |
127 |
128 | class GSSAConvImpl(GSSA):
129 | """ GSSA fast convolutional implementation
130 | """
131 | def __init__(self, channel, num_bands, flex=False):
132 | super().__init__(channel, num_bands)
133 |
134 | def forward(self, x):
135 | B, C, D, H, W = x.shape
136 |
137 | x = rearrange(x, 'b c d h w -> b h w d c')
138 | residual = x
139 |
140 | tmp = rearrange(x, 'b h w d c -> b (h w) d c').mean(1)
141 | attn = self.attn_proj(tmp) # B,band,band
142 | attn = F.softmax(attn, dim=-1) # b,band,band
143 |
144 | v = self.value_proj(x) # b c band w h
145 |
146 | if B > 1:
147 | input = rearrange(v, 'b h w d c -> (b d) c h w').unsqueeze(0)
148 | weight = attn.reshape(B * D, D, 1, 1, 1)
149 | q = F.conv3d(input, weight, groups=B) # 1, b*d, c, w, h
150 | q = rearrange(q.squeeze(0), '(b d) c h w -> b h w d c', b=B)
151 | else:
152 | input = rearrange(v, 'b h w d c -> c (b d) h w')
153 | weight = attn.reshape(D, D, 1, 1)
154 | q = F.conv2d(input, weight) # 1, b*d, c, w, h
155 | q = rearrange(q, 'c (b d) w h -> b h w d c', b=B)
156 |
157 | q = self.fc(q)
158 | q += residual
159 |
160 | q = rearrange(q, 'b h w d c -> b c d h w')
161 | return q, attn
162 |
163 |
164 | """ Feedforwrd
165 | """
166 |
167 |
168 | class SMFFN(nn.Module):
169 | """ Self Modulated Feed Forward Network (SM-FFN) in "Hybrid Spectral Denoising Transformer with Learnable Query"
170 | """
171 |
172 | def __init__(self, d_model, d_ff, bias=False):
173 | super().__init__()
174 | self.w_1 = nn.Linear(d_model, d_ff, bias=bias)
175 | self.w_2 = nn.Linear(d_ff, d_model, bias=bias)
176 | self.w_3 = nn.Linear(d_model, d_ff, bias=bias)
177 |
178 | def forward(self, input):
179 | x = self.w_1(input)
180 | x = F.gelu(x)
181 | x1 = self.w_2(x)
182 |
183 | x = self.w_3(input)
184 | x, w = torch.chunk(x, 2, dim=-1)
185 | x2 = x * torch.sigmoid(w)
186 |
187 | return x1 + x2
188 |
189 |
190 | class SMFFNBranch(nn.Module):
191 | def __init__(self, d_model, d_ff, bias=False):
192 | super().__init__()
193 | self.w_3 = nn.Linear(d_model, d_ff, bias=bias)
194 |
195 | def forward(self, input):
196 | x = self.w_3(input)
197 | x, w = torch.chunk(x, 2, dim=-1)
198 | x2 = x * torch.sigmoid(w)
199 | return x2
200 |
201 |
202 | class GDFN(nn.Module):
203 | """ 3D version of GDFN from Restormer.
204 | """
205 | def __init__(self, d_model, d_ff, bias=False):
206 | super(GDFN, self).__init__()
207 | self.project_in = nn.Conv3d(d_model, d_ff, kernel_size=1, bias=bias)
208 | self.dwconv = nn.Conv3d(d_ff, d_ff, kernel_size=3, stride=1, padding=1, groups=d_ff, bias=bias)
209 | self.project_out = nn.Conv3d(d_model, d_model, kernel_size=1, bias=bias)
210 |
211 | def forward(self, input):
212 | input = rearrange(input, 'b d h w c -> b c d h w')
213 | x = self.project_in(input)
214 | x1, x2 = self.dwconv(x).chunk(2, dim=1)
215 | x = F.gelu(x1) * x2
216 | x = self.project_out(x)
217 | output = x
218 | output = rearrange(output, 'b c d h w -> b d h w c')
219 | return output
220 |
221 |
222 | class FFN(nn.Module):
223 | def __init__(self, d_model, d_ff, bias=False):
224 | super().__init__()
225 | self.w_1 = nn.Linear(d_model, d_ff, bias=bias)
226 | self.w_2 = nn.Linear(d_ff, d_model, bias=bias)
227 | self.act = nn.GELU()
228 |
229 | def forward(self, input):
230 | output= self.w_2(self.act(self.w_1(input)))
231 | return output
232 |
233 |
234 | """ Transformer Block
235 | """
236 |
237 |
238 | class TransformerBlock(nn.Module):
239 | def __init__(self, channels, num_bands=31, bias=False, flex=False):
240 | super().__init__()
241 | self.channels = channels
242 | self.attn = GSSA(channels, num_bands, flex=flex)
243 | self.ffn = SMFFN(channels, channels * 2, bias=bias)
244 |
245 | def forward(self, inputs):
246 | r, _ = self.attn(inputs)
247 | r = rearrange(r, 'b c d h w -> b d h w c')
248 | r = self.ffn(r)
249 | r = rearrange(r, 'b d h w c -> b c d h w')
250 | return r
251 |
252 |
253 | class DummyTransformerBlock(nn.Module):
254 | def __init__(self, channels, num_bands=31, bias=False, flex=False):
255 | super().__init__()
256 | self.channels = channels
257 |
258 | def forward(self, inputs):
259 | return inputs.tanh()
260 |
261 |
262 | """ Ablation
263 | """
264 |
265 |
266 | class PixelwiseTransformerBlock(TransformerBlock):
267 | def __init__(self, channels, num_bands=31, bias=False, flex=False):
268 | super().__init__(channels, num_bands, bias, flex)
269 | self.channels = channels
270 | self.attn = PixelwiseGSSA(channels, num_bands, flex=flex)
271 | self.ffn = SMFFN(channels, channels * 2, bias=bias)
272 |
273 |
274 | class FFNTransformerBlock(TransformerBlock):
275 | def __init__(self, channels, num_bands=31, bias=False, flex=False):
276 | super().__init__(channels, num_bands, bias, flex)
277 | self.channels = channels
278 | self.attn = GSSA(channels, num_bands, flex=flex)
279 | self.ffn = FFN(channels, channels * 2, bias=bias)
280 |
281 |
282 | class GDFNTransformerBlock(TransformerBlock):
283 | def __init__(self, channels, num_bands=31, bias=False, flex=False):
284 | super().__init__(channels, num_bands, bias, flex)
285 | self.channels = channels
286 | self.attn = GSSA(channels, num_bands, flex=flex)
287 | self.ffn = GDFN(channels, channels * 2, bias=bias)
288 |
289 |
290 | class GFNTransformerBlock(TransformerBlock):
291 | def __init__(self, channels, num_bands=31, bias=False, flex=False):
292 | super().__init__(channels, num_bands, bias, flex)
293 | self.channels = channels
294 | self.attn = GSSA(channels, num_bands, flex=flex)
295 | self.ffn = SMFFNBranch(channels, channels * 2, bias=bias)
296 |
297 |
298 | class SSATransformerBlock(TransformerBlock):
299 | def __init__(self, channels, num_bands=31, bias=False, flex=False):
300 | super().__init__(channels, num_bands, bias, flex)
301 | self.channels = channels
302 | self.attn = SSA(channels, num_bands, flex=flex)
303 | self.ffn = SMFFN(channels, channels * 2, bias=bias)
304 |
305 |
306 | class SSAFFNTransformerBlock(TransformerBlock):
307 | def __init__(self, channels, num_bands=31, bias=False, flex=False):
308 | super().__init__(channels, num_bands, bias, flex)
309 | self.channels = channels
310 | self.attn = SSA(channels, num_bands, flex=flex)
311 | self.ffn = FFN(channels, channels * 2, bias=bias)
312 |
--------------------------------------------------------------------------------