├── benchmark ├── README.md └── datasets │ ├── ICVL │ └── todo.txt │ ├── RealHSI │ ├── real_test.txt │ ├── real_train.txt │ ├── build_test.py │ └── real_build_lmdb.py │ ├── Harvard │ ├── harvard_test.txt │ ├── harvard_train.txt │ ├── harvard_meta.txt │ ├── harvard_build_test.py │ └── harvard_build_lmdb.py │ └── CAVE │ ├── cave_test.txt │ ├── cave_train.txt │ ├── cave_meta.txt │ ├── cave_preprocess.py │ ├── cave_build_test.py │ └── cave_build_lmdb.py ├── hsir ├── model │ ├── man │ │ ├── __init__.py │ │ ├── v2.py │ │ └── v1.py │ ├── t3sc │ │ ├── utils │ │ │ ├── __init__.py │ │ │ └── patches_handler.py │ │ ├── metrics │ │ │ ├── __init__.py │ │ │ └── utils.py │ │ ├── layers │ │ │ ├── __init__.py │ │ │ ├── soft_thresholding.py │ │ │ ├── encoding_layer.py │ │ │ └── lowrank_sc_layer.py │ │ ├── __init__.py │ │ ├── configs │ │ │ └── t3sc.yaml │ │ ├── multilayer.py │ │ └── base.py │ ├── __init__.py │ ├── trq3d │ │ ├── __init__.py │ │ ├── conv.py │ │ └── combinations.py │ ├── qrnn3d │ │ ├── __init__.py │ │ ├── conv.py │ │ ├── qrnn.py │ │ └── net.py │ ├── denet.py │ ├── memnet.py │ ├── hsidcnn.py │ ├── unet3d.py │ ├── grunet.py │ └── hsdt │ │ ├── sepconv.py │ │ ├── __init__.py │ │ ├── arch.py │ │ └── attention.py ├── data │ ├── transform │ │ ├── __init__.py │ │ ├── general.py │ │ ├── sr.py │ │ └── noise.py │ ├── __init__.py │ ├── utils.py │ ├── dataloader.py │ └── dataset.py ├── __init__.py ├── scheduler.py └── schedule.py ├── docs ├── requirements.txt ├── source │ ├── getstart.md │ ├── dataset.md │ ├── index.md │ └── conf.py ├── Makefile └── make.bat ├── hsiboard ├── app.py ├── main.py ├── table.py ├── viewer.py ├── box.py ├── util.py └── cmp.py ├── setup.py ├── .readthedocs.yaml ├── LICENSE ├── hsirun ├── benchmark.py ├── train_ssr.py ├── train.py └── test.py └── .gitignore /benchmark/README.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hsir/model/man/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /benchmark/datasets/ICVL/todo.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hsir/data/transform/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hsir/model/t3sc/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .patches_handler import PatchesHandler 2 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | myst-parser 2 | sphinx-rtd-theme 3 | furo 4 | sphinx-copybutton 5 | sphinx-inline-tabs -------------------------------------------------------------------------------- /hsir/model/t3sc/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import mergas, mfsim, mpsnr, msam, mse, mssim, psnr 2 | -------------------------------------------------------------------------------- /hsir/__init__.py: -------------------------------------------------------------------------------- 1 | from . import scheduler 2 | from . import schedule 3 | from . import trainer 4 | from . import model 5 | from . import data -------------------------------------------------------------------------------- /hsir/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import dataset 2 | from . import transform 3 | from .dataset import HSITestDataset, HSITrainDataset, HSITransformTestDataset 4 | from . import ssr -------------------------------------------------------------------------------- /hsir/model/t3sc/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .lowrank_sc_layer import LowRankSCLayer 2 | from .encoding_layer import EncodingLayer 3 | from .soft_thresholding import SoftThresholding 4 | -------------------------------------------------------------------------------- /benchmark/datasets/RealHSI/real_test.txt: -------------------------------------------------------------------------------- 1 | 41 2 | 42 3 | 43 4 | 44 5 | 45 6 | 46 7 | 47 8 | 48 9 | 49 10 | 50 11 | 51 12 | 52 13 | 53 14 | 54 15 | 55 16 | 56 17 | 57 18 | 58 19 | 59 -------------------------------------------------------------------------------- /hsir/model/__init__.py: -------------------------------------------------------------------------------- 1 | from . import qrnn3d 2 | from . import t3sc 3 | from . import grunet 4 | from . import restormer 5 | from . import uformer 6 | from . import unet3d 7 | from . import hsidcnn -------------------------------------------------------------------------------- /benchmark/datasets/Harvard/harvard_test.txt: -------------------------------------------------------------------------------- 1 | imgc7 2 | imgd3 3 | imgd9 4 | imge3 5 | imge7 6 | imgf4 7 | imgf8 8 | imgh3 9 | imga1 10 | imga7 11 | imgb3 12 | imgb7 13 | imgc2 14 | imgc8 15 | imgd4 16 | imge0 17 | imge4 18 | imgf1 19 | imgf5 20 | imgh0 -------------------------------------------------------------------------------- /benchmark/datasets/CAVE/cave_test.txt: -------------------------------------------------------------------------------- 1 | thread_spools_ms 2 | jelly_beans_ms 3 | beads_ms 4 | fake_and_real_peppers_ms 5 | face_ms 6 | pompoms_ms 7 | flowers_ms 8 | photo_and_face_ms 9 | fake_and_real_lemons_ms 10 | stuffed_toys_ms 11 | glass_tiles_ms 12 | oil_painting_ms -------------------------------------------------------------------------------- /benchmark/datasets/RealHSI/real_train.txt: -------------------------------------------------------------------------------- 1 | 1 2 | 2 3 | 3 4 | 4 5 | 5 6 | 6 7 | 7 8 | 8 9 | 9 10 | 10 11 | 11 12 | 12 13 | 13 14 | 14 15 | 15 16 | 16 17 | 17 18 | 18 19 | 19 20 | 20 21 | 21 22 | 22 23 | 23 24 | 24 25 | 25 26 | 26 27 | 27 28 | 28 29 | 29 30 | 30 31 | 31 32 | 32 33 | 33 34 | 34 35 | 35 36 | 36 37 | 37 38 | 38 39 | 39 40 | 40 -------------------------------------------------------------------------------- /benchmark/datasets/Harvard/harvard_train.txt: -------------------------------------------------------------------------------- 1 | imga2 2 | imgb0 3 | imgb4 4 | imgb8 5 | imgc4 6 | imgc9 7 | imgd7 8 | imge1 9 | imge5 10 | imgf2 11 | imgf6 12 | imgh1 13 | img1 14 | imga5 15 | imgb1 16 | imgb5 17 | imgb9 18 | imgc5 19 | imgd2 20 | imgd8 21 | imge2 22 | imge6 23 | imgf3 24 | imgf7 25 | imgh2 26 | img2 27 | imga6 28 | imgb2 29 | imgb6 30 | imgc1 -------------------------------------------------------------------------------- /hsiboard/app.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | 5 | 6 | if __name__ == '__main__': 7 | parser = argparse.ArgumentParser('HSIR Board') 8 | parser.add_argument('--logdir', default='results') 9 | args = parser.parse_args() 10 | 11 | path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'main.py') 12 | subprocess.call(['streamlit', 'run', path, '--', '--logdir', args.logdir]) 13 | -------------------------------------------------------------------------------- /benchmark/datasets/CAVE/cave_train.txt: -------------------------------------------------------------------------------- 1 | fake_and_real_sushi_ms 2 | watercolors_ms 3 | chart_and_stuffed_toy_ms 4 | superballs_ms 5 | real_and_fake_peppers_ms 6 | fake_and_real_tomatoes_ms 7 | balloons_ms 8 | fake_and_real_beers_ms 9 | real_and_fake_apples_ms 10 | cloth_ms 11 | feathers_ms 12 | hairs_ms 13 | paints_ms 14 | sponges_ms 15 | clay_ms 16 | fake_and_real_food_ms 17 | fake_and_real_strawberries_ms 18 | egyptian_statue_ms 19 | fake_and_real_lemon_slices_ms 20 | cd_ms -------------------------------------------------------------------------------- /hsir/data/utils.py: -------------------------------------------------------------------------------- 1 | import torchlight as tl 2 | import torch 3 | import numpy as np 4 | 5 | 6 | def visualize_gray(hsi, band=20): 7 | return hsi[:, band:band + 1, :, :] 8 | 9 | 10 | def visualize_color(hsi): 11 | srf = tl.transforms.HSI2RGB().srf 12 | srf = torch.from_numpy(srf).float().to(hsi.device) 13 | hsi = hsi.permute(0, 2, 3, 1) @ srf.T 14 | return hsi.permute(0, 3, 1, 2) 15 | 16 | 17 | def worker_init_fn(worker_id): 18 | np.random.seed(np.random.get_state()[1][0] + worker_id) 19 | -------------------------------------------------------------------------------- /hsir/model/t3sc/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .multilayer import MultilayerModel 3 | 4 | 5 | def build_t3sc(num_band): 6 | from omegaconf import OmegaConf 7 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) 8 | cfg_path = os.path.join(CURRENT_DIR, 'configs', 't3sc.yaml') 9 | cfg = OmegaConf.load(cfg_path) 10 | cfg.params.channels = num_band 11 | net = MultilayerModel(**cfg.params) 12 | net.use_2dconv = True 13 | net.bandwise = False 14 | return net 15 | 16 | 17 | def t3sc(): 18 | return build_t3sc(31) 19 | -------------------------------------------------------------------------------- /hsir/model/trq3d/__init__.py: -------------------------------------------------------------------------------- 1 | # the original version, with FCU before and after 2 | from functools import partial 3 | from .trq3d import * 4 | 5 | 6 | TRQ3D = partial( 7 | TRQ3D, 8 | Encoder=TRQ3DEncoder, 9 | Decoder=TRQ3DDecoder 10 | ) 11 | 12 | 13 | def trq3d(): 14 | net = TRQ3D(in_channels=1, in_channels_tr=31, channels=16, channels_tr=16, med_channels=31, num_half_layer=4, sample_idx=[1, 3], has_ad=True, 15 | input_resolution=(512, 512)) 16 | net.use_2dconv = False 17 | net.bandwise = False 18 | return net 19 | -------------------------------------------------------------------------------- /docs/source/getstart.md: -------------------------------------------------------------------------------- 1 | # Getting Started 2 | 3 | ```{toctree} 4 | :maxdepth: 2 5 | ``` 6 | 7 | 8 | [![PyPI](https://img.shields.io/pypi/v/hsir)](https://pypi.org/project/hsir/) 9 | 10 | Out-of-box Hyperspectral Image Restoration Toolbox 11 | 12 | 13 | ## Install 14 | 15 | ```shell 16 | pip install hsir 17 | ``` 18 | 19 | ## Usage 20 | 21 | Here are some runable examples, please refer to the code for more options. 22 | 23 | ```shell 24 | python hsirun/train.py -a qrnn3d.qrnn3d 25 | python hsirun/test.py -a qrnn3d.qrnn3d -r qrnn3d.pth -t icvl_512_50 26 | ``` -------------------------------------------------------------------------------- /benchmark/datasets/Harvard/harvard_meta.txt: -------------------------------------------------------------------------------- 1 | imga2 2 | imgb0 3 | imgb4 4 | imgb8 5 | imgc4 6 | imgc9 7 | imgd7 8 | imge1 9 | imge5 10 | imgf2 11 | imgf6 12 | imgh1 13 | img1 14 | imga5 15 | imgb1 16 | imgb5 17 | imgb9 18 | imgc5 19 | imgd2 20 | imgd8 21 | imge2 22 | imge6 23 | imgf3 24 | imgf7 25 | imgh2 26 | img2 27 | imga6 28 | imgb2 29 | imgb6 30 | imgc1 31 | imgc7 32 | imgd3 33 | imgd9 34 | imge3 35 | imge7 36 | imgf4 37 | imgf8 38 | imgh3 39 | imga1 40 | imga7 41 | imgb3 42 | imgb7 43 | imgc2 44 | imgc8 45 | imgd4 46 | imge0 47 | imge4 48 | imgf1 49 | imgf5 50 | imgh0 -------------------------------------------------------------------------------- /benchmark/datasets/CAVE/cave_meta.txt: -------------------------------------------------------------------------------- 1 | fake_and_real_sushi_ms 2 | watercolors_ms 3 | chart_and_stuffed_toy_ms 4 | superballs_ms 5 | real_and_fake_peppers_ms 6 | fake_and_real_tomatoes_ms 7 | balloons_ms 8 | fake_and_real_beers_ms 9 | real_and_fake_apples_ms 10 | cloth_ms 11 | feathers_ms 12 | hairs_ms 13 | paints_ms 14 | sponges_ms 15 | clay_ms 16 | fake_and_real_food_ms 17 | fake_and_real_strawberries_ms 18 | egyptian_statue_ms 19 | fake_and_real_lemon_slices_ms 20 | cd_ms 21 | thread_spools_ms 22 | jelly_beans_ms 23 | beads_ms 24 | fake_and_real_peppers_ms 25 | face_ms 26 | pompoms_ms 27 | flowers_ms 28 | photo_and_face_ms 29 | fake_and_real_lemons_ms 30 | stuffed_toys_ms 31 | glass_tiles_ms 32 | oil_painting_ms -------------------------------------------------------------------------------- /hsir/scheduler.py: -------------------------------------------------------------------------------- 1 | 2 | def adjust_learning_rate(optimizer, lr): 3 | for param_group in optimizer.param_groups: 4 | param_group['lr'] = lr 5 | 6 | class MultiStepSetLR: 7 | def __init__(self, optimizer, schedule, epoch=0) -> None: 8 | self.optimizer = optimizer 9 | self.schedule = schedule 10 | self.epoch = epoch 11 | 12 | def step(self): 13 | self.epoch += 1 14 | if self.epoch in self.schedule.keys(): 15 | adjust_learning_rate(self.optimizer, self.schedule[self.epoch]) 16 | 17 | def state_dict(self): 18 | return {'epoch': self.epoch} 19 | 20 | def load_state_dict(self, state_dict): 21 | self.epoch = state_dict['epoch'] 22 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /benchmark/datasets/CAVE/cave_preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import scipy.io as sio 5 | 6 | root = '/home/wzliu/projects/data/cave' 7 | names = os.listdir(root) 8 | names = [n for n in names if n.endswith('ms')] 9 | 10 | save_dir = '/home/wzliu/projects/data/cave_mat' 11 | os.makedirs(save_dir, exist_ok=True) 12 | 13 | for n in names: 14 | data = [] 15 | for b in range(31): 16 | path = os.path.join(root, n, n, '{}_{:0>2d}.png'.format(n, b + 1)) 17 | img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) 18 | data.append(np.expand_dims(img, 2)) 19 | data = np.concatenate(data, axis=2) 20 | 21 | print(n) 22 | sio.savemat(os.path.join(save_dir, n + '.mat'), {'gt': data}) 23 | -------------------------------------------------------------------------------- /hsiboard/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import streamlit as st 3 | import os 4 | import sys 5 | 6 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 7 | 8 | import viewer 9 | import table 10 | import cmp 11 | 12 | def main(logdir): 13 | page_names_to_funcs = { 14 | "Viewer": viewer.main, 15 | "Comparator": cmp.main, 16 | "Table": table.main, 17 | } 18 | selected_page = st.sidebar.selectbox("Select a page", page_names_to_funcs.keys()) 19 | page_names_to_funcs[selected_page](logdir) 20 | 21 | 22 | if __name__ == '__main__': 23 | parser = argparse.ArgumentParser('HSIR Board') 24 | parser.add_argument('--logdir', default='results') 25 | args = parser.parse_args() 26 | 27 | main(args.logdir) 28 | -------------------------------------------------------------------------------- /hsir/model/t3sc/metrics/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | EPS = 1e-12 6 | 7 | 8 | def abs(x): 9 | return torch.sqrt(x[:, :, :, :, 0] ** 2 + x[:, :, :, :, 1] ** 2 + EPS) 10 | 11 | 12 | def real(x): 13 | return x[:, :, :, :, 0] 14 | 15 | 16 | def imag(x): 17 | return x[:, :, :, :, 1] 18 | 19 | 20 | def downsample(img1, img2, maxSize=256): 21 | _, channels, H, W = img1.shape 22 | f = int(max(1, np.round(min(H, W) / maxSize))) 23 | if f > 1: 24 | aveKernel = (torch.ones(channels, 1, f, f) / f ** 2).to(img1.device) 25 | img1 = F.conv2d(img1, aveKernel, stride=f, padding=0, groups=channels) 26 | img2 = F.conv2d(img2, aveKernel, stride=f, padding=0, groups=channels) 27 | return img1, img2 28 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | from pathlib import Path 5 | this_directory = Path(__file__).parent 6 | long_description = (this_directory / "README.md").read_text() 7 | 8 | packages = find_packages() 9 | packages.append('hsirun') 10 | packages.append('hsiboard') 11 | 12 | setup( 13 | name='hsir', 14 | description="Hyperspectral Image Restoration Toolbox", 15 | long_description=long_description, 16 | long_description_content_type='text/markdown', 17 | url='https://github.com/Zeqiang-Lai/HSIR', 18 | packages=packages, 19 | package_dir={'hsir': 'hsir', 'hsirrun':'hsirrun', 'hsiboard': 'hsiboard'}, 20 | version='0.1.1', 21 | include_package_data=True, 22 | install_requires=['tqdm', 'qqdm', 'timm', 'streamlit', 'torchlights'], 23 | ) 24 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-20.04 11 | tools: 12 | python: "3.9" 13 | # You can also specify other tool versions: 14 | # nodejs: "16" 15 | # rust: "1.55" 16 | # golang: "1.17" 17 | 18 | # Build documentation in the docs/ directory with Sphinx 19 | sphinx: 20 | configuration: docs/source/conf.py 21 | 22 | # If using Sphinx, optionally build your docs in additional formats such as PDF 23 | # formats: 24 | # - pdf 25 | 26 | # Optionally declare the Python requirements required to build your docs 27 | python: 28 | install: 29 | - requirements: docs/requirements.txt -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /hsir/model/qrnn3d/__init__.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from .net import QRNN3DDecoder, QRNN3DEncoder, QRNNREDC3D 3 | from .qrnn import QRNNConv3D, QRNNDeConv3D, QRNNUpsampleConv3d, BiQRNNConv3D, BiQRNNDeConv3D 4 | 5 | 6 | QRNN3DEncoder = partial( 7 | QRNN3DEncoder, 8 | QRNNConv3D=QRNNConv3D) 9 | 10 | QRNN3DDecoder = partial( 11 | QRNN3DDecoder, 12 | QRNNDeConv3D=QRNNDeConv3D, 13 | QRNNUpsampleConv3d=QRNNUpsampleConv3d) 14 | 15 | QRNNREDC3D = partial( 16 | QRNNREDC3D, 17 | BiQRNNConv3D=BiQRNNConv3D, 18 | BiQRNNDeConv3D=BiQRNNDeConv3D, 19 | QRNN3DEncoder=QRNN3DEncoder, 20 | QRNN3DDecoder=QRNN3DDecoder 21 | ) 22 | 23 | 24 | def qrnn3d(): 25 | net = QRNNREDC3D(1, 16, 5, [1, 3], has_ad=True) 26 | net.use_2dconv = False 27 | net.bandwise = False 28 | return net 29 | 30 | 31 | def qrnn3d_nobn(): 32 | net = QRNNREDC3D(1, 16, 5, [1, 3], has_ad=True, bn=False) 33 | net.use_2dconv = False 34 | net.bandwise = False 35 | return net 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Zeqiang Lai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /benchmark/datasets/RealHSI/build_test.py: -------------------------------------------------------------------------------- 1 | from skimage import io 2 | import numpy as np 3 | import scipy.io as sio 4 | import os 5 | 6 | def load_tif_img(filepath): 7 | img = io.imread(filepath) 8 | img = img.astype(np.float32) 9 | return img 10 | 11 | 12 | root = '/media/exthdd/datasets/hsi/real_hsi/real_dataset' 13 | 14 | names = os.listdir(os.path.join(root, 'gt')) 15 | names = [n for n in names if n.endswith('tif')] 16 | 17 | save_dir = '/media/exthdd/datasets/hsi/real_hsi/real_dataset/test34_50' 18 | os.makedirs(save_dir, exist_ok=True) 19 | 20 | names = os.listdir('data/real34') 21 | 22 | for n in names: 23 | n = n[:-4] 24 | print(n) 25 | gt = load_tif_img(os.path.join(root, 'gt', n+'.tif')).transpose((1, 2, 0)) 26 | input = load_tif_img(os.path.join(root, 'input50', n+'.tif')).transpose((1, 2, 0)) 27 | gt = gt / gt.max() 28 | input = input / input.max() 29 | from skimage import exposure 30 | input = exposure.match_histograms(input, gt, multichannel=False) 31 | input = input[:688,:512,:] 32 | gt = gt[:688,:512,:] 33 | sio.savemat(os.path.join(save_dir, n + '.mat'), {'gt': gt, 'input': input}) -------------------------------------------------------------------------------- /hsir/model/t3sc/configs/t3sc.yaml: -------------------------------------------------------------------------------- 1 | class_name: "MultilayerModel" 2 | trainable: true 3 | beta: 0 4 | ckpt: null 5 | params: 6 | ssl: 0 7 | n_ssl: 4 8 | channels: 31 9 | ckpt: null 10 | layers: 11 | l0: 12 | name: "LowRankSCLayer" 13 | params: 14 | patch_side: 1 15 | K: 12 16 | rank: 1 17 | code_size: 64 18 | stride: 1 19 | input_centering: 1 20 | patch_centering: 0 21 | tied: "D" 22 | init_method: "kaiming_uniform" 23 | lbda_init: 0.001 24 | lbda_mode: "MC" 25 | beta: 0 26 | ssl: 00 27 | l1: 28 | name: "LowRankSCLayer" 29 | params: 30 | patch_side: 5 31 | K: 5 32 | rank: 3 33 | code_size: 1024 34 | stride: 1 35 | input_centering: 0 36 | patch_centering: 1 37 | tied: "D" 38 | init_method: "kaiming_uniform" 39 | lbda_init: 0.001 40 | lbda_mode: "MC" 41 | beta: 0 42 | ssl: 0 43 | 44 | backtracking: 45 | monitor: "val_mpsnr" 46 | mode: "max" 47 | dirpath: "backtracking" 48 | period: 5 49 | div_thresh: 4 50 | dummy: False 51 | lr_decay: 0.8 52 | id: T3SC 53 | -------------------------------------------------------------------------------- /docs/source/dataset.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | 3 | 🌟 存放位置: `/share/dataset/hsi` 4 | 5 | ```tree 6 | ├── cave 7 | ├── harvard 8 | ├── icvl 9 | │ ├── test 10 | │ │ ├── icvl_512_30 11 | │ │ ├── icvl_512_50 12 | │ │ ├── icvl_512_70 13 | │ │ ├── icvl_512_blind 14 | │ │ ├── icvl_512_deadline 15 | │ │ ├── icvl_512_impulse 16 | │ │ ├── icvl_512_mixture 17 | │ │ ├── icvl_512_noniid 18 | │ │ └── icvl_512_stripe 19 | │ └── train 20 | │ └── ICVL64_31_100.db 21 | ├── raw 22 | └── remote 23 | ├── Indian_pines.mat 24 | ├── Pavia.mat 25 | ├── PaviaU.mat 26 | ├── Salinas.mat 27 | └── Urban_R162.mat 28 | ``` 29 | ## ICVL 30 | 31 | 自然场景高光谱图像数据集 32 | 33 | - 来源:[Official Website](http://icvl.cs.bgu.ac.il/hyperspectral/) 34 | - 服务器存放路径: `/data/dataset/hsi/icvl` 35 | 36 | Meta Data 37 | 38 | - 201 张图 39 | 40 | 示例图: 41 | 42 | ![icvl](../imgs/icvl.png) 43 | 44 | 文件结构: 45 | ``` 46 | 47 | ``` 48 | 49 | Cite 50 | ```bibtex 51 | @inproceedings{arad_and_ben_shahar_2016_ECCV, 52 | title={Sparse Recovery of Hyperspectral Signal from Natural RGB Images}, 53 | author={Arad, Boaz and Ben-Shahar, Ohad}, 54 | booktitle={European Conference on Computer Vision}, 55 | pages={19--34}, 56 | year={2016}, 57 | organization={Springer} 58 | } 59 | ``` 60 | 61 | ## CAVE 62 | 63 | 64 | 65 | ## Harvard 66 | 67 | ## Remote Sensed 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /hsir/model/t3sc/layers/soft_thresholding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | MODES = ["SG", "SC", "MG", "MC"] 7 | 8 | 9 | class SoftThresholding(nn.Module): 10 | def __init__(self, mode, lbda_init, code_size=None, K=None): 11 | super().__init__() 12 | assert mode in MODES, f"Mode {mode!r} not recognized" 13 | self.mode = mode 14 | 15 | if self.mode[1] == "C": 16 | # 1 lambda per channel 17 | lbda_shape = (1, code_size, 1, 1) 18 | else: 19 | # 1 lambda for all channels 20 | lbda_shape = (1, 1, 1, 1) 21 | 22 | if self.mode[0] == "M": 23 | # 1 set of lambdas per unfolding 24 | self.lbda = nn.ParameterList( 25 | [ 26 | nn.Parameter(lbda_init * torch.ones(*lbda_shape)) 27 | for _ in range(K) 28 | ] 29 | ) 30 | else: 31 | # 1 set of lambdas for all unfoldings 32 | self.lbda = nn.Parameter(lbda_init * torch.ones(*lbda_shape)) 33 | 34 | def forward(self, x, k=None): 35 | if self.mode[0] == "M": 36 | return self._forward(x, self.lbda[k]) 37 | else: 38 | return self._forward(x, self.lbda) 39 | 40 | def _forward(self, x, lbda): 41 | return F.relu(x - lbda) - F.relu(-x - lbda) 42 | -------------------------------------------------------------------------------- /benchmark/datasets/CAVE/cave_build_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from scipy.io import loadmat, savemat 4 | 5 | 6 | def minmax_normalize(array): 7 | amin = np.min(array) 8 | amax = np.max(array) 9 | return (array - amin) / (amax - amin) 10 | 11 | 12 | def main(get_noise, sigma): 13 | root = '/home/wzliu/projects/data/cave_mat/' 14 | save_root = os.path.join('/home/wzliu/projects/data/cave_test/', 'cave_512_{}'.format(sigma)) 15 | os.makedirs(save_root, exist_ok=True) 16 | 17 | with open('cave_test.txt') as f: 18 | fns = [l.strip() for l in f.readlines()] 19 | 20 | np.random.seed(2022) 21 | for fn in fns: 22 | print(fn) 23 | path = os.path.join(root, fn) 24 | data = loadmat(path) 25 | gt = data['gt'] 26 | gt = minmax_normalize(gt) 27 | noise = get_noise(gt.shape) 28 | input = gt + noise 29 | data = {'input': input, 'gt': gt, 'sigma': sigma} 30 | savemat(os.path.join(save_root, fn+'.mat'), data) 31 | 32 | 33 | def fixed(sigma): 34 | def get_noise(shape): 35 | return np.random.randn(*shape) * sigma / 255. 36 | return get_noise 37 | 38 | def blind(min, max): 39 | def get_noise(shape): 40 | return np.random.randn(*shape) * (min+np.random.rand(1)*(max-min)) / 255. 41 | return get_noise 42 | 43 | main(fixed(30),30) 44 | main(fixed(50),50) 45 | main(fixed(70),70) 46 | main(blind(10,70),'blind') 47 | 48 | -------------------------------------------------------------------------------- /hsir/model/t3sc/layers/encoding_layer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch.nn as nn 4 | 5 | logger = logging.getLogger(__name__) 6 | logger.setLevel(logging.DEBUG) 7 | 8 | 9 | class EncodingLayer(nn.Module): 10 | def __init__( 11 | self, 12 | in_channels=None, 13 | code_size=None, 14 | input_centering=False, 15 | **kwargs, 16 | ): 17 | super().__init__() 18 | self.in_channels = in_channels 19 | self.code_size = code_size 20 | self.input_centering = input_centering 21 | 22 | def forward(self, x, mode=None, **kwargs): 23 | assert mode in ["encode", "decode", None], f"Mode {mode!r} unknown" 24 | 25 | if mode in ["encode", None]: 26 | x = self.encode(x, **kwargs) 27 | if mode in ["decode", None]: 28 | x = self.decode(x, **kwargs) 29 | return x 30 | 31 | def encode(self, x, **kwargs): 32 | if self.input_centering: 33 | self.input_means = x.mean(dim=[2, 3], keepdim=True) 34 | x -= self.input_means 35 | 36 | x = self._encode(x, **kwargs) 37 | 38 | return x 39 | 40 | def decode(self, x, **kwargs): 41 | x = self._decode(x, **kwargs) 42 | 43 | if self.input_centering: 44 | x += self.input_means 45 | 46 | return x 47 | 48 | def _encode(self, x, **kwargs): 49 | raise NotImplementedError 50 | 51 | def _decode(self, x, **kwargs): 52 | raise NotImplementedError 53 | -------------------------------------------------------------------------------- /hsiboard/table.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from util import * 3 | 4 | 5 | def main(logdir): 6 | 7 | with st.sidebar: 8 | st.title('HSIR Board') 9 | st.subheader('Table') 10 | st.caption(f'Directory: {logdir}') 11 | methods = listdir(logdir, exclude=['gt']) 12 | selected_methods = st.sidebar.multiselect( 13 | "Select Method", 14 | methods, 15 | default=methods, 16 | ) 17 | 18 | stat = load_stat(logdir) 19 | 20 | st.subheader('Not loaded') 21 | st.text(set(selected_methods)-set(stat.keys())) 22 | 23 | if len(selected_methods) == 1: 24 | selected_method = selected_methods[0] 25 | print(stat[selected_method][0]) 26 | st.header(selected_method) 27 | st.dataframe(stat[selected_method]) 28 | else: 29 | table = {} 30 | for m in stat.keys(): 31 | for d in stat[m]: 32 | row = {'Method': m} 33 | for k, v in d.items(): 34 | if k != 'Name': 35 | row[k] = v 36 | if d['Name'] not in table: 37 | table[d['Name']] = [] 38 | table[d['Name']].append(row) 39 | for k, v in table.items(): 40 | st.subheader(k) 41 | st.dataframe(v) 42 | 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser('HSIR Board') 46 | parser.add_argument('--logdir', default='results') 47 | args = parser.parse_args() 48 | 49 | main(args.logdir) 50 | -------------------------------------------------------------------------------- /docs/source/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | hide-toc: true 3 | --- 4 | 5 | # Welcome to HSIR 6 | 7 | 8 | ```{toctree} 9 | :maxdepth: 3 10 | :hidden: true 11 | 12 | getstart 13 | dataset 14 | benchmark 15 | ``` 16 | 17 | 18 | ```{toctree} 19 | :caption: Useful Links 20 | :hidden: 21 | PyPI page 22 | GitHub Repository 23 | ``` 24 | 25 | 26 | [![PyPI](https://img.shields.io/pypi/v/hsir)](https://pypi.org/project/hsir/) 27 | 28 | Out-of-box Hyperspectral Image Restoration Toolbox 29 | 30 | 31 | ## Install 32 | 33 | ```shell 34 | pip install hsir 35 | ``` 36 | 37 | ## Usage 38 | 39 | Here are some runable examples, please refer to the code for more options. 40 | 41 | ```shell 42 | python hsirun/train.py -a qrnn3d.qrnn3d 43 | python hsirun/test.py -a qrnn3d.qrnn3d -r qrnn3d.pth -t icvl_512_50 44 | ``` 45 | 46 | 47 | ## Acknowledgement 48 | 49 | - [QRNN3D](https://github.com/Vandermode/QRNN3D) 50 | - [DPHSIR](https://github.com/Zeqiang-Lai/DPHSIR) 51 | - [MST](https://github.com/caiyuanhao1998/MST) 52 | - [TLC](https://github.com/megvii-research/TLC) 53 | 54 | ## Citation 55 | 56 | If you find this repo helpful, please considering citing us. 57 | 58 | ```bibtex 59 | @misc{hsir, 60 | author={Zeqiang Lai, Miaoyu Li, Ying Fu}, 61 | title={HSIR: Out-of-box Hyperspectral Image Restoration Toolbox}, 62 | year={2022}, 63 | url={https://github.com/bit-isp/HSIR}, 64 | } 65 | ``` 66 | 67 | ![Visitors](https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fgithub.com%2Fbit-isp%2FHSIR&countColor=%23263759&style=flat) -------------------------------------------------------------------------------- /benchmark/datasets/Harvard/harvard_build_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from scipy.io import loadmat, savemat 4 | 5 | 6 | def minmax_normalize(array): 7 | amin = np.min(array) 8 | amax = np.max(array) 9 | return (array - amin) / (amax - amin) 10 | 11 | def crop_center(img, cropx, cropy): 12 | y, x, _ = img.shape 13 | startx = x // 2 - (cropx // 2) 14 | starty = y // 2 - (cropy // 2) 15 | return img[starty:starty + cropy, startx:startx + cropx, :] 16 | 17 | def main(get_noise, sigma): 18 | root = '/home/wzliu/projects/data/harvard/CZ_hsdb/' 19 | save_root = os.path.join('/home/wzliu/projects/data/harvard_test/', 'harvard_512_{}'.format(sigma)) 20 | os.makedirs(save_root, exist_ok=True) 21 | 22 | with open('harvard_test.txt') as f: 23 | fns = [l.strip() for l in f.readlines()] 24 | 25 | np.random.seed(2022) 26 | for fn in fns: 27 | print(fn) 28 | path = os.path.join(root, fn) 29 | data = loadmat(path) 30 | gt = data['ref'] 31 | gt = minmax_normalize(gt) 32 | gt = crop_center(gt, 512, 512) 33 | print(gt.shape) 34 | noise = get_noise(gt.shape) 35 | input = gt + noise 36 | data = {'input': input, 'gt': gt, 'sigma': sigma} 37 | savemat(os.path.join(save_root, fn+'.mat'), data) 38 | 39 | 40 | def fixed(sigma): 41 | def get_noise(shape): 42 | return np.random.randn(*shape) * sigma / 255. 43 | return get_noise 44 | 45 | def blind(min, max): 46 | def get_noise(shape): 47 | return np.random.randn(*shape) * (min+np.random.rand(1)*(max-min)) / 255. 48 | return get_noise 49 | 50 | main(fixed(30),30) 51 | main(fixed(50),50) 52 | main(fixed(70),70) 53 | main(blind(10,70),'blind') 54 | 55 | -------------------------------------------------------------------------------- /hsir/data/transform/general.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import threading 4 | 5 | 6 | class HSI2Tensor(object): 7 | """ 8 | Transform a numpy array with shape (C, H, W) 9 | into torch 4D Tensor (1, C, H, W) or (C, H, W) 10 | """ 11 | 12 | def __init__(self, use_chw=True): 13 | """ use_chw: True for (C, D, H, W) and False for (C, H, W) """ 14 | self.use_chw = use_chw 15 | 16 | def __call__(self, hsi): 17 | img = torch.from_numpy(hsi) 18 | if not self.use_chw: 19 | img = img.unsqueeze(0) 20 | return img.float() 21 | 22 | 23 | class Identity(object): 24 | """ 25 | Identity transformw 26 | """ 27 | 28 | def __call__(self, data): 29 | return data 30 | 31 | 32 | class RandomCrop(object): 33 | """ 34 | Random crop transform 35 | """ 36 | 37 | def __init__(self, croph, cropw): 38 | self.croph = croph 39 | self.cropw = cropw 40 | 41 | def __call__(self, img): 42 | _, h, w = img.shape 43 | croph, cropw = self.croph, self.cropw 44 | h1 = random.randint(0, h - croph) 45 | w1 = random.randint(0, w - cropw) 46 | return img[:, h1:h1 + croph, w1:w1 + cropw] 47 | 48 | 49 | class LockedIterator(object): 50 | def __init__(self, it): 51 | self.lock = threading.Lock() 52 | self.it = it.__iter__() 53 | 54 | def __iter__(self): return self 55 | 56 | def __next__(self): 57 | self.lock.acquire() 58 | try: 59 | return next(self.it) 60 | finally: 61 | self.lock.release() 62 | 63 | 64 | class SequentialSelect(object): 65 | def __pos(self, n): 66 | i = 0 67 | while True: 68 | # print(i) 69 | yield i 70 | i = (i + 1) % n 71 | 72 | def __init__(self, transforms): 73 | self.transforms = transforms 74 | self.pos = LockedIterator(self.__pos(len(transforms))) 75 | 76 | def __call__(self, img): 77 | out = self.transforms[next(self.pos)](img) 78 | return out 79 | -------------------------------------------------------------------------------- /hsirun/benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from datetime import datetime 5 | from ast import literal_eval 6 | import platform 7 | 8 | import torch 9 | import torch.cuda 10 | import torch.backends.cudnn 11 | 12 | from torchlight.utils import instantiate 13 | 14 | torch.set_grad_enabled(False) 15 | torch.backends.cudnn.benchmark = True 16 | 17 | 18 | def main(arch, input_shape, total_steps=10, save_path='benchmark.json'): 19 | net = instantiate(arch) 20 | 21 | net = net.cuda().eval() 22 | input = torch.randn(*input_shape).cuda() 23 | 24 | total_params = sum([param.nelement() for param in net.parameters()]) / 1e6 25 | print("Number of parameter: %.2fM" % (total_params)) 26 | 27 | # warm up for benchmark 28 | for _ in range(10): 29 | net(input) 30 | 31 | start = torch.cuda.Event(enable_timing=True) 32 | end = torch.cuda.Event(enable_timing=True) 33 | 34 | steps = int(total_steps) 35 | start.record() 36 | 37 | for _ in range(steps): 38 | net(input) 39 | end.record() 40 | 41 | torch.cuda.synchronize() 42 | 43 | avg_time = start.elapsed_time(end) / steps 44 | print('Time: {} ms'.format(avg_time)) 45 | 46 | database = {} 47 | if os.path.exists(save_path): 48 | database = json.load(open(save_path, 'r')) 49 | entry = database.get(arch, []) 50 | dtstr = datetime.now().strftime("%Y-%m-%d-%H:%M:%S") 51 | entry.append({ 52 | 'runtime': avg_time, 'params': total_params, 'date': dtstr, 53 | 'os': platform.platform(), 54 | 'processor': platform.processor(), 55 | }) 56 | database[arch] = entry 57 | with open(save_path, 'w') as f: 58 | json.dump(database, f, indent=4) 59 | 60 | 61 | if __name__ == '__main__': 62 | parser = argparse.ArgumentParser('Benchmark model runtime and params') 63 | parser.add_argument('-a', '--arch', type=str, required=True) 64 | parser.add_argument('-t', '--steps', type=int, default=10) 65 | parser.add_argument('-s', '--save-path', type=str, default='benchmark.json') 66 | parser.add_argument('-i', '--input-shape', type=str, default='[1,1,31,512,512]') 67 | args = parser.parse_args() 68 | input_shape = literal_eval(args.input_shape) 69 | main(args.arch, input_shape, args.steps, args.save_path) 70 | -------------------------------------------------------------------------------- /hsiboard/viewer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from os.path import join 3 | 4 | from util import * 5 | from box import * 6 | 7 | 8 | def main(logdir): 9 | set_page_container_style() 10 | with st.sidebar: 11 | st.title('HSIR Board') 12 | st.subheader('Viewer') 13 | st.caption(f'Directory: {logdir}') 14 | methods = listdir(logdir, exclude=['gt']) 15 | selected_method = st.sidebar.selectbox( 16 | "Select Method", 17 | methods 18 | ) 19 | datasets = listdir(join(logdir, selected_method)) 20 | selected_dataset = st.sidebar.selectbox( 21 | "Select Dataset", 22 | datasets 23 | ) 24 | selected_vis_type = st.sidebar.selectbox( 25 | "Select Image Type", 26 | ['color', 'gray'] 27 | ) 28 | 29 | st.header('Box') 30 | enable_enlarge = st.checkbox('Enlarge') 31 | crow = st.slider('row coordinate', min_value=0.0, max_value=1.0, value=0.2) 32 | ccol = st.slider('col coordinate', min_value=0.0, max_value=1.0, value=0.2) 33 | selected_box_pos = st.sidebar.selectbox( 34 | "Select Box Position", 35 | ['Bottom Right', 'Bottom Left', 'Top Right', 'Top Left'], 36 | ) 37 | 38 | st.header('Layout') 39 | ncol = st.slider('number of columns', min_value=1, max_value=20, value=6) 40 | 41 | imgs = load_imgs(join(logdir, selected_method, selected_dataset, selected_vis_type)) 42 | 43 | nrow = len(imgs) // ncol + 1 44 | grids = make_grid(nrow, ncol) 45 | details = load_per_image_stat(join(logdir, selected_method, selected_dataset)) 46 | with st.container(): 47 | idx = 0 48 | for name, img in imgs.items(): 49 | name = os.path.splitext(name)[-2] 50 | if enable_enlarge: 51 | h, w = img.shape[:2] 52 | img = addbox(img.copy(), (int(h * crow), int(w * ccol)), 53 | bbpos=mapbbpox[selected_box_pos]) 54 | ct = grids[idx // ncol][idx % ncol] 55 | ct.image(img, caption='%s [%.4f]' % (name, details[name]['MPSNR'])) 56 | idx += 1 57 | 58 | 59 | 60 | if __name__ == '__main__': 61 | parser = argparse.ArgumentParser('HSIR Board') 62 | parser.add_argument('--logdir', default='results') 63 | args = parser.parse_args() 64 | 65 | main(args.logdir) 66 | -------------------------------------------------------------------------------- /hsir/model/denet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.init as init 3 | 4 | 5 | """ 6 | each tuple unit: [out_channels, dilation] 7 | """ 8 | cfg = ([64, 1], [64, 1], [64, 1], 9 | [128, 1], [128, 1], [128, 1], 10 | [256, 2], [256, 2], [256, 2], 11 | [128, 1], [128, 1], [128, 1], 12 | [64, 1], [64, 1], [64, 1], [64, 1]) 13 | 14 | 15 | class DeNet(nn.Module): 16 | def __init__(self, in_channels=10, kernel_size=3, init_weights=True): 17 | super(DeNet, self).__init__() 18 | layers = [] 19 | out_channels = 64 20 | # add CR 21 | layers.append(nn.Conv2d( 22 | in_channels=in_channels, out_channels=out_channels, 23 | kernel_size=kernel_size, padding=1, bias=True)) 24 | layers.append(nn.ReLU(inplace=True)) 25 | # add CBR1 to CBR16 26 | in_channels = out_channels 27 | for out_channels, dilation in cfg: 28 | if dilation == 1: 29 | padding = 1 30 | elif dilation == 2: 31 | padding = 2 32 | layers.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 33 | kernel_size=kernel_size, padding=padding, 34 | dilation=dilation, bias=False) 35 | ) 36 | layers.append(nn.BatchNorm2d(num_features=out_channels)) 37 | layers.append(nn.ReLU(inplace=True)) 38 | in_channels = out_channels 39 | 40 | # add C 41 | layers.append(nn.Conv2d(in_channels=64, out_channels=10, 42 | kernel_size=kernel_size, padding=1, bias=False) 43 | ) 44 | self.denet = nn.Sequential(*layers) 45 | # if init_weights: 46 | # self._initialize_weights() 47 | 48 | def forward(self, x): 49 | y = x 50 | out = self.denet(x) 51 | return y - out 52 | 53 | def _initialize_weights(self): 54 | print("===> Start initializing weights") 55 | for m in self.modules(): 56 | if isinstance(m, nn.Conv2d): 57 | init.orthogonal_(m.weight) 58 | if m.bias is not None: 59 | init.constant_(m.bias, 0) 60 | elif isinstance(m, nn.BatchNorm2d): 61 | init.constant_(m.weight, 1) 62 | init.constant_(m.bias, 0) 63 | 64 | def hsidenet(): 65 | net = DeNet(in_channels=31) 66 | net.use_2dconv = True 67 | net.bandwise = False 68 | return net -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints 2 | results 3 | model_zoo 4 | ./data 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /hsir/model/t3sc/multilayer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | #from .base import BaseModel 7 | from . import layers 8 | 9 | logger = logging.getLogger(__name__) 10 | logger.setLevel(logging.DEBUG) 11 | 12 | 13 | class MultilayerModel(nn.Module): 14 | def __init__( 15 | self, 16 | channels, 17 | layers, 18 | ssl=0, 19 | n_ssl=0, 20 | ckpt=None, 21 | ): 22 | super().__init__() 23 | self.channels = channels 24 | self.layers_params = layers 25 | self.ssl = ssl 26 | self.n_ssl = n_ssl 27 | logger.debug(f"ssl : {self.ssl}, n_ssl : {self.n_ssl}") 28 | 29 | self.init_layers() 30 | self.normalized_dict = False 31 | 32 | logger.info(f"Using SSL : {self.ssl}") 33 | self.ckpt = ckpt 34 | if self.ckpt is not None: 35 | logger.info(f"Loading ckpt {self.ckpt!r}") 36 | d = torch.load(self.ckpt) 37 | self.load_state_dict(d["state_dict"]) 38 | 39 | def init_layers(self): 40 | list_layers = [] 41 | in_channels = self.channels 42 | 43 | for i in range(len(self.layers_params)): 44 | logger.debug(f"Initializing layer {i}") 45 | name = self.layers_params[f"l{i}"]["name"] 46 | params = self.layers_params[f"l{i}"]["params"] 47 | layer_cls = layers.__dict__[name] 48 | layer = layer_cls( 49 | in_channels=in_channels, 50 | **params, 51 | ) 52 | in_channels = layer.code_size 53 | 54 | list_layers.append(layer) 55 | self.layers = nn.ModuleList(list_layers) 56 | 57 | def forward( 58 | self, x, mode=None, img_id=None, sigmas=None, ssl_idx=None, **kwargs 59 | ): 60 | assert mode in ["encode", "decode", None], f"Mode {mode!r} unknown" 61 | x = x.float().clone() 62 | 63 | if mode in ["encode", None]: 64 | x = self.encode(x, img_id, sigmas=sigmas, ssl_idx=ssl_idx) 65 | if mode in ["decode", None]: 66 | x = self.decode(x, img_id) 67 | return x 68 | 69 | def encode(self, x, img_id, sigmas, ssl_idx): 70 | 71 | for layer in self.layers: 72 | x = layer( 73 | x, 74 | mode="encode", 75 | img_id=img_id, 76 | sigmas=sigmas, 77 | ssl_idx=ssl_idx, 78 | ) 79 | return x 80 | 81 | def decode(self, x, img_id): 82 | for layer in self.layers[::-1]: 83 | x = layer(x, mode="decode", img_id=img_id) 84 | 85 | return x 86 | -------------------------------------------------------------------------------- /hsiboard/box.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | mapbbpox = { 5 | 'Bottom Right': 'br', 6 | 'Bottom Left': 'bl', 7 | 'Top Right': 'ur', 8 | 'Top Left': 'ul', 9 | } 10 | 11 | def get_border_box_pt(h, w, size, thickness, pos): 12 | if pos == 'br': 13 | return (w-size-thickness, h-size-thickness) 14 | if pos == 'ur': 15 | return (w-size-thickness, 0) 16 | if pos == 'ul': 17 | return (0, 0) 18 | if pos == 'bl': 19 | return (0, h-size-thickness) 20 | 21 | def addbox(img, pt, size=100, 22 | bbsize=200, bbpos='br', 23 | color=(255, 0, 0), 24 | thickness=1, 25 | bbthickness=2): 26 | H, W = img.shape[0], img.shape[1] 27 | 28 | ptb = get_border_box_pt(H, W, bbsize, bbthickness, pos=bbpos) 29 | crop_img = img[pt[1]:pt[1]+size, pt[0]:pt[0]+size] 30 | crop_img = cv2.resize(crop_img, (bbsize, bbsize)) 31 | img[ptb[1]:ptb[1]+bbsize, ptb[0]:ptb[0]+bbsize] = crop_img 32 | 33 | # big box 34 | pt1 = ptb 35 | pt2 = (pt1[0]+bbsize, pt1[1]+bbsize) 36 | cv2.rectangle(img, pt1, pt2, color, bbthickness) 37 | 38 | # small box 39 | pt1 = pt 40 | pt2 = (pt1[0]+size, pt1[1]+size) 41 | cv2.rectangle(img, pt1, pt2, color, thickness) 42 | 43 | return img 44 | 45 | 46 | def convert_color(arr, cmap='viridis', vmin=0, vmax=0.1): 47 | import matplotlib.cm as cm 48 | sm = cm.ScalarMappable(cmap=cmap) 49 | sm.set_clim(vmin, vmax) 50 | rgba = sm.to_rgba(arr, alpha=1) 51 | return np.array(rgba[:, :, :3]) 52 | 53 | 54 | def addbox_with_diff(input, gt, pt, vmax=0.1, size=100, 55 | color=(255, 0, 0), 56 | thickness=1, 57 | bbthickness=2, 58 | sep=4): 59 | if len(input.shape) == 2: 60 | input = np.expand_dims(input, -1) 61 | if len(gt.shape) == 2: 62 | gt = np.expand_dims(gt, -1) 63 | 64 | H, W = input.shape[:2] 65 | C = 3 66 | 67 | out = np.zeros((H, W + H // 2 + sep//2, C)) 68 | out[:, :W] = input 69 | 70 | # small box. 71 | pt1 = pt 72 | pt2 = (pt1[0] + size, pt1[1] + size) 73 | cv2.rectangle(out, pt1, pt2, color, thickness) 74 | 75 | # crop. 76 | bbsize = H // 2 - sep//2 77 | crop_img = input[pt[1]:pt[1] + size, pt[0]:pt[0] + size, :] 78 | crop_img = cv2.resize(crop_img, (bbsize, bbsize)) 79 | 80 | # diff 81 | diff = convert_color(np.abs(input-gt)[:,:,0], vmin=0, vmax=vmax) 82 | 83 | crop_img_diff = diff[pt[1]:pt[1] + size, pt[0]:pt[0] + size, :] 84 | crop_img_diff = cv2.resize(crop_img_diff, (bbsize, bbsize)) 85 | 86 | if len(crop_img.shape) == 2: crop_img = np.expand_dims(crop_img, axis=-1) 87 | if len(crop_img_diff.shape) == 2: crop_img_diff = np.expand_dims(crop_img_diff, axis=-1) 88 | 89 | out[:bbsize, W+sep:W+sep + bbsize, :] = crop_img 90 | out[bbsize+sep:, W+sep:W+sep + bbsize, :] = crop_img_diff 91 | 92 | return out 93 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'HSIR' 21 | copyright = '2022, BIT-ISP Lab' 22 | author = 'Zeqiang-Lai, Miaoyu Li' 23 | 24 | # The full version, including alpha/beta/rc tags 25 | release = '0.0.1' 26 | 27 | 28 | # -- General configuration --------------------------------------------------- 29 | 30 | # Add any Sphinx extension module names here, as strings. They can be 31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 32 | # ones. 33 | extensions = [ 34 | 'myst_parser', 35 | 'sphinx_copybutton', 36 | # "sphinx_inline_tabs", 37 | 'sphinx.ext.autodoc', 38 | 'sphinx.ext.autosummary', 39 | 'sphinx.ext.doctest', 40 | 'sphinx.ext.todo', 41 | 'sphinx.ext.coverage', 42 | 'sphinx.ext.mathjax', 43 | 'sphinx.ext.viewcode', 44 | 'sphinx.ext.napoleon', 45 | ] 46 | 47 | # Add any paths that contain templates here, relative to this directory. 48 | templates_path = ['_templates'] 49 | 50 | # List of patterns, relative to source directory, that match files and 51 | # directories to ignore when looking for source files. 52 | # This pattern also affects html_static_path and html_extra_path. 53 | exclude_patterns = [] 54 | 55 | # The name of the Pygments (syntax highlighting) style to use. 56 | pygments_style = 'sphinx' 57 | 58 | # -- Options for HTML output ------------------------------------------------- 59 | 60 | # The theme to use for HTML and HTML Help pages. See the documentation for 61 | # a list of builtin themes. 62 | # 63 | # html_theme = 'alabaster' 64 | # html_theme = 'sphinx_rtd_theme' 65 | html_theme = 'furo' 66 | 67 | html_title = project 68 | 69 | # Add any paths that contain custom static files (such as style sheets) here, 70 | # relative to this directory. They are copied after the builtin static files, 71 | # so a file named "default.css" will overwrite the builtin "default.css". 72 | html_static_path = ['_static'] 73 | 74 | 75 | # 76 | # -- Options for TODOs ------------------------------------------------------- 77 | # 78 | todo_include_todos = True 79 | 80 | # 81 | # -- Options for Markdown files ---------------------------------------------- 82 | # 83 | myst_admonition_enable = True 84 | myst_deflist_enable = True 85 | myst_heading_anchors = 3 86 | -------------------------------------------------------------------------------- /hsirun/train_ssr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | from torch.utils.data import DataLoader 5 | from torch.optim.lr_scheduler import CosineAnnealingLR 6 | from torchlight.utils import instantiate 7 | from torchlight.nn.utils import get_learning_rate 8 | 9 | from hsir.data.ssr.dataset import TrainDataset, ValidDataset 10 | from hsir.trainer import Trainer 11 | 12 | 13 | def train_cfg(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--arch', '-a', required=True) 16 | parser.add_argument('--name', '-n', type=str, default=None, 17 | help='name of the experiment, if not specified, arch will be used.') 18 | parser.add_argument('--lr', type=float, default=4e-4) 19 | parser.add_argument('--bs', type=int, default=20) 20 | parser.add_argument('--epochs', type=int, default=5) 21 | parser.add_argument('--schedule', type=str, default='hsir.schedule.denoise_default') 22 | parser.add_argument('--resume', '-r', action='store_true') 23 | parser.add_argument('--resume-path', '-rp', type=str, default=None) 24 | parser.add_argument('--data-root', type=str, default='data/rgb2hsi') 25 | parser.add_argument('--data-size', type=int, default=None) 26 | parser.add_argument('--save-root', type=str, default='checkpoints/ssr') 27 | parser.add_argument('--gpu-ids', type=str, default='0', help='gpu ids') 28 | cfg = parser.parse_args() 29 | cfg.gpu_ids = [int(id) for id in cfg.gpu_ids.split(',')] 30 | cfg.name = cfg.arch if cfg.name is None else cfg.name 31 | return cfg 32 | 33 | 34 | def main(): 35 | cfg = train_cfg() 36 | 37 | net = instantiate(cfg.arch) 38 | trainer = Trainer( 39 | net, 40 | lr=cfg.lr, 41 | save_dir=os.path.join(cfg.save_root, cfg.name), 42 | gpu_ids=cfg.gpu_ids, 43 | ) 44 | trainer.logger.print(cfg) 45 | if cfg.resume: trainer.load(cfg.resume_path) 46 | 47 | dataset = TrainDataset(cfg.data_root, size=cfg.data_size, stride=64) 48 | train_loader = DataLoader(dataset, batch_size=cfg.bs, shuffle=True, num_workers=8, pin_memory=True) 49 | dataset = ValidDataset(cfg.data_root) 50 | val_loader = DataLoader(dataset, batch_size=1) 51 | 52 | """Main loop""" 53 | 54 | # lr_scheduler = CosineAnnealingLR(trainer.optimizer, cfg.max_epochs, eta_min=1e-6) 55 | epoch_per_save = 10 56 | best_psnr = 0 57 | 58 | while trainer.epoch < cfg.epochs: 59 | trainer.logger.print('Epoch [{}] Use lr={}'.format(trainer.epoch, get_learning_rate(trainer.optimizer))) 60 | 61 | trainer.train(train_loader) 62 | 63 | # save ckpt 64 | trainer.save_checkpoint('model_latest.pth') 65 | metrics = trainer.validate(val_loader, 'NITRE') 66 | if metrics['psnr'] > best_psnr: 67 | best_psnr = metrics['psnr'] 68 | trainer.save_checkpoint('model_best.pth') 69 | if trainer.epoch % epoch_per_save == 0: 70 | trainer.save_checkpoint() 71 | 72 | 73 | 74 | if __name__ == '__main__': 75 | main() 76 | -------------------------------------------------------------------------------- /hsir/model/t3sc/utils/patches_handler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import logging 6 | 7 | logger = logging.getLogger(__name__) 8 | logger.setLevel(logging.DEBUG) 9 | 10 | 11 | class PatchesHandler(nn.Module): 12 | def __init__(self, size, channels, stride, padding="constant"): 13 | super().__init__() 14 | if isinstance(size, int): 15 | self.size = np.array([size, size]) 16 | else: 17 | self.size = np.array(size) 18 | self.channels = channels 19 | self.n_elements = self.channels * self.size[0] * self.size[1] 20 | self.stride = stride 21 | 22 | self.padding = padding 23 | self.fold = None 24 | self.normalizer = None 25 | self.img_size = None 26 | 27 | def forward(self, x, mode="extract"): 28 | if mode == "extract": 29 | x = self.pad(x) 30 | x = self.extract(x) 31 | return x 32 | elif mode == "aggregate": 33 | x = self.aggregate(x) 34 | x = self.unpad(x) 35 | return x 36 | else: 37 | raise ValueError(f"Mode {mode!r} not recognized") 38 | 39 | def set_img_size(self, img_size): 40 | if np.any(self.img_size != np.array(img_size)): 41 | self.img_size = np.array(img_size) 42 | 43 | self.n_patches = 1 + np.ceil( 44 | np.maximum(self.img_size - self.size, 0) / self.stride 45 | ).astype(int) 46 | pads = ( 47 | self.size + (self.n_patches - 1) * self.stride - self.img_size 48 | ) 49 | self.padded_size = tuple(self.img_size + pads) 50 | _pads = [] 51 | for i in reversed(range(2)): 52 | _pads += [0, pads[i]] 53 | self.pads = _pads 54 | 55 | def pad(self, x): 56 | self.set_img_size(x.shape[2:]) 57 | x = F.pad(x, pad=self.pads, mode=self.padding) 58 | return x 59 | 60 | def unpad(self, x): 61 | x = x[:, :, : self.img_size[0], : self.img_size[1]] 62 | return x 63 | 64 | def extract(self, x): 65 | 66 | x = x.unfold(dimension=2, size=self.size[0], step=self.stride) 67 | x = x.unfold(dimension=3, size=self.size[1], step=self.stride) 68 | x = x.permute(0, 1, 4, 5, 2, 3) 69 | 70 | x = x.contiguous() 71 | 72 | return x 73 | 74 | def aggregate(self, x): 75 | x = x.view(-1, self.n_elements, self.n_patches[0] * self.n_patches[1]) 76 | 77 | if self.fold is None or self.padded_size != self.fold.output_size: 78 | self.init_fold() 79 | 80 | out = self.fold(x) 81 | if self.normalizer is None or self.normalizer.shape != x.shape: 82 | ones_input = torch.ones_like(x) 83 | self.normalizer = self.fold(ones_input) 84 | return (out / self.normalizer).squeeze(1) 85 | 86 | def init_fold(self): 87 | logger.debug(f"Initializing fold, padded shape: {self.padded_size}") 88 | self.fold = nn.Fold( 89 | output_size=self.padded_size, 90 | kernel_size=tuple(self.size), 91 | stride=self.stride, 92 | ) 93 | -------------------------------------------------------------------------------- /hsir/model/memnet.py: -------------------------------------------------------------------------------- 1 | """MemNet""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from sync_batchnorm import SynchronizedBatchNorm2d 6 | 7 | # BatchNorm2d = SynchronizedBatchNorm2d 8 | BatchNorm2d = nn.BatchNorm2d 9 | 10 | 11 | class MemNet(nn.Module): 12 | def __init__(self, in_channels, channels, num_memblock, num_resblock): 13 | super(MemNet, self).__init__() 14 | self.feature_extractor = BNReLUConv(in_channels, channels) 15 | self.reconstructor = BNReLUConv(channels, in_channels) 16 | self.dense_memory = nn.ModuleList( 17 | [MemoryBlock(channels, num_resblock, i+1) for i in range(num_memblock)] 18 | ) 19 | self.freeze_bn = True 20 | self.freeze_bn_affine = True 21 | 22 | def forward(self, x): 23 | residual = x 24 | out = self.feature_extractor(x) 25 | ys = [out] 26 | for memory_block in self.dense_memory: 27 | out = memory_block(out, ys) 28 | out = self.reconstructor(out) 29 | 30 | out = out + residual 31 | 32 | return out 33 | 34 | 35 | class MemoryBlock(nn.Module): 36 | """Note: num_memblock denotes the number of MemoryBlock currently""" 37 | def __init__(self, channels, num_resblock, num_memblock): 38 | super(MemoryBlock, self).__init__() 39 | self.recursive_unit = nn.ModuleList( 40 | [ResidualBlock(channels) for i in range(num_resblock)] 41 | ) 42 | self.gate_unit = BNReLUConv((num_resblock+num_memblock) * channels, channels, 1, 1, 0) 43 | 44 | def forward(self, x, ys): 45 | """ys is a list which contains long-term memory coming from previous memory block 46 | xs denotes the short-term memory coming from recursive unit 47 | """ 48 | xs = [] 49 | residual = x 50 | for layer in self.recursive_unit: 51 | x = layer(x) 52 | xs.append(x) 53 | 54 | gate_out = self.gate_unit(torch.cat(xs+ys, 1)) 55 | ys.append(gate_out) 56 | return gate_out 57 | 58 | 59 | class ResidualBlock(torch.nn.Module): 60 | """ResidualBlock 61 | introduced in: https://arxiv.org/abs/1512.03385 62 | x - Relu - Conv - Relu - Conv - x 63 | """ 64 | 65 | def __init__(self, channels, k=3, s=1, p=1): 66 | super(ResidualBlock, self).__init__() 67 | self.relu_conv1 = BNReLUConv(channels, channels, k, s, p) 68 | self.relu_conv2 = BNReLUConv(channels, channels, k, s, p) 69 | 70 | def forward(self, x): 71 | residual = x 72 | out = self.relu_conv1(x) 73 | out = self.relu_conv2(out) 74 | out = out + residual 75 | return out 76 | 77 | 78 | class BNReLUConv(nn.Sequential): 79 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=True): 80 | super(BNReLUConv, self).__init__() 81 | self.add_module('bn', BatchNorm2d(in_channels)) 82 | self.add_module('relu', nn.ReLU(inplace=inplace)) 83 | self.add_module('conv', nn.Conv2d(in_channels, channels, k, s, p, bias=False)) 84 | 85 | 86 | def memnet(): 87 | net = MemNet(31, 64, 6, 6) 88 | net.use_2dconv = True 89 | net.bandwise = False 90 | return net -------------------------------------------------------------------------------- /hsiboard/util.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import OrderedDict 4 | from os.path import exists, join 5 | 6 | import imageio 7 | import numpy as np 8 | import streamlit as st 9 | import cv2 10 | 11 | BACKGROUND_COLOR = 'white' 12 | COLOR = 'black' 13 | 14 | 15 | def set_page_container_style( 16 | max_width: int = 1100, max_width_100_percent: bool = True, 17 | padding_top: int = 2, padding_right: int = 2, padding_left: int = 2, padding_bottom: int = 10, 18 | color: str = COLOR, background_color: str = BACKGROUND_COLOR, 19 | ): 20 | if max_width_100_percent: 21 | max_width_str = f'max-width: 100%;' 22 | else: 23 | max_width_str = f'max-width: {max_width}px;' 24 | st.markdown( 25 | f''' 26 | 42 | ''', 43 | unsafe_allow_html=True, 44 | ) 45 | 46 | 47 | def listdir(path, exclude=[]): 48 | outputs = [] 49 | for folder in os.listdir(path): 50 | if os.path.isdir(join(path, folder)) and folder not in exclude: 51 | outputs.append(folder) 52 | return outputs 53 | 54 | 55 | def load_method_stat(logdir): 56 | stat = [] 57 | for folder in listdir(logdir): 58 | path = join(logdir, folder, 'log.json') 59 | if exists(path): 60 | data = json.load(open(path, 'r')) 61 | s = OrderedDict() 62 | s['Name'] = folder 63 | s.update(data['avg']) 64 | stat.append(s) 65 | return stat 66 | 67 | 68 | def load_stat(logdir): 69 | total_stat = {} 70 | for folder in listdir(logdir, exclude=['gt']): 71 | try: 72 | stat = load_method_stat(join(logdir, folder)) 73 | total_stat[folder] = stat 74 | except: 75 | print('error loading', folder, 'ignored') 76 | return total_stat 77 | 78 | def load_per_image_stat(logdir): 79 | path = join(logdir, 'log.json') 80 | data = json.load(open(path, 'r')) 81 | return data['detail'] 82 | 83 | def load_imgs(path): 84 | out = {} 85 | for name in os.listdir(path): 86 | if name.endswith('png'): 87 | img = imageio.imread(join(path, name)) 88 | img = np.array(img).clip(0, 255) 89 | out[name] = img 90 | return out 91 | 92 | 93 | def make_grid(rows, cols): 94 | grid = [0] * rows 95 | for i in range(rows): 96 | with st.container(): 97 | grid[i] = st.columns(cols) 98 | return grid 99 | 100 | def encode_image(img): 101 | img = np.uint8(img.clip(0,1)*255) 102 | if len(img.shape) == 3: img = img[:,:,::-1] 103 | _, encoded_image = cv2.imencode('.png', img) 104 | data = encoded_image.tobytes() 105 | return data -------------------------------------------------------------------------------- /hsir/model/qrnn3d/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from sync_batchnorm import SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 4 | 5 | BatchNorm3d = SynchronizedBatchNorm3d 6 | 7 | 8 | class BNReLUConv3d(nn.Sequential): 9 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False): 10 | super(BNReLUConv3d, self).__init__() 11 | self.add_module('bn', BatchNorm3d(in_channels)) 12 | self.add_module('relu', nn.ReLU(inplace=inplace)) 13 | self.add_module('conv', nn.Conv3d(in_channels, channels, k, s, p, bias=False)) 14 | 15 | 16 | class BNReLUDeConv3d(nn.Sequential): 17 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False): 18 | super(BNReLUDeConv3d, self).__init__() 19 | self.add_module('bn', BatchNorm3d(in_channels)) 20 | self.add_module('relu', nn.ReLU(inplace=inplace)) 21 | self.add_module('deconv', nn.ConvTranspose3d(in_channels, channels, k, s, p, bias=False)) 22 | 23 | 24 | class BNReLUUpsampleConv3d(nn.Sequential): 25 | def __init__(self, in_channels, channels, k=3, s=1, p=1, upsample=(1,2,2), inplace=False): 26 | super(BNReLUUpsampleConv3d, self).__init__() 27 | self.add_module('bn', BatchNorm3d(in_channels)) 28 | self.add_module('relu', nn.ReLU(inplace=inplace)) 29 | self.add_module('upsample_conv', UpsampleConv3d(in_channels, channels, k, s, p, bias=False, upsample=upsample)) 30 | 31 | 32 | class UpsampleConv3d(torch.nn.Module): 33 | """UpsampleConvLayer 34 | Upsamples the input and then does a convolution. This method gives better results 35 | compared to ConvTranspose2d. 36 | ref: http://distill.pub/2016/deconv-checkerboard/ 37 | """ 38 | 39 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=True, upsample=None): 40 | super(UpsampleConv3d, self).__init__() 41 | self.upsample = upsample 42 | if upsample: 43 | self.upsample_layer = torch.nn.Upsample(scale_factor=upsample, mode='trilinear', align_corners=True) 44 | 45 | self.conv3d = torch.nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) 46 | 47 | def forward(self, x): 48 | x_in = x 49 | if self.upsample: 50 | x_in = self.upsample_layer(x_in) 51 | out = self.conv3d(x_in) 52 | return out 53 | 54 | 55 | class BasicConv3d(nn.Sequential): 56 | def __init__(self, in_channels, channels, k=3, s=1, p=1, bias=False, bn=True): 57 | super(BasicConv3d, self).__init__() 58 | if bn: 59 | self.add_module('bn', BatchNorm3d(in_channels)) 60 | self.add_module('conv', nn.Conv3d(in_channels, channels, k, s, p, bias=bias)) 61 | 62 | 63 | class BasicDeConv3d(nn.Sequential): 64 | def __init__(self, in_channels, channels, k=3, s=1, p=1, bias=False, bn=True): 65 | super(BasicDeConv3d, self).__init__() 66 | if bn: 67 | self.add_module('bn', BatchNorm3d(in_channels)) 68 | self.add_module('deconv', nn.ConvTranspose3d(in_channels, channels, k, s, p, bias=bias)) 69 | 70 | 71 | class BasicUpsampleConv3d(nn.Sequential): 72 | def __init__(self, in_channels, channels, k=3, s=1, p=1, upsample=(1,2,2), bn=True): 73 | super(BasicUpsampleConv3d, self).__init__() 74 | if bn: 75 | self.add_module('bn', BatchNorm3d(in_channels)) 76 | self.add_module('upsample_conv', UpsampleConv3d(in_channels, channels, k, s, p, bias=False, upsample=upsample)) 77 | -------------------------------------------------------------------------------- /hsir/data/dataloader.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as T 2 | import torch.utils.data as D 3 | import hsir.data.transform.noise as N 4 | import hsir.data.transform.general as G 5 | from hsir.data import HSITestDataset, HSITrainDataset 6 | from hsir.data.utils import worker_init_fn 7 | 8 | 9 | def gaussian_loader_train_s1(root, use_chw=False): 10 | common_transform = G.Identity() 11 | input_transform = T.Compose([ 12 | N.AddNoise(50), 13 | G.HSI2Tensor(use_chw=use_chw) 14 | ]) 15 | target_transform = G.HSI2Tensor(use_chw=use_chw) 16 | dataset = HSITrainDataset(root, input_transform, target_transform, common_transform) 17 | loader = D.DataLoader(dataset, batch_size=16, shuffle=True, num_workers=8, 18 | pin_memory=True, worker_init_fn=worker_init_fn) 19 | return loader 20 | 21 | 22 | def gaussian_loader_train_s2(root, use_chw=False): 23 | common_transform = G.RandomCrop(32, 32) 24 | input_transform = T.Compose([ 25 | N.AddNoiseBlind([10, 30, 50, 70]), 26 | G.HSI2Tensor(use_chw=use_chw) 27 | ]) 28 | target_transform = G.HSI2Tensor(use_chw=use_chw) 29 | dataset = HSITrainDataset(root, input_transform, target_transform, common_transform) 30 | loader = D.DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8, 31 | pin_memory=True, worker_init_fn=worker_init_fn) 32 | return loader 33 | 34 | 35 | def gaussian_loader_train_s2_16(root, use_chw=False): 36 | common_transform = G.Identity() 37 | input_transform = T.Compose([ 38 | N.AddNoiseBlind([10, 30, 50, 70]), 39 | G.HSI2Tensor(use_chw=use_chw) 40 | ]) 41 | target_transform = G.HSI2Tensor(use_chw=use_chw) 42 | dataset = HSITrainDataset(root, input_transform, target_transform, common_transform) 43 | loader = D.DataLoader(dataset, batch_size=16, shuffle=True, num_workers=8, 44 | pin_memory=True, worker_init_fn=worker_init_fn) 45 | return loader 46 | 47 | 48 | def gaussian_loader_val(root, use_chw=False): 49 | dataset = HSITestDataset(root, size=5, use_chw=use_chw) 50 | loader = D.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) 51 | return loader 52 | 53 | 54 | def gaussian_loader_test(root, use_chw=False): 55 | dataset = HSITestDataset(root, return_name=True, use_chw=use_chw) 56 | loader = D.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) 57 | return loader 58 | 59 | 60 | def complex_loader_train(root, use_chw=False): 61 | sigmas = [10, 30, 50, 70] 62 | 63 | common_transform = G.Identity() 64 | input_transform = T.Compose([ 65 | N.AddNoiseNoniid(sigmas), 66 | G.SequentialSelect( 67 | transforms=[ 68 | lambda x: x, 69 | N.AddNoiseImpulse(), 70 | N.AddNoiseStripe(), 71 | N.AddNoiseDeadline() 72 | ] 73 | ), 74 | G.HSI2Tensor(use_chw=use_chw) 75 | ]) 76 | target_transform = G.HSI2Tensor(use_chw=use_chw) 77 | dataset = HSITrainDataset(root, input_transform, target_transform, common_transform) 78 | loader = D.DataLoader(dataset, batch_size=16, shuffle=True, num_workers=8, 79 | pin_memory=True, worker_init_fn=worker_init_fn) 80 | return loader 81 | 82 | 83 | def complex_loader_val(root, use_chw=False): 84 | dataset = HSITestDataset(root, use_chw=use_chw, size=5) 85 | loader = D.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) 86 | return loader 87 | -------------------------------------------------------------------------------- /hsir/data/transform/sr.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import scipy.ndimage 4 | 5 | # C,H,W format 6 | 7 | 8 | """For Super-Resolution""" 9 | 10 | 11 | class SRDegrade(object): 12 | def __init__(self, scale_factor=2): 13 | self.scale_factor = scale_factor 14 | 15 | def __call__(self, img): 16 | from scipy.ndimage import zoom 17 | img = zoom(img, zoom=(1, 1. / self.scale_factor, 1. / self.scale_factor)) 18 | img = zoom(img, zoom=(1, self.scale_factor, self.scale_factor)) 19 | return img 20 | 21 | 22 | class GaussianBlurScipy(object): 23 | def __init__(self, ksize=8, sigma=3): 24 | self.sigma = sigma 25 | self.truncate = (((ksize - 1) / 2) - 0.5) / sigma 26 | 27 | def __call__(self, img): 28 | from scipy.ndimage.filters import gaussian_filter 29 | img = gaussian_filter(img, sigma=self.sigma, truncate=self.truncate) 30 | return img 31 | 32 | 33 | # ---------------------- Blur ----------------------- # 34 | 35 | def fspecial_gaussian(hsize, sigma): 36 | hsize = [hsize, hsize] 37 | siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] 38 | std = sigma 39 | [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), 40 | np.arange(-siz[0], siz[0] + 1)) 41 | arg = -(x * x + y * y) / (2 * std * std) 42 | h = np.exp(arg) 43 | h[h < scipy.finfo(float).eps * h.max()] = 0 44 | sumh = h.sum() 45 | if sumh != 0: 46 | h = h / sumh 47 | return h 48 | 49 | 50 | class AbstractBlur: 51 | def __call__(self, img): 52 | img_L = scipy.ndimage.filters.convolve( 53 | img, np.expand_dims(self.kernel, axis=0), mode='wrap') 54 | return img_L 55 | 56 | 57 | class GaussianBlur(AbstractBlur): 58 | def __init__(self, ksize=8, sigma=3): 59 | self.kernel = fspecial_gaussian(ksize, sigma) 60 | 61 | 62 | class UniformBlur(AbstractBlur): 63 | def __init__(self, ksize): 64 | self.kernel = np.ones((ksize, ksize)) / (ksize * ksize) 65 | 66 | 67 | ## -------------------- Resize -------------------- ## 68 | 69 | class KFoldDownsample: 70 | ''' k-fold downsampler: 71 | Keeping the upper-left pixel for each distinct sfxsf patch and discarding the others 72 | ''' 73 | 74 | def __init__(self, sf): 75 | self.sf = sf 76 | 77 | def __call__(self, img): 78 | st = 0 79 | return img[:, st::self.sf, st::self.sf] 80 | 81 | 82 | class Resize: 83 | def __init__(self, sf, mode='cubic'): 84 | self.sf = sf 85 | self.mode = self.mode_map[mode] 86 | 87 | def __call__(self, img): 88 | img = img.transpose(1, 2, 0) 89 | img = cv2.resize(img, (int(img.shape[1] * self.sf), int(img.shape[0] * self.sf)), interpolation=self.mode) 90 | img = img.transpose(2, 0, 1) 91 | return img 92 | 93 | mode_map = { 94 | 'nearest': cv2.INTER_NEAREST, 95 | 'linear': cv2.INTER_LINEAR, 96 | 'cubic': cv2.INTER_CUBIC, 97 | 'area': cv2.INTER_AREA 98 | } 99 | 100 | 101 | class BicubicDownsample(Resize): 102 | def __init__(self, sf): 103 | super().__init__(1 / sf, 'cubic') 104 | 105 | def __call__(self, img): 106 | return super().__call__(img) 107 | 108 | 109 | class BicubicUpsample(Resize): 110 | def __init__(self, sf): 111 | super().__init__(sf, 'cubic') 112 | 113 | def __call__(self, img): 114 | return super().__call__(img) 115 | -------------------------------------------------------------------------------- /hsir/schedule.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | 4 | TrainSchedule = namedtuple('TrainSchedule', ['base_lr', 'stage1_epoch', 'lr_schedule', 'max_epochs']) 5 | 6 | 7 | denoise_default = TrainSchedule( 8 | max_epochs=80, 9 | stage1_epoch=30, 10 | base_lr=1e-3, 11 | lr_schedule={ 12 | 0: 1e-3, 13 | 20: 1e-4, 14 | 30: 1e-3, 15 | 45: 1e-4, 16 | 55: 5e-5, 17 | 60: 1e-5, 18 | 65: 5e-6, 19 | 75: 1e-6, 20 | }, 21 | ) 22 | 23 | denoise_1x = TrainSchedule( 24 | max_epochs=30, 25 | stage1_epoch=15, 26 | base_lr=1e-3, 27 | lr_schedule={ 28 | 0: 1e-3, 29 | 5: 1e-4, 30 | 10: 1e-3, 31 | 15: 1e-4, 32 | 20: 5e-5, 33 | 25: 1e-5, 34 | 30: 5e-6, 35 | }, 36 | ) 37 | 38 | denoise_unet = TrainSchedule( 39 | max_epochs=80, 40 | stage1_epoch=30, 41 | base_lr=2e-4, 42 | lr_schedule={ 43 | 0: 2e-4, 44 | 20: 1e-4, 45 | 30: 2e-4, 46 | 40: 1e-4, 47 | 50: 5e-5, 48 | 60: 1e-5, 49 | 70: 5e-6, 50 | }, 51 | ) 52 | 53 | denoise_grunet = TrainSchedule( 54 | max_epochs=80, 55 | stage1_epoch=30, 56 | base_lr=1e-4, 57 | lr_schedule={ 58 | 45: 5e-5, 59 | 60: 1e-5, 60 | 70: 5e-6, 61 | }, 62 | ) 63 | 64 | denoise_restormer = TrainSchedule( 65 | max_epochs=80, 66 | stage1_epoch=30, 67 | base_lr=1e-4, 68 | lr_schedule={ 69 | 0: 1e-4, 70 | 45: 5e-5, 71 | 55: 1e-5, 72 | 65: 5e-6, 73 | 75: 1e-6, 74 | }, 75 | ) 76 | 77 | denoise_hsid_cnn = denoise_restormer 78 | 79 | denoise_swinir = TrainSchedule( 80 | max_epochs=80, 81 | stage1_epoch=30, 82 | base_lr=2e-4, 83 | lr_schedule={ 84 | 0: 2e-4, 85 | 17: 1e-4, 86 | 22: 5e-5, 87 | 30: 1e-4, 88 | 50: 1e-5, 89 | 60: 5e-6, 90 | 70: 1e-6, 91 | }, 92 | ) 93 | 94 | denoise_uformer = TrainSchedule( 95 | max_epochs=80, 96 | stage1_epoch=30, 97 | base_lr=2e-4, 98 | lr_schedule={ 99 | 0: 2e-4, 100 | 17: 1e-4, 101 | 40: 5e-5, 102 | 50: 1e-5, 103 | 60: 5e-6, 104 | 70: 1e-6, 105 | }, 106 | ) 107 | 108 | 109 | denoise_complex_uformer = TrainSchedule( 110 | max_epochs=110, 111 | stage1_epoch=30, 112 | base_lr=1e-4, 113 | lr_schedule={ 114 | 80: 1e-4, 115 | 85: 5e-5, 116 | 90: 1e-5, 117 | 95: 5e-6, 118 | 100: 1e-6 119 | }, 120 | ) 121 | 122 | 123 | denoise_complex_swinir = TrainSchedule( 124 | max_epochs=110, 125 | stage1_epoch=30, 126 | base_lr=1e-4, 127 | lr_schedule={ 128 | 80: 1e-4, 129 | 85: 5e-5, 130 | 90: 1e-5, 131 | 95: 5e-6, 132 | 100: 1e-6 133 | }, 134 | ) 135 | 136 | 137 | denoise_complex_default = TrainSchedule( 138 | max_epochs=110, 139 | stage1_epoch=30, 140 | base_lr=1e-3, 141 | lr_schedule={ 142 | 80: 1e-3, 143 | 90: 5e-4, 144 | 95: 1e-4, 145 | 100: 5e-5, 146 | 105: 1e-5, 147 | }, 148 | ) 149 | 150 | denoise_complex_restormer = TrainSchedule( 151 | max_epochs=110, 152 | stage1_epoch=30, 153 | base_lr=1e-4, 154 | lr_schedule={ 155 | 80: 1e-4, 156 | 82: 5e-5, 157 | 90: 1e-5, 158 | 95: 5e-6, 159 | 100: 1e-6 160 | }, 161 | ) 162 | 163 | denoise_complex_hsid_cnn = TrainSchedule( 164 | max_epochs=110, 165 | stage1_epoch=30, 166 | base_lr=1e-4, 167 | lr_schedule={ 168 | 80: 1e-4, 169 | 95: 5e-5, 170 | 100: 1e-5, 171 | 105: 1e-6, 172 | }, 173 | ) 174 | -------------------------------------------------------------------------------- /hsirun/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from os.path import join 4 | 5 | from torchlight.utils import instantiate, locate 6 | from torchlight.nn.utils import adjust_learning_rate, get_learning_rate 7 | 8 | import hsir.data.dataloader as loaders 9 | from hsir.trainer import Trainer 10 | from hsir.scheduler import MultiStepSetLR 11 | 12 | 13 | def train_cfg(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--arch', '-a', required=True) 16 | parser.add_argument('--noise', default='gaussian', choices=['gaussian', 'complex']) 17 | parser.add_argument('--name', '-n', type=str, default=None, 18 | help='name of the experiment, if not specified, arch will be used.') 19 | parser.add_argument('--lr', type=float, default=None) 20 | parser.add_argument('-s', '--schedule', type=str, default='hsir.schedule.denoise_default') 21 | parser.add_argument('--resume', '-r', type=str, default=None) 22 | parser.add_argument('--bandwise', action='store_true') 23 | parser.add_argument('--use-conv2d', action='store_true') 24 | parser.add_argument('--train-root', type=str, default='data/ICVL64_31_100.db') 25 | parser.add_argument('--test-root', type=str, default='data') 26 | parser.add_argument('--save-root', type=str, default='checkpoints') 27 | parser.add_argument('--save-freq', type=int, default=10) 28 | parser.add_argument('--gpu-ids', type=str, default='0', help='gpu ids') 29 | cfg = parser.parse_args() 30 | cfg.gpu_ids = [int(id) for id in cfg.gpu_ids.split(',')] 31 | cfg.name = cfg.arch if cfg.name is None else cfg.name 32 | return cfg 33 | 34 | 35 | def main(): 36 | cfg = train_cfg() 37 | net = instantiate(cfg.arch) 38 | schedule = locate(cfg.schedule) 39 | trainer = Trainer( 40 | net, 41 | lr=schedule.base_lr, 42 | save_dir=join(cfg.save_root, cfg.name), 43 | gpu_ids=cfg.gpu_ids, 44 | bandwise=cfg.bandwise, 45 | ) 46 | trainer.logger.print(cfg) 47 | if cfg.resume: trainer.load(cfg.resume) 48 | 49 | # preare dataset 50 | if cfg.noise == 'gaussian': 51 | train_loader1 = loaders.gaussian_loader_train_s1(cfg.train_root, cfg.use_conv2d) 52 | train_loader2 = loaders.gaussian_loader_train_s2_16(cfg.train_root, cfg.use_conv2d) 53 | val_name = 'icvl_512_50' 54 | val_loader = loaders.gaussian_loader_val(join(cfg.test_root, val_name), cfg.use_conv2d) 55 | else: 56 | train_loader = loaders.complex_loader_train(cfg.train_root, cfg.use_conv2d) 57 | val_name = 'icvl_512_mixture' 58 | val_loader = loaders.complex_loader_val(join(cfg.test_root, val_name), cfg.use_conv2d) 59 | 60 | """Main loop""" 61 | if cfg.lr: adjust_learning_rate(trainer.optimizer, cfg.lr) # override lr 62 | lr_scheduler = MultiStepSetLR(trainer.optimizer, schedule.lr_schedule, epoch=trainer.epoch) 63 | epoch_per_save = cfg.save_freq 64 | best_psnr = 0 65 | while trainer.epoch < schedule.max_epochs: 66 | np.random.seed() # reset seed per epoch, otherwise the noise will be added with a specific pattern 67 | trainer.logger.print('Epoch [{}] Use lr={}'.format(trainer.epoch, get_learning_rate(trainer.optimizer))) 68 | 69 | # train 70 | if cfg.noise == 'gaussian': 71 | if trainer.epoch == 30: best_psnr = 0 72 | if trainer.epoch < 30: trainer.train(train_loader1) 73 | else: trainer.train(train_loader2, warm_up=trainer.epoch == 30) 74 | else: 75 | trainer.train(train_loader, warm_up=trainer.epoch == 80) 76 | 77 | # save ckpt 78 | metrics = trainer.validate(val_loader, val_name) 79 | if metrics['psnr'] > best_psnr: 80 | best_psnr = metrics['psnr'] 81 | trainer.save_checkpoint('model_best.pth') 82 | if trainer.epoch % epoch_per_save == 0: 83 | trainer.save_checkpoint() 84 | trainer.save_checkpoint('model_latest.pth') 85 | 86 | lr_scheduler.step() 87 | 88 | 89 | if __name__ == '__main__': 90 | main() 91 | -------------------------------------------------------------------------------- /hsir/model/hsidcnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ported from official caffe implementation 3 | https://github.com/qzhang95/HSID-CNN 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def hsid_cnn(): 11 | return HSIDCNN() 12 | 13 | 14 | class HSIDCNN(nn.Module): 15 | def __init__(self, num_adj_bands=12): 16 | super().__init__() 17 | self.num_adj_bands = num_adj_bands 18 | 19 | 20 | self.conv_k3 = nn.Conv2d(num_adj_bands * 2, 20, 3, 1, 1) 21 | self.conv_k5 = nn.Conv2d(num_adj_bands * 2, 20, 5, 1, 2) 22 | self.conv_k7 = nn.Conv2d(num_adj_bands * 2, 20, 7, 1, 3) 23 | 24 | self.conv_k3_2 = nn.Conv2d(1, 20, 3, 1, 1) 25 | self.conv_k5_2 = nn.Conv2d(1, 20, 5, 1, 2) 26 | self.conv_k7_2 = nn.Conv2d(1, 20, 7, 1, 3) 27 | 28 | self.conv1 = nn.Conv2d(120, 60, 3, 1, 1) 29 | self.conv2 = nn.Conv2d(60, 60, 3, 1, 1) 30 | self.conv3 = nn.Conv2d(60, 60, 3, 1, 1) 31 | self.conv4 = nn.Conv2d(60, 60, 3, 1, 1) 32 | self.conv5 = nn.Conv2d(60, 60, 3, 1, 1) 33 | self.conv6 = nn.Conv2d(60, 60, 3, 1, 1) 34 | self.conv7 = nn.Conv2d(60, 60, 3, 1, 1) 35 | self.conv8 = nn.Conv2d(60, 60, 3, 1, 1) 36 | self.conv9 = nn.Conv2d(60, 60, 3, 1, 1) 37 | 38 | self.tail = nn.Sequential( 39 | nn.Conv2d(60 * 4, 15, 3, 1, 1), 40 | nn.ReLU(), 41 | nn.Conv2d(15, 1, 3, 1, 1), 42 | ) 43 | 44 | self._reset_parameters() 45 | 46 | def _reset_parameters(self): 47 | for m in self.modules(): 48 | if isinstance(m, nn.Conv2d): 49 | torch.nn.init.kaiming_normal_(m.weight, mode='fan_out') 50 | torch.nn.init.constant_(m.bias, 0) 51 | 52 | def forward(self, x): 53 | num_bands = x.shape[1] 54 | outputs = [] 55 | 56 | inputs_x, inputs_adj_x = [], [] 57 | for i in range(self.num_adj_bands): 58 | inputs_x.append(x[:, i:i + 1, :, :]) 59 | inputs_adj_x.append(x[:, :self.num_adj_bands * 2, :, :]) 60 | 61 | for i in range(self.num_adj_bands, num_bands - self.num_adj_bands): 62 | adj = torch.cat([x[:, i - self.num_adj_bands:i, :, :], 63 | x[:, i + 1:i + 1 + self.num_adj_bands, :, :]], dim=1) 64 | inputs_x.append(x[:, i:i + 1, :, :]) 65 | inputs_adj_x.append(adj) 66 | 67 | for i in range(num_bands - self.num_adj_bands, num_bands): 68 | inputs_x.append(x[:, i:i + 1, :, :]) 69 | inputs_adj_x.append(x[:, -self.num_adj_bands * 2:, :, :]) 70 | 71 | for i in range(num_bands): 72 | output = self._forward(inputs_x[i], inputs_adj_x[i]) 73 | outputs.append(output) 74 | return torch.cat(outputs, dim=1) 75 | 76 | 77 | def _forward(self, x, adj_x): 78 | feat3 = self.conv_k3(adj_x) 79 | feat5 = self.conv_k5(adj_x) 80 | feat7 = self.conv_k7(adj_x) 81 | feat_3_5_7 = torch.cat([feat3, feat5, feat7], dim=1).relu() 82 | 83 | feat3_2 = self.conv_k3_2(x) 84 | feat5_2 = self.conv_k5_2(x) 85 | feat7_2 = self.conv_k7_2(x) 86 | feat_3_5_7_2 = torch.cat([feat3_2, feat5_2, feat7_2], dim=1).relu() 87 | 88 | feat_all = torch.cat([feat_3_5_7, feat_3_5_7_2], dim=1) 89 | 90 | tmp = self.conv1(feat_all).relu() 91 | tmp = self.conv2(tmp).relu() 92 | feat_conv3 = self.conv3(tmp).relu() 93 | 94 | tmp = self.conv4(feat_conv3).relu() 95 | feat_conv5 = self.conv5(tmp).relu() 96 | 97 | tmp = self.conv6(feat_conv5).relu() 98 | feat_conv7 = self.conv7(tmp).relu() 99 | 100 | tmp = self.conv8(feat_conv7).relu() 101 | feat_conv9 = self.conv9(tmp).relu() 102 | 103 | feat_all = torch.cat([feat_conv3, feat_conv5, feat_conv7, feat_conv9], dim=1) 104 | out = self.tail(feat_all) 105 | 106 | return out 107 | 108 | 109 | if __name__ == '__main__': 110 | net = HSIDCNN().cuda() 111 | x = torch.randn(4,31,64,64).cuda() 112 | out = net(x) 113 | print(out.shape) 114 | -------------------------------------------------------------------------------- /hsir/model/unet3d.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model from 3 | Hyperspectral Image Denoising With Realistic Data ICCV 2021 4 | """ 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.utils.data 9 | import torch 10 | import math 11 | 12 | class conv_block(nn.Module): 13 | """ 14 | Convolution Block 15 | """ 16 | def __init__(self, in_ch, out_ch): 17 | super(conv_block, self).__init__() 18 | 19 | self.conv = nn.Sequential( 20 | # nn.InstanceNorm3d(in_ch), 21 | nn.Conv3d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True), 22 | nn.LeakyReLU(negative_slope=0.01, inplace=True), 23 | # nn.InstanceNorm3d(out_ch), 24 | nn.Conv3d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True), 25 | nn.LeakyReLU(negative_slope=0.01, inplace=True)) 26 | 27 | self.conv_residual = nn.Conv3d(in_ch, out_ch, kernel_size=1, stride=1, bias=True) 28 | 29 | def forward(self, x): 30 | 31 | x = self.conv(x) + self.conv_residual(x) 32 | return x 33 | 34 | class U_Net_3D(nn.Module): 35 | def __init__(self, in_ch=1, out_ch=1, dim=32): 36 | super(U_Net_3D, self).__init__() 37 | 38 | n1 = dim 39 | filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] 40 | 41 | self.Down1 = nn.Conv3d(filters[0], filters[0], kernel_size=[1,4,4], stride=[1,2,2], padding=[0,1,1], bias=True) 42 | self.Down2 = nn.Conv3d(filters[1], filters[1], kernel_size=[1,4,4], stride=[1,2,2], padding=[0,1,1], bias=True) 43 | self.Down3 = nn.Conv3d(filters[2], filters[2], kernel_size=[1,4,4], stride=[1,2,2], padding=[0,1,1], bias=True) 44 | self.Down4 = nn.Conv3d(filters[3], filters[3], kernel_size=[1,4,4], stride=[1,2,2], padding=[0,1,1], bias=True) 45 | 46 | self.Conv1 = conv_block(in_ch, filters[0]) 47 | self.Conv2 = conv_block(filters[0], filters[1]) 48 | self.Conv3 = conv_block(filters[1], filters[2]) 49 | self.Conv4 = conv_block(filters[2], filters[3]) 50 | self.Conv5 = conv_block(filters[3], filters[4]) 51 | 52 | self.Up5 = nn.ConvTranspose3d(filters[4], filters[3], kernel_size=[1,2,2], stride=[1,2,2], padding=0, bias=True) 53 | self.Up_conv5 = conv_block(filters[4], filters[3]) 54 | 55 | self.Up4 = nn.ConvTranspose3d(filters[3], filters[2], kernel_size=[1,2,2], stride=[1,2,2], padding=0, bias=True) 56 | self.Up_conv4 = conv_block(filters[3], filters[2]) 57 | 58 | self.Up3 = nn.ConvTranspose3d(filters[2], filters[1], kernel_size=[1,2,2], stride=[1,2,2], padding=0, bias=True) 59 | self.Up_conv3 = conv_block(filters[2], filters[1]) 60 | 61 | self.Up2 = nn.ConvTranspose3d(filters[1], filters[0], kernel_size=[1,2,2], stride=[1,2,2], padding=0, bias=True) 62 | self.Up_conv2 = conv_block(filters[1], filters[0]) 63 | 64 | self.Conv = nn.Conv3d(filters[0], out_ch, kernel_size=1, stride=1, padding=0) 65 | 66 | def forward(self, x): 67 | 68 | e1 = self.Conv1(x) 69 | 70 | e2 = self.Down1(e1) 71 | e2 = self.Conv2(e2) 72 | 73 | e3 = self.Down2(e2) 74 | e3 = self.Conv3(e3) 75 | 76 | e4 = self.Down3(e3) 77 | e4 = self.Conv4(e4) 78 | 79 | e5 = self.Down4(e4) 80 | e5 = self.Conv5(e5) 81 | 82 | d5 = self.Up5(e5) 83 | d5 = torch.cat((e4, d5), dim=1) 84 | d5 = self.Up_conv5(d5) 85 | 86 | d4 = self.Up4(d5) 87 | d4 = torch.cat((e3, d4), dim=1) 88 | d4 = self.Up_conv4(d4) 89 | 90 | d3 = self.Up3(d4) 91 | d3 = torch.cat((e2, d3), dim=1) 92 | d3 = self.Up_conv3(d3) 93 | 94 | d2 = self.Up2(d3) 95 | d2 = torch.cat((e1, d2), dim=1) 96 | d2 = self.Up_conv2(d2) 97 | 98 | out = self.Conv(d2) 99 | 100 | return out+x 101 | 102 | def unet3d(): 103 | net = U_Net_3D() 104 | net.use_2dconv = False 105 | net.bandwise = False 106 | return net 107 | 108 | def unet3d_m(): 109 | net = U_Net_3D(dim=16) 110 | net.use_2dconv = False 111 | net.bandwise = False 112 | return net -------------------------------------------------------------------------------- /hsir/model/grunet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.data 3 | import torch 4 | 5 | from .qrnn3d.qrnn import QRNNConv3D, QRNNUpsampleConv3d, BiQRNNConv3D, BiQRNNDeConv3D, QRNNDeConv3D 6 | 7 | 8 | class conv_block(nn.Module): 9 | def __init__(self, in_ch, out_ch, bn=False): 10 | super(conv_block, self).__init__() 11 | 12 | self.conv1 = QRNNConv3D(in_ch, out_ch, bn=bn) 13 | self.conv2 = QRNNConv3D(out_ch, out_ch, bn=bn) 14 | self.conv_residual = QRNNConv3D(in_ch, out_ch, k=1, s=1, p=0, bn=bn) 15 | 16 | def forward(self, x, reverse=False): 17 | residual = self.conv2(self.conv1(x, reverse=reverse), reverse=reverse) 18 | x = residual + self.conv_residual(x, reverse=reverse) 19 | return x 20 | 21 | 22 | class deconv_block(nn.Module): 23 | def __init__(self, in_ch, out_ch, bn=False): 24 | super(deconv_block, self).__init__() 25 | 26 | self.conv1 = QRNNDeConv3D(in_ch, out_ch, bn=bn) 27 | self.conv2 = QRNNDeConv3D(out_ch, out_ch, bn=bn) 28 | self.conv_residual = QRNNDeConv3D(in_ch, out_ch, k=1, s=1, p=0, bn=bn) 29 | 30 | def forward(self, x, reverse=False): 31 | residual = self.conv2(self.conv1(x, reverse=reverse), reverse=reverse) 32 | x = residual + self.conv_residual(x, reverse=reverse) 33 | return x 34 | 35 | 36 | class GRUnet(nn.Module): 37 | def __init__(self, in_ch=1, out_ch=1, use_noise_map=False, bn=False): 38 | super(GRUnet, self).__init__() 39 | self.use_2dconv = False 40 | self.bandwise = False 41 | self.use_noise_map = use_noise_map 42 | 43 | n1 = 16 44 | filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] 45 | 46 | self.Down1 = QRNNConv3D(filters[0], filters[0], k=3, s=(1, 2, 2), p=1, bn=bn) 47 | self.Down2 = QRNNConv3D(filters[1], filters[1], k=3, s=(1, 2, 2), p=1, bn=bn) 48 | self.Down3 = QRNNConv3D(filters[2], filters[2], k=3, s=(1, 2, 2), p=1, bn=bn) 49 | self.Down4 = QRNNConv3D(filters[3], filters[3], k=3, s=(1, 2, 2), p=1, bn=bn) 50 | 51 | self.Conv1 = BiQRNNConv3D(in_ch, filters[0], bn=bn) 52 | self.Conv2 = conv_block(filters[0], filters[1], bn=bn) 53 | self.Conv3 = conv_block(filters[1], filters[2], bn=bn) 54 | self.Conv4 = conv_block(filters[2], filters[3], bn=bn) 55 | self.Conv5 = conv_block(filters[3], filters[4], bn=bn) 56 | 57 | self.Up5 = QRNNUpsampleConv3d(filters[4], filters[3], bn=bn) 58 | self.Up_conv5 = deconv_block(filters[4], filters[3], bn=bn) 59 | 60 | self.Up4 = QRNNUpsampleConv3d(filters[3], filters[2], bn=bn) 61 | self.Up_conv4 = deconv_block(filters[3], filters[2], bn=bn) 62 | 63 | self.Up3 = QRNNUpsampleConv3d(filters[2], filters[1], bn=bn) 64 | self.Up_conv3 = deconv_block(filters[2], filters[1], bn=bn) 65 | 66 | self.Up2 = QRNNUpsampleConv3d(filters[1], filters[0], bn=bn) 67 | self.Up_conv2 = deconv_block(filters[1], filters[0], bn=bn) 68 | 69 | self.Conv = BiQRNNDeConv3D(filters[0], 1, bias=True, bn=bn) 70 | 71 | def forward(self, x): 72 | # x: [B, C, B, W, H] 73 | e1 = self.Conv1(x) 74 | 75 | e2 = self.Down1(e1, reverse=True) 76 | e2 = self.Conv2(e2, reverse=False) 77 | 78 | e3 = self.Down2(e2, reverse=True) 79 | e3 = self.Conv3(e3, reverse=False) 80 | 81 | e4 = self.Down3(e3, reverse=True) 82 | e4 = self.Conv4(e4, reverse=False) 83 | 84 | e5 = self.Down4(e4, reverse=True) 85 | e5 = self.Conv5(e5, reverse=False) 86 | 87 | d5 = self.Up5(e5, reverse=True) 88 | d5 = torch.cat((e4, d5), dim=1) 89 | d5 = self.Up_conv5(d5, reverse=False) 90 | 91 | d4 = self.Up4(d5, reverse=True) 92 | d4 = torch.cat((e3, d4), dim=1) 93 | d4 = self.Up_conv4(d4, reverse=False) 94 | 95 | d3 = self.Up3(d4, reverse=True) 96 | d3 = torch.cat((e2, d3), dim=1) 97 | d3 = self.Up_conv3(d3, reverse=False) 98 | 99 | d2 = self.Up2(d3, reverse=True) 100 | d2 = torch.cat((e1, d2), dim=1) 101 | d2 = self.Up_conv2(d2, reverse=False) 102 | 103 | out = self.Conv(d2) 104 | 105 | if self.use_noise_map: 106 | return out + x[:, 0, :, :, :].unsqueeze(1) 107 | else: 108 | return out + x 109 | 110 | 111 | def grunet(): 112 | return GRUnet() 113 | 114 | 115 | def grunet_noise_map(): 116 | return GRUnet(in_ch=2, use_noise_map=True) 117 | -------------------------------------------------------------------------------- /hsiboard/cmp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import imageio 3 | from os.path import join 4 | import numpy as np 5 | import json 6 | import zipfile 7 | import io 8 | 9 | from util import * 10 | from box import * 11 | 12 | 13 | def main(logdir): 14 | set_page_container_style() 15 | with st.sidebar: 16 | st.title('HSIR Board') 17 | st.subheader('Comparator') 18 | st.caption(f'Directory: {logdir}') 19 | methods = listdir(logdir) 20 | selected_methods = st.sidebar.multiselect( 21 | "Select Method", 22 | methods, 23 | default=methods 24 | ) 25 | datasets_list = [listdir(join(logdir, m)) for m in selected_methods] 26 | datasets = set(datasets_list[0]).intersection(*map(set, datasets_list)) 27 | selected_dataset = st.sidebar.selectbox( 28 | "Select Dataset", 29 | datasets 30 | ) 31 | selected_vis_type = st.sidebar.selectbox( 32 | "Select Image Type", 33 | ['color', 'gray'] 34 | ) 35 | 36 | img_names = os.listdir(join(logdir, selected_methods[0], selected_dataset, selected_vis_type)) 37 | selected_img = st.sidebar.selectbox( 38 | "Select Image", 39 | img_names 40 | ) 41 | 42 | st.header('Box') 43 | 44 | enable_enlarge = st.checkbox('Enlarge') 45 | enable_diff = st.checkbox('Difference Map') 46 | enable_sidebyside = st.checkbox('Side by Side') 47 | 48 | selected_box_pos = st.sidebar.selectbox( 49 | "Select Box Position", 50 | ['Bottom Right', 'Bottom Left', 'Top Right', 'Top Left'], 51 | ) 52 | 53 | crow = st.slider('row coordinate', min_value=0.0, max_value=1.0, value=0.2) 54 | ccol = st.slider('col coordinate', min_value=0.0, max_value=1.0, value=0.2) 55 | vmax = st.slider('vmax', min_value=0.0, max_value=1.0, value=0.1) 56 | 57 | st.header('Layout') 58 | ncol = st.slider('number of columns', min_value=1, max_value=20, value=4) 59 | 60 | imgs = {} 61 | stats = {} 62 | for m in selected_methods: 63 | img = imageio.imread(join(logdir, m, selected_dataset, selected_vis_type, selected_img)) 64 | imgs[m] = np.array(img, dtype=np.float32) / 255 65 | stat = json.load(open(join(logdir, m, selected_dataset, 'log.json'))) 66 | stats[m] = stat['detail'][os.path.splitext(selected_img)[-2]] 67 | 68 | gt = imgs['gt'] 69 | nrow = len(imgs) // ncol + 1 70 | grids = make_grid(nrow, ncol) 71 | 72 | download_data = {'img': {}, 'meta': {}} 73 | with st.container(): 74 | idx = 0 75 | for name, im in imgs.items(): 76 | img = im.copy() 77 | name = os.path.splitext(name)[-2] 78 | if enable_diff: 79 | diff = np.abs(img - gt) 80 | if len(diff.shape) == 3: diff = diff.mean(-1) 81 | img = convert_color(diff, vmax=vmax) 82 | 83 | if enable_sidebyside: 84 | h, w = img.shape[:2] 85 | img = addbox_with_diff(img, gt, (int(h * crow), int(w * ccol)), vmax=vmax) 86 | 87 | elif enable_enlarge: 88 | h, w = img.shape[:2] 89 | img = addbox(img, (int(h * crow), int(w * ccol)), bbpos=mapbbpox[selected_box_pos]) 90 | 91 | ct = grids[idx // ncol][idx % ncol] 92 | ct.image(img, caption='%s [%.4f]' % (name, stats[name]['MPSNR']), clamp=[0, 1]) 93 | idx += 1 94 | 95 | download_data['img'][name] = img 96 | download_data['meta'][name] = stats[name] 97 | 98 | with st.sidebar: 99 | st.header('Download') 100 | filename = os.path.splitext(selected_img)[-2] 101 | st.download_button('Donwload ' + filename, 102 | to_zip(download_data), 103 | file_name=f'{filename}_{selected_dataset}.zip', 104 | mime='application/zip') 105 | 106 | 107 | def to_zip(data): 108 | zip_buffer = io.BytesIO() 109 | with zipfile.ZipFile(zip_buffer, "a", 110 | zipfile.ZIP_DEFLATED, False) as zip_file: 111 | for file_name, img in data['img'].items(): 112 | zip_file.writestr(file_name + '.png', encode_image(img)) 113 | 114 | zip_file.writestr('meta.json', json.dumps(data['meta'], indent=4)) 115 | return zip_buffer 116 | 117 | 118 | if __name__ == '__main__': 119 | parser = argparse.ArgumentParser('HSIR Board') 120 | parser.add_argument('--logdir', default='results') 121 | args = parser.parse_args() 122 | 123 | main(args.logdir) 124 | -------------------------------------------------------------------------------- /hsir/model/trq3d/conv.py: -------------------------------------------------------------------------------- 1 | from .combinations import * 2 | 3 | 4 | class QRNN3DLayer(nn.Module): 5 | def __init__(self, in_channels, hidden_channels, conv_layer, act='tanh'): 6 | super(QRNN3DLayer, self).__init__() 7 | self.in_channels = in_channels 8 | self.hidden_channels = hidden_channels 9 | # quasi_conv_layer 10 | self.conv = conv_layer 11 | self.act = act 12 | 13 | def _conv_step(self, inputs): 14 | gates = self.conv(inputs) 15 | Z, F = gates.split(split_size=self.hidden_channels, dim=1) 16 | if self.act == 'tanh': 17 | return Z.tanh(), F.sigmoid() 18 | elif self.act == 'relu': 19 | return Z.relu(), F.sigmoid() 20 | elif self.act == 'none': 21 | return Z, F.sigmoid 22 | else: 23 | raise NotImplementedError 24 | 25 | def _rnn_step(self, z, f, h): 26 | # uses 'f pooling' at each time step 27 | h_ = (1 - f) * z if h is None else f * h + (1 - f) * z 28 | return h_ 29 | 30 | def forward(self, inputs, reverse=False): 31 | h = None 32 | Z, F = self._conv_step(inputs) 33 | h_time = [] 34 | 35 | if not reverse: 36 | for time, (z, f) in enumerate(zip(Z.split(1, 2), F.split(1, 2))): # split along timestep 37 | h = self._rnn_step(z, f, h) 38 | h_time.append(h) 39 | else: 40 | for time, (z, f) in enumerate((zip( 41 | reversed(Z.split(1, 2)), reversed(F.split(1, 2)) 42 | ))): # split along timestep 43 | h = self._rnn_step(z, f, h) 44 | h_time.insert(0, h) 45 | 46 | # return concatenated hidden states 47 | return torch.cat(h_time, dim=2) 48 | 49 | def extra_repr(self): 50 | return 'act={}'.format(self.act) 51 | 52 | 53 | class BiQRNN3DLayer(QRNN3DLayer): 54 | def _conv_step(self, inputs): 55 | gates = self.conv(inputs) 56 | Z, F1, F2 = gates.split(split_size=self.hidden_channels, dim=1) 57 | if self.act == 'tanh': 58 | return Z.tanh(), F1.sigmoid(), F2.sigmoid() 59 | elif self.act == 'relu': 60 | return Z.relu(), F1.sigmoid(), F2.sigmoid() 61 | elif self.act == 'none': 62 | return Z, F1.sigmoid(), F2.sigmoid() 63 | else: 64 | raise NotImplementedError 65 | 66 | def forward(self, inputs, fname=None): 67 | h = None 68 | Z, F1, F2 = self._conv_step(inputs) 69 | hsl = [] 70 | hsr = [] 71 | zs = Z.split(1, 2) 72 | 73 | for time, (z, f) in enumerate(zip(zs, F1.split(1, 2))): # split along timestep 74 | h = self._rnn_step(z, f, h) 75 | hsl.append(h) 76 | 77 | h = None 78 | for time, (z, f) in enumerate((zip( 79 | reversed(zs), reversed(F2.split(1, 2)) 80 | ))): # split along timestep 81 | h = self._rnn_step(z, f, h) 82 | hsr.insert(0, h) 83 | 84 | # return concatenated hidden states 85 | hsl = torch.cat(hsl, dim=2) 86 | hsr = torch.cat(hsr, dim=2) 87 | 88 | if fname is not None: 89 | stats_dict = {'z': Z, 'fl': F1, 'fr': F2, 'hsl': hsl, 'hsr': hsr} 90 | torch.save(stats_dict, fname) 91 | return hsl + hsr 92 | 93 | 94 | class BiQRNNConv3D(BiQRNN3DLayer): 95 | def __init__(self, in_channels, hidden_channels, k=3, s=1, p=1, bn=True, act='tanh'): 96 | super(BiQRNNConv3D, self).__init__( 97 | in_channels, hidden_channels, BasicConv3d(in_channels, hidden_channels * 3, k, s, p, bn=bn), act=act) 98 | 99 | 100 | class BiQRNNDeConv3D(BiQRNN3DLayer): 101 | def __init__(self, in_channels, hidden_channels, k=3, s=1, p=1, bias=False, bn=True, act='tanh'): 102 | super(BiQRNNDeConv3D, self).__init__( 103 | in_channels, hidden_channels, BasicDeConv3d(in_channels, hidden_channels * 3, k, s, p, bias=bias, bn=bn), 104 | act=act) 105 | 106 | 107 | class QRNNConv3D(QRNN3DLayer): 108 | def __init__(self, in_channels, hidden_channels, k=3, s=1, p=1, bn=True, act='tanh'): 109 | super(QRNNConv3D, self).__init__( 110 | in_channels, hidden_channels, BasicConv3d(in_channels, hidden_channels * 2, k, s, p, bn=bn), act=act) 111 | 112 | 113 | class QRNNDeConv3D(QRNN3DLayer): 114 | def __init__(self, in_channels, hidden_channels, k=3, s=1, p=1, bn=True, act='tanh'): 115 | super(QRNNDeConv3D, self).__init__( 116 | in_channels, hidden_channels, BasicDeConv3d(in_channels, hidden_channels * 2, k, s, p, bn=bn), act=act) 117 | 118 | 119 | class QRNNUpsampleConv3D(QRNN3DLayer): 120 | def __init__(self, in_channels, hidden_channels, k=3, s=1, p=1, upsample=(1, 2, 2), bn=True, act='tanh'): 121 | super(QRNNUpsampleConv3D, self).__init__( 122 | in_channels, hidden_channels, 123 | BasicUpsampleConv3d(in_channels, hidden_channels * 2, k, s, p, upsample, bn=bn), act=act) 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /hsir/model/qrnn3d/qrnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .conv import * 5 | 6 | 7 | """F pooling""" 8 | 9 | 10 | class QRNN3DLayer(nn.Module): 11 | def __init__(self, in_channels, hidden_channels, conv_layer, act='tanh'): 12 | super(QRNN3DLayer, self).__init__() 13 | self.in_channels = in_channels 14 | self.hidden_channels = hidden_channels 15 | # quasi_conv_layer 16 | self.conv = conv_layer 17 | self.act = act 18 | 19 | def _conv_step(self, inputs): 20 | gates = self.conv(inputs) 21 | Z, F = gates.split(split_size=self.hidden_channels, dim=1) 22 | if self.act == 'tanh': 23 | return Z.tanh(), F.sigmoid() 24 | elif self.act == 'relu': 25 | return Z.relu(), F.sigmoid() 26 | elif self.act == 'none': 27 | return Z, F.sigmoid 28 | else: 29 | raise NotImplementedError 30 | 31 | def _rnn_step(self, z, f, h): 32 | # uses 'f pooling' at each time step 33 | h_ = (1 - f) * z if h is None else f * h + (1 - f) * z 34 | return h_ 35 | 36 | def forward(self, inputs, reverse=False): 37 | h = None 38 | Z, F = self._conv_step(inputs) 39 | h_time = [] 40 | 41 | if not reverse: 42 | for time, (z, f) in enumerate(zip(Z.split(1, 2), F.split(1, 2))): # split along timestep 43 | h = self._rnn_step(z, f, h) 44 | h_time.append(h) 45 | else: 46 | for time, (z, f) in enumerate((zip( 47 | reversed(Z.split(1, 2)), reversed(F.split(1, 2)) 48 | ))): # split along timestep 49 | h = self._rnn_step(z, f, h) 50 | h_time.insert(0, h) 51 | 52 | # return concatenated hidden states 53 | return torch.cat(h_time, dim=2) 54 | 55 | def extra_repr(self): 56 | return 'act={}'.format(self.act) 57 | 58 | 59 | class BiQRNN3DLayer(QRNN3DLayer): 60 | def _conv_step(self, inputs): 61 | gates = self.conv(inputs) 62 | Z, F1, F2 = gates.split(split_size=self.hidden_channels, dim=1) 63 | if self.act == 'tanh': 64 | return Z.tanh(), F1.sigmoid(), F2.sigmoid() 65 | elif self.act == 'relu': 66 | return Z.relu(), F1.sigmoid(), F2.sigmoid() 67 | elif self.act == 'none': 68 | return Z, F1.sigmoid(), F2.sigmoid() 69 | else: 70 | raise NotImplementedError 71 | 72 | def forward(self, inputs, fname=None): 73 | h = None 74 | Z, F1, F2 = self._conv_step(inputs) 75 | hsl = [] 76 | hsr = [] 77 | zs = Z.split(1, 2) 78 | 79 | for time, (z, f) in enumerate(zip(zs, F1.split(1, 2))): # split along timestep 80 | h = self._rnn_step(z, f, h) 81 | hsl.append(h) 82 | 83 | h = None 84 | for time, (z, f) in enumerate((zip( 85 | reversed(zs), reversed(F2.split(1, 2)) 86 | ))): # split along timestep 87 | h = self._rnn_step(z, f, h) 88 | hsr.insert(0, h) 89 | 90 | # return concatenated hidden states 91 | hsl = torch.cat(hsl, dim=2) 92 | hsr = torch.cat(hsr, dim=2) 93 | 94 | if fname is not None: 95 | stats_dict = {'z': Z, 'fl': F1, 'fr': F2, 'hsl': hsl, 'hsr': hsr} 96 | torch.save(stats_dict, fname) 97 | return hsl + hsr 98 | 99 | 100 | class BiQRNNConv3D(BiQRNN3DLayer): 101 | def __init__(self, in_channels, hidden_channels, k=3, s=1, p=1, bn=True, act='tanh'): 102 | super(BiQRNNConv3D, self).__init__( 103 | in_channels, hidden_channels, BasicConv3d(in_channels, hidden_channels*3, k, s, p, bn=bn), act=act) 104 | 105 | 106 | class BiQRNNDeConv3D(BiQRNN3DLayer): 107 | def __init__(self, in_channels, hidden_channels, k=3, s=1, p=1, bias=False, bn=True, act='tanh'): 108 | super(BiQRNNDeConv3D, self).__init__( 109 | in_channels, hidden_channels, BasicDeConv3d(in_channels, hidden_channels*3, k, s, p, bias=bias, bn=bn), act=act) 110 | 111 | 112 | class QRNNConv3D(QRNN3DLayer): 113 | def __init__(self, in_channels, hidden_channels, k=3, s=1, p=1, bn=True, act='tanh'): 114 | super(QRNNConv3D, self).__init__( 115 | in_channels, hidden_channels, BasicConv3d(in_channels, hidden_channels*2, k, s, p, bn=bn), act=act) 116 | 117 | 118 | class QRNNDeConv3D(QRNN3DLayer): 119 | def __init__(self, in_channels, hidden_channels, k=3, s=1, p=1, bn=True, act='tanh'): 120 | super(QRNNDeConv3D, self).__init__( 121 | in_channels, hidden_channels, BasicDeConv3d(in_channels, hidden_channels*2, k, s, p, bn=bn), act=act) 122 | 123 | 124 | class QRNNUpsampleConv3d(QRNN3DLayer): 125 | def __init__(self, in_channels, hidden_channels, k=3, s=1, p=1, upsample=(1, 2, 2), bn=True, act='tanh'): 126 | super(QRNNUpsampleConv3d, self).__init__( 127 | in_channels, hidden_channels, BasicUpsampleConv3d(in_channels, hidden_channels*2, k, s, p, upsample, bn=bn), act=act) 128 | -------------------------------------------------------------------------------- /hsir/model/hsdt/sepconv.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import torch.nn as nn 3 | import torch 4 | 5 | 6 | def repeat(x): 7 | if isinstance(x, (tuple, list)): 8 | return x 9 | return [x] * 3 10 | 11 | 12 | class SepConv_PD(nn.Module): 13 | def __init__(self, BASECONV, in_ch, out_ch, k, s=1, p=1, bias=False): 14 | super().__init__() 15 | k, s, p = repeat(k), repeat(s), repeat(p) 16 | 17 | padding_mode = 'zeros' 18 | self.pw_conv = BASECONV(in_ch, out_ch, (k[0], 1, 1), (s[0], 1, 1), (p[0], 0, 0), bias=bias, padding_mode=padding_mode) 19 | self.dw_conv = BASECONV(out_ch, out_ch, (1, k[1], k[2]), (1, s[1], s[2]), (0, p[1], p[2]), bias=bias, padding_mode=padding_mode) 20 | 21 | def forward(self, x): 22 | x = self.pw_conv(x) 23 | x = self.dw_conv(x) 24 | return x 25 | 26 | @staticmethod 27 | def of(base_conv): 28 | return partial(SepConv_PD, base_conv) 29 | 30 | 31 | class SepConv_DP(nn.Module): 32 | def __init__(self, BASECONV, in_ch, out_ch, k, s=1, p=1, bias=False): 33 | super().__init__() 34 | k, s, p = repeat(k), repeat(s), repeat(p) 35 | 36 | padding_mode = 'zeros' 37 | self.dw_conv = BASECONV(in_ch, out_ch, (1, k[1], k[2]), (1, s[1], s[2]), (0, p[1], p[2]), bias=bias, padding_mode=padding_mode) 38 | self.pw_conv = BASECONV(out_ch, out_ch, (k[0], 1, 1), (s[0], 1, 1), (p[0], 0, 0), bias=bias, padding_mode=padding_mode) 39 | 40 | def forward(self, x): 41 | x = self.dw_conv(x) 42 | x = self.pw_conv(x) 43 | return x 44 | 45 | @staticmethod 46 | def of(base_conv): 47 | return partial(SepConv_DP, base_conv) 48 | 49 | 50 | class SepConv_DP_CA(nn.Module): 51 | def __init__(self, BASECONV, in_ch, out_ch, k, s=1, p=1, bias=False): 52 | super().__init__() 53 | k, s, p = repeat(k), repeat(s), repeat(p) 54 | 55 | padding_mode = 'zeros' 56 | self.dw_conv = BASECONV(in_ch, out_ch, (1, k[1], k[2]), (1, s[1], s[2]), (0, p[1], p[2]), bias=bias, padding_mode=padding_mode) 57 | self.pw_conv = BASECONV(out_ch, out_ch * 2, (k[0], 1, 1), (s[0], 1, 1), (p[0], 0, 0), bias=bias, padding_mode=padding_mode) 58 | 59 | def forward(self, x): 60 | x = self.dw_conv(x) 61 | x = self.pw_conv(x) 62 | x, w = torch.chunk(x, 2, dim=1) 63 | x = x * w.sigmoid() 64 | return x 65 | 66 | @staticmethod 67 | def of(base_conv): 68 | return partial(SepConv_DP_CA, base_conv) 69 | 70 | 71 | class S3Conv(nn.Module): 72 | # deep wise then point wise 73 | def __init__(self, BASECONV, in_ch, out_ch, k, s=1, p=1, bias=False): 74 | super().__init__() 75 | k, s, p = repeat(k), repeat(s), repeat(p) 76 | 77 | padding_mode = 'zeros' 78 | self.dw_conv = nn.Sequential( 79 | BASECONV(in_ch, out_ch, (1, k[1], k[2]), (1, s[1], s[2]), (0, p[1], p[2]), bias=bias, padding_mode=padding_mode), 80 | nn.LeakyReLU(), 81 | BASECONV(out_ch, out_ch, (1, k[1], k[2]), 1, (0, p[1], p[2]), bias=bias, padding_mode=padding_mode), 82 | nn.LeakyReLU(), 83 | ) 84 | self.pw_conv = BASECONV(in_ch, out_ch, (k[0], 1, 1), (s[0], 1, 1), (p[0], 0, 0), bias=bias, padding_mode=padding_mode) 85 | 86 | def forward(self, x): 87 | x1 = self.dw_conv(x) 88 | x2 = self.pw_conv(x) 89 | return x1 + x2 90 | 91 | @staticmethod 92 | def of(base_conv): 93 | return partial(S3Conv, base_conv) 94 | 95 | 96 | class S3Conv_Seq(nn.Module): 97 | def __init__(self, BASECONV, in_ch, out_ch, k, s=1, p=1, bias=False): 98 | super().__init__() 99 | k, s, p = repeat(k), repeat(s), repeat(p) 100 | 101 | padding_mode = 'zeros' 102 | self.dw_conv = nn.Sequential( 103 | BASECONV(in_ch, out_ch, (1, k[1], k[2]), (1, s[1], s[2]), (0, p[1], p[2]), bias=bias, padding_mode=padding_mode), 104 | nn.LeakyReLU(), 105 | BASECONV(out_ch, out_ch, (1, k[1], k[2]), 1, (0, p[1], p[2]), bias=bias, padding_mode=padding_mode), 106 | nn.LeakyReLU(), 107 | ) 108 | self.pw_conv = BASECONV(out_ch, out_ch, (k[0], 1, 1), (s[0], 1, 1), (p[0], 0, 0), bias=bias, padding_mode=padding_mode) 109 | 110 | def forward(self, x): 111 | x = self.dw_conv(x) 112 | x = self.pw_conv(x) 113 | return x 114 | 115 | @staticmethod 116 | def of(base_conv): 117 | return partial(S3Conv_Seq, base_conv) 118 | 119 | 120 | class S3Conv1(nn.Module): 121 | def __init__(self, BASECONV, in_ch, out_ch, k, s=1, p=1, bias=False): 122 | super().__init__() 123 | k, s, p = repeat(k), repeat(s), repeat(p) 124 | 125 | padding_mode = 'zeros' 126 | self.dw_conv = nn.Sequential( 127 | BASECONV(in_ch, out_ch, (1, k[1], k[2]), (1, s[1], s[2]), (0, p[1], p[2]), bias=bias, padding_mode=padding_mode), 128 | nn.LeakyReLU(), 129 | ) 130 | self.pw_conv = BASECONV(in_ch, out_ch, (k[0], 1, 1), (s[0], 1, 1), (p[0], 0, 0), bias=bias, padding_mode=padding_mode) 131 | 132 | def forward(self, x): 133 | x1 = self.dw_conv(x) 134 | x2 = self.pw_conv(x) 135 | return x1 + x2 136 | 137 | @staticmethod 138 | def of(base_conv): 139 | return partial(S3Conv1, base_conv) 140 | -------------------------------------------------------------------------------- /hsir/model/trq3d/combinations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional 4 | from sync_batchnorm import SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 5 | 6 | BatchNorm2d = SynchronizedBatchNorm2d 7 | BatchNorm3d = SynchronizedBatchNorm3d 8 | 9 | 10 | class BNReLUConv3d(nn.Sequential): 11 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False): 12 | super(BNReLUConv3d, self).__init__() 13 | self.add_module('bn', BatchNorm3d(in_channels)) 14 | self.add_module('relu', nn.ReLU(inplace=inplace)) 15 | self.add_module('conv', nn.Conv3d(in_channels, channels, k, s, p, bias=False)) 16 | 17 | 18 | class BNReLUDeConv3d(nn.Sequential): 19 | def __init__(self, in_channels, channels, k=3, s=1, p=1, inplace=False): 20 | super(BNReLUDeConv3d, self).__init__() 21 | self.add_module('bn', BatchNorm3d(in_channels)) 22 | self.add_module('relu', nn.ReLU(inplace=inplace)) 23 | self.add_module('deconv', nn.ConvTranspose3d(in_channels, channels, k, s, p, bias=False)) 24 | 25 | 26 | class BNReLUUpsampleConv3d(nn.Sequential): 27 | def __init__(self, in_channels, channels, k=3, s=1, p=1, upsample=(1, 2, 2), inplace=False): 28 | super(BNReLUUpsampleConv3d, self).__init__() 29 | self.add_module('bn', BatchNorm3d(in_channels)) 30 | self.add_module('relu', nn.ReLU(inplace=inplace)) 31 | self.add_module('upsample_conv', UpsampleConv3d(in_channels, channels, k, s, p, bias=False, upsample=upsample)) 32 | 33 | class UpsampleConv2d(torch.nn.Module): 34 | """UpsampleConvLayer 35 | Upsamples the input and then does a convolution. This method gives better results 36 | compared to ConvTranspose2d. 37 | ref: http://distill.pub/2016/deconv-checkerboard/ 38 | """ 39 | 40 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=True, upsample=None): 41 | super(UpsampleConv2d, self).__init__() 42 | self.upsample = upsample 43 | if upsample: 44 | self.upsample_layer = torch.nn.Upsample(scale_factor=upsample, mode='bilinear', align_corners=True) 45 | 46 | self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) 47 | 48 | def forward(self, x): 49 | x_in = x 50 | if self.upsample: 51 | x_in = self.upsample_layer(x_in) 52 | out = self.conv2d(x_in) 53 | return out 54 | 55 | class UpsampleConv3d(torch.nn.Module): 56 | """UpsampleConvLayer 57 | Upsamples the input and then does a convolution. This method gives better results 58 | compared to ConvTranspose2d. 59 | ref: http://distill.pub/2016/deconv-checkerboard/ 60 | """ 61 | 62 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=True, upsample=None): 63 | super(UpsampleConv3d, self).__init__() 64 | self.upsample = upsample 65 | if upsample: 66 | self.upsample_layer = torch.nn.Upsample(scale_factor=upsample, mode='trilinear', align_corners=True) 67 | 68 | self.conv3d = torch.nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) 69 | 70 | def forward(self, x): 71 | x_in = x 72 | if self.upsample: 73 | x_in = self.upsample_layer(x_in) 74 | out = self.conv3d(x_in) 75 | return out 76 | 77 | class BasicConv2d(nn.Sequential): 78 | def __init__(self, in_channels, channels, k=3, s=1, p=1, bias=False, bn=True): 79 | super(BasicConv2d, self).__init__() 80 | if bn: 81 | self.add_module('bn', BatchNorm2d(in_channels)) 82 | self.add_module('conv', nn.Conv2d(in_channels, channels, k, s, p, bias=bias)) 83 | 84 | class BasicConv3d(nn.Sequential): 85 | def __init__(self, in_channels, channels, k=3, s=1, p=1, bias=False, bn=True): 86 | super(BasicConv3d, self).__init__() 87 | if bn: 88 | self.add_module('bn', BatchNorm3d(in_channels)) 89 | self.add_module('conv', nn.Conv3d(in_channels, channels, k, s, p, bias=bias)) 90 | 91 | class BasicDeConv2d(nn.Sequential): 92 | def __init__(self, in_channels, channels, k=3, s=1, p=1, bias=False, bn=True): 93 | super(BasicDeConv2d, self).__init__() 94 | if bn: 95 | self.add_module('bn', BatchNorm2d(in_channels)) 96 | self.add_module('deconv', nn.ConvTranspose2d(in_channels, channels, k, s, p, bias=bias)) 97 | 98 | class BasicDeConv3d(nn.Sequential): 99 | def __init__(self, in_channels, channels, k=3, s=1, p=1, bias=False, bn=True): 100 | super(BasicDeConv3d, self).__init__() 101 | if bn: 102 | self.add_module('bn', BatchNorm3d(in_channels)) 103 | self.add_module('deconv', nn.ConvTranspose3d(in_channels, channels, k, s, p, bias=bias)) 104 | 105 | class BasicUpsampleConv2d(nn.Sequential): 106 | def __init__(self, in_channels, channels, k=3, s=1, p=1, upsample=(2, 2), bn=True): 107 | super(BasicUpsampleConv2d, self).__init__() 108 | if bn: 109 | self.add_module('bn', BatchNorm2d(in_channels)) 110 | self.add_module('upsample_conv', UpsampleConv2d(in_channels, channels, k, s, p, bias=False, upsample=upsample)) 111 | 112 | class BasicUpsampleConv3d(nn.Sequential): 113 | def __init__(self, in_channels, channels, k=3, s=1, p=1, upsample=(1, 2, 2), bn=True): 114 | super(BasicUpsampleConv3d, self).__init__() 115 | if bn: 116 | self.add_module('bn', BatchNorm3d(in_channels)) 117 | self.add_module('upsample_conv', UpsampleConv3d(in_channels, channels, k, s, p, bias=False, upsample=upsample)) 118 | 119 | 120 | -------------------------------------------------------------------------------- /hsir/data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import scipy.io 5 | import torch 6 | import lmdb 7 | import numpy as np 8 | 9 | from torch.utils.data import Dataset 10 | from torchvision.transforms import Compose 11 | from functools import partial 12 | 13 | __all__ = [ 14 | 'HSITestDataset', 15 | 'HSITrainDataset', 16 | 'HSITransformTestDataset' 17 | ] 18 | 19 | 20 | class HSITestDataset(Dataset): 21 | def __init__(self, root, size=None, use_chw=False, return_name=False): 22 | super().__init__() 23 | self.dataset = MatDataFromFolder(root, size=size) 24 | self.transform = Compose([ 25 | LoadMatHSI(input_key='input', gt_key='gt', 26 | transform=None if use_chw else partial(np.expand_dims, axis=0)), 27 | ]) 28 | self.return_name = return_name 29 | 30 | def __getitem__(self, index): 31 | mat, filename = self.dataset[index] 32 | inputs, targets = self.transform(mat) 33 | output = {'input': inputs, 'target': targets} 34 | if self.return_name: 35 | output['filename'] = filename 36 | return output 37 | 38 | def __len__(self): 39 | return len(self.dataset) 40 | 41 | 42 | class HSITransformTestDataset(Dataset): 43 | def __init__(self, root, transform, size=None): 44 | super().__init__() 45 | self.dataset = MatDataFromFolder(root, size=size) 46 | self.transform = transform 47 | 48 | def __getitem__(self, index): 49 | mat, filename = self.dataset[index] 50 | gt = mat['gt'].transpose(2, 0, 1).astype('float32') 51 | input = self.transform(gt) 52 | return (input, gt), filename 53 | 54 | def __len__(self): 55 | return len(self.dataset) 56 | 57 | 58 | class HSITrainDataset(Dataset): 59 | def __init__(self, 60 | root, 61 | input_transform, 62 | target_transform, 63 | common_transform, 64 | repeat=1, 65 | ): 66 | super().__init__() 67 | self.dataset = LMDBDataset(root, repeat=repeat) 68 | self.common_transform = common_transform 69 | self.input_transform = input_transform 70 | self.target_transform = target_transform 71 | 72 | def __getitem__(self, index): 73 | img = self.dataset[index] 74 | img = self.common_transform(img) 75 | target = img.copy() 76 | if self.input_transform is not None: 77 | img = self.input_transform(img) 78 | if self.target_transform is not None: 79 | target = self.target_transform(target) 80 | return {'input': img, 'target': target} 81 | 82 | def __len__(self): 83 | return len(self.dataset) 84 | 85 | 86 | class LoadMatHSI(object): 87 | def __init__(self, input_key, gt_key, transform=None): 88 | self.gt_key = gt_key 89 | self.input_key = input_key 90 | self.transform = transform 91 | 92 | def __call__(self, mat): 93 | if self.transform: 94 | input = self.transform(mat[self.input_key].transpose((2, 0, 1))) 95 | gt = self.transform(mat[self.gt_key].transpose((2, 0, 1))) 96 | else: 97 | input = mat[self.input_key].transpose((2, 0, 1)) 98 | gt = mat[self.gt_key].transpose((2, 0, 1)) 99 | 100 | input = torch.from_numpy(input).float() 101 | gt = torch.from_numpy(gt).float() 102 | 103 | return input, gt 104 | 105 | 106 | class MatDataFromFolder(Dataset): 107 | """Wrap mat data from folder""" 108 | 109 | def __init__(self, data_dir, load=scipy.io.loadmat, suffix='mat', fns=None, size=None): 110 | super(MatDataFromFolder, self).__init__() 111 | if fns is not None: 112 | self.filenames = [ 113 | os.path.join(data_dir, fn) for fn in fns 114 | ] 115 | else: 116 | self.filenames = [ 117 | os.path.join(data_dir, fn) 118 | for fn in os.listdir(data_dir) 119 | if fn.endswith(suffix) 120 | ] 121 | 122 | self.load = load 123 | 124 | if size and size <= len(self.filenames): 125 | self.filenames = self.filenames[:size] 126 | 127 | def __getitem__(self, index): 128 | filename = self.filenames[index] 129 | mat = self.load(filename) 130 | return mat, filename 131 | 132 | def __len__(self): 133 | return len(self.filenames) 134 | 135 | 136 | class LMDBDataset(Dataset): 137 | def __init__(self, db_path, repeat=1, backend='pickle'): 138 | self.db_path = db_path 139 | self.env = lmdb.open(db_path, max_readers=1, readonly=True, lock=False, 140 | readahead=False, meminit=False) 141 | with self.env.begin(write=False) as txn: 142 | self.length = txn.stat()['entries'] 143 | self.repeat = repeat 144 | self.backend = backend 145 | 146 | def __getitem__(self, index): 147 | index = index % self.length 148 | env = self.env 149 | with env.begin(write=False) as txn: 150 | raw_datum = txn.get('{:08}'.format(index).encode('ascii')) 151 | 152 | if self.backend == 'caffe': 153 | import caffe 154 | datum = caffe.proto.caffe_pb2.Datum() 155 | datum.ParseFromString(raw_datum) 156 | 157 | flat_x = np.fromstring(datum.data, dtype=np.float32) 158 | x = flat_x.reshape(datum.channels, datum.height, datum.width) 159 | else: 160 | x = pickle.loads(raw_datum) 161 | 162 | return x 163 | 164 | def __len__(self): 165 | return self.length * self.repeat 166 | 167 | def __repr__(self): 168 | return self.__class__.__name__ + ' (' + self.db_path + ')' 169 | -------------------------------------------------------------------------------- /hsir/model/hsdt/__init__.py: -------------------------------------------------------------------------------- 1 | from .arch import HSDT 2 | from .attention import GSSA, SMFFN, TransformerBlock 3 | from .sepconv import S3Conv 4 | 5 | def hsdt(): 6 | net = HSDT(1, 16, 5, [1, 3]) 7 | net.use_2dconv = False 8 | net.bandwise = False 9 | return net 10 | 11 | 12 | def hsdt_4(): 13 | net = HSDT(1, 4, 5, [1, 3]) 14 | net.use_2dconv = False 15 | net.bandwise = False 16 | return net 17 | 18 | 19 | def hsdt_8(): 20 | net = HSDT(1, 8, 5, [1, 3]) 21 | net.use_2dconv = False 22 | net.bandwise = False 23 | return net 24 | 25 | 26 | def hsdt_24(): 27 | net = HSDT(1, 24, 5, [1, 3]) 28 | net.use_2dconv = False 29 | net.bandwise = False 30 | return net 31 | 32 | 33 | def hsdt_32(): 34 | net = HSDT(1, 32, 5, [1, 3]) 35 | net.use_2dconv = False 36 | net.bandwise = False 37 | return net 38 | 39 | 40 | def hsdt_deep(): 41 | net = HSDT(1, 16, 7, [1, 3, 5]) 42 | net.use_2dconv = False 43 | net.bandwise = False 44 | return net 45 | 46 | 47 | """ Extension 48 | """ 49 | 50 | 51 | def hsdt_pnp(): 52 | net = HSDT(2, 16, 5, [1, 3]) 53 | net.use_2dconv = False 54 | net.bandwise = False 55 | return net 56 | 57 | 58 | def hsdt_ssr(): 59 | from .arch import HSDTSSR 60 | net = HSDTSSR(1, 16, 5, [1, 3]) 61 | net.use_2dconv = False 62 | net.bandwise = False 63 | return net 64 | 65 | 66 | """ ablations 67 | """ 68 | 69 | 70 | def hsdt_pixelwise(): 71 | from . import arch 72 | from .attention import PixelwiseTransformerBlock 73 | arch.TransformerBlock = PixelwiseTransformerBlock 74 | 75 | net = HSDT(1, 16, 5, [1, 3]) 76 | net.use_2dconv = False 77 | net.bandwise = False 78 | return net 79 | 80 | # ablation of ffn 81 | 82 | 83 | def hsdt_ffn(): 84 | from . import arch 85 | from .attention import FFNTransformerBlock 86 | arch.TransformerBlock = FFNTransformerBlock 87 | 88 | net = HSDT(1, 16, 5, [1, 3]) 89 | net.use_2dconv = False 90 | net.bandwise = False 91 | return net 92 | 93 | 94 | def hsdt_ffn_flex(): 95 | from . import arch 96 | from .attention import FFNTransformerBlock 97 | from functools import partial 98 | arch.TransformerBlock = partial(FFNTransformerBlock, flex=True) 99 | 100 | net = HSDT(1, 16, 5, [1, 3]) 101 | net.use_2dconv = False 102 | net.bandwise = False 103 | return net 104 | 105 | 106 | def hsdt_gdfn(): 107 | from . import arch 108 | from .attention import GDFNTransformerBlock 109 | arch.TransformerBlock = GDFNTransformerBlock 110 | 111 | net = HSDT(1, 16, 5, [1, 3]) 112 | net.use_2dconv = False 113 | net.bandwise = False 114 | return net 115 | 116 | 117 | def hsdt_smffn1(): 118 | from . import arch 119 | from .attention import GFNTransformerBlock 120 | arch.TransformerBlock = GFNTransformerBlock 121 | 122 | net = HSDT(1, 16, 5, [1, 3]) 123 | net.use_2dconv = False 124 | net.bandwise = False 125 | return net 126 | 127 | # ablation of ssa 128 | 129 | 130 | def hsdt_ssa(): 131 | from . import arch 132 | from .attention import SSATransformerBlock 133 | arch.TransformerBlock = SSATransformerBlock 134 | 135 | net = HSDT(1, 16, 5, [1, 3]) 136 | net.use_2dconv = False 137 | net.bandwise = False 138 | return net 139 | 140 | # ablation of s3conv 141 | 142 | 143 | def hsdt_conv3d(): 144 | from . import arch 145 | import torch.nn as nn 146 | arch.Conv3d = nn.Conv3d 147 | 148 | net = HSDT(1, 16, 5, [1, 3]) 149 | net.use_2dconv = False 150 | net.bandwise = False 151 | return net 152 | 153 | 154 | def hsdt_s3conv_sep(): 155 | from . import arch 156 | import torch.nn as nn 157 | from .sepconv import SepConv_DP 158 | arch.Conv3d = SepConv_DP.of(nn.Conv3d) 159 | net = HSDT(1, 16, 5, [1, 3]) 160 | net.use_2dconv = False 161 | net.bandwise = False 162 | return net 163 | 164 | 165 | def hsdt_s3conv_seq(): 166 | from . import arch 167 | import torch.nn as nn 168 | from .sepconv import S3Conv_Seq 169 | arch.Conv3d = S3Conv_Seq.of(nn.Conv3d) 170 | net = HSDT(1, 16, 5, [1, 3]) 171 | net.use_2dconv = False 172 | net.bandwise = False 173 | return net 174 | 175 | 176 | def hsdt_s3conv1(): 177 | from . import arch 178 | import torch.nn as nn 179 | from .sepconv import S3Conv1 180 | arch.Conv3d = S3Conv1.of(nn.Conv3d) 181 | net = HSDT(1, 16, 5, [1, 3]) 182 | net.use_2dconv = False 183 | net.bandwise = False 184 | return net 185 | 186 | 187 | """ Break down 188 | """ 189 | 190 | 191 | def baseline_s3conv(): 192 | from . import arch 193 | from .attention import DummyTransformerBlock 194 | arch.TransformerBlock = DummyTransformerBlock 195 | arch.UseBN = False 196 | 197 | net = HSDT(1, 16, 5, [1, 3]) 198 | net.use_2dconv = False 199 | net.bandwise = False 200 | return net 201 | 202 | 203 | def baseline_conv3d(): 204 | from . import arch 205 | import torch.nn as nn 206 | arch.Conv3d = nn.Conv3d 207 | from .attention import DummyTransformerBlock 208 | arch.TransformerBlock = DummyTransformerBlock 209 | arch.UseBN = False 210 | 211 | net = HSDT(1, 16, 5, [1, 3]) 212 | net.use_2dconv = False 213 | net.bandwise = False 214 | return net 215 | 216 | 217 | def baseline_gssa(): 218 | from . import arch 219 | import torch.nn as nn 220 | arch.Conv3d = nn.Conv3d 221 | from .attention import FFNTransformerBlock 222 | arch.TransformerBlock = FFNTransformerBlock 223 | arch.UseBN = True 224 | 225 | net = HSDT(1, 16, 5, [1, 3]) 226 | net.use_2dconv = False 227 | net.bandwise = False 228 | return net 229 | 230 | 231 | def baseline_ssa(): 232 | from . import arch 233 | import torch.nn as nn 234 | arch.Conv3d = nn.Conv3d 235 | from .attention import SSAFFNTransformerBlock 236 | arch.TransformerBlock = SSAFFNTransformerBlock 237 | arch.UseBN = True 238 | 239 | net = HSDT(1, 16, 5, [1, 3]) 240 | net.use_2dconv = False 241 | net.bandwise = False 242 | return net 243 | -------------------------------------------------------------------------------- /hsir/model/qrnn3d/net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class QRNNREDC3D(nn.Module): 5 | def __init__(self, in_channels, channels, num_half_layer, sample_idx, 6 | BiQRNNConv3D=None, BiQRNNDeConv3D=None, 7 | QRNN3DEncoder=None, QRNN3DDecoder=None, is_2d=False, has_ad=True, bn=True, act='tanh', plain=False): 8 | super(QRNNREDC3D, self).__init__() 9 | assert sample_idx is None or isinstance(sample_idx, list) 10 | 11 | self.enable_ad = has_ad 12 | if sample_idx is None: sample_idx = [] 13 | if is_2d: 14 | self.feature_extractor = BiQRNNConv3D(in_channels, channels, k=(1,3,3), s=1, p=(0,1,1), bn=bn, act=act) 15 | else: 16 | self.feature_extractor = BiQRNNConv3D(in_channels, channels, bn=bn, act=act) 17 | 18 | self.encoder = QRNN3DEncoder(channels, num_half_layer, sample_idx, is_2d=is_2d, has_ad=has_ad, bn=bn, act=act, plain=plain) 19 | self.decoder = QRNN3DDecoder(channels*(2**len(sample_idx)), num_half_layer, sample_idx, is_2d=is_2d, has_ad=has_ad, bn=bn, act=act, plain=plain) 20 | 21 | if act == 'relu': 22 | act = 'none' 23 | 24 | if is_2d: 25 | self.reconstructor = BiQRNNDeConv3D(channels, in_channels, bias=True, k=(1,3,3), s=1, p=(0,1,1), bn=bn, act=act) 26 | else: 27 | self.reconstructor = BiQRNNDeConv3D(channels, in_channels, bias=True, bn=bn, act=act) 28 | 29 | def forward(self, x): 30 | xs = [x] 31 | out = self.feature_extractor(xs[0]) 32 | xs.append(out) 33 | if self.enable_ad: 34 | out, reverse = self.encoder(out, xs, reverse=False) 35 | out = self.decoder(out, xs, reverse=(reverse)) 36 | else: 37 | out = self.encoder(out, xs) 38 | out = self.decoder(out, xs) 39 | out = out + xs.pop() 40 | out = self.reconstructor(out) 41 | out = out + xs.pop() 42 | return out 43 | 44 | 45 | class QRNN3DEncoder(nn.Module): 46 | def __init__(self, channels, num_half_layer, sample_idx, QRNNConv3D=None, 47 | is_2d=False, has_ad=True, bn=True, act='tanh', plain=False): 48 | super(QRNN3DEncoder, self).__init__() 49 | # Encoder 50 | self.layers = nn.ModuleList() 51 | self.enable_ad = has_ad 52 | for i in range(num_half_layer): 53 | if i not in sample_idx: 54 | if is_2d: 55 | encoder_layer = QRNNConv3D(channels, channels, k=(1,3,3), s=1, p=(0,1,1), bn=bn, act=act) 56 | else: 57 | encoder_layer = QRNNConv3D(channels, channels, bn=bn, act=act) 58 | else: 59 | if is_2d: 60 | encoder_layer = QRNNConv3D(channels, 2*channels, k=(1,3,3), s=(1,2,2), p=(0,1,1), bn=bn, act=act) 61 | else: 62 | if not plain: 63 | encoder_layer = QRNNConv3D(channels, 2*channels, k=3, s=(1,2,2), p=1, bn=bn, act=act) 64 | else: 65 | encoder_layer = QRNNConv3D(channels, 2*channels, k=3, s=(1,1,1), p=1, bn=bn, act=act) 66 | 67 | channels *= 2 68 | self.layers.append(encoder_layer) 69 | 70 | def forward(self, x, xs, reverse=False): 71 | if not self.enable_ad: 72 | num_half_layer = len(self.layers) 73 | for i in range(num_half_layer-1): 74 | x = self.layers[i](x) 75 | xs.append(x) 76 | x = self.layers[-1](x) 77 | 78 | return x 79 | else: 80 | num_half_layer = len(self.layers) 81 | for i in range(num_half_layer-1): 82 | x = self.layers[i](x, reverse=reverse) 83 | reverse = not reverse 84 | xs.append(x) 85 | x = self.layers[-1](x, reverse=reverse) 86 | reverse = not reverse 87 | 88 | return x, reverse 89 | 90 | 91 | class QRNN3DDecoder(nn.Module): 92 | def __init__(self, channels, num_half_layer, sample_idx, QRNNDeConv3D=None, QRNNUpsampleConv3d=None, 93 | is_2d=False, has_ad=True, bn=True, act='tanh', plain=False): 94 | super(QRNN3DDecoder, self).__init__() 95 | # Decoder 96 | self.layers = nn.ModuleList() 97 | self.enable_ad = has_ad 98 | for i in reversed(range(num_half_layer)): 99 | if i not in sample_idx: 100 | if is_2d: 101 | decoder_layer = QRNNDeConv3D(channels, channels, k=(1,3,3), s=1, p=(0,1,1), bn=bn, act=act) 102 | else: 103 | decoder_layer = QRNNDeConv3D(channels, channels, bn=bn, act=act) 104 | else: 105 | if is_2d: 106 | decoder_layer = QRNNUpsampleConv3d(channels, channels//2, k=(1,3,3), s=1, p=(0,1,1), bn=bn, act=act) 107 | else: 108 | if not plain: 109 | decoder_layer = QRNNUpsampleConv3d(channels, channels//2, bn=bn, act=act) 110 | else: 111 | decoder_layer = QRNNDeConv3D(channels, channels//2, bn=bn, act=act) 112 | 113 | channels //= 2 114 | self.layers.append(decoder_layer) 115 | 116 | 117 | def forward(self, x, xs, reverse=False): 118 | if not self.enable_ad: 119 | num_half_layer = len(self.layers) 120 | x = self.layers[0](x) 121 | for i in range(1, num_half_layer): 122 | x = x + xs.pop() 123 | x = self.layers[i](x) 124 | return x 125 | else: 126 | num_half_layer = len(self.layers) 127 | x = self.layers[0](x, reverse=reverse) 128 | reverse = not reverse 129 | for i in range(1, num_half_layer): 130 | x = x + xs.pop() 131 | x = self.layers[i](x, reverse=reverse) 132 | reverse = not reverse 133 | return x -------------------------------------------------------------------------------- /benchmark/datasets/CAVE/cave_build_lmdb.py: -------------------------------------------------------------------------------- 1 | """Create lmdb dataset""" 2 | import os 3 | import numpy as np 4 | import lmdb 5 | import caffe 6 | from itertools import product 7 | from scipy.ndimage import zoom 8 | import random 9 | from scipy.io import loadmat 10 | 11 | def Data2Volume(data, ksizes, strides): 12 | """ 13 | Construct Volumes from Original High Dimensional (D) Data 14 | """ 15 | dshape = data.shape 16 | def PatNum(l, k, s): return (np.floor((l - k) / s) + 1) 17 | 18 | TotalPatNum = 1 19 | for i in range(len(ksizes)): 20 | TotalPatNum = TotalPatNum * PatNum(dshape[i], ksizes[i], strides[i]) 21 | 22 | V = np.zeros([int(TotalPatNum)] + ksizes) # create D+1 dimension volume 23 | 24 | args = [range(kz) for kz in ksizes] 25 | for s in product(*args): 26 | s1 = (slice(None),) + s 27 | s2 = tuple([slice(key, -ksizes[i] + key + 1 or None, strides[i]) for i, key in enumerate(s)]) 28 | V[s1] = np.reshape(data[s2], (-1,)) 29 | 30 | return V 31 | 32 | 33 | def minmax_normalize(array): 34 | amin = np.min(array) 35 | amax = np.max(array) 36 | return (array - amin) / (amax - amin) 37 | 38 | 39 | def crop_center(img, cropx, cropy): 40 | _, y, x = img.shape 41 | startx = x // 2 - (cropx // 2) 42 | starty = y // 2 - (cropy // 2) 43 | return img[:, starty:starty + cropy, startx:startx + cropx] 44 | 45 | 46 | def data_augmentation(image, mode=None): 47 | """ 48 | Args: 49 | image: np.ndarray, shape: C X H X W 50 | """ 51 | axes = (-2, -1) 52 | def flipud(x): return x[:, ::-1, :] 53 | 54 | if mode is None: 55 | mode = random.randint(0, 7) 56 | if mode == 0: 57 | # original 58 | image = image 59 | elif mode == 1: 60 | # flip up and down 61 | image = flipud(image) 62 | elif mode == 2: 63 | # rotate counterwise 90 degree 64 | image = np.rot90(image, axes=axes) 65 | elif mode == 3: 66 | # rotate 90 degree and flip up and down 67 | image = np.rot90(image, axes=axes) 68 | image = flipud(image) 69 | elif mode == 4: 70 | # rotate 180 degree 71 | image = np.rot90(image, k=2, axes=axes) 72 | elif mode == 5: 73 | # rotate 180 degree and flip 74 | image = np.rot90(image, k=2, axes=axes) 75 | image = flipud(image) 76 | elif mode == 6: 77 | # rotate 270 degree 78 | image = np.rot90(image, k=3, axes=axes) 79 | elif mode == 7: 80 | # rotate 270 degree and flip 81 | image = np.rot90(image, k=3, axes=axes) 82 | image = flipud(image) 83 | 84 | # we apply spectrum reversal for training 3D CNN, e.g. QRNN3D. 85 | # disable it when training 2D CNN, e.g. MemNet 86 | if random.random() < 0.5: 87 | image = image[::-1, :, :] 88 | 89 | return np.ascontiguousarray(image) 90 | 91 | 92 | def create_lmdb_train( 93 | datadir, fns, name, matkey, 94 | crop_sizes, scales, ksizes, strides, 95 | load=loadmat, augment=True, 96 | seed=2022): 97 | """ 98 | Create Augmented Dataset 99 | """ 100 | def preprocess(data): 101 | new_data = [] 102 | data = np.float32(data) 103 | # data = minmax_normalize(data) 104 | # data = np.rot90(data, k=2, axes=(1,2)) # ICVL 105 | data = minmax_normalize(data.transpose((2, 0, 1))) # for Remote Sensing 106 | # Visualize3D(data) 107 | if crop_sizes is not None: 108 | data = crop_center(data, crop_sizes[0], crop_sizes[1]) 109 | 110 | for i in range(len(scales)): 111 | if scales[i] != 1: 112 | temp = zoom(data, zoom=(1, scales[i], scales[i])) 113 | else: 114 | temp = data 115 | temp = Data2Volume(temp, ksizes=ksizes, strides=list(strides[i])) 116 | new_data.append(temp) 117 | new_data = np.concatenate(new_data, axis=0) 118 | if augment: 119 | for i in range(new_data.shape[0]): 120 | new_data[i, ...] = data_augmentation(new_data[i, ...]) 121 | 122 | return new_data.astype(np.float32) 123 | 124 | np.random.seed(seed) 125 | scales = list(scales) 126 | ksizes = list(ksizes) 127 | assert len(scales) == len(strides) 128 | # calculate the shape of dataset 129 | data = load(datadir + fns[0])[matkey] 130 | data = preprocess(data) 131 | N = data.shape[0] 132 | 133 | print(data.shape) 134 | map_size = data.nbytes * len(fns) * 1.2 135 | print('map size (GB):', map_size / 1024 / 1024 / 1024) 136 | 137 | # import ipdb; ipdb.set_trace() 138 | if os.path.exists(name + '.db'): 139 | raise Exception('database already exist!') 140 | env = lmdb.open(name + '.db', map_size=map_size, writemap=True) 141 | with env.begin(write=True) as txn: 142 | # txn is a Transaction object 143 | k = 0 144 | for i, fn in enumerate(fns): 145 | try: 146 | X = load(datadir + fn)[matkey] 147 | except: 148 | print('loading', datadir + fn, 'fail') 149 | continue 150 | X = preprocess(X) 151 | N = X.shape[0] 152 | for j in range(N): 153 | datum = caffe.proto.caffe_pb2.Datum() 154 | datum.channels = X.shape[1] 155 | datum.height = X.shape[2] 156 | datum.width = X.shape[3] 157 | datum.data = X[j].tobytes() 158 | str_id = '{:08}'.format(k) 159 | k += 1 160 | txn.put(str_id.encode('ascii'), datum.SerializeToString()) 161 | print('load mat (%d/%d): %s' % (i, len(fns), fn)) 162 | 163 | print('done') 164 | 165 | def create_cave(): 166 | print('create cave...') 167 | datadir = '/home/wzliu/projects/data/cave_mat/' 168 | with open('cave_train.txt') as f: 169 | fns = [l.strip() for l in f.readlines()] 170 | 171 | create_lmdb_train( 172 | datadir, fns, 'CAVE64_31', 'gt', 173 | crop_sizes=(512,512), 174 | scales=(1, 0.5, 0.25), 175 | ksizes=(31, 64, 64), 176 | strides=[(31, 64, 64), (31, 32, 32), (31, 32, 32)], 177 | load=loadmat, augment=True, 178 | ) 179 | 180 | 181 | if __name__ == '__main__': 182 | create_cave() 183 | -------------------------------------------------------------------------------- /benchmark/datasets/Harvard/harvard_build_lmdb.py: -------------------------------------------------------------------------------- 1 | """Create lmdb dataset""" 2 | import os 3 | import numpy as np 4 | import lmdb 5 | import caffe 6 | from itertools import product 7 | from scipy.ndimage import zoom 8 | import random 9 | from scipy.io import loadmat 10 | 11 | def Data2Volume(data, ksizes, strides): 12 | """ 13 | Construct Volumes from Original High Dimensional (D) Data 14 | """ 15 | dshape = data.shape 16 | def PatNum(l, k, s): return (np.floor((l - k) / s) + 1) 17 | 18 | TotalPatNum = 1 19 | for i in range(len(ksizes)): 20 | TotalPatNum = TotalPatNum * PatNum(dshape[i], ksizes[i], strides[i]) 21 | 22 | V = np.zeros([int(TotalPatNum)] + ksizes) # create D+1 dimension volume 23 | 24 | args = [range(kz) for kz in ksizes] 25 | for s in product(*args): 26 | s1 = (slice(None),) + s 27 | s2 = tuple([slice(key, -ksizes[i] + key + 1 or None, strides[i]) for i, key in enumerate(s)]) 28 | V[s1] = np.reshape(data[s2], (-1,)) 29 | 30 | return V 31 | 32 | 33 | def minmax_normalize(array): 34 | amin = np.min(array) 35 | amax = np.max(array) 36 | return (array - amin) / (amax - amin) 37 | 38 | 39 | def crop_center(img, cropx, cropy): 40 | _, y, x = img.shape 41 | startx = x // 2 - (cropx // 2) 42 | starty = y // 2 - (cropy // 2) 43 | return img[:, starty:starty + cropy, startx:startx + cropx] 44 | 45 | 46 | def data_augmentation(image, mode=None): 47 | """ 48 | Args: 49 | image: np.ndarray, shape: C X H X W 50 | """ 51 | axes = (-2, -1) 52 | def flipud(x): return x[:, ::-1, :] 53 | 54 | if mode is None: 55 | mode = random.randint(0, 7) 56 | if mode == 0: 57 | # original 58 | image = image 59 | elif mode == 1: 60 | # flip up and down 61 | image = flipud(image) 62 | elif mode == 2: 63 | # rotate counterwise 90 degree 64 | image = np.rot90(image, axes=axes) 65 | elif mode == 3: 66 | # rotate 90 degree and flip up and down 67 | image = np.rot90(image, axes=axes) 68 | image = flipud(image) 69 | elif mode == 4: 70 | # rotate 180 degree 71 | image = np.rot90(image, k=2, axes=axes) 72 | elif mode == 5: 73 | # rotate 180 degree and flip 74 | image = np.rot90(image, k=2, axes=axes) 75 | image = flipud(image) 76 | elif mode == 6: 77 | # rotate 270 degree 78 | image = np.rot90(image, k=3, axes=axes) 79 | elif mode == 7: 80 | # rotate 270 degree and flip 81 | image = np.rot90(image, k=3, axes=axes) 82 | image = flipud(image) 83 | 84 | # we apply spectrum reversal for training 3D CNN, e.g. QRNN3D. 85 | # disable it when training 2D CNN, e.g. MemNet 86 | if random.random() < 0.5: 87 | image = image[::-1, :, :] 88 | 89 | return np.ascontiguousarray(image) 90 | 91 | 92 | def create_lmdb_train( 93 | datadir, fns, name, matkey, 94 | crop_sizes, scales, ksizes, strides, 95 | load=loadmat, augment=True, 96 | seed=2022): 97 | """ 98 | Create Augmented Dataset 99 | """ 100 | def preprocess(data): 101 | new_data = [] 102 | data = np.float32(data) 103 | # data = minmax_normalize(data) 104 | # data = np.rot90(data, k=2, axes=(1,2)) # ICVL 105 | data = minmax_normalize(data.transpose((2, 0, 1))) # for Remote Sensing 106 | # Visualize3D(data) 107 | if crop_sizes is not None: 108 | data = crop_center(data, crop_sizes[0], crop_sizes[1]) 109 | 110 | for i in range(len(scales)): 111 | if scales[i] != 1: 112 | temp = zoom(data, zoom=(1, scales[i], scales[i])) 113 | else: 114 | temp = data 115 | temp = Data2Volume(temp, ksizes=ksizes, strides=list(strides[i])) 116 | new_data.append(temp) 117 | new_data = np.concatenate(new_data, axis=0) 118 | if augment: 119 | for i in range(new_data.shape[0]): 120 | new_data[i, ...] = data_augmentation(new_data[i, ...]) 121 | 122 | return new_data.astype(np.float32) 123 | 124 | np.random.seed(seed) 125 | scales = list(scales) 126 | ksizes = list(ksizes) 127 | assert len(scales) == len(strides) 128 | # calculate the shape of dataset 129 | data = load(datadir + fns[0])[matkey] 130 | data = preprocess(data) 131 | N = data.shape[0] 132 | 133 | print(data.shape) 134 | map_size = data.nbytes * len(fns) * 1.2 135 | print('map size (GB):', map_size / 1024 / 1024 / 1024) 136 | 137 | # import ipdb; ipdb.set_trace() 138 | if os.path.exists(name + '.db'): 139 | raise Exception('database already exist!') 140 | env = lmdb.open(name + '.db', map_size=map_size, writemap=True) 141 | with env.begin(write=True) as txn: 142 | # txn is a Transaction object 143 | k = 0 144 | for i, fn in enumerate(fns): 145 | try: 146 | X = load(datadir + fn)[matkey] 147 | except: 148 | print('loading', datadir + fn, 'fail') 149 | continue 150 | X = preprocess(X) 151 | N = X.shape[0] 152 | for j in range(N): 153 | datum = caffe.proto.caffe_pb2.Datum() 154 | datum.channels = X.shape[1] 155 | datum.height = X.shape[2] 156 | datum.width = X.shape[3] 157 | datum.data = X[j].tobytes() 158 | str_id = '{:08}'.format(k) 159 | k += 1 160 | txn.put(str_id.encode('ascii'), datum.SerializeToString()) 161 | print('load mat (%d/%d): %s' % (i, len(fns), fn)) 162 | 163 | print('done') 164 | 165 | def create_cave(): 166 | print('create harvard...') 167 | datadir = '/home/wzliu/projects/data/harvard/CZ_hsdb/' 168 | with open('harvard_train.txt') as f: 169 | fns = [l.strip() for l in f.readlines()] 170 | 171 | create_lmdb_train( 172 | datadir, fns, 'Harvard64_31', 'ref', 173 | crop_sizes=(1024,1024), 174 | scales=(1, 0.5, 0.25), 175 | ksizes=(31, 64, 64), 176 | strides=[(31, 64, 64), (31, 32, 32), (31, 32, 32)], 177 | load=loadmat, augment=True, 178 | ) 179 | 180 | 181 | if __name__ == '__main__': 182 | create_cave() 183 | -------------------------------------------------------------------------------- /benchmark/datasets/RealHSI/real_build_lmdb.py: -------------------------------------------------------------------------------- 1 | """Create lmdb dataset""" 2 | import os 3 | import numpy as np 4 | import lmdb 5 | from itertools import product 6 | from scipy.ndimage import zoom 7 | import random 8 | from scipy.io import loadmat 9 | import pickle 10 | 11 | def Data2Volume(data, ksizes, strides): 12 | """ 13 | Construct Volumes from Original High Dimensional (D) Data 14 | """ 15 | dshape = data.shape 16 | def PatNum(l, k, s): return (np.floor((l - k) / s) + 1) 17 | 18 | TotalPatNum = 1 19 | for i in range(len(ksizes)): 20 | TotalPatNum = TotalPatNum * PatNum(dshape[i], ksizes[i], strides[i]) 21 | 22 | V = np.zeros([int(TotalPatNum)] + ksizes) # create D+1 dimension volume 23 | 24 | args = [range(kz) for kz in ksizes] 25 | for s in product(*args): 26 | s1 = (slice(None),) + s 27 | s2 = tuple([slice(key, -ksizes[i] + key + 1 or None, strides[i]) for i, key in enumerate(s)]) 28 | V[s1] = np.reshape(data[s2], (-1,)) 29 | 30 | return V 31 | 32 | 33 | def minmax_normalize(array): 34 | amin = np.min(array) 35 | amax = np.max(array) 36 | return (array - amin) / (amax - amin) 37 | 38 | 39 | def crop_center(img, cropx, cropy): 40 | _, y, x = img.shape 41 | startx = x // 2 - (cropx // 2) 42 | starty = y // 2 - (cropy // 2) 43 | return img[:, starty:starty + cropy, startx:startx + cropx] 44 | 45 | 46 | def data_augmentation(image, mode=None): 47 | """ 48 | Args: 49 | image: np.ndarray, shape: C X H X W 50 | """ 51 | axes = (-2, -1) 52 | def flipud(x): return x[:, ::-1, :] 53 | 54 | if mode is None: 55 | mode = random.randint(0, 7) 56 | if mode == 0: 57 | # original 58 | image = image 59 | elif mode == 1: 60 | # flip up and down 61 | image = flipud(image) 62 | elif mode == 2: 63 | # rotate counterwise 90 degree 64 | image = np.rot90(image, axes=axes) 65 | elif mode == 3: 66 | # rotate 90 degree and flip up and down 67 | image = np.rot90(image, axes=axes) 68 | image = flipud(image) 69 | elif mode == 4: 70 | # rotate 180 degree 71 | image = np.rot90(image, k=2, axes=axes) 72 | elif mode == 5: 73 | # rotate 180 degree and flip 74 | image = np.rot90(image, k=2, axes=axes) 75 | image = flipud(image) 76 | elif mode == 6: 77 | # rotate 270 degree 78 | image = np.rot90(image, k=3, axes=axes) 79 | elif mode == 7: 80 | # rotate 270 degree and flip 81 | image = np.rot90(image, k=3, axes=axes) 82 | image = flipud(image) 83 | 84 | # we apply spectrum reversal for training 3D CNN, e.g. QRNN3D. 85 | # disable it when training 2D CNN, e.g. MemNet 86 | if random.random() < 0.5: 87 | image = image[::-1, :, :] 88 | 89 | return np.ascontiguousarray(image) 90 | 91 | 92 | def create_lmdb_train( 93 | datadir, fns, name, matkey, 94 | crop_sizes, scales, ksizes, strides, 95 | load=loadmat, augment=True, 96 | seed=2022): 97 | """ 98 | Create Augmented Dataset 99 | """ 100 | def preprocess(data): 101 | new_data = [] 102 | data = np.float32(data) 103 | # data = minmax_normalize(data.transpose((2, 0, 1))) 104 | data = data.transpose(2, 0, 1) / 4096 105 | if crop_sizes is not None: 106 | data = crop_center(data, crop_sizes[0], crop_sizes[1]) 107 | 108 | for i in range(len(scales)): 109 | if scales[i] != 1: 110 | temp = zoom(data, zoom=(1, scales[i], scales[i])) 111 | else: 112 | temp = data 113 | temp = Data2Volume(temp, ksizes=ksizes, strides=list(strides[i])) 114 | new_data.append(temp) 115 | new_data = np.concatenate(new_data, axis=0) 116 | return new_data.astype(np.float32) 117 | 118 | np.random.seed(seed) 119 | scales = list(scales) 120 | ksizes = list(ksizes) 121 | assert len(scales) == len(strides) 122 | # calculate the shape of dataset 123 | data = load(os.path.join(datadir, fns[0])) 124 | gt = data['gt'] 125 | input = data['input'] 126 | data = preprocess(gt) 127 | 128 | N = data.shape[0] 129 | 130 | print(data.shape) 131 | map_size = data.nbytes * len(fns) * 2.4 132 | print('map size (GB):', map_size / 1024 / 1024 / 1024) 133 | 134 | if os.path.exists(name + '.db'): 135 | raise Exception('database already exist!') 136 | env = lmdb.open(name + '.db', map_size=map_size, writemap=True) 137 | with env.begin(write=True) as txn: 138 | # txn is a Transaction object 139 | k = 0 140 | for i, fn in enumerate(fns): 141 | try: 142 | data = load(os.path.join(datadir, fn)) 143 | gt = data['gt'] 144 | input = data['input'] 145 | except: 146 | print('loading', datadir + fn, 'fail') 147 | continue 148 | gt = preprocess(gt) 149 | input = preprocess(input) 150 | if augment: 151 | for t in range(gt.shape[0]): 152 | mode = random.randint(0, 7) 153 | gt[t, ...] = data_augmentation(gt[t, ...], mode) 154 | input[t, ...] = data_augmentation(input[t, ...], mode) 155 | 156 | N = gt.shape[0] 157 | for j in range(N): 158 | str_id = '{:08}'.format(k) 159 | k += 1 160 | data = {'gt': gt[j], 'input': input[j]} 161 | data = pickle.dumps(data) 162 | txn.put(str_id.encode('ascii'), data) 163 | print('load mat (%d/%d): %s' % (i, len(fns), fn)) 164 | 165 | print('done') 166 | 167 | def create_real(): 168 | print('create real...') 169 | datadir = '/media/exthdd/datasets/hsi/real_hsi/real_dataset/mat/' 170 | with open('real_train.txt') as f: 171 | fns = [l.strip() for l in f.readlines()] 172 | 173 | create_lmdb_train( 174 | datadir, fns, 'Real64_31_2', 'gt', 175 | crop_sizes=None, 176 | scales=(1, 0.5, 0.25), 177 | ksizes=(31, 64, 64), 178 | strides=[(31, 64, 64), (31, 32, 32), (31, 32, 32)], 179 | load=loadmat, augment=True, 180 | ) 181 | 182 | 183 | if __name__ == '__main__': 184 | create_real() 185 | -------------------------------------------------------------------------------- /hsirun/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import OrderedDict 3 | import os 4 | from os.path import join, exists 5 | 6 | import torch 7 | import torch.utils.data 8 | import imageio 9 | import numpy as np 10 | from tqdm import tqdm 11 | from tabulate import tabulate 12 | 13 | import hsir.model 14 | import hsir.data.utils 15 | from hsir.data import HSITestDataset 16 | 17 | import torchlight as tl 18 | 19 | tl.metrics.set_data_format('chw') 20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | 22 | 23 | def bchw2hwc(x): 24 | return np.uint8(x.cpu().squeeze(0).permute(1, 2, 0).numpy() * 255) 25 | 26 | 27 | def eval(net, loader, name, logdir, clamp, bandwise): 28 | os.makedirs(join(logdir, 'color'), exist_ok=True) 29 | os.makedirs(join(logdir, 'gray'), exist_ok=True) 30 | 31 | tracker = tl.trainer.util.MetricTracker() 32 | detail_stat = {} 33 | 34 | with torch.no_grad(): 35 | pbar = tqdm(total=len(loader), dynamic_ncols=True) 36 | pbar.set_description(name) 37 | for data in loader: 38 | filename = data['filename'][0] 39 | basename = tl.utils.filename(filename) 40 | inputs, targets = data['input'].to(device), data['target'].to(device) 41 | 42 | if clamp: 43 | inputs = torch.clamp(inputs, 0., 1.) 44 | tl.utils.timer.tic() 45 | if isinstance(net, str): 46 | outputs = targets if net == 'gt' else inputs 47 | else: 48 | outputs = net(inputs) 49 | 50 | torch.cuda.synchronize() 51 | run_time = float(tl.utils.timer.toc()) / 1000 52 | 53 | outputs = outputs.clamp(0,1) 54 | inputs = inputs.squeeze(1) 55 | outputs = outputs.squeeze(1) 56 | targets = targets.squeeze(1) 57 | 58 | imageio.imwrite( 59 | join(logdir, 'color', basename + '.png'), 60 | bchw2hwc(hsir.data.utils.visualize_color(outputs)) 61 | ) 62 | imageio.imwrite( 63 | join(logdir, 'gray', basename + '.png'), 64 | bchw2hwc(hsir.data.utils.visualize_gray(outputs)) 65 | ) 66 | 67 | psnr = tl.metrics.mpsnr(targets, outputs).item() 68 | ssim = tl.metrics.mssim(targets, outputs).item() 69 | sam = tl.metrics.sam(targets, outputs).item() 70 | 71 | tracker.update('MPSNR', psnr) 72 | tracker.update('MSSIM', ssim) 73 | tracker.update('SAM', sam) 74 | tracker.update('Time', run_time) 75 | 76 | pbar.set_postfix({k: f'{v:.4f}' for k, v in tracker.result().items()}) 77 | pbar.update() 78 | 79 | detail_stat[basename] = {'MPSNR': psnr, 'MSSIM': ssim, 'SAM': sam, 'Time': run_time} 80 | 81 | pbar.close() 82 | 83 | avg_speed = tl.utils.format_time(tracker['Time']) 84 | print(f'Average speed {avg_speed}') 85 | print(f'Average results {tracker.summary()}') 86 | 87 | # log structural results 88 | avg_stat = {k: v for k, v in tracker.result().items()} 89 | tl.utils.io.jsonwrite(join(logdir, 'log.json'), {'avg': avg_stat, 'detail': detail_stat}) 90 | 91 | 92 | def pretty_summary(logdir): 93 | stat = [] 94 | print('') 95 | for folder in os.listdir(logdir): 96 | if os.path.isdir(join(logdir, folder)): 97 | path = join(logdir, folder, 'log.json') 98 | if exists(path): 99 | data = tl.utils.io.jsonload(path) 100 | s = OrderedDict() 101 | s['Name'] = folder 102 | s.update(data['avg']) 103 | stat.append(s) 104 | print(tabulate(stat, headers='keys', tablefmt='github')) 105 | print('') 106 | 107 | 108 | def main(args, logdir): 109 | if args.arch == 'gt' or args.arch == 'input': 110 | net = args.arch 111 | else: 112 | net = tl.utils.instantiate(args.arch) 113 | net = net.to(device) 114 | if args.resume: 115 | ckpt = tl.utils.dict_get(torch.load(args.resume), args.key_path) 116 | if ckpt is None: print(f'key_path {args.key_path} might be wrong') 117 | net.load_state_dict(ckpt) 118 | net.eval() 119 | 120 | for testset in args.testset: 121 | print('Evaluating {}'.format(args.arch)) 122 | print('On {}'.format(testset)) 123 | print('With {}'.format(args.resume)) 124 | testdir = join(args.basedir, testset) 125 | dataset = HSITestDataset(testdir, use_chw=args.use_conv2d, return_name=True) 126 | loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) 127 | eval(net, loader, testset, join(logdir, testset), args.clamp, args.bandwise) 128 | 129 | pretty_summary(logdir) 130 | 131 | 132 | if __name__ == '__main__': 133 | parser = argparse.ArgumentParser(description='HSIR Test script') 134 | parser.add_argument('-a', '--arch', required=True, help='architecture name') 135 | parser.add_argument('-n', '--name', default=None, help='save name') 136 | parser.add_argument('-r', '--resume', default=None, help='checkpoint') 137 | parser.add_argument('-t', '--testset', nargs='+', default=['icvl_512_50'], help='testset') 138 | parser.add_argument('-d', '--basedir', default='data', help='basedir') 139 | parser.add_argument('--logdir', default='results', help='logdir') 140 | parser.add_argument('--save_img', action='store_true', help='whether to save image') 141 | parser.add_argument('--clamp', action='store_true', help='whether clamp input into [0, 1]') 142 | parser.add_argument('-kp', '--key_path', default='net', help='key path to access network state_dict in ckpt') 143 | parser.add_argument('--bandwise', action='store_true') 144 | parser.add_argument('--use-conv2d', action='store_true') 145 | parser.add_argument('--force-run', action='store_true') 146 | args = parser.parse_args() 147 | 148 | save_name = args.arch if args.name is None else args.name 149 | logdir = join(args.logdir, save_name) 150 | if exists(logdir) and not args.force_run: 151 | print(f'It seems that you have evaluated {args.arch} before.') 152 | pretty_summary(logdir) 153 | action = input('Are you sure you want to continue? (y) continue (n) exit\n') 154 | if action != 'y': exit() 155 | 156 | os.makedirs(logdir, exist_ok=True) 157 | with open(join(logdir, 'meta.txt'), 'w') as f: 158 | f.write(tl.utils.get_datetime() + '\n') 159 | f.write(tl.utils.get_cmd() + '\n') 160 | f.write(str(args) + '\n') 161 | 162 | main(args, logdir) 163 | -------------------------------------------------------------------------------- /hsir/model/t3sc/layers/lowrank_sc_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import math 5 | import logging 6 | 7 | from .encoding_layer import EncodingLayer 8 | from .soft_thresholding import SoftThresholding 9 | 10 | logger = logging.getLogger(__name__) 11 | logger.setLevel(logging.DEBUG) 12 | 13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | 15 | 16 | class LowRankSCLayer(EncodingLayer): 17 | def __init__( 18 | self, 19 | patch_side, 20 | stride, 21 | K, 22 | rank, 23 | patch_centering, 24 | lbda_init, 25 | lbda_mode, 26 | beta=0, 27 | ssl=0, 28 | **kwargs, 29 | ): 30 | super().__init__(**kwargs) 31 | assert self.in_channels is not None 32 | assert self.code_size is not None 33 | self.patch_side = patch_side 34 | self.stride = stride 35 | self.K = K 36 | self.rank = rank 37 | self.patch_centering = patch_centering 38 | self.lbda_init = lbda_init 39 | self.lbda_mode = lbda_mode 40 | self.patch_size = self.in_channels * self.patch_side ** 2 41 | self.spat_dim = self.patch_side ** 2 42 | self.spec_dim = self.in_channels 43 | self.beta = beta 44 | self.ssl = ssl 45 | 46 | # first is spectral, second is spatial 47 | self.init_weights( 48 | [ 49 | (self.code_size, self.spec_dim, self.rank), 50 | (self.code_size, self.rank, self.spat_dim), 51 | ] 52 | ) 53 | 54 | self.thresholds = SoftThresholding( 55 | mode=self.lbda_mode, 56 | lbda_init=self.lbda_init, 57 | code_size=self.code_size, 58 | K=self.K, 59 | ) 60 | if self.patch_centering and self.patch_side == 1: 61 | raise ValueError( 62 | "Patch centering and 1x1 kernel will result in null patches" 63 | ) 64 | 65 | if self.patch_centering: 66 | ones = torch.ones( 67 | self.in_channels, 1, self.patch_side, self.patch_side 68 | ) 69 | self.ker_mean = (ones / self.patch_side ** 2).to(device) 70 | self.ker_divider = torch.ones( 71 | 1, 1, self.patch_side, self.patch_side 72 | ).to(device) 73 | self.divider = None 74 | 75 | if self.beta: 76 | self.beta_estimator = nn.Sequential( 77 | # layer1 78 | nn.Conv2d( 79 | in_channels=1, out_channels=64, kernel_size=5, stride=2 80 | ), 81 | nn.ReLU(), 82 | nn.MaxPool2d(kernel_size=2), 83 | # layer2 84 | nn.Conv2d( 85 | in_channels=64, out_channels=128, kernel_size=3, stride=2 86 | ), 87 | nn.ReLU(), 88 | nn.MaxPool2d(kernel_size=2), 89 | # layer3 90 | nn.Conv2d( 91 | in_channels=128, out_channels=1, kernel_size=3, stride=1 92 | ), 93 | nn.Sigmoid(), 94 | ) 95 | 96 | def init_weights(self, shape): 97 | for w in ["C", "D", "W"]: 98 | setattr(self, w, self.init_param(shape)) 99 | 100 | def init_param(self, shape): 101 | def init_tensor(shape): 102 | tensor = torch.empty(*shape) 103 | torch.nn.init.kaiming_uniform_(tensor, a=math.sqrt(5)) 104 | return tensor 105 | 106 | if isinstance(shape, list): 107 | return torch.nn.ParameterList([self.init_param(s) for s in shape]) 108 | return torch.nn.Parameter(init_tensor(shape)) 109 | 110 | def _encode(self, x, sigmas=None, ssl_idx=None, **kwargs): 111 | self.shape_in = x.shape 112 | bs, c, h, w = self.shape_in 113 | 114 | if self.beta: 115 | block = min(56, h) 116 | c_w = (w - block) // 2 117 | c_h = (h - block) // 2 118 | to_estimate = x[:, :, c_h : c_h + block, c_w : c_w + block].view( 119 | bs * c, 1, block, block 120 | ) 121 | beta = 1 - self.beta_estimator(to_estimate) 122 | # (bs * c, 1) 123 | beta = beta.view(bs, c, 1, 1) 124 | else: 125 | beta = torch.ones((bs, c, 1, 1), device=x.device) 126 | 127 | if self.ssl: 128 | # discard error on bands we want to predict 129 | with torch.no_grad(): 130 | mask = torch.ones_like(beta) 131 | mask[:, ssl_idx.long()] = 0.0 132 | 133 | beta = beta * mask 134 | 135 | if self.beta or self.ssl: 136 | # applying beta before or after centering is equivalent 137 | x = x * beta 138 | 139 | CT = (self.C[0] @ self.C[1]).view( 140 | self.code_size, 141 | self.in_channels, 142 | self.patch_side, 143 | self.patch_side, 144 | ) 145 | 146 | if self.patch_centering: 147 | A = F.conv2d(x, CT - CT.mean(dim=[2, 3], keepdim=True)) 148 | self.means = F.conv2d(x, self.ker_mean, groups=self.in_channels) 149 | else: 150 | A = F.conv2d(x, CT) 151 | 152 | alpha = self.thresholds(A, 0) 153 | 154 | D = (self.D[0] @ self.D[1]).view( 155 | self.code_size, 156 | self.in_channels, 157 | self.patch_side, 158 | self.patch_side, 159 | ) 160 | 161 | for k in range(1, self.K): 162 | D_alpha = F.conv_transpose2d(alpha, D) 163 | D_alpha = D_alpha * beta 164 | alpha = self.thresholds(A + alpha - F.conv2d(D_alpha, CT), k) 165 | 166 | return alpha 167 | 168 | def _decode(self, alpha, **kwargs): 169 | W = ((self.W[0]) @ self.W[1]).view( 170 | self.code_size, 171 | self.in_channels, 172 | self.patch_side, 173 | self.patch_side, 174 | ) 175 | 176 | x = F.conv_transpose2d(alpha, W) 177 | 178 | if self.patch_centering: 179 | x += F.conv_transpose2d( 180 | self.means, 181 | self.ker_mean * self.patch_side ** 2, 182 | groups=self.in_channels, 183 | ) 184 | if self.divider is None or self.divider.shape[-2:] != (x.shape[-2:]): 185 | ones = torch.ones( 186 | 1, 1, alpha.shape[2], alpha.shape[3], device=alpha.device 187 | ).to(alpha.device) 188 | self.divider = F.conv_transpose2d(ones, self.ker_divider) 189 | 190 | x = x / self.divider 191 | 192 | return x 193 | -------------------------------------------------------------------------------- /hsir/data/transform/noise.py: -------------------------------------------------------------------------------- 1 | import threading 2 | 3 | import numpy as np 4 | 5 | 6 | # helper 7 | 8 | class LockedIterator(object): 9 | def __init__(self, it): 10 | self.lock = threading.Lock() 11 | self.it = it.__iter__() 12 | 13 | def __iter__(self): return self 14 | 15 | def __next__(self): 16 | self.lock.acquire() 17 | try: 18 | return next(self.it) 19 | finally: 20 | self.lock.release() 21 | 22 | 23 | # noise 24 | 25 | class AddNoise(object): 26 | """add gaussian noise to the given numpy array (B,H,W)""" 27 | 28 | def __init__(self, sigma): 29 | self.sigma_ratio = sigma / 255. 30 | 31 | def __call__(self, img): 32 | noise = np.random.randn(*img.shape) * self.sigma_ratio 33 | # print(img.sum(), noise.sum()) 34 | return img + noise 35 | 36 | 37 | class AddNoiseBlind(object): 38 | """add blind gaussian noise to the given numpy array (B,H,W)""" 39 | 40 | def __pos(self, n): 41 | i = 0 42 | while True: 43 | yield i 44 | i = (i + 1) % n 45 | 46 | def __init__(self, sigmas): 47 | self.sigmas = np.array(sigmas) / 255. 48 | self.pos = LockedIterator(self.__pos(len(sigmas))) 49 | 50 | def __call__(self, img): 51 | noise = np.random.randn(*img.shape) * self.sigmas[next(self.pos)] 52 | return img + noise 53 | 54 | 55 | class AddNoiseBlindv2(object): 56 | """add blind gaussian noise to the given numpy array (B,H,W)""" 57 | 58 | def __init__(self, min_sigma, max_sigma): 59 | self.min_sigma = min_sigma 60 | self.max_sigma = max_sigma 61 | 62 | def __call__(self, img): 63 | noise = np.random.randn(*img.shape) * np.random.uniform(self.min_sigma, self.max_sigma) / 255 64 | return img + noise 65 | 66 | 67 | class AddNoiseNoniid(object): 68 | """add non-iid gaussian noise to the given numpy array (B,H,W)""" 69 | 70 | def __init__(self, sigmas): 71 | self.sigmas = np.array(sigmas) / 255. 72 | 73 | def __call__(self, img): 74 | bwsigmas = np.reshape(self.sigmas[np.random.randint(0, len(self.sigmas), img.shape[0])], (-1, 1, 1)) 75 | noise = np.random.randn(*img.shape) * bwsigmas 76 | return img + noise 77 | 78 | 79 | class AddNoiseMixed(object): 80 | """add mixed noise to the given numpy array (B,H,W) 81 | Args: 82 | noise_bank: list of noise maker (e.g. AddNoiseImpulse) 83 | num_bands: list of number of band which is corrupted by each item in noise_bank""" 84 | 85 | def __init__(self, noise_bank, num_bands): 86 | assert len(noise_bank) == len(num_bands) 87 | self.noise_bank = noise_bank 88 | self.num_bands = num_bands 89 | 90 | def __call__(self, img): 91 | B, H, W = img.shape 92 | all_bands = np.random.permutation(range(B)) 93 | pos = 0 94 | for noise_maker, num_band in zip(self.noise_bank, self.num_bands): 95 | if 0 < num_band <= 1: 96 | num_band = int(np.floor(num_band * B)) 97 | bands = all_bands[pos:pos+num_band] 98 | pos += num_band 99 | img = noise_maker(img, bands) 100 | return img 101 | 102 | 103 | class _AddNoiseImpulse(object): 104 | """add impulse noise to the given numpy array (B,H,W)""" 105 | 106 | def __init__(self, amounts, s_vs_p=0.5): 107 | self.amounts = np.array(amounts) 108 | self.s_vs_p = s_vs_p 109 | 110 | def __call__(self, img, bands): 111 | # bands = np.random.permutation(range(img.shape[0]))[:self.num_band] 112 | bwamounts = self.amounts[np.random.randint(0, len(self.amounts), len(bands))] 113 | for i, amount in zip(bands, bwamounts): 114 | self.add_noise(img[i, ...], amount=amount, salt_vs_pepper=self.s_vs_p) 115 | return img 116 | 117 | def add_noise(self, image, amount, salt_vs_pepper): 118 | # out = image.copy() 119 | out = image 120 | p = amount 121 | q = salt_vs_pepper 122 | flipped = np.random.choice([True, False], size=image.shape, 123 | p=[p, 1 - p]) 124 | salted = np.random.choice([True, False], size=image.shape, 125 | p=[q, 1 - q]) 126 | peppered = ~salted 127 | out[flipped & salted] = 1 128 | out[flipped & peppered] = 0 129 | return out 130 | 131 | 132 | class _AddNoiseStripe(object): 133 | """add stripe noise to the given numpy array (B,H,W)""" 134 | 135 | def __init__(self, min_amount, max_amount): 136 | assert max_amount > min_amount 137 | self.min_amount = min_amount 138 | self.max_amount = max_amount 139 | 140 | def __call__(self, img, bands): 141 | B, H, W = img.shape 142 | # bands = np.random.permutation(range(img.shape[0]))[:len(bands)] 143 | num_stripe = np.random.randint(np.floor(self.min_amount*W), np.floor(self.max_amount*W), len(bands)) 144 | for i, n in zip(bands, num_stripe): 145 | loc = np.random.permutation(range(W)) 146 | loc = loc[:n] 147 | stripe = np.random.uniform(0, 1, size=(len(loc),))*0.5-0.25 148 | img[i, :, loc] -= np.reshape(stripe, (-1, 1)) 149 | return img 150 | 151 | 152 | class _AddNoiseDeadline(object): 153 | """add deadline noise to the given numpy array (B,H,W)""" 154 | 155 | def __init__(self, min_amount, max_amount): 156 | assert max_amount > min_amount 157 | self.min_amount = min_amount 158 | self.max_amount = max_amount 159 | 160 | def __call__(self, img, bands): 161 | B, H, W = img.shape 162 | # bands = np.random.permutation(range(img.shape[0]))[:len(bands)] 163 | num_deadline = np.random.randint(np.ceil(self.min_amount*W), np.ceil(self.max_amount*W), len(bands)) 164 | for i, n in zip(bands, num_deadline): 165 | loc = np.random.permutation(range(W)) 166 | loc = loc[:n] 167 | img[i, :, loc] = 0 168 | return img 169 | 170 | 171 | class AddNoiseImpulse(AddNoiseMixed): 172 | def __init__(self): 173 | self.noise_bank = [_AddNoiseImpulse([0.1, 0.3, 0.5, 0.7])] 174 | self.num_bands = [1/3] 175 | 176 | 177 | class AddNoiseStripe(AddNoiseMixed): 178 | def __init__(self): 179 | self.noise_bank = [_AddNoiseStripe(0.05, 0.15)] 180 | self.num_bands = [1/3] 181 | 182 | 183 | class AddNoiseDeadline(AddNoiseMixed): 184 | def __init__(self): 185 | self.noise_bank = [_AddNoiseDeadline(0.05, 0.15)] 186 | self.num_bands = [1/3] 187 | 188 | 189 | class AddNoiseComplex(AddNoiseMixed): 190 | def __init__(self): 191 | self.noise_bank = [ 192 | _AddNoiseStripe(0.05, 0.15), 193 | _AddNoiseDeadline(0.05, 0.15), 194 | _AddNoiseImpulse([0.1, 0.3, 0.5, 0.7]) 195 | ] 196 | self.num_bands = [1/3, 1/3, 1/3] 197 | -------------------------------------------------------------------------------- /hsir/model/hsdt/arch.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .attention import TransformerBlock 8 | from .sepconv import SepConv_DP, SepConv_DP_CA, S3Conv 9 | 10 | BatchNorm3d = nn.BatchNorm3d 11 | Conv3d = S3Conv.of(nn.Conv3d) 12 | TransformerBlock = TransformerBlock 13 | IsConvImpl = False 14 | UseBN = True 15 | 16 | 17 | def PlainConv(in_ch, out_ch): 18 | return nn.Sequential(OrderedDict([ 19 | ('conv', Conv3d(in_ch, out_ch, 3, 1, 1, bias=False)), 20 | ('bn', BatchNorm3d(out_ch) if UseBN else nn.Identity()), 21 | ('attn', TransformerBlock(out_ch, bias=True)) 22 | ])) 23 | 24 | 25 | def DownConv(in_ch, out_ch): 26 | return nn.Sequential(OrderedDict([ 27 | ('conv', nn. Conv3d(in_ch, out_ch, 3, (1, 2, 2), 1, bias=False)), 28 | ('bn', BatchNorm3d(out_ch)if UseBN else nn.Identity()), 29 | ('attn', TransformerBlock(out_ch, bias=True)) 30 | ])) 31 | 32 | 33 | def UpConv(in_ch, out_ch): 34 | return nn.Sequential(OrderedDict([ 35 | ('up', nn.Upsample(scale_factor=(1, 2, 2), mode='trilinear', align_corners=True)), 36 | ('conv', nn.Conv3d(in_ch, out_ch, 3, 1, 1, bias=False)), 37 | ('bn', BatchNorm3d(out_ch) if UseBN else nn.Identity()), 38 | ('attn', TransformerBlock(out_ch, bias=True)) 39 | ])) 40 | 41 | 42 | class Encoder(nn.Module): 43 | def __init__(self, channels, num_half_layer, sample_idx): 44 | super(Encoder, self).__init__() 45 | self.layers = nn.ModuleList() 46 | for i in range(num_half_layer): 47 | if i not in sample_idx: 48 | encoder_layer = PlainConv(channels, channels) 49 | else: 50 | encoder_layer = DownConv(channels, 2 * channels) 51 | channels *= 2 52 | self.layers.append(encoder_layer) 53 | 54 | def forward(self, x, xs): 55 | num_half_layer = len(self.layers) 56 | for i in range(num_half_layer - 1): 57 | x = self.layers[i](x) 58 | xs.append(x) 59 | x = self.layers[-1](x) 60 | return x 61 | 62 | 63 | class Decoder(nn.Module): 64 | count = 1 65 | def __init__(self, channels, num_half_layer, sample_idx, Fusion=None): 66 | super(Decoder, self).__init__() 67 | # Decoder 68 | self.layers = nn.ModuleList() 69 | self.enable_fusion = Fusion is not None 70 | 71 | if self.enable_fusion: 72 | self.fusions = nn.ModuleList() 73 | ch = channels 74 | for i in reversed(range(num_half_layer)): 75 | fusion_layer = Fusion(ch) 76 | if i in sample_idx: 77 | ch //= 2 78 | self.fusions.append(fusion_layer) 79 | 80 | for i in reversed(range(num_half_layer)): 81 | if i not in sample_idx: 82 | decoder_layer = PlainConv(channels, channels) 83 | else: 84 | decoder_layer = UpConv(channels, channels // 2) 85 | channels //= 2 86 | self.layers.append(decoder_layer) 87 | 88 | def forward(self, x, xs): 89 | num_half_layer = len(self.layers) 90 | x = self.layers[0](x) 91 | for i in range(1, num_half_layer): 92 | if self.enable_fusion: 93 | x = self.fusions[i](x, xs.pop()) 94 | else: 95 | x = x + xs.pop() 96 | x = self.layers[i](x) 97 | return x 98 | 99 | 100 | class HSDT(nn.Module): 101 | def __init__(self, in_channels, channels, num_half_layer, sample_idx, Fusion=None): 102 | super(HSDT, self).__init__() 103 | self.head = PlainConv(in_channels, channels) 104 | self.encoder = Encoder(channels, num_half_layer, sample_idx) 105 | self.decoder = Decoder(channels * (2**len(sample_idx)), num_half_layer, sample_idx, Fusion=Fusion) 106 | self.tail = nn.Conv3d(channels, 1, 3, 1, 1, bias=True) 107 | 108 | def forward(self, x): 109 | xs = [x] 110 | out = self.head(xs[0]) 111 | xs.append(out) 112 | out = self.encoder(out, xs) 113 | out = self.decoder(out, xs) 114 | out = out + xs.pop() 115 | out = self.tail(out) 116 | out = out + xs.pop()[:, 0:1, :, :, :] 117 | return out 118 | 119 | def load_state_dict(self, state_dict, strict: bool = True): 120 | if IsConvImpl: 121 | new_state_dict = {} 122 | for k, v in state_dict.items(): 123 | if ('attn.attn' in k) and 'weight' in k and 'attn_proj' not in k: 124 | new_state_dict[k] = v.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 125 | else: 126 | new_state_dict[k] = v 127 | state_dict = new_state_dict 128 | return super().load_state_dict(state_dict, strict) 129 | 130 | 131 | class HSDTSSR(nn.Module): 132 | def __init__(self, in_channels, channels, num_half_layer, sample_idx, Fusion=None): 133 | super(HSDTSSR, self).__init__() 134 | self.proj = nn.Conv2d(3, 31, 1, bias=False) 135 | self.head = PlainConv(in_channels, channels) 136 | self.encoder = Encoder(channels, num_half_layer, sample_idx) 137 | self.decoder = Decoder(channels * (2**len(sample_idx)), num_half_layer, sample_idx, Fusion=Fusion) 138 | self.tail = nn.Conv3d(channels, 1, 3, 1, 1, bias=True) 139 | 140 | def forward(self, x): 141 | if self.training: 142 | return self.forward_train(x) 143 | return self.forward_test(x) 144 | 145 | def forward_train(self, x): 146 | x = F.leaky_relu(self.proj(x)).unsqueeze(1) 147 | xs = [x] 148 | out = self.head(xs[0]) 149 | xs.append(out) 150 | out = self.encoder(out, xs) 151 | out = self.decoder(out, xs) 152 | out = out + xs.pop() 153 | out = self.tail(out) 154 | out = out + xs.pop()[:, 0:1, :, :, :] 155 | out = out.squeeze(1) 156 | return out 157 | 158 | def forward_test(self, x): 159 | pad_x, H, W = pad_mod(x, 8) 160 | output = self.forward_train(pad_x)[..., :H, :W] 161 | return output 162 | 163 | def load_state_dict(self, state_dict, strict: bool = True): 164 | if IsConvImpl: 165 | new_state_dict = {} 166 | for k, v in state_dict.items(): 167 | if ('attn.attn' in k) and 'weight' in k and 'attn_proj' not in k: 168 | new_state_dict[k] = v.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 169 | else: 170 | new_state_dict[k] = v 171 | state_dict = new_state_dict 172 | return super().load_state_dict(state_dict, strict) 173 | 174 | 175 | def pad_mod(x, mod): 176 | h, w = x.shape[-2:] 177 | h_out = (h // mod + 1) * mod 178 | w_out = (w // mod + 1) * mod 179 | out = torch.zeros(*x.shape[:-2], h_out, w_out).type_as(x) 180 | out[..., :h, :w] = x 181 | return out.to(x.device), h, w 182 | -------------------------------------------------------------------------------- /hsir/model/t3sc/base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | 4 | import numpy as np 5 | import pytorch_lightning as pl 6 | import torch 7 | 8 | from .metrics import mpsnr, mse, psnr 9 | from .utils import PatchesHandler 10 | 11 | logger = logging.getLogger(__name__) 12 | logger.setLevel(logging.DEBUG) 13 | 14 | 15 | class BaseModel(pl.LightningModule): 16 | def __init__( 17 | self, optimizer=None, lr_scheduler=None, block_inference=None 18 | ): 19 | super().__init__() 20 | self.optimizer = optimizer 21 | self.lr_scheduler = lr_scheduler 22 | self.block_inference = block_inference 23 | self.ssl = 0 24 | self.n_ssl = 0 25 | 26 | self.automatic_optimization = False 27 | 28 | def training_step(self, batch, batch_idx): 29 | opt = self.optimizers() 30 | opt.zero_grad() 31 | 32 | y = batch.pop("y") 33 | if self.ssl: 34 | x = batch["x"] 35 | bs, c, h, w = x.shape 36 | ssl_idx = torch.randperm(c)[: self.n_ssl].to(self.device) 37 | batch["ssl_idx"] = ssl_idx 38 | out = self.forward(**batch) 39 | band_out = out[:, ssl_idx] 40 | band_target = x[:, ssl_idx] 41 | self.log("train_mse", mse(band_out, band_target)) 42 | self.log("train_psnr", psnr(band_out, band_target)) 43 | self.log("train_mpsnr", mpsnr(band_out.detach(), band_target)) 44 | band_y = y[:, ssl_idx] 45 | self.log("train_mse_y", mse(band_out, band_y)) 46 | self.log("train_psnr_y", psnr(band_out, band_y)) 47 | self.log("train_mpsnr_y", mpsnr(band_out.detach(), band_y)) 48 | 49 | loss = mse(band_out, band_target) 50 | 51 | else: 52 | out = self.forward(**batch) 53 | self.log("train_mse", mse(out, y)) 54 | self.log("train_psnr", psnr(out, y)) 55 | self.log("train_mpsnr", mpsnr(out.detach(), y)) 56 | 57 | loss = mse(out, y) 58 | 59 | self.manual_backward(loss) 60 | opt.step() 61 | 62 | sch = self.lr_schedulers() 63 | if self.trainer.is_last_batch: 64 | epoch = self.current_epoch 65 | lr = sch.get_last_lr() 66 | logger.info(f"Epoch {epoch} : lr={lr} \t loss={loss:.6f}") 67 | sch.step() 68 | 69 | def validation_step(self, batch, batch_idx): 70 | y = batch.pop("y") 71 | start = time.time() 72 | if self.ssl: 73 | bs, c, h, w = batch["x"].shape 74 | out = torch.zeros_like(batch["x"]) 75 | N = int(np.ceil(c / self.n_ssl)) 76 | for i in range(N): 77 | ssl_idx = self.get_ssl_idx(i, c).long() 78 | batch["ssl_idx"] = ssl_idx 79 | if self.block_inference and self.block_inference.use_bi: 80 | _out = self.forward_blocks(**batch) 81 | else: 82 | _out = self.forward(**batch) 83 | out[:, ssl_idx] = _out[:, ssl_idx] 84 | else: 85 | if self.block_inference and self.block_inference.use_bi: 86 | out = self.forward_blocks(**batch) 87 | else: 88 | out = self.forward(**batch) 89 | logger.debug(f"Val denoised shape: {out.shape}") 90 | out = out.clamp(0, 1) 91 | elapsed = time.time() - start 92 | _mse = mse(out, y) 93 | _mpsnr = mpsnr(out, y) 94 | logger.debug(f"Val mse : {_mse}, mpsnr: {_mpsnr}") 95 | self.log("val_mse", mse(out, y)) 96 | self.log("val_psnr", psnr(out, y)) 97 | self.log("val_mpsnr", mpsnr(out, y)) 98 | self.log("val_batch_time", elapsed) 99 | self.log("val_psnr_noise", psnr(batch["x"], y)) 100 | self.log("val_mpsnr_noise", mpsnr(batch["x"], y)) 101 | 102 | def get_ssl_idx(self, i, c): 103 | N = np.ceil(c / self.n_ssl) 104 | L = int(np.ceil((c - i) / N)) 105 | return i + N * torch.arange(L) 106 | 107 | def test_step(self, batch, batch_idx): 108 | y = batch.pop("y") 109 | if self.ssl: 110 | bs, c, h, w = batch["x"].shape 111 | out = torch.zeros_like(batch["x"]) 112 | N = int(np.ceil(c / self.n_ssl)) 113 | for i in range(N): 114 | ssl_idx = self.get_ssl_idx(i, c).long() 115 | batch["ssl_idx"] = ssl_idx 116 | if self.block_inference and self.block_inference.use_bi: 117 | _out = self.forward_blocks(**batch) 118 | else: 119 | _out = self.forward(**batch) 120 | out[:, ssl_idx] = _out[:, ssl_idx] 121 | else: 122 | if self.block_inference and self.block_inference.use_bi: 123 | out = self.forward_blocks(**batch) 124 | else: 125 | out = self.forward(**batch) 126 | logger.debug(f"Test denoised shape: {out.shape}") 127 | out = out.clamp(0, 1) 128 | self.log("test_mse", mse(out, y)) 129 | self.log("test_psnr", psnr(out, y)) 130 | self.log("test_mpsnr", mpsnr(out, y)) 131 | self.log("test_psnr_noise", psnr(batch["x"], y)) 132 | self.log("test_mpsnr_noise", mpsnr(batch["x"], y)) 133 | 134 | def forward_blocks(self, x, **kwargs): 135 | logger.debug(f"Starting block inference") 136 | block_size = min( 137 | max(x.shape[-1], x.shape[-2]), self.block_inference.block_size 138 | ) 139 | patches_handler = PatchesHandler( 140 | size=(block_size,) * 2, 141 | channels=x.shape[1], 142 | stride=block_size - self.block_inference.overlap, 143 | padding=self.block_inference.padding, 144 | ) 145 | 146 | logger.debug(f"Forward patches handler") 147 | blocks_in = patches_handler(x, mode="extract").clone() 148 | blocks_grid = tuple(blocks_in.shape[-2:]) 149 | logger.debug(f"blocks grid : {blocks_in.shape}") 150 | 151 | blocks_out = torch.zeros_like(blocks_in) 152 | 153 | logger.debug(f"Processing blocks {blocks_grid}") 154 | for i in range(blocks_grid[0]): 155 | for j in range(blocks_grid[1]): 156 | blocks_ij = self.forward(blocks_in[:, :, :, :, i, j], **kwargs) 157 | blocks_out[:, :, :, :, i, j] = blocks_ij 158 | x = patches_handler(blocks_out, mode="aggregate") 159 | logger.debug(f"Blocks aggregated to shape : {tuple(x.shape)}") 160 | return x 161 | 162 | def configure_optimizers(self): 163 | logger.debug("Configuring optimizer") 164 | optim_class = torch.optim.__dict__[self.optimizer.class_name] 165 | optimizer = optim_class(self.parameters(), **self.optimizer.params) 166 | 167 | if self.lr_scheduler is not None: 168 | scheduler_class = torch.optim.lr_scheduler.__dict__[ 169 | self.lr_scheduler.class_name 170 | ] 171 | scheduler = scheduler_class(optimizer, **self.lr_scheduler.params) 172 | return [optimizer], [scheduler] 173 | 174 | def count_params(self): 175 | desc = "Model parameters:\n" 176 | counter = 0 177 | for name, param in self.named_parameters(): 178 | if param.requires_grad: 179 | count = param.numel() 180 | desc += f"\t{name} : {count}\n" 181 | counter += count 182 | desc += f"Total number of learnable parameters : {counter}\n" 183 | logger.info(desc) 184 | return counter, desc 185 | -------------------------------------------------------------------------------- /hsir/model/man/v2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | """ Utility block """ 7 | 8 | 9 | class UpsampleConv3d(torch.nn.Module): 10 | """ UpsampleConvLayer 11 | """ 12 | 13 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=True, upsample=None, group=1): 14 | super(UpsampleConv3d, self).__init__() 15 | self.upsample_layer = nn.Upsample(scale_factor=upsample, mode='trilinear', align_corners=True) 16 | self.conv3d = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, bias=bias, groups=group) 17 | 18 | def forward(self, x): 19 | x = self.upsample_layer(x) 20 | x = self.conv3d(x) 21 | return x 22 | 23 | 24 | class MLP(nn.Module): 25 | """ 26 | Multilayer Perceptron (MLP) 27 | """ 28 | 29 | def __init__(self, channel, bias=True): 30 | super().__init__() 31 | self.w_1 = nn.Conv3d(channel, channel, bias=bias, kernel_size=1) 32 | self.w_2 = nn.Conv3d(channel, channel, bias=bias, kernel_size=1) 33 | 34 | def forward(self, x): 35 | return self.w_2(F.tanh(self.w_1(x))) 36 | 37 | 38 | """ The proposed blocks 39 | """ 40 | 41 | 42 | class PSCA(nn.Module): 43 | """ Progressive Spectral Channel Attention (PSCA) 44 | """ 45 | 46 | def __init__(self, d_model, d_ff): 47 | super().__init__() 48 | self.w_1 = nn.Conv3d(d_model, d_ff, 1, bias=False) 49 | self.w_2 = nn.Conv3d(d_ff, d_model, 1, bias=False) 50 | self.w_3 = nn.Conv3d(d_model, d_model, 1, bias=False) 51 | 52 | nn.init.zeros_(self.w_3.weight) 53 | 54 | def forward(self, x): 55 | x = self.w_3(x) * x + x 56 | x = self.w_1(x) 57 | x = F.gelu(x) 58 | x = self.w_2(x) 59 | return x 60 | 61 | 62 | class ASC(nn.Module): 63 | """ Attentive Skip Connection 64 | """ 65 | 66 | def __init__(self, channel): 67 | super().__init__() 68 | self.weight = nn.Sequential( 69 | nn.Conv3d(channel * 2, channel, 1), 70 | nn.LeakyReLU(), 71 | nn.Conv3d(channel, channel, 3, 1, 1), 72 | nn.Sigmoid() 73 | ) 74 | 75 | def forward(self, x, y): 76 | w = self.weight(torch.cat([x, y], dim=1)) 77 | out = (1 - w) * x + w * y 78 | return out 79 | 80 | 81 | class MHRSA(nn.Module): 82 | """ Multi-Head Recurrent Spectral Attention 83 | """ 84 | 85 | def __init__(self, channels, multi_head=True, ffn=True): 86 | super().__init__() 87 | self.channels = channels 88 | self.multi_head = multi_head 89 | self.ffn = ffn 90 | 91 | if ffn: 92 | self.ffn1 = MLP(channels) 93 | self.ffn2 = MLP(channels) 94 | 95 | def _conv_step(self, inputs): 96 | if self.ffn: 97 | Z = self.ffn1(inputs).tanh() 98 | F = self.ffn2(inputs).sigmoid() 99 | else: 100 | Z, F = inputs.split(split_size=self.channels, dim=1) 101 | Z, F = Z.tanh(), F.sigmoid() 102 | return Z, F 103 | 104 | def _rnn_step(self, z, f, h): 105 | h_ = (1 - f) * z if h is None else f * h + (1 - f) * z 106 | return h_ 107 | 108 | def forward(self, inputs, reverse=False): 109 | Z, F = self._conv_step(inputs) 110 | 111 | if self.multi_head: 112 | Z1, Z2 = Z.split(self.channels // 2, 1) 113 | Z2 = torch.flip(Z2, [2]) 114 | Z = torch.cat([Z1, Z2], dim=1) 115 | 116 | F1, F2 = F.split(self.channels // 2, 1) 117 | F2 = torch.flip(F2, [2]) 118 | F = torch.cat([F1, F2], dim=1) 119 | 120 | h = None 121 | h_time = [] 122 | 123 | if not reverse: 124 | for _, (z, f) in enumerate(zip(Z.split(1, 2), F.split(1, 2))): 125 | h = self._rnn_step(z, f, h) 126 | h_time.append(h) 127 | else: 128 | for _, (z, f) in enumerate((zip( 129 | reversed(Z.split(1, 2)), reversed(F.split(1, 2)) 130 | ))): # split along timestep 131 | h = self._rnn_step(z, f, h) 132 | h_time.insert(0, h) 133 | 134 | y = torch.cat(h_time, dim=2) 135 | 136 | if self.multi_head: 137 | y1, y2 = y.split(self.channels // 2, 1) 138 | y2 = torch.flip(y2, [2]) 139 | y = torch.cat([y1, y2], dim=1) 140 | 141 | return y 142 | 143 | 144 | class MAB(nn.Module): 145 | def __init__(self, conv_layer, channels, multi_head=True, ffn=True): 146 | super().__init__() 147 | self.conv = conv_layer 148 | self.inter_sa = MHRSA(channels, multi_head=multi_head, ffn=ffn) 149 | self.intra_sa = PSCA(channels, channels * 2) 150 | 151 | def forward(self, x, reverse=False): 152 | x = self.conv(x) 153 | x = self.inter_sa(x, reverse=reverse) 154 | x = self.intra_sa(x) 155 | return x 156 | 157 | 158 | """ Encoder-Decoder 159 | """ 160 | 161 | 162 | def PlainMAB(in_ch, out_ch, bias=False): 163 | return MAB(nn.Conv3d(in_ch, out_ch, 3, 1, 1, bias=bias), out_ch) 164 | 165 | 166 | def DownMAB(in_ch, out_ch, bias=False): 167 | return MAB(nn.Conv3d(in_ch, out_ch, 3, (1, 2, 2), 1, bias=bias), out_ch) 168 | 169 | 170 | def UpMAB(in_ch, out_ch, bias=False): 171 | return MAB(UpsampleConv3d(in_ch, out_ch, 3, 1, 1, bias=bias, upsample=(1, 2, 2)), out_ch) 172 | 173 | 174 | class Encoder(nn.Module): 175 | def __init__(self, channels, num_half_layer, sample_idx): 176 | super(Encoder, self).__init__() 177 | self.layers = nn.ModuleList() 178 | for i in range(num_half_layer): 179 | if i not in sample_idx: 180 | encoder_layer = PlainMAB(channels, channels) 181 | else: 182 | encoder_layer = DownMAB(channels, 2 * channels) 183 | channels *= 2 184 | self.layers.append(encoder_layer) 185 | 186 | def forward(self, x, xs, reverse=False): 187 | num_half_layer = len(self.layers) 188 | for i in range(num_half_layer - 1): 189 | x = self.layers[i](x, reverse) 190 | reverse = not reverse 191 | xs.append(x) 192 | x = self.layers[-1](x, reverse) 193 | reverse = not reverse 194 | return x 195 | 196 | 197 | class Decoder(nn.Module): 198 | def __init__(self, channels, num_half_layer, sample_idx, Fusion=None): 199 | super(Decoder, self).__init__() 200 | self.layers = nn.ModuleList() 201 | self.enable_fusion = Fusion is not None 202 | 203 | if self.enable_fusion: 204 | self.fusions = nn.ModuleList() 205 | ch = channels 206 | for i in reversed(range(num_half_layer)): 207 | fusion_layer = Fusion(ch) 208 | if i in sample_idx: 209 | ch //= 2 210 | self.fusions.append(fusion_layer) 211 | 212 | for i in reversed(range(num_half_layer)): 213 | if i not in sample_idx: 214 | decoder_layer = PlainMAB(channels, channels) 215 | else: 216 | decoder_layer = UpMAB(channels, channels // 2) 217 | channels //= 2 218 | self.layers.append(decoder_layer) 219 | 220 | def forward(self, x, xs, reverse=False): 221 | num_half_layer = len(self.layers) 222 | x = self.layers[0](x, reverse) 223 | reverse = not reverse 224 | for i in range(1, num_half_layer): 225 | if self.enable_fusion: 226 | x = self.fusions[i](x, xs.pop()) 227 | else: 228 | x = x + xs.pop() 229 | x = self.layers[i](x, reverse) 230 | reverse = not reverse 231 | return x 232 | 233 | 234 | class MAN(nn.Module): 235 | def __init__(self, in_channels, channels, num_half_layer, sample_idx, Fusion=None): 236 | super().__init__() 237 | self.head = PlainMAB(in_channels, channels) 238 | self.encoder = Encoder(channels, num_half_layer, sample_idx) 239 | self.decoder = Decoder(channels * (2**len(sample_idx)), num_half_layer, sample_idx, Fusion=Fusion) 240 | self.tail = nn.Conv3d(channels, in_channels, 3, 1, 1, bias=True) 241 | 242 | def forward(self, x): 243 | xs = [x] 244 | out = self.head(xs[0]) 245 | xs.append(out) 246 | reverse = True 247 | out = self.encoder(out, xs, reverse) 248 | out = self.decoder(out, xs, reverse) 249 | out = out + xs.pop() 250 | out = self.tail(out) 251 | out = out + xs.pop() 252 | return out 253 | 254 | 255 | class MAN_T(MAN): 256 | def __init__(self, in_channels, channels, num_half_layer, sample_idx, Fusion=None): 257 | super().__init__(in_channels, channels, num_half_layer, sample_idx, Fusion) 258 | self.tail = PlainMAB(channels, in_channels, bias=True) 259 | 260 | 261 | """ Models 262 | """ 263 | 264 | 265 | def man_s(): 266 | net = MAN(1, 12, 5, [1, 3], Fusion=ASC) 267 | net.use_2dconv = False 268 | net.bandwise = False 269 | return net 270 | 271 | 272 | def man_m(): 273 | net = MAN(1, 16, 5, [1, 3], Fusion=ASC) 274 | net.use_2dconv = False 275 | net.bandwise = False 276 | return net 277 | 278 | 279 | def man_l(): 280 | net = MAN(1, 20, 5, [1, 3], Fusion=ASC) 281 | net.use_2dconv = False 282 | net.bandwise = False 283 | return net 284 | 285 | -------------------------------------------------------------------------------- /hsir/model/man/v1.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | """ Skip Connection """ 8 | 9 | 10 | class AdditiveConnection(nn.Module): 11 | def __init__(self, channel): 12 | super().__init__() 13 | 14 | def forward(self, x, y): 15 | return x + y 16 | 17 | 18 | class ConcatSkipConnection(nn.Module): 19 | def __init__(self, channel): 20 | super().__init__() 21 | self.conv = nn.Sequential( 22 | nn.Conv3d(channel * 2, channel, 3, 1, 1), 23 | ) 24 | 25 | def forward(self, x, y): 26 | return self.conv(torch.cat([x, y], dim=1)) 27 | 28 | 29 | class AdaptiveSkipConnection(nn.Module): 30 | def __init__(self, channel): 31 | super().__init__() 32 | self.conv_weight = nn.Sequential( 33 | nn.Conv3d(channel * 2, channel, 1), 34 | nn.Tanh(), 35 | nn.Conv3d(channel, channel, 3, 1, 1), 36 | nn.Sigmoid() 37 | ) 38 | 39 | def forward(self, x, y): 40 | w = self.conv_weight(torch.cat([x, y], dim=1)) 41 | return (1 - w) * x + w * y 42 | 43 | 44 | """ Channel Attention """ 45 | 46 | 47 | class GlobalAvgPool3d(nn.Module): 48 | def __init__(self): 49 | super().__init__() 50 | 51 | def forward(self, x): 52 | sum = torch.sum(x, dim=(2, 3, 4), keepdim=True) 53 | return sum / (x.shape[2] * x.shape[3] * x.shape[4]) 54 | 55 | 56 | class SimplifiedChannelAttention(nn.Module): 57 | def __init__(self, ch): 58 | super().__init__() 59 | self.sca = nn.Sequential( 60 | # nn.AdaptiveAvgPool3d(1), 61 | GlobalAvgPool3d(), 62 | nn.Conv3d(ch, ch, kernel_size=1), 63 | nn.Sigmoid() 64 | ) 65 | 66 | def forward(self, x): 67 | return self.sca(x) * x 68 | 69 | 70 | class ChannelAttention(nn.Module): 71 | def __init__(self, ch): 72 | super().__init__() 73 | self.ca = nn.Sequential( 74 | GlobalAvgPool3d(), 75 | nn.Conv3d(ch, ch, kernel_size=1), 76 | nn.ReLU(), 77 | nn.Conv3d(ch, ch, kernel_size=1), 78 | nn.Sigmoid() 79 | ) 80 | 81 | def forward(self, x): 82 | return self.ca(x) * x 83 | 84 | 85 | """ Mixed Attention Block """ 86 | 87 | 88 | class PositionwiseFeedForward(nn.Module): 89 | "Implements FFN equation." 90 | 91 | def __init__(self, channel): 92 | super(PositionwiseFeedForward, self).__init__() 93 | self.w_1 = nn.Conv3d(channel, channel, kernel_size=1) 94 | self.w_2 = nn.Conv3d(channel, channel, kernel_size=1) 95 | 96 | def forward(self, x): 97 | return self.w_2(torch.tanh(self.w_1(x))) 98 | 99 | 100 | class MAB(nn.Module): 101 | def __init__(self, channels, enable_ca=True, reverse=False): 102 | super(MAB, self).__init__() 103 | self.channels = channels 104 | self.enable_ca = enable_ca 105 | self.reverse = reverse 106 | self.sca = SimplifiedChannelAttention(channels) 107 | self.ffn_f = PositionwiseFeedForward(channels) 108 | self.ffn_w = PositionwiseFeedForward(channels) 109 | 110 | def _rnn_step(self, z, f, h): 111 | h_ = (1 - f) * z if h is None else f * h + (1 - f) * z 112 | return h_ 113 | 114 | def forward(self, inputs): 115 | h = None 116 | Z = self.ffn_f(inputs).tanh() 117 | F = self.ffn_w(inputs).sigmoid() 118 | h_time = [] 119 | 120 | if not self.reverse: 121 | for time, (z, f) in enumerate(zip(Z.split(1, 2), F.split(1, 2))): 122 | h = self._rnn_step(z, f, h) 123 | h_time.append(h) 124 | else: 125 | for time, (z, f) in enumerate((zip( 126 | reversed(Z.split(1, 2)), reversed(F.split(1, 2)) 127 | ))): 128 | h = self._rnn_step(z, f, h) 129 | h_time.insert(0, h) 130 | 131 | out = torch.cat(h_time, dim=2) 132 | if self.enable_ca: 133 | out = self.sca(out) 134 | return out 135 | 136 | 137 | class BiMAB(MAB): 138 | def __init__(self, channels, enable_ca=True): 139 | super().__init__(channels, enable_ca, None) 140 | self.ffn_w2 = PositionwiseFeedForward(channels) 141 | 142 | def forward(self, inputs): 143 | Z = self.ffn_f(inputs).tanh() 144 | F1 = self.ffn_w(inputs).sigmoid() 145 | F2 = self.ffn_w2(inputs).sigmoid() 146 | 147 | h = None 148 | hsl = [] 149 | hsr = [] 150 | zs = Z.split(1, 2) 151 | 152 | for time, (z, f) in enumerate(zip(zs, F1.split(1, 2))): 153 | h = self._rnn_step(z, f, h) 154 | hsl.append(h) 155 | 156 | h = None 157 | for time, (z, f) in enumerate((zip(reversed(zs), reversed(F2.split(1, 2))))): 158 | h = self._rnn_step(z, f, h) 159 | hsr.insert(0, h) 160 | 161 | hsl = torch.cat(hsl, dim=2) 162 | hsr = torch.cat(hsr, dim=2) 163 | 164 | out = hsl + hsr 165 | if self.enable_ca: 166 | out = self.sca(out) 167 | return out 168 | 169 | 170 | """ Mixed Attention Network """ 171 | 172 | 173 | class Encoder(nn.Module): 174 | def __init__(self, channels, num_half_layer, sample_idx, Attn=None): 175 | super(Encoder, self).__init__() 176 | self.layers = nn.ModuleList() 177 | for i in range(num_half_layer): 178 | if i not in sample_idx: 179 | encoder_layer = nn.Sequential(OrderedDict([ 180 | ('conv', nn.Conv3d(channels, channels, 3, 1, 1, bias=False)), 181 | ('attn', Attn(channels, reverse=i % 2 == 1)) if Attn else None 182 | ])) 183 | else: 184 | encoder_layer = nn.Sequential(OrderedDict([ 185 | ('conv', nn.Conv3d(channels, channels * 2, 3, (1, 2, 2), 1, bias=False)), 186 | ('attn', Attn(channels * 2, reverse=i % 2 == 1)) if Attn else None 187 | ])) 188 | channels *= 2 189 | self.layers.append(encoder_layer) 190 | 191 | def forward(self, x, xs): 192 | num_half_layer = len(self.layers) 193 | for i in range(num_half_layer - 1): 194 | x = self.layers[i](x) 195 | xs.append(x) 196 | x = self.layers[-1](x) 197 | return x 198 | 199 | 200 | class Decoder(nn.Module): 201 | def __init__(self, channels, num_half_layer, sample_idx, Fusion=None, Attn=None): 202 | super(Decoder, self).__init__() 203 | 204 | self.layers = nn.ModuleList() 205 | self.enable_fusion = Fusion is not None 206 | 207 | if self.enable_fusion: 208 | self.fusions = nn.ModuleList() 209 | ch = channels 210 | for i in reversed(range(num_half_layer)): 211 | fusion_layer = Fusion(ch) 212 | if i in sample_idx: 213 | ch //= 2 214 | self.fusions.append(fusion_layer) 215 | 216 | for i in reversed(range(num_half_layer)): 217 | if i not in sample_idx: 218 | decoder_layer = nn.Sequential(OrderedDict([ 219 | ('conv', nn.ConvTranspose3d(channels, channels, 3, 1, 1, bias=False)), 220 | ('attn', Attn(channels, reverse=i % 2 == 0)) if Attn else None 221 | ])) 222 | else: 223 | decoder_layer = nn.Sequential(OrderedDict([ 224 | ('up', nn.Upsample(scale_factor=(1, 2, 2), mode='trilinear', align_corners=True)), 225 | ('conv', nn.Conv3d(channels, channels // 2, 3, 1, 1, bias=False)), 226 | ('attn', Attn(channels // 2, reverse=i % 2 == 0)) if Attn else None 227 | ])) 228 | channels //= 2 229 | self.layers.append(decoder_layer) 230 | 231 | def forward(self, x, xs): 232 | num_half_layer = len(self.layers) 233 | x = self.layers[0](x) 234 | for i in range(1, num_half_layer): 235 | if self.enable_fusion: 236 | x = self.fusions[i](x, xs.pop()) 237 | else: 238 | x = x + xs.pop() 239 | x = self.layers[i](x) 240 | return x 241 | 242 | 243 | class MAN(nn.Module): 244 | def __init__( 245 | self, 246 | in_channels=1, 247 | channels=16, 248 | num_half_layer=5, 249 | sample_idx=[1, 3], 250 | Attn=MAB, 251 | BiAttn=BiMAB, 252 | Fusion=AdaptiveSkipConnection, 253 | ): 254 | super(MAN, self).__init__() 255 | 256 | self.head = nn.Sequential( 257 | nn.Conv3d(in_channels, channels, 3, 1, 1, bias=False), 258 | BiAttn(channels) if BiAttn else None 259 | ) 260 | 261 | self.encoder = Encoder(channels, num_half_layer, sample_idx, Attn) 262 | self.decoder = Decoder(channels * (2**len(sample_idx)), num_half_layer, sample_idx, 263 | Fusion=Fusion, Attn=Attn) 264 | 265 | self.tail = nn.Sequential( 266 | nn.ConvTranspose3d(channels, in_channels, 3, 1, 1, bias=True), 267 | BiAttn(in_channels) if BiAttn else None 268 | ) 269 | 270 | def forward(self, x): 271 | xs = [x] 272 | out = self.head(xs[0]) 273 | xs.append(out) 274 | out = self.encoder(out, xs) 275 | out = self.decoder(out, xs) 276 | out = out + xs.pop() 277 | out = self.tail(out) 278 | out = out + xs.pop() 279 | return out 280 | 281 | 282 | """ Model Variants """ 283 | 284 | 285 | def man(): 286 | net = MAN(1, 16, 5, [1, 3]) 287 | net.use_2dconv = False 288 | net.bandwise = False 289 | return net 290 | 291 | 292 | def man_m(): 293 | net = MAN(1, 12, 5, [1, 3]) 294 | net.use_2dconv = False 295 | net.bandwise = False 296 | return net 297 | 298 | 299 | def man_s(): 300 | net = MAN(1, 8, 5, [1, 3]) 301 | net.use_2dconv = False 302 | net.bandwise = False 303 | return net 304 | 305 | 306 | def man_b(): 307 | net = MAN(1, 24, 5, [1, 3]) 308 | net.use_2dconv = False 309 | net.bandwise = False 310 | return net 311 | 312 | 313 | def man_deep(): 314 | net = MAN(1, 16, 7, [1, 3, 5]) 315 | net.use_2dconv = False 316 | net.bandwise = False 317 | return net 318 | 319 | 320 | """ Baseline """ 321 | 322 | 323 | def baseline(): 324 | net = MAN(1, 16, 5, [1, 3], Attn=None, BiAttn=None, Fusion=None) 325 | net.use_2dconv = False 326 | net.bandwise = False 327 | return net 328 | -------------------------------------------------------------------------------- /hsir/model/hsdt/attention.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from einops import rearrange 8 | 9 | 10 | class SSA(nn.Module): 11 | """ Spectral Self Attention (SSA) in "Hybrid Spectral Denoising Transformer with Learnable Query" 12 | GSSA without gudiance. 13 | """ 14 | def __init__(self, channel, num_bands, flex=False): 15 | super().__init__() 16 | self.channel = channel 17 | self.num_bands = num_bands 18 | 19 | self.value_proj = nn.Linear(channel, channel, bias=False) 20 | self.fc = nn.Linear(channel, channel, bias=False) 21 | 22 | def forward(self, x): 23 | B, C, D, H, W = x.shape 24 | 25 | residual = x 26 | 27 | tmp = x.reshape(B, C, D, H * W).mean(-1).permute(0, 2, 1) 28 | attn = tmp @ tmp.transpose(1, 2) 29 | 30 | attn = attn.reshape(B, self.num_bands, self.num_bands) 31 | attn = F.softmax(attn, dim=-1) # B, band, band 32 | attn = attn.unsqueeze(1).unsqueeze(1) 33 | 34 | v = self.value_proj(rearrange(x, 'b c d h w -> b h w d c')) 35 | 36 | q = torch.matmul(attn, v) 37 | 38 | q = self.fc(q) 39 | q = rearrange(q, 'b h w d c -> b c d h w') 40 | 41 | q += residual 42 | 43 | return q, attn 44 | 45 | 46 | class GSSA(nn.Module): 47 | """ Guided Spectral Self Attention (GSSA) in "Hybrid Spectral Denoising Transformer with Learnable Query" 48 | """ 49 | 50 | def __init__(self, channel, num_bands, flex=False): 51 | super().__init__() 52 | self.channel = channel 53 | self.num_bands = num_bands 54 | self.flex = flex 55 | 56 | # learnable query 57 | self.attn_proj = nn.Linear(channel, num_bands) 58 | self.value_proj = nn.Linear(channel, channel, bias=False) 59 | self.fc = nn.Linear(channel, channel, bias=False) 60 | 61 | def forward(self, x): 62 | B, C, D, H, W = x.shape 63 | 64 | residual = x 65 | 66 | tmp = x.reshape(B, C, D, H * W).mean(-1).permute(0, 2, 1) 67 | 68 | if self.training: 69 | if random.random() > 0.5: 70 | attn = tmp @ tmp.transpose(1, 2) 71 | else: 72 | attn = self.attn_proj(tmp) 73 | else: 74 | if self.flex: 75 | attn = tmp @ tmp.transpose(1, 2) 76 | else: 77 | attn = self.attn_proj(tmp) 78 | 79 | attn = attn.reshape(B, self.num_bands, self.num_bands) 80 | attn = F.softmax(attn, dim=-1) # B, band, band 81 | attn = attn.unsqueeze(1).unsqueeze(1) 82 | 83 | v = self.value_proj(rearrange(x, 'b c d h w -> b h w d c')) 84 | 85 | q = torch.matmul(attn, v) 86 | 87 | q = self.fc(q) 88 | q = rearrange(q, 'b h w d c -> b c d h w') 89 | 90 | q += residual 91 | 92 | return q, attn 93 | 94 | 95 | class PixelwiseGSSA(GSSA): 96 | """ Pixelwise GSSA 97 | """ 98 | def __init__(self, channel, num_bands): 99 | super().__init__(channel, num_bands) 100 | 101 | def forward(self, x): 102 | B, C, D, H, W = x.shape 103 | x = rearrange(x, 'b c d h w -> b (h w) d c') 104 | 105 | residual = x 106 | 107 | if self.training: 108 | if random.random() > 0.5: 109 | attn = x @ x.transpose(-1, -2) 110 | else: 111 | attn = self.attn_proj(x) 112 | else: 113 | attn = x @ x.transpose(-1, -2) 114 | 115 | attn = F.softmax(attn, dim=-1) 116 | 117 | v = self.value_proj(x) 118 | q = torch.matmul(attn, v) 119 | 120 | q = self.fc(q) 121 | q += residual 122 | 123 | q = rearrange(q, 'b (h w) d c -> b c d h w', d=D, h=H, w=W) 124 | 125 | return q, attn 126 | 127 | 128 | class GSSAConvImpl(GSSA): 129 | """ GSSA fast convolutional implementation 130 | """ 131 | def __init__(self, channel, num_bands, flex=False): 132 | super().__init__(channel, num_bands) 133 | 134 | def forward(self, x): 135 | B, C, D, H, W = x.shape 136 | 137 | x = rearrange(x, 'b c d h w -> b h w d c') 138 | residual = x 139 | 140 | tmp = rearrange(x, 'b h w d c -> b (h w) d c').mean(1) 141 | attn = self.attn_proj(tmp) # B,band,band 142 | attn = F.softmax(attn, dim=-1) # b,band,band 143 | 144 | v = self.value_proj(x) # b c band w h 145 | 146 | if B > 1: 147 | input = rearrange(v, 'b h w d c -> (b d) c h w').unsqueeze(0) 148 | weight = attn.reshape(B * D, D, 1, 1, 1) 149 | q = F.conv3d(input, weight, groups=B) # 1, b*d, c, w, h 150 | q = rearrange(q.squeeze(0), '(b d) c h w -> b h w d c', b=B) 151 | else: 152 | input = rearrange(v, 'b h w d c -> c (b d) h w') 153 | weight = attn.reshape(D, D, 1, 1) 154 | q = F.conv2d(input, weight) # 1, b*d, c, w, h 155 | q = rearrange(q, 'c (b d) w h -> b h w d c', b=B) 156 | 157 | q = self.fc(q) 158 | q += residual 159 | 160 | q = rearrange(q, 'b h w d c -> b c d h w') 161 | return q, attn 162 | 163 | 164 | """ Feedforwrd 165 | """ 166 | 167 | 168 | class SMFFN(nn.Module): 169 | """ Self Modulated Feed Forward Network (SM-FFN) in "Hybrid Spectral Denoising Transformer with Learnable Query" 170 | """ 171 | 172 | def __init__(self, d_model, d_ff, bias=False): 173 | super().__init__() 174 | self.w_1 = nn.Linear(d_model, d_ff, bias=bias) 175 | self.w_2 = nn.Linear(d_ff, d_model, bias=bias) 176 | self.w_3 = nn.Linear(d_model, d_ff, bias=bias) 177 | 178 | def forward(self, input): 179 | x = self.w_1(input) 180 | x = F.gelu(x) 181 | x1 = self.w_2(x) 182 | 183 | x = self.w_3(input) 184 | x, w = torch.chunk(x, 2, dim=-1) 185 | x2 = x * torch.sigmoid(w) 186 | 187 | return x1 + x2 188 | 189 | 190 | class SMFFNBranch(nn.Module): 191 | def __init__(self, d_model, d_ff, bias=False): 192 | super().__init__() 193 | self.w_3 = nn.Linear(d_model, d_ff, bias=bias) 194 | 195 | def forward(self, input): 196 | x = self.w_3(input) 197 | x, w = torch.chunk(x, 2, dim=-1) 198 | x2 = x * torch.sigmoid(w) 199 | return x2 200 | 201 | 202 | class GDFN(nn.Module): 203 | """ 3D version of GDFN from Restormer. 204 | """ 205 | def __init__(self, d_model, d_ff, bias=False): 206 | super(GDFN, self).__init__() 207 | self.project_in = nn.Conv3d(d_model, d_ff, kernel_size=1, bias=bias) 208 | self.dwconv = nn.Conv3d(d_ff, d_ff, kernel_size=3, stride=1, padding=1, groups=d_ff, bias=bias) 209 | self.project_out = nn.Conv3d(d_model, d_model, kernel_size=1, bias=bias) 210 | 211 | def forward(self, input): 212 | input = rearrange(input, 'b d h w c -> b c d h w') 213 | x = self.project_in(input) 214 | x1, x2 = self.dwconv(x).chunk(2, dim=1) 215 | x = F.gelu(x1) * x2 216 | x = self.project_out(x) 217 | output = x 218 | output = rearrange(output, 'b c d h w -> b d h w c') 219 | return output 220 | 221 | 222 | class FFN(nn.Module): 223 | def __init__(self, d_model, d_ff, bias=False): 224 | super().__init__() 225 | self.w_1 = nn.Linear(d_model, d_ff, bias=bias) 226 | self.w_2 = nn.Linear(d_ff, d_model, bias=bias) 227 | self.act = nn.GELU() 228 | 229 | def forward(self, input): 230 | output= self.w_2(self.act(self.w_1(input))) 231 | return output 232 | 233 | 234 | """ Transformer Block 235 | """ 236 | 237 | 238 | class TransformerBlock(nn.Module): 239 | def __init__(self, channels, num_bands=31, bias=False, flex=False): 240 | super().__init__() 241 | self.channels = channels 242 | self.attn = GSSA(channels, num_bands, flex=flex) 243 | self.ffn = SMFFN(channels, channels * 2, bias=bias) 244 | 245 | def forward(self, inputs): 246 | r, _ = self.attn(inputs) 247 | r = rearrange(r, 'b c d h w -> b d h w c') 248 | r = self.ffn(r) 249 | r = rearrange(r, 'b d h w c -> b c d h w') 250 | return r 251 | 252 | 253 | class DummyTransformerBlock(nn.Module): 254 | def __init__(self, channels, num_bands=31, bias=False, flex=False): 255 | super().__init__() 256 | self.channels = channels 257 | 258 | def forward(self, inputs): 259 | return inputs.tanh() 260 | 261 | 262 | """ Ablation 263 | """ 264 | 265 | 266 | class PixelwiseTransformerBlock(TransformerBlock): 267 | def __init__(self, channels, num_bands=31, bias=False, flex=False): 268 | super().__init__(channels, num_bands, bias, flex) 269 | self.channels = channels 270 | self.attn = PixelwiseGSSA(channels, num_bands, flex=flex) 271 | self.ffn = SMFFN(channels, channels * 2, bias=bias) 272 | 273 | 274 | class FFNTransformerBlock(TransformerBlock): 275 | def __init__(self, channels, num_bands=31, bias=False, flex=False): 276 | super().__init__(channels, num_bands, bias, flex) 277 | self.channels = channels 278 | self.attn = GSSA(channels, num_bands, flex=flex) 279 | self.ffn = FFN(channels, channels * 2, bias=bias) 280 | 281 | 282 | class GDFNTransformerBlock(TransformerBlock): 283 | def __init__(self, channels, num_bands=31, bias=False, flex=False): 284 | super().__init__(channels, num_bands, bias, flex) 285 | self.channels = channels 286 | self.attn = GSSA(channels, num_bands, flex=flex) 287 | self.ffn = GDFN(channels, channels * 2, bias=bias) 288 | 289 | 290 | class GFNTransformerBlock(TransformerBlock): 291 | def __init__(self, channels, num_bands=31, bias=False, flex=False): 292 | super().__init__(channels, num_bands, bias, flex) 293 | self.channels = channels 294 | self.attn = GSSA(channels, num_bands, flex=flex) 295 | self.ffn = SMFFNBranch(channels, channels * 2, bias=bias) 296 | 297 | 298 | class SSATransformerBlock(TransformerBlock): 299 | def __init__(self, channels, num_bands=31, bias=False, flex=False): 300 | super().__init__(channels, num_bands, bias, flex) 301 | self.channels = channels 302 | self.attn = SSA(channels, num_bands, flex=flex) 303 | self.ffn = SMFFN(channels, channels * 2, bias=bias) 304 | 305 | 306 | class SSAFFNTransformerBlock(TransformerBlock): 307 | def __init__(self, channels, num_bands=31, bias=False, flex=False): 308 | super().__init__(channels, num_bands, bias, flex) 309 | self.channels = channels 310 | self.attn = SSA(channels, num_bands, flex=flex) 311 | self.ffn = FFN(channels, channels * 2, bias=bias) 312 | --------------------------------------------------------------------------------